use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use super::traits::ProviderError;
#[async_trait::async_trait]
pub trait CredentialProvider: std::fmt::Debug + Send + Sync {
async fn current(&self) -> Result<String, ProviderError>;
async fn invalidate(&self) -> Result<(), ProviderError> {
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct StaticCredentialProvider {
key: String,
}
impl StaticCredentialProvider {
pub fn new(key: impl Into<String>) -> Self {
Self { key: key.into() }
}
}
#[async_trait::async_trait]
impl CredentialProvider for StaticCredentialProvider {
async fn current(&self) -> Result<String, ProviderError> {
Ok(self.key.clone())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ApiProtocol {
AnthropicMessages,
OpenAiCompletions,
OpenAiResponses,
AzureOpenAiResponses,
GoogleGenerativeAi,
GoogleVertex,
BedrockConverseStream,
}
impl std::fmt::Display for ApiProtocol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::AnthropicMessages => write!(f, "anthropic_messages"),
Self::OpenAiCompletions => write!(f, "openai_completions"),
Self::OpenAiResponses => write!(f, "openai_responses"),
Self::AzureOpenAiResponses => write!(f, "azure_openai_responses"),
Self::GoogleGenerativeAi => write!(f, "google_generative_ai"),
Self::GoogleVertex => write!(f, "google_vertex"),
Self::BedrockConverseStream => write!(f, "bedrock_converse_stream"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostConfig {
pub input_per_million: f64,
pub output_per_million: f64,
#[serde(default)]
pub cache_read_per_million: f64,
#[serde(default)]
pub cache_write_per_million: f64,
}
impl Default for CostConfig {
fn default() -> Self {
Self {
input_per_million: 0.0,
output_per_million: 0.0,
cache_read_per_million: 0.0,
cache_write_per_million: 0.0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum MaxTokensField {
#[default]
MaxTokens,
MaxCompletionTokens,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum ThinkingFormat {
#[default]
OpenAi,
Xai,
Qwen,
OpenRouter,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAiCompat {
pub supports_store: bool,
pub supports_developer_role: bool,
pub supports_reasoning_effort: bool,
pub supports_usage_in_streaming: bool,
pub max_tokens_field: MaxTokensField,
pub requires_tool_result_name: bool,
pub requires_assistant_after_tool_result: bool,
pub thinking_format: ThinkingFormat,
}
impl Default for OpenAiCompat {
fn default() -> Self {
Self {
supports_store: false,
supports_developer_role: false,
supports_reasoning_effort: false,
supports_usage_in_streaming: true, max_tokens_field: MaxTokensField::MaxTokens,
requires_tool_result_name: false,
requires_assistant_after_tool_result: false,
thinking_format: ThinkingFormat::OpenAi,
}
}
}
impl OpenAiCompat {
pub fn openai() -> Self {
Self {
supports_store: true,
supports_developer_role: true,
supports_reasoning_effort: true,
supports_usage_in_streaming: true,
max_tokens_field: MaxTokensField::MaxCompletionTokens,
..Default::default()
}
}
pub fn xai() -> Self {
Self {
supports_usage_in_streaming: true,
thinking_format: ThinkingFormat::Xai, ..Default::default()
}
}
pub fn groq() -> Self {
Self {
supports_usage_in_streaming: true,
..Default::default()
}
}
pub fn cerebras() -> Self {
Self::default() }
pub fn openrouter() -> Self {
Self {
supports_developer_role: true, supports_usage_in_streaming: true,
max_tokens_field: MaxTokensField::MaxTokens, thinking_format: ThinkingFormat::OpenRouter, ..Default::default()
}
}
pub fn mistral() -> Self {
Self {
supports_usage_in_streaming: true,
max_tokens_field: MaxTokensField::MaxTokens,
..Default::default()
}
}
pub fn deepseek() -> Self {
Self {
supports_usage_in_streaming: true,
max_tokens_field: MaxTokensField::MaxCompletionTokens,
..Default::default()
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
pub id: String,
pub name: String,
pub api: ApiProtocol,
pub provider: String,
pub base_url: String,
#[serde(default)]
pub api_key: String,
pub reasoning: bool,
pub context_window: u32,
pub max_tokens: u32,
#[serde(default)]
pub cost: CostConfig,
#[serde(default)]
pub headers: HashMap<String, String>,
#[serde(default)]
pub compat: Option<OpenAiCompat>,
#[serde(skip)]
pub credentials: Option<Arc<dyn CredentialProvider>>,
}
impl ModelConfig {
pub fn anthropic(
id: impl Into<String>, name: impl Into<String>, api_key: impl Into<String>, ) -> Self {
Self {
id: id.into(),
name: name.into(),
api: ApiProtocol::AnthropicMessages,
provider: "anthropic".into(),
base_url: "https://api.anthropic.com".into(),
api_key: api_key.into(),
reasoning: false,
context_window: 200_000,
max_tokens: 8192,
cost: CostConfig::default(),
headers: HashMap::new(),
compat: None, credentials: None,
}
}
pub fn openai(
id: impl Into<String>, name: impl Into<String>, api_key: impl Into<String>, ) -> Self {
Self {
id: id.into(),
name: name.into(),
api: ApiProtocol::OpenAiCompletions,
provider: "openai".into(),
base_url: "https://api.openai.com/v1".into(),
api_key: api_key.into(),
reasoning: false,
context_window: 128_000,
max_tokens: 4096,
cost: CostConfig::default(),
headers: HashMap::new(),
compat: Some(OpenAiCompat::openai()), credentials: None,
}
}
pub fn local(
base_url: impl Into<String>, model_id: impl Into<String>, api_key: impl Into<String>, ) -> Self {
Self {
id: model_id.into(),
name: "Local Model".into(),
api: ApiProtocol::OpenAiCompletions,
provider: "local".into(),
base_url: base_url.into(), api_key: api_key.into(),
reasoning: false,
context_window: 128_000,
max_tokens: 4096,
cost: CostConfig::default(),
headers: HashMap::new(),
compat: Some(OpenAiCompat::default()), credentials: None,
}
}
pub fn google(
id: impl Into<String>, name: impl Into<String>, api_key: impl Into<String>, ) -> Self {
Self {
id: id.into(),
name: name.into(),
api: ApiProtocol::GoogleGenerativeAi,
provider: "google".into(),
base_url: "https://generativelanguage.googleapis.com".into(),
api_key: api_key.into(),
reasoning: false,
context_window: 1_000_000,
max_tokens: 8192,
cost: CostConfig::default(),
headers: HashMap::new(),
compat: None, credentials: None,
}
}
pub fn openrouter(
model_id: impl Into<String>, api_key: impl Into<String>, ) -> Self {
let id = model_id.into();
Self {
name: id.clone(),
id,
api: ApiProtocol::OpenAiCompletions,
provider: "openrouter".into(),
base_url: "https://openrouter.ai/api/v1".into(),
api_key: api_key.into(),
reasoning: false,
context_window: 200_000, max_tokens: 4096,
cost: CostConfig::default(),
headers: HashMap::new(),
compat: Some(OpenAiCompat::openrouter()),
credentials: None,
}
}
pub fn with_credentials(mut self, creds: Arc<dyn CredentialProvider>) -> Self {
self.credentials = Some(creds);
self
}
pub async fn resolve_api_key(&self) -> Result<String, ProviderError> {
match &self.credentials {
Some(c) => c.current().await,
None => Ok(self.api_key.clone()),
}
}
pub async fn invalidate_credentials(&self) -> Result<(), ProviderError> {
match &self.credentials {
Some(c) => c.invalidate().await,
None => Ok(()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_config_anthropic() {
let config =
ModelConfig::anthropic("claude-sonnet-4-20250514", "Claude Sonnet 4", "sk-ant-key");
assert_eq!(config.api, ApiProtocol::AnthropicMessages);
assert_eq!(config.provider, "anthropic");
assert_eq!(config.api_key, "sk-ant-key");
assert!(config.compat.is_none());
}
#[test]
fn test_model_config_openai() {
let config = ModelConfig::openai("gpt-4o", "GPT-4o", "sk-key");
assert_eq!(config.api, ApiProtocol::OpenAiCompletions);
let compat = config.compat.unwrap();
assert!(compat.supports_store);
assert!(compat.supports_developer_role);
assert_eq!(compat.max_tokens_field, MaxTokensField::MaxCompletionTokens);
}
#[test]
fn test_openai_compat_variants() {
let xai = OpenAiCompat::xai();
assert_eq!(xai.thinking_format, ThinkingFormat::Xai);
assert!(!xai.supports_store);
let groq = OpenAiCompat::groq();
assert!(groq.supports_usage_in_streaming);
assert!(!groq.supports_store);
let deepseek = OpenAiCompat::deepseek();
assert_eq!(
deepseek.max_tokens_field,
MaxTokensField::MaxCompletionTokens
);
}
#[test]
fn test_api_protocol_display() {
assert_eq!(
ApiProtocol::AnthropicMessages.to_string(),
"anthropic_messages"
);
assert_eq!(
ApiProtocol::OpenAiCompletions.to_string(),
"openai_completions"
);
assert_eq!(
ApiProtocol::GoogleGenerativeAi.to_string(),
"google_generative_ai"
);
}
#[test]
fn test_cost_config_default() {
let cost = CostConfig::default();
assert_eq!(cost.input_per_million, 0.0);
assert_eq!(cost.output_per_million, 0.0);
}
}