#![allow(dead_code)]
use super::{AIClient, AIConfig, AIResponse, AIContext, ModelConfig};
use anyhow::Result;
use bytes::Bytes;
use futures_util::stream::StreamExt;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
use serde::{Deserialize, Serialize};
use tracing::debug;
pub struct LiteLLMClient {
client: reqwest::Client,
api_key: String,
base_url: String,
model: String,
max_tokens: u32,
temperature: f32,
}
#[derive(Debug, Serialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<Message>,
max_tokens: u32,
temperature: f32,
stream: bool,
}
#[derive(Debug, Serialize, Deserialize)]
struct Message {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct ChatCompletionResponse {
id: String,
object: String,
created: i64,
model: String,
choices: Vec<Choice>,
usage: Option<Usage>,
}
#[derive(Debug, Deserialize)]
struct Choice {
index: i32,
message: Option<Message>,
delta: Option<Delta>,
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct Delta {
content: Option<String>,
}
#[derive(Debug, Deserialize)]
struct Usage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}
impl LiteLLMClient {
pub fn new(config: &AIConfig) -> Result<Self> {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(60))
.build()?;
let base_url = config.api_base.clone()
.unwrap_or_else(|| "http://localhost:4000".to_string());
Ok(Self {
client,
api_key: config.api_key.clone(),
base_url,
model: config.model.clone(),
max_tokens: config.max_tokens,
temperature: config.temperature,
})
}
fn build_headers(&self) -> Result<HeaderMap> {
let mut headers = HeaderMap::new();
headers.insert(
CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
if !self.api_key.is_empty() {
let auth_header = format!("Bearer {}", self.api_key);
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&auth_header)?,
);
}
Ok(headers)
}
}
#[async_trait::async_trait]
impl AIClient for LiteLLMClient {
async fn complete(
&self,
prompt: &str,
_context: Option<&AIContext>,
) -> Result<AIResponse> {
debug!("Sending completion request to LiteLLM: {}", self.model);
let url = format!("{}/chat/completions", self.base_url);
let messages = vec![
Message {
role: "system".to_string(),
content: "You are a helpful AI assistant for developers.".to_string(),
},
Message {
role: "user".to_string(),
content: prompt.to_string(),
},
];
let request = ChatCompletionRequest {
model: self.model.clone(),
messages,
max_tokens: self.max_tokens,
temperature: self.temperature,
stream: false,
};
let response = self
.client
.post(&url)
.headers(self.build_headers()?)
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let error_text = response.text().await?;
anyhow::bail!("LiteLLM API error: {}", error_text);
}
let completion: ChatCompletionResponse = response.json().await?;
let content = completion
.choices
.get(0)
.and_then(|c| c.message.as_ref())
.map(|m| m.content.clone())
.unwrap_or_default();
let tokens_used = completion.usage.as_ref().map(|u| u.total_tokens);
Ok(AIResponse {
content,
tokens_used,
model: completion.model,
finish_reason: completion.choices.get(0).and_then(|c| c.finish_reason.clone()),
cost_estimate: None, })
}
async fn complete_stream(
&self,
prompt: &str,
_context: Option<&AIContext>,
) -> Result<tokio::sync::mpsc::Receiver<Result<String>>> {
let (tx, rx) = tokio::sync::mpsc::channel(100);
let url = format!("{}/chat/completions", self.base_url);
let api_key = self.api_key.clone();
let model = self.model.clone();
let max_tokens = self.max_tokens;
let temperature = self.temperature;
let prompt = prompt.to_string();
let client = self.client.clone();
tokio::spawn(async move {
let messages = vec![
Message {
role: "system".to_string(),
content: "You are a helpful AI assistant for developers.".to_string(),
},
Message {
role: "user".to_string(),
content: prompt,
},
];
let request = ChatCompletionRequest {
model,
messages,
max_tokens,
temperature,
stream: true,
};
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
if !api_key.is_empty() {
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {}", api_key)).unwrap(),
);
}
let response = match client
.post(&url)
.headers(headers)
.json(&request)
.send()
.await
{
Ok(r) => r,
Err(e) => {
let _ = tx.send(Err(e.into())).await;
return;
}
};
if !response.status().is_success() {
let error_text = match response.text().await {
Ok(t) => t,
Err(e) => format!("Failed to read error: {}", e),
};
let _ = tx.send(Err(anyhow::anyhow!("API error: {}", error_text))).await;
return;
}
let mut stream: futures_util::stream::BoxStream<'_, Result<Bytes, reqwest::Error>> = response.bytes_stream().boxed();
while let Some(chunk) = stream.next().await {
match chunk {
Ok(bytes) => {
let text = String::from_utf8_lossy(&bytes);
for line in text.lines() {
if line.starts_with("data: ") {
let data = &line[6..];
if data == "[DONE]" {
return;
}
match serde_json::from_str::<ChatCompletionResponse>(data) {
Ok(resp) => {
if let Some(choice) = resp.choices.get(0) {
if let Some(delta) = &choice.delta {
if let Some(content) = &delta.content {
if tx.send(Ok(content.clone())).await.is_err() {
return;
}
}
}
}
}
Err(_) => continue,
}
}
}
}
Err(e) => {
let _ = tx.send(Err::<String, _>(anyhow::anyhow!("Stream error: {}", e))).await;
return;
}
}
}
});
Ok(rx)
}
fn name(&self) -> &str {
"LiteLLM"
}
fn model_info(&self) -> ModelConfig {
ModelConfig {
name: self.model.clone(),
context_window: 128000, supports_functions: true,
supports_vision: self.model.contains("vision") || self.model.contains("gpt-4-turbo"),
cost_per_1k_input: 0.0, cost_per_1k_output: 0.0,
}
}
}