Skip to main content

nanobot/providers/
litellm.rs

1use crate::providers::base::{LLMProvider, LLMResponse, ToolCallRequest};
2use crate::providers::openai::OpenAIProvider as OpenAICompatProvider;
3use anyhow::Result;
4use async_trait::async_trait;
5use litellm_rs::core::types::content::ContentPart;
6use litellm_rs::core::types::tools::{Tool, ToolChoice};
7use litellm_rs::{CompletionOptions, Message, MessageContent, MessageRole, completion};
8use serde_json::{Map, Value};
9use std::collections::HashMap;
10
11#[derive(Clone, Copy)]
12struct ModelOverride {
13    pattern: &'static str,
14    temperature: Option<f32>,
15}
16
17#[derive(Clone, Copy)]
18struct EnvExtra {
19    key: &'static str,
20    value_template: &'static str,
21}
22
23#[derive(Clone, Copy)]
24struct ProviderSpec {
25    name: &'static str,
26    keywords: &'static [&'static str],
27    env_key: &'static str,
28    litellm_prefix: &'static str,
29    skip_prefixes: &'static [&'static str],
30    is_gateway: bool,
31    is_local: bool,
32    detect_by_key_prefix: &'static str,
33    detect_by_base_keyword: &'static str,
34    default_api_base: &'static str,
35    strip_model_prefix: bool,
36    env_extras: &'static [EnvExtra],
37    model_overrides: &'static [ModelOverride],
38}
39
40const PROVIDERS: &[ProviderSpec] = &[
41    ProviderSpec {
42        name: "openrouter",
43        keywords: &["openrouter"],
44        env_key: "OPENROUTER_API_KEY",
45        litellm_prefix: "openrouter",
46        skip_prefixes: &[],
47        is_gateway: true,
48        is_local: false,
49        detect_by_key_prefix: "sk-or-",
50        detect_by_base_keyword: "openrouter",
51        default_api_base: "https://openrouter.ai/api/v1",
52        strip_model_prefix: false,
53        env_extras: &[],
54        model_overrides: &[],
55    },
56    ProviderSpec {
57        name: "aihubmix",
58        keywords: &["aihubmix"],
59        env_key: "OPENAI_API_KEY",
60        litellm_prefix: "openai",
61        skip_prefixes: &[],
62        is_gateway: true,
63        is_local: false,
64        detect_by_key_prefix: "",
65        detect_by_base_keyword: "aihubmix",
66        default_api_base: "https://aihubmix.com/v1",
67        strip_model_prefix: true,
68        env_extras: &[],
69        model_overrides: &[],
70    },
71    ProviderSpec {
72        name: "siliconflow",
73        keywords: &["siliconflow"],
74        env_key: "OPENAI_API_KEY",
75        litellm_prefix: "openai",
76        skip_prefixes: &[],
77        is_gateway: true,
78        is_local: false,
79        detect_by_key_prefix: "",
80        detect_by_base_keyword: "siliconflow",
81        default_api_base: "https://api.siliconflow.cn/v1",
82        strip_model_prefix: false,
83        env_extras: &[],
84        model_overrides: &[],
85    },
86    ProviderSpec {
87        name: "volcengine",
88        keywords: &["volcengine", "volces", "ark"],
89        env_key: "OPENAI_API_KEY",
90        litellm_prefix: "volcengine",
91        skip_prefixes: &[],
92        is_gateway: true,
93        is_local: false,
94        detect_by_key_prefix: "",
95        detect_by_base_keyword: "volces",
96        default_api_base: "https://ark.cn-beijing.volces.com/api/v3",
97        strip_model_prefix: false,
98        env_extras: &[],
99        model_overrides: &[],
100    },
101    ProviderSpec {
102        name: "anthropic",
103        keywords: &["anthropic", "claude"],
104        env_key: "ANTHROPIC_API_KEY",
105        litellm_prefix: "",
106        skip_prefixes: &[],
107        is_gateway: false,
108        is_local: false,
109        detect_by_key_prefix: "",
110        detect_by_base_keyword: "",
111        default_api_base: "",
112        strip_model_prefix: false,
113        env_extras: &[],
114        model_overrides: &[],
115    },
116    ProviderSpec {
117        name: "openai",
118        keywords: &["openai", "gpt"],
119        env_key: "OPENAI_API_KEY",
120        litellm_prefix: "",
121        skip_prefixes: &[],
122        is_gateway: false,
123        is_local: false,
124        detect_by_key_prefix: "",
125        detect_by_base_keyword: "",
126        default_api_base: "",
127        strip_model_prefix: false,
128        env_extras: &[],
129        model_overrides: &[],
130    },
131    ProviderSpec {
132        name: "deepseek",
133        keywords: &["deepseek"],
134        env_key: "DEEPSEEK_API_KEY",
135        litellm_prefix: "deepseek",
136        skip_prefixes: &["deepseek/"],
137        is_gateway: false,
138        is_local: false,
139        detect_by_key_prefix: "",
140        detect_by_base_keyword: "",
141        default_api_base: "",
142        strip_model_prefix: false,
143        env_extras: &[],
144        model_overrides: &[],
145    },
146    ProviderSpec {
147        name: "gemini",
148        keywords: &["gemini"],
149        env_key: "GEMINI_API_KEY",
150        litellm_prefix: "gemini",
151        skip_prefixes: &["gemini/"],
152        is_gateway: false,
153        is_local: false,
154        detect_by_key_prefix: "",
155        detect_by_base_keyword: "",
156        default_api_base: "",
157        strip_model_prefix: false,
158        env_extras: &[],
159        model_overrides: &[],
160    },
161    ProviderSpec {
162        name: "zhipu",
163        keywords: &["zhipu", "glm", "zai"],
164        env_key: "ZAI_API_KEY",
165        litellm_prefix: "zai",
166        skip_prefixes: &["zhipu/", "zai/", "openrouter/", "hosted_vllm/"],
167        is_gateway: false,
168        is_local: false,
169        detect_by_key_prefix: "",
170        detect_by_base_keyword: "",
171        default_api_base: "",
172        strip_model_prefix: false,
173        env_extras: &[EnvExtra {
174            key: "ZHIPUAI_API_KEY",
175            value_template: "{api_key}",
176        }],
177        model_overrides: &[],
178    },
179    ProviderSpec {
180        name: "dashscope",
181        keywords: &["qwen", "dashscope"],
182        env_key: "DASHSCOPE_API_KEY",
183        litellm_prefix: "dashscope",
184        skip_prefixes: &["dashscope/", "openrouter/"],
185        is_gateway: false,
186        is_local: false,
187        detect_by_key_prefix: "",
188        detect_by_base_keyword: "",
189        default_api_base: "",
190        strip_model_prefix: false,
191        env_extras: &[],
192        model_overrides: &[],
193    },
194    ProviderSpec {
195        name: "moonshot",
196        keywords: &["moonshot", "kimi"],
197        env_key: "MOONSHOT_API_KEY",
198        litellm_prefix: "moonshot",
199        skip_prefixes: &["moonshot/", "openrouter/"],
200        is_gateway: false,
201        is_local: false,
202        detect_by_key_prefix: "",
203        detect_by_base_keyword: "",
204        default_api_base: "https://api.moonshot.ai/v1",
205        strip_model_prefix: false,
206        env_extras: &[EnvExtra {
207            key: "MOONSHOT_API_BASE",
208            value_template: "{api_base}",
209        }],
210        model_overrides: &[ModelOverride {
211            pattern: "kimi-k2.5",
212            temperature: Some(1.0),
213        }],
214    },
215    ProviderSpec {
216        name: "minimax",
217        keywords: &["minimax"],
218        env_key: "MINIMAX_API_KEY",
219        litellm_prefix: "minimax",
220        skip_prefixes: &["minimax/", "openrouter/"],
221        is_gateway: false,
222        is_local: false,
223        detect_by_key_prefix: "",
224        detect_by_base_keyword: "",
225        default_api_base: "https://api.minimax.io/v1",
226        strip_model_prefix: false,
227        env_extras: &[],
228        model_overrides: &[],
229    },
230    ProviderSpec {
231        name: "vllm",
232        keywords: &["vllm"],
233        env_key: "HOSTED_VLLM_API_KEY",
234        litellm_prefix: "hosted_vllm",
235        skip_prefixes: &[],
236        is_gateway: false,
237        is_local: true,
238        detect_by_key_prefix: "",
239        detect_by_base_keyword: "",
240        default_api_base: "",
241        strip_model_prefix: false,
242        env_extras: &[],
243        model_overrides: &[],
244    },
245    ProviderSpec {
246        name: "groq",
247        keywords: &["groq"],
248        env_key: "GROQ_API_KEY",
249        litellm_prefix: "groq",
250        skip_prefixes: &["groq/"],
251        is_gateway: false,
252        is_local: false,
253        detect_by_key_prefix: "",
254        detect_by_base_keyword: "",
255        default_api_base: "",
256        strip_model_prefix: false,
257        env_extras: &[],
258        model_overrides: &[],
259    },
260];
261
262fn find_by_name(name: &str) -> Option<&'static ProviderSpec> {
263    PROVIDERS.iter().find(|spec| spec.name == name)
264}
265
266fn find_by_model(model: &str) -> Option<&'static ProviderSpec> {
267    let model_lower = model.to_lowercase();
268    PROVIDERS.iter().find(|spec| {
269        !spec.is_gateway
270            && !spec.is_local
271            && spec.keywords.iter().any(|kw| model_lower.contains(kw))
272    })
273}
274
275fn find_gateway(
276    provider_name: Option<&str>,
277    api_key: Option<&str>,
278    api_base: Option<&str>,
279) -> Option<&'static ProviderSpec> {
280    if let Some(name) = provider_name
281        && let Some(spec) = find_by_name(name)
282        && (spec.is_gateway || spec.is_local)
283    {
284        return Some(spec);
285    }
286
287    PROVIDERS.iter().find(|spec| {
288        let key_matches = !spec.detect_by_key_prefix.is_empty()
289            && api_key.is_some_and(|k| k.starts_with(spec.detect_by_key_prefix));
290        let base_matches = !spec.detect_by_base_keyword.is_empty()
291            && api_base.is_some_and(|b| b.contains(spec.detect_by_base_keyword));
292        key_matches || base_matches
293    })
294}
295
296#[derive(Clone)]
297pub struct LiteLLMProvider {
298    api_key: String,
299    api_base: Option<String>,
300    default_model: String,
301    extra_headers: HashMap<String, String>,
302    gateway: Option<&'static ProviderSpec>,
303}
304
305impl LiteLLMProvider {
306    pub fn new(
307        api_key: impl Into<String>,
308        api_base: Option<String>,
309        default_model: impl Into<String>,
310        extra_headers: Option<HashMap<String, String>>,
311        provider_name: Option<&str>,
312    ) -> Self {
313        let api_key = api_key.into();
314        let default_model = default_model.into();
315        let gateway = find_gateway(
316            provider_name,
317            if api_key.is_empty() {
318                None
319            } else {
320                Some(&api_key)
321            },
322            api_base.as_deref(),
323        );
324
325        let provider = Self {
326            api_key,
327            api_base,
328            default_model,
329            extra_headers: extra_headers.unwrap_or_default(),
330            gateway,
331        };
332
333        if !provider.api_key.is_empty() {
334            provider.setup_env(&provider.default_model);
335        }
336
337        provider
338    }
339
340    fn resolve_model(&self, model: &str) -> String {
341        if let Some(gateway) = self.gateway {
342            let normalized = if gateway.strip_model_prefix {
343                model.rsplit('/').next().unwrap_or(model)
344            } else {
345                model
346            };
347            if gateway.litellm_prefix.is_empty()
348                || normalized.starts_with(&format!("{}/", gateway.litellm_prefix))
349            {
350                return normalized.to_string();
351            }
352            return format!("{}/{}", gateway.litellm_prefix, normalized);
353        }
354
355        if let Some(spec) = find_by_model(model)
356            && !spec.litellm_prefix.is_empty()
357            && !spec
358                .skip_prefixes
359                .iter()
360                .any(|prefix| model.starts_with(prefix))
361        {
362            return format!("{}/{}", spec.litellm_prefix, model);
363        }
364
365        model.to_string()
366    }
367
368    fn apply_model_overrides(&self, model: &str, temperature: &mut f32) {
369        let model_lower = model.to_lowercase();
370        if let Some(spec) = find_by_model(model) {
371            for rule in spec.model_overrides {
372                if model_lower.contains(rule.pattern)
373                    && let Some(temp) = rule.temperature
374                {
375                    *temperature = temp;
376                    return;
377                }
378            }
379        }
380    }
381
382    fn effective_api_base(&self, model: &str) -> Option<String> {
383        if let Some(base) = &self.api_base {
384            return Some(base.clone());
385        }
386
387        if let Some(gateway) = self.gateway
388            && !gateway.default_api_base.is_empty()
389        {
390            return Some(gateway.default_api_base.to_string());
391        }
392
393        if let Some(spec) = find_by_model(model)
394            && !spec.default_api_base.is_empty()
395        {
396            return Some(spec.default_api_base.to_string());
397        }
398
399        None
400    }
401
402    fn setup_env(&self, model: &str) {
403        let Some(spec) = self.gateway.or_else(|| find_by_model(model)) else {
404            return;
405        };
406
407        if !spec.env_key.is_empty() {
408            Self::set_env_var(spec.env_key, &self.api_key, self.gateway.is_some());
409        }
410
411        let effective_base = self.api_base.as_deref().unwrap_or(spec.default_api_base);
412        for extra in spec.env_extras {
413            let value = extra
414                .value_template
415                .replace("{api_key}", &self.api_key)
416                .replace("{api_base}", effective_base);
417            Self::set_env_var(extra.key, &value, false);
418        }
419    }
420
421    fn use_openai_compat_path(&self, model: &str) -> bool {
422        if self.gateway.is_some() || self.api_base.is_some() {
423            return true;
424        }
425        matches!(find_by_model(model), Some(spec) if spec.name == "openai")
426    }
427
428    fn set_env_var(key: &str, value: &str, overwrite: bool) {
429        if key.is_empty() || value.is_empty() {
430            return;
431        }
432        if !overwrite && std::env::var_os(key).is_some() {
433            return;
434        }
435
436        // SAFETY: We only mutate process env during provider initialization,
437        // mirroring Python nanobot behavior for LiteLLM provider compatibility.
438        unsafe { std::env::set_var(key, value) };
439    }
440
441    fn convert_message(raw: &Value) -> Message {
442        if let Ok(message) = serde_json::from_value::<Message>(raw.clone()) {
443            return message;
444        }
445
446        let role = match raw.get("role").and_then(Value::as_str).unwrap_or("user") {
447            "system" => MessageRole::System,
448            "assistant" => MessageRole::Assistant,
449            "tool" => MessageRole::Tool,
450            "function" => MessageRole::Function,
451            _ => MessageRole::User,
452        };
453
454        let content = match raw.get("content") {
455            Some(Value::String(text)) => Some(MessageContent::Text(text.clone())),
456            Some(Value::Array(parts)) => {
457                serde_json::from_value::<MessageContent>(Value::Array(parts.clone())).ok()
458            }
459            _ => None,
460        };
461
462        let mut message = Message {
463            role,
464            content,
465            ..Default::default()
466        };
467
468        if let Some(name) = raw.get("name").and_then(Value::as_str) {
469            message.name = Some(name.to_string());
470        }
471        if let Some(tool_call_id) = raw.get("tool_call_id").and_then(Value::as_str) {
472            message.tool_call_id = Some(tool_call_id.to_string());
473        }
474        if let Some(tool_calls) = raw.get("tool_calls")
475            && let Ok(parsed) = serde_json::from_value(tool_calls.clone())
476        {
477            message.tool_calls = Some(parsed);
478        }
479        if let Some(function_call) = raw.get("function_call")
480            && let Ok(parsed) = serde_json::from_value(function_call.clone())
481        {
482            message.function_call = Some(parsed);
483        }
484
485        message
486    }
487
488    fn content_to_text(content: &MessageContent) -> String {
489        match content {
490            MessageContent::Text(text) => text.clone(),
491            MessageContent::Parts(parts) => {
492                let chunks = parts
493                    .iter()
494                    .filter_map(|part| match part {
495                        ContentPart::Text { text } => Some(text.clone()),
496                        ContentPart::ToolResult { content, .. } => Some(content.to_string()),
497                        _ => None,
498                    })
499                    .collect::<Vec<_>>();
500                chunks.join("\n")
501            }
502        }
503    }
504}
505
506#[async_trait]
507impl LLMProvider for LiteLLMProvider {
508    async fn chat(
509        &self,
510        messages: &[Value],
511        tools: Option<&[Value]>,
512        model: Option<&str>,
513        max_tokens: u32,
514        temperature: f32,
515    ) -> Result<LLMResponse> {
516        let selected_model = model.unwrap_or(&self.default_model);
517        let mut effective_temperature = temperature;
518        let resolved_model = self.resolve_model(selected_model);
519        self.apply_model_overrides(&resolved_model, &mut effective_temperature);
520
521        if self.use_openai_compat_path(selected_model) {
522            let provider = OpenAICompatProvider::new(
523                self.api_key.clone(),
524                self.effective_api_base(selected_model),
525                selected_model.to_string(),
526                Some(self.extra_headers.clone()),
527            );
528            return provider
529                .chat(
530                    messages,
531                    tools,
532                    Some(selected_model),
533                    max_tokens,
534                    effective_temperature,
535                )
536                .await;
537        }
538
539        let chat_messages = messages
540            .iter()
541            .map(Self::convert_message)
542            .collect::<Vec<_>>();
543        let mut options = CompletionOptions {
544            max_tokens: Some(max_tokens),
545            temperature: Some(effective_temperature),
546            api_key: if self.api_key.is_empty() {
547                None
548            } else {
549                Some(self.api_key.clone())
550            },
551            api_base: self.effective_api_base(selected_model),
552            headers: if self.extra_headers.is_empty() {
553                None
554            } else {
555                Some(self.extra_headers.clone())
556            },
557            ..Default::default()
558        };
559
560        if let Some(tool_defs) = tools {
561            let parsed_tools = tool_defs
562                .iter()
563                .filter_map(|item| serde_json::from_value::<Tool>(item.clone()).ok())
564                .collect::<Vec<_>>();
565            if !parsed_tools.is_empty() {
566                options.tools = Some(parsed_tools);
567                options.tool_choice = Some(ToolChoice::String("auto".to_string()));
568            }
569
570            // litellm-rs 0.3.1 conversion currently drops CompletionOptions.tools.
571            options
572                .extra_params
573                .insert("tools".to_string(), Value::Array(tool_defs.to_vec()));
574            options
575                .extra_params
576                .insert("tool_choice".to_string(), Value::String("auto".to_string()));
577        }
578
579        let response = match completion(
580            &resolved_model,
581            chat_messages.clone(),
582            Some(options.clone()),
583        )
584        .await
585        {
586            Ok(resp) => resp,
587            Err(primary_err) => {
588                // Fallback to raw model for better compatibility with pure
589                // OpenAI-compatible endpoints.
590                if resolved_model != selected_model {
591                    completion(selected_model, chat_messages, Some(options))
592                        .await
593                        .map_err(|fallback_err| {
594                            anyhow::anyhow!(
595                                "failed to call litellm-rs completion: primary={primary_err}; fallback={fallback_err}"
596                            )
597                        })?
598                } else {
599                    return Err(anyhow::anyhow!(
600                        "failed to call litellm-rs completion: {primary_err}"
601                    ));
602                }
603            }
604        };
605
606        let Some(choice) = response.choices.first() else {
607            return Ok(LLMResponse {
608                content: None,
609                tool_calls: Vec::new(),
610                finish_reason: "stop".to_string(),
611                usage: Map::new(),
612                reasoning_content: None,
613            });
614        };
615
616        let content = choice.message.content.as_ref().map(Self::content_to_text);
617        let reasoning_content = choice
618            .message
619            .thinking
620            .as_ref()
621            .and_then(|thinking| thinking.as_text())
622            .map(ToOwned::to_owned);
623        let tool_calls = choice
624            .message
625            .tool_calls
626            .clone()
627            .unwrap_or_default()
628            .into_iter()
629            .map(|call| {
630                let arguments = serde_json::from_str::<Value>(&call.function.arguments)
631                    .ok()
632                    .and_then(|v| v.as_object().cloned())
633                    .unwrap_or_else(|| {
634                        let mut fallback = Map::new();
635                        fallback.insert("raw".to_string(), Value::String(call.function.arguments));
636                        fallback
637                    });
638
639                ToolCallRequest {
640                    id: call.id,
641                    name: call.function.name,
642                    arguments,
643                }
644            })
645            .collect::<Vec<_>>();
646
647        let finish_reason = choice
648            .finish_reason
649            .as_ref()
650            .and_then(|reason| serde_json::to_value(reason).ok())
651            .and_then(|v| v.as_str().map(ToOwned::to_owned))
652            .unwrap_or_else(|| "stop".to_string());
653
654        let usage = response
655            .usage
656            .and_then(|usage| serde_json::to_value(usage).ok())
657            .and_then(|value| value.as_object().cloned())
658            .unwrap_or_default();
659
660        Ok(LLMResponse {
661            content,
662            tool_calls,
663            finish_reason,
664            usage,
665            reasoning_content,
666        })
667    }
668
669    fn default_model(&self) -> &str {
670        &self.default_model
671    }
672}
673
674#[cfg(test)]
675mod tests {
676    use super::*;
677
678    #[test]
679    fn gateway_detects_by_provider_name_and_key_prefix() {
680        let by_name = find_gateway(Some("vllm"), None, None).expect("expected vllm gateway");
681        assert_eq!(by_name.name, "vllm");
682
683        let by_key = find_gateway(None, Some("sk-or-test"), None).expect("expected openrouter");
684        assert_eq!(by_key.name, "openrouter");
685    }
686
687    #[test]
688    fn resolve_model_applies_gateway_and_provider_rules() {
689        let aihubmix = LiteLLMProvider::new(
690            "",
691            Some("https://aihubmix.com/v1".to_string()),
692            "anthropic/claude-3-7-sonnet",
693            None,
694            Some("aihubmix"),
695        );
696        assert_eq!(
697            aihubmix.resolve_model("anthropic/claude-3-7-sonnet"),
698            "openai/claude-3-7-sonnet"
699        );
700
701        let standard = LiteLLMProvider::new("", None, "qwen-plus", None, None);
702        assert_eq!(standard.resolve_model("qwen-plus"), "dashscope/qwen-plus");
703        assert_eq!(
704            standard.resolve_model("dashscope/qwen-plus"),
705            "dashscope/qwen-plus"
706        );
707
708        let volcengine = LiteLLMProvider::new(
709            "x",
710            Some("https://ark.cn-beijing.volces.com/api/v3".to_string()),
711            "doubao-seed-1-6-thinking-250715",
712            None,
713            Some("volcengine"),
714        );
715        assert_eq!(
716            volcengine.resolve_model("doubao-seed-1-6-thinking-250715"),
717            "volcengine/doubao-seed-1-6-thinking-250715"
718        );
719    }
720
721    #[test]
722    fn model_override_applies_kimi_temperature_floor() {
723        let provider = LiteLLMProvider::new("", None, "kimi-k2.5", None, None);
724        let mut temp = 0.2;
725        provider.apply_model_overrides("moonshot/kimi-k2.5", &mut temp);
726        assert!((temp - 1.0).abs() < f32::EPSILON);
727    }
728}