use crate::events::TokenUsage;
use crate::types::{ContentBlock, Message, StopReason, ToolDefinition};
#[derive(Debug, Clone)]
pub struct ModelRequest {
pub messages: Vec<Message>,
pub system_prompt: Option<String>,
pub max_tokens: i32,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub tools: Vec<ToolDefinition>,
}
#[derive(Debug, Clone)]
pub struct ModelResponse {
pub message: Message,
pub stop_reason: StopReason,
pub usage: Option<TokenUsage>,
}
pub trait Model: Send + Sync {
fn name(&self) -> &'static str;
fn max_context_tokens(&self) -> usize;
fn max_output_tokens(&self) -> usize;
fn estimate_token_count(&self, text: &str) -> usize;
fn estimate_message_tokens(&self, messages: &[Message]) -> usize {
let mut total = 0;
for message in messages {
total += 4;
for block in &message.content {
total += self.estimate_content_block_tokens(block);
}
}
total
}
fn estimate_content_block_tokens(&self, block: &ContentBlock) -> usize {
match block {
ContentBlock::Text(text) => self.estimate_token_count(text),
ContentBlock::ToolUse(tool_use) => {
self.estimate_token_count(&tool_use.name)
+ self.estimate_token_count(&tool_use.id)
+ self.estimate_token_count(&tool_use.input.to_string())
+ 10 }
ContentBlock::ToolResult(result) => {
self.estimate_token_count(&result.tool_use_id)
+ match &result.content {
crate::tool::ToolResult::Text(t) => self.estimate_token_count(t.as_str()),
crate::tool::ToolResult::Json(v) => {
self.estimate_token_count(&v.to_string())
}
crate::tool::ToolResult::Image { data, .. } => {
data.len() / 750 + 85 }
crate::tool::ToolResult::Document { data, .. } => {
data.len() / 500 + 50 }
}
+ 10 }
ContentBlock::Thinking {
thinking,
signature,
} => {
self.estimate_token_count(thinking) + self.estimate_token_count(signature) + 10
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum InferenceProfile {
#[default]
None,
US,
EU,
APAC,
Global,
}
impl InferenceProfile {
pub fn apply_to(&self, base_model_id: &str) -> String {
match self.prefix() {
Some(prefix) => format!("{}.{}", prefix, base_model_id),
None => base_model_id.to_string(),
}
}
fn prefix(&self) -> Option<&'static str> {
match self {
InferenceProfile::None => None,
InferenceProfile::US => Some("us"),
InferenceProfile::EU => Some("eu"),
InferenceProfile::APAC => Some("apac"),
InferenceProfile::Global => Some("global"),
}
}
}
pub trait BedrockModel: Model {
fn bedrock_id(&self) -> &'static str;
fn default_inference_profile(&self) -> InferenceProfile {
InferenceProfile::None
}
}
pub trait AnthropicModel: Model {
fn anthropic_id(&self) -> &'static str;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tool::{DocumentFormat, ImageFormat, ToolResult};
use crate::types::{
ContentBlock, Message, Role, ToolResultBlock, ToolResultStatus, ToolUseBlock,
};
struct TestModel;
impl Model for TestModel {
fn name(&self) -> &'static str {
"TestModel"
}
fn max_context_tokens(&self) -> usize {
100_000
}
fn max_output_tokens(&self) -> usize {
4096
}
fn estimate_token_count(&self, text: &str) -> usize {
text.len().div_ceil(4)
}
}
#[test]
fn test_estimate_message_tokens_empty() {
let model = TestModel;
let messages: Vec<Message> = vec![];
assert_eq!(model.estimate_message_tokens(&messages), 0);
}
#[test]
fn test_estimate_message_tokens_simple_text() {
let model = TestModel;
let messages = vec![Message::user("Hello world")];
let tokens = model.estimate_message_tokens(&messages);
assert_eq!(tokens, 7);
}
#[test]
fn test_estimate_message_tokens_multiple_messages() {
let model = TestModel;
let messages = vec![
Message::user("Hello"), Message::assistant("Hi there"), ];
let tokens = model.estimate_message_tokens(&messages);
assert_eq!(tokens, 12);
}
#[test]
fn test_estimate_content_block_tokens_text() {
let model = TestModel;
let block = ContentBlock::Text("test".to_string()); assert_eq!(model.estimate_content_block_tokens(&block), 1);
}
#[test]
fn test_estimate_content_block_tokens_text_empty() {
let model = TestModel;
let block = ContentBlock::Text(String::new());
assert_eq!(model.estimate_content_block_tokens(&block), 0);
}
#[test]
fn test_estimate_content_block_tokens_tool_use() {
let model = TestModel;
let block = ContentBlock::ToolUse(ToolUseBlock {
id: "id12".to_string(), name: "search".to_string(), input: serde_json::json!({"q": "x"}), });
let tokens = model.estimate_content_block_tokens(&block);
assert!(tokens >= 10, "Should include overhead, got {}", tokens);
}
#[test]
fn test_estimate_content_block_tokens_tool_result_text() {
let model = TestModel;
let block = ContentBlock::ToolResult(ToolResultBlock {
tool_use_id: "id12".to_string(), content: ToolResult::Text("result text".to_string()), status: ToolResultStatus::Success,
});
let tokens = model.estimate_content_block_tokens(&block);
assert!(tokens >= 10, "Should include overhead, got {}", tokens);
}
#[test]
fn test_estimate_content_block_tokens_tool_result_json() {
let model = TestModel;
let block = ContentBlock::ToolResult(ToolResultBlock {
tool_use_id: "id".to_string(),
content: ToolResult::Json(serde_json::json!({"key": "value"})),
status: ToolResultStatus::Success,
});
let tokens = model.estimate_content_block_tokens(&block);
assert!(tokens >= 10, "Should include overhead, got {}", tokens);
}
#[test]
fn test_estimate_content_block_tokens_image() {
let model = TestModel;
let data = vec![0u8; 7500];
let block = ContentBlock::ToolResult(ToolResultBlock {
tool_use_id: "img".to_string(),
content: ToolResult::Image {
format: ImageFormat::Png,
data,
},
status: ToolResultStatus::Success,
});
let tokens = model.estimate_content_block_tokens(&block);
assert!(
tokens >= 95,
"Expected at least 95 tokens for image, got {}",
tokens
);
}
#[test]
fn test_estimate_content_block_tokens_document() {
let model = TestModel;
let data = vec![0u8; 5000];
let block = ContentBlock::ToolResult(ToolResultBlock {
tool_use_id: "doc".to_string(),
content: ToolResult::Document {
format: DocumentFormat::Pdf,
data,
name: Some("test.pdf".to_string()),
},
status: ToolResultStatus::Success,
});
let tokens = model.estimate_content_block_tokens(&block);
assert!(
tokens >= 60,
"Expected at least 60 tokens for document, got {}",
tokens
);
}
#[test]
fn test_estimate_content_block_tokens_thinking() {
let model = TestModel;
let block = ContentBlock::Thinking {
thinking: "complex reasoning here".to_string(), signature: "sig".to_string(), };
let tokens = model.estimate_content_block_tokens(&block);
assert!(tokens >= 10, "Should include overhead, got {}", tokens);
}
#[test]
fn test_estimate_message_with_multiple_content_blocks() {
let model = TestModel;
let messages = vec![Message {
role: Role::Assistant,
content: vec![
ContentBlock::Text("Let me search".to_string()),
ContentBlock::ToolUse(ToolUseBlock {
id: "1".to_string(),
name: "search".to_string(),
input: serde_json::json!({"q": "test"}),
}),
],
}];
let tokens = model.estimate_message_tokens(&messages);
assert!(tokens > 4, "Should have content tokens plus overhead");
}
#[test]
fn test_inference_profile_apply_none() {
let profile = InferenceProfile::None;
assert_eq!(profile.apply_to("anthropic.claude-3"), "anthropic.claude-3");
}
#[test]
fn test_inference_profile_apply_us() {
let profile = InferenceProfile::US;
assert_eq!(
profile.apply_to("anthropic.claude-3"),
"us.anthropic.claude-3"
);
}
#[test]
fn test_inference_profile_apply_eu() {
let profile = InferenceProfile::EU;
assert_eq!(
profile.apply_to("anthropic.claude-3"),
"eu.anthropic.claude-3"
);
}
#[test]
fn test_inference_profile_apply_apac() {
let profile = InferenceProfile::APAC;
assert_eq!(profile.apply_to("model-id"), "apac.model-id");
}
#[test]
fn test_inference_profile_apply_global() {
let profile = InferenceProfile::Global;
assert_eq!(profile.apply_to("model-id"), "global.model-id");
}
#[test]
fn test_inference_profile_all_variants() {
let cases = [
(InferenceProfile::None, "model", "model"),
(InferenceProfile::US, "model", "us.model"),
(InferenceProfile::EU, "model", "eu.model"),
(InferenceProfile::APAC, "model", "apac.model"),
(InferenceProfile::Global, "model", "global.model"),
];
for (profile, base, expected) in cases {
assert_eq!(profile.apply_to(base), expected, "Failed for {:?}", profile);
}
}
#[test]
fn test_inference_profile_default() {
let profile = InferenceProfile::default();
assert_eq!(profile, InferenceProfile::None);
}
}