1use std::pin::Pin;
2
3use futures::Stream;
4use serde::{Deserialize, Serialize};
5
6use super::{LlmError, LlmProvider, Message, Response, ResponseChunk, Role, Usage};
7
8#[derive(Serialize)]
9struct OpenAiRequest {
10 model: String,
11 messages: Vec<OpenAiMessage>,
12 temperature: f64,
13 max_tokens: Option<i32>,
14 stream: bool,
15}
16
17#[derive(Serialize, Deserialize)]
18struct OpenAiMessage {
19 role: String,
20 content: String,
21}
22
23#[derive(Deserialize)]
24struct OpenAiResponse {
25 choices: Vec<OpenAiChoice>,
26 usage: Option<OpenAiUsage>,
27}
28
29#[derive(Deserialize)]
30struct OpenAiChoice {
31 message: OpenAiMessage,
32 #[allow(dead_code)]
33 finish_reason: Option<String>,
34}
35
36#[derive(Deserialize)]
37struct OpenAiStreamResponse {
38 choices: Vec<OpenAiStreamChoice>,
39}
40
41#[derive(Deserialize)]
42struct OpenAiStreamChoice {
43 delta: OpenAiDelta,
44 finish_reason: Option<String>,
45}
46
47#[derive(Deserialize)]
48struct OpenAiDelta {
49 #[serde(default)]
50 content: Option<String>,
51}
52
53#[derive(Deserialize)]
54struct OpenAiUsage {
55 prompt_tokens: u32,
56 completion_tokens: u32,
57 total_tokens: u32,
58}
59
60pub struct OpenAiProvider {
62 client: reqwest::Client,
63 base_url: String,
64 api_key: Option<String>,
65 model: String,
66 temperature: f64,
67 max_tokens: Option<i32>,
68}
69
70impl OpenAiProvider {
71 pub fn new(
72 base_url: &str,
73 api_key: Option<&str>,
74 model: &str,
75 temperature: f64,
76 max_tokens: Option<i32>,
77 ) -> Result<Self, LlmError> {
78 let client = reqwest::Client::builder()
79 .timeout(brain_core::timeouts::LLM_GENERATE)
80 .build()
81 .map_err(|e| {
82 LlmError::ProviderUnavailable(format!("Failed to create HTTP client: {e}"))
83 })?;
84
85 Ok(Self {
86 client,
87 base_url: base_url.trim_end_matches('/').to_string(),
88 api_key: api_key.map(|s| s.to_string()),
89 model: model.to_string(),
90 temperature,
91 max_tokens,
92 })
93 }
94
95 pub fn openai(api_key: &str, model: &str) -> Result<Self, LlmError> {
96 Self::new(
97 "https://api.openai.com/v1",
98 Some(api_key),
99 model,
100 0.7,
101 Some(4096),
102 )
103 }
104
105 pub fn openrouter(api_key: &str, model: &str) -> Result<Self, LlmError> {
106 Self::new(
107 "https://openrouter.ai/api/v1",
108 Some(api_key),
109 model,
110 0.7,
111 Some(4096),
112 )
113 }
114
115 fn convert_messages(messages: &[Message]) -> Vec<OpenAiMessage> {
116 messages
117 .iter()
118 .map(|m| OpenAiMessage {
119 role: match m.role {
120 Role::System => "system".to_string(),
121 Role::User => "user".to_string(),
122 Role::Assistant => "assistant".to_string(),
123 },
124 content: m.content.clone(),
125 })
126 .collect()
127 }
128
129 fn build_request(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
130 let mut builder = builder;
131 if let Some(key) = &self.api_key {
132 builder = builder.header("Authorization", format!("Bearer {}", key));
133 }
134 builder
135 }
136}
137
138#[async_trait::async_trait]
139impl LlmProvider for OpenAiProvider {
140 async fn generate(&self, messages: &[Message]) -> Result<Response, LlmError> {
141 let url = format!("{}/chat/completions", self.base_url);
142 let request = OpenAiRequest {
143 model: self.model.clone(),
144 messages: Self::convert_messages(messages),
145 temperature: self.temperature,
146 max_tokens: self.max_tokens,
147 stream: false,
148 };
149
150 let resp = self
151 .build_request(self.client.post(&url))
152 .json(&request)
153 .send()
154 .await?;
155
156 if !resp.status().is_success() {
157 let status = resp.status();
158 let body = resp.text().await.unwrap_or_default();
159 return Err(LlmError::Api {
160 status: status.as_u16(),
161 message: body,
162 });
163 }
164
165 let data: OpenAiResponse = resp.json().await?;
166 let content = data
167 .choices
168 .first()
169 .map(|c| c.message.content.clone())
170 .unwrap_or_default();
171
172 Ok(Response {
173 content,
174 usage: data.usage.map(|u| Usage {
175 prompt_tokens: u.prompt_tokens,
176 completion_tokens: u.completion_tokens,
177 total_tokens: u.total_tokens,
178 }),
179 })
180 }
181
182 async fn generate_stream(
183 &self,
184 messages: &[Message],
185 ) -> Result<Pin<Box<dyn Stream<Item = Result<ResponseChunk, LlmError>> + Send>>, LlmError> {
186 use futures::stream::try_unfold;
187
188 let url = format!("{}/chat/completions", self.base_url);
189 let request = OpenAiRequest {
190 model: self.model.clone(),
191 messages: Self::convert_messages(messages),
192 temperature: self.temperature,
193 max_tokens: self.max_tokens,
194 stream: true,
195 };
196
197 let resp = self
198 .build_request(self.client.post(&url))
199 .json(&request)
200 .send()
201 .await?;
202
203 if !resp.status().is_success() {
204 let status = resp.status();
205 let body = resp.text().await.unwrap_or_default();
206 return Err(LlmError::Api {
207 status: status.as_u16(),
208 message: body,
209 });
210 }
211
212 let byte_stream = resp.bytes_stream();
213 let stream = try_unfold(
214 (Box::pin(byte_stream), String::new()),
215 |(mut byte_stream, mut buf)| async move {
216 use futures::TryStreamExt;
217
218 loop {
219 if let Some(newline_pos) = buf.find('\n') {
220 let line: String = buf[..newline_pos].to_string();
221 buf = buf[newline_pos + 1..].to_string();
222
223 let line = line.trim();
224 if line.is_empty() {
225 continue;
226 }
227
228 if let Some(data) = line.strip_prefix("data: ") {
229 let data = data.trim();
230 if data == "[DONE]" {
231 return Ok(None);
232 }
233
234 match serde_json::from_str::<OpenAiStreamResponse>(data) {
235 Ok(resp) => {
236 if let Some(choice) = resp.choices.first() {
237 let content =
238 choice.delta.content.clone().unwrap_or_default();
239 let is_done = choice.finish_reason.is_some();
240 let chunk = ResponseChunk { content, is_done };
241 return Ok(Some((chunk, (byte_stream, buf))));
242 }
243 continue;
244 }
245 Err(e) => {
246 return Err(LlmError::InvalidFormat(format!(
247 "Failed to parse streaming response: {e}"
248 )));
249 }
250 }
251 }
252 continue;
253 }
254
255 match byte_stream.try_next().await {
256 Ok(Some(bytes)) => {
257 buf.push_str(&String::from_utf8_lossy(&bytes));
258 }
259 Ok(None) => return Ok(None),
260 Err(e) => return Err(LlmError::Http(e)),
261 }
262 }
263 },
264 );
265
266 Ok(Box::pin(stream))
267 }
268
269 async fn health_check(&self) -> bool {
270 let url = format!("{}/models", self.base_url);
271 match self.build_request(self.client.get(&url)).send().await {
272 Ok(resp) => resp.status().is_success(),
273 Err(_) => false,
274 }
275 }
276
277 fn name(&self) -> &str {
278 "openai"
279 }
280
281 fn model(&self) -> &str {
282 &self.model
283 }
284
285 async fn list_models(&self) -> Result<Vec<String>, LlmError> {
286 #[derive(Deserialize)]
287 struct ModelEntry {
288 id: String,
289 }
290 #[derive(Deserialize)]
291 struct Models {
292 data: Vec<ModelEntry>,
293 }
294
295 let url = format!("{}/models", self.base_url);
296 let resp = self.build_request(self.client.get(&url)).send().await?;
297 if !resp.status().is_success() {
298 let status = resp.status();
299 let body = resp.text().await.unwrap_or_default();
300 return Err(LlmError::Api {
301 status: status.as_u16(),
302 message: body,
303 });
304 }
305 let data: Models = resp.json().await?;
306 Ok(data.data.into_iter().map(|m| m.id).collect())
307 }
308}