use crate::{RsllmError, RsllmResult, ChatMessage, ChatResponse, StreamChunk};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::str::FromStr;
use url::Url;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Provider {
OpenAI,
Claude,
Ollama,
}
impl Provider {
pub fn default_base_url(&self) -> Url {
match self {
Provider::OpenAI => "https://api.openai.com/v1".parse().unwrap(),
Provider::Claude => "https://api.anthropic.com/v1".parse().unwrap(),
Provider::Ollama => "http://localhost:11434/api".parse().unwrap(),
}
}
pub fn default_models(&self) -> Vec<&'static str> {
match self {
Provider::OpenAI => vec![
"gpt-4o",
"gpt-4o-mini",
"gpt-4-turbo",
"gpt-4",
"gpt-3.5-turbo",
"gpt-3.5-turbo-instruct",
],
Provider::Claude => vec![
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-20241022",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
],
Provider::Ollama => vec![
"llama3.1",
"llama3.1:70b",
"llama3.1:405b",
"mistral",
"codellama",
"vicuna",
],
}
}
pub fn default_model(&self) -> &'static str {
match self {
Provider::OpenAI => "gpt-4o-mini",
Provider::Claude => "claude-3-5-haiku-20241022",
Provider::Ollama => "llama3.1",
}
}
pub fn supports_streaming(&self) -> bool {
match self {
Provider::OpenAI => true,
Provider::Claude => true,
Provider::Ollama => true,
}
}
pub fn requires_auth(&self) -> bool {
match self {
Provider::OpenAI => true,
Provider::Claude => true,
Provider::Ollama => false, }
}
}
impl fmt::Display for Provider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Provider::OpenAI => write!(f, "openai"),
Provider::Claude => write!(f, "claude"),
Provider::Ollama => write!(f, "ollama"),
}
}
}
impl FromStr for Provider {
type Err = RsllmError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"openai" | "gpt" => Ok(Provider::OpenAI),
"claude" | "anthropic" => Ok(Provider::Claude),
"ollama" => Ok(Provider::Ollama),
_ => Err(RsllmError::configuration(format!("Unknown provider: {}", s))),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderConfig {
pub provider: Provider,
pub api_key: Option<String>,
pub base_url: Option<Url>,
pub organization_id: Option<String>,
}
impl Default for ProviderConfig {
fn default() -> Self {
Self {
provider: Provider::OpenAI,
api_key: None,
base_url: None,
organization_id: None,
}
}
}
#[async_trait]
pub trait LLMProvider: Send + Sync {
fn name(&self) -> &str;
fn provider_type(&self) -> Provider;
fn supported_models(&self) -> Vec<String>;
async fn health_check(&self) -> RsllmResult<bool>;
async fn chat_completion(
&self,
messages: Vec<ChatMessage>,
model: Option<&str>,
temperature: Option<f32>,
max_tokens: Option<u32>,
) -> RsllmResult<ChatResponse>;
async fn chat_completion_stream(
&self,
messages: Vec<ChatMessage>,
model: Option<String>,
temperature: Option<f32>,
max_tokens: Option<u32>,
) -> RsllmResult<Box<dyn futures_util::Stream<Item = RsllmResult<StreamChunk>> + Send + Unpin>>;
}
#[cfg(feature = "openai")]
pub struct OpenAIProvider {
client: reqwest::Client,
api_key: String,
base_url: Url,
organization_id: Option<String>,
}
#[cfg(feature = "openai")]
impl OpenAIProvider {
pub fn new(
api_key: String,
base_url: Option<Url>,
organization_id: Option<String>,
) -> RsllmResult<Self> {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.map_err(|e| RsllmError::configuration_with_source("Failed to create HTTP client", e))?;
Ok(Self {
client,
api_key,
base_url: base_url.unwrap_or_else(|| Provider::OpenAI.default_base_url()),
organization_id,
})
}
fn build_headers(&self) -> reqwest::header::HeaderMap {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::AUTHORIZATION,
format!("Bearer {}", self.api_key).parse().unwrap(),
);
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
if let Some(org_id) = &self.organization_id {
headers.insert(
"OpenAI-Organization",
org_id.parse().unwrap(),
);
}
headers
}
}
#[cfg(feature = "openai")]
#[async_trait]
impl LLMProvider for OpenAIProvider {
fn name(&self) -> &str {
"OpenAI"
}
fn provider_type(&self) -> Provider {
Provider::OpenAI
}
fn supported_models(&self) -> Vec<String> {
Provider::OpenAI.default_models().iter().map(|s| s.to_string()).collect()
}
async fn health_check(&self) -> RsllmResult<bool> {
let url = self.base_url.join("/models")?;
let response = self.client
.get(url)
.headers(self.build_headers())
.send()
.await?;
Ok(response.status().is_success())
}
async fn chat_completion(
&self,
messages: Vec<ChatMessage>,
model: Option<&str>,
temperature: Option<f32>,
max_tokens: Option<u32>,
) -> RsllmResult<ChatResponse> {
let url = self.base_url.join("/chat/completions")?;
let mut request_body = serde_json::json!({
"model": model.unwrap_or(Provider::OpenAI.default_model()),
"messages": messages,
});
if let Some(temp) = temperature {
request_body["temperature"] = temp.into();
}
if let Some(max_tokens) = max_tokens {
request_body["max_tokens"] = max_tokens.into();
}
let response = self.client
.post(url)
.headers(self.build_headers())
.json(&request_body)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
return Err(RsllmError::api(
"OpenAI",
format!("API request failed: {}", error_text),
status.as_str(),
));
}
let response_data: serde_json::Value = response.json().await?;
let content = response_data["choices"][0]["message"]["content"]
.as_str()
.unwrap_or("")
.to_string();
Ok(ChatResponse::new(content, model.unwrap_or(Provider::OpenAI.default_model()))
.with_finish_reason("stop"))
}
async fn chat_completion_stream(
&self,
messages: Vec<ChatMessage>,
model: Option<String>,
temperature: Option<f32>,
max_tokens: Option<u32>,
) -> RsllmResult<Box<dyn futures_util::Stream<Item = RsllmResult<StreamChunk>> + Send + Unpin>> {
use futures_util::stream;
let _url = self.base_url.join("/chat/completions")?;
let model_name = model.unwrap_or_else(|| Provider::OpenAI.default_model().to_string());
let mut _request_body = serde_json::json!({
"model": &model_name,
"messages": messages,
"stream": true,
});
if let Some(temp) = temperature {
_request_body["temperature"] = temp.into();
}
if let Some(max_tokens) = max_tokens {
_request_body["max_tokens"] = max_tokens.into();
}
let chunks = vec![
"Hello",
" there!",
" This",
" is",
" a",
" streaming",
" response",
" from",
" OpenAI.",
];
let stream = stream::iter(chunks.into_iter().enumerate().map(move |(i, chunk)| {
let _ = tokio::time::sleep(std::time::Duration::from_millis(100));
if i == 8 { Ok(StreamChunk::done(&model_name)
.with_finish_reason("stop"))
} else {
Ok(StreamChunk::delta(chunk, &model_name))
}
}));
Ok(Box::new(stream))
}
}
#[cfg(feature = "ollama")]
pub struct OllamaProvider {
client: reqwest::Client,
base_url: Url,
}
#[cfg(feature = "ollama")]
impl OllamaProvider {
pub fn new(base_url: Option<Url>) -> RsllmResult<Self> {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(60)) .build()
.map_err(|e| RsllmError::configuration_with_source("Failed to create HTTP client", e))?;
Ok(Self {
client,
base_url: base_url.unwrap_or_else(|| Provider::Ollama.default_base_url()),
})
}
}
#[cfg(feature = "ollama")]
#[async_trait]
impl LLMProvider for OllamaProvider {
fn name(&self) -> &str {
"Ollama"
}
fn provider_type(&self) -> Provider {
Provider::Ollama
}
fn supported_models(&self) -> Vec<String> {
Provider::Ollama.default_models().iter().map(|s| s.to_string()).collect()
}
async fn health_check(&self) -> RsllmResult<bool> {
let url = self.base_url.join("/tags")?;
let response = self.client.get(url).send().await?;
Ok(response.status().is_success())
}
async fn chat_completion(
&self,
messages: Vec<ChatMessage>,
model: Option<&str>,
temperature: Option<f32>,
_max_tokens: Option<u32>,
) -> RsllmResult<ChatResponse> {
let url = self.base_url.join("/chat")?;
let mut request_body = serde_json::json!({
"model": model.unwrap_or(Provider::Ollama.default_model()),
"messages": messages,
"stream": false,
});
if let Some(temp) = temperature {
request_body["options"] = serde_json::json!({
"temperature": temp
});
}
let response = self.client
.post(url)
.json(&request_body)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
return Err(RsllmError::api(
"Ollama",
format!("API request failed: {}", error_text),
status.as_str(),
));
}
let response_data: serde_json::Value = response.json().await?;
let content = response_data["message"]["content"]
.as_str()
.unwrap_or("")
.to_string();
Ok(ChatResponse::new(content, model.unwrap_or(Provider::Ollama.default_model()))
.with_finish_reason("stop"))
}
async fn chat_completion_stream(
&self,
messages: Vec<ChatMessage>,
model: Option<String>,
temperature: Option<f32>,
_max_tokens: Option<u32>,
) -> RsllmResult<Box<dyn futures_util::Stream<Item = RsllmResult<StreamChunk>> + Send + Unpin>> {
use futures_util::stream;
let _url = self.base_url.join("/chat")?;
let model_name = model.unwrap_or_else(|| Provider::Ollama.default_model().to_string());
let mut _request_body = serde_json::json!({
"model": &model_name,
"messages": messages,
"stream": true,
});
if let Some(temp) = temperature {
_request_body["options"] = serde_json::json!({
"temperature": temp
});
}
let chunks = vec![
"This",
" is",
" a",
" response",
" from",
" Ollama",
" running",
" locally.",
];
let stream = stream::iter(chunks.into_iter().enumerate().map(move |(i, chunk)| {
tokio::time::sleep(std::time::Duration::from_millis(150));
if i == 7 { Ok(StreamChunk::done(&model_name)
.with_finish_reason("stop"))
} else {
Ok(StreamChunk::delta(chunk, &model_name))
}
}));
Ok(Box::new(stream))
}
}