#[allow(unused_imports)]
use crate::{
error,
metadata::ggml::GgmlMetadata,
running_mode,
utils::{
gen_chat_id, gen_response_id, get_output_buffer, get_output_buffer_single,
get_token_info_by_graph, get_token_info_by_graph_name, set_tensor_data_u8,
},
Graph, RunningMode, CHAT_GRAPHS, OUTPUT_TENSOR,
};
use chat_prompts::{BuildChatPrompt, ChatPrompt, PromptTemplateType};
use either::{Either, Left, Right};
use endpoints::{
chat::{ChatCompletionRequestMessage, ChatCompletionRole, ChatCompletionUserMessageContent},
responses::{
items::{ResponseOutputItem, ResponseOutputItemOutputMessageContent},
response_object::{
Input, InputItem, InputMessageContent, InputTokensDetails, OutputTokensDetails,
RequestOfModelResponse, ResponseObject, ToolChoice, Usage,
},
},
};
use error::{BackendError, LlamaCoreError};
use std::{
collections::{HashMap, VecDeque},
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Mutex, OnceLock,
},
task::{Context, Poll, Waker},
time::SystemTime,
};
static CHAT_STREAM_WAKER_QUEUE: OnceLock<Mutex<VecDeque<Waker>>> = OnceLock::new();
static CHAT_STREAM_ACTIVE: AtomicBool = AtomicBool::new(false);
pub async fn chat(
chat_request: &mut RequestOfModelResponse,
) -> Result<
(
Either<impl futures::TryStream<Ok = String, Error = LlamaCoreError>, ResponseObject>,
bool,
),
LlamaCoreError,
> {
#[cfg(feature = "logging")]
{
debug!(target: "stdout", "tool choice: {:?}", &chat_request.tool_choice);
debug!(target: "stdout", "tools: {:?}", chat_request.tools.as_ref());
debug!(target: "stdout", "stream mode: {:?}", chat_request.stream);
}
let result = match chat_request.stream {
Some(true) => match chat_stream(chat_request).await {
Ok((stream, include_tool_calls)) => Ok((Left(stream), include_tool_calls)),
Err(e) => Err(e),
},
Some(false) | None => match chat_once(chat_request).await {
Ok((chat_completion_object, include_tool_calls)) => {
Ok((Right(chat_completion_object), include_tool_calls))
}
Err(e) => Err(e),
},
};
#[cfg(feature = "logging")]
info!(target: "stdout", "Reset the model metadata");
result
}
async fn chat_stream(
chat_request: &mut RequestOfModelResponse,
) -> Result<
(
impl futures::TryStream<Ok = String, Error = LlamaCoreError>,
bool,
),
LlamaCoreError,
> {
#[cfg(feature = "logging")]
info!(target: "stdout", "Process chat completion request in the stream mode");
let running_mode = running_mode()?;
if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
let err_msg = "The chat completion is only supported in the chat or rag mode.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{err_msg}");
return Err(LlamaCoreError::Operation(err_msg.to_string()));
}
let model_name = chat_request.model.clone();
#[cfg(feature = "logging")]
info!(target: "stdout", "Check model metadata");
let mut metadata = check_model_metadata(chat_request)?;
#[cfg(feature = "logging")]
info!(target: "stdout", "Build the chat prompt");
let (prompt, avaible_completion_tokens, tool_use) =
build_prompt(model_name.as_ref(), chat_request)?;
#[cfg(feature = "logging")]
{
info!(target: "stdout", "prompt:\n{}", &prompt);
info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
info!(target: "stdout", "tool_use: {tool_use}");
}
#[cfg(feature = "logging")]
info!(target: "stdout", "Update the n_predict");
update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
#[cfg(feature = "logging")]
info!(target: "stdout", "Feed the prompt to the model");
set_prompt(chat_request.model.as_ref(), &prompt)?;
let stream = match tool_use {
false => (ChatStream::new(model_name, None), false),
true => {
todo!("Implement the streaming with tool use")
}
};
#[cfg(feature = "logging")]
info!(target: "stdout", "End of the chat completion stream.");
Ok(stream)
}
async fn chat_once(
chat_request: &mut RequestOfModelResponse,
) -> Result<(ResponseObject, bool), LlamaCoreError> {
#[cfg(feature = "logging")]
info!(target: "stdout", "Processing chat completion request in non-stream mode");
let running_mode = running_mode()?;
if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
let err_msg = "The chat completion is only supported in the chat or rag mode.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{err_msg}");
return Err(LlamaCoreError::Operation(err_msg.to_string()));
}
let model_name = chat_request.model.clone();
#[cfg(feature = "logging")]
info!(target: "stdout", "Check model metadata");
let mut metadata = check_model_metadata(chat_request)?;
#[cfg(feature = "logging")]
info!(target: "stdout", "Build the chat prompt");
let (prompt, avaible_completion_tokens, tool_use) =
build_prompt(model_name.as_ref(), chat_request)?;
#[cfg(feature = "logging")]
{
info!(target: "stdout", "prompt:\n{}", &prompt);
info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
info!(target: "stdout", "tool_use: {tool_use}");
}
#[cfg(feature = "logging")]
info!(target: "stdout", "Update n_predict");
update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
#[cfg(feature = "logging")]
info!(target: "stdout", "Feed the prompt to the model");
set_prompt(model_name.as_ref(), &prompt)?;
#[cfg(feature = "logging")]
info!(target: "stdout", "Compute chat completion.");
let res = compute(chat_request, model_name.as_ref(), tool_use);
#[cfg(feature = "logging")]
info!(target: "stdout", "End of the chat completion");
reset_model_metadata(model_name.as_ref())?;
res
}
fn compute(
chat_request: &mut RequestOfModelResponse,
model_name: Option<&String>,
tool_use: bool,
) -> Result<(ResponseObject, bool), LlamaCoreError> {
let chat_graphs = match CHAT_GRAPHS.get() {
Some(chat_graphs) => chat_graphs,
None => {
let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
return Err(LlamaCoreError::Operation(err_msg.into()));
}
};
let mut chat_graphs = chat_graphs.lock().map_err(|e| {
let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
match model_name {
Some(model_name) => match chat_graphs.contains_key(model_name) {
true => {
let graph = chat_graphs.get_mut(model_name).unwrap();
compute_by_graph(chat_request, graph, tool_use)
}
false => match chat_graphs.iter_mut().next() {
Some((_, graph)) => compute_by_graph(chat_request, graph, tool_use),
None => {
let err_msg = "There is no model available in the chat graphs.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
Err(LlamaCoreError::Operation(err_msg.into()))
}
},
},
None => match chat_graphs.iter_mut().next() {
Some((_, graph)) => compute_by_graph(chat_request, graph, tool_use),
None => {
let err_msg = "There is no model available in the chat graphs.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
Err(LlamaCoreError::Operation(err_msg.into()))
}
},
}
}
fn compute_by_graph(
chat_request: &mut RequestOfModelResponse,
graph: &mut Graph<GgmlMetadata>,
tool_use: bool,
) -> Result<(ResponseObject, bool), LlamaCoreError> {
#[cfg(feature = "logging")]
info!(target: "stdout", "Compute chat completion by the model named {}.", graph.name());
match graph.compute() {
Ok(_) => {
let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
let err_msg = format!(
"Failed to decode the buffer of the inference result to a utf-8 string. {e}"
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
#[cfg(feature = "logging")]
info!(target: "stdout", "raw generation: {output}");
let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
LlamaCoreError::Operation(format!("Failed to post-process the output. {e}"))
})?;
#[cfg(feature = "logging")]
info!(target: "stdout", "post-processed generation:\n{}", &message);
let token_info = get_token_info_by_graph(graph)?;
#[cfg(feature = "logging")]
info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
let created = SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|e| {
let err_msg = format!("Failed to get the current time. Reason: {e}");
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
match tool_use {
true => {
if graph.metadata.prompt_template != PromptTemplateType::MistralTool
&& graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
&& graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
&& graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
&& graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
&& graph.metadata.prompt_template != PromptTemplateType::NemotronTool
&& graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
&& graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
&& graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
&& graph.metadata.prompt_template != PromptTemplateType::Llama4Chat
&& graph.metadata.prompt_template != PromptTemplateType::Qwen3NoThink
&& graph.metadata.prompt_template != PromptTemplateType::Smol3NoThink
&& graph.metadata.prompt_template != PromptTemplateType::Gemma3
&& graph.metadata.prompt_template != PromptTemplateType::GptOss
&& graph.metadata.prompt_template != PromptTemplateType::Qwen3Agent
&& graph.metadata.prompt_template != PromptTemplateType::SeedOssNoThink
&& graph.metadata.prompt_template != PromptTemplateType::SeedOssThink
{
let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', 'mistral-small-tool', 'llama-4-chat', 'qwen3-no-think', 'smol-3-no-think', 'gemma-3', 'gpt-oss', 'qwen3-agent', 'seed-oss-no-think', and 'seed-oss-think' prompt templates.", graph.metadata.prompt_template);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
return Err(LlamaCoreError::Operation(err_msg));
}
unimplemented!("Return ResponseObject with tool calls")
}
false => {
let message = ResponseOutputItemOutputMessageContent::OutputText {
annotations: vec![],
text: message.to_string(),
ty: "output_text".to_string(),
logprobs: None,
};
let output_message = ResponseOutputItem::OutputMessage {
content: vec![message],
id: "msg_67ccd3acc8d48190a77525dc6de64b4104becb25c45c1d41".to_string(),
role: ChatCompletionRole::Assistant.to_string(),
status: "completed".to_string(),
ty: "message".to_string(),
};
let temperature = match &chat_request.temperature {
Some(t) => *t,
None => graph.metadata.temperature,
};
let top_p = match &chat_request.top_p {
Some(t) => *t,
None => graph.metadata.top_p,
};
let res = ResponseObject {
background: false,
conversation: None,
created_at: created.as_secs(),
error: None,
id: gen_response_id(),
incomplete_details: None,
instructions: None,
max_output_tokens: None,
max_tool_calls: None,
metadata: HashMap::new(),
model: graph.name().to_owned(),
object: "response".to_string(),
output: vec![output_message],
parallel_tool_calls: true,
previous_response_id: None,
safety_identifier: None,
status: "completed".to_string(),
temperature,
tool_choice: chat_request.tool_choice.clone(),
tools: chat_request.tools.clone(),
top_p,
truncation: None,
usage: Usage {
input_tokens: token_info.prompt_tokens,
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
output_tokens: token_info.completion_tokens,
output_tokens_details: OutputTokensDetails {
reasoning_tokens: 0,
},
total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
},
};
Ok((res, false))
}
}
}
Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
let err_msg = format!(
"Failed to decode the buffer of the inference result to a utf-8 string. {e}"
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
let err_msg = format!("Failed to post-process the output. {e}");
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let token_info = get_token_info_by_graph(graph)?;
#[cfg(feature = "logging")]
info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
let created = SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|e| {
let err_msg = format!("Failed to get the current time. Reason: {e}");
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let message = ResponseOutputItemOutputMessageContent::OutputText {
annotations: vec![],
text: message.to_string(),
ty: "output_text".to_string(),
logprobs: None,
};
let output_message = ResponseOutputItem::OutputMessage {
content: vec![message],
id: "msg_67ccd3acc8d48190a77525dc6de64b4104becb25c45c1d41".to_string(),
role: ChatCompletionRole::Assistant.to_string(),
status: "completed".to_string(),
ty: "message".to_string(),
};
let temperature = match &chat_request.temperature {
Some(t) => *t,
None => graph.metadata.temperature,
};
let top_p = match &chat_request.top_p {
Some(t) => *t,
None => graph.metadata.top_p,
};
let res = ResponseObject {
background: false,
conversation: None,
created_at: created.as_secs(),
error: None,
id: gen_response_id(),
incomplete_details: None,
instructions: None,
max_output_tokens: None,
max_tool_calls: None,
metadata: HashMap::new(),
model: graph.name().to_owned(),
object: "response".to_string(),
output: vec![output_message],
parallel_tool_calls: true,
previous_response_id: None,
safety_identifier: None,
status: "completed".to_string(),
temperature,
tool_choice: chat_request.tool_choice.clone(),
tools: chat_request.tools.clone(),
top_p,
truncation: None,
usage: Usage {
input_tokens: token_info.prompt_tokens,
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
output_tokens: token_info.completion_tokens,
output_tokens_details: OutputTokensDetails {
reasoning_tokens: 0,
},
total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
},
};
Ok((res, false))
}
Err(wasmedge_wasi_nn::Error::BackendError(
wasmedge_wasi_nn::BackendError::PromptTooLong,
)) => {
#[cfg(feature = "logging")]
warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
let err_msg = format!(
"Failed to decode the buffer of the inference result to a utf-8 string. {e}"
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
let err_msg = format!("Failed to post-process the output. {e}");
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let token_info = get_token_info_by_graph(graph)?;
#[cfg(feature = "logging")]
info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
let created = SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|e| {
let err_msg = format!("Failed to get the current time. Reason: {e}");
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let message = ResponseOutputItemOutputMessageContent::OutputText {
annotations: vec![],
text: message.to_string(),
ty: "output_text".to_string(),
logprobs: None,
};
let output_message = ResponseOutputItem::OutputMessage {
content: vec![message],
id: "msg_67ccd3acc8d48190a77525dc6de64b4104becb25c45c1d41".to_string(),
role: ChatCompletionRole::Assistant.to_string(),
status: "completed".to_string(),
ty: "message".to_string(),
};
let temperature = match &chat_request.temperature {
Some(t) => *t,
None => graph.metadata.temperature,
};
let top_p = match &chat_request.top_p {
Some(t) => *t,
None => graph.metadata.top_p,
};
let res = ResponseObject {
background: false,
conversation: None,
created_at: created.as_secs(),
error: None,
id: gen_response_id(),
incomplete_details: None,
instructions: None,
max_output_tokens: None,
max_tool_calls: None,
metadata: HashMap::new(),
model: graph.name().to_owned(),
object: "response".to_string(),
output: vec![output_message],
parallel_tool_calls: true,
previous_response_id: None,
safety_identifier: None,
status: "completed".to_string(),
temperature,
tool_choice: chat_request.tool_choice.clone(),
tools: chat_request.tools.clone(),
top_p,
truncation: None,
usage: Usage {
input_tokens: token_info.prompt_tokens,
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
output_tokens: token_info.completion_tokens,
output_tokens_details: OutputTokensDetails {
reasoning_tokens: 0,
},
total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
},
};
Ok((res, false))
}
Err(e) => {
let err_msg = format!("Failed to compute the chat completion. Reason: {e}");
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
}
}
}
fn check_model_metadata(
chat_request: &RequestOfModelResponse,
) -> Result<GgmlMetadata, LlamaCoreError> {
let mut should_update = false;
let mut metadata = get_model_metadata(chat_request.model.as_ref())?;
if let Some(temp) = chat_request.temperature {
if metadata.temperature != temp {
metadata.temperature = temp;
if !should_update {
should_update = true;
}
}
}
if let Some(top_p) = chat_request.top_p {
if metadata.top_p != top_p {
metadata.top_p = top_p;
if !should_update {
should_update = true;
}
}
}
if metadata.embeddings {
metadata.embeddings = false;
if !should_update {
should_update = true;
}
}
if should_update {
#[cfg(feature = "logging")]
info!(target: "stdout", "Update the model metadata.");
update_model_metadata(chat_request.model.as_ref(), &metadata)?;
}
Ok(metadata)
}
fn update_n_predict(
chat_request: &RequestOfModelResponse,
metadata: &mut GgmlMetadata,
available_completion_tokens: u64,
) -> Result<(), LlamaCoreError> {
let mut should_update = false;
#[cfg(feature = "logging")]
info!(target: "stdout", "n_predict: {}", metadata.n_predict);
if let Some(max_output_tokens) = chat_request.max_output_tokens {
if metadata.n_predict != max_output_tokens {
#[cfg(feature = "logging")]
info!(target: "stdout", "Update n_predict with max_output_tokens from {} to {}", metadata.n_predict, max_output_tokens);
metadata.n_predict = max_output_tokens;
if !should_update {
should_update = true;
}
}
}
if metadata.n_predict == -2 {
#[cfg(feature = "logging")]
info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
metadata.n_predict = available_completion_tokens as i32;
if !should_update {
should_update = true;
}
}
if metadata.n_predict == -1
|| (metadata.n_predict > 0 && metadata.n_predict < available_completion_tokens as i32)
|| (metadata.n_predict < 0 && metadata.n_predict != -2)
{
#[cfg(feature = "logging")]
info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
metadata.n_predict = available_completion_tokens as i32;
if !should_update {
should_update = true;
}
}
if should_update {
#[cfg(feature = "logging")]
info!(target: "stdout", "Update the model metadata.");
update_model_metadata(chat_request.model.as_ref(), metadata)?;
}
Ok(())
}
fn post_process(
output: impl AsRef<str>,
template_ty: &PromptTemplateType,
) -> Result<String, String> {
let output = if *template_ty == PromptTemplateType::Baichuan2 {
if output.as_ref().contains("用户:") {
output.as_ref().trim_end_matches("用户:").trim().to_owned()
} else {
output.as_ref().trim().to_owned()
}
} else if *template_ty == PromptTemplateType::OpenChat {
if output.as_ref().contains("<|end_of_turn|>") {
output
.as_ref()
.trim_end_matches("<|end_of_turn|>")
.trim()
.to_owned()
} else {
output.as_ref().trim().to_owned()
}
} else if *template_ty == PromptTemplateType::GemmaInstruct
|| *template_ty == PromptTemplateType::Gemma3
{
let s = output.as_ref().trim();
if s.ends_with("<end_of_turn>") {
s.trim_end_matches("<end_of_turn>").trim().to_owned()
} else {
s.to_owned()
}
} else if *template_ty == PromptTemplateType::ChatML
|| *template_ty == PromptTemplateType::ChatMLTool
|| *template_ty == PromptTemplateType::InternLM2Tool
|| *template_ty == PromptTemplateType::MiniCPMV
{
let mut s = output.as_ref().trim();
if s.ends_with("<|endoftext|>") {
s = s.trim_end_matches("<|endoftext|>").trim();
}
if s.starts_with(":") {
s = s.trim_start_matches(":").trim();
}
let x = {
let pat = r#"<think>
</think>
"#;
if s.contains(pat) {
let x = s.replace(pat, "");
if x.starts_with("()") {
x.trim_start_matches("()").to_owned()
} else {
x.to_owned()
}
} else {
s.to_owned()
}
};
s = x.trim();
if s.contains("<|im_start|>") && s.contains("<|im_end|>") {
let idx_start = s.find("<|im_start|>").unwrap();
let idx_end = s.find("<|im_end|>").unwrap();
match idx_start <= idx_end {
true => s.split("<|im_start|>").collect::<Vec<_>>()[0]
.trim()
.to_owned(),
false => s.split("<|im_end|>").collect::<Vec<_>>()[0]
.trim()
.to_owned(),
}
} else if s.contains("<|im_start|>") {
s.split("<|im_start|>").collect::<Vec<_>>()[0]
.trim()
.to_owned()
} else if s.contains("<|im_end|>") {
let output = s.trim_end_matches("<|im_end|>").trim();
if output.starts_with(": ") {
output.trim_start_matches(": ").to_owned()
} else {
output.to_owned()
}
} else {
s.to_owned()
}
} else if *template_ty == PromptTemplateType::Zephyr
|| *template_ty == PromptTemplateType::MistralLite
|| *template_ty == PromptTemplateType::MistralTool
|| *template_ty == PromptTemplateType::MistralInstruct
|| *template_ty == PromptTemplateType::MistralSmallChat
|| *template_ty == PromptTemplateType::MistralSmallTool
|| *template_ty == PromptTemplateType::BreezeInstruct
{
if output.as_ref().contains("</s><") {
output.as_ref().trim_end_matches("</s><").trim().to_owned()
} else if output.as_ref().contains("</s>") {
output
.as_ref()
.strip_suffix("</s>")
.unwrap()
.trim()
.to_owned()
} else {
output.as_ref().trim().to_owned()
}
} else if *template_ty == PromptTemplateType::DeepseekChat {
if output.as_ref().contains("<|end_of_sentence|>") {
output
.as_ref()
.trim_end_matches("<|end_of_sentence|>")
.trim()
.replace("<|end_of_sentence|>", " ")
.trim()
.to_owned()
} else {
output.as_ref().trim().to_owned()
}
} else if *template_ty == PromptTemplateType::HumanAssistant {
if output.as_ref().contains("Human:") {
output.as_ref().trim_end_matches("Human:").trim().to_owned()
} else {
output.as_ref().trim().to_owned()
}
} else if *template_ty == PromptTemplateType::SolarInstruct {
let s = output.as_ref().trim();
if s.starts_with("### Answer") {
let s = s.trim_start_matches("###").trim();
if s.starts_with("Answer:\n") {
s.replace("Answer:\n", "Answer: ")
} else {
s.to_owned()
}
} else {
s.to_owned()
}
} else if *template_ty == PromptTemplateType::Llama2Chat
|| *template_ty == PromptTemplateType::NemotronTool
|| *template_ty == PromptTemplateType::NemotronChat
{
let s = output.as_ref().trim();
if s.ends_with("</s>") {
s.trim_end_matches("</s>").trim().to_owned()
} else {
s.to_owned()
}
} else if *template_ty == PromptTemplateType::Llama3Chat
|| *template_ty == PromptTemplateType::GroqLlama3Tool
|| *template_ty == PromptTemplateType::Llama3Tool
|| *template_ty == PromptTemplateType::FunctionaryV32
{
let s = output.as_ref().trim();
if s.ends_with("<|eot_id|>") {
s.trim_end_matches("<|eot_id|>").trim().to_owned()
} else {
s.to_owned()
}
} else if *template_ty == PromptTemplateType::Phi3Chat {
let s = output.as_ref().trim();
if s.ends_with("<|end|>") {
s.trim_end_matches("<|end|>").trim().to_owned()
} else {
s.to_owned()
}
} else if *template_ty == PromptTemplateType::Phi4Chat {
let mut s = output.as_ref().trim();
if s.starts_with("think>") {
s = s.trim_start_matches("think>").trim();
}
if s.ends_with("<|im_end|>") {
s.trim_end_matches("<|im_end|>").trim().to_owned()
} else if s.ends_with("<|end|>") {
s.trim_end_matches("<|end|>").trim().to_owned()
} else {
s.to_owned()
}
} else if *template_ty == PromptTemplateType::FunctionaryV31 {
let mut s = output.as_ref().trim();
if s.ends_with("<|eot_id|>") {
s = s.trim_end_matches("<|eot_id|>").trim();
}
if s.ends_with("<|eom_id|>") {
s = s.trim_end_matches("<|eom_id|>").trim();
}
s.to_owned()
} else if *template_ty == PromptTemplateType::MoxinChat
|| *template_ty == PromptTemplateType::MoxinInstruct
{
let s = output.as_ref().trim();
if s.ends_with("</s>") {
s.trim_end_matches("</s>").trim().to_owned()
} else if s.ends_with("[INST]") {
s.trim_end_matches("[INST]").trim().to_owned()
} else {
s.to_owned()
}
} else if *template_ty == PromptTemplateType::Falcon3 {
let s = output.as_ref().trim();
if s.ends_with("<|endoftext|>") {
s.trim_end_matches("<|endoftext|>").trim().to_owned()
} else {
s.to_owned()
}
} else if *template_ty == PromptTemplateType::Megrez {
let s = output.as_ref().trim();
if s.ends_with("<|turn_end|>") {
s.trim_end_matches("<|turn_end|>").trim().to_owned()
} else {
s.to_owned()
}
} else if *template_ty == PromptTemplateType::Qwen2vl
|| *template_ty == PromptTemplateType::Qwen3NoThink
|| *template_ty == PromptTemplateType::ChatMLThink
{
let mut s = output.as_ref().trim();
if s.starts_with(":") {
s = s.trim_start_matches(":").trim();
}
if s.starts_with("</think>") {
s = s.trim_start_matches("</think>").trim();
}
if s.ends_with("<|im_end|>") {
s.trim_end_matches("<|im_end|>").trim().to_owned()
} else {
s.to_owned()
}
} else if *template_ty == PromptTemplateType::VicunaLlava {
let s = output.as_ref().trim();
if s.ends_with("</s>") {
s.trim_end_matches("</s>").trim().to_owned()
} else {
s.to_owned()
}
} else if *template_ty == PromptTemplateType::ExaoneDeepChat
|| *template_ty == PromptTemplateType::ExaoneChat
{
let mut s = output.as_ref().trim();
if s.ends_with("[|endofturn|]") {
s = s.trim_end_matches("[|endofturn|]").trim();
}
s.to_owned()
} else if *template_ty == PromptTemplateType::Llama4Chat {
let mut s = output.as_ref().trim();
if s.ends_with("<|eot|>") {
s = s.trim_end_matches("<|eot|>").trim();
}
s.to_owned()
} else if *template_ty == PromptTemplateType::Smolvl {
let mut s = output.as_ref().trim();
if s.starts_with(":") {
s = s.trim_start_matches(":").trim();
}
if s.ends_with("<end_of_utterance>") {
s = s.trim_end_matches("<end_of_utterance>").trim();
}
if s.contains("<end_of_utterance>:") {
let parts = s.split("<end_of_utterance>:").collect::<Vec<_>>();
parts.last().unwrap().trim().to_owned()
} else {
s.to_owned()
}
} else if *template_ty == PromptTemplateType::Smol3NoThink {
let mut s = output.as_ref().trim();
if s.ends_with("<|im_end|>") {
s = s.trim_end_matches("<|im_end|>").trim();
}
let re = regex::Regex::new(r"(?s)^<think>.*?</think>\s*").unwrap();
re.replace(s, "").to_string()
} else if *template_ty == PromptTemplateType::GptOss {
let s = output.as_ref().trim();
let re =
regex::Regex::new(r"(?s).*<\|channel\|>final<\|message\|>(.*?)<\|return\|>$").unwrap();
if let Some(caps) = re.captures(s) {
let extracted = &caps[1];
extracted.to_owned()
} else {
s.to_owned()
}
} else if *template_ty == PromptTemplateType::Qwen3Agent {
let mut s = output.as_ref().trim();
if s.starts_with(":") {
s = s.trim_start_matches(":").trim();
}
if s.starts_with("</think>") {
s = s.trim_start_matches("</think>").trim();
}
if s.ends_with("<|im_end|>") {
s = s.trim_end_matches("<|im_end|>").trim();
}
if s.contains("<final_answer>") && !s.contains("</final_answer>") {
format!("{s}</final_answer>")
} else {
s.to_owned()
}
} else if *template_ty == PromptTemplateType::SeedOssNoThink {
let s = output.as_ref().trim();
let re = regex::Regex::new(r"(?s)</seed:think>\s*(.*?)\s*<seed:eos>").unwrap();
if let Some(caps) = re.captures(s) {
let extracted = &caps[1];
extracted.to_owned()
} else {
s.to_owned()
}
} else {
output.as_ref().trim().to_owned()
};
Ok(output)
}
fn build_prompt(
model_name: Option<&String>,
chat_request: &mut RequestOfModelResponse,
) -> Result<(String, u64, bool), LlamaCoreError> {
let metadata = get_model_metadata(model_name)?;
let ctx_size = metadata.ctx_size as u64;
let chat_prompt = ChatPrompt::from(metadata.prompt_template);
let max_prompt_tokens = ctx_size * 4 / 5;
let mut chat_completions_messages = to_chat_messages(chat_request.input.as_ref().unwrap())?;
#[cfg(feature = "logging")]
debug!(target: "stdout", "converted chat messages: {chat_completions_messages:?}");
loop {
if chat_request.input.is_none() {
let err_msg = "The `input` field of the request is empty.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{err_msg}");
return Err(LlamaCoreError::Operation(err_msg.to_owned()));
}
let (prompt, tool_use) = match &chat_request.tool_choice {
ToolChoice::None => {
match chat_prompt.build_with_tools(&mut chat_completions_messages, Some(&[])) {
Ok(prompt) => (prompt, false),
Err(e) => {
let err_msg = format!("Fail to build chat prompts. Reason: {e}");
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
return Err(LlamaCoreError::Operation(err_msg));
}
}
}
_ => todo!("Other tool choices are not supported yet."),
};
#[cfg(feature = "logging")]
info!(target: "stdout", "Try to set prompt: {prompt}");
set_prompt(model_name, &prompt)?;
let token_info = get_token_info_by_graph_name(model_name)?;
match token_info.prompt_tokens > max_prompt_tokens {
true => {
match chat_completions_messages[0].role() {
ChatCompletionRole::System => {
if chat_completions_messages.len() == 4
&& chat_completions_messages[1].role() == ChatCompletionRole::User
&& chat_completions_messages[2].role() == ChatCompletionRole::Assistant
&& chat_completions_messages[3].role() == ChatCompletionRole::Tool
{
let err_msg = format!(
"The number of prompt tokens ({}) is greater than the max prompt tokens ({}). Please increase the context size.",
token_info.prompt_tokens, max_prompt_tokens
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
return Err(LlamaCoreError::Operation(err_msg));
}
if chat_completions_messages.len() > 2 {
#[cfg(feature = "logging")]
info!(target: "stdout", "Prune chat history: current length {}", chat_completions_messages.len());
if chat_completions_messages[1].role() == ChatCompletionRole::User {
let user_message = chat_completions_messages.remove(1);
#[cfg(feature = "logging")]
info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
}
while chat_completions_messages[1].role() != ChatCompletionRole::User {
let message = chat_completions_messages.remove(1);
#[cfg(feature = "logging")]
info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
if chat_completions_messages.len() == 1 {
let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
#[cfg(feature = "logging")]
error!(target: "stdout", "{err_msg}");
return Err(LlamaCoreError::Operation(err_msg));
}
}
} else if token_info.prompt_tokens > ctx_size {
let err_msg = format!(
"The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
token_info.prompt_tokens, ctx_size
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
return Err(LlamaCoreError::Operation(err_msg));
} else {
return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
}
}
ChatCompletionRole::User => {
if chat_completions_messages.len() == 3
&& chat_completions_messages[1].role() == ChatCompletionRole::User
&& chat_completions_messages[2].role() == ChatCompletionRole::Assistant
&& chat_completions_messages[3].role() == ChatCompletionRole::Tool
{
let err_msg = format!(
"The number of prompt tokens ({}) is greater than the max prompt tokens ({}). Please increase the context size.",
token_info.prompt_tokens, max_prompt_tokens
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
return Err(LlamaCoreError::Operation(err_msg));
}
if chat_completions_messages.len() > 1 {
if chat_completions_messages[0].role() == ChatCompletionRole::User {
let user_message = chat_completions_messages.remove(0);
#[cfg(feature = "logging")]
info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
}
while chat_completions_messages[0].role() != ChatCompletionRole::User {
let message = chat_completions_messages.remove(0);
#[cfg(feature = "logging")]
info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
if chat_completions_messages.is_empty() {
let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
#[cfg(feature = "logging")]
error!(target: "stdout", "{err_msg}");
return Err(LlamaCoreError::Operation(err_msg));
}
}
} else if token_info.prompt_tokens > ctx_size {
let err_msg = format!(
"The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
token_info.prompt_tokens, ctx_size
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
return Err(LlamaCoreError::Operation(err_msg));
} else {
return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
}
}
_ => {
#[cfg(feature = "logging")]
info!(target: "stdout", "remove a {} message from the message queue", chat_completions_messages[0].role());
chat_completions_messages.remove(0);
}
}
continue;
}
false => return Ok((prompt, ctx_size - max_prompt_tokens, tool_use)),
}
}
}
fn set_prompt(model_name: Option<&String>, prompt: impl AsRef<str>) -> Result<(), LlamaCoreError> {
let chat_graphs = match CHAT_GRAPHS.get() {
Some(chat_graphs) => chat_graphs,
None => {
let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
return Err(LlamaCoreError::Operation(err_msg.into()));
}
};
let mut chat_graphs = chat_graphs.lock().map_err(|e| {
let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
match model_name {
Some(model_name) => {
#[cfg(feature = "logging")]
info!(target: "stdout", "Set prompt to the chat model named {model_name}");
match chat_graphs.contains_key(model_name) {
true => {
let graph = chat_graphs.get_mut(model_name).unwrap();
let tensor_data = prompt.as_ref().as_bytes().to_vec();
set_tensor_data_u8(graph, 0, &tensor_data)
}
false => match chat_graphs.iter_mut().next() {
Some((_, graph)) => {
let tensor_data = prompt.as_ref().as_bytes().to_vec();
set_tensor_data_u8(graph, 0, &tensor_data)
}
None => {
let err_msg = "There is no model available in the chat graphs.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
Err(LlamaCoreError::Operation(err_msg.into()))
}
},
}
}
None => {
#[cfg(feature = "logging")]
info!(target: "stdout", "Set prompt to the default chat model.");
match chat_graphs.iter_mut().next() {
Some((_, graph)) => {
let tensor_data = prompt.as_ref().as_bytes().to_vec();
set_tensor_data_u8(graph, 0, &tensor_data)
}
None => {
let err_msg = "There is no model available in the chat graphs while trying to set prompt to the default model.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{err_msg}");
Err(LlamaCoreError::Operation(err_msg.into()))
}
}
}
}
}
fn get_model_metadata(model_name: Option<&String>) -> Result<GgmlMetadata, LlamaCoreError> {
let chat_graphs = match CHAT_GRAPHS.get() {
Some(chat_graphs) => chat_graphs,
None => {
let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{err_msg}");
return Err(LlamaCoreError::Operation(err_msg.into()));
}
};
let chat_graphs = chat_graphs.lock().map_err(|e| {
let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
match model_name {
Some(model_name) => match chat_graphs.contains_key(model_name) {
true => {
let graph = chat_graphs.get(model_name).unwrap();
Ok(graph.metadata.clone())
}
false => match chat_graphs.iter().next() {
Some((_, graph)) => Ok(graph.metadata.clone()),
None => {
let err_msg = "There is no model available in the chat graphs.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
Err(LlamaCoreError::Operation(err_msg.into()))
}
},
},
None => match chat_graphs.iter().next() {
Some((_, graph)) => Ok(graph.metadata.clone()),
None => {
let err_msg = "There is no model available in the chat graphs.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{err_msg}");
Err(LlamaCoreError::Operation(err_msg.into()))
}
},
}
}
fn update_model_metadata(
model_name: Option<&String>,
metadata: &GgmlMetadata,
) -> Result<(), LlamaCoreError> {
let config = match serde_json::to_string(metadata) {
Ok(config) => config,
Err(e) => {
let err_msg = format!("Fail to serialize metadata to a JSON string. {e}");
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
return Err(LlamaCoreError::Operation(err_msg));
}
};
let chat_graphs = match CHAT_GRAPHS.get() {
Some(chat_graphs) => chat_graphs,
None => {
let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{err_msg}");
return Err(LlamaCoreError::Operation(err_msg.into()));
}
};
let mut chat_graphs = chat_graphs.lock().map_err(|e| {
let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. Reason: {e}");
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
match model_name {
Some(model_name) => {
match chat_graphs.contains_key(model_name) {
true => {
let graph = chat_graphs.get_mut(model_name).unwrap();
set_tensor_data_u8(graph, 1, config.as_bytes())
}
false => match chat_graphs.iter_mut().next() {
Some((_, graph)) => {
set_tensor_data_u8(graph, 1, config.as_bytes())
}
None => {
let err_msg = "There is no model available in the chat graphs.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
Err(LlamaCoreError::Operation(err_msg.into()))
}
},
}
}
None => {
match chat_graphs.iter_mut().next() {
Some((_, graph)) => {
set_tensor_data_u8(graph, 1, config.as_bytes())
}
None => {
let err_msg = "There is no model available in the chat graphs.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{err_msg}");
Err(LlamaCoreError::Operation(err_msg.into()))
}
}
}
}
}
fn reset_model_metadata(model_name: Option<&String>) -> Result<(), LlamaCoreError> {
let metadata = get_model_metadata(model_name)?;
update_model_metadata(model_name, &metadata)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)]
enum ContextFullState {
Message,
Usage,
Done,
EndOfSequence,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)]
enum StreamState {
Usage,
NoUsage,
Done,
EndOfSequence,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)]
enum PromptTooLongState {
Message,
Usage,
Done,
EndOfSequence,
}
struct ChatStream {
model: Option<String>,
context_full_state: ContextFullState,
prompt_too_long_state: PromptTooLongState,
stream_state: StreamState,
cache: Option<VecDeque<String>>,
is_waiting: bool,
has_lock: bool,
}
impl ChatStream {
fn new(
model: Option<String>,
cache: Option<Vec<String>>,
) -> Self {
let has_lock = CHAT_STREAM_ACTIVE
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_ok();
#[cfg(feature = "logging")]
if !has_lock {
info!(target: "stdout", "Lock acquisition failed in ChatStream::new, creating with waiting status");
}
ChatStream {
model,
context_full_state: ContextFullState::Message,
prompt_too_long_state: PromptTooLongState::Message,
stream_state: StreamState::Usage,
cache: cache.map(VecDeque::from),
is_waiting: !has_lock,
has_lock,
}
}
fn try_acquire_lock(&mut self) -> bool {
if self.has_lock {
return true;
}
let acquired = CHAT_STREAM_ACTIVE
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_ok();
if acquired {
self.has_lock = true;
self.is_waiting = false;
}
acquired
}
}
impl Drop for ChatStream {
fn drop(&mut self) {
if self.has_lock || (self.cache.is_none() && !self.is_waiting) {
#[cfg(feature = "logging")]
info!(target: "stdout", "Cleaning up context for ChatStream");
match &self.model {
Some(model_name) => {
match CHAT_GRAPHS.get() {
Some(chat_graphs) => {
match chat_graphs.lock() {
Ok(mut chat_graphs) => match chat_graphs.contains_key(model_name) {
true => {
let graph = chat_graphs.get_mut(model_name).unwrap();
if let Err(e) = graph.finish_single() {
let err_msg = format!(
"Failed to clean up the context. Reason: {e}"
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
#[cfg(not(feature = "logging"))]
println!(
"[ERROR][llama_core] Failed to clean up the context. Reason: {}",
&err_msg
);
}
}
false => match chat_graphs.iter_mut().next() {
Some((_, graph)) => {
if let Err(e) = graph.finish_single() {
let err_msg = format!(
"Failed to clean up the context. Reason: {e}"
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
#[cfg(not(feature = "logging"))]
println!(
"[ERROR][llama_core] Failed to clean up the context. Reason: {}",
&err_msg
);
}
}
None => {
let err_msg =
"There is no model available in the chat graphs.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
#[cfg(not(feature = "logging"))]
println!(
"[ERROR][llama_core] Failed to clean up the context. Reason: {}",
&err_msg
);
}
},
},
Err(e) => {
let err_msg =
format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
#[cfg(not(feature = "logging"))]
println!(
"[ERROR][llama_core] Failed to clean up the context. Reason: {}",
&err_msg
);
}
}
}
None => {
let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
#[cfg(not(feature = "logging"))]
println!(
"[ERROR][llama_core] Failed to clean up the context. Reason: {}",
&err_msg
);
}
};
}
None => {
match CHAT_GRAPHS.get() {
Some(chat_graphs) => {
match chat_graphs.lock() {
Ok(mut chat_graphs) => match chat_graphs.iter_mut().next() {
Some((_, graph)) => {
if let Err(e) = graph.finish_single() {
let err_msg = format!(
"Failed to clean up the context. Reason: {e}"
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
#[cfg(not(feature = "logging"))]
println!(
"[ERROR][llama_core] Failed to clean up the context. Reason: {}",
&err_msg
);
}
}
None => {
let err_msg =
"There is no model available in the chat graphs.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{err_msg}");
#[cfg(not(feature = "logging"))]
println!(
"[ERROR][llama_core] Failed to clean up the context. Reason: {}",
err_msg
);
}
},
Err(e) => {
let err_msg =
format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
#[cfg(not(feature = "logging"))]
println!(
"[ERROR][llama_core] Failed to clean up the context. Reason: {}",
&err_msg
);
}
}
}
None => {
let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
#[cfg(not(feature = "logging"))]
println!(
"[ERROR][llama_core] Failed to clean up the context. Reason: {}",
&err_msg
);
}
};
}
}
#[cfg(feature = "logging")]
info!(target: "stdout", "Model context cleanup done!");
}
if let Err(e) = reset_model_metadata(self.model.as_ref()) {
let err_msg = format!("Fail to reset model metadata. Reason: {e}");
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
#[cfg(not(feature = "logging"))]
println!("[ERROR][llama_core] {}", &err_msg);
}
#[cfg(feature = "logging")]
info!(target: "stdout", "Model metadata reset done!");
if self.has_lock {
CHAT_STREAM_ACTIVE.store(false, Ordering::SeqCst);
#[cfg(feature = "logging")]
info!(target: "stdout", "Lock from ChatStream released");
if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
if let Some(waker) = queue.pop_front() {
#[cfg(feature = "logging")]
info!(target: "stdout", "Waking up a waiting ChatStream");
waker.wake();
}
}
}
}
}
impl futures::Stream for ChatStream {
type Item = Result<String, LlamaCoreError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.is_waiting {
if !this.try_acquire_lock() {
if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
queue.retain(|w| !w.will_wake(cx.waker()));
queue.push_back(cx.waker().clone());
#[cfg(feature = "logging")]
debug!(target: "stdout", "ChatStream is waiting for lock, added waker to queue");
}
return Poll::Pending;
}
#[cfg(feature = "logging")]
info!(target: "stdout", "ChatStream acquired lock and is now active");
}
if !this.has_lock && !this.try_acquire_lock() {
this.is_waiting = true;
if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
queue.retain(|w| !w.will_wake(cx.waker()));
queue.push_back(cx.waker().clone());
}
return Poll::Pending;
}
if let Some(cache) = &mut this.cache {
let x = cache.pop_front();
#[cfg(feature = "logging")]
info!(target: "stdout", "Get the next item from the cache for ChatStream: {:?}", &x);
match x {
Some(x) => Poll::Ready(Some(Ok(x))),
None => Poll::Ready(None),
}
} else {
let res = compute_stream(
this.model.clone(),
&mut this.prompt_too_long_state,
&mut this.context_full_state,
&mut this.stream_state,
);
match res {
Ok(x) => {
#[cfg(feature = "logging")]
info!(target: "stdout", "next item for ChatStream: {}", &x);
if x != "[GGML] End of sequence" && !x.is_empty() {
Poll::Ready(Some(Ok(x)))
} else {
Poll::Ready(None)
}
}
Err(e) => Poll::Ready(Some(Err(e))),
}
}
}
}
fn get_chat_stream_waker_queue() -> &'static Mutex<VecDeque<Waker>> {
CHAT_STREAM_WAKER_QUEUE.get_or_init(|| {
#[cfg(feature = "logging")]
info!(target: "stdout", "Initializing ChatStream waker queue");
Mutex::new(VecDeque::new())
})
}
#[allow(unused_variables)]
fn compute_stream(
model_name: Option<String>,
prompt_too_long_state: &mut PromptTooLongState,
context_full_state: &mut ContextFullState,
stream_state: &mut StreamState,
) -> Result<String, LlamaCoreError> {
{
}
todo!("stream_chat_completion is not implemented yet");
}
fn to_chat_messages(input: &Input) -> Result<Vec<ChatCompletionRequestMessage>, LlamaCoreError> {
match input {
Input::Text(text) => {
let content = ChatCompletionUserMessageContent::Text(text.clone());
let user_message = ChatCompletionRequestMessage::new_user_message(content, None);
Ok(vec![user_message])
}
Input::InputItemList(items) => {
let mut messages = Vec::new();
for item in items {
match item {
InputItem::InputMessage { content, role, .. } => {
let message = input_message_to_chat_message(content, role)?;
messages.push(message);
}
_ => {
#[cfg(feature = "logging")]
warn!(target: "stdout", "Skipping unsupported InputItem variant");
}
}
}
Ok(messages)
}
}
}
fn input_message_to_chat_message(
content: &InputMessageContent,
role: &str,
) -> Result<ChatCompletionRequestMessage, LlamaCoreError> {
match role {
"user" => {
let content = input_message_content_to_chat_message_content(content)?;
let content = ChatCompletionUserMessageContent::Text(content);
Ok(ChatCompletionRequestMessage::new_user_message(
content, None,
))
}
"assistant" => {
let content = input_message_content_to_chat_message_content(content)?;
Ok(ChatCompletionRequestMessage::new_assistant_message(
Some(content),
None,
None,
))
}
"system" => {
let content = input_message_content_to_chat_message_content(content)?;
Ok(ChatCompletionRequestMessage::new_system_message(
content, None,
))
}
"developer" => {
let content = input_message_content_to_chat_message_content(content)?;
Ok(ChatCompletionRequestMessage::new_developer_message(
content, None,
))
}
_ => {
let error_msg = format!("Unsupported role: {}", role);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &error_msg);
Err(LlamaCoreError::Operation(error_msg))
}
}
}
fn input_message_content_to_chat_message_content(
content: &InputMessageContent,
) -> Result<String, LlamaCoreError> {
match content {
InputMessageContent::Text(text) => Ok(text.clone()),
InputMessageContent::InputItemContentList(_items) => {
let error_msg = "Not support InputMessageContent::InputItemContentList";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", error_msg);
Err(LlamaCoreError::Operation(error_msg.into()))
}
}
}