use crate::{MessageRole, ToolCall};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatResponse {
pub content: String,
pub model: String,
pub usage: Option<Usage>,
pub finish_reason: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
pub metadata: HashMap<String, serde_json::Value>,
#[serde(with = "chrono::serde::ts_seconds_option")]
pub timestamp: Option<chrono::DateTime<chrono::Utc>>,
pub id: Option<String>,
}
impl ChatResponse {
pub fn new(content: impl Into<String>, model: impl Into<String>) -> Self {
Self {
content: content.into(),
model: model.into(),
usage: None,
finish_reason: None,
tool_calls: None,
metadata: HashMap::new(),
timestamp: Some(chrono::Utc::now()),
id: None,
}
}
pub fn with_usage(mut self, usage: Usage) -> Self {
self.usage = Some(usage);
self
}
pub fn with_finish_reason(mut self, reason: impl Into<String>) -> Self {
self.finish_reason = Some(reason.into());
self
}
pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCall>) -> Self {
self.tool_calls = Some(tool_calls);
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
pub fn with_id(mut self, id: impl Into<String>) -> Self {
self.id = Some(id.into());
self
}
pub fn has_tool_calls(&self) -> bool {
self.tool_calls.as_ref().map_or(false, |calls| !calls.is_empty())
}
pub fn is_finished(&self) -> bool {
matches!(
self.finish_reason.as_deref(),
Some("stop") | Some("end_turn") | Some("tool_calls")
)
}
pub fn is_truncated(&self) -> bool {
matches!(
self.finish_reason.as_deref(),
Some("length") | Some("max_tokens")
)
}
pub fn content_length(&self) -> usize {
self.content.len()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionResponse {
pub text: String,
pub model: String,
pub usage: Option<Usage>,
pub finish_reason: Option<String>,
pub logprobs: Option<LogProbs>,
pub metadata: HashMap<String, serde_json::Value>,
#[serde(with = "chrono::serde::ts_seconds_option")]
pub timestamp: Option<chrono::DateTime<chrono::Utc>>,
pub id: Option<String>,
}
impl CompletionResponse {
pub fn new(text: impl Into<String>, model: impl Into<String>) -> Self {
Self {
text: text.into(),
model: model.into(),
usage: None,
finish_reason: None,
logprobs: None,
metadata: HashMap::new(),
timestamp: Some(chrono::Utc::now()),
id: None,
}
}
pub fn with_usage(mut self, usage: Usage) -> Self {
self.usage = Some(usage);
self
}
pub fn with_finish_reason(mut self, reason: impl Into<String>) -> Self {
self.finish_reason = Some(reason.into());
self
}
pub fn with_logprobs(mut self, logprobs: LogProbs) -> Self {
self.logprobs = Some(logprobs);
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
pub fn with_id(mut self, id: impl Into<String>) -> Self {
self.id = Some(id.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamChunk {
pub content: String,
pub is_delta: bool,
pub is_done: bool,
pub model: String,
pub role: Option<MessageRole>,
pub tool_calls_delta: Option<Vec<ToolCallDelta>>,
pub finish_reason: Option<String>,
pub usage: Option<Usage>,
pub metadata: HashMap<String, serde_json::Value>,
#[serde(with = "chrono::serde::ts_seconds_option")]
pub timestamp: Option<chrono::DateTime<chrono::Utc>>,
}
impl StreamChunk {
pub fn new(
content: impl Into<String>,
model: impl Into<String>,
is_delta: bool,
is_done: bool,
) -> Self {
Self {
content: content.into(),
is_delta,
is_done,
model: model.into(),
role: None,
tool_calls_delta: None,
finish_reason: None,
usage: None,
metadata: HashMap::new(),
timestamp: Some(chrono::Utc::now()),
}
}
pub fn delta(content: impl Into<String>, model: impl Into<String>) -> Self {
Self::new(content, model, true, false)
}
pub fn done(model: impl Into<String>) -> Self {
Self::new("", model, false, true)
}
pub fn with_role(mut self, role: MessageRole) -> Self {
self.role = Some(role);
self
}
pub fn with_tool_calls_delta(mut self, delta: Vec<ToolCallDelta>) -> Self {
self.tool_calls_delta = Some(delta);
self
}
pub fn with_finish_reason(mut self, reason: impl Into<String>) -> Self {
self.finish_reason = Some(reason.into());
self
}
pub fn with_usage(mut self, usage: Usage) -> Self {
self.usage = Some(usage);
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
pub fn has_content(&self) -> bool {
!self.content.is_empty()
}
pub fn has_tool_calls(&self) -> bool {
self.tool_calls_delta.as_ref().map_or(false, |calls| !calls.is_empty())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallDelta {
pub index: u32,
pub id: Option<String>,
#[serde(rename = "type")]
pub call_type: Option<String>,
pub function: Option<ToolFunctionDelta>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolFunctionDelta {
pub name: Option<String>,
pub arguments: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
pub cached_tokens: Option<u32>,
pub reasoning_tokens: Option<u32>,
}
impl Usage {
pub fn new(prompt_tokens: u32, completion_tokens: u32) -> Self {
Self {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
cached_tokens: None,
reasoning_tokens: None,
}
}
pub fn with_cached_tokens(mut self, cached_tokens: u32) -> Self {
self.cached_tokens = Some(cached_tokens);
self
}
pub fn with_reasoning_tokens(mut self, reasoning_tokens: u32) -> Self {
self.reasoning_tokens = Some(reasoning_tokens);
self
}
pub fn effective_prompt_tokens(&self) -> u32 {
self.prompt_tokens - self.cached_tokens.unwrap_or(0)
}
pub fn total_cost(&self) -> u32 {
self.effective_prompt_tokens() + self.completion_tokens
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogProbs {
pub token_logprobs: Vec<Option<f64>>,
pub top_logprobs: Vec<Option<HashMap<String, f64>>>,
pub text_offset: Vec<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingResponse {
pub embeddings: Vec<Vec<f32>>,
pub model: String,
pub usage: Option<Usage>,
pub metadata: HashMap<String, serde_json::Value>,
#[serde(with = "chrono::serde::ts_seconds_option")]
pub timestamp: Option<chrono::DateTime<chrono::Utc>>,
}
impl EmbeddingResponse {
pub fn new(embeddings: Vec<Vec<f32>>, model: impl Into<String>) -> Self {
Self {
embeddings,
model: model.into(),
usage: None,
metadata: HashMap::new(),
timestamp: Some(chrono::Utc::now()),
}
}
pub fn with_usage(mut self, usage: Usage) -> Self {
self.usage = Some(usage);
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
pub fn count(&self) -> usize {
self.embeddings.len()
}
pub fn dimension(&self) -> Option<usize> {
self.embeddings.first().map(|emb| emb.len())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chat_response_creation() {
let response = ChatResponse::new("Hello!", "gpt-4")
.with_finish_reason("stop")
.with_usage(Usage::new(10, 5));
assert_eq!(response.content, "Hello!");
assert_eq!(response.model, "gpt-4");
assert_eq!(response.finish_reason, Some("stop".to_string()));
assert!(response.usage.is_some());
assert!(response.is_finished());
}
#[test]
fn test_stream_chunk() {
let chunk = StreamChunk::delta("Hello", "gpt-4")
.with_role(MessageRole::Assistant);
assert_eq!(chunk.content, "Hello");
assert!(chunk.is_delta);
assert!(!chunk.is_done);
assert_eq!(chunk.role, Some(MessageRole::Assistant));
assert!(chunk.has_content());
}
#[test]
fn test_usage_calculation() {
let usage = Usage::new(100, 50)
.with_cached_tokens(20);
assert_eq!(usage.total_tokens, 150);
assert_eq!(usage.effective_prompt_tokens(), 80);
assert_eq!(usage.total_cost(), 130);
}
}