1use async_trait::async_trait;
7use derive_builder::Builder;
8use futures::StreamExt;
9use reqwest::Client;
10use serde::{Deserialize, Serialize};
11use std::time::Duration;
12
13use crate::llm::{
14 BaseChatModel, ChatCompletion, ChatStream, LlmError, Message, StopReason, ToolChoice,
15 ToolDefinition, Usage,
16};
17
18#[derive(Builder, Clone)]
20#[builder(pattern = "owned", build_fn(skip))]
21pub struct ChatOpenAICompatible {
22 #[builder(setter(into))]
24 model: String,
25 #[builder(setter(into), default = "None")]
27 api_key: Option<String>,
28 #[builder(setter(into))]
30 base_url: String,
31 #[builder(setter(into))]
33 provider: String,
34 #[builder(default = "0.2")]
36 temperature: f32,
37 #[builder(default = "Some(4096)")]
39 max_completion_tokens: Option<u64>,
40 #[builder(setter(skip))]
42 client: Client,
43 #[builder(setter(skip))]
45 context_window: u64,
46 #[builder(default = "true")]
48 use_bearer_auth: bool,
49}
50
51impl ChatOpenAICompatible {
52 pub fn builder() -> ChatOpenAICompatibleBuilder {
54 ChatOpenAICompatibleBuilder::default()
55 }
56
57 fn build_client() -> Client {
59 Client::builder()
60 .timeout(Duration::from_secs(120))
61 .build()
62 .expect("Failed to create HTTP client")
63 }
64
65 fn default_context_window() -> u64 {
67 128_000
68 }
69
70 fn api_url(&self) -> String {
72 format!("{}/chat/completions", self.base_url.trim_end_matches('/'))
73 }
74
75 fn build_request(
77 &self,
78 messages: Vec<Message>,
79 tools: Option<Vec<ToolDefinition>>,
80 tool_choice: Option<ToolChoice>,
81 stream: bool,
82 ) -> Result<OpenAICompatibleRequest, LlmError> {
83 let openai_messages: Vec<OpenAICompatibleMessage> =
84 messages.into_iter().map(Self::convert_message).collect();
85
86 let openai_tools = tools.map(|ts| {
87 ts.into_iter()
88 .map(|t| OpenAICompatibleTool {
89 tool_type: "function".to_string(),
90 function: OpenAICompatibleFunction {
91 name: t.name,
92 description: t.description,
93 parameters: t.parameters,
94 },
95 })
96 .collect()
97 });
98
99 let tool_choice_value = tool_choice.map(|tc| match tc {
100 ToolChoice::Auto => serde_json::json!("auto"),
101 ToolChoice::Required => serde_json::json!("required"),
102 ToolChoice::None => serde_json::json!("none"),
103 ToolChoice::Named(name) => {
104 serde_json::json!({"type": "function", "function": {"name": name}})
105 }
106 });
107
108 Ok(OpenAICompatibleRequest {
109 model: self.model.clone(),
110 messages: openai_messages,
111 tools: openai_tools,
112 tool_choice: tool_choice_value,
113 temperature: Some(self.temperature),
114 max_tokens: self.max_completion_tokens,
115 stream: if stream { Some(true) } else { None },
116 })
117 }
118
119 fn convert_message(message: Message) -> OpenAICompatibleMessage {
120 match message {
121 Message::User(u) => {
122 let content = if u.content.len() == 1 && u.content[0].is_text() {
123 serde_json::json!(u.content[0].as_text().unwrap())
124 } else {
125 serde_json::json!(u.content)
126 };
127 OpenAICompatibleMessage {
128 role: "user".to_string(),
129 content: Some(content),
130 name: u.name,
131 tool_calls: None,
132 tool_call_id: None,
133 }
134 }
135 Message::Assistant(a) => OpenAICompatibleMessage {
136 role: "assistant".to_string(),
137 content: a.content.map(|c| serde_json::json!(c)),
138 name: None,
139 tool_calls: if a.tool_calls.is_empty() {
140 None
141 } else {
142 Some(a.tool_calls)
143 },
144 tool_call_id: None,
145 },
146 Message::System(s) => OpenAICompatibleMessage {
147 role: "system".to_string(),
148 content: Some(serde_json::json!(s.content)),
149 name: None,
150 tool_calls: None,
151 tool_call_id: None,
152 },
153 Message::Developer(d) => OpenAICompatibleMessage {
154 role: "developer".to_string(),
155 content: Some(serde_json::json!(d.content)),
156 name: None,
157 tool_calls: None,
158 tool_call_id: None,
159 },
160 Message::Tool(t) => OpenAICompatibleMessage {
161 role: "tool".to_string(),
162 content: Some(serde_json::json!(t.content)),
163 name: None,
164 tool_calls: None,
165 tool_call_id: Some(t.tool_call_id),
166 },
167 }
168 }
169
170 fn parse_response(response: OpenAICompatibleResponse) -> ChatCompletion {
171 let stop_reason = response
172 .choices
173 .first()
174 .and_then(|c| c.finish_reason.as_ref())
175 .and_then(|r| match r.as_str() {
176 "stop" => Some(StopReason::EndTurn),
177 "tool_calls" => Some(StopReason::ToolUse),
178 "length" => Some(StopReason::MaxTokens),
179 _ => None,
180 });
181
182 let choice = response.choices.into_iter().next();
183
184 let (content, tool_calls) = choice
185 .map(|c| (c.message.content, c.message.tool_calls.unwrap_or_default()))
186 .unwrap_or((None, Vec::new()));
187
188 let usage = response.usage.map(|u| Usage {
189 prompt_tokens: u.prompt_tokens,
190 completion_tokens: u.completion_tokens,
191 total_tokens: u.total_tokens,
192 ..Default::default()
193 });
194
195 ChatCompletion {
196 content,
197 thinking: None,
198 redacted_thinking: None,
199 tool_calls,
200 usage,
201 stop_reason,
202 }
203 }
204
205 fn parse_stream_chunk(text: &str) -> Option<Result<ChatCompletion, LlmError>> {
206 for line in text.lines() {
207 let line = line.trim();
208 if line.is_empty() || !line.starts_with("data:") {
209 continue;
210 }
211
212 let data = line.strip_prefix("data:").unwrap().trim();
213 if data == "[DONE]" {
214 return None;
215 }
216
217 let chunk: serde_json::Value = match serde_json::from_str(data) {
218 Ok(v) => v,
219 Err(_) => continue,
220 };
221
222 let delta = chunk
223 .get("choices")
224 .and_then(|c| c.as_array())
225 .and_then(|a| a.first())
226 .and_then(|c| c.get("delta"));
227
228 if let Some(delta) = delta {
229 let content = delta
230 .get("content")
231 .and_then(|c| c.as_str())
232 .map(|s| s.to_string());
233
234 let tool_calls: Vec<crate::llm::ToolCall> = delta
235 .get("tool_calls")
236 .and_then(|tc| tc.as_array())
237 .map(|arr| {
238 arr.iter()
239 .filter_map(|tc| {
240 let id = tc.get("id")?.as_str()?.to_string();
241 let func = tc.get("function")?;
242 let name = func.get("name")?.as_str()?.to_string();
243 let arguments = func.get("arguments")?.as_str()?.to_string();
244 Some(crate::llm::ToolCall::new(id, name, arguments))
245 })
246 .collect()
247 })
248 .unwrap_or_default();
249
250 if content.is_some() || !tool_calls.is_empty() {
251 return Some(Ok(ChatCompletion {
252 content,
253 thinking: None,
254 redacted_thinking: None,
255 tool_calls,
256 usage: None,
257 stop_reason: None,
258 }));
259 }
260 }
261 }
262
263 None
264 }
265}
266
267impl ChatOpenAICompatibleBuilder {
268 pub fn build(&self) -> Result<ChatOpenAICompatible, LlmError> {
269 let model = self
270 .model
271 .clone()
272 .ok_or_else(|| LlmError::Config("model is required".into()))?;
273 let base_url = self
274 .base_url
275 .clone()
276 .ok_or_else(|| LlmError::Config("base_url is required".into()))?;
277 let provider = self
278 .provider
279 .clone()
280 .ok_or_else(|| LlmError::Config("provider is required".into()))?;
281
282 Ok(ChatOpenAICompatible {
283 client: ChatOpenAICompatible::build_client(),
284 context_window: ChatOpenAICompatible::default_context_window(),
285 model,
286 api_key: self.api_key.clone().flatten(),
287 base_url,
288 provider,
289 temperature: self.temperature.unwrap_or(0.2),
290 max_completion_tokens: self.max_completion_tokens.flatten(),
291 use_bearer_auth: self.use_bearer_auth.unwrap_or(true),
292 })
293 }
294}
295
296#[async_trait]
297impl BaseChatModel for ChatOpenAICompatible {
298 fn model(&self) -> &str {
299 &self.model
300 }
301
302 fn provider(&self) -> &str {
303 &self.provider
304 }
305
306 fn context_window(&self) -> Option<u64> {
307 Some(self.context_window)
308 }
309
310 async fn invoke(
311 &self,
312 messages: Vec<Message>,
313 tools: Option<Vec<ToolDefinition>>,
314 tool_choice: Option<ToolChoice>,
315 ) -> Result<ChatCompletion, LlmError> {
316 let request = self.build_request(messages, tools, tool_choice, false)?;
317
318 let mut req = self
319 .client
320 .post(self.api_url())
321 .header("Content-Type", "application/json");
322
323 if let Some(ref api_key) = self.api_key {
324 if self.use_bearer_auth {
325 req = req.header("Authorization", format!("Bearer {}", api_key));
326 } else {
327 req = req.header("Authorization", api_key.clone());
328 }
329 }
330
331 let response = req.json(&request).send().await?;
332
333 if !response.status().is_success() {
334 let status = response.status();
335 let body = response.text().await.unwrap_or_default();
336 return Err(LlmError::Api(format!(
337 "{} API error ({}): {}",
338 self.provider, status, body
339 )));
340 }
341
342 let completion: OpenAICompatibleResponse = response.json().await?;
343 Ok(Self::parse_response(completion))
344 }
345
346 async fn invoke_stream(
347 &self,
348 messages: Vec<Message>,
349 tools: Option<Vec<ToolDefinition>>,
350 tool_choice: Option<ToolChoice>,
351 ) -> Result<ChatStream, LlmError> {
352 let request = self.build_request(messages, tools, tool_choice, true)?;
353
354 let mut req = self
355 .client
356 .post(self.api_url())
357 .header("Content-Type", "application/json");
358
359 if let Some(ref api_key) = self.api_key {
360 if self.use_bearer_auth {
361 req = req.header("Authorization", format!("Bearer {}", api_key));
362 } else {
363 req = req.header("Authorization", api_key.clone());
364 }
365 }
366
367 let response = req.json(&request).send().await?;
368
369 if !response.status().is_success() {
370 let status = response.status();
371 let body = response.text().await.unwrap_or_default();
372 return Err(LlmError::Api(format!(
373 "{} API error ({}): {}",
374 self.provider, status, body
375 )));
376 }
377
378 let stream = response.bytes_stream().filter_map(|result| async move {
379 match result {
380 Ok(bytes) => {
381 let text = String::from_utf8_lossy(&bytes);
382 Self::parse_stream_chunk(&text)
383 }
384 Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
385 }
386 });
387
388 Ok(Box::pin(stream))
389 }
390
391 fn supports_vision(&self) -> bool {
392 true
394 }
395}
396
397#[derive(Serialize)]
402struct OpenAICompatibleRequest {
403 model: String,
404 messages: Vec<OpenAICompatibleMessage>,
405 #[serde(skip_serializing_if = "Option::is_none")]
406 tools: Option<Vec<OpenAICompatibleTool>>,
407 #[serde(skip_serializing_if = "Option::is_none")]
408 tool_choice: Option<serde_json::Value>,
409 #[serde(skip_serializing_if = "Option::is_none")]
410 temperature: Option<f32>,
411 #[serde(skip_serializing_if = "Option::is_none")]
412 max_tokens: Option<u64>,
413 #[serde(skip_serializing_if = "Option::is_none")]
414 stream: Option<bool>,
415}
416
417#[derive(Serialize)]
418struct OpenAICompatibleMessage {
419 role: String,
420 #[serde(skip_serializing_if = "Option::is_none")]
421 content: Option<serde_json::Value>,
422 #[serde(skip_serializing_if = "Option::is_none")]
423 name: Option<String>,
424 #[serde(skip_serializing_if = "Option::is_none")]
425 tool_calls: Option<Vec<crate::llm::ToolCall>>,
426 #[serde(skip_serializing_if = "Option::is_none")]
427 tool_call_id: Option<String>,
428}
429
430#[derive(Serialize)]
431struct OpenAICompatibleTool {
432 #[serde(rename = "type")]
433 tool_type: String,
434 function: OpenAICompatibleFunction,
435}
436
437#[derive(Serialize)]
438struct OpenAICompatibleFunction {
439 name: String,
440 description: String,
441 parameters: serde_json::Map<String, serde_json::Value>,
442}
443
444#[derive(Deserialize)]
445struct OpenAICompatibleResponse {
446 choices: Vec<OpenAICompatibleChoice>,
447 #[serde(default)]
448 usage: Option<OpenAICompatibleUsage>,
449}
450
451#[derive(Deserialize)]
452struct OpenAICompatibleChoice {
453 message: OpenAICompatibleMessageResponse,
454 finish_reason: Option<String>,
455}
456
457#[derive(Deserialize)]
458struct OpenAICompatibleMessageResponse {
459 content: Option<String>,
460 tool_calls: Option<Vec<crate::llm::ToolCall>>,
461}
462
463#[derive(Deserialize)]
464struct OpenAICompatibleUsage {
465 prompt_tokens: u64,
466 completion_tokens: u64,
467 total_tokens: u64,
468}