Skip to main content

mixtape_core/
model.rs

1//! Model traits and types
2//!
3//! This module defines the core model abstraction:
4//! - `Model` trait for model metadata (name, token limits)
5//! - Provider-specific traits (`BedrockModel`, `AnthropicModel`) for API IDs
6//!
7//! Models are simple structs that implement these traits. All API interaction
8//! goes through the provider (e.g., `BedrockProvider`).
9
10use crate::events::TokenUsage;
11use crate::types::{ContentBlock, Message, StopReason, ToolDefinition};
12
13/// Request parameters for model completion
14#[derive(Debug, Clone)]
15pub struct ModelRequest {
16    pub messages: Vec<Message>,
17    pub system_prompt: Option<String>,
18    pub max_tokens: i32,
19    pub temperature: Option<f32>,
20    pub top_p: Option<f32>,
21    pub tools: Vec<ToolDefinition>,
22}
23
24/// Response from a model completion
25#[derive(Debug, Clone)]
26pub struct ModelResponse {
27    /// The assistant's response message
28    pub message: Message,
29    /// Why the model stopped generating
30    pub stop_reason: StopReason,
31    /// Token usage statistics (if provided by the model)
32    pub usage: Option<TokenUsage>,
33}
34
35/// Core model metadata trait
36///
37/// All models implement this to provide their capabilities.
38/// This is provider-agnostic - the same model has the same
39/// context window whether accessed via Bedrock or Anthropic.
40pub trait Model: Send + Sync {
41    /// Human-readable model name (e.g., "Claude Sonnet 4.5")
42    fn name(&self) -> &'static str;
43
44    /// Maximum input context tokens
45    fn max_context_tokens(&self) -> usize;
46
47    /// Maximum output tokens the model can generate
48    fn max_output_tokens(&self) -> usize;
49
50    /// Estimate token count for text
51    ///
52    /// Models should implement this to provide accurate token estimation.
53    /// A simple heuristic (~4 characters per token) works reasonably well
54    /// for most models but can be overridden with actual tokenization.
55    fn estimate_token_count(&self, text: &str) -> usize;
56
57    /// Estimate tokens for a conversation
58    ///
59    /// Default implementation sums token estimates for all content blocks
60    /// plus overhead for message structure.
61    fn estimate_message_tokens(&self, messages: &[Message]) -> usize {
62        let mut total = 0;
63        for message in messages {
64            // Role overhead (~4 tokens for role marker and structure)
65            total += 4;
66            // Content blocks
67            for block in &message.content {
68                total += self.estimate_content_block_tokens(block);
69            }
70        }
71        total
72    }
73
74    /// Estimate tokens for a single content block
75    fn estimate_content_block_tokens(&self, block: &ContentBlock) -> usize {
76        match block {
77            ContentBlock::Text(text) => self.estimate_token_count(text),
78            ContentBlock::ToolUse(tool_use) => {
79                // Tool name + ID + JSON input
80                self.estimate_token_count(&tool_use.name)
81                    + self.estimate_token_count(&tool_use.id)
82                    + self.estimate_token_count(&tool_use.input.to_string())
83                    + 10 // Structure overhead
84            }
85            ContentBlock::ToolResult(result) => {
86                // Tool use ID + content
87                self.estimate_token_count(&result.tool_use_id)
88                    + match &result.content {
89                        crate::tool::ToolResult::Text(t) => self.estimate_token_count(t.as_str()),
90                        crate::tool::ToolResult::Json(v) => {
91                            self.estimate_token_count(&v.to_string())
92                        }
93                        crate::tool::ToolResult::Image { data, .. } => {
94                            // Images are typically ~1 token per 750 bytes
95                            data.len() / 750 + 85 // Base overhead for image
96                        }
97                        crate::tool::ToolResult::Document { data, .. } => {
98                            // Documents vary; rough estimate
99                            data.len() / 500 + 50 // Base overhead for document
100                        }
101                    }
102                    + 10 // Structure overhead
103            }
104            ContentBlock::Thinking {
105                thinking,
106                signature,
107            } => {
108                // Estimate tokens for thinking content
109                self.estimate_token_count(thinking) + self.estimate_token_count(signature) + 10
110            }
111        }
112    }
113}
114
115/// Cross-region inference profile configuration for Bedrock
116///
117/// Inference profiles enable cross-region load balancing for higher throughput
118/// and improved reliability. When enabled, Bedrock automatically routes requests
119/// to the optimal region within the specified geographic scope.
120///
121/// Some newer models (Claude 4/4.5, Nova 2 Lite) require inference profiles
122/// and don't support direct single-region invocation.
123///
124/// See: <https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference.html>
125#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
126pub enum InferenceProfile {
127    /// No inference profile - single-region invocation (default)
128    ///
129    /// Requests go directly to the region configured in your AWS SDK.
130    /// Use this for predictable routing and when data locality is important.
131    #[default]
132    None,
133
134    /// US regions only (us-east-1, us-east-2, us-west-2, etc.)
135    US,
136
137    /// European regions only (eu-central-1, eu-west-1, eu-west-2, etc.)
138    EU,
139
140    /// Asia-Pacific regions (ap-northeast-1, ap-southeast-1, etc.)
141    APAC,
142
143    /// Global cross-region inference (all commercial AWS regions)
144    ///
145    /// Provides maximum throughput but may route to any region worldwide.
146    Global,
147}
148
149impl InferenceProfile {
150    /// Apply this inference profile to a base model ID
151    ///
152    /// Returns the full model ID to use with Bedrock API.
153    pub fn apply_to(&self, base_model_id: &str) -> String {
154        match self.prefix() {
155            Some(prefix) => format!("{}.{}", prefix, base_model_id),
156            None => base_model_id.to_string(),
157        }
158    }
159
160    /// Get the prefix for this inference profile, if any
161    fn prefix(&self) -> Option<&'static str> {
162        match self {
163            InferenceProfile::None => None,
164            InferenceProfile::US => Some("us"),
165            InferenceProfile::EU => Some("eu"),
166            InferenceProfile::APAC => Some("apac"),
167            InferenceProfile::Global => Some("global"),
168        }
169    }
170}
171
172/// Trait for models available on AWS Bedrock
173///
174/// Models implement this to be usable with `BedrockProvider`.
175pub trait BedrockModel: Model {
176    /// The Bedrock model ID
177    ///
178    /// This is the full model identifier used in Bedrock API calls,
179    /// e.g., "anthropic.claude-sonnet-4-5-20250929-v1:0"
180    fn bedrock_id(&self) -> &'static str;
181
182    /// The default inference profile for this model
183    ///
184    /// Models that require cross-region inference (Claude 4/4.5, Nova 2 Lite)
185    /// should return `InferenceProfile::Global`. Other models default to
186    /// `InferenceProfile::None` for single-region invocation.
187    fn default_inference_profile(&self) -> InferenceProfile {
188        InferenceProfile::None
189    }
190}
191
192/// Trait for models available via Anthropic's direct API
193///
194/// Models implement this to be usable with a future `AnthropicProvider`.
195pub trait AnthropicModel: Model {
196    /// The Anthropic API model ID
197    ///
198    /// e.g., "claude-sonnet-4-5-20250929"
199    fn anthropic_id(&self) -> &'static str;
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use crate::tool::{DocumentFormat, ImageFormat, ToolResult};
206    use crate::types::{
207        ContentBlock, Message, Role, ToolResultBlock, ToolResultStatus, ToolUseBlock,
208    };
209
210    /// Simple test model with predictable token estimation
211    struct TestModel;
212
213    impl Model for TestModel {
214        fn name(&self) -> &'static str {
215            "TestModel"
216        }
217
218        fn max_context_tokens(&self) -> usize {
219            100_000
220        }
221
222        fn max_output_tokens(&self) -> usize {
223            4096
224        }
225
226        fn estimate_token_count(&self, text: &str) -> usize {
227            // Simple: ~4 chars per token, rounding up
228            text.len().div_ceil(4)
229        }
230    }
231
232    // ===== Token Estimation Tests =====
233
234    #[test]
235    fn test_estimate_message_tokens_empty() {
236        let model = TestModel;
237        let messages: Vec<Message> = vec![];
238        assert_eq!(model.estimate_message_tokens(&messages), 0);
239    }
240
241    #[test]
242    fn test_estimate_message_tokens_simple_text() {
243        let model = TestModel;
244        let messages = vec![Message::user("Hello world")]; // 11 chars = 3 tokens + 4 overhead = 7
245
246        let tokens = model.estimate_message_tokens(&messages);
247        assert_eq!(tokens, 7);
248    }
249
250    #[test]
251    fn test_estimate_message_tokens_multiple_messages() {
252        let model = TestModel;
253        let messages = vec![
254            Message::user("Hello"),         // 5 chars = 2 tokens + 4 overhead = 6
255            Message::assistant("Hi there"), // 8 chars = 2 tokens + 4 overhead = 6
256        ];
257
258        let tokens = model.estimate_message_tokens(&messages);
259        assert_eq!(tokens, 12);
260    }
261
262    #[test]
263    fn test_estimate_content_block_tokens_text() {
264        let model = TestModel;
265        let block = ContentBlock::Text("test".to_string()); // 4 chars = 1 token
266        assert_eq!(model.estimate_content_block_tokens(&block), 1);
267    }
268
269    #[test]
270    fn test_estimate_content_block_tokens_text_empty() {
271        let model = TestModel;
272        let block = ContentBlock::Text(String::new());
273        assert_eq!(model.estimate_content_block_tokens(&block), 0);
274    }
275
276    #[test]
277    fn test_estimate_content_block_tokens_tool_use() {
278        let model = TestModel;
279        let block = ContentBlock::ToolUse(ToolUseBlock {
280            id: "id12".to_string(),               // 4 chars = 1 token
281            name: "search".to_string(),           // 6 chars = 2 tokens
282            input: serde_json::json!({"q": "x"}), // ~10 chars = 3 tokens
283        });
284
285        // 1 + 2 + 3 + 10 (overhead) = 16
286        let tokens = model.estimate_content_block_tokens(&block);
287        assert!(tokens >= 10, "Should include overhead, got {}", tokens);
288    }
289
290    #[test]
291    fn test_estimate_content_block_tokens_tool_result_text() {
292        let model = TestModel;
293        let block = ContentBlock::ToolResult(ToolResultBlock {
294            tool_use_id: "id12".to_string(), // 4 chars = 1 token
295            content: ToolResult::Text("result text".to_string()), // 11 chars = 3 tokens
296            status: ToolResultStatus::Success,
297        });
298
299        // 1 + 3 + 10 (overhead) = 14
300        let tokens = model.estimate_content_block_tokens(&block);
301        assert!(tokens >= 10, "Should include overhead, got {}", tokens);
302    }
303
304    #[test]
305    fn test_estimate_content_block_tokens_tool_result_json() {
306        let model = TestModel;
307        let block = ContentBlock::ToolResult(ToolResultBlock {
308            tool_use_id: "id".to_string(),
309            content: ToolResult::Json(serde_json::json!({"key": "value"})),
310            status: ToolResultStatus::Success,
311        });
312
313        let tokens = model.estimate_content_block_tokens(&block);
314        assert!(tokens >= 10, "Should include overhead, got {}", tokens);
315    }
316
317    #[test]
318    fn test_estimate_content_block_tokens_image() {
319        let model = TestModel;
320        // 7500 bytes / 750 + 85 = 95 tokens
321        let data = vec![0u8; 7500];
322        let block = ContentBlock::ToolResult(ToolResultBlock {
323            tool_use_id: "img".to_string(),
324            content: ToolResult::Image {
325                format: ImageFormat::Png,
326                data,
327            },
328            status: ToolResultStatus::Success,
329        });
330
331        let tokens = model.estimate_content_block_tokens(&block);
332        // 7500/750 + 85 = 10 + 85 = 95 + tool_use_id tokens + overhead
333        assert!(
334            tokens >= 95,
335            "Expected at least 95 tokens for image, got {}",
336            tokens
337        );
338    }
339
340    #[test]
341    fn test_estimate_content_block_tokens_document() {
342        let model = TestModel;
343        // 5000 bytes / 500 + 50 = 60 tokens
344        let data = vec![0u8; 5000];
345        let block = ContentBlock::ToolResult(ToolResultBlock {
346            tool_use_id: "doc".to_string(),
347            content: ToolResult::Document {
348                format: DocumentFormat::Pdf,
349                data,
350                name: Some("test.pdf".to_string()),
351            },
352            status: ToolResultStatus::Success,
353        });
354
355        let tokens = model.estimate_content_block_tokens(&block);
356        // 5000/500 + 50 = 10 + 50 = 60 + overhead
357        assert!(
358            tokens >= 60,
359            "Expected at least 60 tokens for document, got {}",
360            tokens
361        );
362    }
363
364    #[test]
365    fn test_estimate_content_block_tokens_thinking() {
366        let model = TestModel;
367        let block = ContentBlock::Thinking {
368            thinking: "complex reasoning here".to_string(), // 22 chars = 6 tokens
369            signature: "sig".to_string(),                   // 3 chars = 1 token
370        };
371
372        // 6 + 1 + 10 (overhead) = 17
373        let tokens = model.estimate_content_block_tokens(&block);
374        assert!(tokens >= 10, "Should include overhead, got {}", tokens);
375    }
376
377    #[test]
378    fn test_estimate_message_with_multiple_content_blocks() {
379        let model = TestModel;
380        let messages = vec![Message {
381            role: Role::Assistant,
382            content: vec![
383                ContentBlock::Text("Let me search".to_string()),
384                ContentBlock::ToolUse(ToolUseBlock {
385                    id: "1".to_string(),
386                    name: "search".to_string(),
387                    input: serde_json::json!({"q": "test"}),
388                }),
389            ],
390        }];
391
392        let tokens = model.estimate_message_tokens(&messages);
393        // 4 (overhead) + text tokens + tool use tokens
394        assert!(tokens > 4, "Should have content tokens plus overhead");
395    }
396
397    // ===== InferenceProfile Tests =====
398
399    #[test]
400    fn test_inference_profile_apply_none() {
401        let profile = InferenceProfile::None;
402        assert_eq!(profile.apply_to("anthropic.claude-3"), "anthropic.claude-3");
403    }
404
405    #[test]
406    fn test_inference_profile_apply_us() {
407        let profile = InferenceProfile::US;
408        assert_eq!(
409            profile.apply_to("anthropic.claude-3"),
410            "us.anthropic.claude-3"
411        );
412    }
413
414    #[test]
415    fn test_inference_profile_apply_eu() {
416        let profile = InferenceProfile::EU;
417        assert_eq!(
418            profile.apply_to("anthropic.claude-3"),
419            "eu.anthropic.claude-3"
420        );
421    }
422
423    #[test]
424    fn test_inference_profile_apply_apac() {
425        let profile = InferenceProfile::APAC;
426        assert_eq!(profile.apply_to("model-id"), "apac.model-id");
427    }
428
429    #[test]
430    fn test_inference_profile_apply_global() {
431        let profile = InferenceProfile::Global;
432        assert_eq!(profile.apply_to("model-id"), "global.model-id");
433    }
434
435    #[test]
436    fn test_inference_profile_all_variants() {
437        let cases = [
438            (InferenceProfile::None, "model", "model"),
439            (InferenceProfile::US, "model", "us.model"),
440            (InferenceProfile::EU, "model", "eu.model"),
441            (InferenceProfile::APAC, "model", "apac.model"),
442            (InferenceProfile::Global, "model", "global.model"),
443        ];
444
445        for (profile, base, expected) in cases {
446            assert_eq!(profile.apply_to(base), expected, "Failed for {:?}", profile);
447        }
448    }
449
450    #[test]
451    fn test_inference_profile_default() {
452        let profile = InferenceProfile::default();
453        assert_eq!(profile, InferenceProfile::None);
454    }
455}