use std::collections::BTreeMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::protocol::ProviderProtocol;
pub type VendorExtensions = BTreeMap<String, Value>;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct LlmRequest {
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub input: Vec<RequestItem>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub messages: Vec<Message>,
#[serde(default, skip_serializing_if = "CapabilitySet::is_empty")]
pub capabilities: CapabilitySet,
#[serde(default, skip_serializing_if = "GenerationConfig::is_default")]
pub generation: GenerationConfig,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub metadata: VendorExtensions,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub vendor_extensions: VendorExtensions,
}
impl LlmRequest {
pub fn normalized_input(&self) -> Vec<RequestItem> {
if !self.input.is_empty() {
return self.input.clone();
}
self.messages
.iter()
.cloned()
.map(RequestItem::from)
.collect()
}
pub fn normalized_messages(&self) -> Vec<Message> {
if !self.messages.is_empty() {
return self.messages.clone();
}
self.input
.iter()
.filter_map(RequestItem::as_message)
.cloned()
.collect()
}
pub fn normalized_instructions(&self) -> Option<String> {
if self.instructions.is_some() {
return self.instructions.clone();
}
let folded = self
.normalized_messages()
.into_iter()
.filter(|message| matches!(message.role, MessageRole::System | MessageRole::Developer))
.map(|message| message.plain_text())
.filter(|text| !text.is_empty())
.collect::<Vec<_>>()
.join("\n\n");
if folded.is_empty() {
None
} else {
Some(folded)
}
}
pub fn estimated_prompt_tokens(&self) -> u32 {
let mut chars = self
.normalized_input()
.iter()
.map(RequestItem::estimated_chars)
.sum::<usize>();
if let Some(instructions) = self.normalized_instructions() {
chars += instructions.len();
}
(chars / 4).max(1) as u32
}
pub fn estimated_tokens(&self) -> u32 {
self.estimated_prompt_tokens() + self.generation.max_output_tokens.unwrap_or(1024)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum RequestItem {
Message { message: Message },
ToolResult {
#[serde(flatten)]
result: ToolResultPart,
},
}
impl RequestItem {
fn estimated_chars(&self) -> usize {
match self {
Self::Message { message } => message.estimated_chars(),
Self::ToolResult { result } => result.output.to_string().len(),
}
}
pub(crate) fn as_message(&self) -> Option<&Message> {
match self {
Self::Message { message } => Some(message),
Self::ToolResult { .. } => None,
}
}
}
impl From<Message> for RequestItem {
fn from(message: Message) -> Self {
Self::Message { message }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: MessageRole,
pub parts: Vec<MessagePart>,
#[serde(skip_serializing_if = "Option::is_none")]
pub raw_message: Option<String>,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub vendor_extensions: VendorExtensions,
}
impl Message {
pub fn text(role: MessageRole, text: impl Into<String>) -> Self {
Self {
role,
parts: vec![MessagePart::Text { text: text.into() }],
raw_message: None,
vendor_extensions: VendorExtensions::new(),
}
}
pub fn plain_text(&self) -> String {
self.parts
.iter()
.filter_map(MessagePart::plain_text)
.collect::<Vec<_>>()
.join("")
}
fn estimated_chars(&self) -> usize {
self.parts.iter().map(MessagePart::estimated_chars).sum()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MessageRole {
Developer,
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum MessagePart {
Text {
text: String,
},
ImageUrl {
url: String,
#[serde(skip_serializing_if = "Option::is_none")]
detail: Option<String>,
},
ImageBase64 {
data: String,
#[serde(skip_serializing_if = "Option::is_none")]
media_type: Option<String>,
},
Audio {
data: String,
#[serde(skip_serializing_if = "Option::is_none")]
media_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
transcript: Option<String>,
},
File {
#[serde(skip_serializing_if = "Option::is_none")]
file_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
media_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
data: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
filename: Option<String>,
},
Json {
value: Value,
},
ToolCall {
#[serde(flatten)]
call: ToolCallPart,
},
ToolResult {
#[serde(flatten)]
result: ToolResultPart,
},
Reasoning {
text: String,
},
Refusal {
text: String,
},
}
impl MessagePart {
pub(crate) fn plain_text(&self) -> Option<&str> {
match self {
Self::Text { text } | Self::Reasoning { text } | Self::Refusal { text } => {
Some(text.as_str())
}
Self::Audio {
transcript: Some(text),
..
} => Some(text.as_str()),
_ => None,
}
}
fn estimated_chars(&self) -> usize {
match self {
Self::Text { text } => text.len(),
Self::ImageUrl { url, .. } => url.len(),
Self::ImageBase64 { data, .. } => data.len() / 8,
Self::Audio {
data, transcript, ..
} => transcript
.as_ref()
.map_or(data.len() / 8, std::string::String::len),
Self::File { data, filename, .. } => data.as_ref().map_or_else(
|| filename.as_ref().map_or(0, std::string::String::len),
|d| d.len() / 8,
),
Self::Json { value } => value.to_string().len(),
Self::ToolCall { call } => call.arguments.to_string().len() + call.name.len(),
Self::ToolResult { result } => result.output.to_string().len(),
Self::Reasoning { text } | Self::Refusal { text } => text.len(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallPart {
pub call_id: String,
pub name: String,
pub arguments: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResultPart {
pub call_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
pub output: Value,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub is_error: bool,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CapabilitySet {
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<ToolDefinition>,
#[serde(skip_serializing_if = "Option::is_none")]
pub structured_output: Option<StructuredOutputConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning: Option<ReasoningCapability>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub modalities: Vec<OutputModality>,
#[serde(skip_serializing_if = "Option::is_none")]
pub safety: Option<SafetySettings>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache: Option<CacheSettings>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub builtin_tools: Vec<BuiltinTool>,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub vendor_extensions: VendorExtensions,
}
impl CapabilitySet {
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
&& self.structured_output.is_none()
&& self.reasoning.is_none()
&& self.modalities.is_empty()
&& self.safety.is_none()
&& self.cache.is_none()
&& self.builtin_tools.is_empty()
&& self.vendor_extensions.is_empty()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub input_schema: Value,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub strict: bool,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub vendor_extensions: VendorExtensions,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum BuiltinTool {
WebSearch,
FileSearch,
CodeExecution,
ComputerUse,
UrlContext,
Maps,
Mcp {
#[serde(skip_serializing_if = "Option::is_none")]
server_label: Option<String>,
},
Vendor {
name: String,
#[serde(default, skip_serializing_if = "Value::is_null")]
payload: Value,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StructuredOutputConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
pub schema: Value,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub strict: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReasoningCapability {
#[serde(skip_serializing_if = "Option::is_none")]
pub effort: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub summary: Option<String>,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub vendor_extensions: VendorExtensions,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum OutputModality {
Text,
Image,
Audio,
Json,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SafetySettings {
#[serde(skip_serializing_if = "Option::is_none")]
pub policy: Option<String>,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub vendor_extensions: VendorExtensions,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheSettings {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub vendor_extensions: VendorExtensions,
}
fn default_true() -> bool {
true
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct GenerationConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub stop_sequences: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u64>,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub vendor_extensions: VendorExtensions,
}
impl GenerationConfig {
pub fn is_default(&self) -> bool {
self.max_output_tokens.is_none()
&& self.temperature.is_none()
&& self.top_p.is_none()
&& self.top_k.is_none()
&& self.stop_sequences.is_empty()
&& self.presence_penalty.is_none()
&& self.frequency_penalty.is_none()
&& self.seed.is_none()
&& self.vendor_extensions.is_empty()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmResponse {
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub output: Vec<ResponseItem>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub messages: Vec<Message>,
#[serde(default, skip_serializing_if = "String::is_empty")]
pub content_text: String,
pub usage: TokenUsage,
pub model: String,
pub provider_protocol: ProviderProtocol,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<FinishReason>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_id: Option<String>,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub vendor_extensions: VendorExtensions,
}
impl LlmResponse {
pub fn from_message(
provider_protocol: ProviderProtocol,
model: impl Into<String>,
message: Message,
usage: TokenUsage,
) -> Self {
let content_text = message.plain_text();
Self {
output: vec![ResponseItem::Message {
message: message.clone(),
}],
messages: vec![message],
content_text,
usage,
model: model.into(),
provider_protocol,
finish_reason: None,
response_id: None,
vendor_extensions: VendorExtensions::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponseItem {
Message {
message: Message,
},
ToolCall {
#[serde(flatten)]
call: ToolCallPart,
},
ToolResult {
#[serde(flatten)]
result: ToolResultPart,
},
Reasoning {
text: String,
},
Refusal {
text: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FinishReason {
Stop,
Length,
ToolCall,
ContentFilter,
Cancelled,
Error,
Other(String),
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub total_tokens: Option<u32>,
}
impl TokenUsage {
pub fn total(&self) -> u32 {
self.total_tokens
.unwrap_or(self.prompt_tokens + self.completion_tokens)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum LlmStreamEvent {
ResponseStarted {
#[serde(skip_serializing_if = "Option::is_none")]
response_id: Option<String>,
model: String,
provider_protocol: ProviderProtocol,
},
OutputItemAdded {
item: ResponseItem,
},
ContentPartAdded {
part: MessagePart,
},
TextDelta {
delta: String,
},
ToolCallDelta {
call_id: String,
name: String,
delta: String,
},
ToolResult {
result: ToolResultPart,
},
ReasoningDelta {
delta: String,
},
Usage {
usage: TokenUsage,
},
Completed {
response: LlmResponse,
},
Error {
message: String,
},
}