use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use super::parts::{BuiltinToolCallPart, FilePart, TextPart, ThinkingPart, ToolCallPart};
use crate::usage::RequestUsage;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ModelResponse {
pub parts: Vec<ModelResponsePart>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model_name: Option<String>,
pub timestamp: DateTime<Utc>,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<FinishReason>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<RequestUsage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub vendor_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub vendor_details: Option<serde_json::Value>,
#[serde(default = "default_response_kind")]
pub kind: String,
}
fn default_response_kind() -> String {
"response".to_string()
}
impl ModelResponse {
#[must_use]
pub fn new() -> Self {
Self {
parts: Vec::new(),
model_name: None,
timestamp: Utc::now(),
finish_reason: None,
usage: None,
vendor_id: None,
vendor_details: None,
kind: "response".to_string(),
}
}
#[must_use]
pub fn with_parts(parts: Vec<ModelResponsePart>) -> Self {
Self {
parts,
..Self::new()
}
}
#[must_use]
pub fn text(content: impl Into<String>) -> Self {
Self::with_parts(vec![ModelResponsePart::Text(TextPart::new(content))])
}
pub fn add_part(&mut self, part: ModelResponsePart) {
self.parts.push(part);
}
#[must_use]
pub fn with_model_name(mut self, name: impl Into<String>) -> Self {
self.model_name = Some(name.into());
self
}
#[must_use]
pub fn with_finish_reason(mut self, reason: FinishReason) -> Self {
self.finish_reason = Some(reason);
self
}
#[must_use]
pub fn with_usage(mut self, usage: RequestUsage) -> Self {
self.usage = Some(usage);
self
}
#[must_use]
pub fn with_vendor_id(mut self, id: impl Into<String>) -> Self {
self.vendor_id = Some(id.into());
self
}
#[must_use]
pub fn with_vendor_details(mut self, details: serde_json::Value) -> Self {
self.vendor_details = Some(details);
self
}
pub fn text_parts(&self) -> impl Iterator<Item = &TextPart> {
self.parts.iter().filter_map(|p| match p {
ModelResponsePart::Text(t) => Some(t),
_ => None,
})
}
pub fn tool_call_parts(&self) -> impl Iterator<Item = &ToolCallPart> {
self.parts.iter().filter_map(|p| match p {
ModelResponsePart::ToolCall(t) => Some(t),
_ => None,
})
}
pub fn thinking_parts(&self) -> impl Iterator<Item = &ThinkingPart> {
self.parts.iter().filter_map(|p| match p {
ModelResponsePart::Thinking(t) => Some(t),
_ => None,
})
}
pub fn file_parts(&self) -> impl Iterator<Item = &FilePart> {
self.parts.iter().filter_map(|p| match p {
ModelResponsePart::File(f) => Some(f),
_ => None,
})
}
#[deprecated(note = "Use text_parts() iterator instead")]
pub fn text_parts_vec(&self) -> Vec<&TextPart> {
self.text_parts().collect()
}
#[deprecated(note = "Use tool_call_parts() iterator instead")]
pub fn tool_call_parts_vec(&self) -> Vec<&ToolCallPart> {
self.tool_call_parts().collect()
}
#[deprecated(note = "Use thinking_parts() iterator instead")]
pub fn thinking_parts_vec(&self) -> Vec<&ThinkingPart> {
self.thinking_parts().collect()
}
#[deprecated(note = "Use file_parts() iterator instead")]
pub fn file_parts_vec(&self) -> Vec<&FilePart> {
self.file_parts().collect()
}
#[must_use]
pub fn has_files(&self) -> bool {
self.parts
.iter()
.any(|p| matches!(p, ModelResponsePart::File(_)))
}
pub fn builtin_tool_call_parts(&self) -> impl Iterator<Item = &BuiltinToolCallPart> {
self.parts.iter().filter_map(|p| match p {
ModelResponsePart::BuiltinToolCall(b) => Some(b),
_ => None,
})
}
#[deprecated(note = "Use builtin_tool_call_parts() iterator instead")]
pub fn builtin_tool_call_parts_vec(&self) -> Vec<&BuiltinToolCallPart> {
self.builtin_tool_call_parts().collect()
}
#[must_use]
pub fn has_builtin_tool_calls(&self) -> bool {
self.parts
.iter()
.any(|p| matches!(p, ModelResponsePart::BuiltinToolCall(_)))
}
#[must_use]
pub fn text_content(&self) -> String {
self.text_parts()
.map(|p| p.content.as_str())
.collect::<Vec<_>>()
.join("")
}
#[must_use]
pub fn has_tool_calls(&self) -> bool {
self.parts
.iter()
.any(|p| matches!(p, ModelResponsePart::ToolCall(_)))
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.parts.is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.parts.len()
}
}
impl Default for ModelResponse {
fn default() -> Self {
Self::new()
}
}
impl FromIterator<ModelResponsePart> for ModelResponse {
fn from_iter<T: IntoIterator<Item = ModelResponsePart>>(iter: T) -> Self {
Self::with_parts(iter.into_iter().collect())
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "part_kind", rename_all = "kebab-case")]
pub enum ModelResponsePart {
Text(TextPart),
ToolCall(ToolCallPart),
Thinking(ThinkingPart),
File(FilePart),
BuiltinToolCall(BuiltinToolCallPart),
}
impl ModelResponsePart {
#[must_use]
pub fn text(content: impl Into<String>) -> Self {
Self::Text(TextPart::new(content))
}
#[must_use]
pub fn tool_call(
tool_name: impl Into<String>,
args: impl Into<super::parts::ToolCallArgs>,
) -> Self {
Self::ToolCall(ToolCallPart::new(tool_name, args))
}
#[must_use]
pub fn thinking(content: impl Into<String>) -> Self {
Self::Thinking(ThinkingPart::new(content))
}
#[must_use]
pub fn file(data: Vec<u8>, media_type: impl Into<String>) -> Self {
Self::File(FilePart::from_bytes(data, media_type))
}
#[must_use]
pub fn builtin_tool_call(
tool_name: impl Into<String>,
args: impl Into<super::parts::ToolCallArgs>,
) -> Self {
Self::BuiltinToolCall(BuiltinToolCallPart::new(tool_name, args))
}
#[must_use]
pub fn part_kind(&self) -> &'static str {
match self {
Self::Text(_) => TextPart::PART_KIND,
Self::ToolCall(_) => ToolCallPart::PART_KIND,
Self::Thinking(_) => ThinkingPart::PART_KIND,
Self::File(_) => FilePart::PART_KIND,
Self::BuiltinToolCall(_) => BuiltinToolCallPart::PART_KIND,
}
}
#[must_use]
pub fn is_text(&self) -> bool {
matches!(self, Self::Text(_))
}
#[must_use]
pub fn is_tool_call(&self) -> bool {
matches!(self, Self::ToolCall(_))
}
#[must_use]
pub fn is_thinking(&self) -> bool {
matches!(self, Self::Thinking(_))
}
#[must_use]
pub fn is_file(&self) -> bool {
matches!(self, Self::File(_))
}
#[must_use]
pub fn is_builtin_tool_call(&self) -> bool {
matches!(self, Self::BuiltinToolCall(_))
}
}
impl From<TextPart> for ModelResponsePart {
fn from(p: TextPart) -> Self {
Self::Text(p)
}
}
impl From<ToolCallPart> for ModelResponsePart {
fn from(p: ToolCallPart) -> Self {
Self::ToolCall(p)
}
}
impl From<ThinkingPart> for ModelResponsePart {
fn from(p: ThinkingPart) -> Self {
Self::Thinking(p)
}
}
impl From<FilePart> for ModelResponsePart {
fn from(p: FilePart) -> Self {
Self::File(p)
}
}
impl From<BuiltinToolCallPart> for ModelResponsePart {
fn from(p: BuiltinToolCallPart) -> Self {
Self::BuiltinToolCall(p)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FinishReason {
Stop,
Length,
ContentFilter,
ToolCall,
Error,
EndTurn,
StopSequence,
}
impl FinishReason {
#[must_use]
pub fn is_complete(&self) -> bool {
matches!(self, Self::Stop | Self::EndTurn | Self::StopSequence)
}
#[must_use]
pub fn is_truncated(&self) -> bool {
matches!(self, Self::Length)
}
#[must_use]
pub fn is_tool_call(&self) -> bool {
matches!(self, Self::ToolCall)
}
}
impl std::fmt::Display for FinishReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Stop => write!(f, "stop"),
Self::Length => write!(f, "length"),
Self::ContentFilter => write!(f, "content_filter"),
Self::ToolCall => write!(f, "tool_call"),
Self::Error => write!(f, "error"),
Self::EndTurn => write!(f, "end_turn"),
Self::StopSequence => write!(f, "stop_sequence"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_response_new() {
let response = ModelResponse::new();
assert!(response.is_empty());
assert!(!response.has_tool_calls());
}
#[test]
fn test_model_response_text() {
let response = ModelResponse::text("Hello, world!");
assert_eq!(response.len(), 1);
assert_eq!(response.text_content(), "Hello, world!");
}
#[test]
fn test_model_response_with_tool_calls() {
let response = ModelResponse::with_parts(vec![
ModelResponsePart::text("Let me check the weather."),
ModelResponsePart::tool_call("get_weather", serde_json::json!({"city": "NYC"})),
]);
assert!(response.has_tool_calls());
assert_eq!(response.tool_call_parts().count(), 1);
}
#[test]
fn test_finish_reason() {
assert!(FinishReason::Stop.is_complete());
assert!(FinishReason::Length.is_truncated());
assert!(FinishReason::ToolCall.is_tool_call());
}
#[test]
fn test_serde_roundtrip() {
let response = ModelResponse::with_parts(vec![
ModelResponsePart::text("Hello"),
ModelResponsePart::thinking("Thinking..."),
])
.with_model_name("gpt-4")
.with_finish_reason(FinishReason::Stop);
let json = serde_json::to_string(&response).unwrap();
let parsed: ModelResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response.len(), parsed.len());
assert_eq!(response.model_name, parsed.model_name);
}
}