1use anyhow::Result;
2use async_trait::async_trait;
3use serde_json::{json, Value};
4use std::collections::HashMap;
5
6use super::api_client::{ApiClient, AuthMethod};
7use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage};
8use super::embedding::EmbeddingCapable;
9use super::errors::ProviderError;
10use super::retry::ProviderRetry;
11use super::utils::{get_model, handle_response_openai_compat, ImageFormat, RequestLog};
12use crate::conversation::message::Message;
13
14use crate::model::ModelConfig;
15use rmcp::model::Tool;
16
17pub const LITELLM_DEFAULT_MODEL: &str = "gpt-4o-mini";
18pub const LITELLM_DOC_URL: &str = "https://docs.litellm.ai/docs/";
19
20#[derive(Debug, serde::Serialize)]
21pub struct LiteLLMProvider {
22 #[serde(skip)]
23 api_client: ApiClient,
24 base_path: String,
25 model: ModelConfig,
26 #[serde(skip)]
27 name: String,
28}
29
30impl LiteLLMProvider {
31 pub async fn from_env(model: ModelConfig) -> Result<Self> {
32 let config = crate::config::Config::global();
33 let secrets = config
34 .get_secrets("LITELLM_API_KEY", &["LITELLM_CUSTOM_HEADERS"])
35 .unwrap_or_default();
36 let api_key = secrets.get("LITELLM_API_KEY").cloned().unwrap_or_default();
37 let host: String = config
38 .get_param("LITELLM_HOST")
39 .unwrap_or_else(|_| "https://api.litellm.ai".to_string());
40 let base_path: String = config
41 .get_param("LITELLM_BASE_PATH")
42 .unwrap_or_else(|_| "v1/chat/completions".to_string());
43 let custom_headers: Option<HashMap<String, String>> = secrets
44 .get("LITELLM_CUSTOM_HEADERS")
45 .cloned()
46 .map(parse_custom_headers);
47 let timeout_secs: u64 = config.get_param("LITELLM_TIMEOUT").unwrap_or(600);
48
49 let auth = if api_key.is_empty() {
50 AuthMethod::Custom(Box::new(NoAuth))
51 } else {
52 AuthMethod::BearerToken(api_key)
53 };
54
55 let mut api_client =
56 ApiClient::with_timeout(host, auth, std::time::Duration::from_secs(timeout_secs))?;
57
58 if let Some(headers) = custom_headers {
59 let mut header_map = reqwest::header::HeaderMap::new();
60 for (key, value) in headers {
61 let header_name = reqwest::header::HeaderName::from_bytes(key.as_bytes())?;
62 let header_value = reqwest::header::HeaderValue::from_str(&value)?;
63 header_map.insert(header_name, header_value);
64 }
65 api_client = api_client.with_headers(header_map)?;
66 }
67
68 Ok(Self {
69 api_client,
70 base_path,
71 model,
72 name: Self::metadata().name,
73 })
74 }
75
76 async fn fetch_models(&self) -> Result<Vec<ModelInfo>, ProviderError> {
77 let response = self.api_client.response_get("model/info").await?;
78
79 if !response.status().is_success() {
80 return Err(ProviderError::RequestFailed(format!(
81 "Models endpoint returned status: {}",
82 response.status()
83 )));
84 }
85
86 let response_json: Value = response.json().await.map_err(|e| {
87 ProviderError::RequestFailed(format!("Failed to parse models response: {}", e))
88 })?;
89
90 let models_data = response_json["data"].as_array().ok_or_else(|| {
91 ProviderError::RequestFailed("Missing data field in models response".to_string())
92 })?;
93
94 let mut models = Vec::new();
95 for model_data in models_data {
96 if let Some(model_name) = model_data["model_name"].as_str() {
97 if model_name.contains("/*") {
98 continue;
99 }
100
101 let model_info = &model_data["model_info"];
102 let context_length =
103 model_info["max_input_tokens"].as_u64().unwrap_or(128000) as usize;
104 let supports_cache_control = model_info["supports_prompt_caching"].as_bool();
105
106 let mut model_info_obj = ModelInfo::new(model_name, context_length);
107 model_info_obj.supports_cache_control = supports_cache_control;
108 models.push(model_info_obj);
109 }
110 }
111
112 Ok(models)
113 }
114
115 async fn post(&self, payload: &Value) -> Result<Value, ProviderError> {
116 let response = self
117 .api_client
118 .response_post(&self.base_path, payload)
119 .await?;
120 handle_response_openai_compat(response).await
121 }
122}
123
124struct NoAuth;
126
127#[async_trait]
128impl super::api_client::AuthProvider for NoAuth {
129 async fn get_auth_header(&self) -> Result<(String, String)> {
130 Ok(("X-No-Auth".to_string(), "true".to_string()))
132 }
133}
134
135#[async_trait]
136impl Provider for LiteLLMProvider {
137 fn metadata() -> ProviderMetadata {
138 ProviderMetadata::new(
139 "litellm",
140 "LiteLLM",
141 "LiteLLM proxy supporting multiple models with automatic prompt caching",
142 LITELLM_DEFAULT_MODEL,
143 vec![],
144 LITELLM_DOC_URL,
145 vec![
146 ConfigKey::new("LITELLM_API_KEY", true, true, None),
147 ConfigKey::new("LITELLM_HOST", true, false, Some("http://localhost:4000")),
148 ConfigKey::new(
149 "LITELLM_BASE_PATH",
150 true,
151 false,
152 Some("v1/chat/completions"),
153 ),
154 ConfigKey::new("LITELLM_CUSTOM_HEADERS", false, true, None),
155 ConfigKey::new("LITELLM_TIMEOUT", false, false, Some("600")),
156 ],
157 )
158 }
159
160 fn get_name(&self) -> &str {
161 &self.name
162 }
163
164 fn get_model_config(&self) -> ModelConfig {
165 self.model.clone()
166 }
167
168 #[tracing::instrument(skip_all, name = "provider_complete")]
169 async fn complete_with_model(
170 &self,
171 model_config: &ModelConfig,
172 system: &str,
173 messages: &[Message],
174 tools: &[Tool],
175 ) -> Result<(Message, ProviderUsage), ProviderError> {
176 let mut payload = super::formats::openai::create_request(
177 model_config,
178 system,
179 messages,
180 tools,
181 &ImageFormat::OpenAi,
182 false,
183 )?;
184
185 if self.supports_cache_control().await {
186 payload = update_request_for_cache_control(&payload);
187 }
188
189 let response = self
190 .with_retry(|| async {
191 let payload_clone = payload.clone();
192 self.post(&payload_clone).await
193 })
194 .await?;
195
196 let message = super::formats::openai::response_to_message(&response)?;
197 let usage = super::formats::openai::get_usage(&response);
198 let response_model = get_model(&response);
199 let mut log = RequestLog::start(model_config, &payload)?;
200 log.write(&response, Some(&usage))?;
201 Ok((message, ProviderUsage::new(response_model, usage)))
202 }
203
204 fn supports_embeddings(&self) -> bool {
205 true
206 }
207
208 async fn supports_cache_control(&self) -> bool {
209 if let Ok(models) = self.fetch_models().await {
210 if let Some(model_info) = models.iter().find(|m| m.name == self.model.model_name) {
211 return model_info.supports_cache_control.unwrap_or(false);
212 }
213 }
214
215 self.model.model_name.to_lowercase().contains("claude")
216 }
217
218 async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
219 match self.fetch_models().await {
220 Ok(models) => {
221 let model_names: Vec<String> = models.into_iter().map(|m| m.name).collect();
222 Ok(Some(model_names))
223 }
224 Err(e) => {
225 tracing::warn!("Failed to fetch models from LiteLLM: {}", e);
226 Ok(None)
227 }
228 }
229 }
230}
231
232#[async_trait]
233impl EmbeddingCapable for LiteLLMProvider {
234 async fn create_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, anyhow::Error> {
235 let embedding_model = std::env::var("ASTER_EMBEDDING_MODEL")
236 .unwrap_or_else(|_| "text-embedding-3-small".to_string());
237
238 let payload = json!({
239 "input": texts,
240 "model": embedding_model,
241 "encoding_format": "float"
242 });
243
244 let response = self
245 .api_client
246 .response_post("v1/embeddings", &payload)
247 .await?;
248 let response_text = response.text().await?;
249 let response_json: Value = serde_json::from_str(&response_text)?;
250
251 let data = response_json["data"]
252 .as_array()
253 .ok_or_else(|| anyhow::anyhow!("Missing data field"))?;
254
255 let mut embeddings = Vec::new();
256 for item in data {
257 let embedding: Vec<f32> = item["embedding"]
258 .as_array()
259 .ok_or_else(|| anyhow::anyhow!("Missing embedding field"))?
260 .iter()
261 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
262 .collect();
263 embeddings.push(embedding);
264 }
265
266 Ok(embeddings)
267 }
268}
269
270pub fn update_request_for_cache_control(original_payload: &Value) -> Value {
273 let mut payload = original_payload.clone();
274
275 if let Some(messages_spec) = payload
276 .as_object_mut()
277 .and_then(|obj| obj.get_mut("messages"))
278 .and_then(|messages| messages.as_array_mut())
279 {
280 let mut user_count = 0;
281 for message in messages_spec.iter_mut().rev() {
282 if message.get("role") == Some(&json!("user")) {
283 if let Some(content) = message.get_mut("content") {
284 if let Some(content_str) = content.as_str() {
285 *content = json!([{
286 "type": "text",
287 "text": content_str,
288 "cache_control": { "type": "ephemeral" }
289 }]);
290 }
291 }
292 user_count += 1;
293 if user_count >= 2 {
294 break;
295 }
296 }
297 }
298
299 if let Some(system_message) = messages_spec
300 .iter_mut()
301 .find(|msg| msg.get("role") == Some(&json!("system")))
302 {
303 if let Some(content) = system_message.get_mut("content") {
304 if let Some(content_str) = content.as_str() {
305 *system_message = json!({
306 "role": "system",
307 "content": [{
308 "type": "text",
309 "text": content_str,
310 "cache_control": { "type": "ephemeral" }
311 }]
312 });
313 }
314 }
315 }
316 }
317
318 if let Some(tools_spec) = payload
319 .as_object_mut()
320 .and_then(|obj| obj.get_mut("tools"))
321 .and_then(|tools| tools.as_array_mut())
322 {
323 if let Some(last_tool) = tools_spec.last_mut() {
324 if let Some(function) = last_tool.get_mut("function") {
325 function
326 .as_object_mut()
327 .unwrap()
328 .insert("cache_control".to_string(), json!({ "type": "ephemeral" }));
329 }
330 }
331 }
332 payload
333}
334
335fn parse_custom_headers(headers_str: String) -> HashMap<String, String> {
336 let mut headers = HashMap::new();
337 for line in headers_str.lines() {
338 if let Some((key, value)) = line.split_once(':') {
339 headers.insert(key.trim().to_string(), value.trim().to_string());
340 }
341 }
342 headers
343}