1use std::pin::Pin;
2
3use futures::Stream;
4use serde::{Deserialize, Serialize};
5
6use super::{
7 build_http_client, ensure_ok, LlmError, LlmProvider, Message, ProposedToolCall, Response,
8 ResponseChunk, ToolDef, Usage,
9};
10
11#[derive(Serialize)]
12struct OpenAiRequest {
13 model: String,
14 messages: Vec<OpenAiMessage>,
15 temperature: f64,
16 max_tokens: Option<i32>,
17 stream: bool,
18 #[serde(skip_serializing_if = "Option::is_none")]
22 tools: Option<Vec<OpenAiTool>>,
23 #[serde(skip_serializing_if = "Option::is_none")]
26 tool_choice: Option<&'static str>,
27}
28
29#[derive(Serialize, Deserialize, Default)]
30struct OpenAiMessage {
31 role: String,
32 #[serde(default, skip_serializing_if = "Option::is_none")]
35 content: Option<String>,
36 #[serde(default, skip_serializing_if = "Option::is_none")]
37 tool_calls: Option<Vec<OpenAiToolCall>>,
38 #[serde(default, skip_serializing_if = "Option::is_none")]
41 tool_call_id: Option<String>,
42}
43
44#[derive(Serialize)]
46struct OpenAiTool {
47 #[serde(rename = "type")]
48 kind: &'static str,
49 function: OpenAiFunctionDef,
50}
51
52#[derive(Serialize)]
53struct OpenAiFunctionDef {
54 name: String,
55 description: String,
56 parameters: serde_json::Value,
57}
58
59#[derive(Serialize, Deserialize)]
64struct OpenAiToolCall {
65 #[serde(default)]
66 id: Option<String>,
67 #[serde(rename = "type", default = "function_kind")]
68 kind: String,
69 function: OpenAiFunctionCall,
70}
71
72fn function_kind() -> String {
73 "function".to_string()
74}
75
76#[derive(Serialize, Deserialize)]
77struct OpenAiFunctionCall {
78 name: String,
79 #[serde(default)]
80 arguments: String,
81}
82
83#[derive(Deserialize)]
84struct OpenAiResponse {
85 choices: Vec<OpenAiChoice>,
86 usage: Option<OpenAiUsage>,
87}
88
89#[derive(Deserialize)]
90struct OpenAiChoice {
91 message: OpenAiMessage,
92 #[allow(dead_code)]
93 finish_reason: Option<String>,
94}
95
96#[derive(Deserialize)]
97struct OpenAiStreamResponse {
98 choices: Vec<OpenAiStreamChoice>,
99}
100
101#[derive(Deserialize)]
102struct OpenAiStreamChoice {
103 delta: OpenAiDelta,
104 finish_reason: Option<String>,
105}
106
107#[derive(Deserialize)]
108struct OpenAiDelta {
109 #[serde(default)]
110 content: Option<String>,
111}
112
113#[derive(Deserialize)]
114struct OpenAiUsage {
115 prompt_tokens: u32,
116 completion_tokens: u32,
117 total_tokens: u32,
118}
119
120pub struct OpenAiProvider {
122 client: reqwest::Client,
123 base_url: String,
124 api_key: Option<String>,
125 model: String,
126 temperature: f64,
127 max_tokens: Option<i32>,
128}
129
130impl OpenAiProvider {
131 pub fn new(
132 base_url: &str,
133 api_key: Option<&str>,
134 model: &str,
135 temperature: f64,
136 max_tokens: Option<i32>,
137 ) -> Result<Self, LlmError> {
138 let client = build_http_client(brain::timeouts::LLM_GENERATE)?;
139 Ok(Self {
140 client,
141 base_url: base_url.trim_end_matches('/').to_string(),
142 api_key: api_key.map(|s| s.to_string()),
143 model: model.to_string(),
144 temperature,
145 max_tokens,
146 })
147 }
148
149 pub fn openai(api_key: &str, model: &str) -> Result<Self, LlmError> {
150 Self::new(
151 "https://api.openai.com/v1",
152 Some(api_key),
153 model,
154 0.7,
155 Some(4096),
156 )
157 }
158
159 pub fn openrouter(api_key: &str, model: &str) -> Result<Self, LlmError> {
160 Self::new(
161 "https://openrouter.ai/api/v1",
162 Some(api_key),
163 model,
164 0.7,
165 Some(4096),
166 )
167 }
168
169 fn convert_messages(messages: &[Message]) -> Vec<OpenAiMessage> {
170 messages.iter().map(Self::convert_message).collect()
171 }
172
173 fn convert_message(m: &Message) -> OpenAiMessage {
178 let role = m.role.as_wire_str().to_string();
179 if !m.tool_calls.is_empty() {
180 return OpenAiMessage {
181 role,
182 content: (!m.content.is_empty()).then(|| m.content.clone()),
183 tool_calls: Some(m.tool_calls.iter().map(convert_proposed_call).collect()),
184 tool_call_id: None,
185 };
186 }
187 OpenAiMessage {
188 role,
189 content: Some(m.content.clone()),
190 tool_calls: None,
191 tool_call_id: m.tool_call_id.clone(),
192 }
193 }
194
195 fn convert_tools(tools: &[ToolDef]) -> Vec<OpenAiTool> {
198 tools
199 .iter()
200 .map(|t| OpenAiTool {
201 kind: "function",
202 function: OpenAiFunctionDef {
203 name: t.name.clone(),
204 description: t.description.clone(),
205 parameters: t.parameters.clone(),
206 },
207 })
208 .collect()
209 }
210
211 fn extract_tool_calls(message: &OpenAiMessage) -> Vec<ProposedToolCall> {
216 message
217 .tool_calls
218 .iter()
219 .flatten()
220 .map(|tc| ProposedToolCall {
221 id: tc.id.clone(),
222 name: tc.function.name.clone(),
223 arguments: parse_arguments(&tc.function.arguments),
224 })
225 .collect()
226 }
227
228 fn build_request(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
229 let mut builder = builder;
230 if let Some(key) = &self.api_key {
231 builder = builder.header("Authorization", format!("Bearer {}", key));
232 }
233 builder
234 }
235}
236
237#[async_trait::async_trait]
238impl LlmProvider for OpenAiProvider {
239 async fn generate(&self, messages: &[Message]) -> Result<Response, LlmError> {
240 let url = format!("{}/chat/completions", self.base_url);
241 let request = OpenAiRequest {
242 model: self.model.clone(),
243 messages: Self::convert_messages(messages),
244 temperature: self.temperature,
245 max_tokens: self.max_tokens,
246 stream: false,
247 tools: None,
248 tool_choice: None,
249 };
250
251 let resp = self
252 .build_request(self.client.post(&url))
253 .json(&request)
254 .send()
255 .await?;
256 let resp = ensure_ok(resp).await?;
257
258 let data: OpenAiResponse = resp.json().await?;
259 let content = data
260 .choices
261 .first()
262 .and_then(|c| c.message.content.clone())
263 .unwrap_or_default();
264
265 Ok(Response::text(content, convert_usage(data.usage)))
266 }
267
268 async fn generate_with_tools(
269 &self,
270 messages: &[Message],
271 tools: &[ToolDef],
272 ) -> Result<Response, LlmError> {
273 if tools.is_empty() {
275 return self.generate(messages).await;
276 }
277
278 let url = format!("{}/chat/completions", self.base_url);
279 let request = OpenAiRequest {
280 model: self.model.clone(),
281 messages: Self::convert_messages(messages),
282 temperature: self.temperature,
283 max_tokens: self.max_tokens,
284 stream: false,
285 tools: Some(Self::convert_tools(tools)),
286 tool_choice: Some("auto"),
288 };
289
290 let resp = self
291 .build_request(self.client.post(&url))
292 .json(&request)
293 .send()
294 .await?;
295 let resp = ensure_ok(resp).await?;
296
297 let data: OpenAiResponse = resp.json().await?;
298 let (content, tool_calls) = match data.choices.first() {
299 Some(choice) => (
300 choice.message.content.clone().unwrap_or_default(),
301 Self::extract_tool_calls(&choice.message),
302 ),
303 None => (String::new(), Vec::new()),
304 };
305
306 Ok(Response {
307 content,
308 usage: convert_usage(data.usage),
309 tool_calls,
310 })
311 }
312
313 async fn generate_stream(
314 &self,
315 messages: &[Message],
316 ) -> Result<Pin<Box<dyn Stream<Item = Result<ResponseChunk, LlmError>> + Send>>, LlmError> {
317 use futures::stream::try_unfold;
318
319 let url = format!("{}/chat/completions", self.base_url);
320 let request = OpenAiRequest {
321 model: self.model.clone(),
322 messages: Self::convert_messages(messages),
323 temperature: self.temperature,
324 max_tokens: self.max_tokens,
325 stream: true,
326 tools: None,
327 tool_choice: None,
328 };
329
330 let resp = self
331 .build_request(self.client.post(&url))
332 .json(&request)
333 .send()
334 .await?;
335 let resp = ensure_ok(resp).await?;
336
337 let byte_stream = resp.bytes_stream();
338 let stream = try_unfold(
339 (Box::pin(byte_stream), String::new()),
340 |(mut byte_stream, mut buf)| async move {
341 use futures::TryStreamExt;
342
343 loop {
344 if let Some(newline_pos) = buf.find('\n') {
345 let line: String = buf[..newline_pos].to_string();
346 buf = buf[newline_pos + 1..].to_string();
347
348 let line = line.trim();
349 if line.is_empty() {
350 continue;
351 }
352
353 if let Some(data) = line.strip_prefix("data: ") {
354 let data = data.trim();
355 if data == "[DONE]" {
356 return Ok(None);
357 }
358
359 match serde_json::from_str::<OpenAiStreamResponse>(data) {
360 Ok(resp) => {
361 if let Some(choice) = resp.choices.first() {
362 let content =
363 choice.delta.content.clone().unwrap_or_default();
364 let is_done = choice.finish_reason.is_some();
365 let chunk = ResponseChunk { content, is_done };
366 return Ok(Some((chunk, (byte_stream, buf))));
367 }
368 continue;
369 }
370 Err(e) => {
371 return Err(LlmError::InvalidFormat(format!(
372 "Failed to parse streaming response: {e}"
373 )));
374 }
375 }
376 }
377 continue;
378 }
379
380 match byte_stream.try_next().await {
381 Ok(Some(bytes)) => {
382 buf.push_str(&String::from_utf8_lossy(&bytes));
383 }
384 Ok(None) => return Ok(None),
385 Err(e) => return Err(LlmError::Http(e)),
386 }
387 }
388 },
389 );
390
391 Ok(Box::pin(stream))
392 }
393
394 async fn health_check(&self) -> bool {
395 let url = format!("{}/models", self.base_url);
396 match self.build_request(self.client.get(&url)).send().await {
397 Ok(resp) => resp.status().is_success(),
398 Err(_) => false,
399 }
400 }
401
402 fn name(&self) -> &str {
403 "openai"
404 }
405
406 fn model(&self) -> &str {
407 &self.model
408 }
409
410 async fn list_models(&self) -> Result<Vec<String>, LlmError> {
411 #[derive(Deserialize)]
412 struct ModelEntry {
413 id: String,
414 }
415 #[derive(Deserialize)]
416 struct Models {
417 data: Vec<ModelEntry>,
418 }
419
420 let url = format!("{}/models", self.base_url);
421 let resp = self.build_request(self.client.get(&url)).send().await?;
422 let resp = ensure_ok(resp).await?;
423 let data: Models = resp.json().await?;
424 Ok(data.data.into_iter().map(|m| m.id).collect())
425 }
426
427 async fn fetch_context_window(&self) -> Option<usize> {
428 #[derive(Deserialize)]
431 struct ModelDetail {
432 id: String,
433 #[serde(default)]
434 context_length: Option<usize>,
435 }
436 #[derive(Deserialize)]
437 struct ModelsResponse {
438 data: Vec<ModelDetail>,
439 }
440
441 let from_api = (async {
442 let url = format!("{}/models", self.base_url);
443 let resp = self
444 .build_request(self.client.get(&url))
445 .send()
446 .await
447 .ok()?;
448 let resp = ensure_ok(resp).await.ok()?;
449 let data: ModelsResponse = resp.json().await.ok()?;
450 let active = self.model();
451 for model in &data.data {
453 if model.id == active {
454 return model.context_length;
455 }
456 }
457 for model in &data.data {
460 if model.id.ends_with(active) || model.id.contains(active) {
461 return model.context_length;
462 }
463 }
464 None
465 })
466 .await;
467 if from_api.is_some() {
468 return from_api;
469 }
470
471 super::known_context_window(self.model())
473 }
474}
475
476fn convert_usage(usage: Option<OpenAiUsage>) -> Option<Usage> {
478 usage.map(|u| Usage {
479 prompt_tokens: u.prompt_tokens,
480 completion_tokens: u.completion_tokens,
481 total_tokens: u.total_tokens,
482 })
483}
484
485fn convert_proposed_call(call: &ProposedToolCall) -> OpenAiToolCall {
489 OpenAiToolCall {
490 id: call.id.clone(),
491 kind: function_kind(),
492 function: OpenAiFunctionCall {
493 name: call.name.clone(),
494 arguments: serde_json::to_string(&call.arguments).unwrap_or_else(|_| "{}".to_string()),
495 },
496 }
497}
498
499fn parse_arguments(raw: &str) -> serde_json::Value {
503 let trimmed = raw.trim();
504 if trimmed.is_empty() {
505 return serde_json::json!({});
506 }
507 serde_json::from_str(trimmed).unwrap_or_else(|_| serde_json::json!({}))
508}