agent_io/llm/openai_compatible/
mod.rs1mod request;
7mod response;
8mod types;
9
10use async_trait::async_trait;
11use derive_builder::Builder;
12use futures::StreamExt;
13use reqwest::Client;
14use std::time::Duration;
15
16use crate::llm::{
17 BaseChatModel, ChatCompletion, ChatStream, LlmError, Message, ToolChoice, ToolDefinition,
18};
19
20use types::*;
21
22#[derive(Builder, Clone)]
24#[builder(pattern = "owned", build_fn(skip))]
25pub struct ChatOpenAICompatible {
26 #[builder(setter(into))]
28 pub(super) model: String,
29 #[builder(setter(into), default = "None")]
31 pub(super) api_key: Option<String>,
32 #[builder(setter(into))]
34 pub(super) base_url: String,
35 #[builder(setter(into))]
37 pub(super) provider: String,
38 #[builder(default = "0.2")]
40 pub(super) temperature: f32,
41 #[builder(default = "Some(4096)")]
43 pub(super) max_completion_tokens: Option<u64>,
44 #[builder(setter(skip))]
46 pub(super) client: Client,
47 #[builder(setter(skip))]
49 pub(super) context_window: u64,
50 #[builder(default = "true")]
52 pub(super) use_bearer_auth: bool,
53}
54
55impl ChatOpenAICompatible {
56 pub fn builder() -> ChatOpenAICompatibleBuilder {
58 ChatOpenAICompatibleBuilder::default()
59 }
60
61 fn build_client() -> Client {
63 Client::builder()
64 .timeout(Duration::from_secs(120))
65 .build()
66 .expect("Failed to create HTTP client")
67 }
68
69 fn default_context_window() -> u64 {
71 128_000
72 }
73
74 fn api_url(&self) -> String {
76 format!("{}/chat/completions", self.base_url.trim_end_matches('/'))
77 }
78}
79
80impl ChatOpenAICompatibleBuilder {
81 pub fn build(&self) -> Result<ChatOpenAICompatible, LlmError> {
82 let model = self
83 .model
84 .clone()
85 .ok_or_else(|| LlmError::Config("model is required".into()))?;
86 let base_url = self
87 .base_url
88 .clone()
89 .ok_or_else(|| LlmError::Config("base_url is required".into()))?;
90 let provider = self
91 .provider
92 .clone()
93 .ok_or_else(|| LlmError::Config("provider is required".into()))?;
94
95 Ok(ChatOpenAICompatible {
96 client: ChatOpenAICompatible::build_client(),
97 context_window: ChatOpenAICompatible::default_context_window(),
98 model,
99 api_key: self.api_key.clone().flatten(),
100 base_url,
101 provider,
102 temperature: self.temperature.unwrap_or(0.2),
103 max_completion_tokens: self.max_completion_tokens.flatten(),
104 use_bearer_auth: self.use_bearer_auth.unwrap_or(true),
105 })
106 }
107}
108
109#[async_trait]
110impl BaseChatModel for ChatOpenAICompatible {
111 fn model(&self) -> &str {
112 &self.model
113 }
114
115 fn provider(&self) -> &str {
116 &self.provider
117 }
118
119 fn context_window(&self) -> Option<u64> {
120 Some(self.context_window)
121 }
122
123 async fn invoke(
124 &self,
125 messages: Vec<Message>,
126 tools: Option<Vec<ToolDefinition>>,
127 tool_choice: Option<ToolChoice>,
128 ) -> Result<ChatCompletion, LlmError> {
129 let request = self.build_request(messages, tools, tool_choice, false)?;
130
131 let mut req = self
132 .client
133 .post(self.api_url())
134 .header("Content-Type", "application/json");
135
136 if let Some(ref api_key) = self.api_key {
137 if self.use_bearer_auth {
138 req = req.header("Authorization", format!("Bearer {}", api_key));
139 } else {
140 req = req.header("Authorization", api_key.clone());
141 }
142 }
143
144 let response = req.json(&request).send().await?;
145
146 if !response.status().is_success() {
147 let status = response.status();
148 let body = response.text().await.unwrap_or_default();
149 return Err(LlmError::Api(format!(
150 "{} API error ({}): {}",
151 self.provider, status, body
152 )));
153 }
154
155 let completion: OpenAICompatibleResponse = response.json().await?;
156 Ok(Self::parse_response(completion))
157 }
158
159 async fn invoke_stream(
160 &self,
161 messages: Vec<Message>,
162 tools: Option<Vec<ToolDefinition>>,
163 tool_choice: Option<ToolChoice>,
164 ) -> Result<ChatStream, LlmError> {
165 let request = self.build_request(messages, tools, tool_choice, true)?;
166
167 let mut req = self
168 .client
169 .post(self.api_url())
170 .header("Content-Type", "application/json");
171
172 if let Some(ref api_key) = self.api_key {
173 if self.use_bearer_auth {
174 req = req.header("Authorization", format!("Bearer {}", api_key));
175 } else {
176 req = req.header("Authorization", api_key.clone());
177 }
178 }
179
180 let response = req.json(&request).send().await?;
181
182 if !response.status().is_success() {
183 let status = response.status();
184 let body = response.text().await.unwrap_or_default();
185 return Err(LlmError::Api(format!(
186 "{} API error ({}): {}",
187 self.provider, status, body
188 )));
189 }
190
191 let stream = response.bytes_stream().filter_map(|result| async move {
192 match result {
193 Ok(bytes) => {
194 let text = String::from_utf8_lossy(&bytes);
195 Self::parse_stream_chunk(&text)
196 }
197 Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
198 }
199 });
200
201 Ok(Box::pin(stream))
202 }
203
204 fn supports_vision(&self) -> bool {
205 true
207 }
208}