Skip to main content

sparrow/provider/
openai_compat.rs

1use async_trait::async_trait;
2use futures::stream::{self, StreamExt};
3use reqwest::Client;
4use serde_json::json;
5use std::collections::HashMap;
6
7use super::{Brain, BrainEvent, BrainRequest, BrainStream, ContentBlock, LatencyClass, ModelCaps};
8
9/// OpenAI-compatible adapter. Covers OpenAI, Groq, NVIDIA NIM, Together, Cerebras,
10/// OpenRouter, NovitaAI, Nous Portal, HuggingFace, Ollama, and custom endpoints.
11pub struct OpenAICompatAdapter {
12    model: String,
13    api_key: String,
14    base_url: String,
15    client: Client,
16    caps: ModelCaps,
17}
18
19impl OpenAICompatAdapter {
20    pub fn new(model: &str, api_key: impl Into<String>, base_url: &str) -> Self {
21        let model = model.to_string();
22        Self {
23            model,
24            api_key: api_key.into(),
25            base_url: base_url.to_string(),
26            client: Client::new(),
27            caps: ModelCaps::default(),
28        }
29    }
30
31    pub fn with_caps(mut self, caps: ModelCaps) -> Self {
32        self.caps = caps;
33        self
34    }
35
36    /// Create an Ollama adapter (OpenAI-compatible API on localhost)
37    pub fn ollama(model: &str, base_url: &str) -> Self {
38        // Ollama doesn't require an API key
39        Self::new(model, "ollama", base_url).with_caps(ModelCaps {
40            context_window: 32_768,
41            max_output: 8_000,
42            tools: true,
43            vision: false,
44            cost_input_per_mtok: 0.0,
45            cost_output_per_mtok: 0.0,
46            latency: LatencyClass::Medium,
47        })
48    }
49}
50
51fn build_chat_body(model: &str, req: &BrainRequest) -> serde_json::Value {
52    let mut messages: Vec<serde_json::Value> = Vec::new();
53
54    // Add system message
55    if let Some(sys) = &req.system {
56        messages.push(json!({
57            "role": "system",
58            "content": sys,
59        }));
60    }
61
62    // Convert messages
63    for msg in &req.messages {
64        if msg.role == "system" {
65            messages.push(json!({
66                "role": "system",
67                "content": msg.content.iter()
68                    .filter_map(|b| match b {
69                        ContentBlock::Text { text } => Some(text.clone()),
70                        _ => None,
71                    })
72                    .collect::<Vec<_>>()
73                    .join("\n"),
74            }));
75            continue;
76        }
77
78        let mut content: Vec<serde_json::Value> = Vec::new();
79        let mut tool_calls: Vec<serde_json::Value> = Vec::new();
80        let mut reasoning_buf = String::new();
81        let mut emitted_tool_result = false;
82
83        for block in &msg.content {
84            match block {
85                ContentBlock::Text { text } => {
86                    content.push(json!({"type": "text", "text": text}));
87                }
88                ContentBlock::Image { source } => {
89                    content.push(json!({
90                        "type": "image_url",
91                        "image_url": {
92                            "url": image_source_url(source),
93                        }
94                    }));
95                }
96                ContentBlock::Reasoning { text } => {
97                    // DeepSeek / Moonshot / Qwen "thinking mode" require the
98                    // model's previous reasoning_content to be echoed back
99                    // on the next turn or the API rejects with 400. We aggregate
100                    // all reasoning blocks of this message and ship them as a
101                    // single `reasoning_content` field.
102                    if !reasoning_buf.is_empty() {
103                        reasoning_buf.push('\n');
104                    }
105                    reasoning_buf.push_str(text);
106                }
107                ContentBlock::ToolUse { id, name, input } => {
108                    tool_calls.push(json!({
109                        "id": id,
110                        "type": "function",
111                        "function": {
112                            "name": name,
113                            "arguments": serde_json::to_string(input).unwrap_or_default(),
114                        }
115                    }));
116                }
117                ContentBlock::ToolResult {
118                    tool_use_id,
119                    content: tool_content,
120                    ..
121                } => {
122                    let text = tool_content
123                        .iter()
124                        .filter_map(|b| match b {
125                            ContentBlock::Text { text } => Some(text.clone()),
126                            _ => None,
127                        })
128                        .collect::<Vec<_>>()
129                        .join("\n");
130                    messages.push(json!({
131                        "role": "tool",
132                        "tool_call_id": tool_use_id,
133                        "content": text,
134                    }));
135                    emitted_tool_result = true;
136                    continue; // tool results are separate messages
137                }
138            }
139        }
140
141        if emitted_tool_result && content.is_empty() && tool_calls.is_empty() {
142            continue;
143        }
144
145        let mut msg_json = json!({ "role": msg.role });
146
147        if !tool_calls.is_empty() {
148            msg_json["tool_calls"] = json!(tool_calls);
149        }
150        if !content.is_empty() {
151            if content.len() == 1 && content[0]["type"] == "text" {
152                msg_json["content"] = json!(content[0]["text"]);
153            } else {
154                msg_json["content"] = json!(content);
155            }
156        }
157        if !reasoning_buf.is_empty() && msg.role == "assistant" {
158            msg_json["reasoning_content"] = json!(reasoning_buf);
159        }
160
161        messages.push(msg_json);
162    }
163
164    // Build tools
165    let tools: Vec<serde_json::Value> = req
166        .tools
167        .iter()
168        .map(|t| {
169            json!({
170                "type": "function",
171                "function": {
172                    "name": t.name,
173                    "description": t.description,
174                    "parameters": t.input_schema,
175                }
176            })
177        })
178        .collect();
179
180    let mut body = json!({
181        "model": model,
182        "messages": messages,
183        "stream": true,
184        "stream_options": {
185            "include_usage": true
186        },
187        "temperature": req.temperature,
188    });
189
190    if req.max_tokens > 0 {
191        body["max_tokens"] = json!(req.max_tokens);
192    }
193    if !tools.is_empty() {
194        body["tools"] = json!(tools);
195    }
196    if !req.stop.is_empty() {
197        body["stop"] = json!(req.stop);
198    }
199    if req.cache.enabled {
200        if let Some(key) = &req.cache.key {
201            body["prompt_cache_key"] = json!(key);
202        }
203        body["prompt_cache_retention"] = json!(req.cache.ttl.openai_retention());
204    }
205
206    body
207}
208
209fn image_source_url(source: &super::ImageSource) -> String {
210    match source {
211        super::ImageSource::Base64 { media_type, data } => {
212            format!("data:{};base64,{}", media_type, data)
213        }
214        super::ImageSource::Url { url } => url.clone(),
215    }
216}
217
218#[async_trait]
219impl Brain for OpenAICompatAdapter {
220    fn id(&self) -> &str {
221        &self.model
222    }
223
224    fn caps(&self) -> ModelCaps {
225        self.caps.clone()
226    }
227
228    async fn complete(&self, req: BrainRequest) -> anyhow::Result<BrainStream> {
229        let body = build_chat_body(&self.model, &req);
230
231        let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
232
233        let response = self
234            .client
235            .post(&url)
236            .header("Authorization", format!("Bearer {}", self.api_key))
237            .json(&body)
238            .send()
239            .await?;
240
241        if !response.status().is_success() {
242            let status = response.status().as_u16();
243            let body = response.text().await.unwrap_or_default();
244            return Err(anyhow::anyhow!(
245                "OpenAI-compatible API error {}: {}",
246                status,
247                body
248            ));
249        }
250
251        #[derive(Default)]
252        struct ToolCallState {
253            id: String,
254            started: bool,
255        }
256
257        let stream = response.bytes_stream();
258
259        // SSE state: tool-call accumulator + line buffer that survives chunk
260        // boundaries. Without the buffer, a JSON event split across two TCP
261        // chunks was parsed in halves and silently dropped — producing the
262        // "à rebours" → "àours" mangling.
263        struct SseState {
264            tools: HashMap<u64, ToolCallState>,
265            lines: super::sse_buffer::LineBuffer,
266            /// Accumulated assistant `content` text for this completion. Used
267            /// to recover tool calls a provider emitted as inline XML/DSML
268            /// markup inside `content` rather than as native `tool_calls`
269            /// (see provider::tool_markup).
270            content_buf: String,
271            /// True once we've decided the content is inline tool-call markup
272            /// and should be suppressed from the visible text stream.
273            suppress_text: bool,
274        }
275
276        let event_stream = stream
277            .scan(
278                SseState {
279                    tools: HashMap::new(),
280                    lines: super::sse_buffer::LineBuffer::new(),
281                    content_buf: String::new(),
282                    suppress_text: false,
283                },
284                |state, chunk| {
285                    let events: Vec<BrainEvent> = match chunk {
286                        Ok(bytes) => {
287                            let lines = state.lines.push(&bytes);
288                            let tool_state = &mut state.tools;
289                            let mut parsed = Vec::new();
290                            for line in lines {
291                                let line = line.trim();
292                                if line.is_empty() || !line.starts_with("data: ") {
293                                    continue;
294                                }
295                                let data = &line[6..];
296                                if data == "[DONE]" {
297                                    continue;
298                                }
299                                let event: serde_json::Value = match serde_json::from_str(data) {
300                                    Ok(v) => v,
301                                    Err(e) => {
302                                        tracing::debug!(
303                                            "JSON parse error: {} — data: {}",
304                                            e,
305                                            &data[..data.len().min(200)]
306                                        );
307                                        continue;
308                                    }
309                                };
310
311                                if let Some(choices) = event["choices"].as_array() {
312                                    for choice in choices {
313                                        if let Some(delta) = choice["delta"].as_object() {
314                                            if let Some(text) =
315                                                delta.get("content").and_then(|v| v.as_str())
316                                            {
317                                                if !text.is_empty() {
318                                                    state.content_buf.push_str(text);
319                                                    // If this completion's content turns
320                                                    // out to be inline tool-call markup
321                                                    // (DeepSeek DSML / Anthropic-style
322                                                    // <invoke>), suppress it from the
323                                                    // visible text stream — it'll be
324                                                    // converted to real tool calls at
325                                                    // finish_reason.
326                                                    if !state.suppress_text
327                                                        && super::tool_markup::looks_like_tool_markup(
328                                                            &state.content_buf,
329                                                        )
330                                                    {
331                                                        state.suppress_text = true;
332                                                    }
333                                                    if !state.suppress_text {
334                                                        parsed.push(BrainEvent::TextDelta(
335                                                            text.to_string(),
336                                                        ));
337                                                    }
338                                                }
339                                            }
340                                            // DeepSeek / Moonshot thinking-mode emit
341                                            // reasoning trace alongside content. Capture
342                                            // it as a dedicated event so the engine can
343                                            // echo it back on the next turn (required
344                                            // by DeepSeek's contract).
345                                            // Several providers report this under
346                                            // different keys; check the known aliases.
347                                            for key in [
348                                                "reasoning_content",
349                                                "reasoning",
350                                                "thinking",
351                                                "thought",
352                                            ] {
353                                                if let Some(rtext) =
354                                                    delta.get(key).and_then(|v| v.as_str())
355                                                {
356                                                    if !rtext.is_empty() {
357                                                        parsed.push(BrainEvent::ReasoningDelta(
358                                                            rtext.to_string(),
359                                                        ));
360                                                    }
361                                                }
362                                            }
363                                        }
364                                        // Some providers (non-streaming chunk at end of
365                                        // turn) bundle the reasoning under
366                                        // `message.reasoning_content` rather than
367                                        // streaming it through `delta`. Cover that path
368                                        // too — duplicate captures are harmless because
369                                        // the engine joins them.
370                                        if let Some(msg_obj) =
371                                            choice.get("message").and_then(|v| v.as_object())
372                                        {
373                                            for key in
374                                                ["reasoning_content", "reasoning", "thinking"]
375                                            {
376                                                if let Some(rtext) =
377                                                    msg_obj.get(key).and_then(|v| v.as_str())
378                                                {
379                                                    if !rtext.is_empty() {
380                                                        parsed.push(BrainEvent::ReasoningDelta(
381                                                            rtext.to_string(),
382                                                        ));
383                                                    }
384                                                }
385                                            }
386                                        }
387                                        if let Some(delta) = choice["delta"].as_object() {
388                                            // (Re-open the original tool_calls block.)
389                                            let _ = delta; // keep this branch syntactically anchored
390                                            if let Some(tool_calls) =
391                                                delta.get("tool_calls").and_then(|v| v.as_array())
392                                            {
393                                                for tc in tool_calls {
394                                                    let idx = tc
395                                                        .get("index")
396                                                        .and_then(|v| v.as_u64())
397                                                        .unwrap_or(0);
398                                                    let id = tc
399                                                        .get("id")
400                                                        .and_then(|v| v.as_str())
401                                                        .map(|s| s.to_string());
402                                                    let state = tool_state.entry(idx).or_default();
403                                                    if let Some(id) = id {
404                                                        state.id = id;
405                                                    }
406                                                    if let Some(func) = tc
407                                                        .get("function")
408                                                        .and_then(|v| v.as_object())
409                                                    {
410                                                        if let Some(name) = func
411                                                            .get("name")
412                                                            .and_then(|v| v.as_str())
413                                                        {
414                                                            if !state.started {
415                                                                if state.id.is_empty() {
416                                                                    state.id = format!(
417                                                                        "tool-call-{}",
418                                                                        idx
419                                                                    );
420                                                                }
421                                                                state.started = true;
422                                                                parsed.push(
423                                                                    BrainEvent::ToolUseStart {
424                                                                        id: state.id.clone(),
425                                                                        name: name.to_string(),
426                                                                    },
427                                                                );
428                                                            }
429                                                        }
430                                                        if let Some(args) = func
431                                                            .get("arguments")
432                                                            .and_then(|v| v.as_str())
433                                                        {
434                                                            if !state.id.is_empty()
435                                                                && !args.is_empty()
436                                                            {
437                                                                parsed.push(
438                                                                    BrainEvent::ToolUseDelta {
439                                                                        id: state.id.clone(),
440                                                                        json: args.to_string(),
441                                                                    },
442                                                                );
443                                                            }
444                                                        }
445                                                    }
446                                                }
447                                            }
448                                        }
449
450                                        if let Some(reason) =
451                                            choice.get("finish_reason").and_then(|v| v.as_str())
452                                        {
453                                            if !reason.is_empty() && reason != "null" {
454                                                let stop = match reason {
455                                                    "stop" => {
456                                                        // Recover tool calls a provider
457                                                        // emitted as inline XML/DSML
458                                                        // markup in `content` (with
459                                                        // finish_reason "stop") instead
460                                                        // of native tool_calls. Without
461                                                        // this the call leaks as raw
462                                                        // text and never runs.
463                                                        let calls = if super::tool_markup::looks_like_tool_markup(
464                                                            &state.content_buf,
465                                                        ) {
466                                                            super::tool_markup::extract_tool_calls(
467                                                                &state.content_buf,
468                                                            )
469                                                        } else {
470                                                            Vec::new()
471                                                        };
472                                                        if calls.is_empty() {
473                                                            crate::event::StopReason::EndTurn
474                                                        } else {
475                                                            for (i, call) in
476                                                                calls.into_iter().enumerate()
477                                                            {
478                                                                let id = format!(
479                                                                    "markup-call-{}",
480                                                                    i
481                                                                );
482                                                                parsed.push(
483                                                                    BrainEvent::ToolUseStart {
484                                                                        id: id.clone(),
485                                                                        name: call.name,
486                                                                    },
487                                                                );
488                                                                parsed.push(
489                                                                    BrainEvent::ToolUseDelta {
490                                                                        id: id.clone(),
491                                                                        json: call
492                                                                            .args
493                                                                            .to_string(),
494                                                                    },
495                                                                );
496                                                                parsed.push(
497                                                                    BrainEvent::ToolUseEnd { id },
498                                                                );
499                                                            }
500                                                            crate::event::StopReason::ToolUse
501                                                        }
502                                                    }
503                                                    "length" => crate::event::StopReason::MaxTokens,
504                                                    "tool_calls" => {
505                                                        for (_, state) in tool_state.drain() {
506                                                            if !state.id.is_empty() {
507                                                                parsed.push(
508                                                                    BrainEvent::ToolUseEnd {
509                                                                        id: state.id,
510                                                                    },
511                                                                );
512                                                            }
513                                                        }
514                                                        crate::event::StopReason::ToolUse
515                                                    }
516                                                    s => crate::event::StopReason::StopSequence(
517                                                        s.to_string(),
518                                                    ),
519                                                };
520                                                parsed.push(BrainEvent::Done(stop));
521                                            }
522                                        }
523                                    }
524                                }
525
526                                if let Some(usage) = event.get("usage").and_then(|u| u.as_object())
527                                {
528                                    // Use .get() — indexing a serde_json::Map with [] panics on a
529                                    // missing key, and some providers (e.g. MiniMax) omit fields.
530                                    parsed.push(BrainEvent::Usage(crate::event::TokenUsage {
531                                        input: usage
532                                            .get("prompt_tokens")
533                                            .and_then(|v| v.as_u64())
534                                            .unwrap_or(0),
535                                        output: usage
536                                            .get("completion_tokens")
537                                            .and_then(|v| v.as_u64())
538                                            .unwrap_or(0),
539                                    }));
540                                }
541                            }
542                            parsed
543                        }
544                        Err(e) => vec![BrainEvent::Error(format!("stream error: {}", e))],
545                    };
546                    futures::future::ready(Some(stream::iter(events)))
547                },
548            )
549            .flatten();
550
551        Ok(Box::pin(event_stream))
552    }
553}
554
555#[cfg(test)]
556mod tests {
557    use super::*;
558    use crate::provider::{Msg, PromptCacheConfig, PromptCacheTtl};
559
560    #[test]
561    fn openai_chat_body_adds_prompt_cache_controls() {
562        let req = BrainRequest {
563            system: Some("stable sparrow system".into()),
564            messages: vec![Msg {
565                role: "user".into(),
566                content: vec![ContentBlock::Text {
567                    text: "dynamic task".into(),
568                }],
569            }],
570            cache: PromptCacheConfig {
571                enabled: true,
572                ttl: PromptCacheTtl::OneHour,
573                key: Some("sparrow-repo-abc".into()),
574            },
575            ..BrainRequest::default()
576        };
577
578        let body = build_chat_body("gpt-test", &req);
579        assert_eq!(body["prompt_cache_key"], "sparrow-repo-abc");
580        assert_eq!(body["prompt_cache_retention"], "in_memory");
581    }
582
583    #[test]
584    fn openai_chat_body_serializes_image_blocks() {
585        let req = BrainRequest {
586            messages: vec![Msg {
587                role: "user".into(),
588                content: vec![
589                    ContentBlock::Text {
590                        text: "what is in this image?".into(),
591                    },
592                    ContentBlock::Image {
593                        source: crate::provider::ImageSource::Base64 {
594                            media_type: "image/png".into(),
595                            data: "iVBORw0KGgo=".into(),
596                        },
597                    },
598                ],
599            }],
600            ..BrainRequest::default()
601        };
602
603        let body = build_chat_body("gpt-test", &req);
604        assert_eq!(body["messages"][0]["content"][0]["type"], "text");
605        assert_eq!(body["messages"][0]["content"][1]["type"], "image_url");
606        assert_eq!(
607            body["messages"][0]["content"][1]["image_url"]["url"],
608            "data:image/png;base64,iVBORw0KGgo="
609        );
610    }
611
612    #[test]
613    fn openai_chat_body_reinjects_assistant_reasoning_content() {
614        let req = BrainRequest {
615            messages: vec![Msg {
616                role: "assistant".into(),
617                content: vec![
618                    ContentBlock::Reasoning {
619                        text: "opaque provider reasoning".into(),
620                    },
621                    ContentBlock::Text {
622                        text: "visible answer".into(),
623                    },
624                ],
625            }],
626            ..BrainRequest::default()
627        };
628
629        let body = build_chat_body("deepseek-test", &req);
630        assert_eq!(body["messages"][0]["content"], "visible answer");
631        assert_eq!(
632            body["messages"][0]["reasoning_content"],
633            "opaque provider reasoning"
634        );
635    }
636
637    #[test]
638    fn multi_tool_turn_is_one_assistant_message_with_reasoning() {
639        // Regression for the v0.5.5 fix: a single model turn that emits N tool
640        // calls must serialize as ONE assistant message carrying
641        // reasoning_content + a tool_calls array of length N. Splitting it into
642        // one message per tool dropped reasoning_content from the 2nd+ calls,
643        // which DeepSeek/Qwen/Moonshot thinking-mode rejects with HTTP 400 and
644        // which aborted multi-file tasks half-way.
645        let req = BrainRequest {
646            messages: vec![Msg {
647                role: "assistant".into(),
648                content: vec![
649                    ContentBlock::Reasoning {
650                        text: "thinking about two files".into(),
651                    },
652                    ContentBlock::ToolUse {
653                        id: "call_0".into(),
654                        name: "fs_write".into(),
655                        input: serde_json::json!({"path": "reverse.py"}),
656                    },
657                    ContentBlock::ToolUse {
658                        id: "call_1".into(),
659                        name: "fs_write".into(),
660                        input: serde_json::json!({"path": "test_reverse.py"}),
661                    },
662                ],
663            }],
664            ..BrainRequest::default()
665        };
666
667        let body = build_chat_body("deepseek-test", &req);
668        // exactly one assistant message
669        assert_eq!(body["messages"].as_array().unwrap().len(), 1);
670        // reasoning_content present on it
671        assert_eq!(
672            body["messages"][0]["reasoning_content"],
673            "thinking about two files"
674        );
675        // both tool calls in a single tool_calls array
676        let calls = body["messages"][0]["tool_calls"].as_array().unwrap();
677        assert_eq!(calls.len(), 2);
678        assert_eq!(calls[0]["id"], "call_0");
679        assert_eq!(calls[1]["id"], "call_1");
680        assert_eq!(calls[0]["function"]["name"], "fs_write");
681    }
682}