use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ModelProvider {
OpenAI,
Anthropic,
Groq,
Together,
Fireworks,
Local,
}
impl std::fmt::Display for ModelProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ModelProvider::OpenAI => write!(f, "openai"),
ModelProvider::Anthropic => write!(f, "anthropic"),
ModelProvider::Groq => write!(f, "groq"),
ModelProvider::Together => write!(f, "together"),
ModelProvider::Fireworks => write!(f, "fireworks"),
ModelProvider::Local => write!(f, "local"),
}
}
}
impl std::str::FromStr for ModelProvider {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"openai" => Ok(ModelProvider::OpenAI),
"anthropic" => Ok(ModelProvider::Anthropic),
"groq" => Ok(ModelProvider::Groq),
"together" => Ok(ModelProvider::Together),
"fireworks" => Ok(ModelProvider::Fireworks),
"local" => Ok(ModelProvider::Local),
_ => Err(format!("Unknown provider: {}", s)),
}
}
}
impl ModelProvider {
pub fn base_url(&self) -> &'static str {
match self {
ModelProvider::OpenAI => "https://api.openai.com/v1",
ModelProvider::Anthropic => "https://api.anthropic.com/v1",
ModelProvider::Groq => "https://api.groq.com/openai/v1",
ModelProvider::Together => "https://api.together.xyz/v1",
ModelProvider::Fireworks => "https://api.fireworks.ai/inference/v1",
ModelProvider::Local => "",
}
}
pub fn api_key_env_var(&self) -> &'static str {
match self {
ModelProvider::OpenAI => "OPENAI_API_KEY",
ModelProvider::Anthropic => "ANTHROPIC_API_KEY",
ModelProvider::Groq => "GROQ_API_KEY",
ModelProvider::Together => "TOGETHER_API_KEY",
ModelProvider::Fireworks => "FIREWORKS_API_KEY",
ModelProvider::Local => "",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
pub id: String,
pub name: String,
pub provider: ModelProvider,
pub context_window: u32,
pub max_output_tokens: u32,
pub input_cost_per_million: f64,
pub output_cost_per_million: f64,
pub supports_tools: bool,
pub supports_vision: bool,
pub supports_json_mode: bool,
}
impl ModelInfo {
pub fn for_model(model_id: &str) -> Option<Self> {
match model_id {
"gpt-4o" => Some(Self {
id: "gpt-4o".to_string(),
name: "GPT-4o".to_string(),
provider: ModelProvider::OpenAI,
context_window: 128000,
max_output_tokens: 4096,
input_cost_per_million: 5.0,
output_cost_per_million: 15.0,
supports_tools: true,
supports_vision: true,
supports_json_mode: true,
}),
"gpt-4o-mini" => Some(Self {
id: "gpt-4o-mini".to_string(),
name: "GPT-4o Mini".to_string(),
provider: ModelProvider::OpenAI,
context_window: 128000,
max_output_tokens: 16384,
input_cost_per_million: 0.15,
output_cost_per_million: 0.6,
supports_tools: true,
supports_vision: true,
supports_json_mode: true,
}),
"gpt-4-turbo" => Some(Self {
id: "gpt-4-turbo".to_string(),
name: "GPT-4 Turbo".to_string(),
provider: ModelProvider::OpenAI,
context_window: 128000,
max_output_tokens: 4096,
input_cost_per_million: 10.0,
output_cost_per_million: 30.0,
supports_tools: true,
supports_vision: true,
supports_json_mode: true,
}),
"claude-3-5-sonnet-20241022" | "claude-3-5-sonnet" => Some(Self {
id: "claude-3-5-sonnet-20241022".to_string(),
name: "Claude 3.5 Sonnet".to_string(),
provider: ModelProvider::Anthropic,
context_window: 200000,
max_output_tokens: 8192,
input_cost_per_million: 3.0,
output_cost_per_million: 15.0,
supports_tools: true,
supports_vision: true,
supports_json_mode: false,
}),
"claude-3-opus-20240229" | "claude-3-opus" => Some(Self {
id: "claude-3-opus-20240229".to_string(),
name: "Claude 3 Opus".to_string(),
provider: ModelProvider::Anthropic,
context_window: 200000,
max_output_tokens: 4096,
input_cost_per_million: 15.0,
output_cost_per_million: 75.0,
supports_tools: true,
supports_vision: true,
supports_json_mode: false,
}),
"claude-3-haiku-20240307" | "claude-3-haiku" => Some(Self {
id: "claude-3-haiku-20240307".to_string(),
name: "Claude 3 Haiku".to_string(),
provider: ModelProvider::Anthropic,
context_window: 200000,
max_output_tokens: 4096,
input_cost_per_million: 0.25,
output_cost_per_million: 1.25,
supports_tools: true,
supports_vision: true,
supports_json_mode: false,
}),
"llama-3.1-70b-versatile" => Some(Self {
id: "llama-3.1-70b-versatile".to_string(),
name: "Llama 3.1 70B".to_string(),
provider: ModelProvider::Groq,
context_window: 131072,
max_output_tokens: 8192,
input_cost_per_million: 0.59,
output_cost_per_million: 0.79,
supports_tools: true,
supports_vision: false,
supports_json_mode: true,
}),
"llama-3.1-8b-instant" => Some(Self {
id: "llama-3.1-8b-instant".to_string(),
name: "Llama 3.1 8B".to_string(),
provider: ModelProvider::Groq,
context_window: 131072,
max_output_tokens: 8192,
input_cost_per_million: 0.05,
output_cost_per_million: 0.08,
supports_tools: true,
supports_vision: false,
supports_json_mode: true,
}),
"mixtral-8x7b-32768" => Some(Self {
id: "mixtral-8x7b-32768".to_string(),
name: "Mixtral 8x7B".to_string(),
provider: ModelProvider::Groq,
context_window: 32768,
max_output_tokens: 8192,
input_cost_per_million: 0.24,
output_cost_per_million: 0.24,
supports_tools: true,
supports_vision: false,
supports_json_mode: true,
}),
_ => None,
}
}
pub fn estimate_cost(&self, prompt_tokens: u32, completion_tokens: u32) -> f64 {
let input_cost = (prompt_tokens as f64 / 1_000_000.0) * self.input_cost_per_million;
let output_cost = (completion_tokens as f64 / 1_000_000.0) * self.output_cost_per_million;
input_cost + output_cost
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelList {
pub object: String,
pub data: Vec<ModelListEntry>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelListEntry {
pub id: String,
pub object: String,
pub created: i64,
pub owned_by: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_from_str() {
assert_eq!(
"openai".parse::<ModelProvider>().unwrap(),
ModelProvider::OpenAI
);
assert_eq!(
"anthropic".parse::<ModelProvider>().unwrap(),
ModelProvider::Anthropic
);
assert_eq!(
"groq".parse::<ModelProvider>().unwrap(),
ModelProvider::Groq
);
}
#[test]
fn test_model_info() {
let info = ModelInfo::for_model("gpt-4o-mini").unwrap();
assert_eq!(info.provider, ModelProvider::OpenAI);
assert!(info.supports_tools);
assert!(info.supports_vision);
}
#[test]
fn test_cost_estimation() {
let info = ModelInfo::for_model("gpt-4o-mini").unwrap();
let cost = info.estimate_cost(1000, 100);
assert!((cost - 0.00021).abs() < 0.00001);
}
}