Skip to main content

agent_io/llm/openai_compatible/
mod.rs

1//! OpenAI-compatible API base implementation
2//!
3//! This module provides a base implementation for any LLM provider that uses
4//! the OpenAI-compatible API format (Ollama, OpenRouter, DeepSeek, Groq, etc.)
5
6mod request;
7mod response;
8mod types;
9
10use async_trait::async_trait;
11use derive_builder::Builder;
12use futures::StreamExt;
13use reqwest::Client;
14use std::time::Duration;
15
16use crate::llm::{
17    BaseChatModel, ChatCompletion, ChatStream, LlmError, Message, ToolChoice, ToolDefinition,
18};
19
20use types::*;
21
22/// OpenAI-compatible Chat Model base implementation
23#[derive(Builder, Clone)]
24#[builder(pattern = "owned", build_fn(skip))]
25pub struct ChatOpenAICompatible {
26    /// Model identifier
27    #[builder(setter(into))]
28    pub(super) model: String,
29    /// API key (optional for some providers like Ollama)
30    #[builder(setter(into), default = "None")]
31    pub(super) api_key: Option<String>,
32    /// Base URL for the API
33    #[builder(setter(into))]
34    pub(super) base_url: String,
35    /// Provider name for identification
36    #[builder(setter(into))]
37    pub(super) provider: String,
38    /// Temperature for sampling
39    #[builder(default = "0.2")]
40    pub(super) temperature: f32,
41    /// Maximum completion tokens
42    #[builder(default = "Some(4096)")]
43    pub(super) max_completion_tokens: Option<u64>,
44    /// HTTP client
45    #[builder(setter(skip))]
46    pub(super) client: Client,
47    /// Context window size
48    #[builder(setter(skip))]
49    pub(super) context_window: u64,
50    /// Whether to include Bearer prefix in auth header
51    #[builder(default = "true")]
52    pub(super) use_bearer_auth: bool,
53}
54
55impl ChatOpenAICompatible {
56    /// Create a builder for configuration
57    pub fn builder() -> ChatOpenAICompatibleBuilder {
58        ChatOpenAICompatibleBuilder::default()
59    }
60
61    /// Build the HTTP client
62    fn build_client() -> Client {
63        Client::builder()
64            .timeout(Duration::from_secs(120))
65            .build()
66            .expect("Failed to create HTTP client")
67    }
68
69    /// Default context window
70    fn default_context_window() -> u64 {
71        128_000
72    }
73
74    /// Get the API URL
75    fn api_url(&self) -> String {
76        format!("{}/chat/completions", self.base_url.trim_end_matches('/'))
77    }
78}
79
80impl ChatOpenAICompatibleBuilder {
81    pub fn build(&self) -> Result<ChatOpenAICompatible, LlmError> {
82        let model = self
83            .model
84            .clone()
85            .ok_or_else(|| LlmError::Config("model is required".into()))?;
86        let base_url = self
87            .base_url
88            .clone()
89            .ok_or_else(|| LlmError::Config("base_url is required".into()))?;
90        let provider = self
91            .provider
92            .clone()
93            .ok_or_else(|| LlmError::Config("provider is required".into()))?;
94
95        Ok(ChatOpenAICompatible {
96            client: ChatOpenAICompatible::build_client(),
97            context_window: ChatOpenAICompatible::default_context_window(),
98            model,
99            api_key: self.api_key.clone().flatten(),
100            base_url,
101            provider,
102            temperature: self.temperature.unwrap_or(0.2),
103            max_completion_tokens: self.max_completion_tokens.flatten(),
104            use_bearer_auth: self.use_bearer_auth.unwrap_or(true),
105        })
106    }
107}
108
109#[async_trait]
110impl BaseChatModel for ChatOpenAICompatible {
111    fn model(&self) -> &str {
112        &self.model
113    }
114
115    fn provider(&self) -> &str {
116        &self.provider
117    }
118
119    fn context_window(&self) -> Option<u64> {
120        Some(self.context_window)
121    }
122
123    async fn invoke(
124        &self,
125        messages: Vec<Message>,
126        tools: Option<Vec<ToolDefinition>>,
127        tool_choice: Option<ToolChoice>,
128    ) -> Result<ChatCompletion, LlmError> {
129        let request = self.build_request(messages, tools, tool_choice, false)?;
130
131        let mut req = self
132            .client
133            .post(self.api_url())
134            .header("Content-Type", "application/json");
135
136        if let Some(ref api_key) = self.api_key {
137            if self.use_bearer_auth {
138                req = req.header("Authorization", format!("Bearer {}", api_key));
139            } else {
140                req = req.header("Authorization", api_key.clone());
141            }
142        }
143
144        let response = req.json(&request).send().await?;
145
146        if !response.status().is_success() {
147            let status = response.status();
148            let body = response.text().await.unwrap_or_default();
149            return Err(LlmError::Api(format!(
150                "{} API error ({}): {}",
151                self.provider, status, body
152            )));
153        }
154
155        let completion: OpenAICompatibleResponse = response.json().await?;
156        Ok(Self::parse_response(completion))
157    }
158
159    async fn invoke_stream(
160        &self,
161        messages: Vec<Message>,
162        tools: Option<Vec<ToolDefinition>>,
163        tool_choice: Option<ToolChoice>,
164    ) -> Result<ChatStream, LlmError> {
165        let request = self.build_request(messages, tools, tool_choice, true)?;
166
167        let mut req = self
168            .client
169            .post(self.api_url())
170            .header("Content-Type", "application/json");
171
172        if let Some(ref api_key) = self.api_key {
173            if self.use_bearer_auth {
174                req = req.header("Authorization", format!("Bearer {}", api_key));
175            } else {
176                req = req.header("Authorization", api_key.clone());
177            }
178        }
179
180        let response = req.json(&request).send().await?;
181
182        if !response.status().is_success() {
183            let status = response.status();
184            let body = response.text().await.unwrap_or_default();
185            return Err(LlmError::Api(format!(
186                "{} API error ({}): {}",
187                self.provider, status, body
188            )));
189        }
190
191        let stream = response.bytes_stream().filter_map(|result| async move {
192            match result {
193                Ok(bytes) => {
194                    let text = String::from_utf8_lossy(&bytes);
195                    Self::parse_stream_chunk(&text)
196                }
197                Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
198            }
199        });
200
201        Ok(Box::pin(stream))
202    }
203
204    fn supports_vision(&self) -> bool {
205        // Most OpenAI-compatible providers support vision
206        true
207    }
208}