use serde::{Deserialize, Serialize};
use crate::providers::{LLMResponse, StreamEvent, Usage as ZeptoUsage};
use crate::session::{Message, Role};
#[derive(Debug, Deserialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(default)]
pub stream: Option<bool>,
#[serde(default)]
pub max_tokens: Option<u32>,
#[serde(default)]
pub temperature: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
#[derive(Debug, Serialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: &'static str,
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
pub usage: UsageResponse,
}
#[derive(Debug, Serialize)]
pub struct Choice {
pub index: u32,
pub message: ChatMessage,
pub finish_reason: String,
}
#[derive(Debug, Serialize)]
pub struct ChatCompletionChunk {
pub id: String,
pub object: &'static str,
pub created: u64,
pub model: String,
pub choices: Vec<ChunkChoice>,
}
#[derive(Debug, Serialize)]
pub struct ChunkChoice {
pub index: u32,
pub delta: Delta,
pub finish_reason: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct Delta {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct UsageResponse {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Serialize)]
pub struct ModelsResponse {
pub object: &'static str,
pub data: Vec<ModelObject>,
}
#[derive(Debug, Serialize)]
pub struct ModelObject {
pub id: String,
pub object: &'static str,
pub created: u64,
pub owned_by: String,
}
pub fn messages_from_openai(msgs: &[ChatMessage]) -> Result<Vec<Message>, String> {
msgs.iter()
.map(|m| {
let role = match m.role.as_str() {
"system" => Ok(Role::System),
"user" => Ok(Role::User),
"assistant" => Ok(Role::Assistant),
other => Err(format!("unsupported message role: {other}")),
}?;
Ok(Message {
role,
content: m.content.clone(),
content_parts: vec![crate::session::ContentPart::Text {
text: m.content.clone(),
}],
tool_calls: None,
tool_call_id: None,
})
})
.collect()
}
pub fn response_from_llm(llm: &LLMResponse, model: &str) -> ChatCompletionResponse {
let now = unix_now();
let usage = llm
.usage
.as_ref()
.map(usage_from_zepto)
.unwrap_or(UsageResponse {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
});
ChatCompletionResponse {
id: completion_id(),
object: "chat.completion",
created: now,
model: model.to_string(),
choices: vec![Choice {
index: 0,
message: ChatMessage {
role: "assistant".to_string(),
content: llm.content.clone(),
},
finish_reason: "stop".to_string(),
}],
usage,
}
}
pub fn first_chunk(model: &str, id: &str, created: u64) -> ChatCompletionChunk {
ChatCompletionChunk {
id: id.to_string(),
object: "chat.completion.chunk",
created,
model: model.to_string(),
choices: vec![ChunkChoice {
index: 0,
delta: Delta {
role: Some("assistant".to_string()),
content: None,
},
finish_reason: None,
}],
}
}
pub fn delta_chunk(text: &str, model: &str, id: &str, created: u64) -> ChatCompletionChunk {
ChatCompletionChunk {
id: id.to_string(),
object: "chat.completion.chunk",
created,
model: model.to_string(),
choices: vec![ChunkChoice {
index: 0,
delta: Delta {
role: None,
content: Some(text.to_string()),
},
finish_reason: None,
}],
}
}
pub fn done_chunk(model: &str, id: &str, created: u64) -> ChatCompletionChunk {
ChatCompletionChunk {
id: id.to_string(),
object: "chat.completion.chunk",
created,
model: model.to_string(),
choices: vec![ChunkChoice {
index: 0,
delta: Delta {
role: None,
content: None,
},
finish_reason: Some("stop".to_string()),
}],
}
}
pub fn chunk_from_stream_event(
event: &StreamEvent,
model: &str,
id: &str,
created: u64,
) -> Option<ChatCompletionChunk> {
match event {
StreamEvent::Delta(text) => Some(delta_chunk(text, model, id, created)),
StreamEvent::Done { .. } => Some(done_chunk(model, id, created)),
StreamEvent::Error(_) => {
None
}
StreamEvent::ToolCalls(_) => {
None
}
}
}
fn completion_id() -> String {
format!("chatcmpl-{}", uuid::Uuid::new_v4())
}
fn unix_now() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
fn usage_from_zepto(u: &ZeptoUsage) -> UsageResponse {
UsageResponse {
prompt_tokens: u.prompt_tokens,
completion_tokens: u.completion_tokens,
total_tokens: u.total_tokens,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::Usage;
#[test]
fn test_messages_from_openai_empty() {
let msgs = messages_from_openai(&[]).unwrap();
assert!(msgs.is_empty());
}
#[test]
fn test_messages_from_openai_maps_roles() {
let openai_msgs = vec![
ChatMessage {
role: "system".into(),
content: "You are helpful.".into(),
},
ChatMessage {
role: "user".into(),
content: "Hello".into(),
},
ChatMessage {
role: "assistant".into(),
content: "Hi!".into(),
},
];
let msgs = messages_from_openai(&openai_msgs).unwrap();
assert_eq!(msgs.len(), 3);
assert_eq!(msgs[0].role, Role::System);
assert_eq!(msgs[0].content, "You are helpful.");
assert_eq!(msgs[1].role, Role::User);
assert_eq!(msgs[2].role, Role::Assistant);
}
#[test]
fn test_messages_from_openai_unknown_role_returns_error() {
let openai_msgs = vec![ChatMessage {
role: "function".into(),
content: "result".into(),
}];
let result = messages_from_openai(&openai_msgs);
assert!(result.is_err());
assert!(result.unwrap_err().contains("function"));
}
#[test]
fn test_response_from_llm_basic() {
let llm = LLMResponse::text("Hello, world!");
let resp = response_from_llm(&llm, "test-model");
assert_eq!(resp.object, "chat.completion");
assert_eq!(resp.model, "test-model");
assert_eq!(resp.choices.len(), 1);
assert_eq!(resp.choices[0].message.role, "assistant");
assert_eq!(resp.choices[0].message.content, "Hello, world!");
assert_eq!(resp.choices[0].finish_reason, "stop");
assert!(resp.id.starts_with("chatcmpl-"));
}
#[test]
fn test_response_from_llm_with_usage() {
let llm = LLMResponse::text("ok").with_usage(Usage::new(10, 20));
let resp = response_from_llm(&llm, "m");
assert_eq!(resp.usage.prompt_tokens, 10);
assert_eq!(resp.usage.completion_tokens, 20);
assert_eq!(resp.usage.total_tokens, 30);
}
#[test]
fn test_response_from_llm_without_usage_zeroes() {
let llm = LLMResponse::text("ok");
let resp = response_from_llm(&llm, "m");
assert_eq!(resp.usage.prompt_tokens, 0);
assert_eq!(resp.usage.total_tokens, 0);
}
#[test]
fn test_first_chunk_has_role() {
let c = first_chunk("m", "id-1", 1000);
assert_eq!(c.object, "chat.completion.chunk");
assert_eq!(c.choices[0].delta.role.as_deref(), Some("assistant"));
assert!(c.choices[0].delta.content.is_none());
assert!(c.choices[0].finish_reason.is_none());
}
#[test]
fn test_delta_chunk_has_content() {
let c = delta_chunk("hello", "m", "id-1", 1000);
assert!(c.choices[0].delta.role.is_none());
assert_eq!(c.choices[0].delta.content.as_deref(), Some("hello"));
assert!(c.choices[0].finish_reason.is_none());
}
#[test]
fn test_done_chunk_has_stop_reason() {
let c = done_chunk("m", "id-1", 1000);
assert!(c.choices[0].delta.role.is_none());
assert!(c.choices[0].delta.content.is_none());
assert_eq!(c.choices[0].finish_reason.as_deref(), Some("stop"));
}
#[test]
fn test_chunk_from_delta_event() {
let event = StreamEvent::Delta("hi".into());
let chunk = chunk_from_stream_event(&event, "m", "id", 1);
assert!(chunk.is_some());
let c = chunk.unwrap();
assert_eq!(c.choices[0].delta.content.as_deref(), Some("hi"));
}
#[test]
fn test_chunk_from_done_event() {
let event = StreamEvent::Done {
content: "full".into(),
usage: None,
};
let chunk = chunk_from_stream_event(&event, "m", "id", 1);
assert!(chunk.is_some());
let c = chunk.unwrap();
assert_eq!(c.choices[0].finish_reason.as_deref(), Some("stop"));
}
#[test]
fn test_chunk_from_error_event_is_none() {
let event = StreamEvent::Error(crate::error::ZeptoError::Provider("fail".into()));
let chunk = chunk_from_stream_event(&event, "m", "id", 1);
assert!(chunk.is_none());
}
#[test]
fn test_chunk_from_tool_calls_event_is_none() {
let event = StreamEvent::ToolCalls(vec![]);
let chunk = chunk_from_stream_event(&event, "m", "id", 1);
assert!(chunk.is_none());
}
#[test]
fn test_chat_completion_response_serializes() {
let llm = LLMResponse::text("ok");
let resp = response_from_llm(&llm, "m");
let json = serde_json::to_string(&resp).unwrap();
assert!(json.contains("\"object\":\"chat.completion\""));
assert!(json.contains("\"finish_reason\":\"stop\""));
}
#[test]
fn test_chat_completion_chunk_serializes() {
let c = delta_chunk("token", "m", "id", 42);
let json = serde_json::to_string(&c).unwrap();
assert!(json.contains("\"object\":\"chat.completion.chunk\""));
assert!(json.contains("\"content\":\"token\""));
}
#[test]
fn test_models_response_serializes() {
let resp = ModelsResponse {
object: "list",
data: vec![ModelObject {
id: "gpt-4o".into(),
object: "model",
created: 1000,
owned_by: "zeptoclaw".into(),
}],
};
let json = serde_json::to_string(&resp).unwrap();
assert!(json.contains("\"object\":\"list\""));
assert!(json.contains("\"id\":\"gpt-4o\""));
}
#[test]
fn test_chat_completion_request_deserializes() {
let json = r#"{
"model": "gpt-4o",
"messages": [{"role": "user", "content": "hi"}],
"stream": true,
"max_tokens": 100,
"temperature": 0.7
}"#;
let req: ChatCompletionRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.model, "gpt-4o");
assert_eq!(req.messages.len(), 1);
assert_eq!(req.stream, Some(true));
assert_eq!(req.max_tokens, Some(100));
assert!((req.temperature.unwrap() - 0.7).abs() < f32::EPSILON);
}
#[test]
fn test_chat_completion_request_minimal() {
let json = r#"{"model": "m", "messages": []}"#;
let req: ChatCompletionRequest = serde_json::from_str(json).unwrap();
assert!(req.stream.is_none());
assert!(req.max_tokens.is_none());
assert!(req.temperature.is_none());
}
#[test]
fn test_completion_id_format() {
let id = completion_id();
assert!(id.starts_with("chatcmpl-"));
assert!(id.len() > "chatcmpl-".len());
}
#[test]
fn test_unix_now_is_reasonable() {
let now = unix_now();
assert!(now > 1_704_067_200);
}
#[test]
fn test_usage_from_zepto() {
let zu = crate::providers::Usage::new(5, 10);
let u = usage_from_zepto(&zu);
assert_eq!(u.prompt_tokens, 5);
assert_eq!(u.completion_tokens, 10);
assert_eq!(u.total_tokens, 15);
}
}