1use aether_core::{
6 AetherError, AiProvider, ProviderConfig, Result,
7 provider::{GenerationRequest, GenerationResponse},
8 SlotKind,
9};
10use async_trait::async_trait;
11use reqwest::Client;
12use serde::{Deserialize, Serialize};
13use tracing::{debug, instrument};
14
15const OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";
16
17#[derive(Debug, Clone)]
19pub struct OpenAiProvider {
20 client: Client,
21 config: ProviderConfig,
22}
23
24#[derive(Debug, Serialize)]
26struct ChatRequest {
27 model: String,
28 messages: Vec<ChatMessage>,
29 #[serde(skip_serializing_if = "Option::is_none")]
30 max_tokens: Option<u32>,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 temperature: Option<f32>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 stream: Option<bool>,
35}
36
37#[derive(Debug, Serialize, Deserialize)]
39struct ChatMessage {
40 role: String,
41 content: String,
42}
43
44#[derive(Debug, Deserialize)]
46struct ChatResponse {
47 choices: Vec<ChatChoice>,
48 usage: Option<Usage>,
49}
50
51#[derive(Debug, Deserialize)]
52struct ChatChoice {
53 message: ChatMessage,
54}
55
56#[derive(Debug, Deserialize)]
57struct Usage {
58 total_tokens: u32,
59}
60
61#[derive(Debug, Deserialize)]
63struct ChatStreamResponse {
64 choices: Vec<ChatStreamChoice>,
65}
66
67#[derive(Debug, Deserialize)]
68struct ChatStreamChoice {
69 delta: ChatStreamDelta,
70 #[allow(dead_code)]
71 finish_reason: Option<String>,
72}
73
74#[derive(Debug, Deserialize)]
75struct ChatStreamDelta {
76 content: Option<String>,
77}
78
79impl OpenAiProvider {
80 pub fn new(config: ProviderConfig) -> Result<Self> {
82 let timeout = config.timeout_seconds.unwrap_or(60);
83 let client = Client::builder()
84 .timeout(std::time::Duration::from_secs(timeout))
85 .build()
86 .map_err(|e| AetherError::NetworkError(e.to_string()))?;
87
88 Ok(Self { client, config })
89 }
90
91 pub fn from_env() -> Result<Self> {
95 let config = ProviderConfig::from_env()?;
96 Self::new(config)
97 }
98
99 pub fn from_env_with_model(model: &str) -> Result<Self> {
101 let api_key = std::env::var("OPENAI_API_KEY")
102 .map_err(|_| AetherError::ConfigError("OPENAI_API_KEY not set".to_string()))?;
103
104 let config = ProviderConfig::new(api_key, model);
105 Self::new(config)
106 }
107
108 fn build_system_prompt(&self, kind: &SlotKind, context: Option<&str>) -> String {
110 let base = "You are a code generation assistant. Generate only the requested code without explanations or markdown code blocks. Output raw code only.";
111
112 let kind_specific = match kind {
113 SlotKind::Html => "\nGenerate valid HTML5 markup.",
114 SlotKind::Css => "\nGenerate valid CSS styles.",
115 SlotKind::JavaScript => "\nGenerate valid JavaScript code.",
116 SlotKind::Function => "\nGenerate a complete function definition.",
117 SlotKind::Class => "\nGenerate a complete class/struct definition.",
118 SlotKind::Component => "\nGenerate a complete component with HTML, CSS, and JavaScript as needed.",
119 _ => "",
120 };
121
122 let context_part = context
123 .filter(|c| !c.is_empty())
124 .map(|c| format!("\n\nContext:\n{}", c))
125 .unwrap_or_default();
126
127 format!("{}{}{}", base, kind_specific, context_part)
128 }
129}
130
131use aether_core::provider::StreamResponse;
132use futures::stream::{BoxStream, StreamExt};
133
134#[async_trait]
135impl AiProvider for OpenAiProvider {
136 fn name(&self) -> &str {
137 "openai"
138 }
139
140 #[instrument(skip(self, request), fields(slot = %request.slot.name))]
141 async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
142 debug!("Generating code with OpenAI for slot: {}", request.slot.name);
143
144 let api_key = self.config.resolve_api_key().await?;
145
146 let system_prompt = request.system_prompt.unwrap_or_else(|| {
147 self.build_system_prompt(&request.slot.kind, request.context.as_deref())
148 });
149
150 let messages = vec![
151 ChatMessage {
152 role: "system".to_string(),
153 content: system_prompt,
154 },
155 ChatMessage {
156 role: "user".to_string(),
157 content: request.slot.prompt.clone(),
158 },
159 ];
160
161 let temperature = request.slot.temperature.or(self.config.temperature);
162 let api_request = ChatRequest {
163 model: self.config.model.clone(),
164 messages,
165 max_tokens: self.config.max_tokens,
166 temperature,
167 stream: None,
168 };
169
170 let url = self.config.base_url.as_deref().unwrap_or(OPENAI_API_URL);
171
172 let response = self
173 .client
174 .post(url)
175 .header("Authorization", format!("Bearer {}", api_key))
176 .header("Content-Type", "application/json")
177 .json(&api_request)
178 .send()
179 .await
180 .map_err(|e| AetherError::NetworkError(e.to_string()))?;
181
182 if !response.status().is_success() {
183 let status = response.status();
184 let body = response.text().await.unwrap_or_default();
185 return Err(AetherError::ProviderError(format!(
186 "API error {}: {}",
187 status, body
188 )));
189 }
190
191 let chat_response: ChatResponse = response
192 .json()
193 .await
194 .map_err(|e| AetherError::ProviderError(e.to_string()))?;
195
196 let code = chat_response
197 .choices
198 .first()
199 .map(|c| c.message.content.clone())
200 .unwrap_or_default();
201
202 let code = strip_code_blocks(&code);
204
205 if let Err(errors) = request.slot.validate(&code) {
207 debug!("Generated code failed validation: {:?}", errors);
208 }
210
211 Ok(GenerationResponse {
212 code,
213 tokens_used: chat_response.usage.map(|u| u.total_tokens),
214 metadata: None,
215 })
216 }
217
218 fn generate_stream(
219 &self,
220 request: GenerationRequest,
221 ) -> BoxStream<'static, Result<StreamResponse>> {
222 let client = self.client.clone();
223 let config = self.config.clone();
224 let system_prompt = request.system_prompt.unwrap_or_else(|| {
225 self.build_system_prompt(&request.slot.kind, request.context.as_deref())
226 });
227 let user_prompt = request.slot.prompt.clone();
228 let url = config.base_url.as_deref().unwrap_or(OPENAI_API_URL).to_string();
229
230 let temperature = request.slot.temperature.or(config.temperature);
231 let api_request = ChatRequest {
232 model: config.model.clone(),
233 messages: vec![
234 ChatMessage {
235 role: "system".to_string(),
236 content: system_prompt,
237 },
238 ChatMessage {
239 role: "user".to_string(),
240 content: user_prompt,
241 },
242 ],
243 max_tokens: config.max_tokens,
244 temperature,
245 stream: Some(true),
246 };
247
248 let stream = async_stream::stream! {
249 let api_key = match config.resolve_api_key().await {
250 Ok(k) => k,
251 Err(e) => {
252 yield Err(e);
253 return;
254 }
255 };
256
257 let response = client
258 .post(&url)
259 .header("Authorization", format!("Bearer {}", api_key))
260 .header("Content-Type", "application/json")
261 .json(&api_request)
262 .send()
263 .await
264 .map_err(|e| aether_core::AetherError::NetworkError(e.to_string()));
265
266 let response = match response {
267 Ok(r) => r,
268 Err(e) => {
269 yield Err(e);
270 return;
271 }
272 };
273
274 if !response.status().is_success() {
275 let status = response.status();
276 let body = response.text().await.unwrap_or_default();
277 yield Err(aether_core::AetherError::ProviderError(format!(
278 "API error {}: {}",
279 status, body
280 )));
281 return;
282 }
283
284 let mut stream = response.bytes_stream();
285
286 while let Some(chunk_result) = stream.next().await {
287 let chunk = match chunk_result {
288 Ok(c) => c,
289 Err(e) => {
290 yield Err(aether_core::AetherError::NetworkError(e.to_string()));
291 break;
292 }
293 };
294
295 let text = String::from_utf8_lossy(&chunk);
297 for line in text.lines() {
298 let line = line.trim();
299 if line.is_empty() { continue; }
300 if line == "data: [DONE]" { break; }
301
302 if let Some(data) = line.strip_prefix("data: ") {
303 if let Ok(stream_resp) = serde_json::from_str::<ChatStreamResponse>(data) {
304 if let Some(choice) = stream_resp.choices.first() {
305 if let Some(content) = &choice.delta.content {
306 yield Ok(StreamResponse {
307 delta: content.clone(),
308 metadata: None,
309 });
310 }
311 }
312 }
313 }
314 }
315 }
316 };
317
318 Box::pin(stream)
319 }
320
321 async fn health_check(&self) -> Result<bool> {
322 let response = self
324 .client
325 .get("https://api.openai.com/v1/models")
326 .header("Authorization", format!("Bearer {}", self.config.api_key))
327 .send()
328 .await
329 .map_err(|e| AetherError::NetworkError(e.to_string()))?;
330
331 Ok(response.status().is_success())
332 }
333}
334
335fn strip_code_blocks(code: &str) -> String {
337 let code = code.trim();
338
339 if code.starts_with("```") && code.ends_with("```") {
341 let lines: Vec<&str> = code.lines().collect();
342 if lines.len() >= 2 {
343 return lines[1..lines.len() - 1].join("\n");
344 }
345 }
346
347 code.to_string()
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn test_strip_code_blocks() {
356 let input = "```html\n<div>Hello</div>\n```";
357 assert_eq!(strip_code_blocks(input), "<div>Hello</div>");
358
359 let input = "<div>Already clean</div>";
360 assert_eq!(strip_code_blocks(input), "<div>Already clean</div>");
361 }
362
363 #[test]
364 fn test_system_prompt_generation() {
365 let config = ProviderConfig::new("test-key", "gpt-4");
366 let provider = OpenAiProvider::new(config).unwrap();
367
368 let prompt = provider.build_system_prompt(&SlotKind::Html, None);
369 assert!(prompt.contains("HTML5"));
370 }
371}