llama_core/
utils.rs

1//! Define utility functions.
2
3use crate::{
4    error::{BackendError, LlamaCoreError},
5    BaseMetadata, Graph, CHAT_GRAPHS, EMBEDDING_GRAPHS, MAX_BUFFER_SIZE,
6};
7use bitflags::bitflags;
8use chat_prompts::PromptTemplateType;
9use serde_json::Value;
10
11pub(crate) fn gen_chat_id() -> String {
12    format!("chatcmpl-{}", uuid::Uuid::new_v4())
13}
14
15pub(crate) fn gen_response_id() -> String {
16    let uuid1 = uuid::Uuid::new_v4();
17    let uuid2 = uuid::Uuid::new_v4();
18
19    // 生成48个十六进制字符 (与原硬编码长度一致)
20    let part1 = uuid1.simple().to_string(); // 32个字符
21    let part2 = uuid2.simple().to_string(); // 32个字符
22
23    format!("resp_{}{}", part1, &part2[..16]) // resp_ + 48字符 = 53字符总长度
24}
25
26/// Return the names of the chat models.
27pub fn chat_model_names() -> Result<Vec<String>, LlamaCoreError> {
28    #[cfg(feature = "logging")]
29    info!(target: "stdout", "Get the names of the chat models.");
30
31    let chat_graphs = match CHAT_GRAPHS.get() {
32        Some(chat_graphs) => chat_graphs,
33        None => {
34            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
35
36            #[cfg(feature = "logging")]
37            error!(target: "stdout", "{err_msg}");
38
39            return Err(LlamaCoreError::Operation(err_msg.into()));
40        }
41    };
42
43    let chat_graphs = chat_graphs.lock().map_err(|e| {
44        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
45
46        #[cfg(feature = "logging")]
47        error!(target: "stdout", "{}", &err_msg);
48
49        LlamaCoreError::Operation(err_msg)
50    })?;
51
52    let mut model_names = Vec::new();
53    for model_name in chat_graphs.keys() {
54        model_names.push(model_name.clone());
55    }
56
57    Ok(model_names)
58}
59
60/// Return the names of the embedding models.
61pub fn embedding_model_names() -> Result<Vec<String>, LlamaCoreError> {
62    #[cfg(feature = "logging")]
63    info!(target: "stdout", "Get the names of the embedding models.");
64
65    let embedding_graphs = match EMBEDDING_GRAPHS.get() {
66        Some(embedding_graphs) => embedding_graphs,
67        None => {
68            return Err(LlamaCoreError::Operation(String::from(
69                "Fail to get the underlying value of `EMBEDDING_GRAPHS`.",
70            )));
71        }
72    };
73
74    let embedding_graphs = match embedding_graphs.lock() {
75        Ok(embedding_graphs) => embedding_graphs,
76        Err(e) => {
77            let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {e}");
78
79            #[cfg(feature = "logging")]
80            error!(target: "stdout", "{}", &err_msg);
81
82            return Err(LlamaCoreError::Operation(err_msg));
83        }
84    };
85
86    let mut model_names = Vec::new();
87    for model_name in embedding_graphs.keys() {
88        model_names.push(model_name.clone());
89    }
90
91    Ok(model_names)
92}
93
94/// Get the chat prompt template type from the given model name.
95pub fn chat_prompt_template(name: Option<&str>) -> Result<PromptTemplateType, LlamaCoreError> {
96    #[cfg(feature = "logging")]
97    match name {
98        Some(name) => {
99            info!(target: "stdout", "Get the chat prompt template type from the chat model named {name}.")
100        }
101        None => {
102            info!(target: "stdout", "Get the chat prompt template type from the default chat model.")
103        }
104    }
105
106    let chat_graphs = match CHAT_GRAPHS.get() {
107        Some(chat_graphs) => chat_graphs,
108        None => {
109            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
110
111            #[cfg(feature = "logging")]
112            error!(target: "stdout", "{err_msg}");
113
114            return Err(LlamaCoreError::Operation(err_msg.into()));
115        }
116    };
117
118    let chat_graphs = chat_graphs.lock().map_err(|e| {
119        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
120
121        #[cfg(feature = "logging")]
122        error!(target: "stdout", "{}", &err_msg);
123
124        LlamaCoreError::Operation(err_msg)
125    })?;
126
127    match name {
128        Some(model_name) => match chat_graphs.contains_key(model_name) {
129            true => {
130                let graph = chat_graphs.get(model_name).unwrap();
131                let prompt_template = graph.metadata.prompt_template();
132
133                #[cfg(feature = "logging")]
134                info!(target: "stdout", "prompt_template: {}", &prompt_template);
135
136                Ok(prompt_template)
137            }
138            false => match chat_graphs.iter().next() {
139                Some((_, graph)) => {
140                    let prompt_template = graph.metadata.prompt_template();
141
142                    #[cfg(feature = "logging")]
143                    info!(target: "stdout", "prompt_template: {}", &prompt_template);
144
145                    Ok(prompt_template)
146                }
147                None => {
148                    let err_msg = "There is no model available in the chat graphs.";
149
150                    #[cfg(feature = "logging")]
151                    error!(target: "stdout", "{}", &err_msg);
152
153                    Err(LlamaCoreError::Operation(err_msg.into()))
154                }
155            },
156        },
157        None => match chat_graphs.iter().next() {
158            Some((_, graph)) => {
159                let prompt_template = graph.metadata.prompt_template();
160
161                #[cfg(feature = "logging")]
162                info!(target: "stdout", "prompt_template: {}", &prompt_template);
163
164                Ok(prompt_template)
165            }
166            None => {
167                let err_msg = "There is no model available in the chat graphs.";
168
169                #[cfg(feature = "logging")]
170                error!(target: "stdout", "{}", &err_msg);
171
172                Err(LlamaCoreError::Operation(err_msg.into()))
173            }
174        },
175    }
176}
177
178/// Get output buffer generated by model.
179pub(crate) fn get_output_buffer<M>(
180    graph: &Graph<M>,
181    index: usize,
182) -> Result<Vec<u8>, LlamaCoreError>
183where
184    M: BaseMetadata + serde::Serialize + Clone + Default,
185{
186    let mut output_buffer: Vec<u8> = Vec::with_capacity(MAX_BUFFER_SIZE);
187
188    let output_size: usize = graph.get_output(index, &mut output_buffer).map_err(|e| {
189        let err_msg = format!("Fail to get the generated output tensor. {e}");
190
191        #[cfg(feature = "logging")]
192        error!(target: "stdout", "{}", &err_msg);
193
194        LlamaCoreError::Backend(BackendError::GetOutput(err_msg))
195    })?;
196
197    unsafe {
198        output_buffer.set_len(output_size);
199    }
200
201    Ok(output_buffer)
202}
203
204/// Get output buffer generated by model in the stream mode.
205pub(crate) fn get_output_buffer_single<M>(
206    graph: &Graph<M>,
207    index: usize,
208) -> Result<Vec<u8>, LlamaCoreError>
209where
210    M: BaseMetadata + serde::Serialize + Clone + Default,
211{
212    #[cfg(feature = "logging")]
213    info!(target: "stdout", "Get output buffer generated by the model named {} in the stream mode.", graph.name());
214
215    let mut output_buffer: Vec<u8> = Vec::with_capacity(MAX_BUFFER_SIZE);
216
217    let output_size: usize = graph
218        .get_output_single(index, &mut output_buffer)
219        .map_err(|e| {
220            let err_msg = format!("Fail to get plugin metadata. {e}");
221
222            #[cfg(feature = "logging")]
223            error!(target: "stdout", "{}", &err_msg);
224
225            LlamaCoreError::Backend(BackendError::GetOutput(err_msg))
226        })?;
227
228    unsafe {
229        output_buffer.set_len(output_size);
230    }
231
232    Ok(output_buffer)
233}
234
235pub(crate) fn set_tensor_data_u8<M>(
236    graph: &mut Graph<M>,
237    idx: usize,
238    tensor_data: &[u8],
239) -> Result<(), LlamaCoreError>
240where
241    M: BaseMetadata + serde::Serialize + Clone + Default,
242{
243    if graph
244        .set_input(idx, wasmedge_wasi_nn::TensorType::U8, &[1], tensor_data)
245        .is_err()
246    {
247        let err_msg = format!("Fail to set input tensor at index {idx}");
248
249        #[cfg(feature = "logging")]
250        error!(target: "stdout", "{}", &err_msg);
251
252        return Err(LlamaCoreError::Operation(err_msg));
253    };
254
255    Ok(())
256}
257
258/// Get the token information from the graph.
259pub(crate) fn get_token_info_by_graph<M>(graph: &Graph<M>) -> Result<TokenInfo, LlamaCoreError>
260where
261    M: BaseMetadata + serde::Serialize + Clone + Default,
262{
263    #[cfg(feature = "logging")]
264    info!(target: "stdout", "Get token info from the model named {}", graph.name());
265
266    let output_buffer = get_output_buffer(graph, 1)?;
267    let token_info: Value = match serde_json::from_slice(&output_buffer[..]) {
268        Ok(token_info) => token_info,
269        Err(e) => {
270            let err_msg = format!("Fail to deserialize token info: {e}");
271
272            #[cfg(feature = "logging")]
273            error!(target: "stdout", "{}", &err_msg);
274
275            return Err(LlamaCoreError::Operation(err_msg));
276        }
277    };
278
279    let prompt_tokens = match token_info["input_tokens"].as_u64() {
280        Some(prompt_tokens) => prompt_tokens,
281        None => {
282            let err_msg = "Fail to convert `input_tokens` to u64.";
283
284            #[cfg(feature = "logging")]
285            error!(target: "stdout", "{err_msg}");
286
287            return Err(LlamaCoreError::Operation(err_msg.into()));
288        }
289    };
290    let completion_tokens = match token_info["output_tokens"].as_u64() {
291        Some(completion_tokens) => completion_tokens,
292        None => {
293            let err_msg = "Fail to convert `output_tokens` to u64.";
294
295            #[cfg(feature = "logging")]
296            error!(target: "stdout", "{err_msg}");
297
298            return Err(LlamaCoreError::Operation(err_msg.into()));
299        }
300    };
301
302    Ok(TokenInfo {
303        prompt_tokens,
304        completion_tokens,
305    })
306}
307
308/// Get the token information from the graph by the model name.
309pub(crate) fn get_token_info_by_graph_name(
310    name: Option<&String>,
311) -> Result<TokenInfo, LlamaCoreError> {
312    let chat_graphs = match CHAT_GRAPHS.get() {
313        Some(chat_graphs) => chat_graphs,
314        None => {
315            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
316
317            #[cfg(feature = "logging")]
318            error!(target: "stdout", "{err_msg}");
319
320            return Err(LlamaCoreError::Operation(err_msg.into()));
321        }
322    };
323
324    let chat_graphs = chat_graphs.lock().map_err(|e| {
325        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
326
327        #[cfg(feature = "logging")]
328        error!(target: "stdout", "{}", &err_msg);
329
330        LlamaCoreError::Operation(err_msg)
331    })?;
332
333    match name {
334        Some(model_name) => match chat_graphs.contains_key(model_name) {
335            true => {
336                let graph = chat_graphs.get(model_name).unwrap();
337                get_token_info_by_graph(graph)
338            }
339            false => match chat_graphs.iter().next() {
340                Some((_, graph)) => get_token_info_by_graph(graph),
341                None => {
342                    let err_msg = "There is no model available in the chat graphs.";
343
344                    #[cfg(feature = "logging")]
345                    error!(target: "stdout", "{}", &err_msg);
346
347                    Err(LlamaCoreError::Operation(err_msg.into()))
348                }
349            },
350        },
351        None => match chat_graphs.iter().next() {
352            Some((_, graph)) => get_token_info_by_graph(graph),
353            None => {
354                let err_msg = "There is no model available in the chat graphs.";
355
356                #[cfg(feature = "logging")]
357                error!(target: "stdout", "{}", &err_msg);
358
359                Err(LlamaCoreError::Operation(err_msg.into()))
360            }
361        },
362    }
363}
364
365#[derive(Debug)]
366pub(crate) struct TokenInfo {
367    pub(crate) prompt_tokens: u64,
368    pub(crate) completion_tokens: u64,
369}
370
371pub(crate) trait TensorType {
372    fn tensor_type() -> wasmedge_wasi_nn::TensorType;
373    fn shape(shape: impl AsRef<[usize]>) -> Vec<usize> {
374        shape.as_ref().to_vec()
375    }
376}
377
378impl TensorType for u8 {
379    fn tensor_type() -> wasmedge_wasi_nn::TensorType {
380        wasmedge_wasi_nn::TensorType::U8
381    }
382}
383
384impl TensorType for f32 {
385    fn tensor_type() -> wasmedge_wasi_nn::TensorType {
386        wasmedge_wasi_nn::TensorType::F32
387    }
388}
389
390pub(crate) fn set_tensor_data<T, M>(
391    graph: &mut Graph<M>,
392    idx: usize,
393    tensor_data: &[T],
394    shape: impl AsRef<[usize]>,
395) -> Result<(), LlamaCoreError>
396where
397    T: TensorType,
398    M: BaseMetadata + serde::Serialize + Clone + Default,
399{
400    if graph
401        .set_input(idx, T::tensor_type(), &T::shape(shape), tensor_data)
402        .is_err()
403    {
404        let err_msg = format!("Fail to set input tensor at index {idx}");
405
406        #[cfg(feature = "logging")]
407        error!(target: "stdout", "{}", &err_msg);
408
409        return Err(LlamaCoreError::Operation(err_msg));
410    };
411
412    Ok(())
413}
414
415bitflags! {
416    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
417    pub struct RunningMode: u32 {
418        const UNSET = 0b00000000;
419        const CHAT = 0b00000001;
420        const EMBEDDINGS = 0b00000010;
421        const TTS = 0b00000100;
422        const RAG = 0b00001000;
423    }
424}
425impl std::fmt::Display for RunningMode {
426    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
427        let mut mode = String::new();
428
429        if self.contains(RunningMode::CHAT) {
430            mode.push_str("chat, ");
431        }
432        if self.contains(RunningMode::EMBEDDINGS) {
433            mode.push_str("embeddings, ");
434        }
435        if self.contains(RunningMode::TTS) {
436            mode.push_str("tts, ");
437        }
438
439        mode = mode.trim_end_matches(", ").to_string();
440
441        write!(f, "{mode}")
442    }
443}