1use crate::core::{
2 Message, MessageAttachment, Provider, ProviderRequest, ProviderResponse, ProviderStreamEvent,
3 Role, ToolCall,
4};
5use crate::provider::StreamedToolCall;
6use anyhow::{Context, bail};
7use async_trait::async_trait;
8use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
9use serde_json::{Value, json};
10use std::env;
11
12const THINKING_FIELDS: [&str; 3] = ["reasoning", "thinking", "reasoning_content"];
13
14pub struct OpenAiCompatibleProvider {
15 base_url: String,
16 model: String,
17 api_key_env: String,
18 client: reqwest::Client,
19}
20
21impl OpenAiCompatibleProvider {
22 pub fn new(base_url: String, model: String, api_key_env: String) -> Self {
23 Self {
24 base_url,
25 model,
26 api_key_env,
27 client: reqwest::Client::new(),
28 }
29 }
30
31 fn endpoint(&self) -> String {
32 format!("{}/chat/completions", self.base_url.trim_end_matches('/'))
33 }
34
35 fn auth_headers(&self) -> anyhow::Result<HeaderMap> {
36 let api_key = env::var(&self.api_key_env)
37 .with_context(|| format!("missing API key env var {}", self.api_key_env))?;
38 let mut headers = HeaderMap::new();
39 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
40 headers.insert(
41 AUTHORIZATION,
42 HeaderValue::from_str(&format!("Bearer {}", api_key))?,
43 );
44 Ok(headers)
45 }
46
47 fn request_body(
48 &self,
49 req: &ProviderRequest,
50 stream: bool,
51 image_url_as_object: bool,
52 image_data_format: ImageDataFormat,
53 include_tools: bool,
54 ) -> Value {
55 let requested_model = if req.model.is_empty() {
56 self.model.as_str()
57 } else {
58 req.model.as_str()
59 };
60
61 let tools = if include_tools {
62 req.tools
63 .iter()
64 .map(|t| {
65 json!({
66 "type": "function",
67 "function": {
68 "name": t.name,
69 "description": t.description,
70 "parameters": t.parameters,
71 }
72 })
73 })
74 .collect::<Vec<_>>()
75 } else {
76 Vec::new()
77 };
78
79 let messages = req
80 .messages
81 .iter()
82 .map(|message| message_to_wire(message, image_url_as_object, image_data_format))
83 .collect::<Vec<_>>();
84
85 let mut body = json!({
86 "model": requested_model,
87 "messages": messages,
88 "stream": stream,
89 });
90 if include_tools && !tools.is_empty() {
91 body["tools"] = json!(tools);
92 body["tool_choice"] = json!("auto");
93 }
94 body
95 }
96
97 fn parse_chat_response(value: &Value) -> anyhow::Result<ProviderResponse> {
98 let choice = value
99 .get("choices")
100 .and_then(|v| v.as_array())
101 .and_then(|a| a.first())
102 .context("provider response missing choices[0]")?;
103
104 let message = choice
105 .get("message")
106 .and_then(|m| m.as_object())
107 .context("provider response missing message")?;
108
109 let content = message
110 .get("content")
111 .and_then(|c| c.as_str())
112 .unwrap_or_default()
113 .to_string();
114
115 let thinking = extract_thinking(message);
116 let tool_calls = parse_tool_calls(message)?;
117 let context_tokens = parse_context_tokens(value);
118
119 Ok(ProviderResponse {
120 assistant_message: Message {
121 role: Role::Assistant,
122 content,
123 attachments: Vec::new(),
124 tool_call_id: None,
125 },
126 done: tool_calls.is_empty(),
127 tool_calls,
128 thinking,
129 context_tokens,
130 })
131 }
132
133 async fn send_request(
134 &self,
135 req: &ProviderRequest,
136 stream: bool,
137 error_context: &str,
138 ) -> anyhow::Result<reqwest::Response> {
139 let primary_body = self.request_body(req, stream, true, ImageDataFormat::DataUrl, true);
140
141 let primary = self
142 .client
143 .post(self.endpoint())
144 .headers(self.auth_headers()?)
145 .json(&primary_body)
146 .send()
147 .await
148 .with_context(|| error_context.to_string())?;
149
150 if primary.status().is_success() {
151 return Ok(primary);
152 }
153
154 let primary_status = primary.status();
155 let primary_error = primary.text().await.unwrap_or_default();
156 if has_image_attachments(req)
157 && should_retry_for_image_payload(primary_status, &primary_error)
158 {
159 let no_tools_body =
160 self.request_body(req, stream, true, ImageDataFormat::DataUrl, false);
161 let fallback_no_tools = self
162 .client
163 .post(self.endpoint())
164 .headers(self.auth_headers()?)
165 .json(&no_tools_body)
166 .send()
167 .await
168 .with_context(|| format!("{} (fallback: no tools)", error_context))?;
169
170 if fallback_no_tools.status().is_success() {
171 return Ok(fallback_no_tools);
172 }
173
174 let no_tools_status = fallback_no_tools.status();
175 let no_tools_error = fallback_no_tools.text().await.unwrap_or_default();
176
177 if should_retry_for_image_payload(no_tools_status, &no_tools_error) {
178 let raw_base64_body =
179 self.request_body(req, stream, true, ImageDataFormat::RawBase64, false);
180 let fallback_raw_base64 = self
181 .client
182 .post(self.endpoint())
183 .headers(self.auth_headers()?)
184 .json(&raw_base64_body)
185 .send()
186 .await
187 .with_context(|| format!("{} (fallback: raw base64)", error_context))?;
188
189 if fallback_raw_base64.status().is_success() {
190 return Ok(fallback_raw_base64);
191 }
192
193 let raw_base64_status = fallback_raw_base64.status();
194 let raw_base64_error = fallback_raw_base64.text().await.unwrap_or_default();
195
196 let string_image_body =
197 self.request_body(req, stream, false, ImageDataFormat::DataUrl, false);
198 let fallback_string_image = self
199 .client
200 .post(self.endpoint())
201 .headers(self.auth_headers()?)
202 .json(&string_image_body)
203 .send()
204 .await
205 .with_context(|| format!("{} (fallback: string image_url)", error_context))?;
206
207 if fallback_string_image.status().is_success() {
208 return Ok(fallback_string_image);
209 }
210
211 let string_status = fallback_string_image.status();
212 let string_error = fallback_string_image.text().await.unwrap_or_default();
213 bail!(
214 "provider error {}: {} (fallback_no_tools {}: {}) (fallback_raw_base64 {}: {}) (fallback_string_image_url {}: {})",
215 primary_status,
216 primary_error,
217 no_tools_status,
218 no_tools_error,
219 raw_base64_status,
220 raw_base64_error,
221 string_status,
222 string_error
223 );
224 }
225
226 bail!(
227 "provider error {}: {} (fallback_no_tools {}: {})",
228 primary_status,
229 primary_error,
230 no_tools_status,
231 no_tools_error
232 );
233 }
234
235 bail!("provider error {}: {}", primary_status, primary_error)
236 }
237
238 async fn complete_stream_inner<F>(
239 &self,
240 req: &ProviderRequest,
241 mut on_event: F,
242 ) -> anyhow::Result<ProviderResponse>
243 where
244 F: FnMut(ProviderStreamEvent) + Send,
245 {
246 let response = self
247 .send_request(req, true, "provider stream request failed")
248 .await?;
249
250 let mut assistant = String::new();
251 let mut thinking = String::new();
252 let mut partial_calls: Vec<StreamedToolCall> = Vec::new();
253 let mut stream_done = false;
254 let mut context_tokens = None;
255
256 let mut buffer = String::new();
257 let mut resp = response;
258 while !stream_done && let Some(chunk) = resp.chunk().await.context("stream read failed")? {
259 let txt = String::from_utf8_lossy(&chunk);
260 buffer.push_str(&txt);
261
262 while let Some(pos) = buffer.find('\n') {
263 let line = buffer[..pos].trim_end_matches('\r').to_string();
264 buffer.drain(..=pos);
265
266 match parse_stream_line(&line) {
267 Some(StreamLine::Done) => {
268 stream_done = true;
269 break;
270 }
271 Some(StreamLine::Payload(value)) => {
272 if let Some(tokens) = parse_context_tokens(&value) {
273 context_tokens = Some(tokens);
274 }
275 apply_stream_chunk(
276 &value,
277 &mut assistant,
278 &mut thinking,
279 &mut partial_calls,
280 &mut on_event,
281 )
282 }
283 None => continue,
284 }
285 }
286 }
287
288 if !stream_done {
289 match parse_stream_line(buffer.trim()) {
290 Some(StreamLine::Payload(value)) => {
291 if let Some(tokens) = parse_context_tokens(&value) {
292 context_tokens = Some(tokens);
293 }
294 apply_stream_chunk(
295 &value,
296 &mut assistant,
297 &mut thinking,
298 &mut partial_calls,
299 &mut on_event,
300 )
301 }
302 Some(StreamLine::Done) | None => {}
303 }
304 }
305
306 let tool_calls = partial_calls
307 .into_iter()
308 .filter(|c| !c.name.is_empty())
309 .map(StreamedToolCall::into_tool_call)
310 .collect::<Vec<_>>();
311
312 Ok(ProviderResponse {
313 assistant_message: Message {
314 role: Role::Assistant,
315 content: assistant,
316 attachments: Vec::new(),
317 tool_call_id: None,
318 },
319 done: tool_calls.is_empty(),
320 tool_calls,
321 thinking: if thinking.is_empty() {
322 None
323 } else {
324 Some(thinking)
325 },
326 context_tokens,
327 })
328 }
329}
330
331fn emit_response_stream_events<F>(response: &ProviderResponse, on_event: &mut F)
332where
333 F: FnMut(ProviderStreamEvent) + Send,
334{
335 if let Some(thinking) = &response.thinking {
336 on_event(ProviderStreamEvent::ThinkingDelta(thinking.clone()));
337 }
338 if !response.assistant_message.content.is_empty() {
339 on_event(ProviderStreamEvent::AssistantDelta(
340 response.assistant_message.content.clone(),
341 ));
342 }
343}
344
345enum StreamLine {
346 Done,
347 Payload(Value),
348}
349
350#[derive(Clone, Copy)]
351enum ImageDataFormat {
352 DataUrl,
353 RawBase64,
354}
355
356fn message_to_wire(
357 message: &Message,
358 image_url_as_object: bool,
359 image_data_format: ImageDataFormat,
360) -> Value {
361 let content = if message.attachments.is_empty() {
362 json!(message.content)
363 } else {
364 let mut parts = Vec::new();
365 if !message.content.is_empty() {
366 parts.push(json!({
367 "type": "text",
368 "text": message.content,
369 }));
370 }
371
372 for attachment in &message.attachments {
373 match attachment {
374 MessageAttachment::Image {
375 media_type,
376 data_base64,
377 } => {
378 let image_payload = match image_data_format {
379 ImageDataFormat::DataUrl => {
380 format!("data:{};base64,{}", media_type, data_base64)
381 }
382 ImageDataFormat::RawBase64 => data_base64.clone(),
383 };
384 if image_url_as_object {
385 parts.push(json!({
386 "type": "image_url",
387 "image_url": {
388 "url": image_payload,
389 }
390 }));
391 } else {
392 parts.push(json!({
393 "type": "image_url",
394 "image_url": image_payload,
395 }));
396 }
397 }
398 }
399 }
400
401 json!(parts)
402 };
403
404 let mut wire = json!({
405 "role": role_to_wire(&message.role),
406 "content": content,
407 });
408 if let Some(id) = &message.tool_call_id {
409 wire["tool_call_id"] = json!(id);
410 }
411 wire
412}
413
414fn has_image_attachments(req: &ProviderRequest) -> bool {
415 req.messages
416 .iter()
417 .any(|message| !message.attachments.is_empty())
418}
419
420fn should_retry_for_image_payload(status: reqwest::StatusCode, body: &str) -> bool {
421 if !status.is_client_error() {
422 return false;
423 }
424 let lower = body.to_ascii_lowercase();
425 lower.contains("invalid api parameter")
426 || lower.contains("invalid parameter")
427 || lower.contains("image_url")
428 || lower.contains("invalid type")
429}
430
431fn role_to_wire(role: &Role) -> &'static str {
432 match role {
433 Role::System => "system",
434 Role::User => "user",
435 Role::Assistant => "assistant",
436 Role::Tool => "tool",
437 }
438}
439
440fn parse_stream_line(line: &str) -> Option<StreamLine> {
441 let line = line.trim();
442 if line.is_empty() || !line.starts_with("data:") {
443 return None;
444 }
445
446 let payload = line.trim_start_matches("data:").trim();
447 if payload == "[DONE]" {
448 return Some(StreamLine::Done);
449 }
450
451 serde_json::from_str(payload).ok().map(StreamLine::Payload)
452}
453
454#[async_trait]
455impl Provider for OpenAiCompatibleProvider {
456 async fn complete(&self, req: ProviderRequest) -> anyhow::Result<ProviderResponse> {
457 let response = self
458 .send_request(&req, false, "provider request failed")
459 .await?;
460
461 let value: Value = response.json().await.context("invalid provider JSON")?;
462 Self::parse_chat_response(&value)
463 }
464
465 async fn complete_stream<F>(
466 &self,
467 req: ProviderRequest,
468 mut on_event: F,
469 ) -> anyhow::Result<ProviderResponse>
470 where
471 F: FnMut(ProviderStreamEvent) + Send,
472 {
473 match self.complete_stream_inner(&req, &mut on_event).await {
474 Ok(response) => Ok(response),
475 Err(_) => {
476 let response = self.complete(req).await?;
477 emit_response_stream_events(&response, &mut on_event);
478 Ok(response)
479 }
480 }
481 }
482}
483
484fn parse_tool_calls(message: &serde_json::Map<String, Value>) -> anyhow::Result<Vec<ToolCall>> {
485 let mut tool_calls = Vec::new();
486 if let Some(calls) = message.get("tool_calls").and_then(|v| v.as_array()) {
487 for call in calls {
488 let id = call
489 .get("id")
490 .and_then(|v| v.as_str())
491 .unwrap_or_default()
492 .to_string();
493 let function = call
494 .get("function")
495 .and_then(|v| v.as_object())
496 .context("tool call missing function")?;
497 let name = function
498 .get("name")
499 .and_then(|v| v.as_str())
500 .unwrap_or_default()
501 .to_string();
502 let args_raw = function
503 .get("arguments")
504 .and_then(|v| v.as_str())
505 .unwrap_or("{}");
506 let arguments: Value = serde_json::from_str(args_raw).unwrap_or_else(|_| json!({}));
507 tool_calls.push(ToolCall {
508 id,
509 name,
510 arguments,
511 });
512 }
513 }
514 Ok(tool_calls)
515}
516
517fn extract_thinking(message: &serde_json::Map<String, Value>) -> Option<String> {
518 THINKING_FIELDS.iter().find_map(|k| {
519 message
520 .get(*k)
521 .and_then(|v| v.as_str())
522 .filter(|v| !v.is_empty())
523 .map(ToString::to_string)
524 })
525}
526
527fn parse_context_tokens(payload: &Value) -> Option<usize> {
528 let usage = payload.get("usage")?.as_object()?;
529 usage
530 .get("prompt_tokens")
531 .or_else(|| usage.get("input_tokens"))
532 .or_else(|| usage.get("total_tokens"))
533 .and_then(|value| value.as_u64())
534 .map(|value| value as usize)
535}
536
537fn apply_stream_chunk<F>(
538 value: &Value,
539 assistant: &mut String,
540 thinking: &mut String,
541 partial_calls: &mut Vec<StreamedToolCall>,
542 on_event: &mut F,
543) where
544 F: FnMut(ProviderStreamEvent) + Send,
545{
546 let Some(choice) = value
547 .get("choices")
548 .and_then(|v| v.as_array())
549 .and_then(|a| a.first())
550 else {
551 return;
552 };
553
554 let Some(delta) = choice.get("delta").and_then(|v| v.as_object()) else {
555 return;
556 };
557
558 if let Some(content) = delta.get("content").and_then(|v| v.as_str()) {
559 assistant.push_str(content);
560 on_event(ProviderStreamEvent::AssistantDelta(content.to_string()));
561 }
562
563 for key in THINKING_FIELDS {
564 if let Some(text) = delta.get(key).and_then(|v| v.as_str()) {
565 thinking.push_str(text);
566 on_event(ProviderStreamEvent::ThinkingDelta(text.to_string()));
567 }
568 }
569
570 if let Some(tool_calls) = delta.get("tool_calls").and_then(|v| v.as_array()) {
571 for call in tool_calls {
572 let index = call.get("index").and_then(|v| v.as_u64()).unwrap_or(0) as usize;
573 while partial_calls.len() <= index {
574 partial_calls.push(StreamedToolCall::default());
575 }
576
577 let entry = &mut partial_calls[index];
578 if let Some(id) = call.get("id").and_then(|v| v.as_str()) {
579 entry.id = id.to_string();
580 }
581 if let Some(function) = call.get("function").and_then(|v| v.as_object()) {
582 if let Some(name) = function.get("name").and_then(|v| v.as_str()) {
583 entry.name = name.to_string();
584 }
585 if let Some(args_piece) = function.get("arguments").and_then(|v| v.as_str()) {
586 entry.arguments_json.push_str(args_piece);
587 }
588 }
589 }
590 }
591}