use anyhow::{Context, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: String,
}
#[derive(Debug, Clone)]
pub struct ChatOptions {
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub stream: Option<bool>,
}
#[derive(Debug, Clone)]
pub struct ChatResponse {
pub content: String,
pub usage: Option<Usage>,
}
#[derive(Debug, Clone)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[async_trait]
pub trait LLMProvider: Send + Sync {
async fn chat(
&self,
messages: &[Message],
options: Option<ChatOptions>,
) -> Result<ChatResponse>;
async fn chat_stream(
&self,
_messages: &[Message],
_options: Option<ChatOptions>,
) -> Result<()> {
anyhow::bail!("Streaming not supported by this provider")
}
fn get_provider_name(&self) -> &str;
fn get_model_name(&self) -> &str;
}
pub async fn create_llm_provider(
provider: &str,
api_key: String,
model: String,
base_url: Option<String>,
temperature: Option<f32>,
max_tokens: Option<u32>,
) -> Result<Box<dyn LLMProvider>> {
match provider.to_lowercase().as_str() {
"openai" => Ok(Box::new(OpenAIProvider::new(
api_key,
model,
base_url,
temperature,
max_tokens,
))),
"anthropic" => Ok(Box::new(AnthropicProvider::new(
api_key,
model,
temperature,
max_tokens,
))),
"google" | "gemini" => Ok(Box::new(GeminiProvider::new(
api_key,
model,
temperature,
max_tokens,
))),
"ollama" => Ok(Box::new(OllamaProvider::new(
model,
base_url,
temperature,
max_tokens,
))),
_ => anyhow::bail!("Unsupported LLM provider: {}", provider),
}
}
#[allow(dead_code)]
pub struct OpenAIProvider {
api_key: String,
model: String,
base_url: String,
temperature: f32,
max_tokens: u32,
client: reqwest::Client,
}
impl OpenAIProvider {
pub fn new(
api_key: String,
model: String,
base_url: Option<String>,
temperature: Option<f32>,
max_tokens: Option<u32>,
) -> Self {
let client = reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::AUTHORIZATION,
format!("Bearer {}", api_key).parse().unwrap(),
);
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
headers
})
.build()
.unwrap();
Self {
api_key,
model,
base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
temperature: temperature.unwrap_or(0.7),
max_tokens: max_tokens.unwrap_or(1000),
client,
}
}
}
#[async_trait]
impl LLMProvider for OpenAIProvider {
async fn chat(
&self,
messages: &[Message],
options: Option<ChatOptions>,
) -> Result<ChatResponse> {
let opts = options.unwrap_or(ChatOptions {
temperature: Some(self.temperature),
max_tokens: Some(self.max_tokens),
stream: None,
});
let payload = serde_json::json!({
"model": self.model,
"messages": messages,
"temperature": opts.temperature.unwrap_or(self.temperature),
"max_tokens": opts.max_tokens.unwrap_or(self.max_tokens),
});
let response = self
.client
.post(format!("{}/chat/completions", self.base_url))
.json(&payload)
.send()
.await?
.error_for_status()?;
let json: Value = response.json().await?;
let content = json["choices"][0]["message"]["content"]
.as_str()
.context("Invalid OpenAI response format")?
.to_string();
let usage = json.get("usage").map(|u| Usage {
prompt_tokens: u["prompt_tokens"].as_u64().unwrap_or(0) as u32,
completion_tokens: u["completion_tokens"].as_u64().unwrap_or(0) as u32,
total_tokens: u["total_tokens"].as_u64().unwrap_or(0) as u32,
});
Ok(ChatResponse { content, usage })
}
fn get_provider_name(&self) -> &str {
"openai"
}
fn get_model_name(&self) -> &str {
&self.model
}
}
#[allow(dead_code)]
pub struct AnthropicProvider {
api_key: String,
model: String,
temperature: f32,
max_tokens: u32,
client: reqwest::Client,
}
impl AnthropicProvider {
pub fn new(
api_key: String,
model: String,
temperature: Option<f32>,
max_tokens: Option<u32>,
) -> Self {
let client = reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("x-api-key", api_key.parse().unwrap());
headers.insert("anthropic-version", "2023-06-01".parse().unwrap());
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
headers
})
.build()
.unwrap();
Self {
api_key,
model,
temperature: temperature.unwrap_or(0.7),
max_tokens: max_tokens.unwrap_or(1000),
client,
}
}
}
#[async_trait]
impl LLMProvider for AnthropicProvider {
async fn chat(
&self,
messages: &[Message],
options: Option<ChatOptions>,
) -> Result<ChatResponse> {
let opts = options.unwrap_or(ChatOptions {
temperature: Some(self.temperature),
max_tokens: Some(self.max_tokens),
stream: None,
});
let (system_msg, chat_messages): (Option<&Message>, Vec<&Message>) = {
let sys = messages.iter().find(|m| m.role == "system");
let rest: Vec<&Message> = messages.iter().filter(|m| m.role != "system").collect();
(sys, rest)
};
let mut payload = serde_json::json!({
"model": self.model,
"messages": chat_messages,
"temperature": opts.temperature.unwrap_or(self.temperature),
"max_tokens": opts.max_tokens.unwrap_or(self.max_tokens),
});
if let Some(sys) = system_msg {
payload["system"] = serde_json::json!(sys.content);
}
let response = self
.client
.post("https://api.anthropic.com/v1/messages")
.json(&payload)
.send()
.await?
.error_for_status()?;
let json: Value = response.json().await?;
let content = json["content"][0]["text"]
.as_str()
.context("Invalid Anthropic response format")?
.to_string();
let usage = json.get("usage").map(|u| Usage {
prompt_tokens: u["input_tokens"].as_u64().unwrap_or(0) as u32,
completion_tokens: u["output_tokens"].as_u64().unwrap_or(0) as u32,
total_tokens: (u["input_tokens"].as_u64().unwrap_or(0)
+ u["output_tokens"].as_u64().unwrap_or(0)) as u32,
});
Ok(ChatResponse { content, usage })
}
fn get_provider_name(&self) -> &str {
"anthropic"
}
fn get_model_name(&self) -> &str {
&self.model
}
}
pub struct GeminiProvider {
api_key: String,
model: String,
temperature: f32,
max_tokens: u32,
client: reqwest::Client,
}
impl GeminiProvider {
pub fn new(
api_key: String,
model: String,
temperature: Option<f32>,
max_tokens: Option<u32>,
) -> Self {
Self {
api_key,
model,
temperature: temperature.unwrap_or(0.7),
max_tokens: max_tokens.unwrap_or(1000),
client: reqwest::Client::new(),
}
}
}
#[async_trait]
impl LLMProvider for GeminiProvider {
async fn chat(
&self,
messages: &[Message],
options: Option<ChatOptions>,
) -> Result<ChatResponse> {
let opts = options.unwrap_or(ChatOptions {
temperature: Some(self.temperature),
max_tokens: Some(self.max_tokens),
stream: None,
});
let text = messages
.iter()
.map(|m| format!("{}: {}", m.role, m.content))
.collect::<Vec<_>>()
.join("\n\n");
let payload = serde_json::json!({
"contents": [{
"parts": [{
"text": text
}]
}],
"generationConfig": {
"temperature": opts.temperature.unwrap_or(self.temperature),
"maxOutputTokens": opts.max_tokens.unwrap_or(self.max_tokens),
}
});
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
self.model, self.api_key
);
let response = self
.client
.post(&url)
.json(&payload)
.send()
.await?
.error_for_status()?;
let json: Value = response.json().await?;
let content = json["candidates"][0]["content"]["parts"][0]["text"]
.as_str()
.context("Invalid Gemini response format")?
.to_string();
Ok(ChatResponse {
content,
usage: None,
})
}
fn get_provider_name(&self) -> &str {
"google"
}
fn get_model_name(&self) -> &str {
&self.model
}
}
pub struct OllamaProvider {
model: String,
base_url: String,
temperature: f32,
max_tokens: u32,
client: reqwest::Client,
}
impl OllamaProvider {
pub fn new(
model: String,
base_url: Option<String>,
temperature: Option<f32>,
max_tokens: Option<u32>,
) -> Self {
Self {
model,
base_url: base_url.unwrap_or_else(|| "http://localhost:11434".to_string()),
temperature: temperature.unwrap_or(0.7),
max_tokens: max_tokens.unwrap_or(1000),
client: reqwest::Client::new(),
}
}
}
#[async_trait]
impl LLMProvider for OllamaProvider {
async fn chat(
&self,
messages: &[Message],
options: Option<ChatOptions>,
) -> Result<ChatResponse> {
let opts = options.unwrap_or(ChatOptions {
temperature: Some(self.temperature),
max_tokens: Some(self.max_tokens),
stream: None,
});
let payload = serde_json::json!({
"model": self.model,
"messages": messages,
"options": {
"temperature": opts.temperature.unwrap_or(self.temperature),
"num_predict": opts.max_tokens.unwrap_or(self.max_tokens),
},
"stream": false,
});
let response = self
.client
.post(format!("{}/api/chat", self.base_url))
.json(&payload)
.send()
.await?
.error_for_status()?;
let json: Value = response.json().await?;
let content = json["message"]["content"]
.as_str()
.context("Invalid Ollama response format")?
.to_string();
Ok(ChatResponse {
content,
usage: None,
})
}
fn get_provider_name(&self) -> &str {
"ollama"
}
fn get_model_name(&self) -> &str {
&self.model
}
}