1use crate::config::ProviderConfig;
2use crate::error::{LiteLLMError, Result};
3use crate::http::send_json;
4use crate::providers::resolve_api_key;
5use crate::stream::{parse_sse_stream, ChatStream};
6use crate::types::{
7 ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse, ImageData, ImageRequest,
8 ImageResponse, Usage, VideoRequest, VideoResponse,
9};
10use base64::{engine::general_purpose, Engine as _};
11use reqwest::multipart::Form;
12use reqwest::Client;
13use serde::Deserialize;
14use serde_json::Value;
15use std::time::Duration;
16use tokio::time::sleep;
17
18pub const DEFAULT_VIDEO_MAX_POLL_ATTEMPTS: u32 = 120;
20pub const DEFAULT_VIDEO_POLL_INTERVAL_SECS: u64 = 5;
22
23#[derive(Debug, Deserialize)]
24struct OpenAIChatResponse {
25 id: Option<String>,
26 choices: Vec<OpenAIChoice>,
27 usage: Option<OpenAIUsage>,
28}
29
30#[derive(Debug, Deserialize)]
31struct OpenAIChoice {
32 message: OpenAIMessage,
33}
34
35#[derive(Debug, Deserialize)]
36struct OpenAIMessage {
37 content: Option<String>,
38}
39
40#[derive(Debug, Deserialize)]
41struct OpenAIUsage {
42 prompt_tokens: Option<u64>,
43 completion_tokens: Option<u64>,
44 total_tokens: Option<u64>,
45 cost: Option<Value>,
46 completion_tokens_details: Option<CompletionTokensDetails>,
47}
48
49#[derive(Debug, Deserialize)]
50struct CompletionTokensDetails {
51 reasoning_tokens: Option<u64>,
52}
53
54#[derive(Debug, Deserialize)]
55struct OpenAIEmbeddingResponse {
56 data: Vec<OpenAIEmbeddingItem>,
57 usage: Option<OpenAIUsage>,
58}
59
60#[derive(Debug, Deserialize)]
61struct OpenAIEmbeddingItem {
62 embedding: Vec<f32>,
63}
64
65fn build_chat_body(req: &ChatRequest, stream: bool) -> Value {
69 let mut body = serde_json::json!({
70 "model": req.model,
71 "messages": req.messages,
72 });
73
74 if stream {
75 body["stream"] = serde_json::json!(true);
76 }
77
78 if let Some(temp) = req.temperature {
79 body["temperature"] = serde_json::json!(temp);
80 }
81 if let Some(max_tokens) = req.max_tokens {
82 body["max_tokens"] = serde_json::json!(max_tokens);
83 }
84 if let Some(ref fmt) = req.response_format {
85 body["response_format"] = fmt.clone();
86 }
87 if let Some(max_completion_tokens) = req.max_completion_tokens {
88 body["max_completion_tokens"] = serde_json::json!(max_completion_tokens);
89 }
90 if let Some(ref tools) = req.tools {
91 body["tools"] = tools.clone();
92 }
93 if let Some(ref tool_choice) = req.tool_choice {
94 body["tool_choice"] = tool_choice.clone();
95 }
96 if let Some(parallel) = req.parallel_tool_calls {
97 body["parallel_tool_calls"] = serde_json::json!(parallel);
98 }
99 if let Some(ref stop) = req.stop {
100 body["stop"] = stop.clone();
101 }
102 if let Some(top_p) = req.top_p {
103 body["top_p"] = serde_json::json!(top_p);
104 }
105 if let Some(presence) = req.presence_penalty {
106 body["presence_penalty"] = serde_json::json!(presence);
107 }
108 if let Some(frequency) = req.frequency_penalty {
109 body["frequency_penalty"] = serde_json::json!(frequency);
110 }
111 if let Some(seed) = req.seed {
112 body["seed"] = serde_json::json!(seed);
113 }
114 if let Some(ref user) = req.user {
115 body["user"] = serde_json::json!(user);
116 }
117 if let Some(ref metadata) = req.metadata {
118 body["metadata"] = metadata.clone();
119 }
120 if let Some(ref reasoning_effort) = req.reasoning_effort {
121 body["reasoning_effort"] = reasoning_effort.clone();
122 }
123 if let Some(ref thinking) = req.thinking {
124 body["thinking"] = thinking.clone();
125 }
126
127 body
128}
129
130pub async fn chat(client: &Client, cfg: &ProviderConfig, req: ChatRequest) -> Result<ChatResponse> {
131 let base = cfg
132 .base_url
133 .clone()
134 .ok_or_else(|| LiteLLMError::Config("base_url required".into()))?;
135 let url = format!("{}/chat/completions", base.trim_end_matches('/'));
136 let key = resolve_api_key(cfg)?;
137
138 let body = build_chat_body(&req, false);
139
140 let mut builder = client.post(url).json(&body);
141 if let Some(key) = key {
142 builder = builder.bearer_auth(key);
143 }
144 for (k, v) in &cfg.extra_headers {
145 builder = builder.header(k, v);
146 }
147
148 let (parsed, headers) = send_json::<OpenAIChatResponse>(builder).await?;
149 let content = parsed
150 .choices
151 .first()
152 .and_then(|c| c.message.content.clone())
153 .unwrap_or_default();
154 let header_cost = headers
155 .get("x-litellm-response-cost")
156 .and_then(|v| v.to_str().ok())
157 .and_then(|v| v.parse::<f64>().ok());
158 let mut usage = map_usage(parsed.usage);
159 if usage.cost_usd.is_none() {
160 usage.cost_usd = header_cost;
161 }
162
163 Ok(ChatResponse {
164 content,
165 usage,
166 response_id: parsed.id,
167 header_cost,
168 raw: None,
169 })
170}
171
172pub async fn chat_stream(
173 client: &Client,
174 cfg: &ProviderConfig,
175 req: ChatRequest,
176) -> Result<ChatStream> {
177 let base = cfg
178 .base_url
179 .clone()
180 .ok_or_else(|| LiteLLMError::Config("base_url required".into()))?;
181 let url = format!("{}/chat/completions", base.trim_end_matches('/'));
182 let key = resolve_api_key(cfg)?;
183
184 let body = build_chat_body(&req, true);
185
186 let mut builder = client.post(url).json(&body);
187 if let Some(key) = key {
188 builder = builder.bearer_auth(key);
189 }
190 for (k, v) in &cfg.extra_headers {
191 builder = builder.header(k, v);
192 }
193
194 let resp = builder.send().await.map_err(LiteLLMError::from)?;
195 let status = resp.status();
196 if !status.is_success() {
197 let text = resp.text().await.map_err(LiteLLMError::from)?;
198 return Err(LiteLLMError::http(format!(
199 "http {}: {}",
200 status.as_u16(),
201 text
202 )));
203 }
204
205 Ok(parse_sse_stream(resp.bytes_stream()))
206}
207
208pub async fn embeddings(
209 client: &Client,
210 cfg: &ProviderConfig,
211 req: EmbeddingRequest,
212) -> Result<EmbeddingResponse> {
213 let base = cfg
214 .base_url
215 .clone()
216 .ok_or_else(|| LiteLLMError::Config("base_url required".into()))?;
217 let url = format!("{}/embeddings", base.trim_end_matches('/'));
218 let key = resolve_api_key(cfg)?;
219
220 let body = serde_json::json!({
221 "model": req.model,
222 "input": req.input,
223 });
224
225 let mut builder = client.post(url).json(&body);
226 if let Some(key) = key {
227 builder = builder.bearer_auth(key);
228 }
229 for (k, v) in &cfg.extra_headers {
230 builder = builder.header(k, v);
231 }
232
233 let (parsed, _headers) = send_json::<OpenAIEmbeddingResponse>(builder).await?;
234 let vectors = parsed.data.into_iter().map(|d| d.embedding).collect();
235
236 Ok(EmbeddingResponse {
237 vectors,
238 usage: map_usage(parsed.usage),
239 raw: None,
240 })
241}
242
243pub async fn image_generation(
244 client: &Client,
245 cfg: &ProviderConfig,
246 req: ImageRequest,
247) -> Result<ImageResponse> {
248 let base = cfg
249 .base_url
250 .clone()
251 .ok_or_else(|| LiteLLMError::Config("base_url required".into()))?;
252 let url = format!("{}/images/generations", base.trim_end_matches('/'));
253 let key = resolve_api_key(cfg)?;
254
255 let mut body = serde_json::json!({
256 "model": req.model,
257 "prompt": req.prompt,
258 });
259 if let Some(n) = req.n {
260 body["n"] = serde_json::json!(n);
261 }
262 if let Some(ref size) = req.size {
263 body["size"] = serde_json::json!(size);
264 }
265 if let Some(ref quality) = req.quality {
266 body["quality"] = serde_json::json!(quality);
267 }
268 if let Some(ref background) = req.background {
269 body["background"] = serde_json::json!(background);
270 }
271
272 let mut builder = client.post(url).json(&body);
273 if let Some(key) = key {
274 builder = builder.bearer_auth(key);
275 }
276 for (k, v) in &cfg.extra_headers {
277 builder = builder.header(k, v);
278 }
279
280 let (parsed, _headers) = send_json::<Value>(builder).await?;
281 let images = parsed
282 .get("data")
283 .and_then(|v| v.as_array())
284 .map(|arr| {
285 arr.iter()
286 .map(|item| ImageData {
287 b64_json: item
288 .get("b64_json")
289 .and_then(|v| v.as_str())
290 .map(|s| s.to_string()),
291 url: item
292 .get("url")
293 .and_then(|v| v.as_str())
294 .map(|s| s.to_string()),
295 revised_prompt: item
296 .get("revised_prompt")
297 .and_then(|v| v.as_str())
298 .map(|s| s.to_string()),
299 mime_type: None,
300 })
301 .collect::<Vec<_>>()
302 })
303 .unwrap_or_default();
304
305 Ok(ImageResponse {
306 images,
307 usage: Usage::default(),
308 raw: None,
309 })
310}
311
312#[derive(Debug, Clone)]
314pub struct VideoGenerationOptions {
315 pub max_poll_attempts: u32,
317 pub poll_interval_secs: u64,
319}
320
321impl Default for VideoGenerationOptions {
322 fn default() -> Self {
323 Self {
324 max_poll_attempts: DEFAULT_VIDEO_MAX_POLL_ATTEMPTS,
325 poll_interval_secs: DEFAULT_VIDEO_POLL_INTERVAL_SECS,
326 }
327 }
328}
329
330pub async fn video_generation(
331 client: &Client,
332 cfg: &ProviderConfig,
333 req: VideoRequest,
334) -> Result<VideoResponse> {
335 video_generation_with_options(client, cfg, req, VideoGenerationOptions::default()).await
336}
337
338pub async fn video_generation_with_options(
339 client: &Client,
340 cfg: &ProviderConfig,
341 req: VideoRequest,
342 options: VideoGenerationOptions,
343) -> Result<VideoResponse> {
344 let base = cfg
345 .base_url
346 .clone()
347 .ok_or_else(|| LiteLLMError::Config("base_url required".into()))?;
348 let url = format!("{}/videos", base.trim_end_matches('/'));
349 let key = resolve_api_key(cfg)?;
350
351 let mut form = Form::new()
352 .text("model", req.model)
353 .text("prompt", req.prompt);
354 if let Some(seconds) = req.seconds {
355 form = form.text("seconds", seconds.to_string());
356 }
357 if let Some(size) = req.size {
358 form = form.text("size", size);
359 }
360
361 let mut builder = client.post(url).multipart(form);
362 if let Some(ref key) = key {
363 builder = builder.bearer_auth(key.clone());
364 }
365 for (k, v) in &cfg.extra_headers {
366 builder = builder.header(k, v);
367 }
368
369 let (parsed, _headers) = send_json::<Value>(builder).await?;
370 let video_id = parsed
371 .get("id")
372 .and_then(|v| v.as_str())
373 .ok_or_else(|| LiteLLMError::Parse("missing video id".into()))?;
374
375 let status_url = format!("{}/videos/{}", base.trim_end_matches('/'), video_id);
376 let poll_interval = Duration::from_secs(options.poll_interval_secs);
377
378 for attempt in 0..options.max_poll_attempts {
379 let mut status_builder = client.get(&status_url);
380 if let Some(ref key) = key {
381 status_builder = status_builder.bearer_auth(key.clone());
382 }
383 let (status_resp, _headers) = send_json::<Value>(status_builder).await?;
384 let status = status_resp
385 .get("status")
386 .and_then(|v| v.as_str())
387 .unwrap_or("unknown");
388
389 match status {
390 "completed" => {
391 return fetch_video_content(client, &base, video_id, key.as_deref()).await;
392 }
393 "failed" => {
394 let msg = status_resp
395 .get("error")
396 .and_then(|v| v.as_str())
397 .unwrap_or("video generation failed");
398 return Err(LiteLLMError::http(msg.to_string()));
399 }
400 _ => {
401 if attempt + 1 >= options.max_poll_attempts {
402 return Err(LiteLLMError::http(format!(
403 "video generation timed out after {} attempts",
404 options.max_poll_attempts
405 )));
406 }
407 sleep(poll_interval).await;
408 }
409 }
410 }
411
412 Err(LiteLLMError::http("video generation timed out"))
413}
414
415async fn fetch_video_content(
416 client: &Client,
417 base: &str,
418 video_id: &str,
419 key: Option<&str>,
420) -> Result<VideoResponse> {
421 let content_url = format!("{}/videos/{}/content", base.trim_end_matches('/'), video_id);
422 let mut content_builder = client.get(&content_url);
423 if let Some(key) = key {
424 content_builder = content_builder.bearer_auth(key);
425 }
426
427 let bytes = content_builder
428 .send()
429 .await
430 .map_err(LiteLLMError::from)?
431 .bytes()
432 .await
433 .map_err(LiteLLMError::from)?;
434 let b64 = general_purpose::STANDARD.encode(bytes);
435
436 Ok(VideoResponse {
437 video_url: Some(format!("data:video/mp4;base64,{b64}")),
438 raw: None,
439 })
440}
441
442fn map_usage(usage: Option<OpenAIUsage>) -> Usage {
443 usage.map_or_else(Usage::default, |u| Usage {
444 prompt_tokens: u.prompt_tokens,
445 completion_tokens: u.completion_tokens,
446 thoughts_tokens: u.completion_tokens_details.and_then(|d| d.reasoning_tokens),
447 total_tokens: u.total_tokens,
448 cost_usd: parse_cost(u.cost.as_ref()),
449 })
450}
451
452fn parse_cost(value: Option<&Value>) -> Option<f64> {
453 let v = value?;
454 if let Some(n) = v.as_f64() {
455 return Some(n);
456 }
457 if let Some(s) = v.as_str() {
458 return s.parse::<f64>().ok();
459 }
460 None
461}