llama_core/
chat.rs

1//! Define APIs for chat completion.
2
3use crate::{
4    error,
5    metadata::ggml::GgmlMetadata,
6    running_mode,
7    utils::{
8        gen_chat_id, get_output_buffer, get_output_buffer_single, get_token_info_by_graph,
9        get_token_info_by_graph_name, set_tensor_data_u8,
10    },
11    Graph, RunningMode, CACHED_UTF8_ENCODINGS, CHAT_GRAPHS, OUTPUT_TENSOR,
12};
13use chat_prompts::{BuildChatPrompt, ChatPrompt, PromptTemplateType};
14use either::{Either, Left, Right};
15use endpoints::{
16    chat::{
17        ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionChunkChoiceDelta,
18        ChatCompletionObject, ChatCompletionObjectChoice, ChatCompletionObjectMessage,
19        ChatCompletionRequest, ChatCompletionRequestMessage, ChatCompletionRole,
20        ChatCompletionUserMessageContent, ContentPart, Function, ToolCall, ToolCallForChunk,
21        ToolChoice,
22    },
23    common::{FinishReason, Usage},
24};
25use error::{BackendError, LlamaCoreError};
26use std::{
27    collections::VecDeque,
28    pin::Pin,
29    sync::{
30        atomic::{AtomicBool, Ordering},
31        Mutex, OnceLock,
32    },
33    task::{Context, Poll, Waker},
34    time::SystemTime,
35};
36
37// Define a global waker queue for storing waiting ChatStreams
38static CHAT_STREAM_WAKER_QUEUE: OnceLock<Mutex<VecDeque<Waker>>> = OnceLock::new();
39
40// Define a global atomic boolean indicating whether there is an active ChatStream
41static CHAT_STREAM_ACTIVE: AtomicBool = AtomicBool::new(false);
42
43/// Processes a chat-completion request and returns either a stream of ChatCompletionChunk instances or a ChatCompletionObject instance.
44pub async fn chat(
45    chat_request: &mut ChatCompletionRequest,
46) -> Result<
47    (
48        Either<impl futures::TryStream<Ok = String, Error = LlamaCoreError>, ChatCompletionObject>,
49        bool,
50    ),
51    LlamaCoreError,
52> {
53    #[cfg(feature = "logging")]
54    {
55        debug!(target: "stdout", "tool choice: {:?}", chat_request.tool_choice.as_ref());
56        debug!(target: "stdout", "tools: {:?}", chat_request.tools.as_ref());
57        debug!(target: "stdout", "stream mode: {:?}", chat_request.stream);
58    }
59
60    let result = match chat_request.stream {
61        Some(true) => match chat_stream(chat_request).await {
62            Ok((stream, include_tool_calls)) => Ok((Left(stream), include_tool_calls)),
63            Err(e) => Err(e),
64        },
65        Some(false) | None => match chat_once(chat_request).await {
66            Ok((chat_completion_object, include_tool_calls)) => {
67                Ok((Right(chat_completion_object), include_tool_calls))
68            }
69            Err(e) => Err(e),
70        },
71    };
72
73    #[cfg(feature = "logging")]
74    info!(target: "stdout", "Reset the model metadata");
75
76    result
77}
78
79async fn chat_stream(
80    chat_request: &mut ChatCompletionRequest,
81) -> Result<
82    (
83        impl futures::TryStream<Ok = String, Error = LlamaCoreError>,
84        bool,
85    ),
86    LlamaCoreError,
87> {
88    #[cfg(feature = "logging")]
89    info!(target: "stdout", "Process chat completion request in the stream mode");
90
91    let running_mode = running_mode()?;
92    if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
93        let err_msg = "The chat completion is only supported in the chat or rag mode.";
94
95        #[cfg(feature = "logging")]
96        error!(target: "stdout", "{err_msg}");
97
98        return Err(LlamaCoreError::Operation(err_msg.to_string()));
99    }
100
101    let model_name = chat_request.model.clone();
102    let id = match &chat_request.user {
103        Some(id) => id.clone(),
104        None => gen_chat_id(),
105    };
106    #[cfg(feature = "logging")]
107    info!(target: "stdout", "user: {}", &id);
108
109    #[cfg(feature = "logging")]
110    info!(target: "stdout", "Check model metadata");
111
112    // update metadata
113    let mut metadata = check_model_metadata(chat_request)?;
114
115    // parse the `include_usage` option
116    let include_usage = match chat_request.stream_options {
117        Some(ref stream_options) => stream_options.include_usage.unwrap_or_default(),
118        None => metadata.include_usage,
119    };
120    #[cfg(feature = "logging")]
121    info!(target: "stdout", "include_usage: {include_usage}");
122
123    #[cfg(feature = "logging")]
124    info!(target: "stdout", "Build the chat prompt");
125
126    // build prompt
127    let (prompt, avaible_completion_tokens, tool_use) =
128        build_prompt(model_name.as_ref(), chat_request)?;
129
130    #[cfg(feature = "logging")]
131    {
132        info!(target: "stdout", "prompt:\n{}", &prompt);
133        info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
134        info!(target: "stdout", "tool_use: {tool_use}");
135    }
136
137    #[cfg(feature = "logging")]
138    info!(target: "stdout", "Update the n_predict");
139
140    // update metadata n_predict
141    update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
142
143    #[cfg(feature = "logging")]
144    info!(target: "stdout", "Feed the prompt to the model");
145
146    // set prompt
147    set_prompt(chat_request.model.as_ref(), &prompt)?;
148
149    let stream = match tool_use {
150        false => (ChatStream::new(model_name, id, include_usage, None), false),
151        true => {
152            let chat_graphs = match CHAT_GRAPHS.get() {
153                Some(chat_graphs) => chat_graphs,
154                None => {
155                    let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
156
157                    #[cfg(feature = "logging")]
158                    error!(target: "stdout", "{}", &err_msg);
159
160                    return Err(LlamaCoreError::Operation(err_msg.into()));
161                }
162            };
163
164            let mut chat_graphs = chat_graphs.lock().map_err(|e| {
165                let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
166
167                #[cfg(feature = "logging")]
168                error!(target: "stdout", "{}", &err_msg);
169
170                LlamaCoreError::Operation(err_msg)
171            })?;
172
173            match model_name {
174                Some(model_name) => match chat_graphs.contains_key(&model_name) {
175                    true => {
176                        let graph = chat_graphs.get_mut(&model_name).unwrap();
177                        chat_stream_for_tool(graph, id, include_usage)?
178                    }
179                    false => match chat_graphs.iter_mut().next() {
180                        Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
181                        None => {
182                            let err_msg = "There is no model available in the chat graphs.";
183
184                            #[cfg(feature = "logging")]
185                            error!(target: "stdout", "{}", &err_msg);
186
187                            return Err(LlamaCoreError::Operation(err_msg.into()));
188                        }
189                    },
190                },
191                None => match chat_graphs.iter_mut().next() {
192                    Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
193                    None => {
194                        let err_msg = "There is no model available in the chat graphs.";
195
196                        #[cfg(feature = "logging")]
197                        error!(target: "stdout", "{}", &err_msg);
198
199                        return Err(LlamaCoreError::Operation(err_msg.into()));
200                    }
201                },
202            }
203        }
204    };
205
206    #[cfg(feature = "logging")]
207    info!(target: "stdout", "End of the chat completion stream.");
208
209    Ok(stream)
210}
211
212fn chat_stream_for_tool(
213    graph: &mut Graph<GgmlMetadata>,
214    id: impl Into<String>,
215    include_usage: bool,
216) -> Result<(ChatStream, bool), LlamaCoreError> {
217    #[cfg(feature = "logging")]
218    info!(target: "stdout", "Handle chat request with available tools by the model named {}.", graph.name());
219
220    let id = id.into();
221
222    match graph.compute() {
223        Ok(_) => {
224            // Retrieve the output.
225            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
226            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
227                let err_msg = format!(
228                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
229                );
230
231                #[cfg(feature = "logging")]
232                error!(target: "stdout", "{}", &err_msg);
233
234                LlamaCoreError::Operation(err_msg)
235            })?;
236
237            #[cfg(feature = "logging")]
238            info!(target: "stdout", "raw generation:\n{output}");
239
240            // post-process
241            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
242                LlamaCoreError::Operation(format!("Failed to post-process the output. {e}"))
243            })?;
244
245            #[cfg(feature = "logging")]
246            info!(target: "stdout", "post-processed generation:\n{}", &message);
247
248            // retrieve the number of prompt and completion tokens
249            let token_info = get_token_info_by_graph(graph)?;
250
251            #[cfg(feature = "logging")]
252            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
253
254            let usage = Some(Usage {
255                prompt_tokens: token_info.prompt_tokens,
256                completion_tokens: token_info.completion_tokens,
257                total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
258            });
259
260            let created = SystemTime::now()
261                .duration_since(std::time::UNIX_EPOCH)
262                .map_err(|e| {
263                    let err_msg = format!("Failed to get the current time. Reason: {e}");
264
265                    #[cfg(feature = "logging")]
266                    error!(target: "stdout", "{}", &err_msg);
267
268                    LlamaCoreError::Operation(err_msg)
269                })?;
270
271            if graph.metadata.prompt_template != PromptTemplateType::MistralTool
272                && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
273                && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
274                && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
275                && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
276                && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
277                && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
278                && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
279                && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
280                && graph.metadata.prompt_template != PromptTemplateType::Llama4Chat
281                && graph.metadata.prompt_template != PromptTemplateType::Qwen3NoThink
282                && graph.metadata.prompt_template != PromptTemplateType::Smol3NoThink
283                && graph.metadata.prompt_template != PromptTemplateType::Gemma3
284                && graph.metadata.prompt_template != PromptTemplateType::GptOss
285                && graph.metadata.prompt_template != PromptTemplateType::Qwen3Agent
286                && graph.metadata.prompt_template != PromptTemplateType::SeedOssNoThink
287                && graph.metadata.prompt_template != PromptTemplateType::SeedOssThink
288            {
289                let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', 'mistral-small-tool', 'llama-4-chat', 'qwen3-no-think', 'smol-3-no-think', 'gemma-3', 'gpt-oss', 'qwen3-agent', 'seed-oss-no-think', and 'seed-oss-think' prompt templates.", graph.metadata.prompt_template);
290
291                #[cfg(feature = "logging")]
292                error!(target: "stdout", "{}", &err_msg);
293
294                return Err(LlamaCoreError::Operation(err_msg));
295            }
296
297            let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
298
299            let content = if parsed_result.tool_calls.is_empty() {
300                Some(parsed_result.raw.clone())
301            } else {
302                parsed_result.content.clone()
303            };
304
305            let (tool_calls, include_tool_calls) = match parsed_result.tool_calls.is_empty() {
306                false => {
307                    let tool_calls: Vec<ToolCallForChunk> = parsed_result
308                        .tool_calls
309                        .into_iter()
310                        .enumerate()
311                        .map(|(index, tool_call)| ToolCallForChunk {
312                            index,
313                            id: tool_call.id,
314                            ty: tool_call.ty,
315                            function: tool_call.function,
316                        })
317                        .collect();
318                    (tool_calls, true)
319                }
320                true => (vec![], false),
321            };
322
323            // tool_calls chunk
324            let tool_call_chunk = {
325                let chat_completion_chunk = ChatCompletionChunk {
326                    id: id.clone(),
327                    object: "chat.completion.chunk".to_string(),
328                    created: created.as_secs(),
329                    model: graph.name().to_owned(),
330                    system_fingerprint: "fp_44709d6fcb".to_string(),
331                    choices: vec![ChatCompletionChunkChoice {
332                        index: 0,
333                        delta: ChatCompletionChunkChoiceDelta {
334                            role: ChatCompletionRole::Assistant,
335                            content,
336                            tool_calls,
337                        },
338                        logprobs: None,
339                        finish_reason: None,
340                    }],
341                    usage: None,
342                };
343                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
344                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
345
346                    #[cfg(feature = "logging")]
347                    error!(target: "stdout", "{}", &err_msg);
348
349                    LlamaCoreError::Operation(err_msg)
350                })?;
351
352                format!("data: {chunk_str}\n\n")
353            };
354
355            // token uage chunk
356            let usage_chunk = {
357                let chat_completion_chunk = ChatCompletionChunk {
358                    id: id.clone(),
359                    object: "chat.completion.chunk".to_string(),
360                    created: created.as_secs(),
361                    model: graph.name().to_owned(),
362                    system_fingerprint: "fp_44709d6fcb".to_string(),
363                    choices: vec![],
364                    usage,
365                };
366                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
367                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
368
369                    #[cfg(feature = "logging")]
370                    error!(target: "stdout", "{}", &err_msg);
371
372                    LlamaCoreError::Operation(err_msg)
373                })?;
374
375                format!("data: {chunk_str}\n\n")
376            };
377
378            // ending chunk
379            let ending_chunk = "data: [DONE]\n\n".to_string();
380
381            let chunks = vec![tool_call_chunk, usage_chunk, ending_chunk];
382
383            let stream = ChatStream::new(
384                Some(graph.name().to_owned()),
385                id,
386                include_usage,
387                Some(chunks),
388            );
389
390            Ok((stream, include_tool_calls))
391        }
392        Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
393            // Retrieve the output.
394            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
395            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
396                let err_msg = format!(
397                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
398                );
399
400                #[cfg(feature = "logging")]
401                error!(target: "stdout", "{}", &err_msg);
402
403                LlamaCoreError::Operation(err_msg)
404            })?;
405
406            // post-process
407            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
408                let err_msg = format!("Failed to post-process the output. {e}");
409
410                #[cfg(feature = "logging")]
411                error!(target: "stdout", "{}", &err_msg);
412
413                LlamaCoreError::Operation(err_msg)
414            })?;
415
416            // retrieve the number of prompt and completion tokens
417            let token_info = get_token_info_by_graph(graph)?;
418
419            #[cfg(feature = "logging")]
420            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
421
422            let usage = Some(Usage {
423                prompt_tokens: token_info.prompt_tokens,
424                completion_tokens: token_info.completion_tokens,
425                total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
426            });
427
428            let created = SystemTime::now()
429                .duration_since(std::time::UNIX_EPOCH)
430                .map_err(|e| {
431                    let err_msg = format!("Failed to get the current time. Reason: {e}");
432
433                    #[cfg(feature = "logging")]
434                    error!(target: "stdout", "{}", &err_msg);
435
436                    LlamaCoreError::Operation(err_msg)
437                })?;
438
439            // context full chunk
440            let context_full_chunk = {
441                let chat_completion_chunk = ChatCompletionChunk {
442                    id: id.clone(),
443                    object: "chat.completion.chunk".to_string(),
444                    created: created.as_secs(),
445                    model: graph.name().to_owned(),
446                    system_fingerprint: "fp_44709d6fcb".to_string(),
447                    choices: vec![ChatCompletionChunkChoice {
448                        index: 0,
449                        delta: ChatCompletionChunkChoiceDelta {
450                            role: ChatCompletionRole::Assistant,
451                            content: Some(message),
452                            tool_calls: vec![],
453                        },
454                        logprobs: None,
455                        finish_reason: Some(FinishReason::length),
456                    }],
457                    usage: None,
458                };
459
460                // serialize chat completion chunk
461                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
462                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
463
464                    #[cfg(feature = "logging")]
465                    error!(target: "stdout", "{}", &err_msg);
466
467                    LlamaCoreError::Operation(err_msg)
468                })?;
469
470                format!("data: {chunk_str}\n\n")
471            };
472
473            // usage chunk
474            let usage_chunk = {
475                let chat_completion_chunk = ChatCompletionChunk {
476                    id: id.clone(),
477                    object: "chat.completion.chunk".to_string(),
478                    created: created.as_secs(),
479                    model: graph.name().to_owned(),
480                    system_fingerprint: "fp_44709d6fcb".to_string(),
481                    choices: vec![],
482                    usage,
483                };
484
485                // serialize chat completion chunk
486                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
487                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
488
489                    #[cfg(feature = "logging")]
490                    error!(target: "stdout", "{}", &err_msg);
491
492                    LlamaCoreError::Operation(err_msg)
493                })?;
494
495                format!("data: {chunk_str}\n\n")
496            };
497
498            // ending chunk
499            let ending_chunk = "data: [DONE]\n\n".to_string();
500
501            let chunks = vec![context_full_chunk, usage_chunk, ending_chunk];
502
503            let stream = ChatStream::new(
504                Some(graph.name().to_owned()),
505                id,
506                include_usage,
507                Some(chunks),
508            );
509
510            Ok((stream, false))
511        }
512        Err(wasmedge_wasi_nn::Error::BackendError(
513            wasmedge_wasi_nn::BackendError::PromptTooLong,
514        )) => {
515            #[cfg(feature = "logging")]
516            warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
517
518            // Retrieve the output.
519            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
520            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
521                let err_msg = format!(
522                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
523                );
524
525                #[cfg(feature = "logging")]
526                error!(target: "stdout", "{}", &err_msg);
527
528                LlamaCoreError::Operation(err_msg)
529            })?;
530
531            // post-process
532            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
533                let err_msg = format!("Failed to post-process the output. {e}");
534
535                #[cfg(feature = "logging")]
536                error!(target: "stdout", "{}", &err_msg);
537
538                LlamaCoreError::Operation(err_msg)
539            })?;
540
541            // retrieve the number of prompt and completion token
542            let token_info = get_token_info_by_graph(graph)?;
543
544            #[cfg(feature = "logging")]
545            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
546
547            let usage = Some(Usage {
548                prompt_tokens: token_info.prompt_tokens,
549                completion_tokens: token_info.completion_tokens,
550                total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
551            });
552
553            let created = SystemTime::now()
554                .duration_since(std::time::UNIX_EPOCH)
555                .map_err(|e| {
556                    let err_msg = format!("Failed to get the current time. Reason: {e}");
557
558                    #[cfg(feature = "logging")]
559                    error!(target: "stdout", "{}", &err_msg);
560
561                    LlamaCoreError::Operation(err_msg)
562                })?;
563
564            // prompt too long chunk
565            let prompt_too_long_chunk = {
566                let chat_completion_chunk = ChatCompletionChunk {
567                    id: id.clone(),
568                    object: "chat.completion.chunk".to_string(),
569                    created: created.as_secs(),
570                    model: graph.name().to_owned(),
571                    system_fingerprint: "fp_44709d6fcb".to_string(),
572                    choices: vec![ChatCompletionChunkChoice {
573                        index: 0,
574                        delta: ChatCompletionChunkChoiceDelta {
575                            role: ChatCompletionRole::Assistant,
576                            content: Some(message),
577                            tool_calls: vec![],
578                        },
579                        logprobs: None,
580                        finish_reason: Some(FinishReason::length),
581                    }],
582                    usage: None,
583                };
584
585                // serialize chat completion chunk
586                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
587                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
588
589                    #[cfg(feature = "logging")]
590                    error!(target: "stdout", "{}", &err_msg);
591
592                    LlamaCoreError::Operation(err_msg)
593                })?;
594
595                format!("data: {chunk_str}\n\n")
596            };
597
598            // usage chunk
599            let usage_chunk = {
600                let chat_completion_chunk = ChatCompletionChunk {
601                    id: id.clone(),
602                    object: "chat.completion.chunk".to_string(),
603                    created: created.as_secs(),
604                    model: graph.name().to_owned(),
605                    system_fingerprint: "fp_44709d6fcb".to_string(),
606                    choices: vec![],
607                    usage,
608                };
609
610                // serialize chat completion chunk
611                let chunk_str = serde_json::to_string(&chat_completion_chunk).map_err(|e| {
612                    let err_msg = format!("Failed to serialize chat completion chunk. Reason: {e}");
613
614                    #[cfg(feature = "logging")]
615                    error!(target: "stdout", "{}", &err_msg);
616
617                    LlamaCoreError::Operation(err_msg)
618                })?;
619
620                format!("data: {chunk_str}\n\n")
621            };
622
623            // ending chunk
624            let ending_chunk = "data: [DONE]\n\n".to_string();
625
626            let chunks = vec![prompt_too_long_chunk, usage_chunk, ending_chunk];
627
628            let stream = ChatStream::new(
629                Some(graph.name().to_owned()),
630                id,
631                include_usage,
632                Some(chunks),
633            );
634
635            Ok((stream, false))
636        }
637        Err(e) => {
638            let err_msg = format!("Failed to compute the chat completion. Reason: {e}");
639
640            #[cfg(feature = "logging")]
641            error!(target: "stdout", "{}", &err_msg);
642
643            Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
644        }
645    }
646}
647
648async fn chat_once(
649    chat_request: &mut ChatCompletionRequest,
650) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
651    #[cfg(feature = "logging")]
652    info!(target: "stdout", "Processing chat completion request in non-stream mode");
653
654    let running_mode = running_mode()?;
655    if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
656        let err_msg = "The chat completion is only supported in the chat or rag mode.";
657
658        #[cfg(feature = "logging")]
659        error!(target: "stdout", "{err_msg}");
660
661        return Err(LlamaCoreError::Operation(err_msg.to_string()));
662    }
663
664    let model_name = chat_request.model.clone();
665    let id = match &chat_request.user {
666        Some(id) => id.clone(),
667        None => gen_chat_id(),
668    };
669
670    #[cfg(feature = "logging")]
671    info!(target: "stdout", "user: {}", &id);
672
673    #[cfg(feature = "logging")]
674    info!(target: "stdout", "Check model metadata");
675
676    // update metadata
677    let mut metadata = check_model_metadata(chat_request)?;
678
679    #[cfg(feature = "logging")]
680    info!(target: "stdout", "Build the chat prompt");
681
682    // build prompt
683    let (prompt, avaible_completion_tokens, tool_use) =
684        build_prompt(model_name.as_ref(), chat_request)?;
685
686    #[cfg(feature = "logging")]
687    {
688        info!(target: "stdout", "prompt:\n{}", &prompt);
689        info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
690        info!(target: "stdout", "tool_use: {tool_use}");
691    }
692
693    #[cfg(feature = "logging")]
694    info!(target: "stdout", "Update n_predict");
695
696    // update metadata n_predict
697    update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
698
699    #[cfg(feature = "logging")]
700    info!(target: "stdout", "Feed the prompt to the model");
701
702    // feed the prompt to the model
703    set_prompt(model_name.as_ref(), &prompt)?;
704
705    #[cfg(feature = "logging")]
706    info!(target: "stdout", "Compute chat completion.");
707
708    // compute
709    let res = compute(model_name.as_ref(), id, tool_use);
710
711    #[cfg(feature = "logging")]
712    info!(target: "stdout", "End of the chat completion");
713
714    // reset the model metadata
715    reset_model_metadata(model_name.as_ref())?;
716
717    res
718}
719
720fn compute(
721    model_name: Option<&String>,
722    id: impl Into<String>,
723    tool_use: bool,
724) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
725    let chat_graphs = match CHAT_GRAPHS.get() {
726        Some(chat_graphs) => chat_graphs,
727        None => {
728            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
729
730            #[cfg(feature = "logging")]
731            error!(target: "stdout", "{}", &err_msg);
732
733            return Err(LlamaCoreError::Operation(err_msg.into()));
734        }
735    };
736
737    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
738        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
739
740        #[cfg(feature = "logging")]
741        error!(target: "stdout", "{}", &err_msg);
742
743        LlamaCoreError::Operation(err_msg)
744    })?;
745
746    match model_name {
747        Some(model_name) => match chat_graphs.contains_key(model_name) {
748            true => {
749                let graph = chat_graphs.get_mut(model_name).unwrap();
750                compute_by_graph(graph, id, tool_use)
751            }
752            false => match chat_graphs.iter_mut().next() {
753                Some((_, graph)) => compute_by_graph(graph, id, tool_use),
754                None => {
755                    let err_msg = "There is no model available in the chat graphs.";
756
757                    #[cfg(feature = "logging")]
758                    error!(target: "stdout", "{}", &err_msg);
759
760                    Err(LlamaCoreError::Operation(err_msg.into()))
761                }
762            },
763        },
764        None => match chat_graphs.iter_mut().next() {
765            Some((_, graph)) => compute_by_graph(graph, id, tool_use),
766            None => {
767                let err_msg = "There is no model available in the chat graphs.";
768
769                #[cfg(feature = "logging")]
770                error!(target: "stdout", "{}", &err_msg);
771
772                Err(LlamaCoreError::Operation(err_msg.into()))
773            }
774        },
775    }
776}
777
778fn compute_by_graph(
779    graph: &mut Graph<GgmlMetadata>,
780    id: impl Into<String>,
781    tool_use: bool,
782) -> Result<(ChatCompletionObject, bool), LlamaCoreError> {
783    #[cfg(feature = "logging")]
784    info!(target: "stdout", "Compute chat completion by the model named {}.", graph.name());
785
786    match graph.compute() {
787        Ok(_) => {
788            // Retrieve the output.
789            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
790            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
791                let err_msg = format!(
792                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
793                );
794
795                #[cfg(feature = "logging")]
796                error!(target: "stdout", "{}", &err_msg);
797
798                LlamaCoreError::Operation(err_msg)
799            })?;
800
801            #[cfg(feature = "logging")]
802            info!(target: "stdout", "raw generation: {output}");
803
804            // post-process
805            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
806                LlamaCoreError::Operation(format!("Failed to post-process the output. {e}"))
807            })?;
808
809            #[cfg(feature = "logging")]
810            info!(target: "stdout", "post-processed generation:\n{}", &message);
811
812            // retrieve the number of prompt and completion tokens
813            let token_info = get_token_info_by_graph(graph)?;
814
815            #[cfg(feature = "logging")]
816            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
817
818            let created = SystemTime::now()
819                .duration_since(std::time::UNIX_EPOCH)
820                .map_err(|e| {
821                    let err_msg = format!("Failed to get the current time. Reason: {e}");
822
823                    #[cfg(feature = "logging")]
824                    error!(target: "stdout", "{}", &err_msg);
825
826                    LlamaCoreError::Operation(err_msg)
827                })?;
828
829            match tool_use {
830                true => {
831                    if graph.metadata.prompt_template != PromptTemplateType::MistralTool
832                        && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
833                        && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
834                        && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
835                        && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
836                        && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
837                        && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
838                        && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
839                        && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
840                        && graph.metadata.prompt_template != PromptTemplateType::Llama4Chat
841                        && graph.metadata.prompt_template != PromptTemplateType::Qwen3NoThink
842                        && graph.metadata.prompt_template != PromptTemplateType::Smol3NoThink
843                        && graph.metadata.prompt_template != PromptTemplateType::Gemma3
844                        && graph.metadata.prompt_template != PromptTemplateType::GptOss
845                        && graph.metadata.prompt_template != PromptTemplateType::Qwen3Agent
846                        && graph.metadata.prompt_template != PromptTemplateType::SeedOssNoThink
847                        && graph.metadata.prompt_template != PromptTemplateType::SeedOssThink
848                    {
849                        let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', 'mistral-small-tool', 'llama-4-chat', 'qwen3-no-think', 'smol-3-no-think', 'gemma-3', 'gpt-oss', 'qwen3-agent', 'seed-oss-no-think', and 'seed-oss-think' prompt templates.", graph.metadata.prompt_template);
850
851                        #[cfg(feature = "logging")]
852                        error!(target: "stdout", "{}", &err_msg);
853
854                        return Err(LlamaCoreError::Operation(err_msg));
855                    }
856
857                    let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
858
859                    let (finish_reason, content, include_tool_calls) =
860                        if parsed_result.tool_calls.is_empty() {
861                            (FinishReason::stop, Some(parsed_result.raw.clone()), false)
862                        } else if graph.metadata.prompt_template != PromptTemplateType::Qwen3Agent {
863                            (
864                                FinishReason::tool_calls,
865                                Some(parsed_result.raw.clone()),
866                                true,
867                            )
868                        } else {
869                            (
870                                FinishReason::tool_calls,
871                                parsed_result.content.clone(),
872                                true,
873                            )
874                        };
875
876                    let res = ChatCompletionObject {
877                        id: id.into(),
878                        object: String::from("chat.completion"),
879                        created: created.as_secs(),
880                        model: graph.name().to_owned(),
881                        choices: vec![ChatCompletionObjectChoice {
882                            index: 0,
883                            message: ChatCompletionObjectMessage {
884                                role: ChatCompletionRole::Assistant,
885                                content,
886                                tool_calls: parsed_result.tool_calls,
887                                function_call: None,
888                            },
889                            finish_reason,
890                            logprobs: None,
891                        }],
892                        usage: Usage {
893                            prompt_tokens: token_info.prompt_tokens,
894                            completion_tokens: token_info.completion_tokens,
895                            total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
896                        },
897                    };
898
899                    // create ChatCompletionResponse
900                    Ok((res, include_tool_calls))
901                }
902                false => {
903                    // create ChatCompletionResponse
904                    let res = ChatCompletionObject {
905                        id: id.into(),
906                        object: String::from("chat.completion"),
907                        created: created.as_secs(),
908                        model: graph.name().to_owned(),
909                        choices: vec![ChatCompletionObjectChoice {
910                            index: 0,
911                            message: ChatCompletionObjectMessage {
912                                role: ChatCompletionRole::Assistant,
913                                content: Some(message),
914                                tool_calls: vec![],
915                                function_call: None,
916                            },
917                            finish_reason: FinishReason::stop,
918                            logprobs: None,
919                        }],
920                        usage: Usage {
921                            prompt_tokens: token_info.prompt_tokens,
922                            completion_tokens: token_info.completion_tokens,
923                            total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
924                        },
925                    };
926
927                    Ok((res, false))
928                }
929            }
930        }
931        Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
932            // Retrieve the output.
933            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
934            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
935                let err_msg = format!(
936                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
937                );
938
939                #[cfg(feature = "logging")]
940                error!(target: "stdout", "{}", &err_msg);
941
942                LlamaCoreError::Operation(err_msg)
943            })?;
944
945            // post-process
946            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
947                let err_msg = format!("Failed to post-process the output. {e}");
948
949                #[cfg(feature = "logging")]
950                error!(target: "stdout", "{}", &err_msg);
951
952                LlamaCoreError::Operation(err_msg)
953            })?;
954
955            // retrieve the number of prompt and completion tokens
956            let token_info = get_token_info_by_graph(graph)?;
957
958            #[cfg(feature = "logging")]
959            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
960
961            let created = SystemTime::now()
962                .duration_since(std::time::UNIX_EPOCH)
963                .map_err(|e| {
964                    let err_msg = format!("Failed to get the current time. Reason: {e}");
965
966                    #[cfg(feature = "logging")]
967                    error!(target: "stdout", "{}", &err_msg);
968
969                    LlamaCoreError::Operation(err_msg)
970                })?;
971
972            // create ChatCompletionResponse
973            let res = ChatCompletionObject {
974                id: id.into(),
975                object: String::from("chat.completion"),
976                created: created.as_secs(),
977                model: graph.name().to_owned(),
978                choices: vec![ChatCompletionObjectChoice {
979                    index: 0,
980                    message: ChatCompletionObjectMessage {
981                        role: ChatCompletionRole::Assistant,
982                        content: Some(message),
983                        tool_calls: vec![],
984                        function_call: None,
985                    },
986                    finish_reason: FinishReason::length,
987                    logprobs: None,
988                }],
989                usage: Usage {
990                    prompt_tokens: token_info.prompt_tokens,
991                    completion_tokens: token_info.completion_tokens,
992                    total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
993                },
994            };
995
996            Ok((res, false))
997        }
998        Err(wasmedge_wasi_nn::Error::BackendError(
999            wasmedge_wasi_nn::BackendError::PromptTooLong,
1000        )) => {
1001            #[cfg(feature = "logging")]
1002            warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
1003
1004            // Retrieve the output.
1005            let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
1006            let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
1007                let err_msg = format!(
1008                    "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
1009                );
1010
1011                #[cfg(feature = "logging")]
1012                error!(target: "stdout", "{}", &err_msg);
1013
1014                LlamaCoreError::Operation(err_msg)
1015            })?;
1016
1017            // post-process
1018            let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
1019                let err_msg = format!("Failed to post-process the output. {e}");
1020
1021                #[cfg(feature = "logging")]
1022                error!(target: "stdout", "{}", &err_msg);
1023
1024                LlamaCoreError::Operation(err_msg)
1025            })?;
1026
1027            // retrieve the number of prompt and completion token
1028            let token_info = get_token_info_by_graph(graph)?;
1029
1030            #[cfg(feature = "logging")]
1031            info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
1032
1033            let usage = Usage {
1034                prompt_tokens: token_info.prompt_tokens,
1035                completion_tokens: token_info.completion_tokens,
1036                total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
1037            };
1038
1039            let created = SystemTime::now()
1040                .duration_since(std::time::UNIX_EPOCH)
1041                .map_err(|e| {
1042                    let err_msg = format!("Failed to get the current time. Reason: {e}");
1043
1044                    #[cfg(feature = "logging")]
1045                    error!(target: "stdout", "{}", &err_msg);
1046
1047                    LlamaCoreError::Operation(err_msg)
1048                })?;
1049
1050            // create ChatCompletionResponse
1051            let res = ChatCompletionObject {
1052                id: id.into(),
1053                object: String::from("chat.completion"),
1054                created: created.as_secs(),
1055                model: graph.name().to_owned(),
1056                choices: vec![ChatCompletionObjectChoice {
1057                    index: 0,
1058                    message: ChatCompletionObjectMessage {
1059                        role: ChatCompletionRole::Assistant,
1060                        content: Some(message),
1061                        tool_calls: vec![],
1062                        function_call: None,
1063                    },
1064                    finish_reason: FinishReason::length,
1065                    logprobs: None,
1066                }],
1067                usage,
1068            };
1069
1070            Ok((res, false))
1071        }
1072        Err(e) => {
1073            let err_msg = format!("Failed to compute the chat completion. Reason: {e}");
1074
1075            #[cfg(feature = "logging")]
1076            error!(target: "stdout", "{}", &err_msg);
1077
1078            Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
1079        }
1080    }
1081}
1082
1083fn parse_tool_calls(
1084    input: &str,
1085    prompt_template: PromptTemplateType,
1086) -> Result<ParseResult, LlamaCoreError> {
1087    match prompt_template {
1088        PromptTemplateType::MistralTool => match regex::Regex::new(r"\[\{.*?\}\]") {
1089            Ok(re) => {
1090                let mut values: Vec<serde_json::Value> = vec![];
1091                for cap in re.captures_iter(input) {
1092                    let matched = &cap[0];
1093
1094                    #[cfg(feature = "logging")]
1095                    info!(target: "stdout", "captured: {matched}");
1096
1097                    match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1098                        Ok(group) => values.extend(group),
1099                        Err(e) => {
1100                            let err_msg =
1101                                format!("Failed to deserialize generated tool calls. Reason: {e}");
1102
1103                            #[cfg(feature = "logging")]
1104                            error!(target: "stdout", "{}", &err_msg);
1105
1106                            return Err(LlamaCoreError::Operation(err_msg));
1107                        }
1108                    }
1109                }
1110
1111                let mut tool_calls: Vec<ToolCall> = vec![];
1112                for value in values.iter() {
1113                    let name = match value.get("name") {
1114                        Some(name) => name.to_string().replace("\"", ""),
1115                        None => {
1116                            let err_msg = format!(
1117                                "Failed to get the name of the function. Tool call: {value:?}"
1118                            );
1119
1120                            #[cfg(feature = "logging")]
1121                            error!(target: "stdout", "{}", &err_msg);
1122
1123                            return Err(LlamaCoreError::Operation(err_msg));
1124                        }
1125                    };
1126
1127                    let arguments = match value.get("arguments") {
1128                        Some(arguments) => arguments.to_string(),
1129                        None => {
1130                            let err_msg = format!(
1131                                "Failed to get the arguments of the function. Tool call: {value:?}"
1132                            );
1133
1134                            #[cfg(feature = "logging")]
1135                            error!(target: "stdout", "{}", &err_msg);
1136
1137                            return Err(LlamaCoreError::Operation(err_msg));
1138                        }
1139                    };
1140
1141                    let function = Function { name, arguments };
1142
1143                    let tool_call = ToolCall {
1144                        id: "call_abc123".to_string(),
1145                        ty: "function".to_string(),
1146                        function,
1147                    };
1148
1149                    tool_calls.push(tool_call);
1150                }
1151
1152                let parsed = ParseResult {
1153                    raw: input.to_owned(),
1154                    content: None,
1155                    tool_calls,
1156                };
1157
1158                #[cfg(feature = "logging")]
1159                info!(target: "stdout", "parsed result: {parsed:?}");
1160
1161                Ok(parsed)
1162            }
1163            Err(e) => {
1164                let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1165
1166                #[cfg(feature = "logging")]
1167                error!(target: "stdout", "{}", &err_msg);
1168
1169                Err(LlamaCoreError::Operation(err_msg))
1170            }
1171        },
1172        PromptTemplateType::ChatMLTool => {
1173            match regex::Regex::new(r"<tool_call>(.*?)</tool_call>") {
1174                Ok(re) => {
1175                    let mut values: Vec<serde_json::Value> = vec![];
1176                    for cap in re.captures_iter(input) {
1177                        let matched = cap[1].replace("\\n", ""); // Remove "\\n" from the captured group
1178
1179                        #[cfg(feature = "logging")]
1180                        info!(target: "stdout", "captured: {}", &matched);
1181
1182                        match serde_json::from_str::<serde_json::Value>(&matched) {
1183                            Ok(value) => values.push(value),
1184                            Err(e) => {
1185                                let err_msg = format!(
1186                                    "Failed to deserialize generated tool calls. Reason: {e}"
1187                                );
1188
1189                                #[cfg(feature = "logging")]
1190                                error!(target: "stdout", "{}", &err_msg);
1191
1192                                return Err(LlamaCoreError::Operation(err_msg));
1193                            }
1194                        }
1195                    }
1196
1197                    let mut tool_calls: Vec<ToolCall> = vec![];
1198                    for value in values.iter() {
1199                        let name = match value.get("name") {
1200                            Some(name) => name.to_string().replace("\"", ""),
1201                            None => {
1202                                let err_msg = format!(
1203                                    "Failed to get the name of the function. Tool call: {value:?}"
1204                                );
1205
1206                                #[cfg(feature = "logging")]
1207                                error!(target: "stdout", "{}", &err_msg);
1208
1209                                return Err(LlamaCoreError::Operation(err_msg));
1210                            }
1211                        };
1212
1213                        let arguments = match value.get("arguments") {
1214                            Some(arguments) => arguments.to_string(),
1215                            None => {
1216                                let err_msg = format!(
1217                                    "Failed to get the arguments of the function. Tool call: {value:?}"
1218                                );
1219
1220                                #[cfg(feature = "logging")]
1221                                error!(target: "stdout", "{}", &err_msg);
1222
1223                                return Err(LlamaCoreError::Operation(err_msg));
1224                            }
1225                        };
1226
1227                        let function = Function { name, arguments };
1228
1229                        let tool_call = ToolCall {
1230                            id: "call_abc123".to_string(),
1231                            ty: "function".to_string(),
1232                            function,
1233                        };
1234
1235                        tool_calls.push(tool_call);
1236                    }
1237
1238                    let parsed = ParseResult {
1239                        raw: input.to_owned(),
1240                        content: None,
1241                        tool_calls,
1242                    };
1243
1244                    #[cfg(feature = "logging")]
1245                    info!(target: "stdout", "parsed result: {parsed:?}");
1246
1247                    Ok(parsed)
1248                }
1249                Err(e) => {
1250                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1251
1252                    #[cfg(feature = "logging")]
1253                    error!(target: "stdout", "{}", &err_msg);
1254
1255                    Err(LlamaCoreError::Operation(err_msg))
1256                }
1257            }
1258        }
1259        PromptTemplateType::GroqLlama3Tool => {
1260            #[cfg(feature = "logging")]
1261            info!(target: "stdout", "raw input: {input}");
1262
1263            match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
1264                Ok(re) => {
1265                    let mut values: Vec<serde_json::Value> = vec![];
1266                    for cap in re.captures_iter(input) {
1267                        let matched = cap[1].trim();
1268
1269                        #[cfg(feature = "logging")]
1270                        info!(target: "stdout", "captured: {matched}");
1271
1272                        match serde_json::from_str::<serde_json::Value>(matched) {
1273                            Ok(value) => values.push(value),
1274                            Err(e) => {
1275                                let err_msg = format!(
1276                                    "Failed to deserialize generated tool calls. Reason: {e}"
1277                                );
1278
1279                                #[cfg(feature = "logging")]
1280                                error!(target: "stdout", "{}", &err_msg);
1281
1282                                return Err(LlamaCoreError::Operation(err_msg));
1283                            }
1284                        }
1285                    }
1286
1287                    let mut tool_calls: Vec<ToolCall> = vec![];
1288                    for value in values.iter() {
1289                        let name = match value.get("name") {
1290                            Some(name) => name.to_string().replace("\"", ""),
1291                            None => {
1292                                let err_msg = format!(
1293                                    "Failed to get the name of the function. Tool call: {value:?}"
1294                                );
1295
1296                                #[cfg(feature = "logging")]
1297                                error!(target: "stdout", "{}", &err_msg);
1298
1299                                return Err(LlamaCoreError::Operation(err_msg));
1300                            }
1301                        };
1302
1303                        let arguments = match value.get("arguments") {
1304                            Some(arguments) => {
1305                                if arguments.is_string() {
1306                                    arguments.as_str().unwrap().to_string()
1307                                } else if arguments.is_object() {
1308                                    let map = arguments.as_object().unwrap();
1309
1310                                    #[cfg(feature = "logging")]
1311                                    info!(target: "stdout", "func arguments: {map:?}");
1312
1313                                    serde_json::to_string(map).unwrap()
1314                                } else {
1315                                    serde_json::to_string(arguments).unwrap()
1316                                }
1317                            }
1318                            None => {
1319                                let err_msg = format!(
1320                                    "Failed to get the arguments of the function. Tool call: {value:?}"
1321                                );
1322
1323                                #[cfg(feature = "logging")]
1324                                error!(target: "stdout", "{}", &err_msg);
1325
1326                                return Err(LlamaCoreError::Operation(err_msg));
1327                            }
1328                        };
1329
1330                        let function = Function { name, arguments };
1331
1332                        let tool_call = ToolCall {
1333                            id: "call_abc123".to_string(),
1334                            ty: "function".to_string(),
1335                            function,
1336                        };
1337
1338                        tool_calls.push(tool_call);
1339                    }
1340
1341                    let parsed = if tool_calls.is_empty() {
1342                        ParseResult {
1343                            raw: input.to_owned(),
1344                            content: Some(input.to_owned()),
1345                            tool_calls: vec![],
1346                        }
1347                    } else {
1348                        ParseResult {
1349                            raw: input.to_owned(),
1350                            content: None,
1351                            tool_calls,
1352                        }
1353                    };
1354
1355                    #[cfg(feature = "logging")]
1356                    info!(target: "stdout", "parsed result: {parsed:?}");
1357
1358                    Ok(parsed)
1359                }
1360                Err(e) => {
1361                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1362
1363                    #[cfg(feature = "logging")]
1364                    error!(target: "stdout", "{}", &err_msg);
1365
1366                    Err(LlamaCoreError::Operation(err_msg))
1367                }
1368            }
1369        }
1370        PromptTemplateType::Llama3Tool => {
1371            #[cfg(feature = "logging")]
1372            info!(target: "stdout", "raw input: {input}");
1373
1374            let re = match regex::Regex::new(r"^\{(.|\r|\n)*\}$") {
1375                Ok(re) => re,
1376                Err(e) => {
1377                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1378
1379                    #[cfg(feature = "logging")]
1380                    error!(target: "stdout", "{}", &err_msg);
1381
1382                    return Err(LlamaCoreError::Operation(err_msg));
1383                }
1384            };
1385
1386            if re.is_match(input) {
1387                match serde_json::from_str::<serde_json::Value>(input) {
1388                    Ok(value) => {
1389                        let values: Vec<serde_json::Value> = vec![value];
1390
1391                        let mut tool_calls: Vec<ToolCall> = vec![];
1392                        for value in values.iter() {
1393                            let name = match value.get("name") {
1394                                Some(name) => name.to_string().replace("\"", ""),
1395                                None => {
1396                                    let err_msg = format!(
1397                                        "Failed to get the name of the function. Tool call: {value:?}"
1398                                    );
1399
1400                                    #[cfg(feature = "logging")]
1401                                    error!(target: "stdout", "{}", &err_msg);
1402
1403                                    return Err(LlamaCoreError::Operation(err_msg));
1404                                }
1405                            };
1406
1407                            let arguments = match value.get("parameters") {
1408                                Some(arguments) => arguments.to_string(),
1409                                None => {
1410                                    let err_msg = format!(
1411                                        "Failed to get the arguments of the function. Tool call: {value:?}"
1412                                    );
1413
1414                                    #[cfg(feature = "logging")]
1415                                    error!(target: "stdout", "{}", &err_msg);
1416
1417                                    return Err(LlamaCoreError::Operation(err_msg));
1418                                }
1419                            };
1420
1421                            let function = Function { name, arguments };
1422
1423                            let tool_call = ToolCall {
1424                                id: "call_abc123".to_string(),
1425                                ty: "function".to_string(),
1426                                function,
1427                            };
1428
1429                            tool_calls.push(tool_call);
1430                        }
1431
1432                        let parsed = ParseResult {
1433                            raw: input.to_owned(),
1434                            content: None,
1435                            tool_calls,
1436                        };
1437
1438                        #[cfg(feature = "logging")]
1439                        info!(target: "stdout", "parsed result: {parsed:?}");
1440
1441                        Ok(parsed)
1442                    }
1443                    Err(e) => {
1444                        let err_msg =
1445                            format!("Failed to deserialize generated tool calls. Reason: {e}");
1446
1447                        #[cfg(feature = "logging")]
1448                        error!(target: "stdout", "{}", &err_msg);
1449
1450                        Err(LlamaCoreError::Operation(err_msg))
1451                    }
1452                }
1453            } else {
1454                let parsed = ParseResult {
1455                    raw: input.to_owned(),
1456                    content: None,
1457                    tool_calls: vec![],
1458                };
1459
1460                #[cfg(feature = "logging")]
1461                info!(target: "stdout", "parsed result: {parsed:?}");
1462
1463                Ok(parsed)
1464            }
1465        }
1466        PromptTemplateType::InternLM2Tool => {
1467            #[cfg(feature = "logging")]
1468            info!(target: "stdout", "raw input: {input}");
1469
1470            let blocks: Vec<&str> = input.trim().split("<|action_start|><|plugin|>").collect();
1471
1472            #[cfg(feature = "logging")]
1473            info!(target: "stdout", "blocks: {blocks:?}");
1474
1475            let mut tool_calls: Vec<ToolCall> = vec![];
1476            let mut content = String::new();
1477            for block in blocks {
1478                let block = block.trim();
1479                if !block.is_empty() {
1480                    if block.ends_with("<|action_end|>") {
1481                        let value = block.trim().trim_end_matches("<|action_end|>");
1482
1483                        #[cfg(feature = "logging")]
1484                        info!(target: "stdout", "tool call: {value}");
1485
1486                        match serde_json::from_str::<serde_json::Value>(value) {
1487                            Ok(value) => {
1488                                let name = match value.get("name") {
1489                                    Some(name) => name.to_string().replace("\"", ""),
1490                                    None => {
1491                                        let err_msg = format!(
1492                                            "Failed to get the name of the function. Tool call: {value:?}"
1493                                        );
1494
1495                                        #[cfg(feature = "logging")]
1496                                        error!(target: "stdout", "{}", &err_msg);
1497
1498                                        return Err(LlamaCoreError::Operation(err_msg));
1499                                    }
1500                                };
1501
1502                                let arguments = match value.get("parameters") {
1503                                    Some(arguments) => arguments.to_string(),
1504                                    None => {
1505                                        let err_msg = format!(
1506                                            "Failed to get the arguments of the function. Tool call: {value:?}"
1507                                        );
1508
1509                                        #[cfg(feature = "logging")]
1510                                        error!(target: "stdout", "{}", &err_msg);
1511
1512                                        return Err(LlamaCoreError::Operation(err_msg));
1513                                    }
1514                                };
1515
1516                                let function = Function { name, arguments };
1517
1518                                let tool_call = ToolCall {
1519                                    id: "call_abc123".to_string(),
1520                                    ty: "function".to_string(),
1521                                    function,
1522                                };
1523
1524                                tool_calls.push(tool_call);
1525                            }
1526                            Err(e) => {
1527                                let err_msg = format!(
1528                                    "Failed to deserialize generated tool calls. Reason: {e}"
1529                                );
1530
1531                                #[cfg(feature = "logging")]
1532                                error!(target: "stdout", "{}", &err_msg);
1533
1534                                return Err(LlamaCoreError::Operation(err_msg));
1535                            }
1536                        }
1537                    } else {
1538                        content.push_str(block);
1539                        content.push('\n');
1540                    }
1541                }
1542            }
1543
1544            let parsed = match content.is_empty() {
1545                true => ParseResult {
1546                    raw: input.to_owned(),
1547                    content: None,
1548                    tool_calls,
1549                },
1550                false => ParseResult {
1551                    raw: input.to_owned(),
1552                    content: Some(content.trim().to_owned()),
1553                    tool_calls,
1554                },
1555            };
1556
1557            #[cfg(feature = "logging")]
1558            info!(target: "stdout", "parsed result: {parsed:?}");
1559
1560            Ok(parsed)
1561        }
1562        PromptTemplateType::NemotronTool => {
1563            #[cfg(feature = "logging")]
1564            info!(target: "stdout", "raw input: {input}");
1565
1566            match regex::Regex::new(r"(?s)<toolcall>\s*(.*?)\s*</toolcall>") {
1567                Ok(re) => {
1568                    let mut values: Vec<serde_json::Value> = vec![];
1569                    for cap in re.captures_iter(input) {
1570                        #[cfg(feature = "logging")]
1571                        info!(target: "stdout", "captured: {}", &cap[0]);
1572
1573                        #[cfg(feature = "logging")]
1574                        info!(target: "stdout", "extracted: {}", &cap[1]);
1575
1576                        let matched = cap[1].trim();
1577
1578                        #[cfg(feature = "logging")]
1579                        info!(target: "stdout", "captured: {matched}");
1580
1581                        match serde_json::from_str::<serde_json::Value>(matched) {
1582                            Ok(value) => values.push(value),
1583                            Err(e) => {
1584                                let err_msg = format!(
1585                                    "Failed to deserialize generated tool calls. Reason: {e}"
1586                                );
1587
1588                                #[cfg(feature = "logging")]
1589                                error!(target: "stdout", "{}", &err_msg);
1590
1591                                return Err(LlamaCoreError::Operation(err_msg));
1592                            }
1593                        }
1594                    }
1595
1596                    let mut tool_calls: Vec<ToolCall> = vec![];
1597                    for value in values.iter() {
1598                        let name = match value.get("name") {
1599                            Some(name) => name.to_string().replace("\"", ""),
1600                            None => {
1601                                let err_msg = format!(
1602                                    "Failed to get the name of the function. Tool call: {value:?}"
1603                                );
1604
1605                                #[cfg(feature = "logging")]
1606                                error!(target: "stdout", "{}", &err_msg);
1607
1608                                return Err(LlamaCoreError::Operation(err_msg));
1609                            }
1610                        };
1611
1612                        let arguments = match value.get("arguments") {
1613                            Some(arguments) => arguments.to_string(),
1614                            None => {
1615                                let err_msg = format!(
1616                                    "Failed to get the arguments of the function. Tool call: {value:?}"
1617                                );
1618
1619                                #[cfg(feature = "logging")]
1620                                error!(target: "stdout", "{}", &err_msg);
1621
1622                                return Err(LlamaCoreError::Operation(err_msg));
1623                            }
1624                        };
1625
1626                        let function = Function { name, arguments };
1627
1628                        let tool_call = ToolCall {
1629                            id: "call_abc123".to_string(),
1630                            ty: "function".to_string(),
1631                            function,
1632                        };
1633
1634                        tool_calls.push(tool_call);
1635                    }
1636
1637                    let parsed = ParseResult {
1638                        raw: input.to_owned(),
1639                        content: None,
1640                        tool_calls,
1641                    };
1642
1643                    #[cfg(feature = "logging")]
1644                    info!(target: "stdout", "parsed result: {parsed:?}");
1645
1646                    Ok(parsed)
1647                }
1648                Err(e) => {
1649                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1650
1651                    #[cfg(feature = "logging")]
1652                    error!(target: "stdout", "{}", &err_msg);
1653
1654                    Err(LlamaCoreError::Operation(err_msg))
1655                }
1656            }
1657        }
1658        PromptTemplateType::FunctionaryV32 => {
1659            #[cfg(feature = "logging")]
1660            info!(target: "stdout", "raw input: {input}");
1661
1662            match regex::Regex::new(r">>>\s*(\w+)\s*\{(.*)\}<\|eot_id\|>") {
1663                Ok(re) => {
1664                    let mut tool_calls: Vec<ToolCall> = vec![];
1665                    for cap in re.captures_iter(input) {
1666                        #[cfg(feature = "logging")]
1667                        info!(target: "stdout", "func_name: {}", &cap[1]);
1668
1669                        #[cfg(feature = "logging")]
1670                        info!(target: "stdout", "arguments: {}", &cap[2]);
1671
1672                        let tool_call = ToolCall {
1673                            id: "call_abc123".to_string(),
1674                            ty: "function".to_string(),
1675                            function: Function {
1676                                name: cap[1].to_string(),
1677                                arguments: cap[2].to_string(),
1678                            },
1679                        };
1680
1681                        tool_calls.push(tool_call);
1682                    }
1683
1684                    let parsed = ParseResult {
1685                        raw: input.to_owned(),
1686                        content: None,
1687                        tool_calls,
1688                    };
1689
1690                    #[cfg(feature = "logging")]
1691                    info!(target: "stdout", "parsed result: {parsed:?}");
1692
1693                    Ok(parsed)
1694                }
1695                Err(e) => {
1696                    let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1697
1698                    #[cfg(feature = "logging")]
1699                    warn!(target: "stdout", "{}", &warn_msg);
1700
1701                    Ok(ParseResult {
1702                        raw: input.to_owned(),
1703                        content: None,
1704                        tool_calls: vec![],
1705                    })
1706                }
1707            }
1708        }
1709        PromptTemplateType::FunctionaryV31 => {
1710            #[cfg(feature = "logging")]
1711            info!(target: "stdout", "raw input: {input}");
1712
1713            match regex::Regex::new(r"<function=(\w+)>\s*(\{.*?\})</function>") {
1714                Ok(re) => {
1715                    let mut tool_calls: Vec<ToolCall> = vec![];
1716                    for cap in re.captures_iter(input) {
1717                        #[cfg(feature = "logging")]
1718                        info!(target: "stdout", "func_name: {}", &cap[1]);
1719
1720                        #[cfg(feature = "logging")]
1721                        info!(target: "stdout", "arguments: {}", &cap[2]);
1722
1723                        let tool_call = ToolCall {
1724                            id: "call_abc123".to_string(),
1725                            ty: "function".to_string(),
1726                            function: Function {
1727                                name: cap[1].to_string(),
1728                                arguments: cap[2].to_string(),
1729                            },
1730                        };
1731
1732                        tool_calls.push(tool_call);
1733                    }
1734
1735                    let parsed = ParseResult {
1736                        raw: input.to_owned(),
1737                        content: None,
1738                        tool_calls,
1739                    };
1740
1741                    #[cfg(feature = "logging")]
1742                    info!(target: "stdout", "parsed result: {parsed:?}");
1743
1744                    Ok(parsed)
1745                }
1746                Err(e) => {
1747                    let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1748
1749                    #[cfg(feature = "logging")]
1750                    warn!(target: "stdout", "{}", &warn_msg);
1751
1752                    Ok(ParseResult {
1753                        raw: input.to_owned(),
1754                        content: None,
1755                        tool_calls: vec![],
1756                    })
1757                }
1758            }
1759        }
1760        PromptTemplateType::MistralSmallTool => {
1761            #[cfg(feature = "logging")]
1762            info!(target: "stdout", "raw input: {input}");
1763
1764            match regex::Regex::new(r"\[TOOL_CALLS\]\s*(\[(.*?)\])") {
1765                Ok(re) => {
1766                    let mut values: Vec<serde_json::Value> = vec![];
1767                    if let Some(cap) = re.captures(input) {
1768                        let matched = cap[1].trim();
1769
1770                        #[cfg(feature = "logging")]
1771                        info!(target: "stdout", "captured: {matched}");
1772
1773                        match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1774                            Ok(vals) => values = vals,
1775                            Err(e) => {
1776                                let err_msg = format!(
1777                                    "Failed to deserialize generated tool calls. Reason: {e}"
1778                                );
1779
1780                                #[cfg(feature = "logging")]
1781                                error!(target: "stdout", "{}", &err_msg);
1782
1783                                return Err(LlamaCoreError::Operation(err_msg));
1784                            }
1785                        }
1786                    };
1787
1788                    let mut tool_calls: Vec<ToolCall> = vec![];
1789                    for value in values.iter() {
1790                        if let Some(object_map) = value.as_object() {
1791                            if object_map.contains_key("function") {
1792                                let mut function = Function {
1793                                    name: String::new(),
1794                                    arguments: String::new(),
1795                                };
1796
1797                                let value = object_map.get("function").unwrap();
1798                                let func_map = value.as_object().unwrap();
1799                                if func_map.contains_key("name") {
1800                                    let func_name = func_map.get("name").unwrap().as_str().unwrap();
1801                                    println!("Function name: {func_name:?}");
1802
1803                                    function.name = func_name.to_string();
1804                                }
1805                                if func_map.contains_key("arguments") {
1806                                    let args = func_map.get("arguments").unwrap();
1807                                    let arguments = args.to_string();
1808                                    println!("Arguments: {arguments:?}");
1809
1810                                    function.arguments = arguments;
1811                                }
1812
1813                                let tool_call = ToolCall {
1814                                    id: "call_abc123".to_string(),
1815                                    ty: "function".to_string(),
1816                                    function,
1817                                };
1818
1819                                tool_calls.push(tool_call);
1820                            } else if object_map.contains_key("name") {
1821                                let mut function = Function {
1822                                    name: String::new(),
1823                                    arguments: String::new(),
1824                                };
1825
1826                                let name = object_map.get("name").unwrap().as_str().unwrap();
1827                                println!("name: {name:?}");
1828                                function.name = name.to_string();
1829
1830                                if object_map.contains_key("arguments") {
1831                                    let args = object_map.get("arguments").unwrap();
1832                                    let arguments = args.to_string();
1833                                    println!("Arguments: {arguments:?}");
1834
1835                                    function.arguments = arguments;
1836                                }
1837
1838                                let tool_call = ToolCall {
1839                                    id: "call_abc123".to_string(),
1840                                    ty: "function".to_string(),
1841                                    function,
1842                                };
1843
1844                                tool_calls.push(tool_call);
1845                            }
1846                        }
1847                    }
1848
1849                    let parsed = ParseResult {
1850                        raw: input.to_owned(),
1851                        content: None,
1852                        tool_calls,
1853                    };
1854
1855                    #[cfg(feature = "logging")]
1856                    info!(target: "stdout", "parsed result: {parsed:?}");
1857
1858                    Ok(parsed)
1859                }
1860                Err(e) => {
1861                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1862
1863                    #[cfg(feature = "logging")]
1864                    error!(target: "stdout", "{}", &err_msg);
1865
1866                    Err(LlamaCoreError::Operation(err_msg))
1867                }
1868            }
1869        }
1870        PromptTemplateType::Llama4Chat => {
1871            #[cfg(feature = "logging")]
1872            info!(target: "stdout", "raw input: {input:?}");
1873
1874            let mut tool_calls: Vec<ToolCall> = vec![];
1875            if let Ok(value) = serde_json::from_str::<serde_json::Value>(input) {
1876                match value.as_object() {
1877                    Some(object_map) => {
1878                        #[cfg(feature = "logging")]
1879                        debug!(target: "stdout", "object_map: {object_map:?}");
1880
1881                        // parse function name
1882                        if object_map.contains_key("name") {
1883                            let name = object_map.get("name").unwrap().as_str().unwrap();
1884
1885                            #[cfg(feature = "logging")]
1886                            debug!(target: "stdout", "name: {name:?}");
1887
1888                            let mut function = Function {
1889                                name: name.to_string(),
1890                                arguments: String::new(),
1891                            };
1892
1893                            // parse function arguments
1894                            if object_map.contains_key("parameters") {
1895                                let args = object_map.get("parameters").unwrap();
1896                                let arguments = args.to_string();
1897
1898                                #[cfg(feature = "logging")]
1899                                debug!(target: "stdout", "arguments: {:?}", &arguments);
1900
1901                                function.arguments = arguments;
1902                            }
1903
1904                            tool_calls.push(ToolCall {
1905                                id: "call_abc123".to_string(),
1906                                ty: "function".to_string(),
1907                                function,
1908                            });
1909                        } else {
1910                            let err_msg = format!(
1911                                "Failed to get the name of the function. raw input: {input:?}"
1912                            );
1913
1914                            #[cfg(feature = "logging")]
1915                            error!(target: "stdout", "{}", &err_msg);
1916
1917                            return Err(LlamaCoreError::Operation(err_msg));
1918                        }
1919                    }
1920                    None => {
1921                        let err_msg = format!("Failed to parse the JSON string. JSON: {input}");
1922
1923                        #[cfg(feature = "logging")]
1924                        error!(target: "stdout", "{}", &err_msg);
1925
1926                        return Err(LlamaCoreError::Operation(err_msg));
1927                    }
1928                }
1929            }
1930
1931            let parsed = ParseResult {
1932                raw: input.to_owned(),
1933                content: None,
1934                tool_calls,
1935            };
1936
1937            #[cfg(feature = "logging")]
1938            info!(target: "stdout", "parsed result: {parsed:?}");
1939
1940            Ok(parsed)
1941        }
1942        PromptTemplateType::Qwen3NoThink | PromptTemplateType::Smol3NoThink => {
1943            #[cfg(feature = "logging")]
1944            info!(target: "stdout", "raw input: {input:?}");
1945
1946            match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
1947                Ok(re) => {
1948                    let mut values: Vec<serde_json::Value> = vec![];
1949                    for cap in re.captures_iter(input) {
1950                        let mut matched = cap[1].trim();
1951
1952                        if matched.starts_with("\\n") {
1953                            matched = matched.trim_start_matches("\\n");
1954                        }
1955
1956                        if matched.ends_with("\\n") {
1957                            matched = matched.trim_end_matches("\\n");
1958                        }
1959
1960                        #[cfg(feature = "logging")]
1961                        info!(target: "stdout", "captured: {matched:#?}");
1962
1963                        if !matched.is_empty() {
1964                            match serde_json::from_str::<serde_json::Value>(matched) {
1965                                Ok(value) => values.push(value),
1966                                Err(e) => {
1967                                    let err_msg = format!(
1968                                    "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
1969                                );
1970
1971                                    #[cfg(feature = "logging")]
1972                                    error!(target: "stdout", "{}", &err_msg);
1973
1974                                    return Err(LlamaCoreError::Operation(err_msg));
1975                                }
1976                            }
1977                        }
1978                    }
1979
1980                    let mut tool_calls: Vec<ToolCall> = vec![];
1981                    for value in values.iter() {
1982                        let name = match value.get("name") {
1983                            Some(name) => name.to_string().replace("\"", ""),
1984                            None => {
1985                                let err_msg = format!(
1986                                    "Failed to get the name of the function. Tool call: {value:?}"
1987                                );
1988
1989                                #[cfg(feature = "logging")]
1990                                error!(target: "stdout", "{}", &err_msg);
1991
1992                                return Err(LlamaCoreError::Operation(err_msg));
1993                            }
1994                        };
1995
1996                        let arguments = match value.get("arguments") {
1997                            Some(arguments) => {
1998                                if arguments.is_string() {
1999                                    arguments.as_str().unwrap().to_string()
2000                                } else if arguments.is_object() {
2001                                    let map = arguments.as_object().unwrap();
2002
2003                                    #[cfg(feature = "logging")]
2004                                    info!(target: "stdout", "func arguments: {map:?}");
2005
2006                                    serde_json::to_string(map).unwrap()
2007                                } else {
2008                                    serde_json::to_string(arguments).unwrap()
2009                                }
2010                            }
2011                            None => {
2012                                let err_msg = format!(
2013                                    "Failed to get the arguments of the function. Tool call: {value:?}"
2014                                );
2015
2016                                #[cfg(feature = "logging")]
2017                                error!(target: "stdout", "{}", &err_msg);
2018
2019                                return Err(LlamaCoreError::Operation(err_msg));
2020                            }
2021                        };
2022
2023                        let function = Function { name, arguments };
2024
2025                        let tool_call = ToolCall {
2026                            id: "call_abc123".to_string(),
2027                            ty: "function".to_string(),
2028                            function,
2029                        };
2030
2031                        tool_calls.push(tool_call);
2032                    }
2033
2034                    let parsed = if tool_calls.is_empty() {
2035                        ParseResult {
2036                            raw: input.to_owned(),
2037                            content: Some(input.to_owned()),
2038                            tool_calls: vec![],
2039                        }
2040                    } else {
2041                        ParseResult {
2042                            raw: input.to_owned(),
2043                            content: None,
2044                            tool_calls,
2045                        }
2046                    };
2047
2048                    #[cfg(feature = "logging")]
2049                    info!(target: "stdout", "parsed result: {parsed:?}");
2050
2051                    Ok(parsed)
2052                }
2053                Err(e) => {
2054                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2055
2056                    #[cfg(feature = "logging")]
2057                    error!(target: "stdout", "{}", &err_msg);
2058
2059                    Err(LlamaCoreError::Operation(err_msg))
2060                }
2061            }
2062        }
2063        PromptTemplateType::Gemma3 => {
2064            #[cfg(feature = "logging")]
2065            info!(target: "stdout", "raw input: {input:?}");
2066
2067            match regex::Regex::new(r"(?s)```json\s*(.*?)\s*```") {
2068                Ok(re) => {
2069                    let mut values: Vec<serde_json::Value> = vec![];
2070                    for cap in re.captures_iter(input) {
2071                        let mut matched = cap[1].trim();
2072
2073                        if matched.starts_with("\\n") {
2074                            matched = matched.trim_start_matches("\\n");
2075                        }
2076
2077                        if matched.ends_with("\\n") {
2078                            matched = matched.trim_end_matches("\\n");
2079                        }
2080
2081                        #[cfg(feature = "logging")]
2082                        info!(target: "stdout", "captured: {matched:#?}");
2083
2084                        if !matched.is_empty() {
2085                            match serde_json::from_str::<serde_json::Value>(matched) {
2086                                Ok(value) => values.push(value),
2087                                Err(e) => {
2088                                    let err_msg = format!(
2089                                    "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
2090                                );
2091
2092                                    #[cfg(feature = "logging")]
2093                                    error!(target: "stdout", "{}", &err_msg);
2094
2095                                    return Err(LlamaCoreError::Operation(err_msg));
2096                                }
2097                            }
2098                        }
2099                    }
2100
2101                    let mut tool_calls: Vec<ToolCall> = vec![];
2102                    for value in values.iter() {
2103                        let name = match value.get("name") {
2104                            Some(name) => name.to_string().replace("\"", ""),
2105                            None => {
2106                                let err_msg = format!(
2107                                    "Failed to get the name of the function. Tool call: {value:?}"
2108                                );
2109
2110                                #[cfg(feature = "logging")]
2111                                error!(target: "stdout", "{}", &err_msg);
2112
2113                                return Err(LlamaCoreError::Operation(err_msg));
2114                            }
2115                        };
2116
2117                        let arguments = match value.get("arguments") {
2118                            Some(arguments) => {
2119                                if arguments.is_string() {
2120                                    arguments.as_str().unwrap().to_string()
2121                                } else if arguments.is_object() {
2122                                    let map = arguments.as_object().unwrap();
2123
2124                                    #[cfg(feature = "logging")]
2125                                    info!(target: "stdout", "func arguments: {map:?}");
2126
2127                                    serde_json::to_string(map).unwrap()
2128                                } else {
2129                                    serde_json::to_string(arguments).unwrap()
2130                                }
2131                            }
2132                            None => {
2133                                let err_msg = format!(
2134                                    "Failed to get the arguments of the function. Tool call: {value:?}"
2135                                );
2136
2137                                #[cfg(feature = "logging")]
2138                                error!(target: "stdout", "{}", &err_msg);
2139
2140                                return Err(LlamaCoreError::Operation(err_msg));
2141                            }
2142                        };
2143
2144                        let function = Function { name, arguments };
2145
2146                        let tool_call = ToolCall {
2147                            id: "call_abc123".to_string(),
2148                            ty: "function".to_string(),
2149                            function,
2150                        };
2151
2152                        tool_calls.push(tool_call);
2153                    }
2154
2155                    let parsed = if tool_calls.is_empty() {
2156                        ParseResult {
2157                            raw: input.to_owned(),
2158                            content: Some(input.to_owned()),
2159                            tool_calls: vec![],
2160                        }
2161                    } else {
2162                        ParseResult {
2163                            raw: input.to_owned(),
2164                            content: None,
2165                            tool_calls,
2166                        }
2167                    };
2168
2169                    #[cfg(feature = "logging")]
2170                    info!(target: "stdout", "parsed result: {parsed:?}");
2171
2172                    Ok(parsed)
2173                }
2174                Err(e) => {
2175                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2176
2177                    #[cfg(feature = "logging")]
2178                    error!(target: "stdout", "{}", &err_msg);
2179
2180                    Err(LlamaCoreError::Operation(err_msg))
2181                }
2182            }
2183        }
2184        PromptTemplateType::GptOss => {
2185            #[cfg(feature = "logging")]
2186            info!(target: "stdout", "raw input: {input:?}");
2187
2188            // Match strings ending with: <|channel|>commentary to=functions.xxxxx <|constrain|>json<|message|>yyyyy<|call|>
2189            match regex::Regex::new(
2190                r"<\|channel\|>commentary to=functions\.([^<\s]+)\s*<\|constrain\|>json<\|message\|>([^<]*)<\|call\|>$",
2191            ) {
2192                Ok(re) => {
2193                    if let Some(cap) = re.captures(input) {
2194                        let function_name = cap[1].trim();
2195                        let arguments = cap[2].trim();
2196
2197                        #[cfg(feature = "logging")]
2198                        info!(target: "stdout", "extracted function_name: {function_name}, arguments: {arguments}");
2199
2200                        let function = Function {
2201                            name: function_name.to_string(),
2202                            arguments: arguments.to_string(),
2203                        };
2204
2205                        let tool_call = ToolCall {
2206                            id: "call_abc123".to_string(),
2207                            ty: "function".to_string(),
2208                            function,
2209                        };
2210
2211                        let parsed = ParseResult {
2212                            raw: input.to_owned(),
2213                            content: None,
2214                            tool_calls: vec![tool_call],
2215                        };
2216
2217                        #[cfg(feature = "logging")]
2218                        info!(target: "stdout", "parsed result: {parsed:?}");
2219
2220                        Ok(parsed)
2221                    } else {
2222                        match regex::Regex::new(r"(?s)```json\s*(.*?)\s*```") {
2223                            Ok(re) => {
2224                                let mut values: Vec<serde_json::Value> = vec![];
2225                                for cap in re.captures_iter(input) {
2226                                    let mut matched = cap[1].trim();
2227
2228                                    if matched.starts_with("\\n") {
2229                                        matched = matched.trim_start_matches("\\n");
2230                                    }
2231
2232                                    if matched.ends_with("\\n") {
2233                                        matched = matched.trim_end_matches("\\n");
2234                                    }
2235
2236                                    #[cfg(feature = "logging")]
2237                                    info!(target: "stdout", "captured: {matched:#?}");
2238
2239                                    if !matched.is_empty() {
2240                                        match serde_json::from_str::<serde_json::Value>(matched) {
2241                                            Ok(value) => values.push(value),
2242                                            Err(e) => {
2243                                                let err_msg = format!(
2244                                                "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
2245                                            );
2246
2247                                                #[cfg(feature = "logging")]
2248                                                error!(target: "stdout", "{}", &err_msg);
2249
2250                                                return Err(LlamaCoreError::Operation(err_msg));
2251                                            }
2252                                        }
2253                                    }
2254                                }
2255
2256                                let mut tool_calls: Vec<ToolCall> = vec![];
2257                                for value in values.iter() {
2258                                    let name = match value.get("name") {
2259                                        Some(name) => name.to_string().replace("\"", ""),
2260                                        None => {
2261                                            let err_msg = format!(
2262                                                "Failed to get the name of the function. Tool call: {value:?}"
2263                                            );
2264
2265                                            #[cfg(feature = "logging")]
2266                                            error!(target: "stdout", "{}", &err_msg);
2267
2268                                            return Err(LlamaCoreError::Operation(err_msg));
2269                                        }
2270                                    };
2271
2272                                    let arguments = match value.get("arguments") {
2273                                        Some(arguments) => {
2274                                            if arguments.is_string() {
2275                                                arguments.as_str().unwrap().to_string()
2276                                            } else if arguments.is_object() {
2277                                                let map = arguments.as_object().unwrap();
2278
2279                                                #[cfg(feature = "logging")]
2280                                                info!(target: "stdout", "func arguments: {map:?}");
2281
2282                                                serde_json::to_string(map).unwrap()
2283                                            } else {
2284                                                serde_json::to_string(arguments).unwrap()
2285                                            }
2286                                        }
2287                                        None => {
2288                                            let err_msg = format!(
2289                                                "Failed to get the arguments of the function. Tool call: {value:?}"
2290                                            );
2291
2292                                            #[cfg(feature = "logging")]
2293                                            error!(target: "stdout", "{}", &err_msg);
2294
2295                                            return Err(LlamaCoreError::Operation(err_msg));
2296                                        }
2297                                    };
2298
2299                                    let function = Function { name, arguments };
2300
2301                                    let tool_call = ToolCall {
2302                                        id: "call_abc123".to_string(),
2303                                        ty: "function".to_string(),
2304                                        function,
2305                                    };
2306
2307                                    tool_calls.push(tool_call);
2308                                }
2309
2310                                let parsed = if tool_calls.is_empty() {
2311                                    ParseResult {
2312                                        raw: input.to_owned(),
2313                                        content: Some(input.to_owned()),
2314                                        tool_calls: vec![],
2315                                    }
2316                                } else {
2317                                    ParseResult {
2318                                        raw: input.to_owned(),
2319                                        content: Some(input.to_owned()),
2320                                        tool_calls,
2321                                    }
2322                                };
2323
2324                                #[cfg(feature = "logging")]
2325                                info!(target: "stdout", "parsed result: {parsed:?}");
2326
2327                                Ok(parsed)
2328                            }
2329                            Err(e) => {
2330                                let err_msg =
2331                                    format!("Failed to create a regex pattern. Reason: {e}");
2332
2333                                #[cfg(feature = "logging")]
2334                                error!(target: "stdout", "{}", &err_msg);
2335
2336                                Err(LlamaCoreError::Operation(err_msg))
2337                            }
2338                        }
2339                    }
2340                }
2341                Err(e) => {
2342                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2343
2344                    #[cfg(feature = "logging")]
2345                    error!(target: "stdout", "{}", &err_msg);
2346
2347                    Err(LlamaCoreError::Operation(err_msg))
2348                }
2349            }
2350        }
2351        PromptTemplateType::Qwen3Agent => {
2352            #[cfg(feature = "logging")]
2353            info!(target: "stdout", "Raw input to tool call parser: {input:?}");
2354
2355            // detect <action> tags
2356            match regex::Regex::new(r"<action>(.*?)</action>")
2357                .unwrap()
2358                .captures(input)
2359            {
2360                Some(captures) => {
2361                    let action = captures.get(1).unwrap().as_str();
2362
2363                    #[cfg(feature = "logging")]
2364                    info!(target: "stdout", "Action: {action}");
2365
2366                    match serde_json::from_str::<serde_json::Value>(action) {
2367                        Ok(value) => {
2368                            let name = match value.get("name") {
2369                                Some(name) => name.to_string().replace("\"", ""),
2370                                None => {
2371                                    let err_msg = format!(
2372                                        "Failed to get the name of the function. Tool call: {value:?}"
2373                                    );
2374
2375                                    #[cfg(feature = "logging")]
2376                                    error!(target: "stdout", "{}", &err_msg);
2377
2378                                    return Err(LlamaCoreError::Operation(err_msg));
2379                                }
2380                            };
2381
2382                            let arguments = match value.get("arguments") {
2383                                Some(arguments) => {
2384                                    if arguments.is_string() {
2385                                        arguments.as_str().unwrap().to_string()
2386                                    } else if arguments.is_object() {
2387                                        let map = arguments.as_object().unwrap();
2388
2389                                        #[cfg(feature = "logging")]
2390                                        info!(target: "stdout", "func arguments: {map:?}");
2391
2392                                        serde_json::to_string(map).unwrap()
2393                                    } else {
2394                                        serde_json::to_string(arguments).unwrap()
2395                                    }
2396                                }
2397                                None => {
2398                                    let err_msg = format!(
2399                                        "Failed to get the arguments of the function. Tool call: {value:?}"
2400                                    );
2401
2402                                    #[cfg(feature = "logging")]
2403                                    error!(target: "stdout", "{}", &err_msg);
2404
2405                                    return Err(LlamaCoreError::Operation(err_msg));
2406                                }
2407                            };
2408
2409                            let function = Function { name, arguments };
2410
2411                            let tool_call = ToolCall {
2412                                id: "call_abc123".to_string(),
2413                                ty: "function".to_string(),
2414                                function,
2415                            };
2416
2417                            Ok(ParseResult {
2418                                raw: input.to_owned(),
2419                                content: Some(input.to_owned()),
2420                                tool_calls: vec![tool_call],
2421                            })
2422                        }
2423                        Err(e) => {
2424                            let err_msg = format!(
2425                            "Failed to deserialize generated tool calls: {action:#?}. Reason: {e}"
2426                        );
2427
2428                            #[cfg(feature = "logging")]
2429                            error!(target: "stdout", "{}", &err_msg);
2430
2431                            Err(LlamaCoreError::Operation(err_msg))
2432                        }
2433                    }
2434                }
2435                None => match input.contains("<final_answer>") {
2436                    true => Ok(ParseResult {
2437                        raw: input.to_owned(),
2438                        content: Some(input.to_owned()),
2439                        tool_calls: vec![],
2440                    }),
2441                    false => {
2442                        let content = format!("<final_answer>{}</final_answer>", input.trim());
2443
2444                        Ok(ParseResult {
2445                            raw: input.to_owned(),
2446                            content: Some(content),
2447                            tool_calls: vec![],
2448                        })
2449                    }
2450                },
2451            }
2452        }
2453        PromptTemplateType::SeedOssNoThink | PromptTemplateType::SeedOssThink => {
2454            #[cfg(feature = "logging")]
2455            info!(target: "stdout", "Raw input to tool call parser: {input:?}");
2456
2457            match regex::Regex::new(r"```json\n([\s\S]*?)\n") {
2458                Ok(re) => {
2459                    let mut values: Vec<serde_json::Value> = vec![];
2460                    for cap in re.captures_iter(input) {
2461                        let mut matched = cap[1].trim();
2462
2463                        if matched.starts_with("\\n") {
2464                            matched = matched.trim_start_matches("\\n");
2465                        }
2466
2467                        if matched.ends_with("\\n") {
2468                            matched = matched.trim_end_matches("\\n");
2469                        }
2470
2471                        #[cfg(feature = "logging")]
2472                        info!(target: "stdout", "captured: {matched:#?}");
2473
2474                        if !matched.is_empty() {
2475                            match serde_json::from_str::<serde_json::Value>(matched) {
2476                                Ok(value) => values.push(value),
2477                                Err(e) => {
2478                                    let err_msg = format!(
2479                                    "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
2480                                );
2481
2482                                    #[cfg(feature = "logging")]
2483                                    error!(target: "stdout", "{}", &err_msg);
2484
2485                                    return Err(LlamaCoreError::Operation(err_msg));
2486                                }
2487                            }
2488                        }
2489                    }
2490
2491                    let mut tool_calls: Vec<ToolCall> = vec![];
2492                    for value in values.iter() {
2493                        let name = match value.get("name") {
2494                            Some(name) => name.to_string().replace("\"", ""),
2495                            None => {
2496                                let err_msg = format!(
2497                                    "Failed to get the name of the function. Tool call: {value:?}"
2498                                );
2499
2500                                #[cfg(feature = "logging")]
2501                                error!(target: "stdout", "{}", &err_msg);
2502
2503                                return Err(LlamaCoreError::Operation(err_msg));
2504                            }
2505                        };
2506
2507                        let arguments = match value.get("arguments") {
2508                            Some(arguments) => {
2509                                if arguments.is_string() {
2510                                    arguments.as_str().unwrap().to_string()
2511                                } else if arguments.is_object() {
2512                                    let map = arguments.as_object().unwrap();
2513
2514                                    #[cfg(feature = "logging")]
2515                                    info!(target: "stdout", "func arguments: {map:?}");
2516
2517                                    serde_json::to_string(map).unwrap()
2518                                } else {
2519                                    serde_json::to_string(arguments).unwrap()
2520                                }
2521                            }
2522                            None => {
2523                                let err_msg = format!(
2524                                    "Failed to get the arguments of the function. Tool call: {value:?}"
2525                                );
2526
2527                                #[cfg(feature = "logging")]
2528                                error!(target: "stdout", "{}", &err_msg);
2529
2530                                return Err(LlamaCoreError::Operation(err_msg));
2531                            }
2532                        };
2533
2534                        let function = Function { name, arguments };
2535
2536                        let tool_call = ToolCall {
2537                            id: "call_abc123".to_string(),
2538                            ty: "function".to_string(),
2539                            function,
2540                        };
2541
2542                        tool_calls.push(tool_call);
2543                    }
2544
2545                    let parsed = if tool_calls.is_empty() {
2546                        ParseResult {
2547                            raw: input.to_owned(),
2548                            content: Some(input.to_owned()),
2549                            tool_calls: vec![],
2550                        }
2551                    } else {
2552                        ParseResult {
2553                            raw: input.to_owned(),
2554                            content: None,
2555                            tool_calls,
2556                        }
2557                    };
2558
2559                    #[cfg(feature = "logging")]
2560                    info!(target: "stdout", "parsed result: {parsed:?}");
2561
2562                    Ok(parsed)
2563                }
2564                Err(e) => {
2565                    let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2566
2567                    #[cfg(feature = "logging")]
2568                    error!(target: "stdout", "{}", &err_msg);
2569
2570                    Err(LlamaCoreError::Operation(err_msg))
2571                }
2572            }
2573        }
2574        _ => {
2575            let err_msg = format!(
2576                "The tool use is only supported for prompt templates: {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, and {}.",
2577                PromptTemplateType::MistralTool,
2578                PromptTemplateType::ChatMLTool,
2579                PromptTemplateType::GroqLlama3Tool,
2580                PromptTemplateType::Llama3Tool,
2581                PromptTemplateType::InternLM2Tool,
2582                PromptTemplateType::NemotronTool,
2583                PromptTemplateType::FunctionaryV32,
2584                PromptTemplateType::MistralSmallTool,
2585                PromptTemplateType::Llama4Chat,
2586                PromptTemplateType::Qwen3NoThink,
2587                PromptTemplateType::Smol3NoThink,
2588                PromptTemplateType::Gemma3,
2589                PromptTemplateType::GptOss,
2590                PromptTemplateType::Qwen3Agent,
2591                PromptTemplateType::SeedOssNoThink,
2592                PromptTemplateType::SeedOssThink
2593            );
2594
2595            #[cfg(feature = "logging")]
2596            error!(target: "stdout", "{}", &err_msg);
2597
2598            Err(LlamaCoreError::Operation(err_msg))
2599        }
2600    }
2601}
2602
2603fn check_model_metadata(
2604    chat_request: &ChatCompletionRequest,
2605) -> Result<GgmlMetadata, LlamaCoreError> {
2606    let mut should_update = false;
2607    let mut metadata = get_model_metadata(chat_request.model.as_ref())?;
2608
2609    // check if necessary to update `image`
2610    if metadata.prompt_template.is_image_supported() {
2611        if let Some(ChatCompletionRequestMessage::User(user_message)) = chat_request.messages.last()
2612        {
2613            if let ChatCompletionUserMessageContent::Parts(parts) = user_message.content() {
2614                for part in parts {
2615                    if let ContentPart::Image(image_part) = part {
2616                        let image = image_part.image();
2617
2618                        if image.is_url() {
2619                            let err_msg = "The image is provided in URL format. Only base64 format is supported.".to_string();
2620
2621                            #[cfg(feature = "logging")]
2622                            error!(target: "stdout", "{}", &err_msg);
2623
2624                            return Err(LlamaCoreError::Operation(err_msg));
2625                        } else {
2626                            #[cfg(feature = "logging")]
2627                            info!(target: "stdout", "The image is provided in base64 format.");
2628
2629                            // TODO: now only support a single image
2630
2631                            break;
2632                        }
2633                    }
2634                }
2635            }
2636        }
2637    }
2638
2639    // check if necessary to update temperature
2640    if let Some(temp) = chat_request.temperature {
2641        if metadata.temperature != temp {
2642            // update temperature
2643            metadata.temperature = temp;
2644
2645            if !should_update {
2646                should_update = true;
2647            }
2648        }
2649    }
2650
2651    // check if necessary to update top_p
2652    if let Some(top_p) = chat_request.top_p {
2653        if metadata.top_p != top_p {
2654            // update top_p
2655            metadata.top_p = top_p;
2656
2657            if !should_update {
2658                should_update = true;
2659            }
2660        }
2661    }
2662
2663    // check if necessary to update frequency_penalty
2664    if let Some(frequency_penalty) = chat_request.frequency_penalty {
2665        if metadata.frequency_penalty != frequency_penalty {
2666            // update frequency_penalty
2667            metadata.frequency_penalty = frequency_penalty;
2668
2669            if !should_update {
2670                should_update = true;
2671            }
2672        }
2673    }
2674
2675    // check if necessary to update presence_penalty
2676    if let Some(presence_penalty) = chat_request.presence_penalty {
2677        if metadata.presence_penalty != presence_penalty {
2678            // update presence_penalty
2679            metadata.presence_penalty = presence_penalty;
2680
2681            if !should_update {
2682                should_update = true;
2683            }
2684        }
2685    }
2686
2687    // check if the `embedding` option is disabled
2688    if metadata.embeddings {
2689        metadata.embeddings = false;
2690
2691        if !should_update {
2692            should_update = true;
2693        }
2694    }
2695
2696    if should_update {
2697        #[cfg(feature = "logging")]
2698        info!(target: "stdout", "Update the model metadata.");
2699
2700        // update the target graph with the new metadata
2701        update_model_metadata(chat_request.model.as_ref(), &metadata)?;
2702    }
2703
2704    Ok(metadata)
2705}
2706
2707fn update_n_predict(
2708    chat_request: &ChatCompletionRequest,
2709    metadata: &mut GgmlMetadata,
2710    available_completion_tokens: u64,
2711) -> Result<(), LlamaCoreError> {
2712    let mut should_update = false;
2713
2714    #[cfg(feature = "logging")]
2715    info!(target: "stdout", "n_predict: {}", metadata.n_predict);
2716
2717    // From high to low priority
2718    // 1. chat_request.max_completion_tokens
2719    // 2. available_completion_tokens
2720    // 3. n_predict
2721
2722    if let Some(max_completion_tokens) = chat_request.max_completion_tokens {
2723        if metadata.n_predict != max_completion_tokens {
2724            #[cfg(feature = "logging")]
2725            info!(target: "stdout", "Update n_predict with max_completion_tokens from {} to {}", metadata.n_predict, max_completion_tokens);
2726
2727            metadata.n_predict = max_completion_tokens;
2728
2729            if !should_update {
2730                should_update = true;
2731            }
2732        }
2733    }
2734
2735    // TODO: remove this condition after [Issue #3958 on WasmEdge](https://github.com/WasmEdge/WasmEdge/issues/3958) is fixed
2736    if metadata.n_predict == -2 {
2737        #[cfg(feature = "logging")]
2738        info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2739
2740        // update n_predict
2741        metadata.n_predict = available_completion_tokens as i32;
2742
2743        if !should_update {
2744            should_update = true;
2745        }
2746    }
2747
2748    if metadata.n_predict == -1
2749        || (metadata.n_predict > 0 && metadata.n_predict < available_completion_tokens as i32)
2750        || (metadata.n_predict < 0 && metadata.n_predict != -2)
2751    // TODO: remove this condition after [Issue #3958 on WasmEdge](https://github.com/WasmEdge/WasmEdge/issues/3958) is fixed
2752    {
2753        #[cfg(feature = "logging")]
2754        info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2755
2756        // update n_predict
2757        metadata.n_predict = available_completion_tokens as i32;
2758
2759        if !should_update {
2760            should_update = true;
2761        }
2762    }
2763
2764    if should_update {
2765        #[cfg(feature = "logging")]
2766        info!(target: "stdout", "Update the model metadata.");
2767
2768        // update the target graph with the new metadata
2769        update_model_metadata(chat_request.model.as_ref(), metadata)?;
2770    }
2771
2772    Ok(())
2773}
2774
2775/// Build post-processing for output based on template type
2776fn post_process(
2777    output: impl AsRef<str>,
2778    template_ty: &PromptTemplateType,
2779) -> Result<String, String> {
2780    let output = if *template_ty == PromptTemplateType::Baichuan2 {
2781        if output.as_ref().contains("用户:") {
2782            output.as_ref().trim_end_matches("用户:").trim().to_owned()
2783        } else {
2784            output.as_ref().trim().to_owned()
2785        }
2786    } else if *template_ty == PromptTemplateType::OpenChat {
2787        if output.as_ref().contains("<|end_of_turn|>") {
2788            output
2789                .as_ref()
2790                .trim_end_matches("<|end_of_turn|>")
2791                .trim()
2792                .to_owned()
2793        } else {
2794            output.as_ref().trim().to_owned()
2795        }
2796    } else if *template_ty == PromptTemplateType::GemmaInstruct
2797        || *template_ty == PromptTemplateType::Gemma3
2798    {
2799        let s = output.as_ref().trim();
2800        if s.ends_with("<end_of_turn>") {
2801            s.trim_end_matches("<end_of_turn>").trim().to_owned()
2802        } else {
2803            s.to_owned()
2804        }
2805    } else if *template_ty == PromptTemplateType::ChatML
2806        || *template_ty == PromptTemplateType::ChatMLTool
2807        || *template_ty == PromptTemplateType::InternLM2Tool
2808        || *template_ty == PromptTemplateType::MiniCPMV
2809    {
2810        let mut s = output.as_ref().trim();
2811        if s.ends_with("<|endoftext|>") {
2812            s = s.trim_end_matches("<|endoftext|>").trim();
2813        }
2814
2815        if s.starts_with(":") {
2816            s = s.trim_start_matches(":").trim();
2817        }
2818
2819        // handle Qwen3 empty think tags
2820        let x = {
2821            let pat = r#"<think>
2822
2823</think>
2824"#;
2825            if s.contains(pat) {
2826                let x = s.replace(pat, "");
2827                if x.starts_with("()") {
2828                    x.trim_start_matches("()").to_owned()
2829                } else {
2830                    x.to_owned()
2831                }
2832            } else {
2833                s.to_owned()
2834            }
2835        };
2836        s = x.trim();
2837
2838        if s.contains("<|im_start|>") && s.contains("<|im_end|>") {
2839            let idx_start = s.find("<|im_start|>").unwrap();
2840            let idx_end = s.find("<|im_end|>").unwrap();
2841
2842            match idx_start <= idx_end {
2843                true => s.split("<|im_start|>").collect::<Vec<_>>()[0]
2844                    .trim()
2845                    .to_owned(),
2846                false => s.split("<|im_end|>").collect::<Vec<_>>()[0]
2847                    .trim()
2848                    .to_owned(),
2849            }
2850        } else if s.contains("<|im_start|>") {
2851            s.split("<|im_start|>").collect::<Vec<_>>()[0]
2852                .trim()
2853                .to_owned()
2854        } else if s.contains("<|im_end|>") {
2855            let output = s.trim_end_matches("<|im_end|>").trim();
2856            if output.starts_with(": ") {
2857                output.trim_start_matches(": ").to_owned()
2858            } else {
2859                output.to_owned()
2860            }
2861        } else {
2862            s.to_owned()
2863        }
2864    } else if *template_ty == PromptTemplateType::Zephyr
2865        || *template_ty == PromptTemplateType::MistralLite
2866        || *template_ty == PromptTemplateType::MistralTool
2867        || *template_ty == PromptTemplateType::MistralInstruct
2868        || *template_ty == PromptTemplateType::MistralSmallChat
2869        || *template_ty == PromptTemplateType::MistralSmallTool
2870        || *template_ty == PromptTemplateType::BreezeInstruct
2871    {
2872        if output.as_ref().contains("</s><") {
2873            output.as_ref().trim_end_matches("</s><").trim().to_owned()
2874        } else if output.as_ref().contains("</s>") {
2875            output
2876                .as_ref()
2877                .strip_suffix("</s>")
2878                .unwrap()
2879                .trim()
2880                .to_owned()
2881        } else {
2882            output.as_ref().trim().to_owned()
2883        }
2884    } else if *template_ty == PromptTemplateType::DeepseekChat {
2885        if output.as_ref().contains("<|end_of_sentence|>") {
2886            output
2887                .as_ref()
2888                .trim_end_matches("<|end_of_sentence|>")
2889                .trim()
2890                .replace("<|end_of_sentence|>", " ")
2891                .trim()
2892                .to_owned()
2893        } else {
2894            output.as_ref().trim().to_owned()
2895        }
2896    } else if *template_ty == PromptTemplateType::HumanAssistant {
2897        if output.as_ref().contains("Human:") {
2898            output.as_ref().trim_end_matches("Human:").trim().to_owned()
2899        } else {
2900            output.as_ref().trim().to_owned()
2901        }
2902    } else if *template_ty == PromptTemplateType::SolarInstruct {
2903        let s = output.as_ref().trim();
2904
2905        if s.starts_with("### Answer") {
2906            let s = s.trim_start_matches("###").trim();
2907
2908            if s.starts_with("Answer:\n") {
2909                s.replace("Answer:\n", "Answer: ")
2910            } else {
2911                s.to_owned()
2912            }
2913        } else {
2914            s.to_owned()
2915        }
2916    } else if *template_ty == PromptTemplateType::Llama2Chat
2917        || *template_ty == PromptTemplateType::NemotronTool
2918        || *template_ty == PromptTemplateType::NemotronChat
2919    {
2920        let s = output.as_ref().trim();
2921        if s.ends_with("</s>") {
2922            s.trim_end_matches("</s>").trim().to_owned()
2923        } else {
2924            s.to_owned()
2925        }
2926    } else if *template_ty == PromptTemplateType::Llama3Chat
2927        || *template_ty == PromptTemplateType::GroqLlama3Tool
2928        || *template_ty == PromptTemplateType::Llama3Tool
2929        || *template_ty == PromptTemplateType::FunctionaryV32
2930    {
2931        let s = output.as_ref().trim();
2932        if s.ends_with("<|eot_id|>") {
2933            s.trim_end_matches("<|eot_id|>").trim().to_owned()
2934        } else {
2935            s.to_owned()
2936        }
2937    } else if *template_ty == PromptTemplateType::Phi3Chat {
2938        let s = output.as_ref().trim();
2939        if s.ends_with("<|end|>") {
2940            s.trim_end_matches("<|end|>").trim().to_owned()
2941        } else {
2942            s.to_owned()
2943        }
2944    } else if *template_ty == PromptTemplateType::Phi4Chat {
2945        let mut s = output.as_ref().trim();
2946
2947        if s.starts_with("think>") {
2948            s = s.trim_start_matches("think>").trim();
2949        }
2950
2951        if s.ends_with("<|im_end|>") {
2952            s.trim_end_matches("<|im_end|>").trim().to_owned()
2953        } else if s.ends_with("<|end|>") {
2954            s.trim_end_matches("<|end|>").trim().to_owned()
2955        } else {
2956            s.to_owned()
2957        }
2958    } else if *template_ty == PromptTemplateType::FunctionaryV31 {
2959        let mut s = output.as_ref().trim();
2960        if s.ends_with("<|eot_id|>") {
2961            s = s.trim_end_matches("<|eot_id|>").trim();
2962        }
2963        if s.ends_with("<|eom_id|>") {
2964            s = s.trim_end_matches("<|eom_id|>").trim();
2965        }
2966        s.to_owned()
2967    } else if *template_ty == PromptTemplateType::MoxinChat
2968        || *template_ty == PromptTemplateType::MoxinInstruct
2969    {
2970        let s = output.as_ref().trim();
2971        if s.ends_with("</s>") {
2972            s.trim_end_matches("</s>").trim().to_owned()
2973        } else if s.ends_with("[INST]") {
2974            s.trim_end_matches("[INST]").trim().to_owned()
2975        } else {
2976            s.to_owned()
2977        }
2978    } else if *template_ty == PromptTemplateType::Falcon3 {
2979        let s = output.as_ref().trim();
2980        if s.ends_with("<|endoftext|>") {
2981            s.trim_end_matches("<|endoftext|>").trim().to_owned()
2982        } else {
2983            s.to_owned()
2984        }
2985    } else if *template_ty == PromptTemplateType::Megrez {
2986        let s = output.as_ref().trim();
2987        if s.ends_with("<|turn_end|>") {
2988            s.trim_end_matches("<|turn_end|>").trim().to_owned()
2989        } else {
2990            s.to_owned()
2991        }
2992    } else if *template_ty == PromptTemplateType::Qwen2vl
2993        || *template_ty == PromptTemplateType::Qwen3NoThink
2994        || *template_ty == PromptTemplateType::ChatMLThink
2995    {
2996        let mut s = output.as_ref().trim();
2997
2998        if s.starts_with(":") {
2999            s = s.trim_start_matches(":").trim();
3000        }
3001
3002        if s.starts_with("</think>") {
3003            s = s.trim_start_matches("</think>").trim();
3004        }
3005
3006        if s.ends_with("<|im_end|>") {
3007            s.trim_end_matches("<|im_end|>").trim().to_owned()
3008        } else {
3009            s.to_owned()
3010        }
3011    } else if *template_ty == PromptTemplateType::VicunaLlava {
3012        let s = output.as_ref().trim();
3013        if s.ends_with("</s>") {
3014            s.trim_end_matches("</s>").trim().to_owned()
3015        } else {
3016            s.to_owned()
3017        }
3018    } else if *template_ty == PromptTemplateType::ExaoneDeepChat
3019        || *template_ty == PromptTemplateType::ExaoneChat
3020    {
3021        let mut s = output.as_ref().trim();
3022
3023        if s.ends_with("[|endofturn|]") {
3024            s = s.trim_end_matches("[|endofturn|]").trim();
3025        }
3026
3027        s.to_owned()
3028    } else if *template_ty == PromptTemplateType::Llama4Chat {
3029        let mut s = output.as_ref().trim();
3030
3031        if s.ends_with("<|eot|>") {
3032            s = s.trim_end_matches("<|eot|>").trim();
3033        }
3034
3035        s.to_owned()
3036    } else if *template_ty == PromptTemplateType::Smolvl {
3037        let mut s = output.as_ref().trim();
3038
3039        if s.starts_with(":") {
3040            s = s.trim_start_matches(":").trim();
3041        }
3042
3043        if s.ends_with("<end_of_utterance>") {
3044            s = s.trim_end_matches("<end_of_utterance>").trim();
3045        }
3046
3047        if s.contains("<end_of_utterance>:") {
3048            let parts = s.split("<end_of_utterance>:").collect::<Vec<_>>();
3049            parts.last().unwrap().trim().to_owned()
3050        } else {
3051            s.to_owned()
3052        }
3053    } else if *template_ty == PromptTemplateType::Smol3NoThink {
3054        let mut s = output.as_ref().trim();
3055
3056        if s.ends_with("<|im_end|>") {
3057            s = s.trim_end_matches("<|im_end|>").trim();
3058        }
3059
3060        let re = regex::Regex::new(r"(?s)^<think>.*?</think>\s*").unwrap();
3061        re.replace(s, "").to_string()
3062    } else if *template_ty == PromptTemplateType::GptOss {
3063        let s = output.as_ref().trim();
3064
3065        let re =
3066            regex::Regex::new(r"(?s).*<\|channel\|>final<\|message\|>(.*?)<\|return\|>$").unwrap();
3067
3068        if let Some(caps) = re.captures(s) {
3069            let extracted = &caps[1];
3070            extracted.to_owned()
3071        } else {
3072            s.to_owned()
3073        }
3074    } else if *template_ty == PromptTemplateType::Qwen3Agent {
3075        let mut s = output.as_ref().trim();
3076
3077        if s.starts_with(":") {
3078            s = s.trim_start_matches(":").trim();
3079        }
3080
3081        if s.starts_with("</think>") {
3082            s = s.trim_start_matches("</think>").trim();
3083        }
3084
3085        if s.ends_with("<|im_end|>") {
3086            s = s.trim_end_matches("<|im_end|>").trim();
3087        }
3088
3089        if s.contains("<final_answer>") && !s.contains("</final_answer>") {
3090            format!("{s}</final_answer>")
3091        } else {
3092            s.to_owned()
3093        }
3094    } else if *template_ty == PromptTemplateType::SeedOssNoThink {
3095        let s = output.as_ref().trim();
3096
3097        let re = regex::Regex::new(r"(?s)</seed:think>\s*(.*?)\s*<seed:eos>").unwrap();
3098
3099        if let Some(caps) = re.captures(s) {
3100            let extracted = &caps[1];
3101            extracted.to_owned()
3102        } else {
3103            s.to_owned()
3104        }
3105    } else {
3106        output.as_ref().trim().to_owned()
3107    };
3108
3109    Ok(output)
3110}
3111
3112/// Build the chat prompt from the chat messages.
3113///
3114/// # Arguments
3115///
3116/// * `model_name`: The name of the model.
3117///
3118/// * `chat_request`: The chat request.
3119///
3120/// # Returns
3121///
3122/// A tuple containing the prompt, the number of available tokens for completions, and a boolean indicating whether tools are used.
3123fn build_prompt(
3124    model_name: Option<&String>,
3125    chat_request: &mut ChatCompletionRequest,
3126) -> Result<(String, u64, bool), LlamaCoreError> {
3127    let metadata = get_model_metadata(model_name)?;
3128    let ctx_size = metadata.ctx_size as u64;
3129    let chat_prompt = ChatPrompt::from(metadata.prompt_template);
3130
3131    // compute max prompt tokens, which is 80% of the context size
3132    let max_prompt_tokens = ctx_size * 4 / 5;
3133
3134    loop {
3135        // ! DO NOT REMOVE
3136        {
3137            // // build prompt
3138            // let prompt = match chat_prompt.build(&mut chat_request.messages) {
3139            //     Ok(prompt) => prompt,
3140            //     Err(e) => {
3141            //         let err_msg = format!("Fail to build chat prompts. Reason: {}", e);
3142
3143            //         #[cfg(feature = "logging")]
3144            //         error!(target: "stdout", "{}", &err_msg);
3145
3146            //         return Err(LlamaCoreError::Operation(err_msg));
3147            //     }
3148            // };
3149        }
3150
3151        if chat_request.messages.is_empty() {
3152            let err_msg = "The messages in the chat request are empty.";
3153
3154            #[cfg(feature = "logging")]
3155            error!(target: "stdout", "{err_msg}");
3156
3157            return Err(LlamaCoreError::Operation(err_msg.to_owned()));
3158        }
3159
3160        #[cfg(feature = "logging")]
3161        {
3162            let mut role_chain = String::new();
3163            for (idx, message) in chat_request.messages.iter().enumerate() {
3164                if idx == chat_request.messages.len() - 1 {
3165                    role_chain.push_str(&format!("{}", message.role()));
3166                } else {
3167                    role_chain.push_str(&format!("{} -> ", message.role()));
3168                }
3169            }
3170            info!(target: "stdout", "Role chain: {role_chain}");
3171        }
3172
3173        let (prompt, tool_use) = match chat_request.tool_choice.as_ref() {
3174            Some(tool_choice) => match tool_choice {
3175                ToolChoice::None => {
3176                    match chat_prompt.build_with_tools(&mut chat_request.messages, Some(&[])) {
3177                        Ok(prompt) => (prompt, false),
3178                        Err(e) => {
3179                            let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3180
3181                            #[cfg(feature = "logging")]
3182                            error!(target: "stdout", "{}", &err_msg);
3183
3184                            return Err(LlamaCoreError::Operation(err_msg));
3185                        }
3186                    }
3187                }
3188                _ => match chat_request.tools.as_ref() {
3189                    Some(tools) => match chat_prompt
3190                        .build_with_tools(&mut chat_request.messages, Some(tools.as_slice()))
3191                    {
3192                        Ok(prompt) => (prompt, true),
3193                        Err(e) => {
3194                            let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3195
3196                            #[cfg(feature = "logging")]
3197                            error!(target: "stdout", "{}", &err_msg);
3198
3199                            return Err(LlamaCoreError::Operation(err_msg));
3200                        }
3201                    },
3202                    None => {
3203                        #[cfg(feature = "logging")]
3204                        warn!(target: "stdout", "The tool choice without tools is not supported.");
3205
3206                        match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
3207                            Ok(prompt) => (prompt, false),
3208                            Err(e) => {
3209                                let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3210
3211                                #[cfg(feature = "logging")]
3212                                error!(target: "stdout", "{}", &err_msg);
3213
3214                                return Err(LlamaCoreError::Operation(err_msg));
3215                            }
3216                        }
3217                    }
3218                },
3219            },
3220            None => match chat_prompt.build_with_tools(&mut chat_request.messages, None) {
3221                Ok(prompt) => (prompt, false),
3222                Err(e) => {
3223                    let err_msg = format!("Fail to build chat prompts. Reason: {e}");
3224
3225                    #[cfg(feature = "logging")]
3226                    error!(target: "stdout", "{}", &err_msg);
3227
3228                    return Err(LlamaCoreError::Operation(err_msg));
3229                }
3230            },
3231        };
3232        #[cfg(feature = "logging")]
3233        info!(target: "stdout", "Try to set prompt: {prompt}");
3234
3235        // set prompt
3236        set_prompt(model_name, &prompt)?;
3237
3238        // Retrieve the number of prompt tokens.
3239        let token_info = get_token_info_by_graph_name(model_name)?;
3240
3241        match token_info.prompt_tokens > max_prompt_tokens {
3242            true => {
3243                match chat_request.messages[0].role() {
3244                    ChatCompletionRole::System => {
3245                        // corner case: context size is too small, `system -> user -> assistant -> tool` cannot be trimmed.
3246                        if chat_request.messages.len() == 4
3247                            && chat_request.messages[1].role() == ChatCompletionRole::User
3248                            && chat_request.messages[2].role() == ChatCompletionRole::Assistant
3249                            && chat_request.messages[3].role() == ChatCompletionRole::Tool
3250                        {
3251                            let err_msg = format!(
3252                                "The number of prompt tokens ({}) is greater than the max prompt tokens ({}). Please increase the context size.",
3253                                token_info.prompt_tokens, max_prompt_tokens
3254                            );
3255
3256                            #[cfg(feature = "logging")]
3257                            error!(target: "stdout", "{}", &err_msg);
3258
3259                            return Err(LlamaCoreError::Operation(err_msg));
3260                        }
3261
3262                        if chat_request.messages.len() > 2 {
3263                            #[cfg(feature = "logging")]
3264                            info!(target: "stdout", "Prune chat history: current length {}", chat_request.messages.len());
3265
3266                            // remove user_1 if it exists
3267                            // For example, `system -> user_1 -> ... -> user_2 -> ... -> user_latest` will be converted to `system -> ... -> user_2 -> ... -> user_latest`
3268                            if chat_request.messages[1].role() == ChatCompletionRole::User {
3269                                let user_message = chat_request.messages.remove(1);
3270
3271                                #[cfg(feature = "logging")]
3272                                info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
3273                            }
3274
3275                            // remove all messages until the message is of `user`
3276                            // For example, `system -> ... -> user_2 -> ... -> user_latest` will be converted to `system -> user_2 -> ... -> user_latest`
3277                            while chat_request.messages[1].role() != ChatCompletionRole::User {
3278                                let message = chat_request.messages.remove(1);
3279
3280                                #[cfg(feature = "logging")]
3281                                info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
3282
3283                                if chat_request.messages.len() == 1 {
3284                                    let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
3285
3286                                    #[cfg(feature = "logging")]
3287                                    error!(target: "stdout", "{err_msg}");
3288
3289                                    return Err(LlamaCoreError::Operation(err_msg));
3290                                }
3291                            }
3292                        } else if token_info.prompt_tokens > ctx_size {
3293                            let err_msg = format!(
3294                                    "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
3295                                    token_info.prompt_tokens, ctx_size
3296                                );
3297
3298                            #[cfg(feature = "logging")]
3299                            error!(target: "stdout", "{}", &err_msg);
3300
3301                            return Err(LlamaCoreError::Operation(err_msg));
3302                        } else {
3303                            return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
3304                        }
3305                    }
3306                    ChatCompletionRole::User => {
3307                        // corner case: context size is too small, `user -> assistant -> tool` cannot be trimmed.
3308                        if chat_request.messages.len() == 3
3309                            && chat_request.messages[1].role() == ChatCompletionRole::User
3310                            && chat_request.messages[2].role() == ChatCompletionRole::Assistant
3311                            && chat_request.messages[3].role() == ChatCompletionRole::Tool
3312                        {
3313                            let err_msg = format!(
3314                            "The number of prompt tokens ({}) is greater than the max prompt tokens ({}). Please increase the context size.",
3315                            token_info.prompt_tokens, max_prompt_tokens
3316                        );
3317
3318                            #[cfg(feature = "logging")]
3319                            error!(target: "stdout", "{}", &err_msg);
3320
3321                            return Err(LlamaCoreError::Operation(err_msg));
3322                        }
3323
3324                        if chat_request.messages.len() > 1 {
3325                            // user_1 -> ... -> user_2 -> ... -> user_latest
3326
3327                            // remove user_1 if it exists
3328                            // For example, `user_1 -> ... -> user_2 -> ... -> user_latest` will be converted to `... -> user_2 -> ... -> user_latest`
3329                            if chat_request.messages[0].role() == ChatCompletionRole::User {
3330                                let user_message = chat_request.messages.remove(0);
3331
3332                                #[cfg(feature = "logging")]
3333                                info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
3334                            }
3335
3336                            // remove all messages until the message is of `user`
3337                            // For example, `... -> user_2 -> ... -> user_latest` will be converted to `user_2 -> ... -> user_latest`
3338                            while chat_request.messages[0].role() != ChatCompletionRole::User {
3339                                let message = chat_request.messages.remove(0);
3340
3341                                #[cfg(feature = "logging")]
3342                                info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
3343
3344                                if chat_request.messages.is_empty() {
3345                                    let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
3346
3347                                    #[cfg(feature = "logging")]
3348                                    error!(target: "stdout", "{err_msg}");
3349
3350                                    return Err(LlamaCoreError::Operation(err_msg));
3351                                }
3352                            }
3353                        } else if token_info.prompt_tokens > ctx_size {
3354                            let err_msg = format!(
3355                                    "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
3356                                    token_info.prompt_tokens, ctx_size
3357                                );
3358
3359                            #[cfg(feature = "logging")]
3360                            error!(target: "stdout", "{}", &err_msg);
3361
3362                            return Err(LlamaCoreError::Operation(err_msg));
3363                        } else {
3364                            return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
3365                        }
3366                    }
3367                    _ => {
3368                        #[cfg(feature = "logging")]
3369                        info!(target: "stdout", "remove a {} message from the message queue", chat_request.messages[0].role());
3370
3371                        chat_request.messages.remove(0);
3372                    }
3373                }
3374
3375                continue;
3376            }
3377            false => return Ok((prompt, ctx_size - max_prompt_tokens, tool_use)),
3378        }
3379    }
3380}
3381
3382fn set_prompt(model_name: Option<&String>, prompt: impl AsRef<str>) -> Result<(), LlamaCoreError> {
3383    let chat_graphs = match CHAT_GRAPHS.get() {
3384        Some(chat_graphs) => chat_graphs,
3385        None => {
3386            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3387
3388            #[cfg(feature = "logging")]
3389            error!(target: "stdout", "{}", &err_msg);
3390
3391            return Err(LlamaCoreError::Operation(err_msg.into()));
3392        }
3393    };
3394
3395    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3396        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3397
3398        #[cfg(feature = "logging")]
3399        error!(target: "stdout", "{}", &err_msg);
3400
3401        LlamaCoreError::Operation(err_msg)
3402    })?;
3403
3404    match model_name {
3405        Some(model_name) => {
3406            #[cfg(feature = "logging")]
3407            info!(target: "stdout", "Set prompt to the chat model named {model_name}");
3408
3409            match chat_graphs.contains_key(model_name) {
3410                true => {
3411                    let graph = chat_graphs.get_mut(model_name).unwrap();
3412                    let tensor_data = prompt.as_ref().as_bytes().to_vec();
3413                    set_tensor_data_u8(graph, 0, &tensor_data)
3414                }
3415                false => match chat_graphs.iter_mut().next() {
3416                    Some((_, graph)) => {
3417                        let tensor_data = prompt.as_ref().as_bytes().to_vec();
3418                        set_tensor_data_u8(graph, 0, &tensor_data)
3419                    }
3420                    None => {
3421                        let err_msg = "There is no model available in the chat graphs.";
3422
3423                        #[cfg(feature = "logging")]
3424                        error!(target: "stdout", "{}", &err_msg);
3425
3426                        Err(LlamaCoreError::Operation(err_msg.into()))
3427                    }
3428                },
3429            }
3430        }
3431        None => {
3432            #[cfg(feature = "logging")]
3433            info!(target: "stdout", "Set prompt to the default chat model.");
3434
3435            match chat_graphs.iter_mut().next() {
3436                Some((_, graph)) => {
3437                    let tensor_data = prompt.as_ref().as_bytes().to_vec();
3438                    set_tensor_data_u8(graph, 0, &tensor_data)
3439                }
3440                None => {
3441                    let err_msg = "There is no model available in the chat graphs while trying to set prompt to the default model.";
3442
3443                    #[cfg(feature = "logging")]
3444                    error!(target: "stdout", "{err_msg}");
3445
3446                    Err(LlamaCoreError::Operation(err_msg.into()))
3447                }
3448            }
3449        }
3450    }
3451}
3452
3453/// Get a copy of the metadata of the model.
3454fn get_model_metadata(model_name: Option<&String>) -> Result<GgmlMetadata, LlamaCoreError> {
3455    let chat_graphs = match CHAT_GRAPHS.get() {
3456        Some(chat_graphs) => chat_graphs,
3457        None => {
3458            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3459
3460            #[cfg(feature = "logging")]
3461            error!(target: "stdout", "{err_msg}");
3462
3463            return Err(LlamaCoreError::Operation(err_msg.into()));
3464        }
3465    };
3466
3467    let chat_graphs = chat_graphs.lock().map_err(|e| {
3468        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3469
3470        #[cfg(feature = "logging")]
3471        error!(target: "stdout", "{}", &err_msg);
3472
3473        LlamaCoreError::Operation(err_msg)
3474    })?;
3475
3476    match model_name {
3477        Some(model_name) => match chat_graphs.contains_key(model_name) {
3478            true => {
3479                let graph = chat_graphs.get(model_name).unwrap();
3480                Ok(graph.metadata.clone())
3481            }
3482            false => match chat_graphs.iter().next() {
3483                Some((_, graph)) => Ok(graph.metadata.clone()),
3484                None => {
3485                    let err_msg = "There is no model available in the chat graphs.";
3486
3487                    #[cfg(feature = "logging")]
3488                    error!(target: "stdout", "{}", &err_msg);
3489
3490                    Err(LlamaCoreError::Operation(err_msg.into()))
3491                }
3492            },
3493        },
3494        None => match chat_graphs.iter().next() {
3495            Some((_, graph)) => Ok(graph.metadata.clone()),
3496            None => {
3497                let err_msg = "There is no model available in the chat graphs.";
3498
3499                #[cfg(feature = "logging")]
3500                error!(target: "stdout", "{err_msg}");
3501
3502                Err(LlamaCoreError::Operation(err_msg.into()))
3503            }
3504        },
3505    }
3506}
3507
3508fn update_model_metadata(
3509    model_name: Option<&String>,
3510    metadata: &GgmlMetadata,
3511) -> Result<(), LlamaCoreError> {
3512    let config = match serde_json::to_string(metadata) {
3513        Ok(config) => config,
3514        Err(e) => {
3515            let err_msg = format!("Fail to serialize metadata to a JSON string. {e}");
3516
3517            #[cfg(feature = "logging")]
3518            error!(target: "stdout", "{}", &err_msg);
3519
3520            return Err(LlamaCoreError::Operation(err_msg));
3521        }
3522    };
3523
3524    let chat_graphs = match CHAT_GRAPHS.get() {
3525        Some(chat_graphs) => chat_graphs,
3526        None => {
3527            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3528
3529            #[cfg(feature = "logging")]
3530            error!(target: "stdout", "{err_msg}");
3531
3532            return Err(LlamaCoreError::Operation(err_msg.into()));
3533        }
3534    };
3535
3536    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3537        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. Reason: {e}");
3538
3539        #[cfg(feature = "logging")]
3540        error!(target: "stdout", "{}", &err_msg);
3541
3542        LlamaCoreError::Operation(err_msg)
3543    })?;
3544
3545    match model_name {
3546        Some(model_name) => {
3547            match chat_graphs.contains_key(model_name) {
3548                true => {
3549                    let graph = chat_graphs.get_mut(model_name).unwrap();
3550                    // update metadata
3551                    set_tensor_data_u8(graph, 1, config.as_bytes())
3552                }
3553                false => match chat_graphs.iter_mut().next() {
3554                    Some((_, graph)) => {
3555                        // update metadata
3556                        set_tensor_data_u8(graph, 1, config.as_bytes())
3557                    }
3558                    None => {
3559                        let err_msg = "There is no model available in the chat graphs.";
3560
3561                        #[cfg(feature = "logging")]
3562                        error!(target: "stdout", "{}", &err_msg);
3563
3564                        Err(LlamaCoreError::Operation(err_msg.into()))
3565                    }
3566                },
3567            }
3568        }
3569        None => {
3570            match chat_graphs.iter_mut().next() {
3571                Some((_, graph)) => {
3572                    // update metadata
3573                    set_tensor_data_u8(graph, 1, config.as_bytes())
3574                }
3575                None => {
3576                    let err_msg = "There is no model available in the chat graphs.";
3577
3578                    #[cfg(feature = "logging")]
3579                    error!(target: "stdout", "{err_msg}");
3580
3581                    Err(LlamaCoreError::Operation(err_msg.into()))
3582                }
3583            }
3584        }
3585    }
3586}
3587
3588fn reset_model_metadata(model_name: Option<&String>) -> Result<(), LlamaCoreError> {
3589    // get metadata
3590    let metadata = get_model_metadata(model_name)?;
3591
3592    // update model with the original metadata
3593    update_model_metadata(model_name, &metadata)
3594}
3595
3596#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3597enum ContextFullState {
3598    Message,
3599    Usage,
3600    Done,
3601    EndOfSequence,
3602}
3603
3604#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3605enum StreamState {
3606    Usage,
3607    NoUsage,
3608    Done,
3609    EndOfSequence,
3610}
3611
3612#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3613enum PromptTooLongState {
3614    Message,
3615    Usage,
3616    Done,
3617    EndOfSequence,
3618}
3619
3620struct ChatStream {
3621    id: String,
3622    model: Option<String>,
3623    include_usage: bool,
3624    context_full_state: ContextFullState,
3625    prompt_too_long_state: PromptTooLongState,
3626    stream_state: StreamState,
3627    cache: Option<VecDeque<String>>,
3628    is_waiting: bool,
3629    has_lock: bool,
3630}
3631impl ChatStream {
3632    fn new(
3633        model: Option<String>,
3634        id: String,
3635        include_usage: bool,
3636        cache: Option<Vec<String>>,
3637    ) -> Self {
3638        // Try to acquire lock
3639        let has_lock = CHAT_STREAM_ACTIVE
3640            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
3641            .is_ok();
3642
3643        #[cfg(feature = "logging")]
3644        if !has_lock {
3645            info!(target: "stdout", "Lock acquisition failed in ChatStream::new, creating with waiting status");
3646        }
3647
3648        ChatStream {
3649            id,
3650            model,
3651            include_usage,
3652            context_full_state: ContextFullState::Message,
3653            prompt_too_long_state: PromptTooLongState::Message,
3654            stream_state: if include_usage {
3655                StreamState::Usage
3656            } else {
3657                StreamState::NoUsage
3658            },
3659            cache: cache.map(VecDeque::from),
3660            is_waiting: !has_lock,
3661            has_lock,
3662        }
3663    }
3664
3665    // Try to acquire lock, returns whether successful
3666    fn try_acquire_lock(&mut self) -> bool {
3667        if self.has_lock {
3668            return true;
3669        }
3670
3671        let acquired = CHAT_STREAM_ACTIVE
3672            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
3673            .is_ok();
3674
3675        if acquired {
3676            self.has_lock = true;
3677            self.is_waiting = false;
3678        }
3679
3680        acquired
3681    }
3682}
3683impl Drop for ChatStream {
3684    fn drop(&mut self) {
3685        // Clean up is only needed if we have the lock or if stream was actually used
3686        if self.has_lock || (self.cache.is_none() && !self.is_waiting) {
3687            #[cfg(feature = "logging")]
3688            info!(target: "stdout", "Cleaning up context for ChatStream {}", &self.id);
3689
3690            match &self.model {
3691                Some(model_name) => {
3692                    match CHAT_GRAPHS.get() {
3693                        Some(chat_graphs) => {
3694                            match chat_graphs.lock() {
3695                                Ok(mut chat_graphs) => match chat_graphs.contains_key(model_name) {
3696                                    true => {
3697                                        let graph = chat_graphs.get_mut(model_name).unwrap();
3698
3699                                        // clean up the context
3700                                        if let Err(e) = graph.finish_single() {
3701                                            let err_msg = format!(
3702                                                "Failed to clean up the context. Reason: {e}"
3703                                            );
3704
3705                                            #[cfg(feature = "logging")]
3706                                            error!(target: "stdout", "{}", &err_msg);
3707
3708                                            #[cfg(not(feature = "logging"))]
3709                                            println!(
3710                                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3711                                                &err_msg
3712                                            );
3713                                        }
3714                                    }
3715                                    false => match chat_graphs.iter_mut().next() {
3716                                        Some((_, graph)) => {
3717                                            // clean up the context
3718                                            if let Err(e) = graph.finish_single() {
3719                                                let err_msg = format!(
3720                                                    "Failed to clean up the context. Reason: {e}"
3721                                                );
3722
3723                                                #[cfg(feature = "logging")]
3724                                                error!(target: "stdout", "{}", &err_msg);
3725
3726                                                #[cfg(not(feature = "logging"))]
3727                                                println!(
3728                                                    "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3729                                                    &err_msg
3730                                                );
3731                                            }
3732                                        }
3733                                        None => {
3734                                            let err_msg =
3735                                                "There is no model available in the chat graphs.";
3736
3737                                            #[cfg(feature = "logging")]
3738                                            error!(target: "stdout", "{}", &err_msg);
3739
3740                                            #[cfg(not(feature = "logging"))]
3741                                            println!(
3742                                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3743                                                &err_msg
3744                                            );
3745                                        }
3746                                    },
3747                                },
3748                                Err(e) => {
3749                                    let err_msg =
3750                                        format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3751
3752                                    #[cfg(feature = "logging")]
3753                                    error!(target: "stdout", "{}", &err_msg);
3754
3755                                    #[cfg(not(feature = "logging"))]
3756                                    println!(
3757                                        "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3758                                        &err_msg
3759                                    );
3760                                }
3761                            }
3762                        }
3763                        None => {
3764                            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3765
3766                            #[cfg(feature = "logging")]
3767                            error!(target: "stdout", "{}", &err_msg);
3768
3769                            #[cfg(not(feature = "logging"))]
3770                            println!(
3771                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3772                                &err_msg
3773                            );
3774                        }
3775                    };
3776                }
3777                None => {
3778                    match CHAT_GRAPHS.get() {
3779                        Some(chat_graphs) => {
3780                            match chat_graphs.lock() {
3781                                Ok(mut chat_graphs) => match chat_graphs.iter_mut().next() {
3782                                    Some((_, graph)) => {
3783                                        // clean up the context
3784                                        if let Err(e) = graph.finish_single() {
3785                                            let err_msg = format!(
3786                                                "Failed to clean up the context. Reason: {e}"
3787                                            );
3788
3789                                            #[cfg(feature = "logging")]
3790                                            error!(target: "stdout", "{}", &err_msg);
3791
3792                                            #[cfg(not(feature = "logging"))]
3793                                            println!(
3794                                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3795                                                &err_msg
3796                                            );
3797                                        }
3798                                    }
3799                                    None => {
3800                                        let err_msg =
3801                                            "There is no model available in the chat graphs.";
3802
3803                                        #[cfg(feature = "logging")]
3804                                        error!(target: "stdout", "{err_msg}");
3805
3806                                        #[cfg(not(feature = "logging"))]
3807                                        println!(
3808                                            "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3809                                            err_msg
3810                                        );
3811                                    }
3812                                },
3813                                Err(e) => {
3814                                    let err_msg =
3815                                        format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3816
3817                                    #[cfg(feature = "logging")]
3818                                    error!(target: "stdout", "{}", &err_msg);
3819
3820                                    #[cfg(not(feature = "logging"))]
3821                                    println!(
3822                                        "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3823                                        &err_msg
3824                                    );
3825                                }
3826                            }
3827                        }
3828                        None => {
3829                            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3830
3831                            #[cfg(feature = "logging")]
3832                            error!(target: "stdout", "{}", &err_msg);
3833
3834                            #[cfg(not(feature = "logging"))]
3835                            println!(
3836                                "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3837                                &err_msg
3838                            );
3839                        }
3840                    };
3841                }
3842            }
3843
3844            #[cfg(feature = "logging")]
3845            info!(target: "stdout", "Model context cleanup done!");
3846        }
3847
3848        // reset the model metadata
3849        if let Err(e) = reset_model_metadata(self.model.as_ref()) {
3850            let err_msg = format!("Fail to reset model metadata. Reason: {e}");
3851
3852            #[cfg(feature = "logging")]
3853            error!(target: "stdout", "{}", &err_msg);
3854
3855            #[cfg(not(feature = "logging"))]
3856            println!("[ERROR][llama_core] {}", &err_msg);
3857        }
3858        #[cfg(feature = "logging")]
3859        info!(target: "stdout", "Model metadata reset done!");
3860
3861        // When dropping a ChatStream that held the lock, check if there are waiting streams
3862        if self.has_lock {
3863            // Reset the atomic flag
3864            CHAT_STREAM_ACTIVE.store(false, Ordering::SeqCst);
3865
3866            #[cfg(feature = "logging")]
3867            info!(target: "stdout", "Lock from ChatStream {} released", &self.id);
3868
3869            // Wake up waiting streams
3870            if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3871                if let Some(waker) = queue.pop_front() {
3872                    #[cfg(feature = "logging")]
3873                    info!(target: "stdout", "Waking up a waiting ChatStream");
3874
3875                    waker.wake();
3876                }
3877            }
3878        }
3879    }
3880}
3881impl futures::Stream for ChatStream {
3882    type Item = Result<String, LlamaCoreError>;
3883
3884    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
3885        let this = self.get_mut();
3886
3887        // If this is a waiting stream, try to acquire the lock
3888        if this.is_waiting {
3889            if !this.try_acquire_lock() {
3890                // Store the waker to be notified when the lock becomes available
3891                if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3892                    // Remove any previous instance of this waker
3893                    queue.retain(|w| !w.will_wake(cx.waker()));
3894                    // Add this waker to the queue
3895                    queue.push_back(cx.waker().clone());
3896
3897                    #[cfg(feature = "logging")]
3898                    debug!(target: "stdout", "ChatStream {} is waiting for lock, added waker to queue", &this.id);
3899                }
3900
3901                return Poll::Pending;
3902            }
3903
3904            #[cfg(feature = "logging")]
3905            info!(target: "stdout", "ChatStream {} acquired lock and is now active", &this.id);
3906            // If we got here, we successfully acquired the lock and can proceed
3907        }
3908
3909        // Ensure we still have the lock
3910        if !this.has_lock && !this.try_acquire_lock() {
3911            // Lost the lock, need to wait
3912            this.is_waiting = true;
3913
3914            // Register waker to be notified when lock is available
3915            if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3916                queue.retain(|w| !w.will_wake(cx.waker()));
3917                queue.push_back(cx.waker().clone());
3918            }
3919
3920            return Poll::Pending;
3921        }
3922
3923        if this.cache.is_none() {
3924            let res = compute_stream(
3925                this.model.clone(),
3926                this.id.clone(),
3927                this.include_usage,
3928                &mut this.prompt_too_long_state,
3929                &mut this.context_full_state,
3930                &mut this.stream_state,
3931            );
3932
3933            match res {
3934                Ok(x) => {
3935                    #[cfg(feature = "logging")]
3936                    info!(target: "stdout", "next item for ChatStream {}: {}", &this.id, &x);
3937
3938                    if x != "[GGML] End of sequence" && !x.is_empty() {
3939                        Poll::Ready(Some(Ok(x)))
3940                    } else {
3941                        // stopped
3942                        Poll::Ready(None)
3943                    }
3944                }
3945                Err(e) => Poll::Ready(Some(Err(e))),
3946            }
3947        } else {
3948            let x = this.cache.as_mut().unwrap().pop_front();
3949
3950            #[cfg(feature = "logging")]
3951            info!(target: "stdout", "Get the next item from the cache for ChatStream {}: {:?}", &this.id, &x);
3952
3953            match x {
3954                Some(x) => Poll::Ready(Some(Ok(x))),
3955                None => Poll::Ready(None),
3956            }
3957        }
3958    }
3959}
3960
3961/// Helper function to get or initialize the waker queue for waiting ChatStreams
3962fn get_chat_stream_waker_queue() -> &'static Mutex<VecDeque<Waker>> {
3963    CHAT_STREAM_WAKER_QUEUE.get_or_init(|| {
3964        #[cfg(feature = "logging")]
3965        info!(target: "stdout", "Initializing ChatStream waker queue");
3966        Mutex::new(VecDeque::new())
3967    })
3968}
3969
3970fn compute_stream(
3971    model_name: Option<String>,
3972    id: String,
3973    include_usage: bool,
3974    prompt_too_long_state: &mut PromptTooLongState,
3975    context_full_state: &mut ContextFullState,
3976    stream_state: &mut StreamState,
3977) -> Result<String, LlamaCoreError> {
3978    #[cfg(feature = "logging")]
3979    info!(target: "stdout", "Computing stream chunk for ChatStream {}", &id);
3980
3981    #[cfg(feature = "logging")]
3982    debug!(target: "stdout", "prompt_too_long_state: {:?}", *prompt_too_long_state);
3983    #[cfg(feature = "logging")]
3984    debug!(target: "stdout", "context_full_state: {:?}", *context_full_state);
3985    #[cfg(feature = "logging")]
3986    debug!(target: "stdout", "stream_state: {:?}", *stream_state);
3987
3988    if *prompt_too_long_state == PromptTooLongState::EndOfSequence
3989        || *context_full_state == ContextFullState::EndOfSequence
3990        || *stream_state == StreamState::EndOfSequence
3991    {
3992        #[cfg(feature = "logging")]
3993        info!(target: "stdout", "Return the chat stream chunk!");
3994
3995        return Ok("[GGML] End of sequence".to_string());
3996    }
3997
3998    let chat_graphs = match CHAT_GRAPHS.get() {
3999        Some(chat_graphs) => chat_graphs,
4000        None => {
4001            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
4002
4003            #[cfg(feature = "logging")]
4004            error!(target: "stdout", "{}", &err_msg);
4005
4006            return Err(LlamaCoreError::Operation(err_msg.into()));
4007        }
4008    };
4009
4010    // We're already holding the ChatStream lock, so we know we have exclusive access to the graph
4011    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
4012        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
4013
4014        #[cfg(feature = "logging")]
4015        error!(target: "stdout", "{}", &err_msg);
4016
4017        LlamaCoreError::Operation(err_msg)
4018    })?;
4019
4020    // Get the graph based on model name
4021    let res = match &model_name {
4022        Some(model_name) => {
4023            match chat_graphs.contains_key(model_name) {
4024                true => {
4025                    let graph = chat_graphs.get_mut(model_name).unwrap();
4026                    // compute
4027                    match graph.compute_single() {
4028                        Ok(_) => {
4029                            #[cfg(feature = "logging")]
4030                            debug!(target: "stdout", "Compute the chat stream chunk successfully.");
4031
4032                            // Process according to state
4033                            match stream_state {
4034                                StreamState::Usage | StreamState::NoUsage => {
4035                                    // Retrieve the output
4036                                    let output_buffer =
4037                                        get_output_buffer_single(graph, OUTPUT_TENSOR)?;
4038
4039                                    #[cfg(feature = "logging")]
4040                                    info!(target: "stdout", "retrieved the output buffer");
4041
4042                                    // decode the output buffer to a utf8 string
4043                                    let output = match String::from_utf8(output_buffer.clone()) {
4044                                        Ok(token) => token,
4045                                        Err(_) => {
4046                                            let mutex = CACHED_UTF8_ENCODINGS
4047                                                .get_or_init(|| Mutex::new(Vec::new()));
4048                                            let mut cached_encodings = mutex.lock().map_err(|e| {
4049                                            let err_msg = format!(
4050                                                "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
4051                                            );
4052
4053                                            #[cfg(feature = "logging")]
4054                                            error!(target: "stdout", "{}", &err_msg);
4055
4056
4057                                            LlamaCoreError::Operation(err_msg)
4058                                        })?;
4059
4060                                            // cache the bytes for future decoding
4061                                            cached_encodings.extend_from_slice(&output_buffer[..]);
4062
4063                                            match String::from_utf8(cached_encodings.to_vec()) {
4064                                                Ok(token) => {
4065                                                    // clear CACHED_UTF8_ENCODINGS
4066                                                    cached_encodings.clear();
4067
4068                                                    token
4069                                                }
4070                                                Err(e) => {
4071                                                    // TODO This is a temp check. In case, infinite cached encodings happen.
4072                                                    if cached_encodings.len() > 4 {
4073                                                        let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
4074
4075                                                        #[cfg(feature = "logging")]
4076                                                        error!(target: "stdout", "{}", &err_msg);
4077
4078                                                        #[cfg(feature = "logging")]
4079                                                        error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
4080
4081                                                        // let token = String::from_utf8_lossy(
4082                                                        //     &cached_encodings,
4083                                                        // )
4084                                                        // .to_string();
4085
4086                                                        // clear CACHED_UTF8_ENCODINGS
4087                                                        cached_encodings.clear();
4088
4089                                                        String::from("")
4090                                                    } else {
4091                                                        let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
4092
4093                                                        #[cfg(feature = "logging")]
4094                                                        warn!(target: "stdout", "{}", &warn_msg);
4095
4096                                                        String::from("")
4097                                                    }
4098                                                }
4099                                            }
4100                                        }
4101                                    };
4102
4103                                    #[cfg(feature = "logging")]
4104                                    info!(target: "stdout", "decoded the output buffer");
4105
4106                                    let created = SystemTime::now()
4107                                        .duration_since(std::time::UNIX_EPOCH)
4108                                        .map_err(|e| {
4109                                            let err_msg = format!(
4110                                                "Failed to get the current time. Reason: {e}"
4111                                            );
4112
4113                                            #[cfg(feature = "logging")]
4114                                            error!(target: "stdout", "{}", &err_msg);
4115
4116                                            LlamaCoreError::Operation(err_msg)
4117                                        })?;
4118
4119                                    let chat_completion_chunk = ChatCompletionChunk {
4120                                        id,
4121                                        object: "chat.completion.chunk".to_string(),
4122                                        created: created.as_secs(),
4123                                        model: graph.name().to_owned(),
4124                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4125                                        choices: vec![ChatCompletionChunkChoice {
4126                                            index: 0,
4127                                            delta: ChatCompletionChunkChoiceDelta {
4128                                                role: ChatCompletionRole::Assistant,
4129                                                content: Some(output),
4130                                                tool_calls: vec![],
4131                                            },
4132                                            logprobs: None,
4133                                            finish_reason: None,
4134                                        }],
4135                                        usage: None,
4136                                    };
4137
4138                                    #[cfg(feature = "logging")]
4139                                    info!(target: "stdout", "created chat completion chunk");
4140
4141                                    // serialize chat completion chunk
4142                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4143                                        .map_err(|e| {
4144                                        let err_msg = format!(
4145                                            "Failed to serialize chat completion chunk. Reason: {e}"
4146                                        );
4147
4148                                        #[cfg(feature = "logging")]
4149                                        error!(target: "stdout", "{}", &err_msg);
4150
4151                                        LlamaCoreError::Operation(err_msg)
4152                                    })?;
4153
4154                                    Ok(format!("data: {chunk_str}\n\n"))
4155                                }
4156                                StreamState::Done => {
4157                                    *stream_state = StreamState::EndOfSequence;
4158
4159                                    Ok("data: [DONE]\n\n".to_string())
4160                                }
4161                                StreamState::EndOfSequence => {
4162                                    Ok("[GGML] End of sequence".to_string())
4163                                }
4164                            }
4165                        }
4166                        Err(wasmedge_wasi_nn::Error::BackendError(
4167                            wasmedge_wasi_nn::BackendError::EndOfSequence,
4168                        )) => {
4169                            #[cfg(feature = "logging")]
4170                            debug!(target: "stdout", "End of sequence");
4171
4172                            match stream_state {
4173                                StreamState::Usage => {
4174                                    *stream_state = StreamState::Done;
4175
4176                                    // retrieve the number of prompt and completion tokens
4177                                    let token_info = get_token_info_by_graph(graph)?;
4178
4179                                    let usage = Some(Usage {
4180                                        prompt_tokens: token_info.prompt_tokens,
4181                                        completion_tokens: token_info.completion_tokens,
4182                                        total_tokens: token_info.prompt_tokens
4183                                            + token_info.completion_tokens,
4184                                    });
4185
4186                                    #[cfg(feature = "logging")]
4187                                    info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4188
4189                                    let created = SystemTime::now()
4190                                        .duration_since(std::time::UNIX_EPOCH)
4191                                        .map_err(|e| {
4192                                            let err_msg = format!(
4193                                                "Failed to get the current time. Reason: {e}"
4194                                            );
4195
4196                                            #[cfg(feature = "logging")]
4197                                            error!(target: "stdout", "{}", &err_msg);
4198
4199                                            LlamaCoreError::Operation(err_msg)
4200                                        })?;
4201
4202                                    let chat_completion_chunk = ChatCompletionChunk {
4203                                        id,
4204                                        object: "chat.completion.chunk".to_string(),
4205                                        created: created.as_secs(),
4206                                        model: graph.name().to_owned(),
4207                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4208                                        choices: vec![],
4209                                        usage,
4210                                    };
4211
4212                                    // serialize chat completion chunk
4213                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4214                                        .map_err(|e| {
4215                                        let err_msg = format!(
4216                                            "Failed to serialize chat completion chunk. Reason: {e}"
4217                                        );
4218
4219                                        #[cfg(feature = "logging")]
4220                                        error!(target: "stdout", "{}", &err_msg);
4221
4222                                        LlamaCoreError::Operation(err_msg)
4223                                    })?;
4224
4225                                    Ok(format!("data: {chunk_str}\n\n"))
4226                                }
4227                                StreamState::Done | StreamState::NoUsage => {
4228                                    *stream_state = StreamState::EndOfSequence;
4229
4230                                    Ok("data: [DONE]\n\n".to_string())
4231                                }
4232                                StreamState::EndOfSequence => {
4233                                    Ok("[GGML] End of sequence".to_string())
4234                                }
4235                            }
4236                        }
4237                        Err(wasmedge_wasi_nn::Error::BackendError(
4238                            wasmedge_wasi_nn::BackendError::ContextFull,
4239                        )) => {
4240                            #[cfg(feature = "logging")]
4241                            debug!(target: "stdout", "Context full");
4242
4243                            match context_full_state {
4244                                ContextFullState::Message => {
4245                                    match include_usage {
4246                                        true => *context_full_state = ContextFullState::Usage,
4247                                        false => *context_full_state = ContextFullState::Done,
4248                                    }
4249
4250                                    let created = SystemTime::now()
4251                                        .duration_since(std::time::UNIX_EPOCH)
4252                                        .map_err(|e| {
4253                                            let err_msg = format!(
4254                                                "Failed to get the current time. Reason: {e}"
4255                                            );
4256
4257                                            #[cfg(feature = "logging")]
4258                                            error!(target: "stdout", "{}", &err_msg);
4259
4260                                            LlamaCoreError::Operation(err_msg)
4261                                        })?;
4262
4263                                    let chat_completion_chunk = ChatCompletionChunk {
4264                                        id,
4265                                        object: "chat.completion.chunk".to_string(),
4266                                        created: created.as_secs(),
4267                                        model: graph.name().to_owned(),
4268                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4269                                        choices: vec![ChatCompletionChunkChoice {
4270                                            index: 0,
4271                                            delta: ChatCompletionChunkChoiceDelta {
4272                                                role: ChatCompletionRole::Assistant,
4273                                                content: Some(
4274                                                    "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
4275                                                ),
4276                                                tool_calls: vec![],
4277                                            },
4278                                            logprobs: None,
4279                                            finish_reason: Some(FinishReason::length),
4280                                        }],
4281                                        usage: None,
4282                                    };
4283
4284                                    // serialize chat completion chunk
4285                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4286                                        .map_err(|e| {
4287                                        let err_msg = format!(
4288                                            "Failed to serialize chat completion chunk. Reason: {e}"
4289                                        );
4290
4291                                        #[cfg(feature = "logging")]
4292                                        error!(target: "stdout", "{}", &err_msg);
4293
4294                                        LlamaCoreError::Operation(err_msg)
4295                                    })?;
4296
4297                                    Ok(format!("data: {chunk_str}\n\n"))
4298                                }
4299                                ContextFullState::Usage => {
4300                                    *context_full_state = ContextFullState::Done;
4301
4302                                    // retrieve the number of prompt and completion tokens
4303                                    let token_info = get_token_info_by_graph(graph)?;
4304
4305                                    let usage = Some(Usage {
4306                                        prompt_tokens: token_info.prompt_tokens,
4307                                        completion_tokens: token_info.completion_tokens,
4308                                        total_tokens: token_info.prompt_tokens
4309                                            + token_info.completion_tokens,
4310                                    });
4311
4312                                    let created = SystemTime::now()
4313                                        .duration_since(std::time::UNIX_EPOCH)
4314                                        .map_err(|e| {
4315                                            let err_msg = format!(
4316                                                "Failed to get the current time. Reason: {e}"
4317                                            );
4318
4319                                            #[cfg(feature = "logging")]
4320                                            error!(target: "stdout", "{}", &err_msg);
4321
4322                                            LlamaCoreError::Operation(err_msg)
4323                                        })?;
4324
4325                                    let chat_completion_chunk = ChatCompletionChunk {
4326                                        id,
4327                                        object: "chat.completion.chunk".to_string(),
4328                                        created: created.as_secs(),
4329                                        model: graph.name().to_owned(),
4330                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4331                                        choices: vec![],
4332                                        usage,
4333                                    };
4334
4335                                    // serialize chat completion chunk
4336                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4337                                        .map_err(|e| {
4338                                        let err_msg = format!(
4339                                            "Failed to serialize chat completion chunk. Reason: {e}"
4340                                        );
4341
4342                                        #[cfg(feature = "logging")]
4343                                        error!(target: "stdout", "{}", &err_msg);
4344
4345                                        LlamaCoreError::Operation(err_msg)
4346                                    })?;
4347
4348                                    Ok(format!("data: {chunk_str}\n\n"))
4349                                }
4350                                ContextFullState::Done => {
4351                                    *context_full_state = ContextFullState::EndOfSequence;
4352
4353                                    Ok("data: [DONE]\n\n".to_string())
4354                                }
4355                                ContextFullState::EndOfSequence => {
4356                                    Ok("[GGML] End of sequence".to_string())
4357                                }
4358                            }
4359                        }
4360                        Err(wasmedge_wasi_nn::Error::BackendError(
4361                            wasmedge_wasi_nn::BackendError::PromptTooLong,
4362                        )) => {
4363                            #[cfg(feature = "logging")]
4364                            debug!(target: "stdout", "Prompt too long");
4365
4366                            match prompt_too_long_state {
4367                                PromptTooLongState::Message => {
4368                                    match include_usage {
4369                                        true => *prompt_too_long_state = PromptTooLongState::Usage,
4370                                        false => *prompt_too_long_state = PromptTooLongState::Done,
4371                                    }
4372
4373                                    let created = SystemTime::now()
4374                                        .duration_since(std::time::UNIX_EPOCH)
4375                                        .map_err(|e| {
4376                                            let err_msg = format!(
4377                                                "Failed to get the current time. Reason: {e}"
4378                                            );
4379
4380                                            #[cfg(feature = "logging")]
4381                                            error!(target: "stdout", "{}", &err_msg);
4382
4383                                            LlamaCoreError::Operation(err_msg)
4384                                        })?;
4385
4386                                    let chat_completion_chunk = ChatCompletionChunk {
4387                                        id,
4388                                        object: "chat.completion.chunk".to_string(),
4389                                        created: created.as_secs(),
4390                                        model: graph.name().to_owned(),
4391                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4392                                        choices: vec![ChatCompletionChunkChoice {
4393                                            index: 0,
4394                                            delta: ChatCompletionChunkChoiceDelta {
4395                                                role: ChatCompletionRole::Assistant,
4396                                                content: None,
4397                                                tool_calls: vec![],
4398                                            },
4399                                            logprobs: None,
4400                                            finish_reason: Some(FinishReason::length),
4401                                        }],
4402                                        usage: None,
4403                                    };
4404
4405                                    // serialize chat completion chunk
4406                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4407                                        .map_err(|e| {
4408                                        let err_msg = format!(
4409                                            "Failed to serialize chat completion chunk. Reason: {e}"
4410                                        );
4411
4412                                        #[cfg(feature = "logging")]
4413                                        error!(target: "stdout", "{}", &err_msg);
4414
4415                                        LlamaCoreError::Operation(err_msg)
4416                                    })?;
4417
4418                                    Ok(format!("data: {chunk_str}\n\n"))
4419                                }
4420                                PromptTooLongState::Usage => {
4421                                    *prompt_too_long_state = PromptTooLongState::Done;
4422
4423                                    // retrieve the number of prompt and completion tokens
4424                                    let token_info = get_token_info_by_graph(graph)?;
4425
4426                                    let usage = Some(Usage {
4427                                        prompt_tokens: token_info.prompt_tokens,
4428                                        completion_tokens: token_info.completion_tokens,
4429                                        total_tokens: token_info.prompt_tokens
4430                                            + token_info.completion_tokens,
4431                                    });
4432
4433                                    let created = SystemTime::now()
4434                                        .duration_since(std::time::UNIX_EPOCH)
4435                                        .map_err(|e| {
4436                                            let err_msg = format!(
4437                                                "Failed to get the current time. Reason: {e}"
4438                                            );
4439
4440                                            #[cfg(feature = "logging")]
4441                                            error!(target: "stdout", "{}", &err_msg);
4442
4443                                            LlamaCoreError::Operation(err_msg)
4444                                        })?;
4445
4446                                    let chat_completion_chunk = ChatCompletionChunk {
4447                                        id,
4448                                        object: "chat.completion.chunk".to_string(),
4449                                        created: created.as_secs(),
4450                                        model: graph.name().to_owned(),
4451                                        system_fingerprint: "fp_44709d6fcb".to_string(),
4452                                        choices: vec![],
4453                                        usage,
4454                                    };
4455
4456                                    // serialize chat completion chunk
4457                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
4458                                        .map_err(|e| {
4459                                        let err_msg = format!(
4460                                            "Failed to serialize chat completion chunk. Reason: {e}"
4461                                        );
4462
4463                                        #[cfg(feature = "logging")]
4464                                        error!(target: "stdout", "{}", &err_msg);
4465
4466                                        LlamaCoreError::Operation(err_msg)
4467                                    })?;
4468
4469                                    Ok(format!("data: {chunk_str}\n\n"))
4470                                }
4471                                PromptTooLongState::Done => {
4472                                    *prompt_too_long_state = PromptTooLongState::EndOfSequence;
4473
4474                                    Ok("data: [DONE]\n\n".to_string())
4475                                }
4476                                PromptTooLongState::EndOfSequence => {
4477                                    Ok("[GGML] End of sequence".to_string())
4478                                }
4479                            }
4480                        }
4481                        Err(e) => {
4482                            let err_msg =
4483                                format!("Failed to compute the chat completion. Reason: {e}");
4484
4485                            #[cfg(feature = "logging")]
4486                            error!(target: "stdout", "{}", &err_msg);
4487
4488                            Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4489                                err_msg,
4490                            )))
4491                        }
4492                    }
4493                }
4494                false => {
4495                    match chat_graphs.iter_mut().next() {
4496                        Some((_, graph)) => {
4497                            // compute
4498                            match graph.compute_single() {
4499                                Ok(_) => {
4500                                    #[cfg(feature = "logging")]
4501                                    debug!(target: "stdout", "Compute the chat stream chunk successfully.");
4502
4503                                    match stream_state {
4504                                        StreamState::Usage | StreamState::NoUsage => {
4505                                            // Retrieve the output
4506                                            let output_buffer =
4507                                                get_output_buffer_single(graph, OUTPUT_TENSOR)?;
4508
4509                                            #[cfg(feature = "logging")]
4510                                            info!(target: "stdout", "retrieved the output buffer");
4511
4512                                            // decode the output buffer to a utf8 string
4513                                            let output = match String::from_utf8(
4514                                                output_buffer.clone(),
4515                                            ) {
4516                                                Ok(token) => token,
4517                                                Err(_) => {
4518                                                    let mutex = CACHED_UTF8_ENCODINGS
4519                                                        .get_or_init(|| Mutex::new(Vec::new()));
4520                                                    let mut cached_encodings = mutex.lock().map_err(|e| {
4521                                            let err_msg = format!(
4522                                                "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
4523                                            );
4524
4525                                            #[cfg(feature = "logging")]
4526                                            error!(target: "stdout", "{}", &err_msg);
4527
4528
4529                                            LlamaCoreError::Operation(err_msg)
4530                                        })?;
4531
4532                                                    // cache the bytes for future decoding
4533                                                    cached_encodings
4534                                                        .extend_from_slice(&output_buffer[..]);
4535
4536                                                    match String::from_utf8(
4537                                                        cached_encodings.to_vec(),
4538                                                    ) {
4539                                                        Ok(token) => {
4540                                                            // clear encodings
4541                                                            cached_encodings.clear();
4542
4543                                                            token
4544                                                        }
4545                                                        Err(e) => {
4546                                                            // TODO This is a temp check. In case, infinite cached encodings happen.
4547                                                            if cached_encodings.len() > 4 {
4548                                                                let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
4549
4550                                                                #[cfg(feature = "logging")]
4551                                                                error!(target: "stdout", "{}", &err_msg);
4552
4553                                                                #[cfg(feature = "logging")]
4554                                                                error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
4555
4556                                                                // let token =
4557                                                                //     String::from_utf8_lossy(
4558                                                                //         &cached_encodings,
4559                                                                //     )
4560                                                                //     .to_string();
4561
4562                                                                // clear CACHED_UTF8_ENCODINGS
4563                                                                cached_encodings.clear();
4564
4565                                                                String::from("")
4566                                                            } else {
4567                                                                let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
4568
4569                                                                #[cfg(feature = "logging")]
4570                                                                warn!(target: "stdout", "{}", &warn_msg);
4571
4572                                                                String::from("")
4573                                                            }
4574                                                        }
4575                                                    }
4576                                                }
4577                                            };
4578
4579                                            #[cfg(feature = "logging")]
4580                                            info!(target: "stdout", "decoded the output buffer");
4581
4582                                            let created = SystemTime::now()
4583                                                .duration_since(std::time::UNIX_EPOCH)
4584                                                .map_err(|e| {
4585                                                    let err_msg = format!(
4586                                                "Failed to get the current time. Reason: {e}"
4587                                            );
4588
4589                                                    #[cfg(feature = "logging")]
4590                                                    error!(target: "stdout", "{}", &err_msg);
4591
4592                                                    LlamaCoreError::Operation(err_msg)
4593                                                })?;
4594
4595                                            let chat_completion_chunk = ChatCompletionChunk {
4596                                                id,
4597                                                object: "chat.completion.chunk".to_string(),
4598                                                created: created.as_secs(),
4599                                                model: graph.name().to_owned(),
4600                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4601                                                choices: vec![ChatCompletionChunkChoice {
4602                                                    index: 0,
4603                                                    delta: ChatCompletionChunkChoiceDelta {
4604                                                        role: ChatCompletionRole::Assistant,
4605                                                        content: Some(output),
4606                                                        tool_calls: vec![],
4607                                                    },
4608                                                    logprobs: None,
4609                                                    finish_reason: None,
4610                                                }],
4611                                                usage: None,
4612                                            };
4613
4614                                            #[cfg(feature = "logging")]
4615                                            info!(target: "stdout", "created chat completion chunk");
4616
4617                                            // serialize chat completion chunk
4618                                            let chunk_str =
4619                                                serde_json::to_string(&chat_completion_chunk)
4620                                                    .map_err(|e| {
4621                                                        let err_msg = format!(
4622                                            "Failed to serialize chat completion chunk. Reason: {e}"
4623                                        );
4624
4625                                                        #[cfg(feature = "logging")]
4626                                                        error!(target: "stdout", "{}", &err_msg);
4627
4628                                                        LlamaCoreError::Operation(err_msg)
4629                                                    })?;
4630
4631                                            Ok(format!("data: {chunk_str}\n\n"))
4632                                        }
4633                                        StreamState::Done => {
4634                                            *stream_state = StreamState::EndOfSequence;
4635
4636                                            Ok("data: [DONE]\n\n".to_string())
4637                                        }
4638                                        StreamState::EndOfSequence => {
4639                                            Ok("[GGML] End of sequence".to_string())
4640                                        }
4641                                    }
4642                                }
4643                                Err(wasmedge_wasi_nn::Error::BackendError(
4644                                    wasmedge_wasi_nn::BackendError::EndOfSequence,
4645                                )) => {
4646                                    #[cfg(feature = "logging")]
4647                                    debug!(target: "stdout", "End of sequence");
4648
4649                                    match stream_state {
4650                                        StreamState::Usage => {
4651                                            *stream_state = StreamState::Done;
4652
4653                                            // retrieve the number of prompt and completion tokens
4654                                            let token_info = get_token_info_by_graph(graph)?;
4655
4656                                            let usage = Some(Usage {
4657                                                prompt_tokens: token_info.prompt_tokens,
4658                                                completion_tokens: token_info.completion_tokens,
4659                                                total_tokens: token_info.prompt_tokens
4660                                                    + token_info.completion_tokens,
4661                                            });
4662
4663                                            #[cfg(feature = "logging")]
4664                                            info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4665
4666                                            let created = SystemTime::now()
4667                                                .duration_since(std::time::UNIX_EPOCH)
4668                                                .map_err(|e| {
4669                                                    let err_msg = format!(
4670                                                "Failed to get the current time. Reason: {e}"
4671                                            );
4672
4673                                                    #[cfg(feature = "logging")]
4674                                                    error!(target: "stdout", "{}", &err_msg);
4675
4676                                                    LlamaCoreError::Operation(err_msg)
4677                                                })?;
4678
4679                                            let chat_completion_chunk = ChatCompletionChunk {
4680                                                id,
4681                                                object: "chat.completion.chunk".to_string(),
4682                                                created: created.as_secs(),
4683                                                model: graph.name().to_owned(),
4684                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4685                                                choices: vec![],
4686                                                usage,
4687                                            };
4688
4689                                            // serialize chat completion chunk
4690                                            let chunk_str =
4691                                                serde_json::to_string(&chat_completion_chunk)
4692                                                    .map_err(|e| {
4693                                                        let err_msg = format!(
4694                                            "Failed to serialize chat completion chunk. Reason: {e}"
4695                                        );
4696
4697                                                        #[cfg(feature = "logging")]
4698                                                        error!(target: "stdout", "{}", &err_msg);
4699
4700                                                        LlamaCoreError::Operation(err_msg)
4701                                                    })?;
4702
4703                                            Ok(format!("data: {chunk_str}\n\n"))
4704                                        }
4705                                        StreamState::Done | StreamState::NoUsage => {
4706                                            *stream_state = StreamState::EndOfSequence;
4707
4708                                            Ok("data: [DONE]\n\n".to_string())
4709                                        }
4710                                        StreamState::EndOfSequence => {
4711                                            Ok("[GGML] End of sequence".to_string())
4712                                        }
4713                                    }
4714                                }
4715                                Err(wasmedge_wasi_nn::Error::BackendError(
4716                                    wasmedge_wasi_nn::BackendError::ContextFull,
4717                                )) => {
4718                                    #[cfg(feature = "logging")]
4719                                    debug!(target: "stdout", "Context full");
4720
4721                                    match context_full_state {
4722                                        ContextFullState::Message => {
4723                                            match include_usage {
4724                                                true => {
4725                                                    *context_full_state = ContextFullState::Usage
4726                                                }
4727                                                false => {
4728                                                    *context_full_state = ContextFullState::Done
4729                                                }
4730                                            }
4731
4732                                            let created = SystemTime::now()
4733                                                .duration_since(std::time::UNIX_EPOCH)
4734                                                .map_err(|e| {
4735                                                    let err_msg = format!(
4736                                                "Failed to get the current time. Reason: {e}"
4737                                            );
4738
4739                                                    #[cfg(feature = "logging")]
4740                                                    error!(target: "stdout", "{}", &err_msg);
4741
4742                                                    LlamaCoreError::Operation(err_msg)
4743                                                })?;
4744
4745                                            let chat_completion_chunk = ChatCompletionChunk {
4746                                                id,
4747                                                object: "chat.completion.chunk".to_string(),
4748                                                created: created.as_secs(),
4749                                                model: graph.name().to_owned(),
4750                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4751                                                choices: vec![ChatCompletionChunkChoice {
4752                                                    index: 0,
4753                                                    delta: ChatCompletionChunkChoiceDelta {
4754                                                        role: ChatCompletionRole::Assistant,
4755                                                        content: Some(
4756                                                            "<|WASMEDGE-GGML-CONTEXT-FULL|>"
4757                                                                .to_string(),
4758                                                        ),
4759                                                        tool_calls: vec![],
4760                                                    },
4761                                                    logprobs: None,
4762                                                    finish_reason: Some(FinishReason::length),
4763                                                }],
4764                                                usage: None,
4765                                            };
4766
4767                                            // serialize chat completion chunk
4768                                            let chunk_str =
4769                                                serde_json::to_string(&chat_completion_chunk)
4770                                                    .map_err(|e| {
4771                                                        let err_msg = format!(
4772                                            "Failed to serialize chat completion chunk. Reason: {e}"
4773                                        );
4774
4775                                                        #[cfg(feature = "logging")]
4776                                                        error!(target: "stdout", "{}", &err_msg);
4777
4778                                                        LlamaCoreError::Operation(err_msg)
4779                                                    })?;
4780
4781                                            Ok(format!("data: {chunk_str}\n\n"))
4782                                        }
4783                                        ContextFullState::Usage => {
4784                                            *context_full_state = ContextFullState::Done;
4785
4786                                            // retrieve the number of prompt and completion tokens
4787                                            let token_info = get_token_info_by_graph(graph)?;
4788
4789                                            let usage = Some(Usage {
4790                                                prompt_tokens: token_info.prompt_tokens,
4791                                                completion_tokens: token_info.completion_tokens,
4792                                                total_tokens: token_info.prompt_tokens
4793                                                    + token_info.completion_tokens,
4794                                            });
4795
4796                                            let created = SystemTime::now()
4797                                                .duration_since(std::time::UNIX_EPOCH)
4798                                                .map_err(|e| {
4799                                                    let err_msg = format!(
4800                                                "Failed to get the current time. Reason: {e}"
4801                                            );
4802
4803                                                    #[cfg(feature = "logging")]
4804                                                    error!(target: "stdout", "{}", &err_msg);
4805
4806                                                    LlamaCoreError::Operation(err_msg)
4807                                                })?;
4808
4809                                            let chat_completion_chunk = ChatCompletionChunk {
4810                                                id,
4811                                                object: "chat.completion.chunk".to_string(),
4812                                                created: created.as_secs(),
4813                                                model: graph.name().to_owned(),
4814                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4815                                                choices: vec![],
4816                                                usage,
4817                                            };
4818
4819                                            // serialize chat completion chunk
4820                                            let chunk_str =
4821                                                serde_json::to_string(&chat_completion_chunk)
4822                                                    .map_err(|e| {
4823                                                        let err_msg = format!(
4824                                            "Failed to serialize chat completion chunk. Reason: {e}"
4825                                        );
4826
4827                                                        #[cfg(feature = "logging")]
4828                                                        error!(target: "stdout", "{}", &err_msg);
4829
4830                                                        LlamaCoreError::Operation(err_msg)
4831                                                    })?;
4832
4833                                            Ok(format!("data: {chunk_str}\n\n"))
4834                                        }
4835                                        ContextFullState::Done => {
4836                                            *context_full_state = ContextFullState::EndOfSequence;
4837
4838                                            Ok("data: [DONE]\n\n".to_string())
4839                                        }
4840                                        ContextFullState::EndOfSequence => {
4841                                            Ok("[GGML] End of sequence".to_string())
4842                                        }
4843                                    }
4844                                }
4845                                Err(wasmedge_wasi_nn::Error::BackendError(
4846                                    wasmedge_wasi_nn::BackendError::PromptTooLong,
4847                                )) => {
4848                                    #[cfg(feature = "logging")]
4849                                    debug!(target: "stdout", "Prompt too long");
4850
4851                                    match prompt_too_long_state {
4852                                        PromptTooLongState::Message => {
4853                                            match include_usage {
4854                                                true => {
4855                                                    *prompt_too_long_state =
4856                                                        PromptTooLongState::Usage
4857                                                }
4858                                                false => {
4859                                                    *prompt_too_long_state =
4860                                                        PromptTooLongState::Done
4861                                                }
4862                                            }
4863
4864                                            let created = SystemTime::now()
4865                                                .duration_since(std::time::UNIX_EPOCH)
4866                                                .map_err(|e| {
4867                                                    let err_msg = format!(
4868                                                "Failed to get the current time. Reason: {e}"
4869                                            );
4870
4871                                                    #[cfg(feature = "logging")]
4872                                                    error!(target: "stdout", "{}", &err_msg);
4873
4874                                                    LlamaCoreError::Operation(err_msg)
4875                                                })?;
4876
4877                                            let chat_completion_chunk = ChatCompletionChunk {
4878                                                id,
4879                                                object: "chat.completion.chunk".to_string(),
4880                                                created: created.as_secs(),
4881                                                model: graph.name().to_owned(),
4882                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4883                                                choices: vec![ChatCompletionChunkChoice {
4884                                                    index: 0,
4885                                                    delta: ChatCompletionChunkChoiceDelta {
4886                                                        role: ChatCompletionRole::Assistant,
4887                                                        content: None,
4888                                                        tool_calls: vec![],
4889                                                    },
4890                                                    logprobs: None,
4891                                                    finish_reason: Some(FinishReason::length),
4892                                                }],
4893                                                usage: None,
4894                                            };
4895
4896                                            // serialize chat completion chunk
4897                                            let chunk_str =
4898                                                serde_json::to_string(&chat_completion_chunk)
4899                                                    .map_err(|e| {
4900                                                        let err_msg = format!(
4901                                            "Failed to serialize chat completion chunk. Reason: {e}"
4902                                        );
4903
4904                                                        #[cfg(feature = "logging")]
4905                                                        error!(target: "stdout", "{}", &err_msg);
4906
4907                                                        LlamaCoreError::Operation(err_msg)
4908                                                    })?;
4909
4910                                            Ok(format!("data: {chunk_str}\n\n"))
4911                                        }
4912                                        PromptTooLongState::Usage => {
4913                                            *prompt_too_long_state = PromptTooLongState::Done;
4914
4915                                            // retrieve the number of prompt and completion tokens
4916                                            let token_info = get_token_info_by_graph(graph)?;
4917
4918                                            let usage = Some(Usage {
4919                                                prompt_tokens: token_info.prompt_tokens,
4920                                                completion_tokens: token_info.completion_tokens,
4921                                                total_tokens: token_info.prompt_tokens
4922                                                    + token_info.completion_tokens,
4923                                            });
4924
4925                                            let created = SystemTime::now()
4926                                                .duration_since(std::time::UNIX_EPOCH)
4927                                                .map_err(|e| {
4928                                                    let err_msg = format!(
4929                                                "Failed to get the current time. Reason: {e}"
4930                                            );
4931
4932                                                    #[cfg(feature = "logging")]
4933                                                    error!(target: "stdout", "{}", &err_msg);
4934
4935                                                    LlamaCoreError::Operation(err_msg)
4936                                                })?;
4937
4938                                            let chat_completion_chunk = ChatCompletionChunk {
4939                                                id,
4940                                                object: "chat.completion.chunk".to_string(),
4941                                                created: created.as_secs(),
4942                                                model: graph.name().to_owned(),
4943                                                system_fingerprint: "fp_44709d6fcb".to_string(),
4944                                                choices: vec![],
4945                                                usage,
4946                                            };
4947
4948                                            // serialize chat completion chunk
4949                                            let chunk_str =
4950                                                serde_json::to_string(&chat_completion_chunk)
4951                                                    .map_err(|e| {
4952                                                        let err_msg = format!(
4953                                            "Failed to serialize chat completion chunk. Reason: {e}"
4954                                        );
4955
4956                                                        #[cfg(feature = "logging")]
4957                                                        error!(target: "stdout", "{}", &err_msg);
4958
4959                                                        LlamaCoreError::Operation(err_msg)
4960                                                    })?;
4961
4962                                            Ok(format!("data: {chunk_str}\n\n"))
4963                                        }
4964                                        PromptTooLongState::Done => {
4965                                            *prompt_too_long_state =
4966                                                PromptTooLongState::EndOfSequence;
4967
4968                                            Ok("data: [DONE]\n\n".to_string())
4969                                        }
4970                                        PromptTooLongState::EndOfSequence => {
4971                                            Ok("[GGML] End of sequence".to_string())
4972                                        }
4973                                    }
4974                                }
4975                                Err(e) => {
4976                                    let err_msg = format!(
4977                                        "Failed to compute the chat completion. Reason: {e}"
4978                                    );
4979
4980                                    #[cfg(feature = "logging")]
4981                                    error!(target: "stdout", "{}", &err_msg);
4982
4983                                    Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4984                                        err_msg,
4985                                    )))
4986                                }
4987                            }
4988                        }
4989                        None => {
4990                            let err_msg = "There is no model available in the chat graphs.";
4991
4992                            #[cfg(feature = "logging")]
4993                            error!(target: "stdout", "{}", &err_msg);
4994
4995                            Err(LlamaCoreError::Operation(err_msg.into()))
4996                        }
4997                    }
4998                }
4999            }
5000        }
5001        None => {
5002            match chat_graphs.iter_mut().next() {
5003                Some((_, graph)) => {
5004                    // compute
5005                    match graph.compute_single() {
5006                        Ok(_) => {
5007                            #[cfg(feature = "logging")]
5008                            debug!(target: "stdout", "Compute the chat stream chunk successfully.");
5009
5010                            match stream_state {
5011                                StreamState::Usage | StreamState::NoUsage => {
5012                                    // Retrieve the output
5013                                    let output_buffer =
5014                                        get_output_buffer_single(graph, OUTPUT_TENSOR)?;
5015
5016                                    #[cfg(feature = "logging")]
5017                                    info!(target: "stdout", "retrieved the output buffer");
5018
5019                                    // decode the output buffer to a utf8 string
5020                                    let output = match String::from_utf8(output_buffer.clone()) {
5021                                        Ok(token) => token,
5022                                        Err(_) => {
5023                                            let mutex = CACHED_UTF8_ENCODINGS
5024                                                .get_or_init(|| Mutex::new(Vec::new()));
5025                                            let mut cached_encodings = mutex.lock().map_err(|e| {
5026                                            let err_msg = format!(
5027                                                "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
5028                                            );
5029
5030                                            #[cfg(feature = "logging")]
5031                                            error!(target: "stdout", "{}", &err_msg);
5032
5033                                            LlamaCoreError::Operation(err_msg)
5034                                        })?;
5035
5036                                            cached_encodings.extend_from_slice(&output_buffer[..]);
5037
5038                                            match String::from_utf8(cached_encodings.to_vec()) {
5039                                                Ok(token) => {
5040                                                    // clear encodings
5041                                                    cached_encodings.clear();
5042
5043                                                    token
5044                                                }
5045                                                Err(e) => {
5046                                                    // TODO This is a temp check. In case, infinite cached encodings happen.
5047                                                    if cached_encodings.len() > 4 {
5048                                                        let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
5049
5050                                                        #[cfg(feature = "logging")]
5051                                                        error!(target: "stdout", "{}", &err_msg);
5052
5053                                                        #[cfg(feature = "logging")]
5054                                                        error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
5055
5056                                                        // let token = String::from_utf8_lossy(
5057                                                        //     &cached_encodings,
5058                                                        // )
5059                                                        // .to_string();
5060
5061                                                        // clear CACHED_UTF8_ENCODINGS
5062                                                        cached_encodings.clear();
5063
5064                                                        String::from("")
5065                                                    } else {
5066                                                        let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
5067
5068                                                        #[cfg(feature = "logging")]
5069                                                        warn!(target: "stdout", "{}", &warn_msg);
5070
5071                                                        String::from("")
5072                                                    }
5073                                                }
5074                                            }
5075                                        }
5076                                    };
5077
5078                                    #[cfg(feature = "logging")]
5079                                    info!(target: "stdout", "decoded the output buffer");
5080
5081                                    let created = SystemTime::now()
5082                                        .duration_since(std::time::UNIX_EPOCH)
5083                                        .map_err(|e| {
5084                                            let err_msg = format!(
5085                                                "Failed to get the current time. Reason: {e}"
5086                                            );
5087
5088                                            #[cfg(feature = "logging")]
5089                                            error!(target: "stdout", "{}", &err_msg);
5090
5091                                            LlamaCoreError::Operation(err_msg)
5092                                        })?;
5093
5094                                    let chat_completion_chunk = ChatCompletionChunk {
5095                                        id,
5096                                        object: "chat.completion.chunk".to_string(),
5097                                        created: created.as_secs(),
5098                                        model: graph.name().to_owned(),
5099                                        system_fingerprint: "fp_44709d6fcb".to_string(),
5100                                        choices: vec![ChatCompletionChunkChoice {
5101                                            index: 0,
5102                                            delta: ChatCompletionChunkChoiceDelta {
5103                                                role: ChatCompletionRole::Assistant,
5104                                                content: Some(output),
5105                                                tool_calls: vec![],
5106                                            },
5107                                            logprobs: None,
5108                                            finish_reason: None,
5109                                        }],
5110                                        usage: None,
5111                                    };
5112
5113                                    #[cfg(feature = "logging")]
5114                                    info!(target: "stdout", "created chat completion chunk");
5115
5116                                    // serialize chat completion chunk
5117                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
5118                                        .map_err(|e| {
5119                                        let err_msg = format!(
5120                                            "Failed to serialize chat completion chunk. Reason: {e}"
5121                                        );
5122
5123                                        #[cfg(feature = "logging")]
5124                                        error!(target: "stdout", "{}", &err_msg);
5125
5126                                        LlamaCoreError::Operation(err_msg)
5127                                    })?;
5128
5129                                    Ok(format!("data: {chunk_str}\n\n"))
5130                                }
5131                                StreamState::Done => {
5132                                    *stream_state = StreamState::EndOfSequence;
5133
5134                                    Ok("data: [DONE]\n\n".to_string())
5135                                }
5136                                StreamState::EndOfSequence => {
5137                                    Ok("[GGML] End of sequence".to_string())
5138                                }
5139                            }
5140                        }
5141                        Err(wasmedge_wasi_nn::Error::BackendError(
5142                            wasmedge_wasi_nn::BackendError::EndOfSequence,
5143                        )) => {
5144                            #[cfg(feature = "logging")]
5145                            debug!(target: "stdout", "End of sequence");
5146
5147                            match stream_state {
5148                                StreamState::Usage => {
5149                                    *stream_state = StreamState::Done;
5150
5151                                    // retrieve the number of prompt and completion tokens
5152                                    let token_info = get_token_info_by_graph(graph)?;
5153
5154                                    let usage = Some(Usage {
5155                                        prompt_tokens: token_info.prompt_tokens,
5156                                        completion_tokens: token_info.completion_tokens,
5157                                        total_tokens: token_info.prompt_tokens
5158                                            + token_info.completion_tokens,
5159                                    });
5160
5161                                    #[cfg(feature = "logging")]
5162                                    info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
5163
5164                                    let created = SystemTime::now()
5165                                        .duration_since(std::time::UNIX_EPOCH)
5166                                        .map_err(|e| {
5167                                            let err_msg = format!(
5168                                                "Failed to get the current time. Reason: {e}"
5169                                            );
5170
5171                                            #[cfg(feature = "logging")]
5172                                            error!(target: "stdout", "{}", &err_msg);
5173
5174                                            LlamaCoreError::Operation(err_msg)
5175                                        })?;
5176
5177                                    let chat_completion_chunk = ChatCompletionChunk {
5178                                        id,
5179                                        object: "chat.completion.chunk".to_string(),
5180                                        created: created.as_secs(),
5181                                        model: graph.name().to_owned(),
5182                                        system_fingerprint: "fp_44709d6fcb".to_string(),
5183                                        choices: vec![],
5184                                        usage,
5185                                    };
5186
5187                                    // serialize chat completion chunk
5188                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
5189                                        .map_err(|e| {
5190                                        let err_msg = format!(
5191                                            "Failed to serialize chat completion chunk. Reason: {e}"
5192                                        );
5193
5194                                        #[cfg(feature = "logging")]
5195                                        error!(target: "stdout", "{}", &err_msg);
5196
5197                                        LlamaCoreError::Operation(err_msg)
5198                                    })?;
5199
5200                                    Ok(format!("data: {chunk_str}\n\n"))
5201                                }
5202                                StreamState::Done | StreamState::NoUsage => {
5203                                    *stream_state = StreamState::EndOfSequence;
5204
5205                                    Ok("data: [DONE]\n\n".to_string())
5206                                }
5207                                StreamState::EndOfSequence => {
5208                                    Ok("[GGML] End of sequence".to_string())
5209                                }
5210                            }
5211                        }
5212                        Err(wasmedge_wasi_nn::Error::BackendError(
5213                            wasmedge_wasi_nn::BackendError::ContextFull,
5214                        )) => {
5215                            #[cfg(feature = "logging")]
5216                            debug!(target: "stdout", "Context full");
5217
5218                            match context_full_state {
5219                                ContextFullState::Message => {
5220                                    match include_usage {
5221                                        true => *context_full_state = ContextFullState::Usage,
5222                                        false => *context_full_state = ContextFullState::Done,
5223                                    }
5224
5225                                    let created = SystemTime::now()
5226                                        .duration_since(std::time::UNIX_EPOCH)
5227                                        .map_err(|e| {
5228                                            let err_msg = format!(
5229                                                "Failed to get the current time. Reason: {e}"
5230                                            );
5231
5232                                            #[cfg(feature = "logging")]
5233                                            error!(target: "stdout", "{}", &err_msg);
5234
5235                                            LlamaCoreError::Operation(err_msg)
5236                                        })?;
5237
5238                                    let chat_completion_chunk = ChatCompletionChunk {
5239                                        id,
5240                                        object: "chat.completion.chunk".to_string(),
5241                                        created: created.as_secs(),
5242                                        model: graph.name().to_owned(),
5243                                        system_fingerprint: "fp_44709d6fcb".to_string(),
5244                                        choices: vec![ChatCompletionChunkChoice {
5245                                            index: 0,
5246                                            delta: ChatCompletionChunkChoiceDelta {
5247                                                role: ChatCompletionRole::Assistant,
5248                                                content: Some(
5249                                                    "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
5250                                                ),
5251                                                tool_calls: vec![],
5252                                            },
5253                                            logprobs: None,
5254                                            finish_reason: Some(FinishReason::length),
5255                                        }],
5256                                        usage: None,
5257                                    };
5258
5259                                    // serialize chat completion chunk
5260                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
5261                                        .map_err(|e| {
5262                                        let err_msg = format!(
5263                                            "Failed to serialize chat completion chunk. Reason: {e}"
5264                                        );
5265
5266                                        #[cfg(feature = "logging")]
5267                                        error!(target: "stdout", "{}", &err_msg);
5268
5269                                        LlamaCoreError::Operation(err_msg)
5270                                    })?;
5271
5272                                    Ok(format!("data: {chunk_str}\n\n"))
5273                                }
5274                                ContextFullState::Usage => {
5275                                    *context_full_state = ContextFullState::Done;
5276
5277                                    // retrieve the number of prompt and completion tokens
5278                                    let token_info = get_token_info_by_graph(graph)?;
5279
5280                                    let usage = Some(Usage {
5281                                        prompt_tokens: token_info.prompt_tokens,
5282                                        completion_tokens: token_info.completion_tokens,
5283                                        total_tokens: token_info.prompt_tokens
5284                                            + token_info.completion_tokens,
5285                                    });
5286
5287                                    let created = SystemTime::now()
5288                                        .duration_since(std::time::UNIX_EPOCH)
5289                                        .map_err(|e| {
5290                                            let err_msg = format!(
5291                                                "Failed to get the current time. Reason: {e}"
5292                                            );
5293
5294                                            #[cfg(feature = "logging")]
5295                                            error!(target: "stdout", "{}", &err_msg);
5296
5297                                            LlamaCoreError::Operation(err_msg)
5298                                        })?;
5299
5300                                    let chat_completion_chunk = ChatCompletionChunk {
5301                                        id,
5302                                        object: "chat.completion.chunk".to_string(),
5303                                        created: created.as_secs(),
5304                                        model: graph.name().to_owned(),
5305                                        system_fingerprint: "fp_44709d6fcb".to_string(),
5306                                        choices: vec![],
5307                                        usage,
5308                                    };
5309
5310                                    // serialize chat completion chunk
5311                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
5312                                        .map_err(|e| {
5313                                        let err_msg = format!(
5314                                            "Failed to serialize chat completion chunk. Reason: {e}"
5315                                        );
5316
5317                                        #[cfg(feature = "logging")]
5318                                        error!(target: "stdout", "{}", &err_msg);
5319
5320                                        LlamaCoreError::Operation(err_msg)
5321                                    })?;
5322
5323                                    Ok(format!("data: {chunk_str}\n\n"))
5324                                }
5325                                ContextFullState::Done => {
5326                                    *context_full_state = ContextFullState::EndOfSequence;
5327
5328                                    Ok("data: [DONE]\n\n".to_string())
5329                                }
5330                                ContextFullState::EndOfSequence => {
5331                                    Ok("[GGML] End of sequence".to_string())
5332                                }
5333                            }
5334                        }
5335                        Err(wasmedge_wasi_nn::Error::BackendError(
5336                            wasmedge_wasi_nn::BackendError::PromptTooLong,
5337                        )) => {
5338                            #[cfg(feature = "logging")]
5339                            debug!(target: "stdout", "Prompt too long");
5340
5341                            match prompt_too_long_state {
5342                                PromptTooLongState::Message => {
5343                                    match include_usage {
5344                                        true => *prompt_too_long_state = PromptTooLongState::Usage,
5345                                        false => *prompt_too_long_state = PromptTooLongState::Done,
5346                                    }
5347
5348                                    let created = SystemTime::now()
5349                                        .duration_since(std::time::UNIX_EPOCH)
5350                                        .map_err(|e| {
5351                                            let err_msg = format!(
5352                                                "Failed to get the current time. Reason: {e}"
5353                                            );
5354
5355                                            #[cfg(feature = "logging")]
5356                                            error!(target: "stdout", "{}", &err_msg);
5357
5358                                            LlamaCoreError::Operation(err_msg)
5359                                        })?;
5360
5361                                    let chat_completion_chunk = ChatCompletionChunk {
5362                                        id,
5363                                        object: "chat.completion.chunk".to_string(),
5364                                        created: created.as_secs(),
5365                                        model: graph.name().to_owned(),
5366                                        system_fingerprint: "fp_44709d6fcb".to_string(),
5367                                        choices: vec![ChatCompletionChunkChoice {
5368                                            index: 0,
5369                                            delta: ChatCompletionChunkChoiceDelta {
5370                                                role: ChatCompletionRole::Assistant,
5371                                                content: None,
5372                                                tool_calls: vec![],
5373                                            },
5374                                            logprobs: None,
5375                                            finish_reason: Some(FinishReason::length),
5376                                        }],
5377                                        usage: None,
5378                                    };
5379
5380                                    // serialize chat completion chunk
5381                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
5382                                        .map_err(|e| {
5383                                        let err_msg = format!(
5384                                            "Failed to serialize chat completion chunk. Reason: {e}"
5385                                        );
5386
5387                                        #[cfg(feature = "logging")]
5388                                        error!(target: "stdout", "{}", &err_msg);
5389
5390                                        LlamaCoreError::Operation(err_msg)
5391                                    })?;
5392
5393                                    Ok(format!("data: {chunk_str}\n\n"))
5394                                }
5395                                PromptTooLongState::Usage => {
5396                                    *prompt_too_long_state = PromptTooLongState::Done;
5397
5398                                    // retrieve the number of prompt and completion tokens
5399                                    let token_info = get_token_info_by_graph(graph)?;
5400
5401                                    let usage = Some(Usage {
5402                                        prompt_tokens: token_info.prompt_tokens,
5403                                        completion_tokens: token_info.completion_tokens,
5404                                        total_tokens: token_info.prompt_tokens
5405                                            + token_info.completion_tokens,
5406                                    });
5407
5408                                    let created = SystemTime::now()
5409                                        .duration_since(std::time::UNIX_EPOCH)
5410                                        .map_err(|e| {
5411                                            let err_msg = format!(
5412                                                "Failed to get the current time. Reason: {e}"
5413                                            );
5414
5415                                            #[cfg(feature = "logging")]
5416                                            error!(target: "stdout", "{}", &err_msg);
5417
5418                                            LlamaCoreError::Operation(err_msg)
5419                                        })?;
5420
5421                                    let chat_completion_chunk = ChatCompletionChunk {
5422                                        id,
5423                                        object: "chat.completion.chunk".to_string(),
5424                                        created: created.as_secs(),
5425                                        model: graph.name().to_owned(),
5426                                        system_fingerprint: "fp_44709d6fcb".to_string(),
5427                                        choices: vec![],
5428                                        usage,
5429                                    };
5430
5431                                    // serialize chat completion chunk
5432                                    let chunk_str = serde_json::to_string(&chat_completion_chunk)
5433                                        .map_err(|e| {
5434                                        let err_msg = format!(
5435                                            "Failed to serialize chat completion chunk. Reason: {e}"
5436                                        );
5437
5438                                        #[cfg(feature = "logging")]
5439                                        error!(target: "stdout", "{}", &err_msg);
5440
5441                                        LlamaCoreError::Operation(err_msg)
5442                                    })?;
5443
5444                                    Ok(format!("data: {chunk_str}\n\n"))
5445                                }
5446                                PromptTooLongState::Done => {
5447                                    *prompt_too_long_state = PromptTooLongState::EndOfSequence;
5448
5449                                    Ok("data: [DONE]\n\n".to_string())
5450                                }
5451                                PromptTooLongState::EndOfSequence => {
5452                                    Ok("[GGML] End of sequence".to_string())
5453                                }
5454                            }
5455                        }
5456                        Err(e) => {
5457                            let err_msg =
5458                                format!("Failed to compute the chat completion. Reason: {e}");
5459
5460                            #[cfg(feature = "logging")]
5461                            error!(target: "stdout", "{}", &err_msg);
5462
5463                            Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
5464                                err_msg,
5465                            )))
5466                        }
5467                    }
5468                }
5469                None => {
5470                    let err_msg = "There is no model available in the chat graphs.";
5471
5472                    #[cfg(feature = "logging")]
5473                    error!(target: "stdout", "{}", &err_msg);
5474
5475                    Err(LlamaCoreError::Operation(err_msg.into()))
5476                }
5477            }
5478        }
5479    };
5480
5481    #[cfg(feature = "logging")]
5482    info!(target: "stdout", "Return the chat stream chunk!");
5483
5484    res
5485}
5486
5487#[allow(dead_code)]
5488#[derive(Debug)]
5489struct ParseResult {
5490    raw: String,
5491    content: Option<String>,
5492    tool_calls: Vec<ToolCall>,
5493}