Skip to main content

devsper_providers/
litellm.rs

1use devsper_core::{LlmProvider, LlmRequest, LlmResponse, LlmRole, StopReason};
2use anyhow::{anyhow, Result};
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use tracing::debug;
7
8/// LiteLLM proxy provider — OpenAI-compatible, optional Bearer auth.
9/// Expects model names prefixed with "litellm:" (e.g. "litellm:gpt-4o").
10pub struct LiteLlmProvider {
11    client: Client,
12    base_url: String,
13    api_key: String,
14}
15
16impl LiteLlmProvider {
17    /// `base_url` is required (e.g. "http://localhost:4000").
18    /// `api_key` may be empty — auth header is omitted when it is.
19    pub fn new(base_url: impl Into<String>, api_key: impl Into<String>) -> Self {
20        Self {
21            client: Client::new(),
22            base_url: base_url.into(),
23            api_key: api_key.into(),
24        }
25    }
26}
27
28#[derive(Serialize)]
29struct OaiRequest<'a> {
30    model: &'a str,
31    messages: Vec<OaiMessage<'a>>,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    max_tokens: Option<u32>,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    temperature: Option<f32>,
36}
37
38#[derive(Serialize)]
39struct OaiMessage<'a> {
40    role: &'a str,
41    content: &'a str,
42}
43
44#[derive(Deserialize)]
45struct OaiResponse {
46    choices: Vec<OaiChoice>,
47    usage: OaiUsage,
48    model: String,
49}
50
51#[derive(Deserialize)]
52struct OaiChoice {
53    message: OaiChoiceMessage,
54    finish_reason: Option<String>,
55}
56
57#[derive(Deserialize)]
58struct OaiChoiceMessage {
59    content: Option<String>,
60}
61
62#[derive(Deserialize)]
63struct OaiUsage {
64    prompt_tokens: u32,
65    completion_tokens: u32,
66}
67
68fn role_str(role: &LlmRole) -> &'static str {
69    match role {
70        LlmRole::System => "system",
71        LlmRole::User | LlmRole::Tool => "user",
72        LlmRole::Assistant => "assistant",
73    }
74}
75
76#[async_trait]
77impl LlmProvider for LiteLlmProvider {
78    async fn generate(&self, req: LlmRequest) -> Result<LlmResponse> {
79        use tracing::Instrument;
80
81        let span = tracing::info_span!(
82            "gen_ai.chat",
83            "gen_ai.system" = self.name(),
84            "gen_ai.operation.name" = "chat",
85            "gen_ai.request.model" = req.model.as_str(),
86            "gen_ai.request.max_tokens" = req.max_tokens,
87            "gen_ai.response.model" = tracing::field::Empty,
88            "gen_ai.usage.input_tokens" = tracing::field::Empty,
89            "gen_ai.usage.output_tokens" = tracing::field::Empty,
90        );
91
92        // Strip "litellm:" prefix before sending to proxy
93        let model = req.model.strip_prefix("litellm:").unwrap_or(&req.model);
94
95        let messages: Vec<OaiMessage> = req
96            .messages
97            .iter()
98            .map(|m| OaiMessage {
99                role: role_str(&m.role),
100                content: &m.content,
101            })
102            .collect();
103
104        let body = OaiRequest {
105            model,
106            messages,
107            max_tokens: req.max_tokens,
108            temperature: req.temperature,
109        };
110
111        debug!(model = %model, provider = "litellm", "LiteLLM proxy request");
112
113        let result = async {
114            let mut request = self
115                .client
116                .post(format!("{}/v1/chat/completions", self.base_url))
117                .header("Content-Type", "application/json");
118
119            if !self.api_key.is_empty() {
120                request = request.header("Authorization", format!("Bearer {}", self.api_key));
121            }
122
123            let resp = request.json(&body).send().await?;
124
125            if !resp.status().is_success() {
126                let status = resp.status();
127                let text = resp.text().await.unwrap_or_default();
128                return Err(anyhow!("litellm API error {status}: {text}"));
129            }
130
131            let data: OaiResponse = resp.json().await?;
132            let choice = data
133                .choices
134                .into_iter()
135                .next()
136                .ok_or_else(|| anyhow!("No choices in response"))?;
137
138            let stop_reason = match choice.finish_reason.as_deref() {
139                Some("tool_calls") => StopReason::ToolUse,
140                Some("length") => StopReason::MaxTokens,
141                Some("stop") | None => StopReason::EndTurn,
142                _ => StopReason::EndTurn,
143            };
144
145            Ok(LlmResponse {
146                content: choice.message.content.unwrap_or_default(),
147                tool_calls: vec![],
148                input_tokens: data.usage.prompt_tokens,
149                output_tokens: data.usage.completion_tokens,
150                model: data.model,
151                stop_reason,
152            })
153        }
154        .instrument(span.clone())
155        .await;
156
157        if let Ok(ref resp) = result {
158            span.record("gen_ai.response.model", resp.model.as_str());
159            span.record("gen_ai.usage.input_tokens", resp.input_tokens);
160            span.record("gen_ai.usage.output_tokens", resp.output_tokens);
161        }
162        result
163    }
164
165    fn name(&self) -> &str {
166        "litellm"
167    }
168
169    fn supports_model(&self, model: &str) -> bool {
170        model.starts_with("litellm:")
171    }
172}