use axum::{extract::State, response::{IntoResponse, Sse, sse::Event}, Json};
use axum::extract::ws::{WebSocketUpgrade, WebSocket, Message as WsMessage};
use futures_util::StreamExt;
use serde::{Deserialize, Serialize};
use tokio_stream::wrappers::UnboundedReceiverStream;
use crate::{engine::{GenOptions}, templates::TemplateFamily, AppState};
use std::sync::Arc;
#[derive(Debug, Deserialize)]
pub struct GenerateRequest {
pub model: String,
pub prompt: Option<String>, pub messages: Option<Vec<ChatMessage>>, pub system: Option<String>,
#[serde(default)] pub temperature: Option<f32>,
#[serde(default)] pub top_p: Option<f32>,
#[serde(default)] pub top_k: Option<i32>,
#[serde(default)] pub max_tokens: Option<usize>,
#[serde(default)] pub stream: Option<bool>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ChatMessage { pub role: String, pub content: String }
#[derive(Debug, Serialize, Deserialize)]
pub struct GenerateResponse { pub response: String }
pub async fn generate(State(state): State<Arc<AppState>>, Json(req): Json<GenerateRequest>) -> impl IntoResponse {
let Some(spec) = state.registry.to_spec(&req.model) else { return axum::http::StatusCode::NOT_FOUND.into_response(); };
let engine = &state.engine;
let Ok(loaded) = engine.load(&spec).await else { return axum::http::StatusCode::BAD_GATEWAY.into_response(); };
let prompt = if let Some(ms) = &req.messages {
let fam = match spec.template.as_deref() {
Some("chatml") => TemplateFamily::ChatML,
Some("llama3") | Some("llama-3") => TemplateFamily::Llama3,
_ => TemplateFamily::OpenChat,
};
let pairs = ms.iter().map(|m| (m.role.clone(), m.content.clone())).collect::<Vec<_>>();
fam.render(req.system.as_deref(), &pairs, None)
} else {
req.prompt.unwrap_or_default()
};
let mut opts = GenOptions::default();
if let Some(t) = req.temperature { opts.temperature = t; }
if let Some(p) = req.top_p { opts.top_p = p; }
if let Some(k) = req.top_k { opts.top_k = k; }
if let Some(m) = req.max_tokens { opts.max_tokens = m; }
if let Some(s) = req.stream { opts.stream = s; }
if opts.stream { let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<String>();
let mut opts_clone = opts.clone(); opts_clone.stream = false; let prompt_clone = prompt.clone();
tokio::spawn(async move {
let tx_tokens = tx.clone();
let _ = loaded.generate(&prompt_clone, opts_clone, Some(Box::new(move |tok| { let _ = tx_tokens.send(tok); }))).await;
let _ = tx.send("[DONE]".into());
});
let stream = UnboundedReceiverStream::new(rx).map(|s| Ok::<Event, std::convert::Infallible>(Event::default().data(s)));
Sse::new(stream).into_response()
} else {
match loaded.generate(&prompt, opts, None).await {
Ok(full) => Json(GenerateResponse { response: full }).into_response(),
Err(_) => axum::http::StatusCode::BAD_GATEWAY.into_response(),
}
}
}
pub async fn ws_generate(State(state): State<Arc<AppState>>, ws: WebSocketUpgrade) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_ws_generate(state, socket))
}
async fn handle_ws_generate(state: Arc<AppState>, mut socket: WebSocket) {
let Some(Ok(first)) = socket.recv().await else { return; };
let req_json = match first {
WsMessage::Text(t) => t,
WsMessage::Binary(b) => String::from_utf8_lossy(&b).to_string(),
_ => return,
};
let req: GenerateRequest = match serde_json::from_str(&req_json) {
Ok(r) => r,
Err(e) => {
let _ = socket.send(WsMessage::Text(format!("{{\"error\":\"bad request: {e}\"}}"))).await;
return;
}
};
let Some(spec) = state.registry.to_spec(&req.model) else {
let _ = socket.send(WsMessage::Text("{\"error\":\"model not found\"}".into())).await;
return;
};
let Ok(loaded) = state.engine.load(&spec).await else {
let _ = socket.send(WsMessage::Text("{\"error\":\"load failed\"}".into())).await;
return;
};
let prompt = if let Some(ms) = &req.messages {
let fam = match spec.template.as_deref() {
Some("chatml") => TemplateFamily::ChatML,
Some("llama3") | Some("llama-3") => TemplateFamily::Llama3,
_ => TemplateFamily::OpenChat,
};
let pairs = ms.iter().map(|m| (m.role.clone(), m.content.clone())).collect::<Vec<_>>();
fam.render(req.system.as_deref(), &pairs, None)
} else { req.prompt.clone().unwrap_or_default() };
let mut opts = GenOptions::default();
if let Some(t) = req.temperature { opts.temperature = t; }
if let Some(p) = req.top_p { opts.top_p = p; }
if let Some(k) = req.top_k { opts.top_k = k; }
if let Some(m) = req.max_tokens { opts.max_tokens = m; }
let mut internal = opts.clone(); internal.stream = false;
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
tokio::spawn({
let prompt = prompt.clone();
let tx_done = tx.clone();
async move {
let tx_tokens = tx.clone();
let _ = loaded.generate(&prompt, internal, Some(Box::new(move |tok| { let _ = tx_tokens.send(tok); }))).await;
let _ = tx_done.send("[DONE]".into());
}
});
while let Some(piece) = rx.recv().await {
if piece == "[DONE]" { break; }
if socket.send(WsMessage::Text(piece)).await.is_err() { break; }
}
let _ = socket.send(WsMessage::Text("{\"done\":true}".into())).await;
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ModelListResponse {
pub models: Vec<ModelInfo>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ModelInfo {
pub name: String,
pub size_bytes: Option<u64>,
pub model_type: Option<String>,
pub parameter_count: Option<String>,
pub source: String, }
pub async fn list_models(State(state): State<Arc<AppState>>) -> impl IntoResponse {
let mut models = Vec::new();
for entry in state.registry.list() {
models.push(ModelInfo {
name: entry.name.clone(),
size_bytes: None, model_type: None,
parameter_count: None,
source: "registered".to_string(),
});
}
for (name, discovered) in &state.registry.discovered_models {
models.push(ModelInfo {
name: name.clone(),
size_bytes: Some(discovered.size_bytes),
model_type: Some(discovered.model_type.clone()),
parameter_count: discovered.parameter_count.clone(),
source: "discovered".to_string(),
});
}
Json(ModelListResponse { models })
}
pub async fn discover_models(State(_state): State<Arc<AppState>>) -> impl IntoResponse {
use crate::auto_discovery::ModelAutoDiscovery;
let discovery = ModelAutoDiscovery::new();
match discovery.discover_models() {
Ok(models) => {
let model_infos: Vec<ModelInfo> = models.iter().map(|m| ModelInfo {
name: m.name.clone(),
size_bytes: Some(m.size_bytes),
model_type: Some(m.model_type.clone()),
parameter_count: m.parameter_count.clone(),
source: "discovered".to_string(),
}).collect();
Json(serde_json::json!({
"discovered": model_infos.len(),
"models": model_infos
})).into_response()
},
Err(_e) => {
axum::http::StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
}
use axum::extract::Path;
pub async fn load_model(State(_state): State<Arc<AppState>>, Path(name): Path<String>) -> impl IntoResponse {
Json(serde_json::json!({
"message": format!("Model {} load requested", name),
"status": "pending"
}))
}
pub async fn unload_model(State(_state): State<Arc<AppState>>, Path(name): Path<String>) -> impl IntoResponse {
Json(serde_json::json!({
"message": format!("Model {} unload requested", name),
"status": "pending"
}))
}
pub async fn model_status(State(_state): State<Arc<AppState>>, Path(name): Path<String>) -> impl IntoResponse {
Json(serde_json::json!({
"model": name,
"status": "unknown",
"loaded": false
}))
}
#[allow(dead_code)]
pub async fn list_tools(State(_state): State<Arc<AppState>>) -> impl IntoResponse {
Json(serde_json::json!({
"tools": []
}))
}
#[allow(dead_code)]
pub async fn execute_tool(State(_state): State<Arc<AppState>>, Path(name): Path<String>, Json(_arguments): Json<serde_json::Value>) -> impl IntoResponse {
Json(serde_json::json!({
"error": format!("Tool {} not available", name)
})).into_response()
}
#[allow(dead_code)]
pub async fn execute_workflow(State(_state): State<Arc<AppState>>, Json(_request): Json<serde_json::Value>) -> impl IntoResponse {
Json(serde_json::json!({
"message": "Workflow execution not yet implemented",
"status": "pending"
}))
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json;
#[test]
fn test_generate_request_parsing() {
let json_str = r#"{"prompt": "test", "max_tokens": 100}"#;
let parsed: Result<serde_json::Value, _> = serde_json::from_str(json_str);
assert!(parsed.is_ok());
if let Ok(json) = parsed {
assert_eq!(json["prompt"], "test");
assert_eq!(json["max_tokens"], 100);
}
}
#[test]
fn test_model_list_response() {
let models = vec!["model1".to_string(), "model2".to_string()];
assert_eq!(models.len(), 2);
assert!(models.contains(&"model1".to_string()));
}
#[test]
fn test_error_response_format() {
let error_response = serde_json::json!({"error": "Model not found"});
assert_eq!(error_response["error"], "Model not found");
}
#[test]
fn test_invalid_json_handling() {
let invalid_json = "{invalid json}";
let result: Result<serde_json::Value, _> = serde_json::from_str(invalid_json);
assert!(result.is_err());
}
#[test]
fn test_missing_prompt_field() {
let json_missing_prompt = r#"{"max_tokens": 100}"#;
let parsed: serde_json::Value = serde_json::from_str(json_missing_prompt).unwrap();
assert!(parsed.get("prompt").is_none());
}
#[test]
fn test_model_not_found_error() {
let error_msg = "Model 'nonexistent' not found";
assert!(error_msg.contains("not found"));
}
#[test]
fn test_websocket_message_format() {
let message = serde_json::json!({
"model": "test",
"prompt": "hello",
"stream": true
});
assert_eq!(message["stream"], true);
assert_eq!(message["model"], "test");
}
#[tokio::test]
async fn test_generate_handler_execution() {
use crate::model_registry::Registry;
use crate::engine::adapter::InferenceEngineAdapter;
let registry = Registry::default();
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(AppState { engine, registry });
let request = GenerateRequest {
model: "test".to_string(),
prompt: Some("Hello".to_string()),
messages: None,
system: None,
max_tokens: Some(50),
temperature: None,
top_p: None,
top_k: None,
stream: Some(false),
};
let _result = generate(State(state), Json(request)).await;
assert!(true);
}
#[tokio::test]
async fn test_list_models_handler_execution() {
use crate::model_registry::Registry;
use crate::engine::adapter::InferenceEngineAdapter;
let registry = Registry::default();
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(AppState { engine, registry });
let _result = list_models(State(state)).await;
assert!(true);
}
#[test]
fn test_websocket_connection_setup() {
let ws_message = serde_json::json!({
"type": "connection",
"model": "test-model",
"stream": true
});
assert!(ws_message.is_object());
assert_eq!(ws_message["type"], "connection");
assert_eq!(ws_message["stream"], true);
}
#[test]
fn test_generate_request_structure() {
let req = GenerateRequest {
model: "test".to_string(),
prompt: Some("Hello".to_string()),
messages: None,
system: None,
max_tokens: Some(100),
temperature: Some(0.7),
top_p: Some(0.9),
top_k: Some(40),
stream: Some(false),
};
assert_eq!(req.model, "test");
assert_eq!(req.prompt.as_ref().unwrap(), "Hello");
assert_eq!(req.max_tokens.unwrap(), 100);
}
#[test]
fn test_chat_message_structure() {
let msg = ChatMessage {
role: "user".to_string(),
content: "Hello world".to_string(),
};
assert_eq!(msg.role, "user");
assert_eq!(msg.content, "Hello world");
}
#[test]
fn test_generate_response_structure() {
let resp = GenerateResponse {
response: "Generated text".to_string(),
};
assert_eq!(resp.response, "Generated text");
}
#[tokio::test]
async fn test_discover_models_handler_execution() {
use crate::model_registry::Registry;
use crate::engine::adapter::InferenceEngineAdapter;
let registry = Registry::default();
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(AppState { engine, registry });
let _result = discover_models(State(state)).await;
assert!(true);
}
#[tokio::test]
async fn test_load_model_handler_execution() {
use crate::model_registry::Registry;
use crate::engine::adapter::InferenceEngineAdapter;
use axum::extract::Path;
let registry = Registry::default();
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(AppState { engine, registry });
let _result = load_model(State(state), Path("test-model".to_string())).await;
assert!(true);
}
#[tokio::test]
async fn test_unload_model_handler_execution() {
use crate::model_registry::Registry;
use crate::engine::adapter::InferenceEngineAdapter;
use axum::extract::Path;
let registry = Registry::default();
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(AppState { engine, registry });
let _result = unload_model(State(state), Path("test-model".to_string())).await;
assert!(true);
}
#[tokio::test]
async fn test_model_status_handler_execution() {
use crate::model_registry::Registry;
use crate::engine::adapter::InferenceEngineAdapter;
use axum::extract::Path;
let registry = Registry::default();
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(AppState { engine, registry });
let _result = model_status(State(state), Path("test-model".to_string())).await;
assert!(true);
}
#[tokio::test]
async fn test_list_tools_handler_execution() {
use crate::model_registry::Registry;
use crate::engine::adapter::InferenceEngineAdapter;
let registry = Registry::default();
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(AppState { engine, registry });
let _result = list_tools(State(state)).await;
assert!(true);
}
#[tokio::test]
async fn test_execute_tool_handler_execution() {
use crate::model_registry::Registry;
use crate::engine::adapter::InferenceEngineAdapter;
use axum::extract::Path;
let registry = Registry::default();
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(AppState { engine, registry });
let arguments = serde_json::json!({"test": "value"});
let _result = execute_tool(State(state), Path("test-tool".to_string()), Json(arguments)).await;
assert!(true);
}
#[tokio::test]
async fn test_execute_workflow_handler_execution() {
use crate::model_registry::Registry;
use crate::engine::adapter::InferenceEngineAdapter;
let registry = Registry::default();
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(AppState { engine, registry });
let request = serde_json::json!({"workflow": "test"});
let _result = execute_workflow(State(state), Json(request)).await;
assert!(true);
}
#[tokio::test]
async fn test_generate_handler_streaming() {
use crate::model_registry::{Registry, ModelEntry};
use crate::engine::adapter::InferenceEngineAdapter;
let mut registry = Registry::default();
registry.register(ModelEntry {
name: "stream-test".to_string(),
base_path: "./test.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(AppState { engine, registry });
let request = GenerateRequest {
model: "stream-test".to_string(),
prompt: Some("Test prompt".to_string()),
messages: None,
system: None,
max_tokens: Some(50),
temperature: Some(0.7),
top_p: Some(0.9),
top_k: Some(40),
stream: Some(true), };
let _result = generate(State(state), Json(request)).await;
assert!(true);
}
#[tokio::test]
async fn test_generate_handler_with_messages() {
use crate::model_registry::{Registry, ModelEntry};
use crate::engine::adapter::InferenceEngineAdapter;
let mut registry = Registry::default();
registry.register(ModelEntry {
name: "messages-test".to_string(),
base_path: "./test.gguf".into(),
lora_path: None,
template: Some("llama3".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(AppState { engine, registry });
let request = GenerateRequest {
model: "messages-test".to_string(),
prompt: None,
messages: Some(vec![
ChatMessage {
role: "user".to_string(),
content: "Hello".to_string(),
},
ChatMessage {
role: "assistant".to_string(),
content: "Hi there!".to_string(),
},
]),
system: Some("You are a helpful assistant".to_string()),
max_tokens: Some(100),
temperature: None,
top_p: None,
top_k: None,
stream: Some(false),
};
let _result = generate(State(state), Json(request)).await;
assert!(true);
}
#[test]
fn test_template_family_selection_in_generate() {
use crate::templates::TemplateFamily;
let template = Some("chatml");
let fam = match template {
Some("chatml") => TemplateFamily::ChatML,
Some("llama3") | Some("llama-3") => TemplateFamily::Llama3,
_ => TemplateFamily::OpenChat,
};
assert!(matches!(fam, TemplateFamily::ChatML));
let template = Some("llama-3");
let fam = match template {
Some("chatml") => TemplateFamily::ChatML,
Some("llama3") | Some("llama-3") => TemplateFamily::Llama3,
_ => TemplateFamily::OpenChat,
};
assert!(matches!(fam, TemplateFamily::Llama3));
let template = Some("unknown");
let fam = match template {
Some("chatml") => TemplateFamily::ChatML,
Some("llama3") | Some("llama-3") => TemplateFamily::Llama3,
_ => TemplateFamily::OpenChat,
};
assert!(matches!(fam, TemplateFamily::OpenChat));
}
#[test]
fn test_generation_options_construction() {
use crate::engine::GenOptions;
let mut opts = GenOptions::default();
let temperature = Some(0.8f32);
if let Some(t) = temperature { opts.temperature = t; }
assert_eq!(opts.temperature, 0.8);
let top_p = Some(0.9f32);
if let Some(p) = top_p { opts.top_p = p; }
assert_eq!(opts.top_p, 0.9);
let top_k = Some(50i32);
if let Some(k) = top_k { opts.top_k = k; }
assert_eq!(opts.top_k, 50);
let max_tokens = Some(200usize);
if let Some(m) = max_tokens { opts.max_tokens = m; }
assert_eq!(opts.max_tokens, 200);
let stream = Some(true);
if let Some(s) = stream { opts.stream = s; }
assert_eq!(opts.stream, true);
}
#[tokio::test]
async fn test_ws_generate_handler() {
use crate::model_registry::{Registry, ModelEntry};
use crate::engine::adapter::InferenceEngineAdapter;
use axum::extract::ws::WebSocketUpgrade;
let mut registry = Registry::default();
registry.register(ModelEntry {
name: "ws-test".to_string(),
base_path: "./test.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = Box::new(InferenceEngineAdapter::new());
let _state = Arc::new(AppState { engine, registry });
fn test_signature() -> bool {
fn _dummy_test(_state: axum::extract::State<std::sync::Arc<crate::AppState>>, _ws: WebSocketUpgrade) -> impl axum::response::IntoResponse {
axum::response::Json(serde_json::json!({"test": true}))
}
true
}
assert!(test_signature());
}
#[test]
fn test_model_info_structure() {
let info = ModelInfo {
name: "test-model".to_string(),
size_bytes: Some(1024000),
model_type: Some("gguf".to_string()),
parameter_count: Some("7B".to_string()),
source: "registered".to_string(),
};
assert_eq!(info.name, "test-model");
assert_eq!(info.size_bytes, Some(1024000));
assert_eq!(info.model_type.as_ref().unwrap(), "gguf");
assert_eq!(info.parameter_count.as_ref().unwrap(), "7B");
assert_eq!(info.source, "registered");
}
#[test]
fn test_model_list_response_structure() {
let response = ModelListResponse {
models: vec![
ModelInfo {
name: "model1".to_string(),
size_bytes: Some(1000),
model_type: None,
parameter_count: None,
source: "registered".to_string(),
},
ModelInfo {
name: "model2".to_string(),
size_bytes: Some(2000),
model_type: Some("gguf".to_string()),
parameter_count: Some("3B".to_string()),
source: "discovered".to_string(),
},
],
};
assert_eq!(response.models.len(), 2);
assert_eq!(response.models[0].name, "model1");
assert_eq!(response.models[1].name, "model2");
assert_eq!(response.models[1].model_type.as_ref().unwrap(), "gguf");
}
#[tokio::test]
async fn test_list_models_with_discovered_models() {
use crate::model_registry::{Registry, ModelEntry};
use crate::engine::adapter::InferenceEngineAdapter;
let mut registry = Registry::default();
registry.register(ModelEntry {
name: "registered-model".to_string(),
base_path: "./registered.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(AppState { engine, registry });
let _response = list_models(State(state)).await;
assert!(true);
}
#[test]
fn test_prompt_construction_logic() {
use crate::templates::TemplateFamily;
let messages = Some(vec![
ChatMessage {
role: "user".to_string(),
content: "Hello".to_string(),
},
]);
let system = Some("System message");
let template = Some("chatml");
if let Some(ms) = &messages {
let fam = match template {
Some("chatml") => TemplateFamily::ChatML,
Some("llama3") | Some("llama-3") => TemplateFamily::Llama3,
_ => TemplateFamily::OpenChat,
};
let pairs = ms.iter().map(|m| (m.role.clone(), m.content.clone())).collect::<Vec<_>>();
let _prompt = fam.render(system, &pairs, None);
assert_eq!(pairs.len(), 1);
assert_eq!(pairs[0].0, "user");
}
let direct_prompt = Some("Direct prompt text".to_string());
let prompt = direct_prompt.unwrap_or_default();
assert_eq!(prompt, "Direct prompt text");
let no_prompt: Option<String> = None;
let prompt = no_prompt.unwrap_or_default();
assert_eq!(prompt, "");
}
#[test]
fn test_websocket_message_types() {
use axum::extract::ws::Message as WsMessage;
let text_msg = WsMessage::Text("test message".to_string());
match text_msg {
WsMessage::Text(t) => assert_eq!(t, "test message"),
_ => panic!("Expected Text message"),
}
let binary_msg = WsMessage::Binary(b"test binary".to_vec());
match binary_msg {
WsMessage::Binary(b) => {
let s = String::from_utf8_lossy(&b).to_string();
assert_eq!(s, "test binary");
}
_ => panic!("Expected Binary message"),
}
}
#[test]
fn test_json_error_responses() {
let error_response = serde_json::json!({"error": "bad request: parse error"});
assert!(error_response["error"].is_string());
assert!(error_response["error"].as_str().unwrap().contains("bad request"));
let model_not_found = serde_json::json!({"error": "model not found"});
assert_eq!(model_not_found["error"], "model not found");
let load_failed = serde_json::json!({"error": "load failed"});
assert_eq!(load_failed["error"], "load failed");
let done_message = serde_json::json!({"done": true});
assert_eq!(done_message["done"], true);
}
#[tokio::test]
async fn test_discover_models_success_path() {
use crate::model_registry::Registry;
use crate::engine::adapter::InferenceEngineAdapter;
let registry = Registry::default();
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(AppState { engine, registry });
let _response = discover_models(State(state)).await;
assert!(true);
}
#[test]
fn test_debug_impls() {
let req = GenerateRequest {
model: "test".to_string(),
prompt: Some("test prompt".to_string()),
messages: None,
system: None,
max_tokens: Some(50),
temperature: Some(0.7),
top_p: Some(0.9),
top_k: Some(40),
stream: Some(false),
};
let debug_str = format!("{:?}", req);
assert!(debug_str.contains("test"));
assert!(debug_str.contains("test prompt"));
let chat_msg = ChatMessage {
role: "user".to_string(),
content: "hello".to_string(),
};
let debug_str = format!("{:?}", chat_msg);
assert!(debug_str.contains("user"));
assert!(debug_str.contains("hello"));
let gen_resp = GenerateResponse {
response: "generated text".to_string(),
};
let debug_str = format!("{:?}", gen_resp);
assert!(debug_str.contains("generated text"));
let model_info = ModelInfo {
name: "test".to_string(),
size_bytes: Some(1000),
model_type: Some("gguf".to_string()),
parameter_count: Some("7B".to_string()),
source: "test".to_string(),
};
let debug_str = format!("{:?}", model_info);
assert!(debug_str.contains("test"));
assert!(debug_str.contains("gguf"));
assert!(debug_str.contains("7B"));
}
#[test]
fn test_serialization_of_structures() {
let model_list = ModelListResponse {
models: vec![
ModelInfo {
name: "test1".to_string(),
size_bytes: Some(1000),
model_type: Some("gguf".to_string()),
parameter_count: Some("7B".to_string()),
source: "registered".to_string(),
},
],
};
let json = serde_json::to_string(&model_list).unwrap();
assert!(json.contains("test1"));
assert!(json.contains("7B"));
let parsed: ModelListResponse = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.models.len(), 1);
assert_eq!(parsed.models[0].name, "test1");
let gen_response = GenerateResponse {
response: "Test response".to_string(),
};
let json = serde_json::to_string(&gen_response).unwrap();
assert!(json.contains("Test response"));
let parsed: GenerateResponse = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.response, "Test response");
}
#[test]
fn test_request_defaults_and_optional_fields() {
let minimal_json = r#"{
"model": "test-model",
"prompt": "test prompt"
}"#;
let request: GenerateRequest = serde_json::from_str(minimal_json).unwrap();
assert_eq!(request.model, "test-model");
assert_eq!(request.prompt.as_ref().unwrap(), "test prompt");
assert!(request.messages.is_none());
assert!(request.system.is_none());
assert!(request.temperature.is_none());
assert!(request.top_p.is_none());
assert!(request.top_k.is_none());
assert!(request.max_tokens.is_none());
assert!(request.stream.is_none());
let full_json = r#"{
"model": "test-model",
"messages": [
{"role": "user", "content": "hello"}
],
"system": "system prompt",
"temperature": 0.8,
"top_p": 0.9,
"top_k": 50,
"max_tokens": 100,
"stream": true
}"#;
let request: GenerateRequest = serde_json::from_str(full_json).unwrap();
assert_eq!(request.temperature, Some(0.8));
assert_eq!(request.top_p, Some(0.9));
assert_eq!(request.top_k, Some(50));
assert_eq!(request.max_tokens, Some(100));
assert_eq!(request.stream, Some(true));
assert!(request.messages.is_some());
assert_eq!(request.messages.as_ref().unwrap().len(), 1);
}
}