Skip to main content

aster/providers/
litellm.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use serde_json::{json, Value};
4use std::collections::HashMap;
5
6use super::api_client::{ApiClient, AuthMethod};
7use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage};
8use super::embedding::EmbeddingCapable;
9use super::errors::ProviderError;
10use super::retry::ProviderRetry;
11use super::utils::{get_model, handle_response_openai_compat, ImageFormat, RequestLog};
12use crate::conversation::message::Message;
13
14use crate::model::ModelConfig;
15use rmcp::model::Tool;
16
17pub const LITELLM_DEFAULT_MODEL: &str = "gpt-4o-mini";
18pub const LITELLM_DOC_URL: &str = "https://docs.litellm.ai/docs/";
19
20#[derive(Debug, serde::Serialize)]
21pub struct LiteLLMProvider {
22    #[serde(skip)]
23    api_client: ApiClient,
24    base_path: String,
25    model: ModelConfig,
26    #[serde(skip)]
27    name: String,
28}
29
30impl LiteLLMProvider {
31    pub async fn from_env(model: ModelConfig) -> Result<Self> {
32        let config = crate::config::Config::global();
33        let secrets = config
34            .get_secrets("LITELLM_API_KEY", &["LITELLM_CUSTOM_HEADERS"])
35            .unwrap_or_default();
36        let api_key = secrets.get("LITELLM_API_KEY").cloned().unwrap_or_default();
37        let host: String = config
38            .get_param("LITELLM_HOST")
39            .unwrap_or_else(|_| "https://api.litellm.ai".to_string());
40        let base_path: String = config
41            .get_param("LITELLM_BASE_PATH")
42            .unwrap_or_else(|_| "v1/chat/completions".to_string());
43        let custom_headers: Option<HashMap<String, String>> = secrets
44            .get("LITELLM_CUSTOM_HEADERS")
45            .cloned()
46            .map(parse_custom_headers);
47        let timeout_secs: u64 = config.get_param("LITELLM_TIMEOUT").unwrap_or(600);
48
49        let auth = if api_key.is_empty() {
50            AuthMethod::Custom(Box::new(NoAuth))
51        } else {
52            AuthMethod::BearerToken(api_key)
53        };
54
55        let mut api_client =
56            ApiClient::with_timeout(host, auth, std::time::Duration::from_secs(timeout_secs))?;
57
58        if let Some(headers) = custom_headers {
59            let mut header_map = reqwest::header::HeaderMap::new();
60            for (key, value) in headers {
61                let header_name = reqwest::header::HeaderName::from_bytes(key.as_bytes())?;
62                let header_value = reqwest::header::HeaderValue::from_str(&value)?;
63                header_map.insert(header_name, header_value);
64            }
65            api_client = api_client.with_headers(header_map)?;
66        }
67
68        Ok(Self {
69            api_client,
70            base_path,
71            model,
72            name: Self::metadata().name,
73        })
74    }
75
76    async fn fetch_models(&self) -> Result<Vec<ModelInfo>, ProviderError> {
77        let response = self.api_client.response_get("model/info").await?;
78
79        if !response.status().is_success() {
80            return Err(ProviderError::RequestFailed(format!(
81                "Models endpoint returned status: {}",
82                response.status()
83            )));
84        }
85
86        let response_json: Value = response.json().await.map_err(|e| {
87            ProviderError::RequestFailed(format!("Failed to parse models response: {}", e))
88        })?;
89
90        let models_data = response_json["data"].as_array().ok_or_else(|| {
91            ProviderError::RequestFailed("Missing data field in models response".to_string())
92        })?;
93
94        let mut models = Vec::new();
95        for model_data in models_data {
96            if let Some(model_name) = model_data["model_name"].as_str() {
97                if model_name.contains("/*") {
98                    continue;
99                }
100
101                let model_info = &model_data["model_info"];
102                let context_length =
103                    model_info["max_input_tokens"].as_u64().unwrap_or(128000) as usize;
104                let supports_cache_control = model_info["supports_prompt_caching"].as_bool();
105
106                let mut model_info_obj = ModelInfo::new(model_name, context_length);
107                model_info_obj.supports_cache_control = supports_cache_control;
108                models.push(model_info_obj);
109            }
110        }
111
112        Ok(models)
113    }
114
115    async fn post(&self, payload: &Value) -> Result<Value, ProviderError> {
116        let response = self
117            .api_client
118            .response_post(&self.base_path, payload)
119            .await?;
120        handle_response_openai_compat(response).await
121    }
122}
123
124// No authentication provider for LiteLLM when API key is not provided
125struct NoAuth;
126
127#[async_trait]
128impl super::api_client::AuthProvider for NoAuth {
129    async fn get_auth_header(&self) -> Result<(String, String)> {
130        // Return a dummy header that won't be used
131        Ok(("X-No-Auth".to_string(), "true".to_string()))
132    }
133}
134
135#[async_trait]
136impl Provider for LiteLLMProvider {
137    fn metadata() -> ProviderMetadata {
138        ProviderMetadata::new(
139            "litellm",
140            "LiteLLM",
141            "LiteLLM proxy supporting multiple models with automatic prompt caching",
142            LITELLM_DEFAULT_MODEL,
143            vec![],
144            LITELLM_DOC_URL,
145            vec![
146                ConfigKey::new("LITELLM_API_KEY", true, true, None),
147                ConfigKey::new("LITELLM_HOST", true, false, Some("http://localhost:4000")),
148                ConfigKey::new(
149                    "LITELLM_BASE_PATH",
150                    true,
151                    false,
152                    Some("v1/chat/completions"),
153                ),
154                ConfigKey::new("LITELLM_CUSTOM_HEADERS", false, true, None),
155                ConfigKey::new("LITELLM_TIMEOUT", false, false, Some("600")),
156            ],
157        )
158    }
159
160    fn get_name(&self) -> &str {
161        &self.name
162    }
163
164    fn get_model_config(&self) -> ModelConfig {
165        self.model.clone()
166    }
167
168    #[tracing::instrument(skip_all, name = "provider_complete")]
169    async fn complete_with_model(
170        &self,
171        model_config: &ModelConfig,
172        system: &str,
173        messages: &[Message],
174        tools: &[Tool],
175    ) -> Result<(Message, ProviderUsage), ProviderError> {
176        let mut payload = super::formats::openai::create_request(
177            model_config,
178            system,
179            messages,
180            tools,
181            &ImageFormat::OpenAi,
182            false,
183        )?;
184
185        if self.supports_cache_control().await {
186            payload = update_request_for_cache_control(&payload);
187        }
188
189        let response = self
190            .with_retry(|| async {
191                let payload_clone = payload.clone();
192                self.post(&payload_clone).await
193            })
194            .await?;
195
196        let message = super::formats::openai::response_to_message(&response)?;
197        let usage = super::formats::openai::get_usage(&response);
198        let response_model = get_model(&response);
199        let mut log = RequestLog::start(model_config, &payload)?;
200        log.write(&response, Some(&usage))?;
201        Ok((message, ProviderUsage::new(response_model, usage)))
202    }
203
204    fn supports_embeddings(&self) -> bool {
205        true
206    }
207
208    async fn supports_cache_control(&self) -> bool {
209        if let Ok(models) = self.fetch_models().await {
210            if let Some(model_info) = models.iter().find(|m| m.name == self.model.model_name) {
211                return model_info.supports_cache_control.unwrap_or(false);
212            }
213        }
214
215        self.model.model_name.to_lowercase().contains("claude")
216    }
217
218    async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
219        match self.fetch_models().await {
220            Ok(models) => {
221                let model_names: Vec<String> = models.into_iter().map(|m| m.name).collect();
222                Ok(Some(model_names))
223            }
224            Err(e) => {
225                tracing::warn!("Failed to fetch models from LiteLLM: {}", e);
226                Ok(None)
227            }
228        }
229    }
230}
231
232#[async_trait]
233impl EmbeddingCapable for LiteLLMProvider {
234    async fn create_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, anyhow::Error> {
235        let embedding_model = std::env::var("ASTER_EMBEDDING_MODEL")
236            .unwrap_or_else(|_| "text-embedding-3-small".to_string());
237
238        let payload = json!({
239            "input": texts,
240            "model": embedding_model,
241            "encoding_format": "float"
242        });
243
244        let response = self
245            .api_client
246            .response_post("v1/embeddings", &payload)
247            .await?;
248        let response_text = response.text().await?;
249        let response_json: Value = serde_json::from_str(&response_text)?;
250
251        let data = response_json["data"]
252            .as_array()
253            .ok_or_else(|| anyhow::anyhow!("Missing data field"))?;
254
255        let mut embeddings = Vec::new();
256        for item in data {
257            let embedding: Vec<f32> = item["embedding"]
258                .as_array()
259                .ok_or_else(|| anyhow::anyhow!("Missing embedding field"))?
260                .iter()
261                .map(|v| v.as_f64().unwrap_or(0.0) as f32)
262                .collect();
263            embeddings.push(embedding);
264        }
265
266        Ok(embeddings)
267    }
268}
269
270/// Updates the request payload to include cache control headers for automatic prompt caching
271/// Adds ephemeral cache control to the last 2 user messages, system message, and last tool
272pub fn update_request_for_cache_control(original_payload: &Value) -> Value {
273    let mut payload = original_payload.clone();
274
275    if let Some(messages_spec) = payload
276        .as_object_mut()
277        .and_then(|obj| obj.get_mut("messages"))
278        .and_then(|messages| messages.as_array_mut())
279    {
280        let mut user_count = 0;
281        for message in messages_spec.iter_mut().rev() {
282            if message.get("role") == Some(&json!("user")) {
283                if let Some(content) = message.get_mut("content") {
284                    if let Some(content_str) = content.as_str() {
285                        *content = json!([{
286                            "type": "text",
287                            "text": content_str,
288                            "cache_control": { "type": "ephemeral" }
289                        }]);
290                    }
291                }
292                user_count += 1;
293                if user_count >= 2 {
294                    break;
295                }
296            }
297        }
298
299        if let Some(system_message) = messages_spec
300            .iter_mut()
301            .find(|msg| msg.get("role") == Some(&json!("system")))
302        {
303            if let Some(content) = system_message.get_mut("content") {
304                if let Some(content_str) = content.as_str() {
305                    *system_message = json!({
306                        "role": "system",
307                        "content": [{
308                            "type": "text",
309                            "text": content_str,
310                            "cache_control": { "type": "ephemeral" }
311                        }]
312                    });
313                }
314            }
315        }
316    }
317
318    if let Some(tools_spec) = payload
319        .as_object_mut()
320        .and_then(|obj| obj.get_mut("tools"))
321        .and_then(|tools| tools.as_array_mut())
322    {
323        if let Some(last_tool) = tools_spec.last_mut() {
324            if let Some(function) = last_tool.get_mut("function") {
325                function
326                    .as_object_mut()
327                    .unwrap()
328                    .insert("cache_control".to_string(), json!({ "type": "ephemeral" }));
329            }
330        }
331    }
332    payload
333}
334
335fn parse_custom_headers(headers_str: String) -> HashMap<String, String> {
336    let mut headers = HashMap::new();
337    for line in headers_str.lines() {
338        if let Some((key, value)) = line.split_once(':') {
339            headers.insert(key.trim().to_string(), value.trim().to_string());
340        }
341    }
342    headers
343}