1use crate::{
4 error,
5 metadata::ggml::GgmlMetadata,
6 running_mode,
7 utils::{
8 gen_chat_id, get_output_buffer, get_output_buffer_single, get_token_info_by_graph,
9 get_token_info_by_graph_name, set_tensor_data_u8,
10 },
11 Graph, RunningMode, CACHED_UTF8_ENCODINGS, CHAT_GRAPHS, OUTPUT_TENSOR,
12};
13use chat_prompts::{BuildChatPrompt, ChatPrompt, PromptTemplateType};
14use either::{Either, Left, Right};
15use endpoints::{
16 chat::{
17 ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionChunkChoiceDelta,
18 ChatCompletionObject, ChatCompletionObjectChoice, ChatCompletionObjectMessage,
19 ChatCompletionRequest, ChatCompletionRequestMessage, ChatCompletionRole,
20 ChatCompletionUserMessageContent, ContentPart, Function, ToolCall, ToolCallForChunk,
21 ToolChoice,
22 },
23 common::{FinishReason, Usage},
24};
25use error::{BackendError, LlamaCoreError};
26use std::{
27 collections::VecDeque,
28 pin::Pin,
29 sync::{
30 atomic::{AtomicBool, Ordering},
31 Mutex, OnceLock,
32 },
33 task::{Context, Poll, Waker},
34 time::SystemTime,
35};
36
37static CHAT_STREAM_WAKER_QUEUE: OnceLock<Mutex<VecDeque<Waker>>> = OnceLock::new();
39
40static CHAT_STREAM_ACTIVE: AtomicBool = AtomicBool::new(false);
42
43pub async fn chat(
45 chat_request: &mut ChatCompletionRequest,
46) -> Result<
47 (
48 Either<impl futures::TryStream<Ok = String, Error = LlamaCoreError>, ChatCompletionObject>,
49 bool,
50 ),
51 LlamaCoreError,
52> {
53 #[cfg(feature = "logging")]
54 {
55 debug!(target: "stdout", "tool choice: {:?}", chat_request.tool_choice.as_ref());
56 debug!(target: "stdout", "tools: {:?}", chat_request.tools.as_ref());
57 debug!(target: "stdout", "stream mode: {:?}", chat_request.stream);
58 }
59
60 let result = match chat_request.stream {
61 Some(true) => match chat_stream(chat_request).await {
62 Ok((stream, include_tool_calls)) => Ok((Left(stream), include_tool_calls)),
63 Err(e) => Err(e),
64 },
65 Some(false) | None => match chat_once(chat_request).await {
66 Ok((chat_completion_object, include_tool_calls)) => {
67 Ok((Right(chat_completion_object), include_tool_calls))
68 }
69 Err(e) => Err(e),
70 },
71 };
72
73 #[cfg(feature = "logging")]
74 info!(target: "stdout", "Reset the model metadata");
75
76 result
77}
78
79async fn chat_stream(
80 chat_request: &mut ChatCompletionRequest,
81) -> Result<
82 (
83 impl futures::TryStream<Ok = String, Error = LlamaCoreError>,
84 bool,
85 ),
86 LlamaCoreError,
87> {
88 #[cfg(feature = "logging")]
89 info!(target: "stdout", "Process chat completion request in the stream mode");
90
91 let running_mode = running_mode()?;
92 if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
93 let err_msg = "The chat completion is only supported in the chat or rag mode.";
94
95 #[cfg(feature = "logging")]
96 error!(target: "stdout", "{err_msg}");
97
98 return Err(LlamaCoreError::Operation(err_msg.to_string()));
99 }
100
101 let model_name = chat_request.model.clone();
102 let id = match &chat_request.user {
103 Some(id) => id.clone(),
104 None => gen_chat_id(),
105 };
106 #[cfg(feature = "logging")]
107 info!(target: "stdout", "user: {}", &id);
108
109 #[cfg(feature = "logging")]
110 info!(target: "stdout", "Check model metadata");
111
112 let mut metadata = check_model_metadata(chat_request)?;
114
115 let include_usage = match chat_request.stream_options {
117 Some(ref stream_options) => stream_options.include_usage.unwrap_or_default(),
118 None => metadata.include_usage,
119 };
120 #[cfg(feature = "logging")]
121 info!(target: "stdout", "include_usage: {include_usage}");
122
123 #[cfg(feature = "logging")]
124 info!(target: "stdout", "Build the chat prompt");
125
126 let (prompt, avaible_completion_tokens, tool_use) =
128 build_prompt(model_name.as_ref(), chat_request)?;
129
130 #[cfg(feature = "logging")]
131 {
132 info!(target: "stdout", "prompt:\n{}", &prompt);
133 info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
134 info!(target: "stdout", "tool_use: {tool_use}");
135 }
136
137 #[cfg(feature = "logging")]
138 info!(target: "stdout", "Update the n_predict");
139
140 update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
142
143 #[cfg(feature = "logging")]
144 info!(target: "stdout", "Feed the prompt to the model");
145
146 set_prompt(chat_request.model.as_ref(), &prompt)?;
148
149 let stream = match tool_use {
150 false => (ChatStream::new(model_name, id, include_usage, None), false),
151 true => {
152 let chat_graphs = match CHAT_GRAPHS.get() {
153 Some(chat_graphs) => chat_graphs,
154 None => {
155 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
156
157 #[cfg(feature = "logging")]
158 error!(target: "stdout", "{}", &err_msg);
159
160 return Err(LlamaCoreError::Operation(err_msg.into()));
161 }
162 };
163
164 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
165 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
166
167 #[cfg(feature = "logging")]
168 error!(target: "stdout", "{}", &err_msg);
169
170 LlamaCoreError::Operation(err_msg)
171 })?;
172
173 match model_name {
174 Some(model_name) => match chat_graphs.contains_key(&model_name) {
175 true => {
176 let graph = chat_graphs.get_mut(&model_name).unwrap();
177 chat_stream_for_tool(graph, id, include_usage)?
178 }
179 false => match chat_graphs.iter_mut().next() {
180 Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
181 None => {
182 let err_msg = "There is no model available in the chat graphs.";
183
184 #[cfg(feature = "logging")]
185 error!(target: "stdout", "{}", &err_msg);
186
187 return Err(LlamaCoreError::Operation(err_msg.into()));
188 }
189 },
190 },
191 None => match chat_graphs.iter_mut().next() {
192 Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
193 None => {
194 let err_msg = "There is no model available in the chat graphs.";
195
196 #[cfg(feature = "logging")]
197 error!(target: "stdout", "{}", &err_msg);
198
199 return Err(LlamaCoreError::Operation(err_msg.into()));
200 }
201 },
202 }
203 }
204 };
205
206 #[cfg(feature = "logging")]
207 info!(target: "stdout", "End of the chat completion stream.");
208
209 Ok(stream)
210}
211
212fn chat_stream_for_tool(
213 graph: &mut Graph<GgmlMetadata>,
214 id: impl Into<String>,
215 include_usage: bool,
216) -> Result<(ChatStream, bool), LlamaCoreError> {
217 #[cfg(feature = "logging")]
218 info!(target: "stdout", "Handle chat request with available tools by the model named {}.", graph.name());
219
220 let id = id.into();
221
222 match graph.compute() {
223 Ok(_) => {
224 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
226 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
227 let err_msg = format!(
228 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
229 );
230
231 #[cfg(feature = "logging")]
232 error!(target: "stdout", "{}", &err_msg);
233
234 LlamaCoreError::Operation(err_msg)
235 })?;
236
237 #[cfg(feature = "logging")]
238 info!(target: "stdout", "raw generation:\n{output}");
239
240 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
242 LlamaCoreError::Operation(format!("Failed to post-process the output. {e}"))
243 })?;
244
245 #[cfg(feature = "logging")]
246 info!(target: "stdout", "post-processed generation:\n{}", &message);
247
248 let token_info = get_token_info_by_graph(graph)?;
250
251 #[cfg(feature = "logging")]
252 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
253
254 let usage = Some(Usage {
255 prompt_tokens: token_info.prompt_tokens,
256 completion_tokens: token_info.completion_tokens,
257 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
258 });
259
260 let created = SystemTime::now()
261 .duration_since(std::time::UNIX_EPOCH)
262 .map_err(|e| {
263 let err_msg = format!("Failed to get the current time. Reason: {e}");
264
265 #[cfg(feature = "logging")]
266 error!(target: "stdout", "{}", &err_msg);
267
268 LlamaCoreError::Operation(err_msg)
269 })?;
270
271 if graph.metadata.prompt_template != PromptTemplateType::MistralTool
272 && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
273 && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
274 && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
275 && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
276 && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
277 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
278 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
279 && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
280 && graph.metadata.prompt_template != PromptTemplateType::Llama4Chat
281 && graph.metadata.prompt_template != PromptTemplateType::Qwen3NoThink
282 && graph.metadata.prompt_template != PromptTemplateType::Smol3NoThink
283 && graph.metadata.prompt_template != PromptTemplateType::Gemma3
284 && graph.metadata.prompt_template != PromptTemplateType::GptOss
285 && graph.metadata.prompt_template != PromptTemplateType::Qwen3Agent
286 && graph.metadata.prompt_template != PromptTemplateType::SeedOssNoThink
287 && graph.metadata.prompt_template != PromptTemplateType::SeedOssThink
288 {
289 let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', 'mistral-small-tool', 'llama-4-chat', 'qwen3-no-think', 'smol-3-no-think', 'gemma-3', 'gpt-oss', 'qwen3-agent', 'seed-oss-no-think', and 'seed-oss-think' prompt templates.", graph.metadata.prompt_template);
290
291 #[cfg(feature = "logging")]
292 error!(target: "stdout", "{}", &err_msg);
293
294 return Err(LlamaCoreError::Operation(err_msg));
295 }
296
297 let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
298
299 let content = if parsed_result.tool_calls.is_empty() {
300 Some(parsed_result.raw.clone())
301 } else {
302 parsed_result.content.clone()
303 };
304
305 let (tool_calls, include_tool_calls) = match parsed_result.tool_calls.is_empty() {
306 false => {
307 let tool_calls: Vec<ToolCallForChunk> = parsed_result
308 .tool_calls
309 .into_iter()
310 .enumerate()
311 .map(|(index, tool_call)| ToolCallForChunk {
312 index,
313 id: tool_call.id,
314 ty: tool_call.ty,
315 function: tool_call.function,
316 })
317 .collect();
318 (tool_calls, true)
319 }
320 true => (vec![], false),
321 };
322
323 let tool_call_chunk = {
325 let chat_completion_chunk = ChatCompletionChunk {
326 id: id.clone(),
327 object: "chat.completion.chunk".to_string(),
328 created: created.as_secs(),
329 model: graph.name().to_owned(),
330 system_fingerprint: "fp_44709d6fcb".to_string(),
331 choices: vec![ChatCompletionChunkChoice {
332 index: 0,
333 delta: ChatCompletionChunkChoiceDelta {
334 role: ChatCompletionRole::Assistant,
335 content,
336 tool_calls,
337 },
338 logprobs: None,
339 finish_reason: None,
340 }],
341 usage: None,
342 };
343 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
344 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
345
346 #[cfg(feature = "logging")]
347 error!(target: "stdout", "{}", &err_msg);
348
349 LlamaCoreError::Operation(err_msg)
350 })?;
351
352 format!("data: {chunk_str}\n\n")
353 };
354
355 let usage_chunk = {
357 let chat_completion_chunk = ChatCompletionChunk {
358 id: id.clone(),
359 object: "chat.completion.chunk".to_string(),
360 created: created.as_secs(),
361 model: graph.name().to_owned(),
362 system_fingerprint: "fp_44709d6fcb".to_string(),
363 choices: vec![],
364 usage,
365 };
366 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
367 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
368
369 #[cfg(feature = "logging")]
370 error!(target: "stdout", "{}", &err_msg);
371
372 LlamaCoreError::Operation(err_msg)
373 })?;
374
375 format!("data: {chunk_str}\n\n")
376 };
377
378 let ending_chunk = "data: [DONE]\n\n".to_string();
380
381 let chunks = vec![tool_call_chunk, usage_chunk, ending_chunk];
382
383 let stream = ChatStream::new(
384 Some(graph.name().to_owned()),
385 id,
386 include_usage,
387 Some(chunks),
388 );
389
390 Ok((stream, include_tool_calls))
391 }
392 Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
393 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
395 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
396 let err_msg = format!(
397 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
398 );
399
400 #[cfg(feature = "logging")]
401 error!(target: "stdout", "{}", &err_msg);
402
403 LlamaCoreError::Operation(err_msg)
404 })?;
405
406 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
408 let err_msg = format!("Failed to post-process the output. {e}");
409
410 #[cfg(feature = "logging")]
411 error!(target: "stdout", "{}", &err_msg);
412
413 LlamaCoreError::Operation(err_msg)
414 })?;
415
416 let token_info = get_token_info_by_graph(graph)?;
418
419 #[cfg(feature = "logging")]
420 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
421
422 let usage = Some(Usage {
423 prompt_tokens: token_info.prompt_tokens,
424 completion_tokens: token_info.completion_tokens,
425 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
426 });
427
428 let created = SystemTime::now()
429 .duration_since(std::time::UNIX_EPOCH)
430 .map_err(|e| {
431 let err_msg = format!("Failed to get the current time. Reason: {e}");
432
433 #[cfg(feature = "logging")]
434 error!(target: "stdout", "{}", &err_msg);
435
436 LlamaCoreError::Operation(err_msg)
437 })?;
438
439 let context_full_chunk = {
441 let chat_completion_chunk = ChatCompletionChunk {
442 id: id.clone(),
443 object: "chat.completion.chunk".to_string(),
444 created: created.as_secs(),
445 model: graph.name().to_owned(),
446 system_fingerprint: "fp_44709d6fcb".to_string(),
447 choices: vec![ChatCompletionChunkChoice {
448 index: 0,
449 delta: ChatCompletionChunkChoiceDelta {
450 role: ChatCompletionRole::Assistant,
451 content: Some(message),
452 tool_calls: vec![],
453 },
454 logprobs: None,
455 finish_reason: Some(FinishReason::length),
456 }],
457 usage: None,
458 };
459
460 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
462 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
463
464 #[cfg(feature = "logging")]
465 error!(target: "stdout", "{}", &err_msg);
466
467 LlamaCoreError::Operation(err_msg)
468 })?;
469
470 format!("data: {chunk_str}\n\n")
471 };
472
473 let usage_chunk = {
475 let chat_completion_chunk = ChatCompletionChunk {
476 id: id.clone(),
477 object: "chat.completion.chunk".to_string(),
478 created: created.as_secs(),
479 model: graph.name().to_owned(),
480 system_fingerprint: "fp_44709d6fcb".to_string(),
481 choices: vec![],
482 usage,
483 };
484
485 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
487 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
488
489 #[cfg(feature = "logging")]
490 error!(target: "stdout", "{}", &err_msg);
491
492 LlamaCoreError::Operation(err_msg)
493 })?;
494
495 format!("data: {chunk_str}\n\n")
496 };
497
498 let ending_chunk = "data: [DONE]\n\n".to_string();
500
501 let chunks = vec![context_full_chunk, usage_chunk, ending_chunk];
502
503 let stream = ChatStream::new(
504 Some(graph.name().to_owned()),
505 id,
506 include_usage,
507 Some(chunks),
508 );
509
510 Ok((stream, false))
511 }
512 Err(wasmedge_wasi_nn::Error::BackendError(
513 wasmedge_wasi_nn::BackendError::PromptTooLong,
514 )) => {
515 #[cfg(feature = "logging")]
516 warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
517
518 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
520 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
521 let err_msg = format!(
522 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
523 );
524
525 #[cfg(feature = "logging")]
526 error!(target: "stdout", "{}", &err_msg);
527
528 LlamaCoreError::Operation(err_msg)
529 })?;
530
531 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
533 let err_msg = format!("Failed to post-process the output. {e}");
534
535 #[cfg(feature = "logging")]
536 error!(target: "stdout", "{}", &err_msg);
537
538 LlamaCoreError::Operation(err_msg)
539 })?;
540
541 let token_info = get_token_info_by_graph(graph)?;
543
544 #[cfg(feature = "logging")]
545 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
546
547 let usage = Some(Usage {
548 prompt_tokens: token_info.prompt_tokens,
549 completion_tokens: token_info.completion_tokens,
550 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
551 });
552
553 let created = SystemTime::now()
554 .duration_since(std::time::UNIX_EPOCH)
555 .map_err(|e| {
556 let err_msg = format!("Failed to get the current time. Reason: {e}");
557
558 #[cfg(feature = "logging")]
559 error!(target: "stdout", "{}", &err_msg);
560
561 LlamaCoreError::Operation(err_msg)
562 })?;
563
564 let prompt_too_long_chunk = {
566 let chat_completion_chunk = ChatCompletionChunk {
567 id: id.clone(),
568 object: "chat.completion.chunk".to_string(),
569 created: created.as_secs(),
570 model: graph.name().to_owned(),
571 system_fingerprint: "fp_44709d6fcb".to_string(),
572 choices: vec![ChatCompletionChunkChoice {
573 index: 0,
574 delta: ChatCompletionChunkChoiceDelta {
575 role: ChatCompletionRole::Assistant,
576 content: Some(message),
577 tool_calls: vec![],
578 },
579 logprobs: None,
580 finish_reason: Some(FinishReason::length),
581 }],
582 usage: None,
583 };
584
585 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
587 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
588
589 #[cfg(feature = "logging")]
590 error!(target: "stdout", "{}", &err_msg);
591
592 LlamaCoreError::Operation(err_msg)
593 })?;
594
595 format!("data: {chunk_str}\n\n")
596 };
597
598 let usage_chunk = {
600 let chat_completion_chunk = ChatCompletionChunk {
601 id: id.clone(),
602 object: "chat.completion.chunk".to_string(),
603 created: created.as_secs(),
604 model: graph.name().to_owned(),
605 system_fingerprint: "fp_44709d6fcb".to_string(),
606 choices: vec![],
607 usage,
608 };
609
610 let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
612 let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
613
614 #[cfg(feature = "logging")]
615 error!(target: "stdout", "{}", &err_msg);
616
617 LlamaCoreError::Operation(err_msg)
618 })?;
619
620 format!("data: {chunk_str}\n\n")
621 };
622
623 let ending_chunk = "data: [DONE]\n\n".to_string();
625
626 let chunks = vec![prompt_too_long_chunk, usage_chunk, ending_chunk];
627
628 let stream = ChatStream::new(
629 Some(graph.name().to_owned()),
630 id,
631 include_usage,
632 Some(chunks),
633 );
634
635 Ok((stream, false))
636 }
637 Err(e) => {
638 let err_msg = format!("Failed to compute the chat completion. Reason: {e}");
639
640 #[cfg(feature = "logging")]
641 error!(target: "stdout", "{}", &err_msg);
642
643 Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
644 }
645 }
646}
647
648async fn chat_once(
649 chat_request: &mut ChatCompletionRequest,
650) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
651 #[cfg(feature = "logging")]
652 info!(target: "stdout", "Processing chat completion request in non-stream mode");
653
654 let running_mode = running_mode()?;
655 if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
656 let err_msg = "The chat completion is only supported in the chat or rag mode.";
657
658 #[cfg(feature = "logging")]
659 error!(target: "stdout", "{err_msg}");
660
661 return Err(LlamaCoreError::Operation(err_msg.to_string()));
662 }
663
664 let model_name = chat_request.model.clone();
665 let id = match &chat_request.user {
666 Some(id) => id.clone(),
667 None => gen_chat_id(),
668 };
669
670 #[cfg(feature = "logging")]
671 info!(target: "stdout", "user: {}", &id);
672
673 #[cfg(feature = "logging")]
674 info!(target: "stdout", "Check model metadata");
675
676 let mut metadata = check_model_metadata(chat_request)?;
678
679 #[cfg(feature = "logging")]
680 info!(target: "stdout", "Build the chat prompt");
681
682 let (prompt, avaible_completion_tokens, tool_use) =
684 build_prompt(model_name.as_ref(), chat_request)?;
685
686 #[cfg(feature = "logging")]
687 {
688 info!(target: "stdout", "prompt:\n{}", &prompt);
689 info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
690 info!(target: "stdout", "tool_use: {tool_use}");
691 }
692
693 #[cfg(feature = "logging")]
694 info!(target: "stdout", "Update n_predict");
695
696 update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
698
699 #[cfg(feature = "logging")]
700 info!(target: "stdout", "Feed the prompt to the model");
701
702 set_prompt(model_name.as_ref(), &prompt)?;
704
705 #[cfg(feature = "logging")]
706 info!(target: "stdout", "Compute chat completion.");
707
708 let res = compute(model_name.as_ref(), id, tool_use);
710
711 #[cfg(feature = "logging")]
712 info!(target: "stdout", "End of the chat completion");
713
714 reset_model_metadata(model_name.as_ref())?;
716
717 res
718}
719
720fn compute(
721 model_name: Option<&String>,
722 id: impl Into<String>,
723 tool_use: bool,
724) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
725 let chat_graphs = match CHAT_GRAPHS.get() {
726 Some(chat_graphs) => chat_graphs,
727 None => {
728 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
729
730 #[cfg(feature = "logging")]
731 error!(target: "stdout", "{}", &err_msg);
732
733 return Err(LlamaCoreError::Operation(err_msg.into()));
734 }
735 };
736
737 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
738 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
739
740 #[cfg(feature = "logging")]
741 error!(target: "stdout", "{}", &err_msg);
742
743 LlamaCoreError::Operation(err_msg)
744 })?;
745
746 match model_name {
747 Some(model_name) => match chat_graphs.contains_key(model_name) {
748 true => {
749 let graph = chat_graphs.get_mut(model_name).unwrap();
750 compute_by_graph(graph, id, tool_use)
751 }
752 false => match chat_graphs.iter_mut().next() {
753 Some((_, graph)) => compute_by_graph(graph, id, tool_use),
754 None => {
755 let err_msg = "There is no model available in the chat graphs.";
756
757 #[cfg(feature = "logging")]
758 error!(target: "stdout", "{}", &err_msg);
759
760 Err(LlamaCoreError::Operation(err_msg.into()))
761 }
762 },
763 },
764 None => match chat_graphs.iter_mut().next() {
765 Some((_, graph)) => compute_by_graph(graph, id, tool_use),
766 None => {
767 let err_msg = "There is no model available in the chat graphs.";
768
769 #[cfg(feature = "logging")]
770 error!(target: "stdout", "{}", &err_msg);
771
772 Err(LlamaCoreError::Operation(err_msg.into()))
773 }
774 },
775 }
776}
777
778fn compute_by_graph(
779 graph: &mut Graph<GgmlMetadata>,
780 id: impl Into<String>,
781 tool_use: bool,
782) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
783 #[cfg(feature = "logging")]
784 info!(target: "stdout", "Compute chat completion by the model named {}.", graph.name());
785
786 match graph.compute() {
787 Ok(_) => {
788 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
790 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
791 let err_msg = format!(
792 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
793 );
794
795 #[cfg(feature = "logging")]
796 error!(target: "stdout", "{}", &err_msg);
797
798 LlamaCoreError::Operation(err_msg)
799 })?;
800
801 #[cfg(feature = "logging")]
802 info!(target: "stdout", "raw generation: {output}");
803
804 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
806 LlamaCoreError::Operation(format!("Failed to post-process the output. {e}"))
807 })?;
808
809 #[cfg(feature = "logging")]
810 info!(target: "stdout", "post-processed generation:\n{}", &message);
811
812 let token_info = get_token_info_by_graph(graph)?;
814
815 #[cfg(feature = "logging")]
816 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
817
818 let created = SystemTime::now()
819 .duration_since(std::time::UNIX_EPOCH)
820 .map_err(|e| {
821 let err_msg = format!("Failed to get the current time. Reason: {e}");
822
823 #[cfg(feature = "logging")]
824 error!(target: "stdout", "{}", &err_msg);
825
826 LlamaCoreError::Operation(err_msg)
827 })?;
828
829 match tool_use {
830 true => {
831 if graph.metadata.prompt_template != PromptTemplateType::MistralTool
832 && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
833 && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
834 && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
835 && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
836 && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
837 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
838 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
839 && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
840 && graph.metadata.prompt_template != PromptTemplateType::Llama4Chat
841 && graph.metadata.prompt_template != PromptTemplateType::Qwen3NoThink
842 && graph.metadata.prompt_template != PromptTemplateType::Smol3NoThink
843 && graph.metadata.prompt_template != PromptTemplateType::Gemma3
844 && graph.metadata.prompt_template != PromptTemplateType::GptOss
845 && graph.metadata.prompt_template != PromptTemplateType::Qwen3Agent
846 && graph.metadata.prompt_template != PromptTemplateType::SeedOssNoThink
847 && graph.metadata.prompt_template != PromptTemplateType::SeedOssThink
848 {
849 let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', 'mistral-small-tool', 'llama-4-chat', 'qwen3-no-think', 'smol-3-no-think', 'gemma-3', 'gpt-oss', 'qwen3-agent', 'seed-oss-no-think', and 'seed-oss-think' prompt templates.", graph.metadata.prompt_template);
850
851 #[cfg(feature = "logging")]
852 error!(target: "stdout", "{}", &err_msg);
853
854 return Err(LlamaCoreError::Operation(err_msg));
855 }
856
857 let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
858
859 let (finish_reason, content, include_tool_calls) =
860 if parsed_result.tool_calls.is_empty() {
861 (FinishReason::stop, Some(parsed_result.raw.clone()), false)
862 } else if graph.metadata.prompt_template != PromptTemplateType::Qwen3Agent {
863 (
864 FinishReason::tool_calls,
865 Some(parsed_result.raw.clone()),
866 true,
867 )
868 } else {
869 (
870 FinishReason::tool_calls,
871 parsed_result.content.clone(),
872 true,
873 )
874 };
875
876 let res = ChatCompletionObject {
877 id: id.into(),
878 object: String::from("chat.completion"),
879 created: created.as_secs(),
880 model: graph.name().to_owned(),
881 choices: vec![ChatCompletionObjectChoice {
882 index: 0,
883 message: ChatCompletionObjectMessage {
884 role: ChatCompletionRole::Assistant,
885 content,
886 tool_calls: parsed_result.tool_calls,
887 function_call: None,
888 },
889 finish_reason,
890 logprobs: None,
891 }],
892 usage: Usage {
893 prompt_tokens: token_info.prompt_tokens,
894 completion_tokens: token_info.completion_tokens,
895 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
896 },
897 };
898
899 Ok((res, include_tool_calls))
901 }
902 false => {
903 let res = ChatCompletionObject {
905 id: id.into(),
906 object: String::from("chat.completion"),
907 created: created.as_secs(),
908 model: graph.name().to_owned(),
909 choices: vec![ChatCompletionObjectChoice {
910 index: 0,
911 message: ChatCompletionObjectMessage {
912 role: ChatCompletionRole::Assistant,
913 content: Some(message),
914 tool_calls: vec![],
915 function_call: None,
916 },
917 finish_reason: FinishReason::stop,
918 logprobs: None,
919 }],
920 usage: Usage {
921 prompt_tokens: token_info.prompt_tokens,
922 completion_tokens: token_info.completion_tokens,
923 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
924 },
925 };
926
927 Ok((res, false))
928 }
929 }
930 }
931 Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
932 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
934 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
935 let err_msg = format!(
936 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
937 );
938
939 #[cfg(feature = "logging")]
940 error!(target: "stdout", "{}", &err_msg);
941
942 LlamaCoreError::Operation(err_msg)
943 })?;
944
945 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
947 let err_msg = format!("Failed to post-process the output. {e}");
948
949 #[cfg(feature = "logging")]
950 error!(target: "stdout", "{}", &err_msg);
951
952 LlamaCoreError::Operation(err_msg)
953 })?;
954
955 let token_info = get_token_info_by_graph(graph)?;
957
958 #[cfg(feature = "logging")]
959 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
960
961 let created = SystemTime::now()
962 .duration_since(std::time::UNIX_EPOCH)
963 .map_err(|e| {
964 let err_msg = format!("Failed to get the current time. Reason: {e}");
965
966 #[cfg(feature = "logging")]
967 error!(target: "stdout", "{}", &err_msg);
968
969 LlamaCoreError::Operation(err_msg)
970 })?;
971
972 let res = ChatCompletionObject {
974 id: id.into(),
975 object: String::from("chat.completion"),
976 created: created.as_secs(),
977 model: graph.name().to_owned(),
978 choices: vec![ChatCompletionObjectChoice {
979 index: 0,
980 message: ChatCompletionObjectMessage {
981 role: ChatCompletionRole::Assistant,
982 content: Some(message),
983 tool_calls: vec![],
984 function_call: None,
985 },
986 finish_reason: FinishReason::length,
987 logprobs: None,
988 }],
989 usage: Usage {
990 prompt_tokens: token_info.prompt_tokens,
991 completion_tokens: token_info.completion_tokens,
992 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
993 },
994 };
995
996 Ok((res, false))
997 }
998 Err(wasmedge_wasi_nn::Error::BackendError(
999 wasmedge_wasi_nn::BackendError::PromptTooLong,
1000 )) => {
1001 #[cfg(feature = "logging")]
1002 warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
1003
1004 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
1006 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
1007 let err_msg = format!(
1008 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
1009 );
1010
1011 #[cfg(feature = "logging")]
1012 error!(target: "stdout", "{}", &err_msg);
1013
1014 LlamaCoreError::Operation(err_msg)
1015 })?;
1016
1017 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
1019 let err_msg = format!("Failed to post-process the output. {e}");
1020
1021 #[cfg(feature = "logging")]
1022 error!(target: "stdout", "{}", &err_msg);
1023
1024 LlamaCoreError::Operation(err_msg)
1025 })?;
1026
1027 let token_info = get_token_info_by_graph(graph)?;
1029
1030 #[cfg(feature = "logging")]
1031 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
1032
1033 let usage = Usage {
1034 prompt_tokens: token_info.prompt_tokens,
1035 completion_tokens: token_info.completion_tokens,
1036 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
1037 };
1038
1039 let created = SystemTime::now()
1040 .duration_since(std::time::UNIX_EPOCH)
1041 .map_err(|e| {
1042 let err_msg = format!("Failed to get the current time. Reason: {e}");
1043
1044 #[cfg(feature = "logging")]
1045 error!(target: "stdout", "{}", &err_msg);
1046
1047 LlamaCoreError::Operation(err_msg)
1048 })?;
1049
1050 let res = ChatCompletionObject {
1052 id: id.into(),
1053 object: String::from("chat.completion"),
1054 created: created.as_secs(),
1055 model: graph.name().to_owned(),
1056 choices: vec![ChatCompletionObjectChoice {
1057 index: 0,
1058 message: ChatCompletionObjectMessage {
1059 role: ChatCompletionRole::Assistant,
1060 content: Some(message),
1061 tool_calls: vec![],
1062 function_call: None,
1063 },
1064 finish_reason: FinishReason::length,
1065 logprobs: None,
1066 }],
1067 usage,
1068 };
1069
1070 Ok((res, false))
1071 }
1072 Err(e) => {
1073 let err_msg = format!("Failed to compute the chat completion. Reason: {e}");
1074
1075 #[cfg(feature = "logging")]
1076 error!(target: "stdout", "{}", &err_msg);
1077
1078 Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
1079 }
1080 }
1081}
1082
1083fn parse_tool_calls(
1084 input: &str,
1085 prompt_template: PromptTemplateType,
1086) -> Result<ParseResult, LlamaCoreError> {
1087 match prompt_template {
1088 PromptTemplateType::MistralTool => match regex::Regex::new(r"\[\{.*?\}\]") {
1089 Ok(re) => {
1090 let mut values: Vec<serde_json::Value> = vec![];
1091 for cap in re.captures_iter(input) {
1092 let matched = &cap[0];
1093
1094 #[cfg(feature = "logging")]
1095 info!(target: "stdout", "captured: {matched}");
1096
1097 match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1098 Ok(group) => values.extend(group),
1099 Err(e) => {
1100 let err_msg =
1101 format!("Failed to deserialize generated tool calls. Reason: {e}");
1102
1103 #[cfg(feature = "logging")]
1104 error!(target: "stdout", "{}", &err_msg);
1105
1106 return Err(LlamaCoreError::Operation(err_msg));
1107 }
1108 }
1109 }
1110
1111 let mut tool_calls: Vec<ToolCall> = vec![];
1112 for value in values.iter() {
1113 let name = match value.get("name") {
1114 Some(name) => name.to_string().replace("\"", ""),
1115 None => {
1116 let err_msg = format!(
1117 "Failed to get the name of the function. Tool call: {value:?}"
1118 );
1119
1120 #[cfg(feature = "logging")]
1121 error!(target: "stdout", "{}", &err_msg);
1122
1123 return Err(LlamaCoreError::Operation(err_msg));
1124 }
1125 };
1126
1127 let arguments = match value.get("arguments") {
1128 Some(arguments) => arguments.to_string(),
1129 None => {
1130 let err_msg = format!(
1131 "Failed to get the arguments of the function. Tool call: {value:?}"
1132 );
1133
1134 #[cfg(feature = "logging")]
1135 error!(target: "stdout", "{}", &err_msg);
1136
1137 return Err(LlamaCoreError::Operation(err_msg));
1138 }
1139 };
1140
1141 let function = Function { name, arguments };
1142
1143 let tool_call = ToolCall {
1144 id: "call_abc123".to_string(),
1145 ty: "function".to_string(),
1146 function,
1147 };
1148
1149 tool_calls.push(tool_call);
1150 }
1151
1152 let parsed = ParseResult {
1153 raw: input.to_owned(),
1154 content: None,
1155 tool_calls,
1156 };
1157
1158 #[cfg(feature = "logging")]
1159 info!(target: "stdout", "parsed result: {parsed:?}");
1160
1161 Ok(parsed)
1162 }
1163 Err(e) => {
1164 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1165
1166 #[cfg(feature = "logging")]
1167 error!(target: "stdout", "{}", &err_msg);
1168
1169 Err(LlamaCoreError::Operation(err_msg))
1170 }
1171 },
1172 PromptTemplateType::ChatMLTool => {
1173 match regex::Regex::new(r"<tool_call>(.*?)</tool_call>") {
1174 Ok(re) => {
1175 let mut values: Vec<serde_json::Value> = vec![];
1176 for cap in re.captures_iter(input) {
1177 let matched = cap[1].replace("\\n", ""); #[cfg(feature = "logging")]
1180 info!(target: "stdout", "captured: {}", &matched);
1181
1182 match serde_json::from_str::<serde_json::Value>(&matched) {
1183 Ok(value) => values.push(value),
1184 Err(e) => {
1185 let err_msg = format!(
1186 "Failed to deserialize generated tool calls. Reason: {e}"
1187 );
1188
1189 #[cfg(feature = "logging")]
1190 error!(target: "stdout", "{}", &err_msg);
1191
1192 return Err(LlamaCoreError::Operation(err_msg));
1193 }
1194 }
1195 }
1196
1197 let mut tool_calls: Vec<ToolCall> = vec![];
1198 for value in values.iter() {
1199 let name = match value.get("name") {
1200 Some(name) => name.to_string().replace("\"", ""),
1201 None => {
1202 let err_msg = format!(
1203 "Failed to get the name of the function. Tool call: {value:?}"
1204 );
1205
1206 #[cfg(feature = "logging")]
1207 error!(target: "stdout", "{}", &err_msg);
1208
1209 return Err(LlamaCoreError::Operation(err_msg));
1210 }
1211 };
1212
1213 let arguments = match value.get("arguments") {
1214 Some(arguments) => arguments.to_string(),
1215 None => {
1216 let err_msg = format!(
1217 "Failed to get the arguments of the function. Tool call: {value:?}"
1218 );
1219
1220 #[cfg(feature = "logging")]
1221 error!(target: "stdout", "{}", &err_msg);
1222
1223 return Err(LlamaCoreError::Operation(err_msg));
1224 }
1225 };
1226
1227 let function = Function { name, arguments };
1228
1229 let tool_call = ToolCall {
1230 id: "call_abc123".to_string(),
1231 ty: "function".to_string(),
1232 function,
1233 };
1234
1235 tool_calls.push(tool_call);
1236 }
1237
1238 let parsed = ParseResult {
1239 raw: input.to_owned(),
1240 content: None,
1241 tool_calls,
1242 };
1243
1244 #[cfg(feature = "logging")]
1245 info!(target: "stdout", "parsed result: {parsed:?}");
1246
1247 Ok(parsed)
1248 }
1249 Err(e) => {
1250 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1251
1252 #[cfg(feature = "logging")]
1253 error!(target: "stdout", "{}", &err_msg);
1254
1255 Err(LlamaCoreError::Operation(err_msg))
1256 }
1257 }
1258 }
1259 PromptTemplateType::GroqLlama3Tool => {
1260 #[cfg(feature = "logging")]
1261 info!(target: "stdout", "raw input: {input}");
1262
1263 match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
1264 Ok(re) => {
1265 let mut values: Vec<serde_json::Value> = vec![];
1266 for cap in re.captures_iter(input) {
1267 let matched = cap[1].trim();
1268
1269 #[cfg(feature = "logging")]
1270 info!(target: "stdout", "captured: {matched}");
1271
1272 match serde_json::from_str::<serde_json::Value>(matched) {
1273 Ok(value) => values.push(value),
1274 Err(e) => {
1275 let err_msg = format!(
1276 "Failed to deserialize generated tool calls. Reason: {e}"
1277 );
1278
1279 #[cfg(feature = "logging")]
1280 error!(target: "stdout", "{}", &err_msg);
1281
1282 return Err(LlamaCoreError::Operation(err_msg));
1283 }
1284 }
1285 }
1286
1287 let mut tool_calls: Vec<ToolCall> = vec![];
1288 for value in values.iter() {
1289 let name = match value.get("name") {
1290 Some(name) => name.to_string().replace("\"", ""),
1291 None => {
1292 let err_msg = format!(
1293 "Failed to get the name of the function. Tool call: {value:?}"
1294 );
1295
1296 #[cfg(feature = "logging")]
1297 error!(target: "stdout", "{}", &err_msg);
1298
1299 return Err(LlamaCoreError::Operation(err_msg));
1300 }
1301 };
1302
1303 let arguments = match value.get("arguments") {
1304 Some(arguments) => {
1305 if arguments.is_string() {
1306 arguments.as_str().unwrap().to_string()
1307 } else if arguments.is_object() {
1308 let map = arguments.as_object().unwrap();
1309
1310 #[cfg(feature = "logging")]
1311 info!(target: "stdout", "func arguments: {map:?}");
1312
1313 serde_json::to_string(map).unwrap()
1314 } else {
1315 serde_json::to_string(arguments).unwrap()
1316 }
1317 }
1318 None => {
1319 let err_msg = format!(
1320 "Failed to get the arguments of the function. Tool call: {value:?}"
1321 );
1322
1323 #[cfg(feature = "logging")]
1324 error!(target: "stdout", "{}", &err_msg);
1325
1326 return Err(LlamaCoreError::Operation(err_msg));
1327 }
1328 };
1329
1330 let function = Function { name, arguments };
1331
1332 let tool_call = ToolCall {
1333 id: "call_abc123".to_string(),
1334 ty: "function".to_string(),
1335 function,
1336 };
1337
1338 tool_calls.push(tool_call);
1339 }
1340
1341 let parsed = if tool_calls.is_empty() {
1342 ParseResult {
1343 raw: input.to_owned(),
1344 content: Some(input.to_owned()),
1345 tool_calls: vec![],
1346 }
1347 } else {
1348 ParseResult {
1349 raw: input.to_owned(),
1350 content: None,
1351 tool_calls,
1352 }
1353 };
1354
1355 #[cfg(feature = "logging")]
1356 info!(target: "stdout", "parsed result: {parsed:?}");
1357
1358 Ok(parsed)
1359 }
1360 Err(e) => {
1361 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1362
1363 #[cfg(feature = "logging")]
1364 error!(target: "stdout", "{}", &err_msg);
1365
1366 Err(LlamaCoreError::Operation(err_msg))
1367 }
1368 }
1369 }
1370 PromptTemplateType::Llama3Tool => {
1371 #[cfg(feature = "logging")]
1372 info!(target: "stdout", "raw input: {input}");
1373
1374 let re = match regex::Regex::new(r"^\{(.|\r|\n)*\}$") {
1375 Ok(re) => re,
1376 Err(e) => {
1377 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1378
1379 #[cfg(feature = "logging")]
1380 error!(target: "stdout", "{}", &err_msg);
1381
1382 return Err(LlamaCoreError::Operation(err_msg));
1383 }
1384 };
1385
1386 if re.is_match(input) {
1387 match serde_json::from_str::<serde_json::Value>(input) {
1388 Ok(value) => {
1389 let values: Vec<serde_json::Value> = vec![value];
1390
1391 let mut tool_calls: Vec<ToolCall> = vec![];
1392 for value in values.iter() {
1393 let name = match value.get("name") {
1394 Some(name) => name.to_string().replace("\"", ""),
1395 None => {
1396 let err_msg = format!(
1397 "Failed to get the name of the function. Tool call: {value:?}"
1398 );
1399
1400 #[cfg(feature = "logging")]
1401 error!(target: "stdout", "{}", &err_msg);
1402
1403 return Err(LlamaCoreError::Operation(err_msg));
1404 }
1405 };
1406
1407 let arguments = match value.get("parameters") {
1408 Some(arguments) => arguments.to_string(),
1409 None => {
1410 let err_msg = format!(
1411 "Failed to get the arguments of the function. Tool call: {value:?}"
1412 );
1413
1414 #[cfg(feature = "logging")]
1415 error!(target: "stdout", "{}", &err_msg);
1416
1417 return Err(LlamaCoreError::Operation(err_msg));
1418 }
1419 };
1420
1421 let function = Function { name, arguments };
1422
1423 let tool_call = ToolCall {
1424 id: "call_abc123".to_string(),
1425 ty: "function".to_string(),
1426 function,
1427 };
1428
1429 tool_calls.push(tool_call);
1430 }
1431
1432 let parsed = ParseResult {
1433 raw: input.to_owned(),
1434 content: None,
1435 tool_calls,
1436 };
1437
1438 #[cfg(feature = "logging")]
1439 info!(target: "stdout", "parsed result: {parsed:?}");
1440
1441 Ok(parsed)
1442 }
1443 Err(e) => {
1444 let err_msg =
1445 format!("Failed to deserialize generated tool calls. Reason: {e}");
1446
1447 #[cfg(feature = "logging")]
1448 error!(target: "stdout", "{}", &err_msg);
1449
1450 Err(LlamaCoreError::Operation(err_msg))
1451 }
1452 }
1453 } else {
1454 let parsed = ParseResult {
1455 raw: input.to_owned(),
1456 content: None,
1457 tool_calls: vec![],
1458 };
1459
1460 #[cfg(feature = "logging")]
1461 info!(target: "stdout", "parsed result: {parsed:?}");
1462
1463 Ok(parsed)
1464 }
1465 }
1466 PromptTemplateType::InternLM2Tool => {
1467 #[cfg(feature = "logging")]
1468 info!(target: "stdout", "raw input: {input}");
1469
1470 let blocks: Vec<&str> = input.trim().split("<|action_start|><|plugin|>").collect();
1471
1472 #[cfg(feature = "logging")]
1473 info!(target: "stdout", "blocks: {blocks:?}");
1474
1475 let mut tool_calls: Vec<ToolCall> = vec![];
1476 let mut content = String::new();
1477 for block in blocks {
1478 let block = block.trim();
1479 if !block.is_empty() {
1480 if block.ends_with("<|action_end|>") {
1481 let value = block.trim().trim_end_matches("<|action_end|>");
1482
1483 #[cfg(feature = "logging")]
1484 info!(target: "stdout", "tool call: {value}");
1485
1486 match serde_json::from_str::<serde_json::Value>(value) {
1487 Ok(value) => {
1488 let name = match value.get("name") {
1489 Some(name) => name.to_string().replace("\"", ""),
1490 None => {
1491 let err_msg = format!(
1492 "Failed to get the name of the function. Tool call: {value:?}"
1493 );
1494
1495 #[cfg(feature = "logging")]
1496 error!(target: "stdout", "{}", &err_msg);
1497
1498 return Err(LlamaCoreError::Operation(err_msg));
1499 }
1500 };
1501
1502 let arguments = match value.get("parameters") {
1503 Some(arguments) => arguments.to_string(),
1504 None => {
1505 let err_msg = format!(
1506 "Failed to get the arguments of the function. Tool call: {value:?}"
1507 );
1508
1509 #[cfg(feature = "logging")]
1510 error!(target: "stdout", "{}", &err_msg);
1511
1512 return Err(LlamaCoreError::Operation(err_msg));
1513 }
1514 };
1515
1516 let function = Function { name, arguments };
1517
1518 let tool_call = ToolCall {
1519 id: "call_abc123".to_string(),
1520 ty: "function".to_string(),
1521 function,
1522 };
1523
1524 tool_calls.push(tool_call);
1525 }
1526 Err(e) => {
1527 let err_msg = format!(
1528 "Failed to deserialize generated tool calls. Reason: {e}"
1529 );
1530
1531 #[cfg(feature = "logging")]
1532 error!(target: "stdout", "{}", &err_msg);
1533
1534 return Err(LlamaCoreError::Operation(err_msg));
1535 }
1536 }
1537 } else {
1538 content.push_str(block);
1539 content.push('\n');
1540 }
1541 }
1542 }
1543
1544 let parsed = match content.is_empty() {
1545 true => ParseResult {
1546 raw: input.to_owned(),
1547 content: None,
1548 tool_calls,
1549 },
1550 false => ParseResult {
1551 raw: input.to_owned(),
1552 content: Some(content.trim().to_owned()),
1553 tool_calls,
1554 },
1555 };
1556
1557 #[cfg(feature = "logging")]
1558 info!(target: "stdout", "parsed result: {parsed:?}");
1559
1560 Ok(parsed)
1561 }
1562 PromptTemplateType::NemotronTool => {
1563 #[cfg(feature = "logging")]
1564 info!(target: "stdout", "raw input: {input}");
1565
1566 match regex::Regex::new(r"(?s)<toolcall>\s*(.*?)\s*</toolcall>") {
1567 Ok(re) => {
1568 let mut values: Vec<serde_json::Value> = vec![];
1569 for cap in re.captures_iter(input) {
1570 #[cfg(feature = "logging")]
1571 info!(target: "stdout", "captured: {}", &cap[0]);
1572
1573 #[cfg(feature = "logging")]
1574 info!(target: "stdout", "extracted: {}", &cap[1]);
1575
1576 let matched = cap[1].trim();
1577
1578 #[cfg(feature = "logging")]
1579 info!(target: "stdout", "captured: {matched}");
1580
1581 match serde_json::from_str::<serde_json::Value>(matched) {
1582 Ok(value) => values.push(value),
1583 Err(e) => {
1584 let err_msg = format!(
1585 "Failed to deserialize generated tool calls. Reason: {e}"
1586 );
1587
1588 #[cfg(feature = "logging")]
1589 error!(target: "stdout", "{}", &err_msg);
1590
1591 return Err(LlamaCoreError::Operation(err_msg));
1592 }
1593 }
1594 }
1595
1596 let mut tool_calls: Vec<ToolCall> = vec![];
1597 for value in values.iter() {
1598 let name = match value.get("name") {
1599 Some(name) => name.to_string().replace("\"", ""),
1600 None => {
1601 let err_msg = format!(
1602 "Failed to get the name of the function. Tool call: {value:?}"
1603 );
1604
1605 #[cfg(feature = "logging")]
1606 error!(target: "stdout", "{}", &err_msg);
1607
1608 return Err(LlamaCoreError::Operation(err_msg));
1609 }
1610 };
1611
1612 let arguments = match value.get("arguments") {
1613 Some(arguments) => arguments.to_string(),
1614 None => {
1615 let err_msg = format!(
1616 "Failed to get the arguments of the function. Tool call: {value:?}"
1617 );
1618
1619 #[cfg(feature = "logging")]
1620 error!(target: "stdout", "{}", &err_msg);
1621
1622 return Err(LlamaCoreError::Operation(err_msg));
1623 }
1624 };
1625
1626 let function = Function { name, arguments };
1627
1628 let tool_call = ToolCall {
1629 id: "call_abc123".to_string(),
1630 ty: "function".to_string(),
1631 function,
1632 };
1633
1634 tool_calls.push(tool_call);
1635 }
1636
1637 let parsed = ParseResult {
1638 raw: input.to_owned(),
1639 content: None,
1640 tool_calls,
1641 };
1642
1643 #[cfg(feature = "logging")]
1644 info!(target: "stdout", "parsed result: {parsed:?}");
1645
1646 Ok(parsed)
1647 }
1648 Err(e) => {
1649 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1650
1651 #[cfg(feature = "logging")]
1652 error!(target: "stdout", "{}", &err_msg);
1653
1654 Err(LlamaCoreError::Operation(err_msg))
1655 }
1656 }
1657 }
1658 PromptTemplateType::FunctionaryV32 => {
1659 #[cfg(feature = "logging")]
1660 info!(target: "stdout", "raw input: {input}");
1661
1662 match regex::Regex::new(r">>>\s*(\w+)\s*\{(.*)\}<\|eot_id\|>") {
1663 Ok(re) => {
1664 let mut tool_calls: Vec<ToolCall> = vec![];
1665 for cap in re.captures_iter(input) {
1666 #[cfg(feature = "logging")]
1667 info!(target: "stdout", "func_name: {}", &cap[1]);
1668
1669 #[cfg(feature = "logging")]
1670 info!(target: "stdout", "arguments: {}", &cap[2]);
1671
1672 let tool_call = ToolCall {
1673 id: "call_abc123".to_string(),
1674 ty: "function".to_string(),
1675 function: Function {
1676 name: cap[1].to_string(),
1677 arguments: cap[2].to_string(),
1678 },
1679 };
1680
1681 tool_calls.push(tool_call);
1682 }
1683
1684 let parsed = ParseResult {
1685 raw: input.to_owned(),
1686 content: None,
1687 tool_calls,
1688 };
1689
1690 #[cfg(feature = "logging")]
1691 info!(target: "stdout", "parsed result: {parsed:?}");
1692
1693 Ok(parsed)
1694 }
1695 Err(e) => {
1696 let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1697
1698 #[cfg(feature = "logging")]
1699 warn!(target: "stdout", "{}", &warn_msg);
1700
1701 Ok(ParseResult {
1702 raw: input.to_owned(),
1703 content: None,
1704 tool_calls: vec![],
1705 })
1706 }
1707 }
1708 }
1709 PromptTemplateType::FunctionaryV31 => {
1710 #[cfg(feature = "logging")]
1711 info!(target: "stdout", "raw input: {input}");
1712
1713 match regex::Regex::new(r"<function=(\w+)>\s*(\{.*?\})</function>") {
1714 Ok(re) => {
1715 let mut tool_calls: Vec<ToolCall> = vec![];
1716 for cap in re.captures_iter(input) {
1717 #[cfg(feature = "logging")]
1718 info!(target: "stdout", "func_name: {}", &cap[1]);
1719
1720 #[cfg(feature = "logging")]
1721 info!(target: "stdout", "arguments: {}", &cap[2]);
1722
1723 let tool_call = ToolCall {
1724 id: "call_abc123".to_string(),
1725 ty: "function".to_string(),
1726 function: Function {
1727 name: cap[1].to_string(),
1728 arguments: cap[2].to_string(),
1729 },
1730 };
1731
1732 tool_calls.push(tool_call);
1733 }
1734
1735 let parsed = ParseResult {
1736 raw: input.to_owned(),
1737 content: None,
1738 tool_calls,
1739 };
1740
1741 #[cfg(feature = "logging")]
1742 info!(target: "stdout", "parsed result: {parsed:?}");
1743
1744 Ok(parsed)
1745 }
1746 Err(e) => {
1747 let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1748
1749 #[cfg(feature = "logging")]
1750 warn!(target: "stdout", "{}", &warn_msg);
1751
1752 Ok(ParseResult {
1753 raw: input.to_owned(),
1754 content: None,
1755 tool_calls: vec![],
1756 })
1757 }
1758 }
1759 }
1760 PromptTemplateType::MistralSmallTool => {
1761 #[cfg(feature = "logging")]
1762 info!(target: "stdout", "raw input: {input}");
1763
1764 match regex::Regex::new(r"\[TOOL_CALLS\]\s*(\[(.*?)\])") {
1765 Ok(re) => {
1766 let mut values: Vec<serde_json::Value> = vec![];
1767 if let Some(cap) = re.captures(input) {
1768 let matched = cap[1].trim();
1769
1770 #[cfg(feature = "logging")]
1771 info!(target: "stdout", "captured: {matched}");
1772
1773 match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1774 Ok(vals) => values = vals,
1775 Err(e) => {
1776 let err_msg = format!(
1777 "Failed to deserialize generated tool calls. Reason: {e}"
1778 );
1779
1780 #[cfg(feature = "logging")]
1781 error!(target: "stdout", "{}", &err_msg);
1782
1783 return Err(LlamaCoreError::Operation(err_msg));
1784 }
1785 }
1786 };
1787
1788 let mut tool_calls: Vec<ToolCall> = vec![];
1789 for value in values.iter() {
1790 if let Some(object_map) = value.as_object() {
1791 if object_map.contains_key("function") {
1792 let mut function = Function {
1793 name: String::new(),
1794 arguments: String::new(),
1795 };
1796
1797 let value = object_map.get("function").unwrap();
1798 let func_map = value.as_object().unwrap();
1799 if func_map.contains_key("name") {
1800 let func_name = func_map.get("name").unwrap().as_str().unwrap();
1801 println!("Function name: {func_name:?}");
1802
1803 function.name = func_name.to_string();
1804 }
1805 if func_map.contains_key("arguments") {
1806 let args = func_map.get("arguments").unwrap();
1807 let arguments = args.to_string();
1808 println!("Arguments: {arguments:?}");
1809
1810 function.arguments = arguments;
1811 }
1812
1813 let tool_call = ToolCall {
1814 id: "call_abc123".to_string(),
1815 ty: "function".to_string(),
1816 function,
1817 };
1818
1819 tool_calls.push(tool_call);
1820 } else if object_map.contains_key("name") {
1821 let mut function = Function {
1822 name: String::new(),
1823 arguments: String::new(),
1824 };
1825
1826 let name = object_map.get("name").unwrap().as_str().unwrap();
1827 println!("name: {name:?}");
1828 function.name = name.to_string();
1829
1830 if object_map.contains_key("arguments") {
1831 let args = object_map.get("arguments").unwrap();
1832 let arguments = args.to_string();
1833 println!("Arguments: {arguments:?}");
1834
1835 function.arguments = arguments;
1836 }
1837
1838 let tool_call = ToolCall {
1839 id: "call_abc123".to_string(),
1840 ty: "function".to_string(),
1841 function,
1842 };
1843
1844 tool_calls.push(tool_call);
1845 }
1846 }
1847 }
1848
1849 let parsed = ParseResult {
1850 raw: input.to_owned(),
1851 content: None,
1852 tool_calls,
1853 };
1854
1855 #[cfg(feature = "logging")]
1856 info!(target: "stdout", "parsed result: {parsed:?}");
1857
1858 Ok(parsed)
1859 }
1860 Err(e) => {
1861 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1862
1863 #[cfg(feature = "logging")]
1864 error!(target: "stdout", "{}", &err_msg);
1865
1866 Err(LlamaCoreError::Operation(err_msg))
1867 }
1868 }
1869 }
1870 PromptTemplateType::Llama4Chat => {
1871 #[cfg(feature = "logging")]
1872 info!(target: "stdout", "raw input: {input:?}");
1873
1874 let mut tool_calls: Vec<ToolCall> = vec![];
1875 if let Ok(value) = serde_json::from_str::<serde_json::Value>(input) {
1876 match value.as_object() {
1877 Some(object_map) => {
1878 #[cfg(feature = "logging")]
1879 debug!(target: "stdout", "object_map: {object_map:?}");
1880
1881 if object_map.contains_key("name") {
1883 let name = object_map.get("name").unwrap().as_str().unwrap();
1884
1885 #[cfg(feature = "logging")]
1886 debug!(target: "stdout", "name: {name:?}");
1887
1888 let mut function = Function {
1889 name: name.to_string(),
1890 arguments: String::new(),
1891 };
1892
1893 if object_map.contains_key("parameters") {
1895 let args = object_map.get("parameters").unwrap();
1896 let arguments = args.to_string();
1897
1898 #[cfg(feature = "logging")]
1899 debug!(target: "stdout", "arguments: {:?}", &arguments);
1900
1901 function.arguments = arguments;
1902 }
1903
1904 tool_calls.push(ToolCall {
1905 id: "call_abc123".to_string(),
1906 ty: "function".to_string(),
1907 function,
1908 });
1909 } else {
1910 let err_msg = format!(
1911 "Failed to get the name of the function. raw input: {input:?}"
1912 );
1913
1914 #[cfg(feature = "logging")]
1915 error!(target: "stdout", "{}", &err_msg);
1916
1917 return Err(LlamaCoreError::Operation(err_msg));
1918 }
1919 }
1920 None => {
1921 let err_msg = format!("Failed to parse the JSON string. JSON: {input}");
1922
1923 #[cfg(feature = "logging")]
1924 error!(target: "stdout", "{}", &err_msg);
1925
1926 return Err(LlamaCoreError::Operation(err_msg));
1927 }
1928 }
1929 }
1930
1931 let parsed = ParseResult {
1932 raw: input.to_owned(),
1933 content: None,
1934 tool_calls,
1935 };
1936
1937 #[cfg(feature = "logging")]
1938 info!(target: "stdout", "parsed result: {parsed:?}");
1939
1940 Ok(parsed)
1941 }
1942 PromptTemplateType::Qwen3NoThink | PromptTemplateType::Smol3NoThink => {
1943 #[cfg(feature = "logging")]
1944 info!(target: "stdout", "raw input: {input:?}");
1945
1946 match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
1947 Ok(re) => {
1948 let mut values: Vec<serde_json::Value> = vec![];
1949 for cap in re.captures_iter(input) {
1950 let mut matched = cap[1].trim();
1951
1952 if matched.starts_with("\\n") {
1953 matched = matched.trim_start_matches("\\n");
1954 }
1955
1956 if matched.ends_with("\\n") {
1957 matched = matched.trim_end_matches("\\n");
1958 }
1959
1960 #[cfg(feature = "logging")]
1961 info!(target: "stdout", "captured: {matched:#?}");
1962
1963 if !matched.is_empty() {
1964 match serde_json::from_str::<serde_json::Value>(matched) {
1965 Ok(value) => values.push(value),
1966 Err(e) => {
1967 let err_msg = format!(
1968 "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
1969 );
1970
1971 #[cfg(feature = "logging")]
1972 error!(target: "stdout", "{}", &err_msg);
1973
1974 return Err(LlamaCoreError::Operation(err_msg));
1975 }
1976 }
1977 }
1978 }
1979
1980 let mut tool_calls: Vec<ToolCall> = vec![];
1981 for value in values.iter() {
1982 let name = match value.get("name") {
1983 Some(name) => name.to_string().replace("\"", ""),
1984 None => {
1985 let err_msg = format!(
1986 "Failed to get the name of the function. Tool call: {value:?}"
1987 );
1988
1989 #[cfg(feature = "logging")]
1990 error!(target: "stdout", "{}", &err_msg);
1991
1992 return Err(LlamaCoreError::Operation(err_msg));
1993 }
1994 };
1995
1996 let arguments = match value.get("arguments") {
1997 Some(arguments) => {
1998 if arguments.is_string() {
1999 arguments.as_str().unwrap().to_string()
2000 } else if arguments.is_object() {
2001 let map = arguments.as_object().unwrap();
2002
2003 #[cfg(feature = "logging")]
2004 info!(target: "stdout", "func arguments: {map:?}");
2005
2006 serde_json::to_string(map).unwrap()
2007 } else {
2008 serde_json::to_string(arguments).unwrap()
2009 }
2010 }
2011 None => {
2012 let err_msg = format!(
2013 "Failed to get the arguments of the function. Tool call: {value:?}"
2014 );
2015
2016 #[cfg(feature = "logging")]
2017 error!(target: "stdout", "{}", &err_msg);
2018
2019 return Err(LlamaCoreError::Operation(err_msg));
2020 }
2021 };
2022
2023 let function = Function { name, arguments };
2024
2025 let tool_call = ToolCall {
2026 id: "call_abc123".to_string(),
2027 ty: "function".to_string(),
2028 function,
2029 };
2030
2031 tool_calls.push(tool_call);
2032 }
2033
2034 let parsed = if tool_calls.is_empty() {
2035 ParseResult {
2036 raw: input.to_owned(),
2037 content: Some(input.to_owned()),
2038 tool_calls: vec![],
2039 }
2040 } else {
2041 ParseResult {
2042 raw: input.to_owned(),
2043 content: None,
2044 tool_calls,
2045 }
2046 };
2047
2048 #[cfg(feature = "logging")]
2049 info!(target: "stdout", "parsed result: {parsed:?}");
2050
2051 Ok(parsed)
2052 }
2053 Err(e) => {
2054 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2055
2056 #[cfg(feature = "logging")]
2057 error!(target: "stdout", "{}", &err_msg);
2058
2059 Err(LlamaCoreError::Operation(err_msg))
2060 }
2061 }
2062 }
2063 PromptTemplateType::Gemma3 => {
2064 #[cfg(feature = "logging")]
2065 info!(target: "stdout", "raw input: {input:?}");
2066
2067 match regex::Regex::new(r"(?s)```json\s*(.*?)\s*```") {
2068 Ok(re) => {
2069 let mut values: Vec<serde_json::Value> = vec![];
2070 for cap in re.captures_iter(input) {
2071 let mut matched = cap[1].trim();
2072
2073 if matched.starts_with("\\n") {
2074 matched = matched.trim_start_matches("\\n");
2075 }
2076
2077 if matched.ends_with("\\n") {
2078 matched = matched.trim_end_matches("\\n");
2079 }
2080
2081 #[cfg(feature = "logging")]
2082 info!(target: "stdout", "captured: {matched:#?}");
2083
2084 if !matched.is_empty() {
2085 match serde_json::from_str::<serde_json::Value>(matched) {
2086 Ok(value) => values.push(value),
2087 Err(e) => {
2088 let err_msg = format!(
2089 "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
2090 );
2091
2092 #[cfg(feature = "logging")]
2093 error!(target: "stdout", "{}", &err_msg);
2094
2095 return Err(LlamaCoreError::Operation(err_msg));
2096 }
2097 }
2098 }
2099 }
2100
2101 let mut tool_calls: Vec<ToolCall> = vec![];
2102 for value in values.iter() {
2103 let name = match value.get("name") {
2104 Some(name) => name.to_string().replace("\"", ""),
2105 None => {
2106 let err_msg = format!(
2107 "Failed to get the name of the function. Tool call: {value:?}"
2108 );
2109
2110 #[cfg(feature = "logging")]
2111 error!(target: "stdout", "{}", &err_msg);
2112
2113 return Err(LlamaCoreError::Operation(err_msg));
2114 }
2115 };
2116
2117 let arguments = match value.get("arguments") {
2118 Some(arguments) => {
2119 if arguments.is_string() {
2120 arguments.as_str().unwrap().to_string()
2121 } else if arguments.is_object() {
2122 let map = arguments.as_object().unwrap();
2123
2124 #[cfg(feature = "logging")]
2125 info!(target: "stdout", "func arguments: {map:?}");
2126
2127 serde_json::to_string(map).unwrap()
2128 } else {
2129 serde_json::to_string(arguments).unwrap()
2130 }
2131 }
2132 None => {
2133 let err_msg = format!(
2134 "Failed to get the arguments of the function. Tool call: {value:?}"
2135 );
2136
2137 #[cfg(feature = "logging")]
2138 error!(target: "stdout", "{}", &err_msg);
2139
2140 return Err(LlamaCoreError::Operation(err_msg));
2141 }
2142 };
2143
2144 let function = Function { name, arguments };
2145
2146 let tool_call = ToolCall {
2147 id: "call_abc123".to_string(),
2148 ty: "function".to_string(),
2149 function,
2150 };
2151
2152 tool_calls.push(tool_call);
2153 }
2154
2155 let parsed = if tool_calls.is_empty() {
2156 ParseResult {
2157 raw: input.to_owned(),
2158 content: Some(input.to_owned()),
2159 tool_calls: vec![],
2160 }
2161 } else {
2162 ParseResult {
2163 raw: input.to_owned(),
2164 content: None,
2165 tool_calls,
2166 }
2167 };
2168
2169 #[cfg(feature = "logging")]
2170 info!(target: "stdout", "parsed result: {parsed:?}");
2171
2172 Ok(parsed)
2173 }
2174 Err(e) => {
2175 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2176
2177 #[cfg(feature = "logging")]
2178 error!(target: "stdout", "{}", &err_msg);
2179
2180 Err(LlamaCoreError::Operation(err_msg))
2181 }
2182 }
2183 }
2184 PromptTemplateType::GptOss => {
2185 #[cfg(feature = "logging")]
2186 info!(target: "stdout", "raw input: {input:?}");
2187
2188 match regex::Regex::new(
2190 r"<\|channel\|>commentary to=functions\.([^<\s]+)\s*<\|constrain\|>json<\|message\|>([^<]*)<\|call\|>$",
2191 ) {
2192 Ok(re) => {
2193 if let Some(cap) = re.captures(input) {
2194 let function_name = cap[1].trim();
2195 let arguments = cap[2].trim();
2196
2197 #[cfg(feature = "logging")]
2198 info!(target: "stdout", "extracted function_name: {function_name}, arguments: {arguments}");
2199
2200 let function = Function {
2201 name: function_name.to_string(),
2202 arguments: arguments.to_string(),
2203 };
2204
2205 let tool_call = ToolCall {
2206 id: "call_abc123".to_string(),
2207 ty: "function".to_string(),
2208 function,
2209 };
2210
2211 let parsed = ParseResult {
2212 raw: input.to_owned(),
2213 content: None,
2214 tool_calls: vec![tool_call],
2215 };
2216
2217 #[cfg(feature = "logging")]
2218 info!(target: "stdout", "parsed result: {parsed:?}");
2219
2220 Ok(parsed)
2221 } else {
2222 match regex::Regex::new(r"(?s)```json\s*(.*?)\s*```") {
2223 Ok(re) => {
2224 let mut values: Vec<serde_json::Value> = vec![];
2225 for cap in re.captures_iter(input) {
2226 let mut matched = cap[1].trim();
2227
2228 if matched.starts_with("\\n") {
2229 matched = matched.trim_start_matches("\\n");
2230 }
2231
2232 if matched.ends_with("\\n") {
2233 matched = matched.trim_end_matches("\\n");
2234 }
2235
2236 #[cfg(feature = "logging")]
2237 info!(target: "stdout", "captured: {matched:#?}");
2238
2239 if !matched.is_empty() {
2240 match serde_json::from_str::<serde_json::Value>(matched) {
2241 Ok(value) => values.push(value),
2242 Err(e) => {
2243 let err_msg = format!(
2244 "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
2245 );
2246
2247 #[cfg(feature = "logging")]
2248 error!(target: "stdout", "{}", &err_msg);
2249
2250 return Err(LlamaCoreError::Operation(err_msg));
2251 }
2252 }
2253 }
2254 }
2255
2256 let mut tool_calls: Vec<ToolCall> = vec![];
2257 for value in values.iter() {
2258 let name = match value.get("name") {
2259 Some(name) => name.to_string().replace("\"", ""),
2260 None => {
2261 let err_msg = format!(
2262 "Failed to get the name of the function. Tool call: {value:?}"
2263 );
2264
2265 #[cfg(feature = "logging")]
2266 error!(target: "stdout", "{}", &err_msg);
2267
2268 return Err(LlamaCoreError::Operation(err_msg));
2269 }
2270 };
2271
2272 let arguments = match value.get("arguments") {
2273 Some(arguments) => {
2274 if arguments.is_string() {
2275 arguments.as_str().unwrap().to_string()
2276 } else if arguments.is_object() {
2277 let map = arguments.as_object().unwrap();
2278
2279 #[cfg(feature = "logging")]
2280 info!(target: "stdout", "func arguments: {map:?}");
2281
2282 serde_json::to_string(map).unwrap()
2283 } else {
2284 serde_json::to_string(arguments).unwrap()
2285 }
2286 }
2287 None => {
2288 let err_msg = format!(
2289 "Failed to get the arguments of the function. Tool call: {value:?}"
2290 );
2291
2292 #[cfg(feature = "logging")]
2293 error!(target: "stdout", "{}", &err_msg);
2294
2295 return Err(LlamaCoreError::Operation(err_msg));
2296 }
2297 };
2298
2299 let function = Function { name, arguments };
2300
2301 let tool_call = ToolCall {
2302 id: "call_abc123".to_string(),
2303 ty: "function".to_string(),
2304 function,
2305 };
2306
2307 tool_calls.push(tool_call);
2308 }
2309
2310 let parsed = if tool_calls.is_empty() {
2311 ParseResult {
2312 raw: input.to_owned(),
2313 content: Some(input.to_owned()),
2314 tool_calls: vec![],
2315 }
2316 } else {
2317 ParseResult {
2318 raw: input.to_owned(),
2319 content: Some(input.to_owned()),
2320 tool_calls,
2321 }
2322 };
2323
2324 #[cfg(feature = "logging")]
2325 info!(target: "stdout", "parsed result: {parsed:?}");
2326
2327 Ok(parsed)
2328 }
2329 Err(e) => {
2330 let err_msg =
2331 format!("Failed to create a regex pattern. Reason: {e}");
2332
2333 #[cfg(feature = "logging")]
2334 error!(target: "stdout", "{}", &err_msg);
2335
2336 Err(LlamaCoreError::Operation(err_msg))
2337 }
2338 }
2339 }
2340 }
2341 Err(e) => {
2342 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2343
2344 #[cfg(feature = "logging")]
2345 error!(target: "stdout", "{}", &err_msg);
2346
2347 Err(LlamaCoreError::Operation(err_msg))
2348 }
2349 }
2350 }
2351 PromptTemplateType::Qwen3Agent => {
2352 #[cfg(feature = "logging")]
2353 info!(target: "stdout", "Raw input to tool call parser: {input:?}");
2354
2355 match regex::Regex::new(r"<action>(.*?)</action>")
2357 .unwrap()
2358 .captures(input)
2359 {
2360 Some(captures) => {
2361 let action = captures.get(1).unwrap().as_str();
2362
2363 #[cfg(feature = "logging")]
2364 info!(target: "stdout", "Action: {action}");
2365
2366 match serde_json::from_str::<serde_json::Value>(action) {
2367 Ok(value) => {
2368 let name = match value.get("name") {
2369 Some(name) => name.to_string().replace("\"", ""),
2370 None => {
2371 let err_msg = format!(
2372 "Failed to get the name of the function. Tool call: {value:?}"
2373 );
2374
2375 #[cfg(feature = "logging")]
2376 error!(target: "stdout", "{}", &err_msg);
2377
2378 return Err(LlamaCoreError::Operation(err_msg));
2379 }
2380 };
2381
2382 let arguments = match value.get("arguments") {
2383 Some(arguments) => {
2384 if arguments.is_string() {
2385 arguments.as_str().unwrap().to_string()
2386 } else if arguments.is_object() {
2387 let map = arguments.as_object().unwrap();
2388
2389 #[cfg(feature = "logging")]
2390 info!(target: "stdout", "func arguments: {map:?}");
2391
2392 serde_json::to_string(map).unwrap()
2393 } else {
2394 serde_json::to_string(arguments).unwrap()
2395 }
2396 }
2397 None => {
2398 let err_msg = format!(
2399 "Failed to get the arguments of the function. Tool call: {value:?}"
2400 );
2401
2402 #[cfg(feature = "logging")]
2403 error!(target: "stdout", "{}", &err_msg);
2404
2405 return Err(LlamaCoreError::Operation(err_msg));
2406 }
2407 };
2408
2409 let function = Function { name, arguments };
2410
2411 let tool_call = ToolCall {
2412 id: "call_abc123".to_string(),
2413 ty: "function".to_string(),
2414 function,
2415 };
2416
2417 Ok(ParseResult {
2418 raw: input.to_owned(),
2419 content: Some(input.to_owned()),
2420 tool_calls: vec![tool_call],
2421 })
2422 }
2423 Err(e) => {
2424 let err_msg = format!(
2425 "Failed to deserialize generated tool calls: {action:#?}. Reason: {e}"
2426 );
2427
2428 #[cfg(feature = "logging")]
2429 error!(target: "stdout", "{}", &err_msg);
2430
2431 Err(LlamaCoreError::Operation(err_msg))
2432 }
2433 }
2434 }
2435 None => match input.contains("<final_answer>") {
2436 true => Ok(ParseResult {
2437 raw: input.to_owned(),
2438 content: Some(input.to_owned()),
2439 tool_calls: vec![],
2440 }),
2441 false => {
2442 let content = format!("<final_answer>{}</final_answer>", input.trim());
2443
2444 Ok(ParseResult {
2445 raw: input.to_owned(),
2446 content: Some(content),
2447 tool_calls: vec![],
2448 })
2449 }
2450 },
2451 }
2452 }
2453 PromptTemplateType::SeedOssNoThink | PromptTemplateType::SeedOssThink => {
2454 #[cfg(feature = "logging")]
2455 info!(target: "stdout", "Raw input to tool call parser: {input:?}");
2456
2457 match regex::Regex::new(r"```json\n([\s\S]*?)\n") {
2458 Ok(re) => {
2459 let mut values: Vec<serde_json::Value> = vec![];
2460 for cap in re.captures_iter(input) {
2461 let mut matched = cap[1].trim();
2462
2463 if matched.starts_with("\\n") {
2464 matched = matched.trim_start_matches("\\n");
2465 }
2466
2467 if matched.ends_with("\\n") {
2468 matched = matched.trim_end_matches("\\n");
2469 }
2470
2471 #[cfg(feature = "logging")]
2472 info!(target: "stdout", "captured: {matched:#?}");
2473
2474 if !matched.is_empty() {
2475 match serde_json::from_str::<serde_json::Value>(matched) {
2476 Ok(value) => values.push(value),
2477 Err(e) => {
2478 let err_msg = format!(
2479 "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
2480 );
2481
2482 #[cfg(feature = "logging")]
2483 error!(target: "stdout", "{}", &err_msg);
2484
2485 return Err(LlamaCoreError::Operation(err_msg));
2486 }
2487 }
2488 }
2489 }
2490
2491 let mut tool_calls: Vec<ToolCall> = vec![];
2492 for value in values.iter() {
2493 let name = match value.get("name") {
2494 Some(name) => name.to_string().replace("\"", ""),
2495 None => {
2496 let err_msg = format!(
2497 "Failed to get the name of the function. Tool call: {value:?}"
2498 );
2499
2500 #[cfg(feature = "logging")]
2501 error!(target: "stdout", "{}", &err_msg);
2502
2503 return Err(LlamaCoreError::Operation(err_msg));
2504 }
2505 };
2506
2507 let arguments = match value.get("arguments") {
2508 Some(arguments) => {
2509 if arguments.is_string() {
2510 arguments.as_str().unwrap().to_string()
2511 } else if arguments.is_object() {
2512 let map = arguments.as_object().unwrap();
2513
2514 #[cfg(feature = "logging")]
2515 info!(target: "stdout", "func arguments: {map:?}");
2516
2517 serde_json::to_string(map).unwrap()
2518 } else {
2519 serde_json::to_string(arguments).unwrap()
2520 }
2521 }
2522 None => {
2523 let err_msg = format!(
2524 "Failed to get the arguments of the function. Tool call: {value:?}"
2525 );
2526
2527 #[cfg(feature = "logging")]
2528 error!(target: "stdout", "{}", &err_msg);
2529
2530 return Err(LlamaCoreError::Operation(err_msg));
2531 }
2532 };
2533
2534 let function = Function { name, arguments };
2535
2536 let tool_call = ToolCall {
2537 id: "call_abc123".to_string(),
2538 ty: "function".to_string(),
2539 function,
2540 };
2541
2542 tool_calls.push(tool_call);
2543 }
2544
2545 let parsed = if tool_calls.is_empty() {
2546 ParseResult {
2547 raw: input.to_owned(),
2548 content: Some(input.to_owned()),
2549 tool_calls: vec![],
2550 }
2551 } else {
2552 ParseResult {
2553 raw: input.to_owned(),
2554 content: None,
2555 tool_calls,
2556 }
2557 };
2558
2559 #[cfg(feature = "logging")]
2560 info!(target: "stdout", "parsed result: {parsed:?}");
2561
2562 Ok(parsed)
2563 }
2564 Err(e) => {
2565 let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2566
2567 #[cfg(feature = "logging")]
2568 error!(target: "stdout", "{}", &err_msg);
2569
2570 Err(LlamaCoreError::Operation(err_msg))
2571 }
2572 }
2573 }
2574 _ => {
2575 let err_msg = format!(
2576 "The tool use is only supported for prompt templates: {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, and {}.",
2577 PromptTemplateType::MistralTool,
2578 PromptTemplateType::ChatMLTool,
2579 PromptTemplateType::GroqLlama3Tool,
2580 PromptTemplateType::Llama3Tool,
2581 PromptTemplateType::InternLM2Tool,
2582 PromptTemplateType::NemotronTool,
2583 PromptTemplateType::FunctionaryV32,
2584 PromptTemplateType::MistralSmallTool,
2585 PromptTemplateType::Llama4Chat,
2586 PromptTemplateType::Qwen3NoThink,
2587 PromptTemplateType::Smol3NoThink,
2588 PromptTemplateType::Gemma3,
2589 PromptTemplateType::GptOss,
2590 PromptTemplateType::Qwen3Agent,
2591 PromptTemplateType::SeedOssNoThink,
2592 PromptTemplateType::SeedOssThink
2593 );
2594
2595 #[cfg(feature = "logging")]
2596 error!(target: "stdout", "{}", &err_msg);
2597
2598 Err(LlamaCoreError::Operation(err_msg))
2599 }
2600 }
2601}
2602
2603fn check_model_metadata(
2604 chat_request: &ChatCompletionRequest,
2605) -> Result<GgmlMetadata, LlamaCoreError> {
2606 let mut should_update = false;
2607 let mut metadata = get_model_metadata(chat_request.model.as_ref())?;
2608
2609 if metadata.prompt_template.is_image_supported() {
2611 if let Some(ChatCompletionRequestMessage::User(user_message)) = chat_request.messages.last()
2612 {
2613 if let ChatCompletionUserMessageContent::Parts(parts) = user_message.content() {
2614 for part in parts {
2615 if let ContentPart::Image(image_part) = part {
2616 let image = image_part.image();
2617
2618 if image.is_url() {
2619 let err_msg = "The image is provided in URL format. Only base64 format is supported.".to_string();
2620
2621 #[cfg(feature = "logging")]
2622 error!(target: "stdout", "{}", &err_msg);
2623
2624 return Err(LlamaCoreError::Operation(err_msg));
2625 } else {
2626 #[cfg(feature = "logging")]
2627 info!(target: "stdout", "The image is provided in base64 format.");
2628
2629 break;
2632 }
2633 }
2634 }
2635 }
2636 }
2637 }
2638
2639 if let Some(temp) = chat_request.temperature {
2641 if metadata.temperature != temp {
2642 metadata.temperature = temp;
2644
2645 if !should_update {
2646 should_update = true;
2647 }
2648 }
2649 }
2650
2651 if let Some(top_p) = chat_request.top_p {
2653 if metadata.top_p != top_p {
2654 metadata.top_p = top_p;
2656
2657 if !should_update {
2658 should_update = true;
2659 }
2660 }
2661 }
2662
2663 if let Some(frequency_penalty) = chat_request.frequency_penalty {
2665 if metadata.frequency_penalty != frequency_penalty {
2666 metadata.frequency_penalty = frequency_penalty;
2668
2669 if !should_update {
2670 should_update = true;
2671 }
2672 }
2673 }
2674
2675 if let Some(presence_penalty) = chat_request.presence_penalty {
2677 if metadata.presence_penalty != presence_penalty {
2678 metadata.presence_penalty = presence_penalty;
2680
2681 if !should_update {
2682 should_update = true;
2683 }
2684 }
2685 }
2686
2687 if metadata.embeddings {
2689 metadata.embeddings = false;
2690
2691 if !should_update {
2692 should_update = true;
2693 }
2694 }
2695
2696 if should_update {
2697 #[cfg(feature = "logging")]
2698 info!(target: "stdout", "Update the model metadata.");
2699
2700 update_model_metadata(chat_request.model.as_ref(), &metadata)?;
2702 }
2703
2704 Ok(metadata)
2705}
2706
2707fn update_n_predict(
2708 chat_request: &ChatCompletionRequest,
2709 metadata: &mut GgmlMetadata,
2710 available_completion_tokens: u64,
2711) -> Result<(), LlamaCoreError> {
2712 let mut should_update = false;
2713
2714 #[cfg(feature = "logging")]
2715 info!(target: "stdout", "n_predict: {}", metadata.n_predict);
2716
2717 if let Some(max_completion_tokens) = chat_request.max_completion_tokens {
2723 if metadata.n_predict != max_completion_tokens {
2724 #[cfg(feature = "logging")]
2725 info!(target: "stdout", "Update n_predict with max_completion_tokens from {} to {}", metadata.n_predict, max_completion_tokens);
2726
2727 metadata.n_predict = max_completion_tokens;
2728
2729 if !should_update {
2730 should_update = true;
2731 }
2732 }
2733 }
2734
2735 if metadata.n_predict == -2 {
2737 #[cfg(feature = "logging")]
2738 info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2739
2740 metadata.n_predict = available_completion_tokens as i32;
2742
2743 if !should_update {
2744 should_update = true;
2745 }
2746 }
2747
2748 if metadata.n_predict == -1
2749 || (metadata.n_predict > 0 && metadata.n_predict < available_completion_tokens as i32)
2750 || (metadata.n_predict < 0 && metadata.n_predict != -2)
2751 {
2753 #[cfg(feature = "logging")]
2754 info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2755
2756 metadata.n_predict = available_completion_tokens as i32;
2758
2759 if !should_update {
2760 should_update = true;
2761 }
2762 }
2763
2764 if should_update {
2765 #[cfg(feature = "logging")]
2766 info!(target: "stdout", "Update the model metadata.");
2767
2768 update_model_metadata(chat_request.model.as_ref(), metadata)?;
2770 }
2771
2772 Ok(())
2773}
2774
2775fn post_process(
2777 output: impl AsRef<str>,
2778 template_ty: &PromptTemplateType,
2779) -> Result<String, String> {
2780 let output = if *template_ty == PromptTemplateType::Baichuan2 {
2781 if output.as_ref().contains("用户:") {
2782 output.as_ref().trim_end_matches("用户:").trim().to_owned()
2783 } else {
2784 output.as_ref().trim().to_owned()
2785 }
2786 } else if *template_ty == PromptTemplateType::OpenChat {
2787 if output.as_ref().contains("<|end_of_turn|>") {
2788 output
2789 .as_ref()
2790 .trim_end_matches("<|end_of_turn|>")
2791 .trim()
2792 .to_owned()
2793 } else {
2794 output.as_ref().trim().to_owned()
2795 }
2796 } else if *template_ty == PromptTemplateType::GemmaInstruct
2797 || *template_ty == PromptTemplateType::Gemma3
2798 {
2799 let s = output.as_ref().trim();
2800 if s.ends_with("<end_of_turn>") {
2801 s.trim_end_matches("<end_of_turn>").trim().to_owned()
2802 } else {
2803 s.to_owned()
2804 }
2805 } else if *template_ty == PromptTemplateType::ChatML
2806 || *template_ty == PromptTemplateType::ChatMLTool
2807 || *template_ty == PromptTemplateType::InternLM2Tool
2808 || *template_ty == PromptTemplateType::MiniCPMV
2809 {
2810 let mut s = output.as_ref().trim();
2811 if s.ends_with("<|endoftext|>") {
2812 s = s.trim_end_matches("<|endoftext|>").trim();
2813 }
2814
2815 if s.starts_with(":") {
2816 s = s.trim_start_matches(":").trim();
2817 }
2818
2819 let x = {
2821 let pat = r#"<think>
2822
2823</think>
2824"#;
2825 if s.contains(pat) {
2826 let x = s.replace(pat, "");
2827 if x.starts_with("()") {
2828 x.trim_start_matches("()").to_owned()
2829 } else {
2830 x.to_owned()
2831 }
2832 } else {
2833 s.to_owned()
2834 }
2835 };
2836 s = x.trim();
2837
2838 if s.contains("<|im_start|>") && s.contains("<|im_end|>") {
2839 let idx_start = s.find("<|im_start|>").unwrap();
2840 let idx_end = s.find("<|im_end|>").unwrap();
2841
2842 match idx_start <= idx_end {
2843 true => s.split("<|im_start|>").collect::<Vec<_>>()[0]
2844 .trim()
2845 .to_owned(),
2846 false => s.split("<|im_end|>").collect::<Vec<_>>()[0]
2847 .trim()
2848 .to_owned(),
2849 }
2850 } else if s.contains("<|im_start|>") {
2851 s.split("<|im_start|>").collect::<Vec<_>>()[0]
2852 .trim()
2853 .to_owned()
2854 } else if s.contains("<|im_end|>") {
2855 let output = s.trim_end_matches("<|im_end|>").trim();
2856 if output.starts_with(": ") {
2857 output.trim_start_matches(": ").to_owned()
2858 } else {
2859 output.to_owned()
2860 }
2861 } else {
2862 s.to_owned()
2863 }
2864 } else if *template_ty == PromptTemplateType::Zephyr
2865 || *template_ty == PromptTemplateType::MistralLite
2866 || *template_ty == PromptTemplateType::MistralTool
2867 || *template_ty == PromptTemplateType::MistralInstruct
2868 || *template_ty == PromptTemplateType::MistralSmallChat
2869 || *template_ty == PromptTemplateType::MistralSmallTool
2870 || *template_ty == PromptTemplateType::BreezeInstruct
2871 {
2872 if output.as_ref().contains("</s><") {
2873 output.as_ref().trim_end_matches("</s><").trim().to_owned()
2874 } else if output.as_ref().contains("</s>") {
2875 output
2876 .as_ref()
2877 .strip_suffix("</s>")
2878 .unwrap()
2879 .trim()
2880 .to_owned()
2881 } else {
2882 output.as_ref().trim().to_owned()
2883 }
2884 } else if *template_ty == PromptTemplateType::DeepseekChat {
2885 if output.as_ref().contains("<|end_of_sentence|>") {
2886 output
2887 .as_ref()
2888 .trim_end_matches("<|end_of_sentence|>")
2889 .trim()
2890 .replace("<|end_of_sentence|>", " ")
2891 .trim()
2892 .to_owned()
2893 } else {
2894 output.as_ref().trim().to_owned()
2895 }
2896 } else if *template_ty == PromptTemplateType::HumanAssistant {
2897 if output.as_ref().contains("Human:") {
2898 output.as_ref().trim_end_matches("Human:").trim().to_owned()
2899 } else {
2900 output.as_ref().trim().to_owned()
2901 }
2902 } else if *template_ty == PromptTemplateType::SolarInstruct {
2903 let s = output.as_ref().trim();
2904
2905 if s.starts_with("### Answer") {
2906 let s = s.trim_start_matches("###").trim();
2907
2908 if s.starts_with("Answer:\n") {
2909 s.replace("Answer:\n", "Answer: ")
2910 } else {
2911 s.to_owned()
2912 }
2913 } else {
2914 s.to_owned()
2915 }
2916 } else if *template_ty == PromptTemplateType::Llama2Chat
2917 || *template_ty == PromptTemplateType::NemotronTool
2918 || *template_ty == PromptTemplateType::NemotronChat
2919 {
2920 let s = output.as_ref().trim();
2921 if s.ends_with("</s>") {
2922 s.trim_end_matches("</s>").trim().to_owned()
2923 } else {
2924 s.to_owned()
2925 }
2926 } else if *template_ty == PromptTemplateType::Llama3Chat
2927 || *template_ty == PromptTemplateType::GroqLlama3Tool
2928 || *template_ty == PromptTemplateType::Llama3Tool
2929 || *template_ty == PromptTemplateType::FunctionaryV32
2930 {
2931 let s = output.as_ref().trim();
2932 if s.ends_with("<|eot_id|>") {
2933 s.trim_end_matches("<|eot_id|>").trim().to_owned()
2934 } else {
2935 s.to_owned()
2936 }
2937 } else if *template_ty == PromptTemplateType::Phi3Chat {
2938 let s = output.as_ref().trim();
2939 if s.ends_with("<|end|>") {
2940 s.trim_end_matches("<|end|>").trim().to_owned()
2941 } else {
2942 s.to_owned()
2943 }
2944 } else if *template_ty == PromptTemplateType::Phi4Chat {
2945 let mut s = output.as_ref().trim();
2946
2947 if s.starts_with("think>") {
2948 s = s.trim_start_matches("think>").trim();
2949 }
2950
2951 if s.ends_with("<|im_end|>") {
2952 s.trim_end_matches("<|im_end|>").trim().to_owned()
2953 } else if s.ends_with("<|end|>") {
2954 s.trim_end_matches("<|end|>").trim().to_owned()
2955 } else {
2956 s.to_owned()
2957 }
2958 } else if *template_ty == PromptTemplateType::FunctionaryV31 {
2959 let mut s = output.as_ref().trim();
2960 if s.ends_with("<|eot_id|>") {
2961 s = s.trim_end_matches("<|eot_id|>").trim();
2962 }
2963 if s.ends_with("<|eom_id|>") {
2964 s = s.trim_end_matches("<|eom_id|>").trim();
2965 }
2966 s.to_owned()
2967 } else if *template_ty == PromptTemplateType::MoxinChat
2968 || *template_ty == PromptTemplateType::MoxinInstruct
2969 {
2970 let s = output.as_ref().trim();
2971 if s.ends_with("</s>") {
2972 s.trim_end_matches("</s>").trim().to_owned()
2973 } else if s.ends_with("[INST]") {
2974 s.trim_end_matches("[INST]").trim().to_owned()
2975 } else {
2976 s.to_owned()
2977 }
2978 } else if *template_ty == PromptTemplateType::Falcon3 {
2979 let s = output.as_ref().trim();
2980 if s.ends_with("<|endoftext|>") {
2981 s.trim_end_matches("<|endoftext|>").trim().to_owned()
2982 } else {
2983 s.to_owned()
2984 }
2985 } else if *template_ty == PromptTemplateType::Megrez {
2986 let s = output.as_ref().trim();
2987 if s.ends_with("<|turn_end|>") {
2988 s.trim_end_matches("<|turn_end|>").trim().to_owned()
2989 } else {
2990 s.to_owned()
2991 }
2992 } else if *template_ty == PromptTemplateType::Qwen2vl
2993 || *template_ty == PromptTemplateType::Qwen3NoThink
2994 || *template_ty == PromptTemplateType::ChatMLThink
2995 {
2996 let mut s = output.as_ref().trim();
2997
2998 if s.starts_with(":") {
2999 s = s.trim_start_matches(":").trim();
3000 }
3001
3002 if s.starts_with("</think>") {
3003 s = s.trim_start_matches("</think>").trim();
3004 }
3005
3006 if s.ends_with("<|im_end|>") {
3007 s.trim_end_matches("<|im_end|>").trim().to_owned()
3008 } else {
3009 s.to_owned()
3010 }
3011 } else if *template_ty == PromptTemplateType::VicunaLlava {
3012 let s = output.as_ref().trim();
3013 if s.ends_with("</s>") {
3014 s.trim_end_matches("</s>").trim().to_owned()
3015 } else {
3016 s.to_owned()
3017 }
3018 } else if *template_ty == PromptTemplateType::ExaoneDeepChat
3019 || *template_ty == PromptTemplateType::ExaoneChat
3020 {
3021 let mut s = output.as_ref().trim();
3022
3023 if s.ends_with("[|endofturn|]") {
3024 s = s.trim_end_matches("[|endofturn|]").trim();
3025 }
3026
3027 s.to_owned()
3028 } else if *template_ty == PromptTemplateType::Llama4Chat {
3029 let mut s = output.as_ref().trim();
3030
3031 if s.ends_with("<|eot|>") {
3032 s = s.trim_end_matches("<|eot|>").trim();
3033 }
3034
3035 s.to_owned()
3036 } else if *template_ty == PromptTemplateType::Smolvl {
3037 let mut s = output.as_ref().trim();
3038
3039 if s.starts_with(":") {
3040 s = s.trim_start_matches(":").trim();
3041 }
3042
3043 if s.ends_with("<end_of_utterance>") {
3044 s = s.trim_end_matches("<end_of_utterance>").trim();
3045 }
3046
3047 if s.contains("<end_of_utterance>:") {
3048 let parts = s.split("<end_of_utterance>:").collect::<Vec<_>>();
3049 parts.last().unwrap().trim().to_owned()
3050 } else {
3051 s.to_owned()
3052 }
3053 } else if *template_ty == PromptTemplateType::Smol3NoThink {
3054 let mut s = output.as_ref().trim();
3055
3056 if s.ends_with("<|im_end|>") {
3057 s = s.trim_end_matches("<|im_end|>").trim();
3058 }
3059
3060 let re = regex::Regex::new(r"(?s)^<think>.*?</think>\s*").unwrap();
3061 re.replace(s, "").to_string()
3062 } else if *template_ty == PromptTemplateType::GptOss {
3063 let s = output.as_ref().trim();
3064
3065 let re =
3066 regex::Regex::new(r"(?s).*<\|channel\|>final<\|message\|>(.*?)<\|return\|>$").unwrap();
3067
3068 if let Some(caps) = re.captures(s) {
3069 let extracted = &caps[1];
3070 extracted.to_owned()
3071 } else {
3072 s.to_owned()
3073 }
3074 } else if *template_ty == PromptTemplateType::Qwen3Agent {
3075 let mut s = output.as_ref().trim();
3076
3077 if s.starts_with(":") {
3078 s = s.trim_start_matches(":").trim();
3079 }
3080
3081 if s.starts_with("</think>") {
3082 s = s.trim_start_matches("</think>").trim();
3083 }
3084
3085 if s.ends_with("<|im_end|>") {
3086 s = s.trim_end_matches("<|im_end|>").trim();
3087 }
3088
3089 if s.contains("<final_answer>") && !s.contains("</final_answer>") {
3090 format!("{s}</final_answer>")
3091 } else {
3092 s.to_owned()
3093 }
3094 } else if *template_ty == PromptTemplateType::SeedOssNoThink {
3095 let s = output.as_ref().trim();
3096
3097 let re = regex::Regex::new(r"(?s)</seed:think>\s*(.*?)\s*<seed:eos>").unwrap();
3098
3099 if let Some(caps) = re.captures(s) {
3100 let extracted = &caps[1];
3101 extracted.to_owned()
3102 } else {
3103 s.to_owned()
3104 }
3105 } else {
3106 output.as_ref().trim().to_owned()
3107 };
3108
3109 Ok(output)
3110}
3111
3112fn build_prompt(
3124 model_name: Option<&String>,
3125 chat_request: &mut ChatCompletionRequest,
3126) -> Result<(String, u64, bool), LlamaCoreError> {
3127 let metadata = get_model_metadata(model_name)?;
3128 let ctx_size = metadata.ctx_size as u64;
3129 let chat_prompt = ChatPrompt::from(metadata.prompt_template);
3130
3131 let max_prompt_tokens = ctx_size * 4 / 5;
3133
3134 loop {
3135 {
3137 }
3150
3151 if chat_request.messages.is_empty() {
3152 let err_msg = "The messages in the chat request are empty.";
3153
3154 #[cfg(feature = "logging")]
3155 error!(target: "stdout", "{err_msg}");
3156
3157 return Err(LlamaCoreError::Operation(err_msg.to_owned()));
3158 }
3159
3160 #[cfg(feature = "logging")]
3161 {
3162 let mut role_chain = String::new();
3163 for (idx, message) in chat_request.messages.iter().enumerate() {
3164 if idx == chat_request.messages.len() - 1 {
3165 role_chain.push_str(&format!("{}", message.role()));
3166 } else {
3167 role_chain.push_str(&format!("{} -> ", message.role()));
3168 }
3169 }
3170 info!(target: "stdout", "Role chain: {role_chain}");
3171 }
3172
3173 let (prompt, tool_use) = match chat_request.tool_choice.as_ref() {
3174 Some(tool_choice) => match tool_choice {
3175 ToolChoice::None => {
3176 match chat_prompt.build_with_tools(&mut chat_request.messages, Some(&[])) {
3177 Ok(prompt) => (prompt, false),
3178 Err(e) => {
3179 let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3180
3181 #[cfg(feature = "logging")]
3182 error!(target: "stdout", "{}", &err_msg);
3183
3184 return Err(LlamaCoreError::Operation(err_msg));
3185 }
3186 }
3187 }
3188 _ => match chat_request.tools.as_ref() {
3189 Some(tools) => match chat_prompt
3190 .build_with_tools(&mut chat_request.messages, Some(tools.as_slice()))
3191 {
3192 Ok(prompt) => (prompt, true),
3193 Err(e) => {
3194 let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3195
3196 #[cfg(feature = "logging")]
3197 error!(target: "stdout", "{}", &err_msg);
3198
3199 return Err(LlamaCoreError::Operation(err_msg));
3200 }
3201 },
3202 None => {
3203 #[cfg(feature = "logging")]
3204 warn!(target: "stdout", "The tool choice without tools is not supported.");
3205
3206 match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
3207 Ok(prompt) => (prompt, false),
3208 Err(e) => {
3209 let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3210
3211 #[cfg(feature = "logging")]
3212 error!(target: "stdout", "{}", &err_msg);
3213
3214 return Err(LlamaCoreError::Operation(err_msg));
3215 }
3216 }
3217 }
3218 },
3219 },
3220 None => match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
3221 Ok(prompt) => (prompt, false),
3222 Err(e) => {
3223 let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3224
3225 #[cfg(feature = "logging")]
3226 error!(target: "stdout", "{}", &err_msg);
3227
3228 return Err(LlamaCoreError::Operation(err_msg));
3229 }
3230 },
3231 };
3232 #[cfg(feature = "logging")]
3233 info!(target: "stdout", "Try to set prompt: {prompt}");
3234
3235 set_prompt(model_name, &prompt)?;
3237
3238 let token_info = get_token_info_by_graph_name(model_name)?;
3240
3241 match token_info.prompt_tokens > max_prompt_tokens {
3242 true => {
3243 match chat_request.messages[0].role() {
3244 ChatCompletionRole::System => {
3245 if chat_request.messages.len() == 4
3247 && chat_request.messages[1].role() == ChatCompletionRole::User
3248 && chat_request.messages[2].role() == ChatCompletionRole::Assistant
3249 && chat_request.messages[3].role() == ChatCompletionRole::Tool
3250 {
3251 let err_msg = format!(
3252 "The number of prompt tokens ({}) is greater than the max prompt tokens ({}). Please increase the context size.",
3253 token_info.prompt_tokens, max_prompt_tokens
3254 );
3255
3256 #[cfg(feature = "logging")]
3257 error!(target: "stdout", "{}", &err_msg);
3258
3259 return Err(LlamaCoreError::Operation(err_msg));
3260 }
3261
3262 if chat_request.messages.len() > 2 {
3263 #[cfg(feature = "logging")]
3264 info!(target: "stdout", "Prune chat history: current length {}", chat_request.messages.len());
3265
3266 if chat_request.messages[1].role() == ChatCompletionRole::User {
3269 let user_message = chat_request.messages.remove(1);
3270
3271 #[cfg(feature = "logging")]
3272 info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
3273 }
3274
3275 while chat_request.messages[1].role() != ChatCompletionRole::User {
3278 let message = chat_request.messages.remove(1);
3279
3280 #[cfg(feature = "logging")]
3281 info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
3282
3283 if chat_request.messages.len() == 1 {
3284 let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
3285
3286 #[cfg(feature = "logging")]
3287 error!(target: "stdout", "{err_msg}");
3288
3289 return Err(LlamaCoreError::Operation(err_msg));
3290 }
3291 }
3292 } else if token_info.prompt_tokens > ctx_size {
3293 let err_msg = format!(
3294 "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
3295 token_info.prompt_tokens, ctx_size
3296 );
3297
3298 #[cfg(feature = "logging")]
3299 error!(target: "stdout", "{}", &err_msg);
3300
3301 return Err(LlamaCoreError::Operation(err_msg));
3302 } else {
3303 return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
3304 }
3305 }
3306 ChatCompletionRole::User => {
3307 if chat_request.messages.len() == 3
3309 && chat_request.messages[1].role() == ChatCompletionRole::User
3310 && chat_request.messages[2].role() == ChatCompletionRole::Assistant
3311 && chat_request.messages[3].role() == ChatCompletionRole::Tool
3312 {
3313 let err_msg = format!(
3314 "The number of prompt tokens ({}) is greater than the max prompt tokens ({}). Please increase the context size.",
3315 token_info.prompt_tokens, max_prompt_tokens
3316 );
3317
3318 #[cfg(feature = "logging")]
3319 error!(target: "stdout", "{}", &err_msg);
3320
3321 return Err(LlamaCoreError::Operation(err_msg));
3322 }
3323
3324 if chat_request.messages.len() > 1 {
3325 if chat_request.messages[0].role() == ChatCompletionRole::User {
3330 let user_message = chat_request.messages.remove(0);
3331
3332 #[cfg(feature = "logging")]
3333 info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
3334 }
3335
3336 while chat_request.messages[0].role() != ChatCompletionRole::User {
3339 let message = chat_request.messages.remove(0);
3340
3341 #[cfg(feature = "logging")]
3342 info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
3343
3344 if chat_request.messages.is_empty() {
3345 let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
3346
3347 #[cfg(feature = "logging")]
3348 error!(target: "stdout", "{err_msg}");
3349
3350 return Err(LlamaCoreError::Operation(err_msg));
3351 }
3352 }
3353 } else if token_info.prompt_tokens > ctx_size {
3354 let err_msg = format!(
3355 "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
3356 token_info.prompt_tokens, ctx_size
3357 );
3358
3359 #[cfg(feature = "logging")]
3360 error!(target: "stdout", "{}", &err_msg);
3361
3362 return Err(LlamaCoreError::Operation(err_msg));
3363 } else {
3364 return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
3365 }
3366 }
3367 _ => {
3368 #[cfg(feature = "logging")]
3369 info!(target: "stdout", "remove a {} message from the message queue", chat_request.messages[0].role());
3370
3371 chat_request.messages.remove(0);
3372 }
3373 }
3374
3375 continue;
3376 }
3377 false => return Ok((prompt, ctx_size - max_prompt_tokens, tool_use)),
3378 }
3379 }
3380}
3381
3382fn set_prompt(model_name: Option<&String>, prompt: impl AsRef<str>) -> Result<(), LlamaCoreError> {
3383 let chat_graphs = match CHAT_GRAPHS.get() {
3384 Some(chat_graphs) => chat_graphs,
3385 None => {
3386 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3387
3388 #[cfg(feature = "logging")]
3389 error!(target: "stdout", "{}", &err_msg);
3390
3391 return Err(LlamaCoreError::Operation(err_msg.into()));
3392 }
3393 };
3394
3395 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3396 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3397
3398 #[cfg(feature = "logging")]
3399 error!(target: "stdout", "{}", &err_msg);
3400
3401 LlamaCoreError::Operation(err_msg)
3402 })?;
3403
3404 match model_name {
3405 Some(model_name) => {
3406 #[cfg(feature = "logging")]
3407 info!(target: "stdout", "Set prompt to the chat model named {model_name}");
3408
3409 match chat_graphs.contains_key(model_name) {
3410 true => {
3411 let graph = chat_graphs.get_mut(model_name).unwrap();
3412 let tensor_data = prompt.as_ref().as_bytes().to_vec();
3413 set_tensor_data_u8(graph, 0, &tensor_data)
3414 }
3415 false => match chat_graphs.iter_mut().next() {
3416 Some((_, graph)) => {
3417 let tensor_data = prompt.as_ref().as_bytes().to_vec();
3418 set_tensor_data_u8(graph, 0, &tensor_data)
3419 }
3420 None => {
3421 let err_msg = "There is no model available in the chat graphs.";
3422
3423 #[cfg(feature = "logging")]
3424 error!(target: "stdout", "{}", &err_msg);
3425
3426 Err(LlamaCoreError::Operation(err_msg.into()))
3427 }
3428 },
3429 }
3430 }
3431 None => {
3432 #[cfg(feature = "logging")]
3433 info!(target: "stdout", "Set prompt to the default chat model.");
3434
3435 match chat_graphs.iter_mut().next() {
3436 Some((_, graph)) => {
3437 let tensor_data = prompt.as_ref().as_bytes().to_vec();
3438 set_tensor_data_u8(graph, 0, &tensor_data)
3439 }
3440 None => {
3441 let err_msg = "There is no model available in the chat graphs while trying to set prompt to the default model.";
3442
3443 #[cfg(feature = "logging")]
3444 error!(target: "stdout", "{err_msg}");
3445
3446 Err(LlamaCoreError::Operation(err_msg.into()))
3447 }
3448 }
3449 }
3450 }
3451}
3452
3453fn get_model_metadata(model_name: Option<&String>) -> Result<GgmlMetadata, LlamaCoreError> {
3455 let chat_graphs = match CHAT_GRAPHS.get() {
3456 Some(chat_graphs) => chat_graphs,
3457 None => {
3458 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3459
3460 #[cfg(feature = "logging")]
3461 error!(target: "stdout", "{err_msg}");
3462
3463 return Err(LlamaCoreError::Operation(err_msg.into()));
3464 }
3465 };
3466
3467 let chat_graphs = chat_graphs.lock().map_err(|e| {
3468 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3469
3470 #[cfg(feature = "logging")]
3471 error!(target: "stdout", "{}", &err_msg);
3472
3473 LlamaCoreError::Operation(err_msg)
3474 })?;
3475
3476 match model_name {
3477 Some(model_name) => match chat_graphs.contains_key(model_name) {
3478 true => {
3479 let graph = chat_graphs.get(model_name).unwrap();
3480 Ok(graph.metadata.clone())
3481 }
3482 false => match chat_graphs.iter().next() {
3483 Some((_, graph)) => Ok(graph.metadata.clone()),
3484 None => {
3485 let err_msg = "There is no model available in the chat graphs.";
3486
3487 #[cfg(feature = "logging")]
3488 error!(target: "stdout", "{}", &err_msg);
3489
3490 Err(LlamaCoreError::Operation(err_msg.into()))
3491 }
3492 },
3493 },
3494 None => match chat_graphs.iter().next() {
3495 Some((_, graph)) => Ok(graph.metadata.clone()),
3496 None => {
3497 let err_msg = "There is no model available in the chat graphs.";
3498
3499 #[cfg(feature = "logging")]
3500 error!(target: "stdout", "{err_msg}");
3501
3502 Err(LlamaCoreError::Operation(err_msg.into()))
3503 }
3504 },
3505 }
3506}
3507
3508fn update_model_metadata(
3509 model_name: Option<&String>,
3510 metadata: &GgmlMetadata,
3511) -> Result<(), LlamaCoreError> {
3512 let config = match serde_json::to_string(metadata) {
3513 Ok(config) => config,
3514 Err(e) => {
3515 let err_msg = format!("Fail to serialize metadata to a JSON string. {e}");
3516
3517 #[cfg(feature = "logging")]
3518 error!(target: "stdout", "{}", &err_msg);
3519
3520 return Err(LlamaCoreError::Operation(err_msg));
3521 }
3522 };
3523
3524 let chat_graphs = match CHAT_GRAPHS.get() {
3525 Some(chat_graphs) => chat_graphs,
3526 None => {
3527 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3528
3529 #[cfg(feature = "logging")]
3530 error!(target: "stdout", "{err_msg}");
3531
3532 return Err(LlamaCoreError::Operation(err_msg.into()));
3533 }
3534 };
3535
3536 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3537 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. Reason: {e}");
3538
3539 #[cfg(feature = "logging")]
3540 error!(target: "stdout", "{}", &err_msg);
3541
3542 LlamaCoreError::Operation(err_msg)
3543 })?;
3544
3545 match model_name {
3546 Some(model_name) => {
3547 match chat_graphs.contains_key(model_name) {
3548 true => {
3549 let graph = chat_graphs.get_mut(model_name).unwrap();
3550 set_tensor_data_u8(graph, 1, config.as_bytes())
3552 }
3553 false => match chat_graphs.iter_mut().next() {
3554 Some((_, graph)) => {
3555 set_tensor_data_u8(graph, 1, config.as_bytes())
3557 }
3558 None => {
3559 let err_msg = "There is no model available in the chat graphs.";
3560
3561 #[cfg(feature = "logging")]
3562 error!(target: "stdout", "{}", &err_msg);
3563
3564 Err(LlamaCoreError::Operation(err_msg.into()))
3565 }
3566 },
3567 }
3568 }
3569 None => {
3570 match chat_graphs.iter_mut().next() {
3571 Some((_, graph)) => {
3572 set_tensor_data_u8(graph, 1, config.as_bytes())
3574 }
3575 None => {
3576 let err_msg = "There is no model available in the chat graphs.";
3577
3578 #[cfg(feature = "logging")]
3579 error!(target: "stdout", "{err_msg}");
3580
3581 Err(LlamaCoreError::Operation(err_msg.into()))
3582 }
3583 }
3584 }
3585 }
3586}
3587
3588fn reset_model_metadata(model_name: Option<&String>) -> Result<(), LlamaCoreError> {
3589 let metadata = get_model_metadata(model_name)?;
3591
3592 update_model_metadata(model_name, &metadata)
3594}
3595
3596#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3597enum ContextFullState {
3598 Message,
3599 Usage,
3600 Done,
3601 EndOfSequence,
3602}
3603
3604#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3605enum StreamState {
3606 Usage,
3607 NoUsage,
3608 Done,
3609 EndOfSequence,
3610}
3611
3612#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3613enum PromptTooLongState {
3614 Message,
3615 Usage,
3616 Done,
3617 EndOfSequence,
3618}
3619
3620struct ChatStream {
3621 id: String,
3622 model: Option<String>,
3623 include_usage: bool,
3624 context_full_state: ContextFullState,
3625 prompt_too_long_state: PromptTooLongState,
3626 stream_state: StreamState,
3627 cache: Option<VecDeque<String>>,
3628 is_waiting: bool,
3629 has_lock: bool,
3630}
3631impl ChatStream {
3632 fn new(
3633 model: Option<String>,
3634 id: String,
3635 include_usage: bool,
3636 cache: Option<Vec<String>>,
3637 ) -> Self {
3638 let has_lock = CHAT_STREAM_ACTIVE
3640 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
3641 .is_ok();
3642
3643 #[cfg(feature = "logging")]
3644 if !has_lock {
3645 info!(target: "stdout", "Lock acquisition failed in ChatStream::new, creating with waiting status");
3646 }
3647
3648 ChatStream {
3649 id,
3650 model,
3651 include_usage,
3652 context_full_state: ContextFullState::Message,
3653 prompt_too_long_state: PromptTooLongState::Message,
3654 stream_state: if include_usage {
3655 StreamState::Usage
3656 } else {
3657 StreamState::NoUsage
3658 },
3659 cache: cache.map(VecDeque::from),
3660 is_waiting: !has_lock,
3661 has_lock,
3662 }
3663 }
3664
3665 fn try_acquire_lock(&mut self) -> bool {
3667 if self.has_lock {
3668 return true;
3669 }
3670
3671 let acquired = CHAT_STREAM_ACTIVE
3672 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
3673 .is_ok();
3674
3675 if acquired {
3676 self.has_lock = true;
3677 self.is_waiting = false;
3678 }
3679
3680 acquired
3681 }
3682}
3683impl Drop for ChatStream {
3684 fn drop(&mut self) {
3685 if self.has_lock || (self.cache.is_none() && !self.is_waiting) {
3687 #[cfg(feature = "logging")]
3688 info!(target: "stdout", "Cleaning up context for ChatStream {}", &self.id);
3689
3690 match &self.model {
3691 Some(model_name) => {
3692 match CHAT_GRAPHS.get() {
3693 Some(chat_graphs) => {
3694 match chat_graphs.lock() {
3695 Ok(mut chat_graphs) => match chat_graphs.contains_key(model_name) {
3696 true => {
3697 let graph = chat_graphs.get_mut(model_name).unwrap();
3698
3699 if let Err(e) = graph.finish_single() {
3701 let err_msg = format!(
3702 "Failed to clean up the context. Reason: {e}"
3703 );
3704
3705 #[cfg(feature = "logging")]
3706 error!(target: "stdout", "{}", &err_msg);
3707
3708 #[cfg(not(feature = "logging"))]
3709 println!(
3710 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3711 &err_msg
3712 );
3713 }
3714 }
3715 false => match chat_graphs.iter_mut().next() {
3716 Some((_, graph)) => {
3717 if let Err(e) = graph.finish_single() {
3719 let err_msg = format!(
3720 "Failed to clean up the context. Reason: {e}"
3721 );
3722
3723 #[cfg(feature = "logging")]
3724 error!(target: "stdout", "{}", &err_msg);
3725
3726 #[cfg(not(feature = "logging"))]
3727 println!(
3728 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3729 &err_msg
3730 );
3731 }
3732 }
3733 None => {
3734 let err_msg =
3735 "There is no model available in the chat graphs.";
3736
3737 #[cfg(feature = "logging")]
3738 error!(target: "stdout", "{}", &err_msg);
3739
3740 #[cfg(not(feature = "logging"))]
3741 println!(
3742 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3743 &err_msg
3744 );
3745 }
3746 },
3747 },
3748 Err(e) => {
3749 let err_msg =
3750 format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3751
3752 #[cfg(feature = "logging")]
3753 error!(target: "stdout", "{}", &err_msg);
3754
3755 #[cfg(not(feature = "logging"))]
3756 println!(
3757 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3758 &err_msg
3759 );
3760 }
3761 }
3762 }
3763 None => {
3764 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3765
3766 #[cfg(feature = "logging")]
3767 error!(target: "stdout", "{}", &err_msg);
3768
3769 #[cfg(not(feature = "logging"))]
3770 println!(
3771 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3772 &err_msg
3773 );
3774 }
3775 };
3776 }
3777 None => {
3778 match CHAT_GRAPHS.get() {
3779 Some(chat_graphs) => {
3780 match chat_graphs.lock() {
3781 Ok(mut chat_graphs) => match chat_graphs.iter_mut().next() {
3782 Some((_, graph)) => {
3783 if let Err(e) = graph.finish_single() {
3785 let err_msg = format!(
3786 "Failed to clean up the context. Reason: {e}"
3787 );
3788
3789 #[cfg(feature = "logging")]
3790 error!(target: "stdout", "{}", &err_msg);
3791
3792 #[cfg(not(feature = "logging"))]
3793 println!(
3794 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3795 &err_msg
3796 );
3797 }
3798 }
3799 None => {
3800 let err_msg =
3801 "There is no model available in the chat graphs.";
3802
3803 #[cfg(feature = "logging")]
3804 error!(target: "stdout", "{err_msg}");
3805
3806 #[cfg(not(feature = "logging"))]
3807 println!(
3808 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3809 err_msg
3810 );
3811 }
3812 },
3813 Err(e) => {
3814 let err_msg =
3815 format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3816
3817 #[cfg(feature = "logging")]
3818 error!(target: "stdout", "{}", &err_msg);
3819
3820 #[cfg(not(feature = "logging"))]
3821 println!(
3822 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3823 &err_msg
3824 );
3825 }
3826 }
3827 }
3828 None => {
3829 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3830
3831 #[cfg(feature = "logging")]
3832 error!(target: "stdout", "{}", &err_msg);
3833
3834 #[cfg(not(feature = "logging"))]
3835 println!(
3836 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3837 &err_msg
3838 );
3839 }
3840 };
3841 }
3842 }
3843
3844 #[cfg(feature = "logging")]
3845 info!(target: "stdout", "Model context cleanup done!");
3846 }
3847
3848 if let Err(e) = reset_model_metadata(self.model.as_ref()) {
3850 let err_msg = format!("Fail to reset model metadata. Reason: {e}");
3851
3852 #[cfg(feature = "logging")]
3853 error!(target: "stdout", "{}", &err_msg);
3854
3855 #[cfg(not(feature = "logging"))]
3856 println!("[ERROR][llama_core] {}", &err_msg);
3857 }
3858 #[cfg(feature = "logging")]
3859 info!(target: "stdout", "Model metadata reset done!");
3860
3861 if self.has_lock {
3863 CHAT_STREAM_ACTIVE.store(false, Ordering::SeqCst);
3865
3866 #[cfg(feature = "logging")]
3867 info!(target: "stdout", "Lock from ChatStream {} released", &self.id);
3868
3869 if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3871 if let Some(waker) = queue.pop_front() {
3872 #[cfg(feature = "logging")]
3873 info!(target: "stdout", "Waking up a waiting ChatStream");
3874
3875 waker.wake();
3876 }
3877 }
3878 }
3879 }
3880}
3881impl futures::Stream for ChatStream {
3882 type Item = Result<String, LlamaCoreError>;
3883
3884 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
3885 let this = self.get_mut();
3886
3887 if this.is_waiting {
3889 if !this.try_acquire_lock() {
3890 if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3892 queue.retain(|w| !w.will_wake(cx.waker()));
3894 queue.push_back(cx.waker().clone());
3896
3897 #[cfg(feature = "logging")]
3898 debug!(target: "stdout", "ChatStream {} is waiting for lock, added waker to queue", &this.id);
3899 }
3900
3901 return Poll::Pending;
3902 }
3903
3904 #[cfg(feature = "logging")]
3905 info!(target: "stdout", "ChatStream {} acquired lock and is now active", &this.id);
3906 }
3908
3909 if !this.has_lock && !this.try_acquire_lock() {
3911 this.is_waiting = true;
3913
3914 if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3916 queue.retain(|w| !w.will_wake(cx.waker()));
3917 queue.push_back(cx.waker().clone());
3918 }
3919
3920 return Poll::Pending;
3921 }
3922
3923 if this.cache.is_none() {
3924 let res = compute_stream(
3925 this.model.clone(),
3926 this.id.clone(),
3927 this.include_usage,
3928 &mut this.prompt_too_long_state,
3929 &mut this.context_full_state,
3930 &mut this.stream_state,
3931 );
3932
3933 match res {
3934 Ok(x) => {
3935 #[cfg(feature = "logging")]
3936 info!(target: "stdout", "next item for ChatStream {}: {}", &this.id, &x);
3937
3938 if x != "[GGML] End of sequence" && !x.is_empty() {
3939 Poll::Ready(Some(Ok(x)))
3940 } else {
3941 Poll::Ready(None)
3943 }
3944 }
3945 Err(e) => Poll::Ready(Some(Err(e))),
3946 }
3947 } else {
3948 let x = this.cache.as_mut().unwrap().pop_front();
3949
3950 #[cfg(feature = "logging")]
3951 info!(target: "stdout", "Get the next item from the cache for ChatStream {}: {:?}", &this.id, &x);
3952
3953 match x {
3954 Some(x) => Poll::Ready(Some(Ok(x))),
3955 None => Poll::Ready(None),
3956 }
3957 }
3958 }
3959}
3960
3961fn get_chat_stream_waker_queue() -> &'static Mutex<VecDeque<Waker>> {
3963 CHAT_STREAM_WAKER_QUEUE.get_or_init(|| {
3964 #[cfg(feature = "logging")]
3965 info!(target: "stdout", "Initializing ChatStream waker queue");
3966 Mutex::new(VecDeque::new())
3967 })
3968}
3969
3970fn compute_stream(
3971 model_name: Option<String>,
3972 id: String,
3973 include_usage: bool,
3974 prompt_too_long_state: &mut PromptTooLongState,
3975 context_full_state: &mut ContextFullState,
3976 stream_state: &mut StreamState,
3977) -> Result<String, LlamaCoreError> {
3978 #[cfg(feature = "logging")]
3979 info!(target: "stdout", "Computing stream chunk for ChatStream {}", &id);
3980
3981 #[cfg(feature = "logging")]
3982 debug!(target: "stdout", "prompt_too_long_state: {:?}", *prompt_too_long_state);
3983 #[cfg(feature = "logging")]
3984 debug!(target: "stdout", "context_full_state: {:?}", *context_full_state);
3985 #[cfg(feature = "logging")]
3986 debug!(target: "stdout", "stream_state: {:?}", *stream_state);
3987
3988 if *prompt_too_long_state == PromptTooLongState::EndOfSequence
3989 || *context_full_state == ContextFullState::EndOfSequence
3990 || *stream_state == StreamState::EndOfSequence
3991 {
3992 #[cfg(feature = "logging")]
3993 info!(target: "stdout", "Return the chat stream chunk!");
3994
3995 return Ok("[GGML] End of sequence".to_string());
3996 }
3997
3998 let chat_graphs = match CHAT_GRAPHS.get() {
3999 Some(chat_graphs) => chat_graphs,
4000 None => {
4001 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
4002
4003 #[cfg(feature = "logging")]
4004 error!(target: "stdout", "{}", &err_msg);
4005
4006 return Err(LlamaCoreError::Operation(err_msg.into()));
4007 }
4008 };
4009
4010 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
4012 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
4013
4014 #[cfg(feature = "logging")]
4015 error!(target: "stdout", "{}", &err_msg);
4016
4017 LlamaCoreError::Operation(err_msg)
4018 })?;
4019
4020 let res = match &model_name {
4022 Some(model_name) => {
4023 match chat_graphs.contains_key(model_name) {
4024 true => {
4025 let graph = chat_graphs.get_mut(model_name).unwrap();
4026 match graph.compute_single() {
4028 Ok(_) => {
4029 #[cfg(feature = "logging")]
4030 debug!(target: "stdout", "Compute the chat stream chunk successfully.");
4031
4032 match stream_state {
4034 StreamState::Usage | StreamState::NoUsage => {
4035 let output_buffer =
4037 get_output_buffer_single(graph, OUTPUT_TENSOR)?;
4038
4039 #[cfg(feature = "logging")]
4040 info!(target: "stdout", "retrieved the output buffer");
4041
4042 let output = match String::from_utf8(output_buffer.clone()) {
4044 Ok(token) => token,
4045 Err(_) => {
4046 let mutex = CACHED_UTF8_ENCODINGS
4047 .get_or_init(|| Mutex::new(Vec::new()));
4048 let mut cached_encodings = mutex.lock().map_err(|e| {
4049 let err_msg = format!(
4050 "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
4051 );
4052
4053 #[cfg(feature = "logging")]
4054 error!(target: "stdout", "{}", &err_msg);
4055
4056
4057 LlamaCoreError::Operation(err_msg)
4058 })?;
4059
4060 cached_encodings.extend_from_slice(&output_buffer[..]);
4062
4063 match String::from_utf8(cached_encodings.to_vec()) {
4064 Ok(token) => {
4065 cached_encodings.clear();
4067
4068 token
4069 }
4070 Err(e) => {
4071 if cached_encodings.len() > 4 {
4073 let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
4074
4075 #[cfg(feature = "logging")]
4076 error!(target: "stdout", "{}", &err_msg);
4077
4078 #[cfg(feature = "logging")]
4079 error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
4080
4081 cached_encodings.clear();
4088
4089 String::from("")
4090 } else {
4091 let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
4092
4093 #[cfg(feature = "logging")]
4094 warn!(target: "stdout", "{}", &warn_msg);
4095
4096 String::from("")
4097 }
4098 }
4099 }
4100 }
4101 };
4102
4103 #[cfg(feature = "logging")]
4104 info!(target: "stdout", "decoded the output buffer");
4105
4106 let created = SystemTime::now()
4107 .duration_since(std::time::UNIX_EPOCH)
4108 .map_err(|e| {
4109 let err_msg = format!(
4110 "Failed to get the current time. Reason: {e}"
4111 );
4112
4113 #[cfg(feature = "logging")]
4114 error!(target: "stdout", "{}", &err_msg);
4115
4116 LlamaCoreError::Operation(err_msg)
4117 })?;
4118
4119 let chat_completion_chunk = ChatCompletionChunk {
4120 id,
4121 object: "chat.completion.chunk".to_string(),
4122 created: created.as_secs(),
4123 model: graph.name().to_owned(),
4124 system_fingerprint: "fp_44709d6fcb".to_string(),
4125 choices: vec![ChatCompletionChunkChoice {
4126 index: 0,
4127 delta: ChatCompletionChunkChoiceDelta {
4128 role: ChatCompletionRole::Assistant,
4129 content: Some(output),
4130 tool_calls: vec![],
4131 },
4132 logprobs: None,
4133 finish_reason: None,
4134 }],
4135 usage: None,
4136 };
4137
4138 #[cfg(feature = "logging")]
4139 info!(target: "stdout", "created chat completion chunk");
4140
4141 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4143 .map_err(|e| {
4144 let err_msg = format!(
4145 "Failed to serialize chat completion chunk. Reason: {e}"
4146 );
4147
4148 #[cfg(feature = "logging")]
4149 error!(target: "stdout", "{}", &err_msg);
4150
4151 LlamaCoreError::Operation(err_msg)
4152 })?;
4153
4154 Ok(format!("data: {chunk_str}\n\n"))
4155 }
4156 StreamState::Done => {
4157 *stream_state = StreamState::EndOfSequence;
4158
4159 Ok("data: [DONE]\n\n".to_string())
4160 }
4161 StreamState::EndOfSequence => {
4162 Ok("[GGML] End of sequence".to_string())
4163 }
4164 }
4165 }
4166 Err(wasmedge_wasi_nn::Error::BackendError(
4167 wasmedge_wasi_nn::BackendError::EndOfSequence,
4168 )) => {
4169 #[cfg(feature = "logging")]
4170 debug!(target: "stdout", "End of sequence");
4171
4172 match stream_state {
4173 StreamState::Usage => {
4174 *stream_state = StreamState::Done;
4175
4176 let token_info = get_token_info_by_graph(graph)?;
4178
4179 let usage = Some(Usage {
4180 prompt_tokens: token_info.prompt_tokens,
4181 completion_tokens: token_info.completion_tokens,
4182 total_tokens: token_info.prompt_tokens
4183 + token_info.completion_tokens,
4184 });
4185
4186 #[cfg(feature = "logging")]
4187 info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4188
4189 let created = SystemTime::now()
4190 .duration_since(std::time::UNIX_EPOCH)
4191 .map_err(|e| {
4192 let err_msg = format!(
4193 "Failed to get the current time. Reason: {e}"
4194 );
4195
4196 #[cfg(feature = "logging")]
4197 error!(target: "stdout", "{}", &err_msg);
4198
4199 LlamaCoreError::Operation(err_msg)
4200 })?;
4201
4202 let chat_completion_chunk = ChatCompletionChunk {
4203 id,
4204 object: "chat.completion.chunk".to_string(),
4205 created: created.as_secs(),
4206 model: graph.name().to_owned(),
4207 system_fingerprint: "fp_44709d6fcb".to_string(),
4208 choices: vec![],
4209 usage,
4210 };
4211
4212 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4214 .map_err(|e| {
4215 let err_msg = format!(
4216 "Failed to serialize chat completion chunk. Reason: {e}"
4217 );
4218
4219 #[cfg(feature = "logging")]
4220 error!(target: "stdout", "{}", &err_msg);
4221
4222 LlamaCoreError::Operation(err_msg)
4223 })?;
4224
4225 Ok(format!("data: {chunk_str}\n\n"))
4226 }
4227 StreamState::Done | StreamState::NoUsage => {
4228 *stream_state = StreamState::EndOfSequence;
4229
4230 Ok("data: [DONE]\n\n".to_string())
4231 }
4232 StreamState::EndOfSequence => {
4233 Ok("[GGML] End of sequence".to_string())
4234 }
4235 }
4236 }
4237 Err(wasmedge_wasi_nn::Error::BackendError(
4238 wasmedge_wasi_nn::BackendError::ContextFull,
4239 )) => {
4240 #[cfg(feature = "logging")]
4241 debug!(target: "stdout", "Context full");
4242
4243 match context_full_state {
4244 ContextFullState::Message => {
4245 match include_usage {
4246 true => *context_full_state = ContextFullState::Usage,
4247 false => *context_full_state = ContextFullState::Done,
4248 }
4249
4250 let created = SystemTime::now()
4251 .duration_since(std::time::UNIX_EPOCH)
4252 .map_err(|e| {
4253 let err_msg = format!(
4254 "Failed to get the current time. Reason: {e}"
4255 );
4256
4257 #[cfg(feature = "logging")]
4258 error!(target: "stdout", "{}", &err_msg);
4259
4260 LlamaCoreError::Operation(err_msg)
4261 })?;
4262
4263 let chat_completion_chunk = ChatCompletionChunk {
4264 id,
4265 object: "chat.completion.chunk".to_string(),
4266 created: created.as_secs(),
4267 model: graph.name().to_owned(),
4268 system_fingerprint: "fp_44709d6fcb".to_string(),
4269 choices: vec![ChatCompletionChunkChoice {
4270 index: 0,
4271 delta: ChatCompletionChunkChoiceDelta {
4272 role: ChatCompletionRole::Assistant,
4273 content: Some(
4274 "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
4275 ),
4276 tool_calls: vec![],
4277 },
4278 logprobs: None,
4279 finish_reason: Some(FinishReason::length),
4280 }],
4281 usage: None,
4282 };
4283
4284 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4286 .map_err(|e| {
4287 let err_msg = format!(
4288 "Failed to serialize chat completion chunk. Reason: {e}"
4289 );
4290
4291 #[cfg(feature = "logging")]
4292 error!(target: "stdout", "{}", &err_msg);
4293
4294 LlamaCoreError::Operation(err_msg)
4295 })?;
4296
4297 Ok(format!("data: {chunk_str}\n\n"))
4298 }
4299 ContextFullState::Usage => {
4300 *context_full_state = ContextFullState::Done;
4301
4302 let token_info = get_token_info_by_graph(graph)?;
4304
4305 let usage = Some(Usage {
4306 prompt_tokens: token_info.prompt_tokens,
4307 completion_tokens: token_info.completion_tokens,
4308 total_tokens: token_info.prompt_tokens
4309 + token_info.completion_tokens,
4310 });
4311
4312 let created = SystemTime::now()
4313 .duration_since(std::time::UNIX_EPOCH)
4314 .map_err(|e| {
4315 let err_msg = format!(
4316 "Failed to get the current time. Reason: {e}"
4317 );
4318
4319 #[cfg(feature = "logging")]
4320 error!(target: "stdout", "{}", &err_msg);
4321
4322 LlamaCoreError::Operation(err_msg)
4323 })?;
4324
4325 let chat_completion_chunk = ChatCompletionChunk {
4326 id,
4327 object: "chat.completion.chunk".to_string(),
4328 created: created.as_secs(),
4329 model: graph.name().to_owned(),
4330 system_fingerprint: "fp_44709d6fcb".to_string(),
4331 choices: vec![],
4332 usage,
4333 };
4334
4335 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4337 .map_err(|e| {
4338 let err_msg = format!(
4339 "Failed to serialize chat completion chunk. Reason: {e}"
4340 );
4341
4342 #[cfg(feature = "logging")]
4343 error!(target: "stdout", "{}", &err_msg);
4344
4345 LlamaCoreError::Operation(err_msg)
4346 })?;
4347
4348 Ok(format!("data: {chunk_str}\n\n"))
4349 }
4350 ContextFullState::Done => {
4351 *context_full_state = ContextFullState::EndOfSequence;
4352
4353 Ok("data: [DONE]\n\n".to_string())
4354 }
4355 ContextFullState::EndOfSequence => {
4356 Ok("[GGML] End of sequence".to_string())
4357 }
4358 }
4359 }
4360 Err(wasmedge_wasi_nn::Error::BackendError(
4361 wasmedge_wasi_nn::BackendError::PromptTooLong,
4362 )) => {
4363 #[cfg(feature = "logging")]
4364 debug!(target: "stdout", "Prompt too long");
4365
4366 match prompt_too_long_state {
4367 PromptTooLongState::Message => {
4368 match include_usage {
4369 true => *prompt_too_long_state = PromptTooLongState::Usage,
4370 false => *prompt_too_long_state = PromptTooLongState::Done,
4371 }
4372
4373 let created = SystemTime::now()
4374 .duration_since(std::time::UNIX_EPOCH)
4375 .map_err(|e| {
4376 let err_msg = format!(
4377 "Failed to get the current time. Reason: {e}"
4378 );
4379
4380 #[cfg(feature = "logging")]
4381 error!(target: "stdout", "{}", &err_msg);
4382
4383 LlamaCoreError::Operation(err_msg)
4384 })?;
4385
4386 let chat_completion_chunk = ChatCompletionChunk {
4387 id,
4388 object: "chat.completion.chunk".to_string(),
4389 created: created.as_secs(),
4390 model: graph.name().to_owned(),
4391 system_fingerprint: "fp_44709d6fcb".to_string(),
4392 choices: vec![ChatCompletionChunkChoice {
4393 index: 0,
4394 delta: ChatCompletionChunkChoiceDelta {
4395 role: ChatCompletionRole::Assistant,
4396 content: None,
4397 tool_calls: vec![],
4398 },
4399 logprobs: None,
4400 finish_reason: Some(FinishReason::length),
4401 }],
4402 usage: None,
4403 };
4404
4405 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4407 .map_err(|e| {
4408 let err_msg = format!(
4409 "Failed to serialize chat completion chunk. Reason: {e}"
4410 );
4411
4412 #[cfg(feature = "logging")]
4413 error!(target: "stdout", "{}", &err_msg);
4414
4415 LlamaCoreError::Operation(err_msg)
4416 })?;
4417
4418 Ok(format!("data: {chunk_str}\n\n"))
4419 }
4420 PromptTooLongState::Usage => {
4421 *prompt_too_long_state = PromptTooLongState::Done;
4422
4423 let token_info = get_token_info_by_graph(graph)?;
4425
4426 let usage = Some(Usage {
4427 prompt_tokens: token_info.prompt_tokens,
4428 completion_tokens: token_info.completion_tokens,
4429 total_tokens: token_info.prompt_tokens
4430 + token_info.completion_tokens,
4431 });
4432
4433 let created = SystemTime::now()
4434 .duration_since(std::time::UNIX_EPOCH)
4435 .map_err(|e| {
4436 let err_msg = format!(
4437 "Failed to get the current time. Reason: {e}"
4438 );
4439
4440 #[cfg(feature = "logging")]
4441 error!(target: "stdout", "{}", &err_msg);
4442
4443 LlamaCoreError::Operation(err_msg)
4444 })?;
4445
4446 let chat_completion_chunk = ChatCompletionChunk {
4447 id,
4448 object: "chat.completion.chunk".to_string(),
4449 created: created.as_secs(),
4450 model: graph.name().to_owned(),
4451 system_fingerprint: "fp_44709d6fcb".to_string(),
4452 choices: vec![],
4453 usage,
4454 };
4455
4456 let chunk_str = serde_json::to_string(&chat_completion_chunk)
4458 .map_err(|e| {
4459 let err_msg = format!(
4460 "Failed to serialize chat completion chunk. Reason: {e}"
4461 );
4462
4463 #[cfg(feature = "logging")]
4464 error!(target: "stdout", "{}", &err_msg);
4465
4466 LlamaCoreError::Operation(err_msg)
4467 })?;
4468
4469 Ok(format!("data: {chunk_str}\n\n"))
4470 }
4471 PromptTooLongState::Done => {
4472 *prompt_too_long_state = PromptTooLongState::EndOfSequence;
4473
4474 Ok("data: [DONE]\n\n".to_string())
4475 }
4476 PromptTooLongState::EndOfSequence => {
4477 Ok("[GGML] End of sequence".to_string())
4478 }
4479 }
4480 }
4481 Err(e) => {
4482 let err_msg =
4483 format!("Failed to compute the chat completion. Reason: {e}");
4484
4485 #[cfg(feature = "logging")]
4486 error!(target: "stdout", "{}", &err_msg);
4487
4488 Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4489 err_msg,
4490 )))
4491 }
4492 }
4493 }
4494 false => {
4495 match chat_graphs.iter_mut().next() {
4496 Some((_, graph)) => {
4497 match graph.compute_single() {
4499 Ok(_) => {
4500 #[cfg(feature = "logging")]
4501 debug!(target: "stdout", "Compute the chat stream chunk successfully.");
4502
4503 match stream_state {
4504 StreamState::Usage | StreamState::NoUsage => {
4505 let output_buffer =
4507 get_output_buffer_single(graph, OUTPUT_TENSOR)?;
4508
4509 #[cfg(feature = "logging")]
4510 info!(target: "stdout", "retrieved the output buffer");
4511
4512 let output = match String::from_utf8(
4514 output_buffer.clone(),
4515 ) {
4516 Ok(token) => token,
4517 Err(_) => {
4518 let mutex = CACHED_UTF8_ENCODINGS
4519 .get_or_init(|| Mutex::new(Vec::new()));
4520 let mut cached_encodings = mutex.lock().map_err(|e| {
4521 let err_msg = format!(
4522 "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
4523 );
4524
4525 #[cfg(feature = "logging")]
4526 error!(target: "stdout", "{}", &err_msg);
4527
4528
4529 LlamaCoreError::Operation(err_msg)
4530 })?;
4531
4532 cached_encodings
4534 .extend_from_slice(&output_buffer[..]);
4535
4536 match String::from_utf8(
4537 cached_encodings.to_vec(),
4538 ) {
4539 Ok(token) => {
4540 cached_encodings.clear();
4542
4543 token
4544 }
4545 Err(e) => {
4546 if cached_encodings.len() > 4 {
4548 let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
4549
4550 #[cfg(feature = "logging")]
4551 error!(target: "stdout", "{}", &err_msg);
4552
4553 #[cfg(feature = "logging")]
4554 error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
4555
4556 cached_encodings.clear();
4564
4565 String::from("")
4566 } else {
4567 let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
4568
4569 #[cfg(feature = "logging")]
4570 warn!(target: "stdout", "{}", &warn_msg);
4571
4572 String::from("")
4573 }
4574 }
4575 }
4576 }
4577 };
4578
4579 #[cfg(feature = "logging")]
4580 info!(target: "stdout", "decoded the output buffer");
4581
4582 let created = SystemTime::now()
4583 .duration_since(std::time::UNIX_EPOCH)
4584 .map_err(|e| {
4585 let err_msg = format!(
4586 "Failed to get the current time. Reason: {e}"
4587 );
4588
4589 #[cfg(feature = "logging")]
4590 error!(target: "stdout", "{}", &err_msg);
4591
4592 LlamaCoreError::Operation(err_msg)
4593 })?;
4594
4595 let chat_completion_chunk = ChatCompletionChunk {
4596 id,
4597 object: "chat.completion.chunk".to_string(),
4598 created: created.as_secs(),
4599 model: graph.name().to_owned(),
4600 system_fingerprint: "fp_44709d6fcb".to_string(),
4601 choices: vec![ChatCompletionChunkChoice {
4602 index: 0,
4603 delta: ChatCompletionChunkChoiceDelta {
4604 role: ChatCompletionRole::Assistant,
4605 content: Some(output),
4606 tool_calls: vec![],
4607 },
4608 logprobs: None,
4609 finish_reason: None,
4610 }],
4611 usage: None,
4612 };
4613
4614 #[cfg(feature = "logging")]
4615 info!(target: "stdout", "created chat completion chunk");
4616
4617 let chunk_str =
4619 serde_json::to_string(&chat_completion_chunk)
4620 .map_err(|e| {
4621 let err_msg = format!(
4622 "Failed to serialize chat completion chunk. Reason: {e}"
4623 );
4624
4625 #[cfg(feature = "logging")]
4626 error!(target: "stdout", "{}", &err_msg);
4627
4628 LlamaCoreError::Operation(err_msg)
4629 })?;
4630
4631 Ok(format!("data: {chunk_str}\n\n"))
4632 }
4633 StreamState::Done => {
4634 *stream_state = StreamState::EndOfSequence;
4635
4636 Ok("data: [DONE]\n\n".to_string())
4637 }
4638 StreamState::EndOfSequence => {
4639 Ok("[GGML] End of sequence".to_string())
4640 }
4641 }
4642 }
4643 Err(wasmedge_wasi_nn::Error::BackendError(
4644 wasmedge_wasi_nn::BackendError::EndOfSequence,
4645 )) => {
4646 #[cfg(feature = "logging")]
4647 debug!(target: "stdout", "End of sequence");
4648
4649 match stream_state {
4650 StreamState::Usage => {
4651 *stream_state = StreamState::Done;
4652
4653 let token_info = get_token_info_by_graph(graph)?;
4655
4656 let usage = Some(Usage {
4657 prompt_tokens: token_info.prompt_tokens,
4658 completion_tokens: token_info.completion_tokens,
4659 total_tokens: token_info.prompt_tokens
4660 + token_info.completion_tokens,
4661 });
4662
4663 #[cfg(feature = "logging")]
4664 info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4665
4666 let created = SystemTime::now()
4667 .duration_since(std::time::UNIX_EPOCH)
4668 .map_err(|e| {
4669 let err_msg = format!(
4670 "Failed to get the current time. Reason: {e}"
4671 );
4672
4673 #[cfg(feature = "logging")]
4674 error!(target: "stdout", "{}", &err_msg);
4675
4676 LlamaCoreError::Operation(err_msg)
4677 })?;
4678
4679 let chat_completion_chunk = ChatCompletionChunk {
4680 id,
4681 object: "chat.completion.chunk".to_string(),
4682 created: created.as_secs(),
4683 model: graph.name().to_owned(),
4684 system_fingerprint: "fp_44709d6fcb".to_string(),
4685 choices: vec![],
4686 usage,
4687 };
4688
4689 let chunk_str =
4691 serde_json::to_string(&chat_completion_chunk)
4692 .map_err(|e| {
4693 let err_msg = format!(
4694 "Failed to serialize chat completion chunk. Reason: {e}"
4695 );
4696
4697 #[cfg(feature = "logging")]
4698 error!(target: "stdout", "{}", &err_msg);
4699
4700 LlamaCoreError::Operation(err_msg)
4701 })?;
4702
4703 Ok(format!("data: {chunk_str}\n\n"))
4704 }
4705 StreamState::Done | StreamState::NoUsage => {
4706 *stream_state = StreamState::EndOfSequence;
4707
4708 Ok("data: [DONE]\n\n".to_string())
4709 }
4710 StreamState::EndOfSequence => {
4711 Ok("[GGML] End of sequence".to_string())
4712 }
4713 }
4714 }
4715 Err(wasmedge_wasi_nn::Error::BackendError(
4716 wasmedge_wasi_nn::BackendError::ContextFull,
4717 )) => {
4718 #[cfg(feature = "logging")]
4719 debug!(target: "stdout", "Context full");
4720
4721 match context_full_state {
4722 ContextFullState::Message => {
4723 match include_usage {
4724 true => {
4725 *context_full_state = ContextFullState::Usage
4726 }
4727 false => {
4728 *context_full_state = ContextFullState::Done
4729 }
4730 }
4731
4732 let created = SystemTime::now()
4733 .duration_since(std::time::UNIX_EPOCH)
4734 .map_err(|e| {
4735 let err_msg = format!(
4736 "Failed to get the current time. Reason: {e}"
4737 );
4738
4739 #[cfg(feature = "logging")]
4740 error!(target: "stdout", "{}", &err_msg);
4741
4742 LlamaCoreError::Operation(err_msg)
4743 })?;
4744
4745 let chat_completion_chunk = ChatCompletionChunk {
4746 id,
4747 object: "chat.completion.chunk".to_string(),
4748 created: created.as_secs(),
4749 model: graph.name().to_owned(),
4750 system_fingerprint: "fp_44709d6fcb".to_string(),
4751 choices: vec![ChatCompletionChunkChoice {
4752 index: 0,
4753 delta: ChatCompletionChunkChoiceDelta {
4754 role: ChatCompletionRole::Assistant,
4755 content: Some(
4756 "<|WASMEDGE-GGML-CONTEXT-FULL|>"
4757 .to_string(),
4758 ),
4759 tool_calls: vec![],
4760 },
4761 logprobs: None,
4762 finish_reason: Some(FinishReason::length),
4763 }],
4764 usage: None,
4765 };
4766
4767 let chunk_str =
4769 serde_json::to_string(&chat_completion_chunk)
4770 .map_err(|e| {
4771 let err_msg = format!(
4772 "Failed to serialize chat completion chunk. Reason: {e}"
4773 );
4774
4775 #[cfg(feature = "logging")]
4776 error!(target: "stdout", "{}", &err_msg);
4777
4778 LlamaCoreError::Operation(err_msg)
4779 })?;
4780
4781 Ok(format!("data: {chunk_str}\n\n"))
4782 }
4783 ContextFullState::Usage => {
4784 *context_full_state = ContextFullState::Done;
4785
4786 let token_info = get_token_info_by_graph(graph)?;
4788
4789 let usage = Some(Usage {
4790 prompt_tokens: token_info.prompt_tokens,
4791 completion_tokens: token_info.completion_tokens,
4792 total_tokens: token_info.prompt_tokens
4793 + token_info.completion_tokens,
4794 });
4795
4796 let created = SystemTime::now()
4797 .duration_since(std::time::UNIX_EPOCH)
4798 .map_err(|e| {
4799 let err_msg = format!(
4800 "Failed to get the current time. Reason: {e}"
4801 );
4802
4803 #[cfg(feature = "logging")]
4804 error!(target: "stdout", "{}", &err_msg);
4805
4806 LlamaCoreError::Operation(err_msg)
4807 })?;
4808
4809 let chat_completion_chunk = ChatCompletionChunk {
4810 id,
4811 object: "chat.completion.chunk".to_string(),
4812 created: created.as_secs(),
4813 model: graph.name().to_owned(),
4814 system_fingerprint: "fp_44709d6fcb".to_string(),
4815 choices: vec![],
4816 usage,
4817 };
4818
4819 let chunk_str =
4821 serde_json::to_string(&chat_completion_chunk)
4822 .map_err(|e| {
4823 let err_msg = format!(
4824 "Failed to serialize chat completion chunk. Reason: {e}"
4825 );
4826
4827 #[cfg(feature = "logging")]
4828 error!(target: "stdout", "{}", &err_msg);
4829
4830 LlamaCoreError::Operation(err_msg)
4831 })?;
4832
4833 Ok(format!("data: {chunk_str}\n\n"))
4834 }
4835 ContextFullState::Done => {
4836 *context_full_state = ContextFullState::EndOfSequence;
4837
4838 Ok("data: [DONE]\n\n".to_string())
4839 }
4840 ContextFullState::EndOfSequence => {
4841 Ok("[GGML] End of sequence".to_string())
4842 }
4843 }
4844 }
4845 Err(wasmedge_wasi_nn::Error::BackendError(
4846 wasmedge_wasi_nn::BackendError::PromptTooLong,
4847 )) => {
4848 #[cfg(feature = "logging")]
4849 debug!(target: "stdout", "Prompt too long");
4850
4851 match prompt_too_long_state {
4852 PromptTooLongState::Message => {
4853 match include_usage {
4854 true => {
4855 *prompt_too_long_state =
4856 PromptTooLongState::Usage
4857 }
4858 false => {
4859 *prompt_too_long_state =
4860 PromptTooLongState::Done
4861 }
4862 }
4863
4864 let created = SystemTime::now()
4865 .duration_since(std::time::UNIX_EPOCH)
4866 .map_err(|e| {
4867 let err_msg = format!(
4868 "Failed to get the current time. Reason: {e}"
4869 );
4870
4871 #[cfg(feature = "logging")]
4872 error!(target: "stdout", "{}", &err_msg);
4873
4874 LlamaCoreError::Operation(err_msg)
4875 })?;
4876
4877 let chat_completion_chunk = ChatCompletionChunk {
4878 id,
4879 object: "chat.completion.chunk".to_string(),
4880 created: created.as_secs(),
4881 model: graph.name().to_owned(),
4882 system_fingerprint: "fp_44709d6fcb".to_string(),
4883 choices: vec![ChatCompletionChunkChoice {
4884 index: 0,
4885 delta: ChatCompletionChunkChoiceDelta {
4886 role: ChatCompletionRole::Assistant,
4887 content: None,
4888 tool_calls: vec![],
4889 },
4890 logprobs: None,
4891 finish_reason: Some(FinishReason::length),
4892 }],
4893 usage: None,
4894 };
4895
4896 let chunk_str =
4898 serde_json::to_string(&chat_completion_chunk)
4899 .map_err(|e| {
4900 let err_msg = format!(
4901 "Failed to serialize chat completion chunk. Reason: {e}"
4902 );
4903
4904 #[cfg(feature = "logging")]
4905 error!(target: "stdout", "{}", &err_msg);
4906
4907 LlamaCoreError::Operation(err_msg)
4908 })?;
4909
4910 Ok(format!("data: {chunk_str}\n\n"))
4911 }
4912 PromptTooLongState::Usage => {
4913 *prompt_too_long_state = PromptTooLongState::Done;
4914
4915 let token_info = get_token_info_by_graph(graph)?;
4917
4918 let usage = Some(Usage {
4919 prompt_tokens: token_info.prompt_tokens,
4920 completion_tokens: token_info.completion_tokens,
4921 total_tokens: token_info.prompt_tokens
4922 + token_info.completion_tokens,
4923 });
4924
4925 let created = SystemTime::now()
4926 .duration_since(std::time::UNIX_EPOCH)
4927 .map_err(|e| {
4928 let err_msg = format!(
4929 "Failed to get the current time. Reason: {e}"
4930 );
4931
4932 #[cfg(feature = "logging")]
4933 error!(target: "stdout", "{}", &err_msg);
4934
4935 LlamaCoreError::Operation(err_msg)
4936 })?;
4937
4938 let chat_completion_chunk = ChatCompletionChunk {
4939 id,
4940 object: "chat.completion.chunk".to_string(),
4941 created: created.as_secs(),
4942 model: graph.name().to_owned(),
4943 system_fingerprint: "fp_44709d6fcb".to_string(),
4944 choices: vec![],
4945 usage,
4946 };
4947
4948 let chunk_str =
4950 serde_json::to_string(&chat_completion_chunk)
4951 .map_err(|e| {
4952 let err_msg = format!(
4953 "Failed to serialize chat completion chunk. Reason: {e}"
4954 );
4955
4956 #[cfg(feature = "logging")]
4957 error!(target: "stdout", "{}", &err_msg);
4958
4959 LlamaCoreError::Operation(err_msg)
4960 })?;
4961
4962 Ok(format!("data: {chunk_str}\n\n"))
4963 }
4964 PromptTooLongState::Done => {
4965 *prompt_too_long_state =
4966 PromptTooLongState::EndOfSequence;
4967
4968 Ok("data: [DONE]\n\n".to_string())
4969 }
4970 PromptTooLongState::EndOfSequence => {
4971 Ok("[GGML] End of sequence".to_string())
4972 }
4973 }
4974 }
4975 Err(e) => {
4976 let err_msg = format!(
4977 "Failed to compute the chat completion. Reason: {e}"
4978 );
4979
4980 #[cfg(feature = "logging")]
4981 error!(target: "stdout", "{}", &err_msg);
4982
4983 Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4984 err_msg,
4985 )))
4986 }
4987 }
4988 }
4989 None => {
4990 let err_msg = "There is no model available in the chat graphs.";
4991
4992 #[cfg(feature = "logging")]
4993 error!(target: "stdout", "{}", &err_msg);
4994
4995 Err(LlamaCoreError::Operation(err_msg.into()))
4996 }
4997 }
4998 }
4999 }
5000 }
5001 None => {
5002 match chat_graphs.iter_mut().next() {
5003 Some((_, graph)) => {
5004 match graph.compute_single() {
5006 Ok(_) => {
5007 #[cfg(feature = "logging")]
5008 debug!(target: "stdout", "Compute the chat stream chunk successfully.");
5009
5010 match stream_state {
5011 StreamState::Usage | StreamState::NoUsage => {
5012 let output_buffer =
5014 get_output_buffer_single(graph, OUTPUT_TENSOR)?;
5015
5016 #[cfg(feature = "logging")]
5017 info!(target: "stdout", "retrieved the output buffer");
5018
5019 let output = match String::from_utf8(output_buffer.clone()) {
5021 Ok(token) => token,
5022 Err(_) => {
5023 let mutex = CACHED_UTF8_ENCODINGS
5024 .get_or_init(|| Mutex::new(Vec::new()));
5025 let mut cached_encodings = mutex.lock().map_err(|e| {
5026 let err_msg = format!(
5027 "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
5028 );
5029
5030 #[cfg(feature = "logging")]
5031 error!(target: "stdout", "{}", &err_msg);
5032
5033 LlamaCoreError::Operation(err_msg)
5034 })?;
5035
5036 cached_encodings.extend_from_slice(&output_buffer[..]);
5037
5038 match String::from_utf8(cached_encodings.to_vec()) {
5039 Ok(token) => {
5040 cached_encodings.clear();
5042
5043 token
5044 }
5045 Err(e) => {
5046 if cached_encodings.len() > 4 {
5048 let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
5049
5050 #[cfg(feature = "logging")]
5051 error!(target: "stdout", "{}", &err_msg);
5052
5053 #[cfg(feature = "logging")]
5054 error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
5055
5056 cached_encodings.clear();
5063
5064 String::from("")
5065 } else {
5066 let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
5067
5068 #[cfg(feature = "logging")]
5069 warn!(target: "stdout", "{}", &warn_msg);
5070
5071 String::from("")
5072 }
5073 }
5074 }
5075 }
5076 };
5077
5078 #[cfg(feature = "logging")]
5079 info!(target: "stdout", "decoded the output buffer");
5080
5081 let created = SystemTime::now()
5082 .duration_since(std::time::UNIX_EPOCH)
5083 .map_err(|e| {
5084 let err_msg = format!(
5085 "Failed to get the current time. Reason: {e}"
5086 );
5087
5088 #[cfg(feature = "logging")]
5089 error!(target: "stdout", "{}", &err_msg);
5090
5091 LlamaCoreError::Operation(err_msg)
5092 })?;
5093
5094 let chat_completion_chunk = ChatCompletionChunk {
5095 id,
5096 object: "chat.completion.chunk".to_string(),
5097 created: created.as_secs(),
5098 model: graph.name().to_owned(),
5099 system_fingerprint: "fp_44709d6fcb".to_string(),
5100 choices: vec![ChatCompletionChunkChoice {
5101 index: 0,
5102 delta: ChatCompletionChunkChoiceDelta {
5103 role: ChatCompletionRole::Assistant,
5104 content: Some(output),
5105 tool_calls: vec![],
5106 },
5107 logprobs: None,
5108 finish_reason: None,
5109 }],
5110 usage: None,
5111 };
5112
5113 #[cfg(feature = "logging")]
5114 info!(target: "stdout", "created chat completion chunk");
5115
5116 let chunk_str = serde_json::to_string(&chat_completion_chunk)
5118 .map_err(|e| {
5119 let err_msg = format!(
5120 "Failed to serialize chat completion chunk. Reason: {e}"
5121 );
5122
5123 #[cfg(feature = "logging")]
5124 error!(target: "stdout", "{}", &err_msg);
5125
5126 LlamaCoreError::Operation(err_msg)
5127 })?;
5128
5129 Ok(format!("data: {chunk_str}\n\n"))
5130 }
5131 StreamState::Done => {
5132 *stream_state = StreamState::EndOfSequence;
5133
5134 Ok("data: [DONE]\n\n".to_string())
5135 }
5136 StreamState::EndOfSequence => {
5137 Ok("[GGML] End of sequence".to_string())
5138 }
5139 }
5140 }
5141 Err(wasmedge_wasi_nn::Error::BackendError(
5142 wasmedge_wasi_nn::BackendError::EndOfSequence,
5143 )) => {
5144 #[cfg(feature = "logging")]
5145 debug!(target: "stdout", "End of sequence");
5146
5147 match stream_state {
5148 StreamState::Usage => {
5149 *stream_state = StreamState::Done;
5150
5151 let token_info = get_token_info_by_graph(graph)?;
5153
5154 let usage = Some(Usage {
5155 prompt_tokens: token_info.prompt_tokens,
5156 completion_tokens: token_info.completion_tokens,
5157 total_tokens: token_info.prompt_tokens
5158 + token_info.completion_tokens,
5159 });
5160
5161 #[cfg(feature = "logging")]
5162 info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
5163
5164 let created = SystemTime::now()
5165 .duration_since(std::time::UNIX_EPOCH)
5166 .map_err(|e| {
5167 let err_msg = format!(
5168 "Failed to get the current time. Reason: {e}"
5169 );
5170
5171 #[cfg(feature = "logging")]
5172 error!(target: "stdout", "{}", &err_msg);
5173
5174 LlamaCoreError::Operation(err_msg)
5175 })?;
5176
5177 let chat_completion_chunk = ChatCompletionChunk {
5178 id,
5179 object: "chat.completion.chunk".to_string(),
5180 created: created.as_secs(),
5181 model: graph.name().to_owned(),
5182 system_fingerprint: "fp_44709d6fcb".to_string(),
5183 choices: vec![],
5184 usage,
5185 };
5186
5187 let chunk_str = serde_json::to_string(&chat_completion_chunk)
5189 .map_err(|e| {
5190 let err_msg = format!(
5191 "Failed to serialize chat completion chunk. Reason: {e}"
5192 );
5193
5194 #[cfg(feature = "logging")]
5195 error!(target: "stdout", "{}", &err_msg);
5196
5197 LlamaCoreError::Operation(err_msg)
5198 })?;
5199
5200 Ok(format!("data: {chunk_str}\n\n"))
5201 }
5202 StreamState::Done | StreamState::NoUsage => {
5203 *stream_state = StreamState::EndOfSequence;
5204
5205 Ok("data: [DONE]\n\n".to_string())
5206 }
5207 StreamState::EndOfSequence => {
5208 Ok("[GGML] End of sequence".to_string())
5209 }
5210 }
5211 }
5212 Err(wasmedge_wasi_nn::Error::BackendError(
5213 wasmedge_wasi_nn::BackendError::ContextFull,
5214 )) => {
5215 #[cfg(feature = "logging")]
5216 debug!(target: "stdout", "Context full");
5217
5218 match context_full_state {
5219 ContextFullState::Message => {
5220 match include_usage {
5221 true => *context_full_state = ContextFullState::Usage,
5222 false => *context_full_state = ContextFullState::Done,
5223 }
5224
5225 let created = SystemTime::now()
5226 .duration_since(std::time::UNIX_EPOCH)
5227 .map_err(|e| {
5228 let err_msg = format!(
5229 "Failed to get the current time. Reason: {e}"
5230 );
5231
5232 #[cfg(feature = "logging")]
5233 error!(target: "stdout", "{}", &err_msg);
5234
5235 LlamaCoreError::Operation(err_msg)
5236 })?;
5237
5238 let chat_completion_chunk = ChatCompletionChunk {
5239 id,
5240 object: "chat.completion.chunk".to_string(),
5241 created: created.as_secs(),
5242 model: graph.name().to_owned(),
5243 system_fingerprint: "fp_44709d6fcb".to_string(),
5244 choices: vec![ChatCompletionChunkChoice {
5245 index: 0,
5246 delta: ChatCompletionChunkChoiceDelta {
5247 role: ChatCompletionRole::Assistant,
5248 content: Some(
5249 "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
5250 ),
5251 tool_calls: vec![],
5252 },
5253 logprobs: None,
5254 finish_reason: Some(FinishReason::length),
5255 }],
5256 usage: None,
5257 };
5258
5259 let chunk_str = serde_json::to_string(&chat_completion_chunk)
5261 .map_err(|e| {
5262 let err_msg = format!(
5263 "Failed to serialize chat completion chunk. Reason: {e}"
5264 );
5265
5266 #[cfg(feature = "logging")]
5267 error!(target: "stdout", "{}", &err_msg);
5268
5269 LlamaCoreError::Operation(err_msg)
5270 })?;
5271
5272 Ok(format!("data: {chunk_str}\n\n"))
5273 }
5274 ContextFullState::Usage => {
5275 *context_full_state = ContextFullState::Done;
5276
5277 let token_info = get_token_info_by_graph(graph)?;
5279
5280 let usage = Some(Usage {
5281 prompt_tokens: token_info.prompt_tokens,
5282 completion_tokens: token_info.completion_tokens,
5283 total_tokens: token_info.prompt_tokens
5284 + token_info.completion_tokens,
5285 });
5286
5287 let created = SystemTime::now()
5288 .duration_since(std::time::UNIX_EPOCH)
5289 .map_err(|e| {
5290 let err_msg = format!(
5291 "Failed to get the current time. Reason: {e}"
5292 );
5293
5294 #[cfg(feature = "logging")]
5295 error!(target: "stdout", "{}", &err_msg);
5296
5297 LlamaCoreError::Operation(err_msg)
5298 })?;
5299
5300 let chat_completion_chunk = ChatCompletionChunk {
5301 id,
5302 object: "chat.completion.chunk".to_string(),
5303 created: created.as_secs(),
5304 model: graph.name().to_owned(),
5305 system_fingerprint: "fp_44709d6fcb".to_string(),
5306 choices: vec![],
5307 usage,
5308 };
5309
5310 let chunk_str = serde_json::to_string(&chat_completion_chunk)
5312 .map_err(|e| {
5313 let err_msg = format!(
5314 "Failed to serialize chat completion chunk. Reason: {e}"
5315 );
5316
5317 #[cfg(feature = "logging")]
5318 error!(target: "stdout", "{}", &err_msg);
5319
5320 LlamaCoreError::Operation(err_msg)
5321 })?;
5322
5323 Ok(format!("data: {chunk_str}\n\n"))
5324 }
5325 ContextFullState::Done => {
5326 *context_full_state = ContextFullState::EndOfSequence;
5327
5328 Ok("data: [DONE]\n\n".to_string())
5329 }
5330 ContextFullState::EndOfSequence => {
5331 Ok("[GGML] End of sequence".to_string())
5332 }
5333 }
5334 }
5335 Err(wasmedge_wasi_nn::Error::BackendError(
5336 wasmedge_wasi_nn::BackendError::PromptTooLong,
5337 )) => {
5338 #[cfg(feature = "logging")]
5339 debug!(target: "stdout", "Prompt too long");
5340
5341 match prompt_too_long_state {
5342 PromptTooLongState::Message => {
5343 match include_usage {
5344 true => *prompt_too_long_state = PromptTooLongState::Usage,
5345 false => *prompt_too_long_state = PromptTooLongState::Done,
5346 }
5347
5348 let created = SystemTime::now()
5349 .duration_since(std::time::UNIX_EPOCH)
5350 .map_err(|e| {
5351 let err_msg = format!(
5352 "Failed to get the current time. Reason: {e}"
5353 );
5354
5355 #[cfg(feature = "logging")]
5356 error!(target: "stdout", "{}", &err_msg);
5357
5358 LlamaCoreError::Operation(err_msg)
5359 })?;
5360
5361 let chat_completion_chunk = ChatCompletionChunk {
5362 id,
5363 object: "chat.completion.chunk".to_string(),
5364 created: created.as_secs(),
5365 model: graph.name().to_owned(),
5366 system_fingerprint: "fp_44709d6fcb".to_string(),
5367 choices: vec![ChatCompletionChunkChoice {
5368 index: 0,
5369 delta: ChatCompletionChunkChoiceDelta {
5370 role: ChatCompletionRole::Assistant,
5371 content: None,
5372 tool_calls: vec![],
5373 },
5374 logprobs: None,
5375 finish_reason: Some(FinishReason::length),
5376 }],
5377 usage: None,
5378 };
5379
5380 let chunk_str = serde_json::to_string(&chat_completion_chunk)
5382 .map_err(|e| {
5383 let err_msg = format!(
5384 "Failed to serialize chat completion chunk. Reason: {e}"
5385 );
5386
5387 #[cfg(feature = "logging")]
5388 error!(target: "stdout", "{}", &err_msg);
5389
5390 LlamaCoreError::Operation(err_msg)
5391 })?;
5392
5393 Ok(format!("data: {chunk_str}\n\n"))
5394 }
5395 PromptTooLongState::Usage => {
5396 *prompt_too_long_state = PromptTooLongState::Done;
5397
5398 let token_info = get_token_info_by_graph(graph)?;
5400
5401 let usage = Some(Usage {
5402 prompt_tokens: token_info.prompt_tokens,
5403 completion_tokens: token_info.completion_tokens,
5404 total_tokens: token_info.prompt_tokens
5405 + token_info.completion_tokens,
5406 });
5407
5408 let created = SystemTime::now()
5409 .duration_since(std::time::UNIX_EPOCH)
5410 .map_err(|e| {
5411 let err_msg = format!(
5412 "Failed to get the current time. Reason: {e}"
5413 );
5414
5415 #[cfg(feature = "logging")]
5416 error!(target: "stdout", "{}", &err_msg);
5417
5418 LlamaCoreError::Operation(err_msg)
5419 })?;
5420
5421 let chat_completion_chunk = ChatCompletionChunk {
5422 id,
5423 object: "chat.completion.chunk".to_string(),
5424 created: created.as_secs(),
5425 model: graph.name().to_owned(),
5426 system_fingerprint: "fp_44709d6fcb".to_string(),
5427 choices: vec![],
5428 usage,
5429 };
5430
5431 let chunk_str = serde_json::to_string(&chat_completion_chunk)
5433 .map_err(|e| {
5434 let err_msg = format!(
5435 "Failed to serialize chat completion chunk. Reason: {e}"
5436 );
5437
5438 #[cfg(feature = "logging")]
5439 error!(target: "stdout", "{}", &err_msg);
5440
5441 LlamaCoreError::Operation(err_msg)
5442 })?;
5443
5444 Ok(format!("data: {chunk_str}\n\n"))
5445 }
5446 PromptTooLongState::Done => {
5447 *prompt_too_long_state = PromptTooLongState::EndOfSequence;
5448
5449 Ok("data: [DONE]\n\n".to_string())
5450 }
5451 PromptTooLongState::EndOfSequence => {
5452 Ok("[GGML] End of sequence".to_string())
5453 }
5454 }
5455 }
5456 Err(e) => {
5457 let err_msg =
5458 format!("Failed to compute the chat completion. Reason: {e}");
5459
5460 #[cfg(feature = "logging")]
5461 error!(target: "stdout", "{}", &err_msg);
5462
5463 Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
5464 err_msg,
5465 )))
5466 }
5467 }
5468 }
5469 None => {
5470 let err_msg = "There is no model available in the chat graphs.";
5471
5472 #[cfg(feature = "logging")]
5473 error!(target: "stdout", "{}", &err_msg);
5474
5475 Err(LlamaCoreError::Operation(err_msg.into()))
5476 }
5477 }
5478 }
5479 };
5480
5481 #[cfg(feature = "logging")]
5482 info!(target: "stdout", "Return the chat stream chunk!");
5483
5484 res
5485}
5486
5487#[allow(dead_code)]
5488#[derive(Debug)]
5489struct ParseResult {
5490 raw: String,
5491 content: Option<String>,
5492 tool_calls: Vec<ToolCall>,
5493}