use serde::{Deserialize, Serialize};
use crate::json::Json;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AnnotatedLlmRequest {
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<GenerationParams>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ToolDefinition>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub store: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub previous_response_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub truncation: Option<Json>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning: Option<Json>,
#[serde(skip_serializing_if = "Option::is_none")]
pub include: Option<Json>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<Json>,
#[serde(skip_serializing_if = "Option::is_none")]
pub service_tier: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tool_calls: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_logprobs: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(flatten)]
pub extra: serde_json::Map<String, Json>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum Message {
System {
content: MessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
User {
content: MessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
Assistant {
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<MessageContent>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
Tool {
content: MessageContent,
tool_call_id: String,
},
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Parts(Vec<ContentPart>),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPart {
Text {
text: String,
},
ImageUrl {
image_url: OpenAiImageUrl,
},
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct OpenAiImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: FunctionCall,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolDefinition {
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionDefinition,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<Json>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ToolChoice {
Auto,
None,
Required,
#[serde(untagged)]
Specific(ToolChoiceFunction),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolChoiceFunction {
#[serde(rename = "type")]
pub choice_type: String,
pub function: ToolChoiceFunctionName,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolChoiceFunctionName {
pub name: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
pub struct GenerationParams {
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
}
impl AnnotatedLlmRequest {
pub fn system_prompt(&self) -> Option<&str> {
self.messages.iter().find_map(|m| match m {
Message::System { content, .. } => match content {
MessageContent::Text(s) => Some(s.as_str()),
MessageContent::Parts(parts) => parts.iter().find_map(|p| match p {
ContentPart::Text { text } => Some(text.as_str()),
ContentPart::ImageUrl { .. } => None,
}),
},
_ => None,
})
}
pub fn last_user_message(&self) -> Option<&str> {
self.messages.iter().rev().find_map(|m| match m {
Message::User { content, .. } => match content {
MessageContent::Text(s) => Some(s.as_str()),
MessageContent::Parts(parts) => parts.iter().find_map(|p| match p {
ContentPart::Text { text } => Some(text.as_str()),
ContentPart::ImageUrl { .. } => None,
}),
},
_ => None,
})
}
pub fn has_tool_calls(&self) -> bool {
self.messages.iter().any(|m| {
matches!(
m,
Message::Assistant { tool_calls: Some(calls), .. } if !calls.is_empty()
)
})
}
}
#[cfg(test)]
#[path = "../../tests/unit/codec/request_tests.rs"]
mod tests;