Skip to main content

hh_cli/provider/
openai_compatible.rs

1use crate::core::{
2    Message, MessageAttachment, Provider, ProviderRequest, ProviderResponse, ProviderStreamEvent,
3    Role, ToolCall,
4};
5use crate::provider::StreamedToolCall;
6use anyhow::{Context, bail};
7use async_trait::async_trait;
8use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
9use serde_json::{Value, json};
10use std::env;
11
12const THINKING_FIELDS: [&str; 3] = ["reasoning", "thinking", "reasoning_content"];
13
14pub struct OpenAiCompatibleProvider {
15    base_url: String,
16    model: String,
17    api_key_env: String,
18    client: reqwest::Client,
19}
20
21impl OpenAiCompatibleProvider {
22    pub fn new(base_url: String, model: String, api_key_env: String) -> Self {
23        Self {
24            base_url,
25            model,
26            api_key_env,
27            client: reqwest::Client::new(),
28        }
29    }
30
31    fn endpoint(&self) -> String {
32        format!("{}/chat/completions", self.base_url.trim_end_matches('/'))
33    }
34
35    fn auth_headers(&self) -> anyhow::Result<HeaderMap> {
36        let api_key = env::var(&self.api_key_env)
37            .with_context(|| format!("missing API key env var {}", self.api_key_env))?;
38        let mut headers = HeaderMap::new();
39        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
40        headers.insert(
41            AUTHORIZATION,
42            HeaderValue::from_str(&format!("Bearer {}", api_key))?,
43        );
44        Ok(headers)
45    }
46
47    fn request_body(
48        &self,
49        req: &ProviderRequest,
50        stream: bool,
51        image_url_as_object: bool,
52        image_data_format: ImageDataFormat,
53        include_tools: bool,
54    ) -> Value {
55        let requested_model = if req.model.is_empty() {
56            self.model.as_str()
57        } else {
58            req.model.as_str()
59        };
60
61        let tools = if include_tools {
62            req.tools
63                .iter()
64                .map(|t| {
65                    json!({
66                        "type": "function",
67                        "function": {
68                            "name": t.name,
69                            "description": t.description,
70                            "parameters": t.parameters,
71                        }
72                    })
73                })
74                .collect::<Vec<_>>()
75        } else {
76            Vec::new()
77        };
78
79        let messages = req
80            .messages
81            .iter()
82            .map(|message| message_to_wire(message, image_url_as_object, image_data_format))
83            .collect::<Vec<_>>();
84
85        let mut body = json!({
86            "model": requested_model,
87            "messages": messages,
88            "stream": stream,
89        });
90        if include_tools && !tools.is_empty() {
91            body["tools"] = json!(tools);
92            body["tool_choice"] = json!("auto");
93        }
94        body
95    }
96
97    fn parse_chat_response(value: &Value) -> anyhow::Result<ProviderResponse> {
98        let choice = value
99            .get("choices")
100            .and_then(|v| v.as_array())
101            .and_then(|a| a.first())
102            .context("provider response missing choices[0]")?;
103
104        let message = choice
105            .get("message")
106            .and_then(|m| m.as_object())
107            .context("provider response missing message")?;
108
109        let content = message
110            .get("content")
111            .and_then(|c| c.as_str())
112            .unwrap_or_default()
113            .to_string();
114
115        let thinking = extract_thinking(message);
116        let tool_calls = parse_tool_calls(message)?;
117        let context_tokens = parse_context_tokens(value);
118
119        Ok(ProviderResponse {
120            assistant_message: Message {
121                role: Role::Assistant,
122                content,
123                attachments: Vec::new(),
124                tool_call_id: None,
125            },
126            done: tool_calls.is_empty(),
127            tool_calls,
128            thinking,
129            context_tokens,
130        })
131    }
132
133    async fn send_request(
134        &self,
135        req: &ProviderRequest,
136        stream: bool,
137        error_context: &str,
138    ) -> anyhow::Result<reqwest::Response> {
139        let primary_body = self.request_body(req, stream, true, ImageDataFormat::DataUrl, true);
140
141        let primary = self
142            .client
143            .post(self.endpoint())
144            .headers(self.auth_headers()?)
145            .json(&primary_body)
146            .send()
147            .await
148            .with_context(|| error_context.to_string())?;
149
150        if primary.status().is_success() {
151            return Ok(primary);
152        }
153
154        let primary_status = primary.status();
155        let primary_error = primary.text().await.unwrap_or_default();
156        if has_image_attachments(req)
157            && should_retry_for_image_payload(primary_status, &primary_error)
158        {
159            let no_tools_body =
160                self.request_body(req, stream, true, ImageDataFormat::DataUrl, false);
161            let fallback_no_tools = self
162                .client
163                .post(self.endpoint())
164                .headers(self.auth_headers()?)
165                .json(&no_tools_body)
166                .send()
167                .await
168                .with_context(|| format!("{} (fallback: no tools)", error_context))?;
169
170            if fallback_no_tools.status().is_success() {
171                return Ok(fallback_no_tools);
172            }
173
174            let no_tools_status = fallback_no_tools.status();
175            let no_tools_error = fallback_no_tools.text().await.unwrap_or_default();
176
177            if should_retry_for_image_payload(no_tools_status, &no_tools_error) {
178                let raw_base64_body =
179                    self.request_body(req, stream, true, ImageDataFormat::RawBase64, false);
180                let fallback_raw_base64 = self
181                    .client
182                    .post(self.endpoint())
183                    .headers(self.auth_headers()?)
184                    .json(&raw_base64_body)
185                    .send()
186                    .await
187                    .with_context(|| format!("{} (fallback: raw base64)", error_context))?;
188
189                if fallback_raw_base64.status().is_success() {
190                    return Ok(fallback_raw_base64);
191                }
192
193                let raw_base64_status = fallback_raw_base64.status();
194                let raw_base64_error = fallback_raw_base64.text().await.unwrap_or_default();
195
196                let string_image_body =
197                    self.request_body(req, stream, false, ImageDataFormat::DataUrl, false);
198                let fallback_string_image = self
199                    .client
200                    .post(self.endpoint())
201                    .headers(self.auth_headers()?)
202                    .json(&string_image_body)
203                    .send()
204                    .await
205                    .with_context(|| format!("{} (fallback: string image_url)", error_context))?;
206
207                if fallback_string_image.status().is_success() {
208                    return Ok(fallback_string_image);
209                }
210
211                let string_status = fallback_string_image.status();
212                let string_error = fallback_string_image.text().await.unwrap_or_default();
213                bail!(
214                    "provider error {}: {} (fallback_no_tools {}: {}) (fallback_raw_base64 {}: {}) (fallback_string_image_url {}: {})",
215                    primary_status,
216                    primary_error,
217                    no_tools_status,
218                    no_tools_error,
219                    raw_base64_status,
220                    raw_base64_error,
221                    string_status,
222                    string_error
223                );
224            }
225
226            bail!(
227                "provider error {}: {} (fallback_no_tools {}: {})",
228                primary_status,
229                primary_error,
230                no_tools_status,
231                no_tools_error
232            );
233        }
234
235        bail!("provider error {}: {}", primary_status, primary_error)
236    }
237
238    async fn complete_stream_inner<F>(
239        &self,
240        req: &ProviderRequest,
241        mut on_event: F,
242    ) -> anyhow::Result<ProviderResponse>
243    where
244        F: FnMut(ProviderStreamEvent) + Send,
245    {
246        let response = self
247            .send_request(req, true, "provider stream request failed")
248            .await?;
249
250        let mut assistant = String::new();
251        let mut thinking = String::new();
252        let mut partial_calls: Vec<StreamedToolCall> = Vec::new();
253        let mut stream_done = false;
254        let mut context_tokens = None;
255
256        let mut buffer = String::new();
257        let mut resp = response;
258        while !stream_done && let Some(chunk) = resp.chunk().await.context("stream read failed")? {
259            let txt = String::from_utf8_lossy(&chunk);
260            buffer.push_str(&txt);
261
262            while let Some(pos) = buffer.find('\n') {
263                let line = buffer[..pos].trim_end_matches('\r').to_string();
264                buffer.drain(..=pos);
265
266                match parse_stream_line(&line) {
267                    Some(StreamLine::Done) => {
268                        stream_done = true;
269                        break;
270                    }
271                    Some(StreamLine::Payload(value)) => {
272                        if let Some(tokens) = parse_context_tokens(&value) {
273                            context_tokens = Some(tokens);
274                        }
275                        apply_stream_chunk(
276                            &value,
277                            &mut assistant,
278                            &mut thinking,
279                            &mut partial_calls,
280                            &mut on_event,
281                        )
282                    }
283                    None => continue,
284                }
285            }
286        }
287
288        if !stream_done {
289            match parse_stream_line(buffer.trim()) {
290                Some(StreamLine::Payload(value)) => {
291                    if let Some(tokens) = parse_context_tokens(&value) {
292                        context_tokens = Some(tokens);
293                    }
294                    apply_stream_chunk(
295                        &value,
296                        &mut assistant,
297                        &mut thinking,
298                        &mut partial_calls,
299                        &mut on_event,
300                    )
301                }
302                Some(StreamLine::Done) | None => {}
303            }
304        }
305
306        let tool_calls = partial_calls
307            .into_iter()
308            .filter(|c| !c.name.is_empty())
309            .map(StreamedToolCall::into_tool_call)
310            .collect::<Vec<_>>();
311
312        Ok(ProviderResponse {
313            assistant_message: Message {
314                role: Role::Assistant,
315                content: assistant,
316                attachments: Vec::new(),
317                tool_call_id: None,
318            },
319            done: tool_calls.is_empty(),
320            tool_calls,
321            thinking: if thinking.is_empty() {
322                None
323            } else {
324                Some(thinking)
325            },
326            context_tokens,
327        })
328    }
329}
330
331fn emit_response_stream_events<F>(response: &ProviderResponse, on_event: &mut F)
332where
333    F: FnMut(ProviderStreamEvent) + Send,
334{
335    if let Some(thinking) = &response.thinking {
336        on_event(ProviderStreamEvent::ThinkingDelta(thinking.clone()));
337    }
338    if !response.assistant_message.content.is_empty() {
339        on_event(ProviderStreamEvent::AssistantDelta(
340            response.assistant_message.content.clone(),
341        ));
342    }
343}
344
345enum StreamLine {
346    Done,
347    Payload(Value),
348}
349
350#[derive(Clone, Copy)]
351enum ImageDataFormat {
352    DataUrl,
353    RawBase64,
354}
355
356fn message_to_wire(
357    message: &Message,
358    image_url_as_object: bool,
359    image_data_format: ImageDataFormat,
360) -> Value {
361    let content = if message.attachments.is_empty() {
362        json!(message.content)
363    } else {
364        let mut parts = Vec::new();
365        if !message.content.is_empty() {
366            parts.push(json!({
367                "type": "text",
368                "text": message.content,
369            }));
370        }
371
372        for attachment in &message.attachments {
373            match attachment {
374                MessageAttachment::Image {
375                    media_type,
376                    data_base64,
377                } => {
378                    let image_payload = match image_data_format {
379                        ImageDataFormat::DataUrl => {
380                            format!("data:{};base64,{}", media_type, data_base64)
381                        }
382                        ImageDataFormat::RawBase64 => data_base64.clone(),
383                    };
384                    if image_url_as_object {
385                        parts.push(json!({
386                            "type": "image_url",
387                            "image_url": {
388                                "url": image_payload,
389                            }
390                        }));
391                    } else {
392                        parts.push(json!({
393                            "type": "image_url",
394                            "image_url": image_payload,
395                        }));
396                    }
397                }
398            }
399        }
400
401        json!(parts)
402    };
403
404    let mut wire = json!({
405        "role": role_to_wire(&message.role),
406        "content": content,
407    });
408    if let Some(id) = &message.tool_call_id {
409        wire["tool_call_id"] = json!(id);
410    }
411    wire
412}
413
414fn has_image_attachments(req: &ProviderRequest) -> bool {
415    req.messages
416        .iter()
417        .any(|message| !message.attachments.is_empty())
418}
419
420fn should_retry_for_image_payload(status: reqwest::StatusCode, body: &str) -> bool {
421    if !status.is_client_error() {
422        return false;
423    }
424    let lower = body.to_ascii_lowercase();
425    lower.contains("invalid api parameter")
426        || lower.contains("invalid parameter")
427        || lower.contains("image_url")
428        || lower.contains("invalid type")
429}
430
431fn role_to_wire(role: &Role) -> &'static str {
432    match role {
433        Role::System => "system",
434        Role::User => "user",
435        Role::Assistant => "assistant",
436        Role::Tool => "tool",
437    }
438}
439
440fn parse_stream_line(line: &str) -> Option<StreamLine> {
441    let line = line.trim();
442    if line.is_empty() || !line.starts_with("data:") {
443        return None;
444    }
445
446    let payload = line.trim_start_matches("data:").trim();
447    if payload == "[DONE]" {
448        return Some(StreamLine::Done);
449    }
450
451    serde_json::from_str(payload).ok().map(StreamLine::Payload)
452}
453
454#[async_trait]
455impl Provider for OpenAiCompatibleProvider {
456    async fn complete(&self, req: ProviderRequest) -> anyhow::Result<ProviderResponse> {
457        let response = self
458            .send_request(&req, false, "provider request failed")
459            .await?;
460
461        let value: Value = response.json().await.context("invalid provider JSON")?;
462        Self::parse_chat_response(&value)
463    }
464
465    async fn complete_stream<F>(
466        &self,
467        req: ProviderRequest,
468        mut on_event: F,
469    ) -> anyhow::Result<ProviderResponse>
470    where
471        F: FnMut(ProviderStreamEvent) + Send,
472    {
473        match self.complete_stream_inner(&req, &mut on_event).await {
474            Ok(response) => Ok(response),
475            Err(_) => {
476                let response = self.complete(req).await?;
477                emit_response_stream_events(&response, &mut on_event);
478                Ok(response)
479            }
480        }
481    }
482}
483
484fn parse_tool_calls(message: &serde_json::Map<String, Value>) -> anyhow::Result<Vec<ToolCall>> {
485    let mut tool_calls = Vec::new();
486    if let Some(calls) = message.get("tool_calls").and_then(|v| v.as_array()) {
487        for call in calls {
488            let id = call
489                .get("id")
490                .and_then(|v| v.as_str())
491                .unwrap_or_default()
492                .to_string();
493            let function = call
494                .get("function")
495                .and_then(|v| v.as_object())
496                .context("tool call missing function")?;
497            let name = function
498                .get("name")
499                .and_then(|v| v.as_str())
500                .unwrap_or_default()
501                .to_string();
502            let args_raw = function
503                .get("arguments")
504                .and_then(|v| v.as_str())
505                .unwrap_or("{}");
506            let arguments: Value = serde_json::from_str(args_raw).unwrap_or_else(|_| json!({}));
507            tool_calls.push(ToolCall {
508                id,
509                name,
510                arguments,
511            });
512        }
513    }
514    Ok(tool_calls)
515}
516
517fn extract_thinking(message: &serde_json::Map<String, Value>) -> Option<String> {
518    THINKING_FIELDS.iter().find_map(|k| {
519        message
520            .get(*k)
521            .and_then(|v| v.as_str())
522            .filter(|v| !v.is_empty())
523            .map(ToString::to_string)
524    })
525}
526
527fn parse_context_tokens(payload: &Value) -> Option<usize> {
528    let usage = payload.get("usage")?.as_object()?;
529    usage
530        .get("prompt_tokens")
531        .or_else(|| usage.get("input_tokens"))
532        .or_else(|| usage.get("total_tokens"))
533        .and_then(|value| value.as_u64())
534        .map(|value| value as usize)
535}
536
537fn apply_stream_chunk<F>(
538    value: &Value,
539    assistant: &mut String,
540    thinking: &mut String,
541    partial_calls: &mut Vec<StreamedToolCall>,
542    on_event: &mut F,
543) where
544    F: FnMut(ProviderStreamEvent) + Send,
545{
546    let Some(choice) = value
547        .get("choices")
548        .and_then(|v| v.as_array())
549        .and_then(|a| a.first())
550    else {
551        return;
552    };
553
554    let Some(delta) = choice.get("delta").and_then(|v| v.as_object()) else {
555        return;
556    };
557
558    if let Some(content) = delta.get("content").and_then(|v| v.as_str()) {
559        assistant.push_str(content);
560        on_event(ProviderStreamEvent::AssistantDelta(content.to_string()));
561    }
562
563    for key in THINKING_FIELDS {
564        if let Some(text) = delta.get(key).and_then(|v| v.as_str()) {
565            thinking.push_str(text);
566            on_event(ProviderStreamEvent::ThinkingDelta(text.to_string()));
567        }
568    }
569
570    if let Some(tool_calls) = delta.get("tool_calls").and_then(|v| v.as_array()) {
571        for call in tool_calls {
572            let index = call.get("index").and_then(|v| v.as_u64()).unwrap_or(0) as usize;
573            while partial_calls.len() <= index {
574                partial_calls.push(StreamedToolCall::default());
575            }
576
577            let entry = &mut partial_calls[index];
578            if let Some(id) = call.get("id").and_then(|v| v.as_str()) {
579                entry.id = id.to_string();
580            }
581            if let Some(function) = call.get("function").and_then(|v| v.as_object()) {
582                if let Some(name) = function.get("name").and_then(|v| v.as_str()) {
583                    entry.name = name.to_string();
584                }
585                if let Some(args_piece) = function.get("arguments").and_then(|v| v.as_str()) {
586                    entry.arguments_json.push_str(args_piece);
587                }
588            }
589        }
590    }
591}