#![allow(dead_code)]
use crate::{api::ChatMessage, AppState};
use axum::{extract::State, response::IntoResponse, Json};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Deserialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(default)]
pub stream: Option<bool>,
#[serde(default)]
pub temperature: Option<f32>,
#[serde(default)]
pub max_tokens: Option<usize>,
#[serde(default)]
pub top_p: Option<f32>,
}
#[derive(Debug, Serialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
pub usage: Usage,
}
#[derive(Debug, Serialize)]
pub struct Choice {
pub index: usize,
pub message: ChatMessage,
pub finish_reason: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatCompletionChunk {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChunkChoice>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChunkChoice {
pub index: usize,
pub delta: Delta,
pub finish_reason: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Delta {
pub content: Option<String>,
pub role: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct ModelsResponse {
pub object: String,
pub data: Vec<Model>,
}
#[derive(Debug, Serialize)]
pub struct Model {
pub id: String,
pub object: String,
pub created: u64,
pub owned_by: String,
}
pub async fn models(State(state): State<Arc<AppState>>) -> impl IntoResponse {
let models = state
.registry
.list_all_available()
.into_iter()
.map(|name| Model {
id: name,
object: "model".to_string(),
created: 0, owned_by: "shimmy".to_string(),
})
.collect();
Json(ModelsResponse {
object: "list".to_string(),
data: models,
})
}
pub async fn chat_completions(
State(state): State<Arc<AppState>>,
Json(req): Json<ChatCompletionRequest>,
) -> impl IntoResponse {
use axum::http::StatusCode;
let Some(spec) = state.registry.to_spec(&req.model) else {
tracing::warn!("Model '{}' not found in registry", req.model);
return StatusCode::NOT_FOUND.into_response();
};
tracing::debug!("Found model spec for '{}': {:?}", req.model, spec);
let engine = &state.engine;
let loaded = match engine.load(&spec).await {
Ok(loaded) => loaded,
Err(e) => {
tracing::error!("Failed to load model '{}': {:?}", req.model, e);
return StatusCode::BAD_GATEWAY.into_response();
}
};
let fam = match spec.template.as_deref() {
Some("chatml") => crate::templates::TemplateFamily::ChatML,
Some("llama3") | Some("llama-3") => crate::templates::TemplateFamily::Llama3,
_ => {
if req.model.to_lowercase().contains("qwen")
|| req.model.to_lowercase().contains("chatglm")
{
crate::templates::TemplateFamily::ChatML
} else if req.model.to_lowercase().contains("llama") {
crate::templates::TemplateFamily::Llama3
} else {
crate::templates::TemplateFamily::OpenChat
}
}
};
let pairs = req
.messages
.iter()
.map(|m| (m.role.clone(), m.content.clone()))
.collect::<Vec<_>>();
let last_user_message = req
.messages
.iter()
.filter(|m| m.role == "user")
.last()
.map(|m| m.content.as_str());
let history: Vec<_> = if last_user_message.is_some() {
req.messages
.iter()
.take(req.messages.len().saturating_sub(1))
.map(|m| (m.role.clone(), m.content.clone()))
.collect()
} else {
pairs.clone()
};
let prompt = fam.render(None, &history, last_user_message);
let mut opts = crate::engine::GenOptions::default();
if let Some(t) = req.temperature {
opts.temperature = t;
}
if let Some(p) = req.top_p {
opts.top_p = p;
}
if let Some(m) = req.max_tokens {
opts.max_tokens = m;
}
if let Some(s) = req.stream {
opts.stream = s;
}
if opts.stream {
use axum::response::sse::{Event, Sse};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<String>();
let mut opts_clone = opts.clone();
opts_clone.stream = false;
let prompt_clone = prompt.clone();
let model_clone = req.model.clone();
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let id = format!("chatcmpl-{}", uuid::Uuid::new_v4().simple());
tokio::spawn(async move {
let tx_tokens = tx.clone();
let id_for_tokens = id.clone();
let model_for_tokens = model_clone.clone();
let id_for_final = id.clone();
let model_for_final = model_clone.clone();
let initial_chunk = ChatCompletionChunk {
id: id_for_tokens.clone(),
object: "chat.completion.chunk".to_string(),
created: timestamp,
model: model_for_tokens.clone(),
choices: vec![ChunkChoice {
index: 0,
delta: Delta {
role: Some("assistant".to_string()),
content: None,
},
finish_reason: None,
}],
};
let _ = tx_tokens.send(format!(
"data: {}\n\n",
serde_json::to_string(&initial_chunk).unwrap()
));
let _ = loaded
.generate(
&prompt_clone,
opts_clone,
Some(Box::new(move |tok| {
let chunk = ChatCompletionChunk {
id: id_for_tokens.clone(),
object: "chat.completion.chunk".to_string(),
created: timestamp,
model: model_for_tokens.clone(),
choices: vec![ChunkChoice {
index: 0,
delta: Delta {
role: None,
content: Some(tok),
},
finish_reason: None,
}],
};
let _ = tx_tokens.send(
serde_json::to_string(&chunk).unwrap()
);
})),
)
.await;
let final_chunk = ChatCompletionChunk {
id: id_for_final,
object: "chat.completion.chunk".to_string(),
created: timestamp,
model: model_for_final,
choices: vec![ChunkChoice {
index: 0,
delta: Delta {
role: None,
content: None,
},
finish_reason: Some("stop".to_string()),
}],
};
let _ = tx.send(
serde_json::to_string(&final_chunk).unwrap()
);
let _ = tx.send("[DONE]".to_string());
});
let stream = UnboundedReceiverStream::new(rx)
.map(|s| Ok::<Event, std::convert::Infallible>(Event::default().data(s)));
Sse::new(stream).into_response()
} else {
match loaded.generate(&prompt, opts, None).await {
Ok(content) => {
tracing::debug!(
"Generated response for model '{}': {} chars",
req.model,
content.len()
);
let response = ChatCompletionResponse {
id: format!("chatcmpl-{}", uuid::Uuid::new_v4().simple()),
object: "chat.completion".to_string(),
created: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
model: req.model,
choices: vec![Choice {
index: 0,
message: ChatMessage {
role: "assistant".to_string(),
content,
},
finish_reason: Some("stop".to_string()),
}],
usage: Usage {
prompt_tokens: 0, completion_tokens: 0,
total_tokens: 0,
},
};
Json(response).into_response()
}
Err(e) => {
tracing::error!(
"Failed to generate response for model '{}': {:?}",
req.model,
e
);
StatusCode::BAD_GATEWAY.into_response()
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::engine::adapter::InferenceEngineAdapter;
use crate::model_registry::Registry;
use crate::AppState;
use axum::{extract::State, Json};
use std::sync::Arc;
#[tokio::test]
async fn test_chat_completions_handler_execution() {
let registry = Registry::default();
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(AppState { engine, registry });
let request = ChatCompletionRequest {
model: "test".to_string(),
messages: vec![],
temperature: None,
max_tokens: None,
top_p: None,
stream: Some(false),
};
let _result = chat_completions(State(state), Json(request)).await;
assert!(true);
}
#[tokio::test]
async fn test_models_handler_execution() {
let registry = Registry::default();
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(AppState { engine, registry });
let _result = models(State(state)).await;
assert!(true);
}
#[test]
fn test_chat_completion_response_creation() {
let response = ChatCompletionResponse {
id: "test-id".to_string(),
object: "chat.completion".to_string(),
created: 1234567890,
model: "test-model".to_string(),
choices: vec![Choice {
index: 0,
message: ChatMessage {
role: "assistant".to_string(),
content: "Hello world".to_string(),
},
finish_reason: Some("stop".to_string()),
}],
usage: Usage {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
};
assert_eq!(response.id, "test-id");
assert_eq!(response.choices.len(), 1);
assert_eq!(response.choices[0].message.content, "Hello world");
}
#[test]
fn test_chunk_choice_creation() {
let choice = ChunkChoice {
index: 0,
delta: Delta {
role: Some("assistant".to_string()),
content: Some("token".to_string()),
},
finish_reason: None,
};
assert_eq!(choice.index, 0);
assert_eq!(choice.delta.content.unwrap(), "token");
}
#[tokio::test]
async fn test_chat_completions_model_not_found() {
let registry = Registry::default();
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(AppState { engine, registry });
let request = ChatCompletionRequest {
model: "nonexistent-model".to_string(),
messages: vec![ChatMessage {
role: "user".to_string(),
content: "Hello".to_string(),
}],
stream: Some(false),
temperature: None,
max_tokens: None,
top_p: None,
};
let _response = chat_completions(State(state), Json(request)).await;
assert!(true); }
#[tokio::test]
async fn test_chat_completions_streaming_request() {
use crate::model_registry::ModelEntry;
let mut registry = Registry::default();
registry.register(ModelEntry {
name: "test-streaming".to_string(),
base_path: "./test.safetensors".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(AppState { engine, registry });
let request = ChatCompletionRequest {
model: "test-streaming".to_string(),
messages: vec![ChatMessage {
role: "user".to_string(),
content: "Hello".to_string(),
}],
stream: Some(true), temperature: Some(0.7),
max_tokens: Some(100),
top_p: Some(0.9),
};
let _response = chat_completions(State(state), Json(request)).await;
assert!(true);
}
#[tokio::test]
async fn test_chat_completions_non_streaming_request() {
use crate::model_registry::ModelEntry;
let mut registry = Registry::default();
registry.register(ModelEntry {
name: "test-non-streaming".to_string(),
base_path: "./test.safetensors".into(),
lora_path: None,
template: Some("llama3".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(AppState { engine, registry });
let request = ChatCompletionRequest {
model: "test-non-streaming".to_string(),
messages: vec![
ChatMessage {
role: "user".to_string(),
content: "Hello".to_string(),
},
ChatMessage {
role: "assistant".to_string(),
content: "Hi there!".to_string(),
},
],
stream: Some(false), temperature: Some(0.5),
max_tokens: Some(50),
top_p: Some(0.8),
};
let _response = chat_completions(State(state), Json(request)).await;
assert!(true);
}
#[test]
fn test_template_family_selection() {
use crate::templates::TemplateFamily;
let spec_chatml = crate::engine::ModelSpec {
name: "test-chatml".to_string(),
base_path: "./test.safetensors".into(),
lora_path: None,
template: Some("chatml".to_string()),
ctx_len: 2048,
n_threads: None,
};
let fam = match spec_chatml.template.as_deref() {
Some("chatml") => TemplateFamily::ChatML,
Some("llama3") | Some("llama-3") => TemplateFamily::Llama3,
_ => TemplateFamily::OpenChat,
};
assert!(matches!(fam, TemplateFamily::ChatML));
let spec_llama3 = crate::engine::ModelSpec {
name: "test-llama3".to_string(),
base_path: "./test.safetensors".into(),
lora_path: None,
template: Some("llama3".to_string()),
ctx_len: 2048,
n_threads: None,
};
let fam = match spec_llama3.template.as_deref() {
Some("chatml") => TemplateFamily::ChatML,
Some("llama3") | Some("llama-3") => TemplateFamily::Llama3,
_ => TemplateFamily::OpenChat,
};
assert!(matches!(fam, TemplateFamily::Llama3));
let spec_default = crate::engine::ModelSpec {
name: "test-default".to_string(),
base_path: "./test.safetensors".into(),
lora_path: None,
template: Some("unknown".to_string()),
ctx_len: 2048,
n_threads: None,
};
let fam = match spec_default.template.as_deref() {
Some("chatml") => TemplateFamily::ChatML,
Some("llama3") | Some("llama-3") => TemplateFamily::Llama3,
_ => TemplateFamily::OpenChat,
};
assert!(matches!(fam, TemplateFamily::OpenChat));
}
#[test]
fn test_generation_options_setting() {
let mut opts = crate::engine::GenOptions::default();
let temp = Some(0.8f32);
if let Some(t) = temp {
opts.temperature = t;
}
assert_eq!(opts.temperature, 0.8);
let top_p = Some(0.9f32);
if let Some(p) = top_p {
opts.top_p = p;
}
assert_eq!(opts.top_p, 0.9);
let max_tokens = Some(150usize);
if let Some(m) = max_tokens {
opts.max_tokens = m;
}
assert_eq!(opts.max_tokens, 150);
let stream = Some(true);
if let Some(s) = stream {
opts.stream = s;
}
assert_eq!(opts.stream, true);
}
#[test]
fn test_chat_completion_chunk_serialization() {
let chunk = ChatCompletionChunk {
id: "chatcmpl-test123".to_string(),
object: "chat.completion.chunk".to_string(),
created: 1234567890,
model: "test-model".to_string(),
choices: vec![ChunkChoice {
index: 0,
delta: Delta {
role: Some("assistant".to_string()),
content: Some("Hello".to_string()),
},
finish_reason: None,
}],
};
let json = serde_json::to_string(&chunk).unwrap();
assert!(json.contains("chatcmpl-test123"));
assert!(json.contains("chat.completion.chunk"));
assert!(json.contains("Hello"));
let parsed: ChatCompletionChunk = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.id, "chatcmpl-test123");
assert_eq!(parsed.choices[0].delta.content.as_ref().unwrap(), "Hello");
}
#[test]
fn test_delta_with_role_only() {
let delta = Delta {
role: Some("assistant".to_string()),
content: None,
};
assert_eq!(delta.role.as_ref().unwrap(), "assistant");
assert!(delta.content.is_none());
}
#[test]
fn test_delta_with_content_only() {
let delta = Delta {
role: None,
content: Some("token".to_string()),
};
assert!(delta.role.is_none());
assert_eq!(delta.content.as_ref().unwrap(), "token");
}
#[test]
fn test_usage_structure() {
let usage = Usage {
prompt_tokens: 10,
completion_tokens: 20,
total_tokens: 30,
};
assert_eq!(usage.prompt_tokens, 10);
assert_eq!(usage.completion_tokens, 20);
assert_eq!(usage.total_tokens, 30);
let json = serde_json::to_string(&usage).unwrap();
let parsed: Usage = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.total_tokens, 30);
}
#[test]
fn test_models_response_structure() {
let models_response = ModelsResponse {
object: "list".to_string(),
data: vec![
Model {
id: "model1".to_string(),
object: "model".to_string(),
created: 1234567890,
owned_by: "shimmy".to_string(),
},
Model {
id: "model2".to_string(),
object: "model".to_string(),
created: 1234567890,
owned_by: "shimmy".to_string(),
},
],
};
assert_eq!(models_response.data.len(), 2);
assert_eq!(models_response.data[0].id, "model1");
assert_eq!(models_response.data[1].id, "model2");
}
#[test]
fn test_chat_completion_request_defaults() {
let json_str = r#"{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
]
}"#;
let request: ChatCompletionRequest = serde_json::from_str(json_str).unwrap();
assert_eq!(request.model, "test-model");
assert_eq!(request.messages.len(), 1);
assert!(request.stream.is_none());
assert!(request.temperature.is_none());
assert!(request.max_tokens.is_none());
assert!(request.top_p.is_none());
}
#[test]
fn test_chat_completion_request_with_all_fields() {
let json_str = r#"{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
],
"stream": true,
"temperature": 0.7,
"max_tokens": 100,
"top_p": 0.9
}"#;
let request: ChatCompletionRequest = serde_json::from_str(json_str).unwrap();
assert_eq!(request.model, "test-model");
assert_eq!(request.stream, Some(true));
assert_eq!(request.temperature, Some(0.7));
assert_eq!(request.max_tokens, Some(100));
assert_eq!(request.top_p, Some(0.9));
}
#[test]
fn test_finish_reason_values() {
let choice = Choice {
index: 0,
message: ChatMessage {
role: "assistant".to_string(),
content: "Response".to_string(),
},
finish_reason: Some("stop".to_string()),
};
assert_eq!(choice.finish_reason.as_ref().unwrap(), "stop");
let chunk_choice = ChunkChoice {
index: 0,
delta: Delta {
role: None,
content: None,
},
finish_reason: Some("length".to_string()),
};
assert_eq!(chunk_choice.finish_reason.as_ref().unwrap(), "length");
}
#[test]
fn test_message_pairs_conversion() {
let messages = vec![
ChatMessage {
role: "user".to_string(),
content: "Hello".to_string(),
},
ChatMessage {
role: "assistant".to_string(),
content: "Hi there!".to_string(),
},
];
let pairs: Vec<(String, String)> = messages
.iter()
.map(|m| (m.role.clone(), m.content.clone()))
.collect();
assert_eq!(pairs.len(), 2);
assert_eq!(pairs[0].0, "user");
assert_eq!(pairs[0].1, "Hello");
assert_eq!(pairs[1].0, "assistant");
assert_eq!(pairs[1].1, "Hi there!");
}
#[tokio::test]
async fn test_models_endpoint_with_registered_models() {
use crate::model_registry::ModelEntry;
let mut registry = Registry::default();
registry.register(ModelEntry {
name: "registered-model".to_string(),
base_path: "./test1.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
registry.register(ModelEntry {
name: "another-model".to_string(),
base_path: "./test2.gguf".into(),
lora_path: None,
template: Some("llama3".into()),
ctx_len: Some(4096),
n_threads: None,
});
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(AppState { engine, registry });
let _response = models(State(state)).await;
assert!(true); }
}