devsper_providers/
litellm.rs1use devsper_core::{LlmProvider, LlmRequest, LlmResponse, LlmRole, StopReason};
2use anyhow::{anyhow, Result};
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use tracing::debug;
7
8pub struct LiteLlmProvider {
11 client: Client,
12 base_url: String,
13 api_key: String,
14}
15
16impl LiteLlmProvider {
17 pub fn new(base_url: impl Into<String>, api_key: impl Into<String>) -> Self {
20 Self {
21 client: Client::new(),
22 base_url: base_url.into(),
23 api_key: api_key.into(),
24 }
25 }
26}
27
28#[derive(Serialize)]
29struct OaiRequest<'a> {
30 model: &'a str,
31 messages: Vec<OaiMessage<'a>>,
32 #[serde(skip_serializing_if = "Option::is_none")]
33 max_tokens: Option<u32>,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 temperature: Option<f32>,
36}
37
38#[derive(Serialize)]
39struct OaiMessage<'a> {
40 role: &'a str,
41 content: &'a str,
42}
43
44#[derive(Deserialize)]
45struct OaiResponse {
46 choices: Vec<OaiChoice>,
47 usage: OaiUsage,
48 model: String,
49}
50
51#[derive(Deserialize)]
52struct OaiChoice {
53 message: OaiChoiceMessage,
54 finish_reason: Option<String>,
55}
56
57#[derive(Deserialize)]
58struct OaiChoiceMessage {
59 content: Option<String>,
60}
61
62#[derive(Deserialize)]
63struct OaiUsage {
64 prompt_tokens: u32,
65 completion_tokens: u32,
66}
67
68fn role_str(role: &LlmRole) -> &'static str {
69 match role {
70 LlmRole::System => "system",
71 LlmRole::User | LlmRole::Tool => "user",
72 LlmRole::Assistant => "assistant",
73 }
74}
75
76#[async_trait]
77impl LlmProvider for LiteLlmProvider {
78 async fn generate(&self, req: LlmRequest) -> Result<LlmResponse> {
79 use tracing::Instrument;
80
81 let span = tracing::info_span!(
82 "gen_ai.chat",
83 "gen_ai.system" = self.name(),
84 "gen_ai.operation.name" = "chat",
85 "gen_ai.request.model" = req.model.as_str(),
86 "gen_ai.request.max_tokens" = req.max_tokens,
87 "gen_ai.response.model" = tracing::field::Empty,
88 "gen_ai.usage.input_tokens" = tracing::field::Empty,
89 "gen_ai.usage.output_tokens" = tracing::field::Empty,
90 );
91
92 let model = req.model.strip_prefix("litellm:").unwrap_or(&req.model);
94
95 let messages: Vec<OaiMessage> = req
96 .messages
97 .iter()
98 .map(|m| OaiMessage {
99 role: role_str(&m.role),
100 content: &m.content,
101 })
102 .collect();
103
104 let body = OaiRequest {
105 model,
106 messages,
107 max_tokens: req.max_tokens,
108 temperature: req.temperature,
109 };
110
111 debug!(model = %model, provider = "litellm", "LiteLLM proxy request");
112
113 let result = async {
114 let mut request = self
115 .client
116 .post(format!("{}/v1/chat/completions", self.base_url))
117 .header("Content-Type", "application/json");
118
119 if !self.api_key.is_empty() {
120 request = request.header("Authorization", format!("Bearer {}", self.api_key));
121 }
122
123 let resp = request.json(&body).send().await?;
124
125 if !resp.status().is_success() {
126 let status = resp.status();
127 let text = resp.text().await.unwrap_or_default();
128 return Err(anyhow!("litellm API error {status}: {text}"));
129 }
130
131 let data: OaiResponse = resp.json().await?;
132 let choice = data
133 .choices
134 .into_iter()
135 .next()
136 .ok_or_else(|| anyhow!("No choices in response"))?;
137
138 let stop_reason = match choice.finish_reason.as_deref() {
139 Some("tool_calls") => StopReason::ToolUse,
140 Some("length") => StopReason::MaxTokens,
141 Some("stop") | None => StopReason::EndTurn,
142 _ => StopReason::EndTurn,
143 };
144
145 Ok(LlmResponse {
146 content: choice.message.content.unwrap_or_default(),
147 tool_calls: vec![],
148 input_tokens: data.usage.prompt_tokens,
149 output_tokens: data.usage.completion_tokens,
150 model: data.model,
151 stop_reason,
152 })
153 }
154 .instrument(span.clone())
155 .await;
156
157 if let Ok(ref resp) = result {
158 span.record("gen_ai.response.model", resp.model.as_str());
159 span.record("gen_ai.usage.input_tokens", resp.input_tokens);
160 span.record("gen_ai.usage.output_tokens", resp.output_tokens);
161 }
162 result
163 }
164
165 fn name(&self) -> &str {
166 "litellm"
167 }
168
169 fn supports_model(&self, model: &str) -> bool {
170 model.starts_with("litellm:")
171 }
172}