1use crate::{
4 error::{BackendError, LlamaCoreError},
5 BaseMetadata, Graph, CHAT_GRAPHS, EMBEDDING_GRAPHS, MAX_BUFFER_SIZE,
6};
7use bitflags::bitflags;
8use chat_prompts::PromptTemplateType;
9use serde_json::Value;
10
11pub(crate) fn gen_chat_id() -> String {
12 format!("chatcmpl-{}", uuid::Uuid::new_v4())
13}
14
15pub(crate) fn gen_response_id() -> String {
16 let uuid1 = uuid::Uuid::new_v4();
17 let uuid2 = uuid::Uuid::new_v4();
18
19 let part1 = uuid1.simple().to_string(); let part2 = uuid2.simple().to_string(); format!("resp_{}{}", part1, &part2[..16]) }
25
26pub fn chat_model_names() -> Result<Vec<String>, LlamaCoreError> {
28 #[cfg(feature = "logging")]
29 info!(target: "stdout", "Get the names of the chat models.");
30
31 let chat_graphs = match CHAT_GRAPHS.get() {
32 Some(chat_graphs) => chat_graphs,
33 None => {
34 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
35
36 #[cfg(feature = "logging")]
37 error!(target: "stdout", "{err_msg}");
38
39 return Err(LlamaCoreError::Operation(err_msg.into()));
40 }
41 };
42
43 let chat_graphs = chat_graphs.lock().map_err(|e| {
44 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
45
46 #[cfg(feature = "logging")]
47 error!(target: "stdout", "{}", &err_msg);
48
49 LlamaCoreError::Operation(err_msg)
50 })?;
51
52 let mut model_names = Vec::new();
53 for model_name in chat_graphs.keys() {
54 model_names.push(model_name.clone());
55 }
56
57 Ok(model_names)
58}
59
60pub fn embedding_model_names() -> Result<Vec<String>, LlamaCoreError> {
62 #[cfg(feature = "logging")]
63 info!(target: "stdout", "Get the names of the embedding models.");
64
65 let embedding_graphs = match EMBEDDING_GRAPHS.get() {
66 Some(embedding_graphs) => embedding_graphs,
67 None => {
68 return Err(LlamaCoreError::Operation(String::from(
69 "Fail to get the underlying value of `EMBEDDING_GRAPHS`.",
70 )));
71 }
72 };
73
74 let embedding_graphs = match embedding_graphs.lock() {
75 Ok(embedding_graphs) => embedding_graphs,
76 Err(e) => {
77 let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {e}");
78
79 #[cfg(feature = "logging")]
80 error!(target: "stdout", "{}", &err_msg);
81
82 return Err(LlamaCoreError::Operation(err_msg));
83 }
84 };
85
86 let mut model_names = Vec::new();
87 for model_name in embedding_graphs.keys() {
88 model_names.push(model_name.clone());
89 }
90
91 Ok(model_names)
92}
93
94pub fn chat_prompt_template(name: Option<&str>) -> Result<PromptTemplateType, LlamaCoreError> {
96 #[cfg(feature = "logging")]
97 match name {
98 Some(name) => {
99 info!(target: "stdout", "Get the chat prompt template type from the chat model named {name}.")
100 }
101 None => {
102 info!(target: "stdout", "Get the chat prompt template type from the default chat model.")
103 }
104 }
105
106 let chat_graphs = match CHAT_GRAPHS.get() {
107 Some(chat_graphs) => chat_graphs,
108 None => {
109 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
110
111 #[cfg(feature = "logging")]
112 error!(target: "stdout", "{err_msg}");
113
114 return Err(LlamaCoreError::Operation(err_msg.into()));
115 }
116 };
117
118 let chat_graphs = chat_graphs.lock().map_err(|e| {
119 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
120
121 #[cfg(feature = "logging")]
122 error!(target: "stdout", "{}", &err_msg);
123
124 LlamaCoreError::Operation(err_msg)
125 })?;
126
127 match name {
128 Some(model_name) => match chat_graphs.contains_key(model_name) {
129 true => {
130 let graph = chat_graphs.get(model_name).unwrap();
131 let prompt_template = graph.metadata.prompt_template();
132
133 #[cfg(feature = "logging")]
134 info!(target: "stdout", "prompt_template: {}", &prompt_template);
135
136 Ok(prompt_template)
137 }
138 false => match chat_graphs.iter().next() {
139 Some((_, graph)) => {
140 let prompt_template = graph.metadata.prompt_template();
141
142 #[cfg(feature = "logging")]
143 info!(target: "stdout", "prompt_template: {}", &prompt_template);
144
145 Ok(prompt_template)
146 }
147 None => {
148 let err_msg = "There is no model available in the chat graphs.";
149
150 #[cfg(feature = "logging")]
151 error!(target: "stdout", "{}", &err_msg);
152
153 Err(LlamaCoreError::Operation(err_msg.into()))
154 }
155 },
156 },
157 None => match chat_graphs.iter().next() {
158 Some((_, graph)) => {
159 let prompt_template = graph.metadata.prompt_template();
160
161 #[cfg(feature = "logging")]
162 info!(target: "stdout", "prompt_template: {}", &prompt_template);
163
164 Ok(prompt_template)
165 }
166 None => {
167 let err_msg = "There is no model available in the chat graphs.";
168
169 #[cfg(feature = "logging")]
170 error!(target: "stdout", "{}", &err_msg);
171
172 Err(LlamaCoreError::Operation(err_msg.into()))
173 }
174 },
175 }
176}
177
178pub(crate) fn get_output_buffer<M>(
180 graph: &Graph<M>,
181 index: usize,
182) -> Result<Vec<u8>, LlamaCoreError>
183where
184 M: BaseMetadata + serde::Serialize + Clone + Default,
185{
186 let mut output_buffer: Vec<u8> = Vec::with_capacity(MAX_BUFFER_SIZE);
187
188 let output_size: usize = graph.get_output(index, &mut output_buffer).map_err(|e| {
189 let err_msg = format!("Fail to get the generated output tensor. {e}");
190
191 #[cfg(feature = "logging")]
192 error!(target: "stdout", "{}", &err_msg);
193
194 LlamaCoreError::Backend(BackendError::GetOutput(err_msg))
195 })?;
196
197 unsafe {
198 output_buffer.set_len(output_size);
199 }
200
201 Ok(output_buffer)
202}
203
204pub(crate) fn get_output_buffer_single<M>(
206 graph: &Graph<M>,
207 index: usize,
208) -> Result<Vec<u8>, LlamaCoreError>
209where
210 M: BaseMetadata + serde::Serialize + Clone + Default,
211{
212 #[cfg(feature = "logging")]
213 info!(target: "stdout", "Get output buffer generated by the model named {} in the stream mode.", graph.name());
214
215 let mut output_buffer: Vec<u8> = Vec::with_capacity(MAX_BUFFER_SIZE);
216
217 let output_size: usize = graph
218 .get_output_single(index, &mut output_buffer)
219 .map_err(|e| {
220 let err_msg = format!("Fail to get plugin metadata. {e}");
221
222 #[cfg(feature = "logging")]
223 error!(target: "stdout", "{}", &err_msg);
224
225 LlamaCoreError::Backend(BackendError::GetOutput(err_msg))
226 })?;
227
228 unsafe {
229 output_buffer.set_len(output_size);
230 }
231
232 Ok(output_buffer)
233}
234
235pub(crate) fn set_tensor_data_u8<M>(
236 graph: &mut Graph<M>,
237 idx: usize,
238 tensor_data: &[u8],
239) -> Result<(), LlamaCoreError>
240where
241 M: BaseMetadata + serde::Serialize + Clone + Default,
242{
243 if graph
244 .set_input(idx, wasmedge_wasi_nn::TensorType::U8, &[1], tensor_data)
245 .is_err()
246 {
247 let err_msg = format!("Fail to set input tensor at index {idx}");
248
249 #[cfg(feature = "logging")]
250 error!(target: "stdout", "{}", &err_msg);
251
252 return Err(LlamaCoreError::Operation(err_msg));
253 };
254
255 Ok(())
256}
257
258pub(crate) fn get_token_info_by_graph<M>(graph: &Graph<M>) -> Result<TokenInfo, LlamaCoreError>
260where
261 M: BaseMetadata + serde::Serialize + Clone + Default,
262{
263 #[cfg(feature = "logging")]
264 info!(target: "stdout", "Get token info from the model named {}", graph.name());
265
266 let output_buffer = get_output_buffer(graph, 1)?;
267 let token_info: Value = match serde_json::from_slice(&output_buffer[..]) {
268 Ok(token_info) => token_info,
269 Err(e) => {
270 let err_msg = format!("Fail to deserialize token info: {e}");
271
272 #[cfg(feature = "logging")]
273 error!(target: "stdout", "{}", &err_msg);
274
275 return Err(LlamaCoreError::Operation(err_msg));
276 }
277 };
278
279 let prompt_tokens = match token_info["input_tokens"].as_u64() {
280 Some(prompt_tokens) => prompt_tokens,
281 None => {
282 let err_msg = "Fail to convert `input_tokens` to u64.";
283
284 #[cfg(feature = "logging")]
285 error!(target: "stdout", "{err_msg}");
286
287 return Err(LlamaCoreError::Operation(err_msg.into()));
288 }
289 };
290 let completion_tokens = match token_info["output_tokens"].as_u64() {
291 Some(completion_tokens) => completion_tokens,
292 None => {
293 let err_msg = "Fail to convert `output_tokens` to u64.";
294
295 #[cfg(feature = "logging")]
296 error!(target: "stdout", "{err_msg}");
297
298 return Err(LlamaCoreError::Operation(err_msg.into()));
299 }
300 };
301
302 Ok(TokenInfo {
303 prompt_tokens,
304 completion_tokens,
305 })
306}
307
308pub(crate) fn get_token_info_by_graph_name(
310 name: Option<&String>,
311) -> Result<TokenInfo, LlamaCoreError> {
312 let chat_graphs = match CHAT_GRAPHS.get() {
313 Some(chat_graphs) => chat_graphs,
314 None => {
315 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
316
317 #[cfg(feature = "logging")]
318 error!(target: "stdout", "{err_msg}");
319
320 return Err(LlamaCoreError::Operation(err_msg.into()));
321 }
322 };
323
324 let chat_graphs = chat_graphs.lock().map_err(|e| {
325 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
326
327 #[cfg(feature = "logging")]
328 error!(target: "stdout", "{}", &err_msg);
329
330 LlamaCoreError::Operation(err_msg)
331 })?;
332
333 match name {
334 Some(model_name) => match chat_graphs.contains_key(model_name) {
335 true => {
336 let graph = chat_graphs.get(model_name).unwrap();
337 get_token_info_by_graph(graph)
338 }
339 false => match chat_graphs.iter().next() {
340 Some((_, graph)) => get_token_info_by_graph(graph),
341 None => {
342 let err_msg = "There is no model available in the chat graphs.";
343
344 #[cfg(feature = "logging")]
345 error!(target: "stdout", "{}", &err_msg);
346
347 Err(LlamaCoreError::Operation(err_msg.into()))
348 }
349 },
350 },
351 None => match chat_graphs.iter().next() {
352 Some((_, graph)) => get_token_info_by_graph(graph),
353 None => {
354 let err_msg = "There is no model available in the chat graphs.";
355
356 #[cfg(feature = "logging")]
357 error!(target: "stdout", "{}", &err_msg);
358
359 Err(LlamaCoreError::Operation(err_msg.into()))
360 }
361 },
362 }
363}
364
365#[derive(Debug)]
366pub(crate) struct TokenInfo {
367 pub(crate) prompt_tokens: u64,
368 pub(crate) completion_tokens: u64,
369}
370
371pub(crate) trait TensorType {
372 fn tensor_type() -> wasmedge_wasi_nn::TensorType;
373 fn shape(shape: impl AsRef<[usize]>) -> Vec<usize> {
374 shape.as_ref().to_vec()
375 }
376}
377
378impl TensorType for u8 {
379 fn tensor_type() -> wasmedge_wasi_nn::TensorType {
380 wasmedge_wasi_nn::TensorType::U8
381 }
382}
383
384impl TensorType for f32 {
385 fn tensor_type() -> wasmedge_wasi_nn::TensorType {
386 wasmedge_wasi_nn::TensorType::F32
387 }
388}
389
390pub(crate) fn set_tensor_data<T, M>(
391 graph: &mut Graph<M>,
392 idx: usize,
393 tensor_data: &[T],
394 shape: impl AsRef<[usize]>,
395) -> Result<(), LlamaCoreError>
396where
397 T: TensorType,
398 M: BaseMetadata + serde::Serialize + Clone + Default,
399{
400 if graph
401 .set_input(idx, T::tensor_type(), &T::shape(shape), tensor_data)
402 .is_err()
403 {
404 let err_msg = format!("Fail to set input tensor at index {idx}");
405
406 #[cfg(feature = "logging")]
407 error!(target: "stdout", "{}", &err_msg);
408
409 return Err(LlamaCoreError::Operation(err_msg));
410 };
411
412 Ok(())
413}
414
415bitflags! {
416 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
417 pub struct RunningMode: u32 {
418 const UNSET = 0b00000000;
419 const CHAT = 0b00000001;
420 const EMBEDDINGS = 0b00000010;
421 const TTS = 0b00000100;
422 const RAG = 0b00001000;
423 }
424}
425impl std::fmt::Display for RunningMode {
426 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
427 let mut mode = String::new();
428
429 if self.contains(RunningMode::CHAT) {
430 mode.push_str("chat, ");
431 }
432 if self.contains(RunningMode::EMBEDDINGS) {
433 mode.push_str("embeddings, ");
434 }
435 if self.contains(RunningMode::TTS) {
436 mode.push_str("tts, ");
437 }
438
439 mode = mode.trim_end_matches(", ").to_string();
440
441 write!(f, "{mode}")
442 }
443}