1use crate::providers::base::{LLMProvider, LLMResponse, ToolCallRequest};
2use crate::providers::openai::OpenAIProvider as OpenAICompatProvider;
3use anyhow::Result;
4use async_trait::async_trait;
5use litellm_rs::core::types::content::ContentPart;
6use litellm_rs::core::types::tools::{Tool, ToolChoice};
7use litellm_rs::{CompletionOptions, Message, MessageContent, MessageRole, completion};
8use serde_json::{Map, Value};
9use std::collections::HashMap;
10
11#[derive(Clone, Copy)]
12struct ModelOverride {
13 pattern: &'static str,
14 temperature: Option<f32>,
15}
16
17#[derive(Clone, Copy)]
18struct EnvExtra {
19 key: &'static str,
20 value_template: &'static str,
21}
22
23#[derive(Clone, Copy)]
24struct ProviderSpec {
25 name: &'static str,
26 keywords: &'static [&'static str],
27 env_key: &'static str,
28 litellm_prefix: &'static str,
29 skip_prefixes: &'static [&'static str],
30 is_gateway: bool,
31 is_local: bool,
32 detect_by_key_prefix: &'static str,
33 detect_by_base_keyword: &'static str,
34 default_api_base: &'static str,
35 strip_model_prefix: bool,
36 env_extras: &'static [EnvExtra],
37 model_overrides: &'static [ModelOverride],
38}
39
40const PROVIDERS: &[ProviderSpec] = &[
41 ProviderSpec {
42 name: "openrouter",
43 keywords: &["openrouter"],
44 env_key: "OPENROUTER_API_KEY",
45 litellm_prefix: "openrouter",
46 skip_prefixes: &[],
47 is_gateway: true,
48 is_local: false,
49 detect_by_key_prefix: "sk-or-",
50 detect_by_base_keyword: "openrouter",
51 default_api_base: "https://openrouter.ai/api/v1",
52 strip_model_prefix: false,
53 env_extras: &[],
54 model_overrides: &[],
55 },
56 ProviderSpec {
57 name: "aihubmix",
58 keywords: &["aihubmix"],
59 env_key: "OPENAI_API_KEY",
60 litellm_prefix: "openai",
61 skip_prefixes: &[],
62 is_gateway: true,
63 is_local: false,
64 detect_by_key_prefix: "",
65 detect_by_base_keyword: "aihubmix",
66 default_api_base: "https://aihubmix.com/v1",
67 strip_model_prefix: true,
68 env_extras: &[],
69 model_overrides: &[],
70 },
71 ProviderSpec {
72 name: "siliconflow",
73 keywords: &["siliconflow"],
74 env_key: "OPENAI_API_KEY",
75 litellm_prefix: "openai",
76 skip_prefixes: &[],
77 is_gateway: true,
78 is_local: false,
79 detect_by_key_prefix: "",
80 detect_by_base_keyword: "siliconflow",
81 default_api_base: "https://api.siliconflow.cn/v1",
82 strip_model_prefix: false,
83 env_extras: &[],
84 model_overrides: &[],
85 },
86 ProviderSpec {
87 name: "volcengine",
88 keywords: &["volcengine", "volces", "ark"],
89 env_key: "OPENAI_API_KEY",
90 litellm_prefix: "volcengine",
91 skip_prefixes: &[],
92 is_gateway: true,
93 is_local: false,
94 detect_by_key_prefix: "",
95 detect_by_base_keyword: "volces",
96 default_api_base: "https://ark.cn-beijing.volces.com/api/v3",
97 strip_model_prefix: false,
98 env_extras: &[],
99 model_overrides: &[],
100 },
101 ProviderSpec {
102 name: "anthropic",
103 keywords: &["anthropic", "claude"],
104 env_key: "ANTHROPIC_API_KEY",
105 litellm_prefix: "",
106 skip_prefixes: &[],
107 is_gateway: false,
108 is_local: false,
109 detect_by_key_prefix: "",
110 detect_by_base_keyword: "",
111 default_api_base: "",
112 strip_model_prefix: false,
113 env_extras: &[],
114 model_overrides: &[],
115 },
116 ProviderSpec {
117 name: "openai",
118 keywords: &["openai", "gpt"],
119 env_key: "OPENAI_API_KEY",
120 litellm_prefix: "",
121 skip_prefixes: &[],
122 is_gateway: false,
123 is_local: false,
124 detect_by_key_prefix: "",
125 detect_by_base_keyword: "",
126 default_api_base: "",
127 strip_model_prefix: false,
128 env_extras: &[],
129 model_overrides: &[],
130 },
131 ProviderSpec {
132 name: "deepseek",
133 keywords: &["deepseek"],
134 env_key: "DEEPSEEK_API_KEY",
135 litellm_prefix: "deepseek",
136 skip_prefixes: &["deepseek/"],
137 is_gateway: false,
138 is_local: false,
139 detect_by_key_prefix: "",
140 detect_by_base_keyword: "",
141 default_api_base: "",
142 strip_model_prefix: false,
143 env_extras: &[],
144 model_overrides: &[],
145 },
146 ProviderSpec {
147 name: "gemini",
148 keywords: &["gemini"],
149 env_key: "GEMINI_API_KEY",
150 litellm_prefix: "gemini",
151 skip_prefixes: &["gemini/"],
152 is_gateway: false,
153 is_local: false,
154 detect_by_key_prefix: "",
155 detect_by_base_keyword: "",
156 default_api_base: "",
157 strip_model_prefix: false,
158 env_extras: &[],
159 model_overrides: &[],
160 },
161 ProviderSpec {
162 name: "zhipu",
163 keywords: &["zhipu", "glm", "zai"],
164 env_key: "ZAI_API_KEY",
165 litellm_prefix: "zai",
166 skip_prefixes: &["zhipu/", "zai/", "openrouter/", "hosted_vllm/"],
167 is_gateway: false,
168 is_local: false,
169 detect_by_key_prefix: "",
170 detect_by_base_keyword: "",
171 default_api_base: "",
172 strip_model_prefix: false,
173 env_extras: &[EnvExtra {
174 key: "ZHIPUAI_API_KEY",
175 value_template: "{api_key}",
176 }],
177 model_overrides: &[],
178 },
179 ProviderSpec {
180 name: "dashscope",
181 keywords: &["qwen", "dashscope"],
182 env_key: "DASHSCOPE_API_KEY",
183 litellm_prefix: "dashscope",
184 skip_prefixes: &["dashscope/", "openrouter/"],
185 is_gateway: false,
186 is_local: false,
187 detect_by_key_prefix: "",
188 detect_by_base_keyword: "",
189 default_api_base: "",
190 strip_model_prefix: false,
191 env_extras: &[],
192 model_overrides: &[],
193 },
194 ProviderSpec {
195 name: "moonshot",
196 keywords: &["moonshot", "kimi"],
197 env_key: "MOONSHOT_API_KEY",
198 litellm_prefix: "moonshot",
199 skip_prefixes: &["moonshot/", "openrouter/"],
200 is_gateway: false,
201 is_local: false,
202 detect_by_key_prefix: "",
203 detect_by_base_keyword: "",
204 default_api_base: "https://api.moonshot.ai/v1",
205 strip_model_prefix: false,
206 env_extras: &[EnvExtra {
207 key: "MOONSHOT_API_BASE",
208 value_template: "{api_base}",
209 }],
210 model_overrides: &[ModelOverride {
211 pattern: "kimi-k2.5",
212 temperature: Some(1.0),
213 }],
214 },
215 ProviderSpec {
216 name: "minimax",
217 keywords: &["minimax"],
218 env_key: "MINIMAX_API_KEY",
219 litellm_prefix: "minimax",
220 skip_prefixes: &["minimax/", "openrouter/"],
221 is_gateway: false,
222 is_local: false,
223 detect_by_key_prefix: "",
224 detect_by_base_keyword: "",
225 default_api_base: "https://api.minimax.io/v1",
226 strip_model_prefix: false,
227 env_extras: &[],
228 model_overrides: &[],
229 },
230 ProviderSpec {
231 name: "vllm",
232 keywords: &["vllm"],
233 env_key: "HOSTED_VLLM_API_KEY",
234 litellm_prefix: "hosted_vllm",
235 skip_prefixes: &[],
236 is_gateway: false,
237 is_local: true,
238 detect_by_key_prefix: "",
239 detect_by_base_keyword: "",
240 default_api_base: "",
241 strip_model_prefix: false,
242 env_extras: &[],
243 model_overrides: &[],
244 },
245 ProviderSpec {
246 name: "groq",
247 keywords: &["groq"],
248 env_key: "GROQ_API_KEY",
249 litellm_prefix: "groq",
250 skip_prefixes: &["groq/"],
251 is_gateway: false,
252 is_local: false,
253 detect_by_key_prefix: "",
254 detect_by_base_keyword: "",
255 default_api_base: "",
256 strip_model_prefix: false,
257 env_extras: &[],
258 model_overrides: &[],
259 },
260];
261
262fn find_by_name(name: &str) -> Option<&'static ProviderSpec> {
263 PROVIDERS.iter().find(|spec| spec.name == name)
264}
265
266fn find_by_model(model: &str) -> Option<&'static ProviderSpec> {
267 let model_lower = model.to_lowercase();
268 PROVIDERS.iter().find(|spec| {
269 !spec.is_gateway
270 && !spec.is_local
271 && spec.keywords.iter().any(|kw| model_lower.contains(kw))
272 })
273}
274
275fn find_gateway(
276 provider_name: Option<&str>,
277 api_key: Option<&str>,
278 api_base: Option<&str>,
279) -> Option<&'static ProviderSpec> {
280 if let Some(name) = provider_name
281 && let Some(spec) = find_by_name(name)
282 && (spec.is_gateway || spec.is_local)
283 {
284 return Some(spec);
285 }
286
287 PROVIDERS.iter().find(|spec| {
288 let key_matches = !spec.detect_by_key_prefix.is_empty()
289 && api_key.is_some_and(|k| k.starts_with(spec.detect_by_key_prefix));
290 let base_matches = !spec.detect_by_base_keyword.is_empty()
291 && api_base.is_some_and(|b| b.contains(spec.detect_by_base_keyword));
292 key_matches || base_matches
293 })
294}
295
296#[derive(Clone)]
297pub struct LiteLLMProvider {
298 api_key: String,
299 api_base: Option<String>,
300 default_model: String,
301 extra_headers: HashMap<String, String>,
302 gateway: Option<&'static ProviderSpec>,
303}
304
305impl LiteLLMProvider {
306 pub fn new(
307 api_key: impl Into<String>,
308 api_base: Option<String>,
309 default_model: impl Into<String>,
310 extra_headers: Option<HashMap<String, String>>,
311 provider_name: Option<&str>,
312 ) -> Self {
313 let api_key = api_key.into();
314 let default_model = default_model.into();
315 let gateway = find_gateway(
316 provider_name,
317 if api_key.is_empty() {
318 None
319 } else {
320 Some(&api_key)
321 },
322 api_base.as_deref(),
323 );
324
325 let provider = Self {
326 api_key,
327 api_base,
328 default_model,
329 extra_headers: extra_headers.unwrap_or_default(),
330 gateway,
331 };
332
333 if !provider.api_key.is_empty() {
334 provider.setup_env(&provider.default_model);
335 }
336
337 provider
338 }
339
340 fn resolve_model(&self, model: &str) -> String {
341 if let Some(gateway) = self.gateway {
342 let normalized = if gateway.strip_model_prefix {
343 model.rsplit('/').next().unwrap_or(model)
344 } else {
345 model
346 };
347 if gateway.litellm_prefix.is_empty()
348 || normalized.starts_with(&format!("{}/", gateway.litellm_prefix))
349 {
350 return normalized.to_string();
351 }
352 return format!("{}/{}", gateway.litellm_prefix, normalized);
353 }
354
355 if let Some(spec) = find_by_model(model)
356 && !spec.litellm_prefix.is_empty()
357 && !spec
358 .skip_prefixes
359 .iter()
360 .any(|prefix| model.starts_with(prefix))
361 {
362 return format!("{}/{}", spec.litellm_prefix, model);
363 }
364
365 model.to_string()
366 }
367
368 fn apply_model_overrides(&self, model: &str, temperature: &mut f32) {
369 let model_lower = model.to_lowercase();
370 if let Some(spec) = find_by_model(model) {
371 for rule in spec.model_overrides {
372 if model_lower.contains(rule.pattern)
373 && let Some(temp) = rule.temperature
374 {
375 *temperature = temp;
376 return;
377 }
378 }
379 }
380 }
381
382 fn effective_api_base(&self, model: &str) -> Option<String> {
383 if let Some(base) = &self.api_base {
384 return Some(base.clone());
385 }
386
387 if let Some(gateway) = self.gateway
388 && !gateway.default_api_base.is_empty()
389 {
390 return Some(gateway.default_api_base.to_string());
391 }
392
393 if let Some(spec) = find_by_model(model)
394 && !spec.default_api_base.is_empty()
395 {
396 return Some(spec.default_api_base.to_string());
397 }
398
399 None
400 }
401
402 fn setup_env(&self, model: &str) {
403 let Some(spec) = self.gateway.or_else(|| find_by_model(model)) else {
404 return;
405 };
406
407 if !spec.env_key.is_empty() {
408 Self::set_env_var(spec.env_key, &self.api_key, self.gateway.is_some());
409 }
410
411 let effective_base = self.api_base.as_deref().unwrap_or(spec.default_api_base);
412 for extra in spec.env_extras {
413 let value = extra
414 .value_template
415 .replace("{api_key}", &self.api_key)
416 .replace("{api_base}", effective_base);
417 Self::set_env_var(extra.key, &value, false);
418 }
419 }
420
421 fn use_openai_compat_path(&self, model: &str) -> bool {
422 if self.gateway.is_some() || self.api_base.is_some() {
423 return true;
424 }
425 matches!(find_by_model(model), Some(spec) if spec.name == "openai")
426 }
427
428 fn set_env_var(key: &str, value: &str, overwrite: bool) {
429 if key.is_empty() || value.is_empty() {
430 return;
431 }
432 if !overwrite && std::env::var_os(key).is_some() {
433 return;
434 }
435
436 unsafe { std::env::set_var(key, value) };
439 }
440
441 fn convert_message(raw: &Value) -> Message {
442 if let Ok(message) = serde_json::from_value::<Message>(raw.clone()) {
443 return message;
444 }
445
446 let role = match raw.get("role").and_then(Value::as_str).unwrap_or("user") {
447 "system" => MessageRole::System,
448 "assistant" => MessageRole::Assistant,
449 "tool" => MessageRole::Tool,
450 "function" => MessageRole::Function,
451 _ => MessageRole::User,
452 };
453
454 let content = match raw.get("content") {
455 Some(Value::String(text)) => Some(MessageContent::Text(text.clone())),
456 Some(Value::Array(parts)) => {
457 serde_json::from_value::<MessageContent>(Value::Array(parts.clone())).ok()
458 }
459 _ => None,
460 };
461
462 let mut message = Message {
463 role,
464 content,
465 ..Default::default()
466 };
467
468 if let Some(name) = raw.get("name").and_then(Value::as_str) {
469 message.name = Some(name.to_string());
470 }
471 if let Some(tool_call_id) = raw.get("tool_call_id").and_then(Value::as_str) {
472 message.tool_call_id = Some(tool_call_id.to_string());
473 }
474 if let Some(tool_calls) = raw.get("tool_calls")
475 && let Ok(parsed) = serde_json::from_value(tool_calls.clone())
476 {
477 message.tool_calls = Some(parsed);
478 }
479 if let Some(function_call) = raw.get("function_call")
480 && let Ok(parsed) = serde_json::from_value(function_call.clone())
481 {
482 message.function_call = Some(parsed);
483 }
484
485 message
486 }
487
488 fn content_to_text(content: &MessageContent) -> String {
489 match content {
490 MessageContent::Text(text) => text.clone(),
491 MessageContent::Parts(parts) => {
492 let chunks = parts
493 .iter()
494 .filter_map(|part| match part {
495 ContentPart::Text { text } => Some(text.clone()),
496 ContentPart::ToolResult { content, .. } => Some(content.to_string()),
497 _ => None,
498 })
499 .collect::<Vec<_>>();
500 chunks.join("\n")
501 }
502 }
503 }
504}
505
506#[async_trait]
507impl LLMProvider for LiteLLMProvider {
508 async fn chat(
509 &self,
510 messages: &[Value],
511 tools: Option<&[Value]>,
512 model: Option<&str>,
513 max_tokens: u32,
514 temperature: f32,
515 ) -> Result<LLMResponse> {
516 let selected_model = model.unwrap_or(&self.default_model);
517 let mut effective_temperature = temperature;
518 let resolved_model = self.resolve_model(selected_model);
519 self.apply_model_overrides(&resolved_model, &mut effective_temperature);
520
521 if self.use_openai_compat_path(selected_model) {
522 let provider = OpenAICompatProvider::new(
523 self.api_key.clone(),
524 self.effective_api_base(selected_model),
525 selected_model.to_string(),
526 Some(self.extra_headers.clone()),
527 );
528 return provider
529 .chat(
530 messages,
531 tools,
532 Some(selected_model),
533 max_tokens,
534 effective_temperature,
535 )
536 .await;
537 }
538
539 let chat_messages = messages
540 .iter()
541 .map(Self::convert_message)
542 .collect::<Vec<_>>();
543 let mut options = CompletionOptions {
544 max_tokens: Some(max_tokens),
545 temperature: Some(effective_temperature),
546 api_key: if self.api_key.is_empty() {
547 None
548 } else {
549 Some(self.api_key.clone())
550 },
551 api_base: self.effective_api_base(selected_model),
552 headers: if self.extra_headers.is_empty() {
553 None
554 } else {
555 Some(self.extra_headers.clone())
556 },
557 ..Default::default()
558 };
559
560 if let Some(tool_defs) = tools {
561 let parsed_tools = tool_defs
562 .iter()
563 .filter_map(|item| serde_json::from_value::<Tool>(item.clone()).ok())
564 .collect::<Vec<_>>();
565 if !parsed_tools.is_empty() {
566 options.tools = Some(parsed_tools);
567 options.tool_choice = Some(ToolChoice::String("auto".to_string()));
568 }
569
570 options
572 .extra_params
573 .insert("tools".to_string(), Value::Array(tool_defs.to_vec()));
574 options
575 .extra_params
576 .insert("tool_choice".to_string(), Value::String("auto".to_string()));
577 }
578
579 let response = match completion(
580 &resolved_model,
581 chat_messages.clone(),
582 Some(options.clone()),
583 )
584 .await
585 {
586 Ok(resp) => resp,
587 Err(primary_err) => {
588 if resolved_model != selected_model {
591 completion(selected_model, chat_messages, Some(options))
592 .await
593 .map_err(|fallback_err| {
594 anyhow::anyhow!(
595 "failed to call litellm-rs completion: primary={primary_err}; fallback={fallback_err}"
596 )
597 })?
598 } else {
599 return Err(anyhow::anyhow!(
600 "failed to call litellm-rs completion: {primary_err}"
601 ));
602 }
603 }
604 };
605
606 let Some(choice) = response.choices.first() else {
607 return Ok(LLMResponse {
608 content: None,
609 tool_calls: Vec::new(),
610 finish_reason: "stop".to_string(),
611 usage: Map::new(),
612 reasoning_content: None,
613 });
614 };
615
616 let content = choice.message.content.as_ref().map(Self::content_to_text);
617 let reasoning_content = choice
618 .message
619 .thinking
620 .as_ref()
621 .and_then(|thinking| thinking.as_text())
622 .map(ToOwned::to_owned);
623 let tool_calls = choice
624 .message
625 .tool_calls
626 .clone()
627 .unwrap_or_default()
628 .into_iter()
629 .map(|call| {
630 let arguments = serde_json::from_str::<Value>(&call.function.arguments)
631 .ok()
632 .and_then(|v| v.as_object().cloned())
633 .unwrap_or_else(|| {
634 let mut fallback = Map::new();
635 fallback.insert("raw".to_string(), Value::String(call.function.arguments));
636 fallback
637 });
638
639 ToolCallRequest {
640 id: call.id,
641 name: call.function.name,
642 arguments,
643 }
644 })
645 .collect::<Vec<_>>();
646
647 let finish_reason = choice
648 .finish_reason
649 .as_ref()
650 .and_then(|reason| serde_json::to_value(reason).ok())
651 .and_then(|v| v.as_str().map(ToOwned::to_owned))
652 .unwrap_or_else(|| "stop".to_string());
653
654 let usage = response
655 .usage
656 .and_then(|usage| serde_json::to_value(usage).ok())
657 .and_then(|value| value.as_object().cloned())
658 .unwrap_or_default();
659
660 Ok(LLMResponse {
661 content,
662 tool_calls,
663 finish_reason,
664 usage,
665 reasoning_content,
666 })
667 }
668
669 fn default_model(&self) -> &str {
670 &self.default_model
671 }
672}
673
674#[cfg(test)]
675mod tests {
676 use super::*;
677
678 #[test]
679 fn gateway_detects_by_provider_name_and_key_prefix() {
680 let by_name = find_gateway(Some("vllm"), None, None).expect("expected vllm gateway");
681 assert_eq!(by_name.name, "vllm");
682
683 let by_key = find_gateway(None, Some("sk-or-test"), None).expect("expected openrouter");
684 assert_eq!(by_key.name, "openrouter");
685 }
686
687 #[test]
688 fn resolve_model_applies_gateway_and_provider_rules() {
689 let aihubmix = LiteLLMProvider::new(
690 "",
691 Some("https://aihubmix.com/v1".to_string()),
692 "anthropic/claude-3-7-sonnet",
693 None,
694 Some("aihubmix"),
695 );
696 assert_eq!(
697 aihubmix.resolve_model("anthropic/claude-3-7-sonnet"),
698 "openai/claude-3-7-sonnet"
699 );
700
701 let standard = LiteLLMProvider::new("", None, "qwen-plus", None, None);
702 assert_eq!(standard.resolve_model("qwen-plus"), "dashscope/qwen-plus");
703 assert_eq!(
704 standard.resolve_model("dashscope/qwen-plus"),
705 "dashscope/qwen-plus"
706 );
707
708 let volcengine = LiteLLMProvider::new(
709 "x",
710 Some("https://ark.cn-beijing.volces.com/api/v3".to_string()),
711 "doubao-seed-1-6-thinking-250715",
712 None,
713 Some("volcengine"),
714 );
715 assert_eq!(
716 volcengine.resolve_model("doubao-seed-1-6-thinking-250715"),
717 "volcengine/doubao-seed-1-6-thinking-250715"
718 );
719 }
720
721 #[test]
722 fn model_override_applies_kimi_temperature_floor() {
723 let provider = LiteLLMProvider::new("", None, "kimi-k2.5", None, None);
724 let mut temp = 0.2;
725 provider.apply_model_overrides("moonshot/kimi-k2.5", &mut temp);
726 assert!((temp - 1.0).abs() < f32::EPSILON);
727 }
728}