use serde::{Deserialize, Serialize};
use crate::json::Json;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AnnotatedLlmRequest {
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<GenerationParams>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ToolDefinition>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(flatten)]
pub extra: serde_json::Map<String, Json>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum Message {
System {
content: MessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
User {
content: MessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
Assistant {
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<MessageContent>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
Tool {
content: MessageContent,
tool_call_id: String,
},
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Parts(Vec<ContentPart>),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPart {
Text {
text: String,
},
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: FunctionCall,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolDefinition {
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionDefinition,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<Json>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ToolChoice {
Auto,
None,
Required,
#[serde(untagged)]
Specific(ToolChoiceFunction),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolChoiceFunction {
#[serde(rename = "type")]
pub choice_type: String,
pub function: ToolChoiceFunctionName,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolChoiceFunctionName {
pub name: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
pub struct GenerationParams {
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
}
impl AnnotatedLlmRequest {
pub fn system_prompt(&self) -> Option<&str> {
self.messages.iter().find_map(|m| match m {
Message::System { content, .. } => match content {
MessageContent::Text(s) => Some(s.as_str()),
MessageContent::Parts(parts) => parts
.iter()
.map(|p| {
let ContentPart::Text { text } = p;
text.as_str()
})
.next(),
},
_ => None,
})
}
pub fn last_user_message(&self) -> Option<&str> {
self.messages.iter().rev().find_map(|m| match m {
Message::User { content, .. } => match content {
MessageContent::Text(s) => Some(s.as_str()),
MessageContent::Parts(parts) => parts
.iter()
.map(|p| {
let ContentPart::Text { text } = p;
text.as_str()
})
.next(),
},
_ => None,
})
}
pub fn has_tool_calls(&self) -> bool {
self.messages.iter().any(|m| {
matches!(
m,
Message::Assistant { tool_calls: Some(calls), .. } if !calls.is_empty()
)
})
}
}
#[cfg(test)]
#[path = "../../tests/unit/codec/request_tests.rs"]
mod tests;