Skip to main content

mermaid_cli/providers/model/
openai_compat.rs

1//! OpenAI-compatible provider — wraps
2//! `models::adapters::openai_compat::OpenAICompatAdapter`.
3//!
4//! This provider covers the OpenAI long-tail: OpenRouter, Groq,
5//! Fireworks, Together, custom vLLM endpoints, plus the user-defined
6//! entries in `[providers.*]`. The adapter looks up a
7//! `ProviderProfile` (registry entry) and applies per-provider
8//! reasoning shapes (flat `reasoning_effort` vs nested `reasoning:
9//! {effort}`). This wrapper just forwards.
10
11use std::collections::HashMap;
12use std::sync::Arc;
13
14use async_trait::async_trait;
15
16use crate::domain::ChatRequest;
17use crate::models::adapters::openai_compat::OpenAICompatAdapter;
18use crate::models::{
19    Model, ModelConfig, ModelError, ProviderProfile, ReasoningChunk, Result, StreamCallback,
20    StreamEvent as ModelStreamEvent,
21};
22
23use super::super::capabilities::Capabilities;
24use super::super::ctx::{FinalResponse, StreamContext, StreamEvent};
25use super::ModelProvider;
26
27pub struct OpenAICompatProvider {
28    adapter: OpenAICompatAdapter,
29    capabilities: Capabilities,
30}
31
32impl OpenAICompatProvider {
33    pub fn new(
34        profile: &'static ProviderProfile,
35        base_url: String,
36        api_key: String,
37        model_name: String,
38        extra_headers: HashMap<String, String>,
39    ) -> Result<Self> {
40        let adapter =
41            OpenAICompatAdapter::new(profile, base_url, api_key, model_name, extra_headers)?;
42        let capabilities = Capabilities::from_legacy(adapter.capabilities());
43        Ok(Self {
44            adapter,
45            capabilities,
46        })
47    }
48}
49
50#[async_trait]
51impl ModelProvider for OpenAICompatProvider {
52    fn capabilities(&self) -> &Capabilities {
53        &self.capabilities
54    }
55
56    async fn chat(&self, request: ChatRequest, ctx: StreamContext) -> Result<FinalResponse> {
57        let config = build_model_config(&request);
58        let relay_tx = super::stream_bridge::ordered_relay(ctx.sink.clone());
59        let callback = forward_callback(relay_tx);
60        let chat_fut = self
61            .adapter
62            .chat(&request.messages, &config, Some(callback));
63
64        let response = tokio::select! {
65            biased;
66            _ = ctx.token.cancelled() => {
67                return Err(ModelError::Cancelled);
68            },
69            r = chat_fut => r?,
70        };
71
72        let usage = response.usage.clone();
73        let _ = ctx
74            .sink
75            .send(StreamEvent::Done {
76                usage: usage.clone(),
77                thinking_signature: None,
78            })
79            .await;
80
81        Ok(FinalResponse {
82            usage,
83            thinking_signature: None,
84            tool_calls: response.tool_calls.unwrap_or_default(),
85        })
86    }
87}
88
89fn build_model_config(request: &ChatRequest) -> ModelConfig {
90    ModelConfig {
91        model: request.model_id.clone(),
92        temperature: request.temperature,
93        max_tokens: request.max_tokens,
94        reasoning: request.reasoning,
95        system_prompt: Some(request.system_prompt.clone()),
96        dynamic_system_suffix: request.instructions.clone(),
97        tools: request.tools.iter().map(|t| t.to_openai_json()).collect(),
98        ..Default::default()
99    }
100}
101
102fn forward_callback(sink: tokio::sync::mpsc::UnboundedSender<StreamEvent>) -> StreamCallback {
103    Arc::new(move |event: ModelStreamEvent| {
104        let mapped = match event {
105            ModelStreamEvent::Text(s) => StreamEvent::Text(s),
106            ModelStreamEvent::Reasoning(chunk) => StreamEvent::Reasoning(ReasoningChunk {
107                text: chunk.text,
108                signature: chunk.signature,
109            }),
110            ModelStreamEvent::ToolCall(tc) => StreamEvent::ToolCall(tc),
111            ModelStreamEvent::Done { tokens } => StreamEvent::Done {
112                usage: if tokens > 0 {
113                    Some(crate::models::TokenUsage::provider(0, tokens, tokens))
114                } else {
115                    None
116                },
117                thinking_signature: None,
118            },
119        };
120        let _ = sink.send(mapped);
121    })
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn build_model_config_maps_fields() {
130        let req = ChatRequest {
131            model_id: "groq/llama-3.3-70b-versatile".to_string(),
132            messages: vec![],
133            system_prompt: "sys".to_string(),
134            instructions: None,
135            reasoning: crate::models::ReasoningLevel::Medium,
136            temperature: 0.7,
137            max_tokens: 4096,
138            tools: vec![],
139        };
140        let cfg = build_model_config(&req);
141        assert_eq!(cfg.model, "groq/llama-3.3-70b-versatile");
142    }
143}