mermaid_cli/providers/model/
openai_compat.rs1use std::collections::HashMap;
12use std::sync::Arc;
13
14use async_trait::async_trait;
15
16use crate::domain::ChatRequest;
17use crate::models::adapters::openai_compat::OpenAICompatAdapter;
18use crate::models::{
19 Model, ModelConfig, ModelError, ProviderProfile, ReasoningChunk, Result, StreamCallback,
20 StreamEvent as ModelStreamEvent,
21};
22
23use super::super::capabilities::Capabilities;
24use super::super::ctx::{FinalResponse, StreamContext, StreamEvent};
25use super::ModelProvider;
26
27pub struct OpenAICompatProvider {
28 adapter: OpenAICompatAdapter,
29 capabilities: Capabilities,
30}
31
32impl OpenAICompatProvider {
33 pub fn new(
34 profile: &'static ProviderProfile,
35 base_url: String,
36 api_key: String,
37 model_name: String,
38 extra_headers: HashMap<String, String>,
39 ) -> Result<Self> {
40 let adapter =
41 OpenAICompatAdapter::new(profile, base_url, api_key, model_name, extra_headers)?;
42 let capabilities = Capabilities::from_legacy(adapter.capabilities());
43 Ok(Self {
44 adapter,
45 capabilities,
46 })
47 }
48}
49
50#[async_trait]
51impl ModelProvider for OpenAICompatProvider {
52 fn capabilities(&self) -> &Capabilities {
53 &self.capabilities
54 }
55
56 async fn chat(&self, request: ChatRequest, ctx: StreamContext) -> Result<FinalResponse> {
57 let config = build_model_config(&request);
58 let relay_tx = super::stream_bridge::ordered_relay(ctx.sink.clone());
59 let callback = forward_callback(relay_tx);
60 let chat_fut = self
61 .adapter
62 .chat(&request.messages, &config, Some(callback));
63
64 let response = tokio::select! {
65 biased;
66 _ = ctx.token.cancelled() => {
67 return Err(ModelError::Cancelled);
68 },
69 r = chat_fut => r?,
70 };
71
72 let usage = response.usage.clone();
73 let _ = ctx
74 .sink
75 .send(StreamEvent::Done {
76 usage: usage.clone(),
77 thinking_signature: None,
78 })
79 .await;
80
81 Ok(FinalResponse {
82 usage,
83 thinking_signature: None,
84 tool_calls: response.tool_calls.unwrap_or_default(),
85 })
86 }
87}
88
89fn build_model_config(request: &ChatRequest) -> ModelConfig {
90 ModelConfig {
91 model: request.model_id.clone(),
92 temperature: request.temperature,
93 max_tokens: request.max_tokens,
94 reasoning: request.reasoning,
95 system_prompt: Some(request.system_prompt.clone()),
96 dynamic_system_suffix: request.instructions.clone(),
97 tools: request.tools.iter().map(|t| t.to_openai_json()).collect(),
98 ..Default::default()
99 }
100}
101
102fn forward_callback(sink: tokio::sync::mpsc::UnboundedSender<StreamEvent>) -> StreamCallback {
103 Arc::new(move |event: ModelStreamEvent| {
104 let mapped = match event {
105 ModelStreamEvent::Text(s) => StreamEvent::Text(s),
106 ModelStreamEvent::Reasoning(chunk) => StreamEvent::Reasoning(ReasoningChunk {
107 text: chunk.text,
108 signature: chunk.signature,
109 }),
110 ModelStreamEvent::ToolCall(tc) => StreamEvent::ToolCall(tc),
111 ModelStreamEvent::Done { tokens } => StreamEvent::Done {
112 usage: if tokens > 0 {
113 Some(crate::models::TokenUsage::provider(0, tokens, tokens))
114 } else {
115 None
116 },
117 thinking_signature: None,
118 },
119 };
120 let _ = sink.send(mapped);
121 })
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127
128 #[test]
129 fn build_model_config_maps_fields() {
130 let req = ChatRequest {
131 model_id: "groq/llama-3.3-70b-versatile".to_string(),
132 messages: vec![],
133 system_prompt: "sys".to_string(),
134 instructions: None,
135 reasoning: crate::models::ReasoningLevel::Medium,
136 temperature: 0.7,
137 max_tokens: 4096,
138 tools: vec![],
139 };
140 let cfg = build_model_config(&req);
141 assert_eq!(cfg.model, "groq/llama-3.3-70b-versatile");
142 }
143}