use crate::coercion::CoercionFlag;
use crate::message::Message;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct HealingMetadata {
pub flags: Vec<CoercionFlag>,
pub confidence: f32,
pub original_error: String,
}
impl HealingMetadata {
pub fn new(flags: Vec<CoercionFlag>, confidence: f32, original_error: String) -> Self {
Self {
flags,
confidence: confidence.clamp(0.0, 1.0),
original_error,
}
}
pub fn has_major_coercions(&self) -> bool {
self.flags.iter().any(|f| f.is_major())
}
pub fn is_confident(&self, threshold: f32) -> bool {
self.confidence >= threshold
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct CompletionResponse {
pub id: String,
pub model: String,
pub choices: Vec<CompletionChoice>,
pub usage: Usage,
#[serde(skip_serializing_if = "Option::is_none")]
pub created: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub provider: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub healing_metadata: Option<HealingMetadata>,
}
impl CompletionResponse {
pub fn content(&self) -> Option<&str> {
self.choices
.first()
.map(|choice| choice.message.content_text())
}
pub fn first_choice(&self) -> Option<&CompletionChoice> {
self.choices.first()
}
pub fn was_healed(&self) -> bool {
self.healing_metadata.is_some()
}
pub fn confidence(&self) -> f32 {
self.healing_metadata
.as_ref()
.map(|m| m.confidence)
.unwrap_or(1.0)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct CompletionChoice {
pub index: u32,
pub message: Message,
pub finish_reason: FinishReason,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FinishReason {
Stop,
Length,
ContentFilter,
ToolCalls,
}
impl FinishReason {
pub fn as_str(self) -> &'static str {
match self {
Self::Stop => "stop",
Self::Length => "length",
Self::ContentFilter => "content_filter",
Self::ToolCalls => "tool_calls",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
#[serde(
skip_serializing_if = "Option::is_none",
default,
alias = "thinking_tokens"
)]
pub reasoning_tokens: Option<u32>,
}
impl Usage {
pub fn new(prompt_tokens: u32, completion_tokens: u32) -> Self {
Self {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
reasoning_tokens: None,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct CompletionChunk {
pub id: String,
pub model: String,
pub choices: Vec<ChoiceDelta>,
#[serde(skip_serializing_if = "Option::is_none")]
pub created: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ChoiceDelta {
pub index: u32,
pub delta: MessageDelta,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<FinishReason>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MessageDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<crate::message::Role>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallDelta>>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolCallDelta {
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
pub tool_type: Option<crate::tool::ToolType>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<ToolCallFunctionDelta>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolCallFunctionDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_completion_response_content() {
let response = CompletionResponse {
id: "resp_123".to_string(),
model: "gpt-4".to_string(),
choices: vec![CompletionChoice {
index: 0,
message: Message::assistant("Hello!"),
finish_reason: FinishReason::Stop,
logprobs: None,
}],
usage: Usage::new(10, 5),
created: Some(1234567890),
provider: Some("openai".to_string()),
healing_metadata: None,
};
assert_eq!(response.content(), Some("Hello!"));
assert_eq!(response.first_choice().unwrap().index, 0);
assert!(!response.was_healed());
assert_eq!(response.confidence(), 1.0);
}
#[test]
fn test_completion_response_empty_choices() {
let response = CompletionResponse {
id: "resp_123".to_string(),
model: "gpt-4".to_string(),
choices: vec![],
usage: Usage::new(10, 0),
created: None,
provider: None,
healing_metadata: None,
};
assert_eq!(response.content(), None);
assert_eq!(response.first_choice(), None);
}
#[test]
fn test_usage_calculation() {
let usage = Usage::new(100, 50);
assert_eq!(usage.prompt_tokens, 100);
assert_eq!(usage.completion_tokens, 50);
assert_eq!(usage.total_tokens, 150);
assert_eq!(usage.reasoning_tokens, None);
}
#[test]
fn test_usage_deserializes_thinking_tokens_alias() {
let json = serde_json::json!({
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15,
"thinking_tokens": 3
});
let usage: Usage = serde_json::from_value(json).unwrap();
assert_eq!(usage.reasoning_tokens, Some(3));
}
#[test]
fn test_usage_serializes_reasoning_tokens_name() {
let usage = Usage {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
reasoning_tokens: Some(3),
};
let json = serde_json::to_value(usage).unwrap();
assert_eq!(
json.get("reasoning_tokens").and_then(|v| v.as_u64()),
Some(3)
);
assert!(json.get("thinking_tokens").is_none());
}
#[test]
fn test_finish_reason_serialization() {
let json = serde_json::to_string(&FinishReason::Stop).unwrap();
assert_eq!(json, "\"stop\"");
let json = serde_json::to_string(&FinishReason::Length).unwrap();
assert_eq!(json, "\"length\"");
let json = serde_json::to_string(&FinishReason::ContentFilter).unwrap();
assert_eq!(json, "\"content_filter\"");
let json = serde_json::to_string(&FinishReason::ToolCalls).unwrap();
assert_eq!(json, "\"tool_calls\"");
}
#[test]
fn test_response_serialization() {
let response = CompletionResponse {
id: "resp_123".to_string(),
model: "gpt-4".to_string(),
choices: vec![CompletionChoice {
index: 0,
message: Message::assistant("Hello!"),
finish_reason: FinishReason::Stop,
logprobs: None,
}],
usage: Usage::new(10, 5),
created: None,
provider: None,
healing_metadata: None,
};
let json = serde_json::to_string(&response).unwrap();
let parsed: CompletionResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, parsed);
}
#[test]
fn test_streaming_chunk() {
let chunk = CompletionChunk {
id: "resp_123".to_string(),
model: "gpt-4".to_string(),
choices: vec![ChoiceDelta {
index: 0,
delta: MessageDelta {
role: Some(crate::message::Role::Assistant),
content: Some("Hello".to_string()),
reasoning_content: None,
tool_calls: None,
},
finish_reason: None,
}],
created: Some(1234567890),
usage: None,
};
let json = serde_json::to_string(&chunk).unwrap();
let parsed: CompletionChunk = serde_json::from_str(&json).unwrap();
assert_eq!(chunk, parsed);
}
#[test]
fn test_message_delta() {
let delta = MessageDelta {
role: Some(crate::message::Role::Assistant),
content: Some("Hi".to_string()),
reasoning_content: None,
tool_calls: None,
};
let json = serde_json::to_value(&delta).unwrap();
assert_eq!(json.get("role").and_then(|v| v.as_str()), Some("assistant"));
assert_eq!(json.get("content").and_then(|v| v.as_str()), Some("Hi"));
}
#[test]
fn test_optional_fields_not_serialized() {
let response = CompletionResponse {
id: "resp_123".to_string(),
model: "gpt-4".to_string(),
choices: vec![],
usage: Usage::new(10, 5),
created: None,
provider: None,
healing_metadata: None,
};
let json = serde_json::to_value(&response).unwrap();
assert!(json.get("created").is_none());
assert!(json.get("provider").is_none());
assert!(json.get("healing_metadata").is_none());
}
#[test]
fn test_healing_metadata() {
use crate::coercion::CoercionFlag;
let metadata = HealingMetadata::new(
vec![CoercionFlag::StrippedMarkdown],
0.9,
"Parse error".to_string(),
);
assert_eq!(metadata.confidence, 0.9);
assert!(!metadata.has_major_coercions());
assert!(metadata.is_confident(0.8));
assert!(!metadata.is_confident(0.95));
let major_metadata = HealingMetadata::new(
vec![CoercionFlag::TruncatedJson],
0.7,
"Parse error".to_string(),
);
assert!(major_metadata.has_major_coercions());
}
#[test]
fn test_healing_metadata_confidence_clamped() {
let metadata = HealingMetadata::new(vec![], 1.5, "error".to_string());
assert_eq!(metadata.confidence, 1.0);
let metadata = HealingMetadata::new(vec![], -0.5, "error".to_string());
assert_eq!(metadata.confidence, 0.0);
}
#[test]
fn test_response_with_healing_metadata() {
use crate::coercion::CoercionFlag;
let metadata = HealingMetadata::new(
vec![
CoercionFlag::StrippedMarkdown,
CoercionFlag::FixedTrailingComma,
],
0.85,
"JSON parse error".to_string(),
);
let response = CompletionResponse {
id: "resp_123".to_string(),
model: "gpt-4".to_string(),
choices: vec![],
usage: Usage::new(10, 5),
created: None,
provider: None,
healing_metadata: Some(metadata),
};
assert!(response.was_healed());
assert_eq!(response.confidence(), 0.85);
let json = serde_json::to_string(&response).unwrap();
let parsed: CompletionResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, parsed);
assert!(parsed.was_healed());
assert_eq!(parsed.confidence(), 0.85);
}
}