Skip to main content

litellm_rust/providers/
openai_compat.rs

1use crate::config::ProviderConfig;
2use crate::error::{LiteLLMError, Result};
3use crate::http::send_json;
4use crate::providers::resolve_api_key;
5use crate::stream::{parse_sse_stream, ChatStream};
6use crate::types::{
7    ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse, ImageData, ImageRequest,
8    ImageResponse, Usage, VideoRequest, VideoResponse,
9};
10use base64::{engine::general_purpose, Engine as _};
11use reqwest::multipart::Form;
12use reqwest::Client;
13use serde::Deserialize;
14use serde_json::Value;
15use std::time::Duration;
16use tokio::time::sleep;
17
18/// Default maximum polling attempts for video generation (120 * 5s = 10 minutes)
19pub const DEFAULT_VIDEO_MAX_POLL_ATTEMPTS: u32 = 120;
20/// Default polling interval for video generation status checks
21pub const DEFAULT_VIDEO_POLL_INTERVAL_SECS: u64 = 5;
22
23#[derive(Debug, Deserialize)]
24struct OpenAIChatResponse {
25    id: Option<String>,
26    choices: Vec<OpenAIChoice>,
27    usage: Option<OpenAIUsage>,
28}
29
30#[derive(Debug, Deserialize)]
31struct OpenAIChoice {
32    message: OpenAIMessage,
33}
34
35#[derive(Debug, Deserialize)]
36struct OpenAIMessage {
37    content: Option<String>,
38}
39
40#[derive(Debug, Deserialize)]
41struct OpenAIUsage {
42    prompt_tokens: Option<u64>,
43    completion_tokens: Option<u64>,
44    total_tokens: Option<u64>,
45    cost: Option<Value>,
46    completion_tokens_details: Option<CompletionTokensDetails>,
47}
48
49#[derive(Debug, Deserialize)]
50struct CompletionTokensDetails {
51    reasoning_tokens: Option<u64>,
52}
53
54#[derive(Debug, Deserialize)]
55struct OpenAIEmbeddingResponse {
56    data: Vec<OpenAIEmbeddingItem>,
57    usage: Option<OpenAIUsage>,
58}
59
60#[derive(Debug, Deserialize)]
61struct OpenAIEmbeddingItem {
62    embedding: Vec<f32>,
63}
64
65/// Build the chat request body from a ChatRequest.
66///
67/// This is shared between streaming and non-streaming chat calls.
68fn build_chat_body(req: &ChatRequest, stream: bool) -> Value {
69    let mut body = serde_json::json!({
70        "model": req.model,
71        "messages": req.messages,
72    });
73
74    if stream {
75        body["stream"] = serde_json::json!(true);
76    }
77
78    if let Some(temp) = req.temperature {
79        body["temperature"] = serde_json::json!(temp);
80    }
81    if let Some(max_tokens) = req.max_tokens {
82        body["max_tokens"] = serde_json::json!(max_tokens);
83    }
84    if let Some(ref fmt) = req.response_format {
85        body["response_format"] = fmt.clone();
86    }
87    if let Some(max_completion_tokens) = req.max_completion_tokens {
88        body["max_completion_tokens"] = serde_json::json!(max_completion_tokens);
89    }
90    if let Some(ref tools) = req.tools {
91        body["tools"] = tools.clone();
92    }
93    if let Some(ref tool_choice) = req.tool_choice {
94        body["tool_choice"] = tool_choice.clone();
95    }
96    if let Some(parallel) = req.parallel_tool_calls {
97        body["parallel_tool_calls"] = serde_json::json!(parallel);
98    }
99    if let Some(ref stop) = req.stop {
100        body["stop"] = stop.clone();
101    }
102    if let Some(top_p) = req.top_p {
103        body["top_p"] = serde_json::json!(top_p);
104    }
105    if let Some(presence) = req.presence_penalty {
106        body["presence_penalty"] = serde_json::json!(presence);
107    }
108    if let Some(frequency) = req.frequency_penalty {
109        body["frequency_penalty"] = serde_json::json!(frequency);
110    }
111    if let Some(seed) = req.seed {
112        body["seed"] = serde_json::json!(seed);
113    }
114    if let Some(ref user) = req.user {
115        body["user"] = serde_json::json!(user);
116    }
117    if let Some(ref metadata) = req.metadata {
118        body["metadata"] = metadata.clone();
119    }
120    if let Some(ref reasoning_effort) = req.reasoning_effort {
121        body["reasoning_effort"] = reasoning_effort.clone();
122    }
123    if let Some(ref thinking) = req.thinking {
124        body["thinking"] = thinking.clone();
125    }
126
127    body
128}
129
130pub async fn chat(client: &Client, cfg: &ProviderConfig, req: ChatRequest) -> Result<ChatResponse> {
131    let base = cfg
132        .base_url
133        .clone()
134        .ok_or_else(|| LiteLLMError::Config("base_url required".into()))?;
135    let url = format!("{}/chat/completions", base.trim_end_matches('/'));
136    let key = resolve_api_key(cfg)?;
137
138    let body = build_chat_body(&req, false);
139
140    let mut builder = client.post(url).json(&body);
141    if let Some(key) = key {
142        builder = builder.bearer_auth(key);
143    }
144    for (k, v) in &cfg.extra_headers {
145        builder = builder.header(k, v);
146    }
147
148    let (parsed, headers) = send_json::<OpenAIChatResponse>(builder).await?;
149    let content = parsed
150        .choices
151        .first()
152        .and_then(|c| c.message.content.clone())
153        .unwrap_or_default();
154    let header_cost = headers
155        .get("x-litellm-response-cost")
156        .and_then(|v| v.to_str().ok())
157        .and_then(|v| v.parse::<f64>().ok());
158    let mut usage = map_usage(parsed.usage);
159    if usage.cost_usd.is_none() {
160        usage.cost_usd = header_cost;
161    }
162
163    Ok(ChatResponse {
164        content,
165        usage,
166        response_id: parsed.id,
167        header_cost,
168        raw: None,
169    })
170}
171
172pub async fn chat_stream(
173    client: &Client,
174    cfg: &ProviderConfig,
175    req: ChatRequest,
176) -> Result<ChatStream> {
177    let base = cfg
178        .base_url
179        .clone()
180        .ok_or_else(|| LiteLLMError::Config("base_url required".into()))?;
181    let url = format!("{}/chat/completions", base.trim_end_matches('/'));
182    let key = resolve_api_key(cfg)?;
183
184    let body = build_chat_body(&req, true);
185
186    let mut builder = client.post(url).json(&body);
187    if let Some(key) = key {
188        builder = builder.bearer_auth(key);
189    }
190    for (k, v) in &cfg.extra_headers {
191        builder = builder.header(k, v);
192    }
193
194    let resp = builder.send().await.map_err(LiteLLMError::from)?;
195    let status = resp.status();
196    if !status.is_success() {
197        let text = resp.text().await.map_err(LiteLLMError::from)?;
198        return Err(LiteLLMError::http(format!(
199            "http {}: {}",
200            status.as_u16(),
201            text
202        )));
203    }
204
205    Ok(parse_sse_stream(resp.bytes_stream()))
206}
207
208pub async fn embeddings(
209    client: &Client,
210    cfg: &ProviderConfig,
211    req: EmbeddingRequest,
212) -> Result<EmbeddingResponse> {
213    let base = cfg
214        .base_url
215        .clone()
216        .ok_or_else(|| LiteLLMError::Config("base_url required".into()))?;
217    let url = format!("{}/embeddings", base.trim_end_matches('/'));
218    let key = resolve_api_key(cfg)?;
219
220    let body = serde_json::json!({
221        "model": req.model,
222        "input": req.input,
223    });
224
225    let mut builder = client.post(url).json(&body);
226    if let Some(key) = key {
227        builder = builder.bearer_auth(key);
228    }
229    for (k, v) in &cfg.extra_headers {
230        builder = builder.header(k, v);
231    }
232
233    let (parsed, _headers) = send_json::<OpenAIEmbeddingResponse>(builder).await?;
234    let vectors = parsed.data.into_iter().map(|d| d.embedding).collect();
235
236    Ok(EmbeddingResponse {
237        vectors,
238        usage: map_usage(parsed.usage),
239        raw: None,
240    })
241}
242
243pub async fn image_generation(
244    client: &Client,
245    cfg: &ProviderConfig,
246    req: ImageRequest,
247) -> Result<ImageResponse> {
248    let base = cfg
249        .base_url
250        .clone()
251        .ok_or_else(|| LiteLLMError::Config("base_url required".into()))?;
252    let url = format!("{}/images/generations", base.trim_end_matches('/'));
253    let key = resolve_api_key(cfg)?;
254
255    let mut body = serde_json::json!({
256        "model": req.model,
257        "prompt": req.prompt,
258    });
259    if let Some(n) = req.n {
260        body["n"] = serde_json::json!(n);
261    }
262    if let Some(ref size) = req.size {
263        body["size"] = serde_json::json!(size);
264    }
265    if let Some(ref quality) = req.quality {
266        body["quality"] = serde_json::json!(quality);
267    }
268    if let Some(ref background) = req.background {
269        body["background"] = serde_json::json!(background);
270    }
271
272    let mut builder = client.post(url).json(&body);
273    if let Some(key) = key {
274        builder = builder.bearer_auth(key);
275    }
276    for (k, v) in &cfg.extra_headers {
277        builder = builder.header(k, v);
278    }
279
280    let (parsed, _headers) = send_json::<Value>(builder).await?;
281    let images = parsed
282        .get("data")
283        .and_then(|v| v.as_array())
284        .map(|arr| {
285            arr.iter()
286                .map(|item| ImageData {
287                    b64_json: item
288                        .get("b64_json")
289                        .and_then(|v| v.as_str())
290                        .map(|s| s.to_string()),
291                    url: item
292                        .get("url")
293                        .and_then(|v| v.as_str())
294                        .map(|s| s.to_string()),
295                    revised_prompt: item
296                        .get("revised_prompt")
297                        .and_then(|v| v.as_str())
298                        .map(|s| s.to_string()),
299                    mime_type: None,
300                })
301                .collect::<Vec<_>>()
302        })
303        .unwrap_or_default();
304
305    Ok(ImageResponse {
306        images,
307        usage: Usage::default(),
308        raw: None,
309    })
310}
311
312/// Video generation options for configurable timeouts.
313#[derive(Debug, Clone)]
314pub struct VideoGenerationOptions {
315    /// Maximum number of polling attempts
316    pub max_poll_attempts: u32,
317    /// Interval between polling attempts in seconds
318    pub poll_interval_secs: u64,
319}
320
321impl Default for VideoGenerationOptions {
322    fn default() -> Self {
323        Self {
324            max_poll_attempts: DEFAULT_VIDEO_MAX_POLL_ATTEMPTS,
325            poll_interval_secs: DEFAULT_VIDEO_POLL_INTERVAL_SECS,
326        }
327    }
328}
329
330pub async fn video_generation(
331    client: &Client,
332    cfg: &ProviderConfig,
333    req: VideoRequest,
334) -> Result<VideoResponse> {
335    video_generation_with_options(client, cfg, req, VideoGenerationOptions::default()).await
336}
337
338pub async fn video_generation_with_options(
339    client: &Client,
340    cfg: &ProviderConfig,
341    req: VideoRequest,
342    options: VideoGenerationOptions,
343) -> Result<VideoResponse> {
344    let base = cfg
345        .base_url
346        .clone()
347        .ok_or_else(|| LiteLLMError::Config("base_url required".into()))?;
348    let url = format!("{}/videos", base.trim_end_matches('/'));
349    let key = resolve_api_key(cfg)?;
350
351    let mut form = Form::new()
352        .text("model", req.model)
353        .text("prompt", req.prompt);
354    if let Some(seconds) = req.seconds {
355        form = form.text("seconds", seconds.to_string());
356    }
357    if let Some(size) = req.size {
358        form = form.text("size", size);
359    }
360
361    let mut builder = client.post(url).multipart(form);
362    if let Some(ref key) = key {
363        builder = builder.bearer_auth(key.clone());
364    }
365    for (k, v) in &cfg.extra_headers {
366        builder = builder.header(k, v);
367    }
368
369    let (parsed, _headers) = send_json::<Value>(builder).await?;
370    let video_id = parsed
371        .get("id")
372        .and_then(|v| v.as_str())
373        .ok_or_else(|| LiteLLMError::Parse("missing video id".into()))?;
374
375    let status_url = format!("{}/videos/{}", base.trim_end_matches('/'), video_id);
376    let poll_interval = Duration::from_secs(options.poll_interval_secs);
377
378    for attempt in 0..options.max_poll_attempts {
379        let mut status_builder = client.get(&status_url);
380        if let Some(ref key) = key {
381            status_builder = status_builder.bearer_auth(key.clone());
382        }
383        let (status_resp, _headers) = send_json::<Value>(status_builder).await?;
384        let status = status_resp
385            .get("status")
386            .and_then(|v| v.as_str())
387            .unwrap_or("unknown");
388
389        match status {
390            "completed" => {
391                return fetch_video_content(client, &base, video_id, key.as_deref()).await;
392            }
393            "failed" => {
394                let msg = status_resp
395                    .get("error")
396                    .and_then(|v| v.as_str())
397                    .unwrap_or("video generation failed");
398                return Err(LiteLLMError::http(msg.to_string()));
399            }
400            _ => {
401                if attempt + 1 >= options.max_poll_attempts {
402                    return Err(LiteLLMError::http(format!(
403                        "video generation timed out after {} attempts",
404                        options.max_poll_attempts
405                    )));
406                }
407                sleep(poll_interval).await;
408            }
409        }
410    }
411
412    Err(LiteLLMError::http("video generation timed out"))
413}
414
415async fn fetch_video_content(
416    client: &Client,
417    base: &str,
418    video_id: &str,
419    key: Option<&str>,
420) -> Result<VideoResponse> {
421    let content_url = format!("{}/videos/{}/content", base.trim_end_matches('/'), video_id);
422    let mut content_builder = client.get(&content_url);
423    if let Some(key) = key {
424        content_builder = content_builder.bearer_auth(key);
425    }
426
427    let bytes = content_builder
428        .send()
429        .await
430        .map_err(LiteLLMError::from)?
431        .bytes()
432        .await
433        .map_err(LiteLLMError::from)?;
434    let b64 = general_purpose::STANDARD.encode(bytes);
435
436    Ok(VideoResponse {
437        video_url: Some(format!("data:video/mp4;base64,{b64}")),
438        raw: None,
439    })
440}
441
442fn map_usage(usage: Option<OpenAIUsage>) -> Usage {
443    usage.map_or_else(Usage::default, |u| Usage {
444        prompt_tokens: u.prompt_tokens,
445        completion_tokens: u.completion_tokens,
446        thoughts_tokens: u.completion_tokens_details.and_then(|d| d.reasoning_tokens),
447        total_tokens: u.total_tokens,
448        cost_usd: parse_cost(u.cost.as_ref()),
449    })
450}
451
452fn parse_cost(value: Option<&Value>) -> Option<f64> {
453    let v = value?;
454    if let Some(n) = v.as_f64() {
455        return Some(n);
456    }
457    if let Some(s) = v.as_str() {
458        return s.parse::<f64>().ok();
459    }
460    None
461}