pub mod anthropic;
pub mod openai;
pub use anthropic::AnthropicProvider;
pub use openai::OpenAIProvider;
use crate::context::{Message, MessageRole};
use crate::error::AgentError;
use async_trait::async_trait;
use futures::Stream;
use serde::{Deserialize, Serialize};
use std::pin::Pin;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderConfig {
pub model: String,
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
}
fn default_temperature() -> f32 {
0.7
}
impl ProviderConfig {
pub fn new(model: impl Into<String>) -> Self {
Self {
model: model.into(),
temperature: default_temperature(),
max_tokens: None,
top_p: None,
frequency_penalty: None,
presence_penalty: None,
}
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature.clamp(0.0, 2.0);
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p.clamp(0.0, 1.0));
self
}
pub fn with_frequency_penalty(mut self, penalty: f32) -> Self {
self.frequency_penalty = Some(penalty.clamp(-2.0, 2.0));
self
}
pub fn with_presence_penalty(mut self, penalty: f32) -> Self {
self.presence_penalty = Some(penalty.clamp(-2.0, 2.0));
self
}
}
pub type StreamChunk = std::result::Result<String, AgentError>;
#[async_trait]
pub trait LLMProvider: Send + Sync {
async fn complete(&self, messages: Vec<Message>) -> std::result::Result<String, AgentError>;
async fn stream(
&self,
messages: Vec<Message>,
) -> std::result::Result<Pin<Box<dyn Stream<Item = StreamChunk> + Send>>, AgentError>;
fn name(&self) -> &str;
fn config(&self) -> &ProviderConfig;
}
pub fn messages_to_provider_format(messages: &[Message]) -> Vec<(MessageRole, String)> {
messages
.iter()
.map(|m| (m.role, m.content.clone()))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_config_creation() {
let config = ProviderConfig::new("gpt-5");
assert_eq!(config.model, "gpt-5");
assert_eq!(config.temperature, 0.7);
assert!(config.max_tokens.is_none());
}
#[test]
fn test_provider_config_with_temperature() {
let config = ProviderConfig::new("gpt-5").with_temperature(0.5);
assert_eq!(config.temperature, 0.5);
}
#[test]
fn test_provider_config_temperature_clamping() {
let config1 = ProviderConfig::new("gpt-5").with_temperature(-0.5);
assert_eq!(config1.temperature, 0.0);
let config2 = ProviderConfig::new("gpt-5").with_temperature(3.0);
assert_eq!(config2.temperature, 2.0);
}
#[test]
fn test_provider_config_with_max_tokens() {
let config = ProviderConfig::new("gpt-5").with_max_tokens(1000);
assert_eq!(config.max_tokens, Some(1000));
}
#[test]
fn test_provider_config_with_top_p() {
let config = ProviderConfig::new("gpt-5").with_top_p(0.9);
assert_eq!(config.top_p, Some(0.9));
}
#[test]
fn test_provider_config_top_p_clamping() {
let config1 = ProviderConfig::new("gpt-5").with_top_p(-0.1);
assert_eq!(config1.top_p, Some(0.0));
let config2 = ProviderConfig::new("gpt-5").with_top_p(1.5);
assert_eq!(config2.top_p, Some(1.0));
}
#[test]
fn test_provider_config_with_penalties() {
let config = ProviderConfig::new("gpt-5")
.with_frequency_penalty(0.5)
.with_presence_penalty(0.3);
assert_eq!(config.frequency_penalty, Some(0.5));
assert_eq!(config.presence_penalty, Some(0.3));
}
#[test]
fn test_provider_config_penalty_clamping() {
let config1 = ProviderConfig::new("gpt-5")
.with_frequency_penalty(-3.0)
.with_presence_penalty(3.0);
assert_eq!(config1.frequency_penalty, Some(-2.0));
assert_eq!(config1.presence_penalty, Some(2.0));
}
#[test]
fn test_provider_config_serialization() {
let config = ProviderConfig::new("gpt-5")
.with_temperature(0.8)
.with_max_tokens(500);
let json = serde_json::to_string(&config).unwrap();
let deserialized: ProviderConfig = serde_json::from_str(&json).unwrap();
assert_eq!(config.model, deserialized.model);
assert_eq!(config.temperature, deserialized.temperature);
assert_eq!(config.max_tokens, deserialized.max_tokens);
}
#[test]
fn test_messages_to_provider_format() {
let messages = vec![
Message::system("You are a helpful assistant"),
Message::user("Hello"),
Message::assistant("Hi there!"),
];
let provider_messages = messages_to_provider_format(&messages);
assert_eq!(provider_messages.len(), 3);
assert_eq!(provider_messages[0].0, MessageRole::System);
assert_eq!(provider_messages[0].1, "You are a helpful assistant");
assert_eq!(provider_messages[1].0, MessageRole::User);
assert_eq!(provider_messages[1].1, "Hello");
assert_eq!(provider_messages[2].0, MessageRole::Assistant);
assert_eq!(provider_messages[2].1, "Hi there!");
}
}