1use serde::{Deserialize, Serialize};
2
3use crate::prompt_block::PromptBlock;
4
5#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
6#[serde(rename_all = "lowercase")]
7pub enum ChatRole {
8 System,
9 User,
10 Assistant,
11 Tool,
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ChatMessage {
16 pub role: ChatRole,
17 pub content: String,
18 #[serde(skip_serializing_if = "Option::is_none")]
20 pub tool_call_id: Option<String>,
21 #[serde(skip_serializing_if = "Option::is_none")]
23 pub name: Option<String>,
24 #[serde(default, skip_serializing_if = "Vec::is_empty")]
29 pub tool_calls: Vec<ToolCall>,
30 #[serde(default, skip_serializing_if = "Vec::is_empty")]
34 pub attachments: Vec<Attachment>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct Attachment {
42 pub kind: String,
44 pub mime_type: String,
46 #[serde(flatten)]
47 pub data: AttachmentData,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51#[serde(rename_all = "snake_case")]
52pub enum AttachmentData {
53 Base64 { base64: String },
55 Url { url: String },
57 Path { path: String },
60}
61
62impl Attachment {
63 pub fn image_path(mime_type: impl Into<String>, path: impl Into<String>) -> Self {
64 Self {
65 kind: "image".into(),
66 mime_type: mime_type.into(),
67 data: AttachmentData::Path { path: path.into() },
68 }
69 }
70 pub fn image_url(mime_type: impl Into<String>, url: impl Into<String>) -> Self {
71 Self {
72 kind: "image".into(),
73 mime_type: mime_type.into(),
74 data: AttachmentData::Url { url: url.into() },
75 }
76 }
77 pub fn image_base64(mime_type: impl Into<String>, base64: impl Into<String>) -> Self {
78 Self {
79 kind: "image".into(),
80 mime_type: mime_type.into(),
81 data: AttachmentData::Base64 {
82 base64: base64.into(),
83 },
84 }
85 }
86
87 pub fn materialize(&mut self) -> anyhow::Result<()> {
91 if let AttachmentData::Path { path } = &self.data {
92 use base64::Engine;
93 let bytes =
94 std::fs::read(path).map_err(|e| anyhow::anyhow!("read attachment {path}: {e}"))?;
95 let encoded = base64::engine::general_purpose::STANDARD.encode(bytes);
96 self.data = AttachmentData::Base64 { base64: encoded };
97 }
98 Ok(())
99 }
100}
101
102impl ChatMessage {
103 pub fn system(content: impl Into<String>) -> Self {
104 Self {
105 role: ChatRole::System,
106 content: content.into(),
107 tool_call_id: None,
108 name: None,
109 tool_calls: Vec::new(),
110 attachments: Vec::new(),
111 }
112 }
113
114 pub fn user(content: impl Into<String>) -> Self {
115 Self {
116 role: ChatRole::User,
117 content: content.into(),
118 tool_call_id: None,
119 name: None,
120 tool_calls: Vec::new(),
121 attachments: Vec::new(),
122 }
123 }
124
125 pub fn assistant(content: impl Into<String>) -> Self {
126 Self {
127 role: ChatRole::Assistant,
128 content: content.into(),
129 tool_call_id: None,
130 name: None,
131 tool_calls: Vec::new(),
132 attachments: Vec::new(),
133 }
134 }
135
136 pub fn tool_result(
137 tool_call_id: impl Into<String>,
138 name: impl Into<String>,
139 content: impl Into<String>,
140 ) -> Self {
141 Self {
142 role: ChatRole::Tool,
143 content: content.into(),
144 tool_call_id: Some(tool_call_id.into()),
145 name: Some(name.into()),
146 tool_calls: Vec::new(),
147 attachments: Vec::new(),
148 }
149 }
150
151 pub fn assistant_tool_calls(calls: Vec<ToolCall>, text: impl Into<String>) -> Self {
156 Self {
157 role: ChatRole::Assistant,
158 content: text.into(),
159 tool_call_id: None,
160 name: None,
161 tool_calls: calls,
162 attachments: Vec::new(),
163 }
164 }
165}
166
167#[derive(Debug, Clone, Default, PartialEq, Eq)]
169pub enum ToolChoice {
170 #[default]
172 Auto,
173 Any,
175 None,
177 Specific(String),
179}
180
181#[derive(Debug, Clone)]
182pub struct ChatRequest {
183 pub model: String,
184 pub messages: Vec<ChatMessage>,
185 pub tools: Vec<ToolDef>,
186 pub max_tokens: u32,
187 pub temperature: f32,
188 pub system_prompt: Option<String>,
189 pub stop_sequences: Vec<String>,
192 pub tool_choice: ToolChoice,
195 pub system_blocks: Vec<PromptBlock>,
202 pub cache_tools: bool,
207}
208
209impl ChatRequest {
210 pub fn new(model: impl Into<String>, messages: Vec<ChatMessage>) -> Self {
211 Self {
212 model: model.into(),
213 messages,
214 tools: vec![],
215 max_tokens: 4096,
216 temperature: 0.7,
217 system_prompt: None,
218 tool_choice: ToolChoice::Auto,
219 stop_sequences: Vec::new(),
220 system_blocks: Vec::new(),
221 cache_tools: false,
222 }
223 }
224}
225
226#[derive(Debug, Clone)]
227pub struct ChatResponse {
228 pub content: ResponseContent,
229 pub usage: TokenUsage,
230 pub finish_reason: FinishReason,
231 pub cache_usage: Option<CacheUsage>,
235}
236
237#[derive(Debug, Clone, Default, PartialEq, Eq)]
249pub struct CacheUsage {
250 pub cache_read_input_tokens: u32,
251 pub cache_creation_input_tokens: u32,
252 pub input_tokens: u32,
253 pub output_tokens: u32,
254}
255
256impl CacheUsage {
257 pub fn hit_ratio(&self) -> f32 {
260 let denom =
261 self.cache_read_input_tokens + self.cache_creation_input_tokens + self.input_tokens;
262 if denom == 0 {
263 return 0.0;
264 }
265 self.cache_read_input_tokens as f32 / denom as f32
266 }
267}
268
269#[derive(Debug, Clone)]
270pub enum ResponseContent {
271 Text(String),
272 ToolCalls(Vec<ToolCall>),
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
276pub struct ToolCall {
277 pub id: String,
278 pub name: String,
279 pub arguments: serde_json::Value,
280}
281
282#[derive(Debug, Clone, Default)]
283pub struct TokenUsage {
284 pub prompt_tokens: u32,
285 pub completion_tokens: u32,
286}
287
288#[derive(Debug, Clone, PartialEq, Eq)]
289pub enum FinishReason {
290 Stop,
291 ToolUse,
292 Length,
293 Other(String),
294}
295
296#[derive(Debug, Clone, Serialize, Deserialize)]
297pub struct ToolDef {
298 pub name: String,
299 pub description: String,
300 pub parameters: serde_json::Value,
301}
302
303impl ToolDef {
304 pub const MAX_NAME_LEN: usize = 64;
310
311 pub fn fit_name(prefix: &str, id: &str, tool: &str) -> String {
318 let full = format!("{prefix}{id}_{tool}");
319 if full.len() <= Self::MAX_NAME_LEN {
320 return full;
321 }
322 use sha2::{Digest, Sha256};
323 let mut h = Sha256::new();
324 h.update(id.as_bytes());
325 h.update([0u8]);
326 h.update(tool.as_bytes());
327 let digest = h.finalize();
328 let hash = &hex::encode(digest)[..6];
329
330 let fixed = prefix.len() + id.len() + 1 + 1 + 6;
332 if fixed <= Self::MAX_NAME_LEN {
333 let budget = Self::MAX_NAME_LEN - fixed;
334 let head: String = tool.chars().take(budget).collect();
335 return format!("{prefix}{id}_{head}_{hash}");
336 }
337 let id_budget = Self::MAX_NAME_LEN.saturating_sub(prefix.len() + 1 + 6);
339 let id_head: String = id.chars().take(id_budget).collect();
340 format!("{prefix}{id_head}_{hash}")
341 }
342}
343
344#[cfg(test)]
345mod fit_name_tests {
346 use super::ToolDef;
347
348 #[test]
349 fn passthrough_when_short() {
350 assert_eq!(ToolDef::fit_name("ext_", "echo", "say"), "ext_echo_say");
351 }
352
353 #[test]
354 fn hashes_overflow_tool_and_fits() {
355 let tool = "long_".repeat(30);
356 let name = ToolDef::fit_name("ext_", "mybot", &tool);
357 assert!(name.starts_with("ext_mybot_"));
358 assert_eq!(name.len(), ToolDef::MAX_NAME_LEN);
359 }
360
361 #[test]
362 fn different_inputs_yield_different_hashes() {
363 let tool_a = "process_data_batch_".to_string() + &"x".repeat(60);
364 let tool_b = "process_data_batch_".to_string() + &"y".repeat(60);
365 let a = ToolDef::fit_name("ext_", "mybot", &tool_a);
366 let b = ToolDef::fit_name("ext_", "mybot", &tool_b);
367 assert_ne!(a, b);
368 assert_eq!(a.len(), ToolDef::MAX_NAME_LEN);
369 assert_eq!(b.len(), ToolDef::MAX_NAME_LEN);
370 }
371
372 #[test]
373 fn handles_id_that_busts_budget() {
374 let id = "x".repeat(60);
375 let name = ToolDef::fit_name("ext_", &id, "t");
376 assert!(name.starts_with("ext_"));
377 assert!(name.len() <= ToolDef::MAX_NAME_LEN);
378 }
379
380 #[test]
381 fn is_deterministic() {
382 let long = "a".repeat(80);
383 let a = ToolDef::fit_name("mcp_", "server", &long);
384 let b = ToolDef::fit_name("mcp_", "server", &long);
385 assert_eq!(a, b);
386 }
387}