Skip to main content

agent_io/llm/
openai_compatible.rs

1//! OpenAI-compatible API base implementation
2//!
3//! This module provides a base implementation for any LLM provider that uses
4//! the OpenAI-compatible API format (Ollama, OpenRouter, DeepSeek, Groq, etc.)
5
6use async_trait::async_trait;
7use derive_builder::Builder;
8use futures::StreamExt;
9use reqwest::Client;
10use serde::{Deserialize, Serialize};
11use std::time::Duration;
12
13use crate::llm::{
14    BaseChatModel, ChatCompletion, ChatStream, LlmError, Message, StopReason, ToolChoice,
15    ToolDefinition, Usage,
16};
17
18/// OpenAI-compatible Chat Model base implementation
19#[derive(Builder, Clone)]
20#[builder(pattern = "owned", build_fn(skip))]
21pub struct ChatOpenAICompatible {
22    /// Model identifier
23    #[builder(setter(into))]
24    model: String,
25    /// API key (optional for some providers like Ollama)
26    #[builder(setter(into), default = "None")]
27    api_key: Option<String>,
28    /// Base URL for the API
29    #[builder(setter(into))]
30    base_url: String,
31    /// Provider name for identification
32    #[builder(setter(into))]
33    provider: String,
34    /// Temperature for sampling
35    #[builder(default = "0.2")]
36    temperature: f32,
37    /// Maximum completion tokens
38    #[builder(default = "Some(4096)")]
39    max_completion_tokens: Option<u64>,
40    /// HTTP client
41    #[builder(setter(skip))]
42    client: Client,
43    /// Context window size
44    #[builder(setter(skip))]
45    context_window: u64,
46    /// Whether to include Bearer prefix in auth header
47    #[builder(default = "true")]
48    use_bearer_auth: bool,
49}
50
51impl ChatOpenAICompatible {
52    /// Create a builder for configuration
53    pub fn builder() -> ChatOpenAICompatibleBuilder {
54        ChatOpenAICompatibleBuilder::default()
55    }
56
57    /// Build the HTTP client
58    fn build_client() -> Client {
59        Client::builder()
60            .timeout(Duration::from_secs(120))
61            .build()
62            .expect("Failed to create HTTP client")
63    }
64
65    /// Default context window
66    fn default_context_window() -> u64 {
67        128_000
68    }
69
70    /// Get the API URL
71    fn api_url(&self) -> String {
72        format!("{}/chat/completions", self.base_url.trim_end_matches('/'))
73    }
74
75    /// Build request
76    fn build_request(
77        &self,
78        messages: Vec<Message>,
79        tools: Option<Vec<ToolDefinition>>,
80        tool_choice: Option<ToolChoice>,
81        stream: bool,
82    ) -> Result<OpenAICompatibleRequest, LlmError> {
83        let openai_messages: Vec<OpenAICompatibleMessage> =
84            messages.into_iter().map(Self::convert_message).collect();
85
86        let openai_tools = tools.map(|ts| {
87            ts.into_iter()
88                .map(|t| OpenAICompatibleTool {
89                    tool_type: "function".to_string(),
90                    function: OpenAICompatibleFunction {
91                        name: t.name,
92                        description: t.description,
93                        parameters: t.parameters,
94                    },
95                })
96                .collect()
97        });
98
99        let tool_choice_value = tool_choice.map(|tc| match tc {
100            ToolChoice::Auto => serde_json::json!("auto"),
101            ToolChoice::Required => serde_json::json!("required"),
102            ToolChoice::None => serde_json::json!("none"),
103            ToolChoice::Named(name) => {
104                serde_json::json!({"type": "function", "function": {"name": name}})
105            }
106        });
107
108        Ok(OpenAICompatibleRequest {
109            model: self.model.clone(),
110            messages: openai_messages,
111            tools: openai_tools,
112            tool_choice: tool_choice_value,
113            temperature: Some(self.temperature),
114            max_tokens: self.max_completion_tokens,
115            stream: if stream { Some(true) } else { None },
116        })
117    }
118
119    fn convert_message(message: Message) -> OpenAICompatibleMessage {
120        match message {
121            Message::User(u) => {
122                let content = if u.content.len() == 1 && u.content[0].is_text() {
123                    serde_json::json!(u.content[0].as_text().unwrap())
124                } else {
125                    serde_json::json!(u.content)
126                };
127                OpenAICompatibleMessage {
128                    role: "user".to_string(),
129                    content: Some(content),
130                    name: u.name,
131                    tool_calls: None,
132                    tool_call_id: None,
133                }
134            }
135            Message::Assistant(a) => OpenAICompatibleMessage {
136                role: "assistant".to_string(),
137                content: a.content.map(|c| serde_json::json!(c)),
138                name: None,
139                tool_calls: if a.tool_calls.is_empty() {
140                    None
141                } else {
142                    Some(a.tool_calls)
143                },
144                tool_call_id: None,
145            },
146            Message::System(s) => OpenAICompatibleMessage {
147                role: "system".to_string(),
148                content: Some(serde_json::json!(s.content)),
149                name: None,
150                tool_calls: None,
151                tool_call_id: None,
152            },
153            Message::Developer(d) => OpenAICompatibleMessage {
154                role: "developer".to_string(),
155                content: Some(serde_json::json!(d.content)),
156                name: None,
157                tool_calls: None,
158                tool_call_id: None,
159            },
160            Message::Tool(t) => OpenAICompatibleMessage {
161                role: "tool".to_string(),
162                content: Some(serde_json::json!(t.content)),
163                name: None,
164                tool_calls: None,
165                tool_call_id: Some(t.tool_call_id),
166            },
167        }
168    }
169
170    fn parse_response(response: OpenAICompatibleResponse) -> ChatCompletion {
171        let stop_reason = response
172            .choices
173            .first()
174            .and_then(|c| c.finish_reason.as_ref())
175            .and_then(|r| match r.as_str() {
176                "stop" => Some(StopReason::EndTurn),
177                "tool_calls" => Some(StopReason::ToolUse),
178                "length" => Some(StopReason::MaxTokens),
179                _ => None,
180            });
181
182        let choice = response.choices.into_iter().next();
183
184        let (content, tool_calls) = choice
185            .map(|c| (c.message.content, c.message.tool_calls.unwrap_or_default()))
186            .unwrap_or((None, Vec::new()));
187
188        let usage = response.usage.map(|u| Usage {
189            prompt_tokens: u.prompt_tokens,
190            completion_tokens: u.completion_tokens,
191            total_tokens: u.total_tokens,
192            ..Default::default()
193        });
194
195        ChatCompletion {
196            content,
197            thinking: None,
198            redacted_thinking: None,
199            tool_calls,
200            usage,
201            stop_reason,
202        }
203    }
204
205    fn parse_stream_chunk(text: &str) -> Option<Result<ChatCompletion, LlmError>> {
206        for line in text.lines() {
207            let line = line.trim();
208            if line.is_empty() || !line.starts_with("data:") {
209                continue;
210            }
211
212            let data = line.strip_prefix("data:").unwrap().trim();
213            if data == "[DONE]" {
214                return None;
215            }
216
217            let chunk: serde_json::Value = match serde_json::from_str(data) {
218                Ok(v) => v,
219                Err(_) => continue,
220            };
221
222            let delta = chunk
223                .get("choices")
224                .and_then(|c| c.as_array())
225                .and_then(|a| a.first())
226                .and_then(|c| c.get("delta"));
227
228            if let Some(delta) = delta {
229                let content = delta
230                    .get("content")
231                    .and_then(|c| c.as_str())
232                    .map(|s| s.to_string());
233
234                let tool_calls: Vec<crate::llm::ToolCall> = delta
235                    .get("tool_calls")
236                    .and_then(|tc| tc.as_array())
237                    .map(|arr| {
238                        arr.iter()
239                            .filter_map(|tc| {
240                                let id = tc.get("id")?.as_str()?.to_string();
241                                let func = tc.get("function")?;
242                                let name = func.get("name")?.as_str()?.to_string();
243                                let arguments = func.get("arguments")?.as_str()?.to_string();
244                                Some(crate::llm::ToolCall::new(id, name, arguments))
245                            })
246                            .collect()
247                    })
248                    .unwrap_or_default();
249
250                if content.is_some() || !tool_calls.is_empty() {
251                    return Some(Ok(ChatCompletion {
252                        content,
253                        thinking: None,
254                        redacted_thinking: None,
255                        tool_calls,
256                        usage: None,
257                        stop_reason: None,
258                    }));
259                }
260            }
261        }
262
263        None
264    }
265}
266
267impl ChatOpenAICompatibleBuilder {
268    pub fn build(&self) -> Result<ChatOpenAICompatible, LlmError> {
269        let model = self
270            .model
271            .clone()
272            .ok_or_else(|| LlmError::Config("model is required".into()))?;
273        let base_url = self
274            .base_url
275            .clone()
276            .ok_or_else(|| LlmError::Config("base_url is required".into()))?;
277        let provider = self
278            .provider
279            .clone()
280            .ok_or_else(|| LlmError::Config("provider is required".into()))?;
281
282        Ok(ChatOpenAICompatible {
283            client: ChatOpenAICompatible::build_client(),
284            context_window: ChatOpenAICompatible::default_context_window(),
285            model,
286            api_key: self.api_key.clone().flatten(),
287            base_url,
288            provider,
289            temperature: self.temperature.unwrap_or(0.2),
290            max_completion_tokens: self.max_completion_tokens.flatten(),
291            use_bearer_auth: self.use_bearer_auth.unwrap_or(true),
292        })
293    }
294}
295
296#[async_trait]
297impl BaseChatModel for ChatOpenAICompatible {
298    fn model(&self) -> &str {
299        &self.model
300    }
301
302    fn provider(&self) -> &str {
303        &self.provider
304    }
305
306    fn context_window(&self) -> Option<u64> {
307        Some(self.context_window)
308    }
309
310    async fn invoke(
311        &self,
312        messages: Vec<Message>,
313        tools: Option<Vec<ToolDefinition>>,
314        tool_choice: Option<ToolChoice>,
315    ) -> Result<ChatCompletion, LlmError> {
316        let request = self.build_request(messages, tools, tool_choice, false)?;
317
318        let mut req = self
319            .client
320            .post(self.api_url())
321            .header("Content-Type", "application/json");
322
323        if let Some(ref api_key) = self.api_key {
324            if self.use_bearer_auth {
325                req = req.header("Authorization", format!("Bearer {}", api_key));
326            } else {
327                req = req.header("Authorization", api_key.clone());
328            }
329        }
330
331        let response = req.json(&request).send().await?;
332
333        if !response.status().is_success() {
334            let status = response.status();
335            let body = response.text().await.unwrap_or_default();
336            return Err(LlmError::Api(format!(
337                "{} API error ({}): {}",
338                self.provider, status, body
339            )));
340        }
341
342        let completion: OpenAICompatibleResponse = response.json().await?;
343        Ok(Self::parse_response(completion))
344    }
345
346    async fn invoke_stream(
347        &self,
348        messages: Vec<Message>,
349        tools: Option<Vec<ToolDefinition>>,
350        tool_choice: Option<ToolChoice>,
351    ) -> Result<ChatStream, LlmError> {
352        let request = self.build_request(messages, tools, tool_choice, true)?;
353
354        let mut req = self
355            .client
356            .post(self.api_url())
357            .header("Content-Type", "application/json");
358
359        if let Some(ref api_key) = self.api_key {
360            if self.use_bearer_auth {
361                req = req.header("Authorization", format!("Bearer {}", api_key));
362            } else {
363                req = req.header("Authorization", api_key.clone());
364            }
365        }
366
367        let response = req.json(&request).send().await?;
368
369        if !response.status().is_success() {
370            let status = response.status();
371            let body = response.text().await.unwrap_or_default();
372            return Err(LlmError::Api(format!(
373                "{} API error ({}): {}",
374                self.provider, status, body
375            )));
376        }
377
378        let stream = response.bytes_stream().filter_map(|result| async move {
379            match result {
380                Ok(bytes) => {
381                    let text = String::from_utf8_lossy(&bytes);
382                    Self::parse_stream_chunk(&text)
383                }
384                Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
385            }
386        });
387
388        Ok(Box::pin(stream))
389    }
390
391    fn supports_vision(&self) -> bool {
392        // Most OpenAI-compatible providers support vision
393        true
394    }
395}
396
397// =============================================================================
398// Request/Response Types
399// =============================================================================
400
401#[derive(Serialize)]
402struct OpenAICompatibleRequest {
403    model: String,
404    messages: Vec<OpenAICompatibleMessage>,
405    #[serde(skip_serializing_if = "Option::is_none")]
406    tools: Option<Vec<OpenAICompatibleTool>>,
407    #[serde(skip_serializing_if = "Option::is_none")]
408    tool_choice: Option<serde_json::Value>,
409    #[serde(skip_serializing_if = "Option::is_none")]
410    temperature: Option<f32>,
411    #[serde(skip_serializing_if = "Option::is_none")]
412    max_tokens: Option<u64>,
413    #[serde(skip_serializing_if = "Option::is_none")]
414    stream: Option<bool>,
415}
416
417#[derive(Serialize)]
418struct OpenAICompatibleMessage {
419    role: String,
420    #[serde(skip_serializing_if = "Option::is_none")]
421    content: Option<serde_json::Value>,
422    #[serde(skip_serializing_if = "Option::is_none")]
423    name: Option<String>,
424    #[serde(skip_serializing_if = "Option::is_none")]
425    tool_calls: Option<Vec<crate::llm::ToolCall>>,
426    #[serde(skip_serializing_if = "Option::is_none")]
427    tool_call_id: Option<String>,
428}
429
430#[derive(Serialize)]
431struct OpenAICompatibleTool {
432    #[serde(rename = "type")]
433    tool_type: String,
434    function: OpenAICompatibleFunction,
435}
436
437#[derive(Serialize)]
438struct OpenAICompatibleFunction {
439    name: String,
440    description: String,
441    parameters: serde_json::Map<String, serde_json::Value>,
442}
443
444#[derive(Deserialize)]
445struct OpenAICompatibleResponse {
446    choices: Vec<OpenAICompatibleChoice>,
447    #[serde(default)]
448    usage: Option<OpenAICompatibleUsage>,
449}
450
451#[derive(Deserialize)]
452struct OpenAICompatibleChoice {
453    message: OpenAICompatibleMessageResponse,
454    finish_reason: Option<String>,
455}
456
457#[derive(Deserialize)]
458struct OpenAICompatibleMessageResponse {
459    content: Option<String>,
460    tool_calls: Option<Vec<crate::llm::ToolCall>>,
461}
462
463#[derive(Deserialize)]
464struct OpenAICompatibleUsage {
465    prompt_tokens: u64,
466    completion_tokens: u64,
467    total_tokens: u64,
468}