use crate::error::ProtoError;
use crate::v2::attachment::Attachment;
use crate::v2::tool::{Tool, ToolCallId, ToolUseInput};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum RoleV2 {
System,
User,
Assistant,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
Text {
text: String,
},
Image {
attachment_id: String,
},
Audio {
attachment_id: String,
},
Video {
attachment_id: String,
},
ToolUse {
tool_call_id: ToolCallId,
name: String,
input: ToolUseInput,
},
ToolResult {
tool_call_id: ToolCallId,
content: Vec<ContentBlock>,
},
#[serde(other)]
Unknown,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MessageV2 {
pub role: RoleV2,
pub content: Vec<ContentBlock>,
}
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
pub struct RequestV2 {
#[serde(default, skip_serializing_if = "String::is_empty")]
pub id: String,
pub messages: Vec<MessageV2>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub attachments: Vec<Attachment>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<Tool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct ResolvedV2 {
pub id: String,
pub messages: Vec<MessageV2>,
pub attachments: Vec<Attachment>,
pub tools: Vec<Tool>,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub top_k: Option<u32>,
pub max_tokens: Option<u32>,
pub stream: Option<bool>,
}
impl RequestV2 {
pub fn resolve(self) -> Result<ResolvedV2, ProtoError> {
if self.messages.is_empty() {
return Err(ProtoError::InvalidRequest(
"messages must not be empty".into(),
));
}
let mut attachments_by_id: HashMap<&str, &Attachment> = HashMap::new();
for att in &self.attachments {
if matches!(att, Attachment::Unknown) {
return Err(ProtoError::InvalidRequest(
"attachments contain an unknown kind".into(),
));
}
let id = att.id();
if id.is_empty() {
return Err(ProtoError::InvalidRequest(
"attachments must have non-empty id".into(),
));
}
if attachments_by_id.insert(id, att).is_some() {
return Err(ProtoError::InvalidRequest(format!(
"duplicate attachment id: {id}"
)));
}
}
let mut tool_names: HashSet<&str> = HashSet::new();
for tool in &self.tools {
if !tool_names.insert(tool.name.as_str()) {
return Err(ProtoError::InvalidRequest(format!(
"duplicate tool name: {}",
tool.name
)));
}
}
for (mi, msg) in self.messages.iter().enumerate() {
if msg.content.is_empty() {
return Err(ProtoError::InvalidRequest(format!(
"messages[{mi}].content must not be empty"
)));
}
validate_content_blocks(&msg.content, mi, &attachments_by_id, &tool_names)?;
}
Ok(ResolvedV2 {
id: self.id,
messages: self.messages,
attachments: self.attachments,
tools: self.tools,
temperature: self.temperature,
top_p: self.top_p,
top_k: self.top_k,
max_tokens: self.max_tokens,
stream: self.stream,
})
}
}
fn validate_content_blocks(
blocks: &[ContentBlock],
msg_index: usize,
attachments_by_id: &HashMap<&str, &Attachment>,
tool_names: &HashSet<&str>,
) -> Result<(), ProtoError> {
for (bi, block) in blocks.iter().enumerate() {
match block {
ContentBlock::Text { .. } => {}
ContentBlock::Image { attachment_id } => check_kind(
msg_index,
bi,
attachment_id,
attachments_by_id,
Attachment::is_image,
"image",
)?,
ContentBlock::Audio { attachment_id } => check_kind(
msg_index,
bi,
attachment_id,
attachments_by_id,
Attachment::is_audio,
"audio",
)?,
ContentBlock::Video { attachment_id } => check_kind(
msg_index,
bi,
attachment_id,
attachments_by_id,
Attachment::is_video,
"video",
)?,
ContentBlock::ToolUse { name, .. } => {
let _ = (name, tool_names);
}
ContentBlock::ToolResult { content, .. } => {
validate_content_blocks(content, msg_index, attachments_by_id, tool_names)?;
}
ContentBlock::Unknown => {
return Err(ProtoError::InvalidRequest(format!(
"messages[{msg_index}].content[{bi}] uses unknown content-block type"
)));
}
}
}
Ok(())
}
fn check_kind(
msg_index: usize,
block_index: usize,
attachment_id: &str,
attachments_by_id: &HashMap<&str, &Attachment>,
pred: fn(&Attachment) -> bool,
expected: &str,
) -> Result<(), ProtoError> {
let att = attachments_by_id.get(attachment_id).ok_or_else(|| {
ProtoError::InvalidRequest(format!(
"messages[{msg_index}].content[{block_index}] references unknown attachment_id {attachment_id:?}"
))
})?;
if !pred(att) {
return Err(ProtoError::InvalidRequest(format!(
"messages[{msg_index}].content[{block_index}] block expects {expected} attachment but {attachment_id:?} is a different kind"
)));
}
Ok(())
}