1use crate::events::TokenUsage;
11use crate::types::{ContentBlock, Message, StopReason, ToolDefinition};
12
13#[derive(Debug, Clone)]
15pub struct ModelRequest {
16 pub messages: Vec<Message>,
17 pub system_prompt: Option<String>,
18 pub max_tokens: i32,
19 pub temperature: Option<f32>,
20 pub top_p: Option<f32>,
21 pub tools: Vec<ToolDefinition>,
22}
23
24#[derive(Debug, Clone)]
26pub struct ModelResponse {
27 pub message: Message,
29 pub stop_reason: StopReason,
31 pub usage: Option<TokenUsage>,
33}
34
35pub trait Model: Send + Sync {
41 fn name(&self) -> &'static str;
43
44 fn max_context_tokens(&self) -> usize;
46
47 fn max_output_tokens(&self) -> usize;
49
50 fn estimate_token_count(&self, text: &str) -> usize;
56
57 fn estimate_message_tokens(&self, messages: &[Message]) -> usize {
62 let mut total = 0;
63 for message in messages {
64 total += 4;
66 for block in &message.content {
68 total += self.estimate_content_block_tokens(block);
69 }
70 }
71 total
72 }
73
74 fn estimate_content_block_tokens(&self, block: &ContentBlock) -> usize {
76 match block {
77 ContentBlock::Text(text) => self.estimate_token_count(text),
78 ContentBlock::ToolUse(tool_use) => {
79 self.estimate_token_count(&tool_use.name)
81 + self.estimate_token_count(&tool_use.id)
82 + self.estimate_token_count(&tool_use.input.to_string())
83 + 10 }
85 ContentBlock::ToolResult(result) => {
86 self.estimate_token_count(&result.tool_use_id)
88 + match &result.content {
89 crate::tool::ToolResult::Text(t) => self.estimate_token_count(t.as_str()),
90 crate::tool::ToolResult::Json(v) => {
91 self.estimate_token_count(&v.to_string())
92 }
93 crate::tool::ToolResult::Image { data, .. } => {
94 data.len() / 750 + 85 }
97 crate::tool::ToolResult::Document { data, .. } => {
98 data.len() / 500 + 50 }
101 }
102 + 10 }
104 ContentBlock::Thinking {
105 thinking,
106 signature,
107 } => {
108 self.estimate_token_count(thinking) + self.estimate_token_count(signature) + 10
110 }
111 }
112 }
113}
114
115#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
126pub enum InferenceProfile {
127 #[default]
132 None,
133
134 US,
136
137 EU,
139
140 APAC,
142
143 Global,
147}
148
149impl InferenceProfile {
150 pub fn apply_to(&self, base_model_id: &str) -> String {
154 match self.prefix() {
155 Some(prefix) => format!("{}.{}", prefix, base_model_id),
156 None => base_model_id.to_string(),
157 }
158 }
159
160 fn prefix(&self) -> Option<&'static str> {
162 match self {
163 InferenceProfile::None => None,
164 InferenceProfile::US => Some("us"),
165 InferenceProfile::EU => Some("eu"),
166 InferenceProfile::APAC => Some("apac"),
167 InferenceProfile::Global => Some("global"),
168 }
169 }
170}
171
172pub trait BedrockModel: Model {
176 fn bedrock_id(&self) -> &'static str;
181
182 fn default_inference_profile(&self) -> InferenceProfile {
188 InferenceProfile::None
189 }
190}
191
192pub trait AnthropicModel: Model {
196 fn anthropic_id(&self) -> &'static str;
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use crate::tool::{DocumentFormat, ImageFormat, ToolResult};
206 use crate::types::{
207 ContentBlock, Message, Role, ToolResultBlock, ToolResultStatus, ToolUseBlock,
208 };
209
210 struct TestModel;
212
213 impl Model for TestModel {
214 fn name(&self) -> &'static str {
215 "TestModel"
216 }
217
218 fn max_context_tokens(&self) -> usize {
219 100_000
220 }
221
222 fn max_output_tokens(&self) -> usize {
223 4096
224 }
225
226 fn estimate_token_count(&self, text: &str) -> usize {
227 text.len().div_ceil(4)
229 }
230 }
231
232 #[test]
235 fn test_estimate_message_tokens_empty() {
236 let model = TestModel;
237 let messages: Vec<Message> = vec![];
238 assert_eq!(model.estimate_message_tokens(&messages), 0);
239 }
240
241 #[test]
242 fn test_estimate_message_tokens_simple_text() {
243 let model = TestModel;
244 let messages = vec![Message::user("Hello world")]; let tokens = model.estimate_message_tokens(&messages);
247 assert_eq!(tokens, 7);
248 }
249
250 #[test]
251 fn test_estimate_message_tokens_multiple_messages() {
252 let model = TestModel;
253 let messages = vec![
254 Message::user("Hello"), Message::assistant("Hi there"), ];
257
258 let tokens = model.estimate_message_tokens(&messages);
259 assert_eq!(tokens, 12);
260 }
261
262 #[test]
263 fn test_estimate_content_block_tokens_text() {
264 let model = TestModel;
265 let block = ContentBlock::Text("test".to_string()); assert_eq!(model.estimate_content_block_tokens(&block), 1);
267 }
268
269 #[test]
270 fn test_estimate_content_block_tokens_text_empty() {
271 let model = TestModel;
272 let block = ContentBlock::Text(String::new());
273 assert_eq!(model.estimate_content_block_tokens(&block), 0);
274 }
275
276 #[test]
277 fn test_estimate_content_block_tokens_tool_use() {
278 let model = TestModel;
279 let block = ContentBlock::ToolUse(ToolUseBlock {
280 id: "id12".to_string(), name: "search".to_string(), input: serde_json::json!({"q": "x"}), });
284
285 let tokens = model.estimate_content_block_tokens(&block);
287 assert!(tokens >= 10, "Should include overhead, got {}", tokens);
288 }
289
290 #[test]
291 fn test_estimate_content_block_tokens_tool_result_text() {
292 let model = TestModel;
293 let block = ContentBlock::ToolResult(ToolResultBlock {
294 tool_use_id: "id12".to_string(), content: ToolResult::Text("result text".to_string()), status: ToolResultStatus::Success,
297 });
298
299 let tokens = model.estimate_content_block_tokens(&block);
301 assert!(tokens >= 10, "Should include overhead, got {}", tokens);
302 }
303
304 #[test]
305 fn test_estimate_content_block_tokens_tool_result_json() {
306 let model = TestModel;
307 let block = ContentBlock::ToolResult(ToolResultBlock {
308 tool_use_id: "id".to_string(),
309 content: ToolResult::Json(serde_json::json!({"key": "value"})),
310 status: ToolResultStatus::Success,
311 });
312
313 let tokens = model.estimate_content_block_tokens(&block);
314 assert!(tokens >= 10, "Should include overhead, got {}", tokens);
315 }
316
317 #[test]
318 fn test_estimate_content_block_tokens_image() {
319 let model = TestModel;
320 let data = vec![0u8; 7500];
322 let block = ContentBlock::ToolResult(ToolResultBlock {
323 tool_use_id: "img".to_string(),
324 content: ToolResult::Image {
325 format: ImageFormat::Png,
326 data,
327 },
328 status: ToolResultStatus::Success,
329 });
330
331 let tokens = model.estimate_content_block_tokens(&block);
332 assert!(
334 tokens >= 95,
335 "Expected at least 95 tokens for image, got {}",
336 tokens
337 );
338 }
339
340 #[test]
341 fn test_estimate_content_block_tokens_document() {
342 let model = TestModel;
343 let data = vec![0u8; 5000];
345 let block = ContentBlock::ToolResult(ToolResultBlock {
346 tool_use_id: "doc".to_string(),
347 content: ToolResult::Document {
348 format: DocumentFormat::Pdf,
349 data,
350 name: Some("test.pdf".to_string()),
351 },
352 status: ToolResultStatus::Success,
353 });
354
355 let tokens = model.estimate_content_block_tokens(&block);
356 assert!(
358 tokens >= 60,
359 "Expected at least 60 tokens for document, got {}",
360 tokens
361 );
362 }
363
364 #[test]
365 fn test_estimate_content_block_tokens_thinking() {
366 let model = TestModel;
367 let block = ContentBlock::Thinking {
368 thinking: "complex reasoning here".to_string(), signature: "sig".to_string(), };
371
372 let tokens = model.estimate_content_block_tokens(&block);
374 assert!(tokens >= 10, "Should include overhead, got {}", tokens);
375 }
376
377 #[test]
378 fn test_estimate_message_with_multiple_content_blocks() {
379 let model = TestModel;
380 let messages = vec![Message {
381 role: Role::Assistant,
382 content: vec![
383 ContentBlock::Text("Let me search".to_string()),
384 ContentBlock::ToolUse(ToolUseBlock {
385 id: "1".to_string(),
386 name: "search".to_string(),
387 input: serde_json::json!({"q": "test"}),
388 }),
389 ],
390 }];
391
392 let tokens = model.estimate_message_tokens(&messages);
393 assert!(tokens > 4, "Should have content tokens plus overhead");
395 }
396
397 #[test]
400 fn test_inference_profile_apply_none() {
401 let profile = InferenceProfile::None;
402 assert_eq!(profile.apply_to("anthropic.claude-3"), "anthropic.claude-3");
403 }
404
405 #[test]
406 fn test_inference_profile_apply_us() {
407 let profile = InferenceProfile::US;
408 assert_eq!(
409 profile.apply_to("anthropic.claude-3"),
410 "us.anthropic.claude-3"
411 );
412 }
413
414 #[test]
415 fn test_inference_profile_apply_eu() {
416 let profile = InferenceProfile::EU;
417 assert_eq!(
418 profile.apply_to("anthropic.claude-3"),
419 "eu.anthropic.claude-3"
420 );
421 }
422
423 #[test]
424 fn test_inference_profile_apply_apac() {
425 let profile = InferenceProfile::APAC;
426 assert_eq!(profile.apply_to("model-id"), "apac.model-id");
427 }
428
429 #[test]
430 fn test_inference_profile_apply_global() {
431 let profile = InferenceProfile::Global;
432 assert_eq!(profile.apply_to("model-id"), "global.model-id");
433 }
434
435 #[test]
436 fn test_inference_profile_all_variants() {
437 let cases = [
438 (InferenceProfile::None, "model", "model"),
439 (InferenceProfile::US, "model", "us.model"),
440 (InferenceProfile::EU, "model", "eu.model"),
441 (InferenceProfile::APAC, "model", "apac.model"),
442 (InferenceProfile::Global, "model", "global.model"),
443 ];
444
445 for (profile, base, expected) in cases {
446 assert_eq!(profile.apply_to(base), expected, "Failed for {:?}", profile);
447 }
448 }
449
450 #[test]
451 fn test_inference_profile_default() {
452 let profile = InferenceProfile::default();
453 assert_eq!(profile, InferenceProfile::None);
454 }
455}