use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use crate::AofResult;
#[async_trait]
pub trait Model: Send + Sync {
async fn generate(&self, request: &ModelRequest) -> AofResult<ModelResponse>;
async fn generate_stream(
&self,
request: &ModelRequest,
) -> AofResult<Pin<Box<dyn futures::Stream<Item = AofResult<StreamChunk>> + Send>>>;
fn config(&self) -> &ModelConfig;
fn provider(&self) -> ModelProvider;
fn count_tokens(&self, text: &str) -> usize {
text.len() / 4
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ModelProvider {
Anthropic,
OpenAI,
Google,
Groq,
Bedrock,
Azure,
Ollama,
Custom,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
pub model: String,
pub provider: ModelProvider,
#[serde(skip_serializing_if = "Option::is_none")]
pub api_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub endpoint: Option<String>,
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<usize>,
#[serde(default = "default_timeout")]
pub timeout_secs: u64,
#[serde(default)]
pub headers: HashMap<String, String>,
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
fn default_temperature() -> f32 {
0.7
}
fn default_timeout() -> u64 {
60
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelRequest {
pub messages: Vec<RequestMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub tools: Vec<ToolDefinition>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<usize>,
#[serde(default)]
pub stream: bool,
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestMessage {
pub role: MessageRole,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<crate::ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
User,
Assistant,
System,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelResponse {
pub content: String,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub tool_calls: Vec<crate::ToolCall>,
pub stop_reason: StopReason,
pub usage: Usage,
#[serde(flatten)]
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum StopReason {
EndTurn,
MaxTokens,
StopSequence,
ToolUse,
ContentFilter,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
pub struct Usage {
pub input_tokens: usize,
pub output_tokens: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum StreamChunk {
ContentDelta { delta: String },
ToolCall { tool_call: crate::ToolCall },
Done { usage: Usage, stop_reason: StopReason },
}
pub type ModelRef = Arc<dyn Model>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_provider_serialization() {
let provider = ModelProvider::Anthropic;
let json = serde_json::to_string(&provider).unwrap();
assert_eq!(json, "\"anthropic\"");
let provider: ModelProvider = serde_json::from_str("\"openai\"").unwrap();
assert_eq!(provider, ModelProvider::OpenAI);
}
#[test]
fn test_model_config_defaults() {
let json = r#"{
"model": "gpt-4",
"provider": "openai"
}"#;
let config: ModelConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.model, "gpt-4");
assert_eq!(config.provider, ModelProvider::OpenAI);
assert_eq!(config.temperature, 0.7); assert_eq!(config.timeout_secs, 60); assert!(config.api_key.is_none());
}
#[test]
fn test_model_config_full() {
let config = ModelConfig {
model: "claude-3-5-sonnet".to_string(),
provider: ModelProvider::Anthropic,
api_key: Some("test_key".to_string()),
endpoint: Some("https://api.anthropic.com".to_string()),
temperature: 0.3,
max_tokens: Some(4096),
timeout_secs: 120,
headers: {
let mut h = HashMap::new();
h.insert("X-Custom".to_string(), "value".to_string());
h
},
extra: HashMap::new(),
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: ModelConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.model, "claude-3-5-sonnet");
assert_eq!(deserialized.temperature, 0.3);
assert_eq!(deserialized.max_tokens, Some(4096));
}
#[test]
fn test_message_role() {
assert_eq!(
serde_json::to_string(&MessageRole::User).unwrap(),
"\"user\""
);
assert_eq!(
serde_json::to_string(&MessageRole::Assistant).unwrap(),
"\"assistant\""
);
assert_eq!(
serde_json::to_string(&MessageRole::System).unwrap(),
"\"system\""
);
assert_eq!(
serde_json::to_string(&MessageRole::Tool).unwrap(),
"\"tool\""
);
}
#[test]
fn test_model_request() {
let request = ModelRequest {
messages: vec![
RequestMessage {
role: MessageRole::User,
content: "Hello".to_string(),
tool_calls: None,
tool_call_id: None,
},
],
system: Some("You are a helpful assistant.".to_string()),
tools: vec![],
temperature: Some(0.5),
max_tokens: Some(1000),
stream: false,
extra: HashMap::new(),
};
assert_eq!(request.messages.len(), 1);
assert_eq!(request.system, Some("You are a helpful assistant.".to_string()));
}
#[test]
fn test_stop_reason_serialization() {
let end_turn = StopReason::EndTurn;
let json = serde_json::to_string(&end_turn).unwrap();
assert_eq!(json, "\"end_turn\"");
let max_tokens: StopReason = serde_json::from_str("\"max_tokens\"").unwrap();
assert_eq!(max_tokens, StopReason::MaxTokens);
}
#[test]
fn test_usage() {
let usage = Usage {
input_tokens: 100,
output_tokens: 50,
};
assert_eq!(usage.input_tokens, 100);
assert_eq!(usage.output_tokens, 50);
let default_usage = Usage::default();
assert_eq!(default_usage.input_tokens, 0);
assert_eq!(default_usage.output_tokens, 0);
}
#[test]
fn test_stream_chunk_content_delta() {
let chunk = StreamChunk::ContentDelta {
delta: "Hello".to_string(),
};
let json = serde_json::to_string(&chunk).unwrap();
assert!(json.contains("content_delta"));
assert!(json.contains("Hello"));
}
#[test]
fn test_stream_chunk_done() {
let chunk = StreamChunk::Done {
usage: Usage {
input_tokens: 10,
output_tokens: 20,
},
stop_reason: StopReason::EndTurn,
};
let json = serde_json::to_string(&chunk).unwrap();
assert!(json.contains("done"));
assert!(json.contains("end_turn"));
}
#[test]
fn test_model_response() {
let response = ModelResponse {
content: "Hello, world!".to_string(),
tool_calls: vec![],
stop_reason: StopReason::EndTurn,
usage: Usage {
input_tokens: 5,
output_tokens: 3,
},
metadata: HashMap::new(),
};
assert_eq!(response.content, "Hello, world!");
assert!(response.tool_calls.is_empty());
assert_eq!(response.stop_reason, StopReason::EndTurn);
}
}