aether_ai/
openai.rs

1//! OpenAI provider implementation.
2//!
3//! Supports GPT-4, GPT-3.5-turbo, and other OpenAI models.
4
5use aether_core::{
6    AetherError, AiProvider, ProviderConfig, Result,
7    provider::{GenerationRequest, GenerationResponse},
8    SlotKind,
9};
10use async_trait::async_trait;
11use reqwest::Client;
12use serde::{Deserialize, Serialize};
13use tracing::{debug, instrument};
14
15const OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";
16
17/// OpenAI provider for code generation.
18#[derive(Debug, Clone)]
19pub struct OpenAiProvider {
20    client: Client,
21    config: ProviderConfig,
22}
23
24/// OpenAI chat completion request.
25#[derive(Debug, Serialize)]
26struct ChatRequest {
27    model: String,
28    messages: Vec<ChatMessage>,
29    #[serde(skip_serializing_if = "Option::is_none")]
30    max_tokens: Option<u32>,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    temperature: Option<f32>,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    stream: Option<bool>,
35}
36
37/// Chat message.
38#[derive(Debug, Serialize, Deserialize)]
39struct ChatMessage {
40    role: String,
41    content: String,
42}
43
44/// OpenAI chat completion response.
45#[derive(Debug, Deserialize)]
46struct ChatResponse {
47    choices: Vec<ChatChoice>,
48    usage: Option<Usage>,
49}
50
51#[derive(Debug, Deserialize)]
52struct ChatChoice {
53    message: ChatMessage,
54}
55
56#[derive(Debug, Deserialize)]
57struct Usage {
58    total_tokens: u32,
59}
60
61/// OpenAI streaming response chunk.
62#[derive(Debug, Deserialize)]
63struct ChatStreamResponse {
64    choices: Vec<ChatStreamChoice>,
65}
66
67#[derive(Debug, Deserialize)]
68struct ChatStreamChoice {
69    delta: ChatStreamDelta,
70    #[allow(dead_code)]
71    finish_reason: Option<String>,
72}
73
74#[derive(Debug, Deserialize)]
75struct ChatStreamDelta {
76    content: Option<String>,
77}
78
79impl OpenAiProvider {
80    /// Create a new OpenAI provider with the given configuration.
81    pub fn new(config: ProviderConfig) -> Result<Self> {
82        let timeout = config.timeout_seconds.unwrap_or(60);
83        let client = Client::builder()
84            .timeout(std::time::Duration::from_secs(timeout))
85            .build()
86            .map_err(|e| AetherError::NetworkError(e.to_string()))?;
87
88        Ok(Self { client, config })
89    }
90
91    /// Create a provider from environment variables.
92    ///
93    /// Reads `OPENAI_API_KEY` and optionally `OPENAI_MODEL`.
94    pub fn from_env() -> Result<Self> {
95        let config = ProviderConfig::from_env()?;
96        Self::new(config)
97    }
98
99    /// Create a provider from environment with a specific model.
100    pub fn from_env_with_model(model: &str) -> Result<Self> {
101        let api_key = std::env::var("OPENAI_API_KEY")
102            .map_err(|_| AetherError::ConfigError("OPENAI_API_KEY not set".to_string()))?;
103
104        let config = ProviderConfig::new(api_key, model);
105        Self::new(config)
106    }
107
108    /// Build the system prompt for code generation.
109    fn build_system_prompt(&self, kind: &SlotKind, context: Option<&str>) -> String {
110        let base = "You are a code generation assistant. Generate only the requested code without explanations or markdown code blocks. Output raw code only.";
111
112        let kind_specific = match kind {
113            SlotKind::Html => "\nGenerate valid HTML5 markup.",
114            SlotKind::Css => "\nGenerate valid CSS styles.",
115            SlotKind::JavaScript => "\nGenerate valid JavaScript code.",
116            SlotKind::Function => "\nGenerate a complete function definition.",
117            SlotKind::Class => "\nGenerate a complete class/struct definition.",
118            SlotKind::Component => "\nGenerate a complete component with HTML, CSS, and JavaScript as needed.",
119            _ => "",
120        };
121
122        let context_part = context
123            .filter(|c| !c.is_empty())
124            .map(|c| format!("\n\nContext:\n{}", c))
125            .unwrap_or_default();
126
127        format!("{}{}{}", base, kind_specific, context_part)
128    }
129}
130
131use aether_core::provider::StreamResponse;
132use futures::stream::{BoxStream, StreamExt};
133
134#[async_trait]
135impl AiProvider for OpenAiProvider {
136    fn name(&self) -> &str {
137        "openai"
138    }
139
140    #[instrument(skip(self, request), fields(slot = %request.slot.name))]
141    async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
142        debug!("Generating code with OpenAI for slot: {}", request.slot.name);
143
144        let api_key = self.config.resolve_api_key().await?;
145
146        let system_prompt = request.system_prompt.unwrap_or_else(|| {
147            self.build_system_prompt(&request.slot.kind, request.context.as_deref())
148        });
149
150        let messages = vec![
151            ChatMessage {
152                role: "system".to_string(),
153                content: system_prompt,
154            },
155            ChatMessage {
156                role: "user".to_string(),
157                content: request.slot.prompt.clone(),
158            },
159        ];
160
161        let temperature = request.slot.temperature.or(self.config.temperature);
162        let api_request = ChatRequest {
163            model: self.config.model.clone(),
164            messages,
165            max_tokens: self.config.max_tokens,
166            temperature,
167            stream: None,
168        };
169
170        let url = self.config.base_url.as_deref().unwrap_or(OPENAI_API_URL);
171
172        let response = self
173            .client
174            .post(url)
175            .header("Authorization", format!("Bearer {}", api_key))
176            .header("Content-Type", "application/json")
177            .json(&api_request)
178            .send()
179            .await
180            .map_err(|e| AetherError::NetworkError(e.to_string()))?;
181
182        if !response.status().is_success() {
183            let status = response.status();
184            let body = response.text().await.unwrap_or_default();
185            return Err(AetherError::ProviderError(format!(
186                "API error {}: {}",
187                status, body
188            )));
189        }
190
191        let chat_response: ChatResponse = response
192            .json()
193            .await
194            .map_err(|e| AetherError::ProviderError(e.to_string()))?;
195
196        let code = chat_response
197            .choices
198            .first()
199            .map(|c| c.message.content.clone())
200            .unwrap_or_default();
201
202        // Strip markdown code blocks if present
203        let code = strip_code_blocks(&code);
204
205        // Validate against slot constraints
206        if let Err(errors) = request.slot.validate(&code) {
207            debug!("Generated code failed validation: {:?}", errors);
208            // For now, we'll still return the code but log the warning
209        }
210
211        Ok(GenerationResponse {
212            code,
213            tokens_used: chat_response.usage.map(|u| u.total_tokens),
214            metadata: None,
215        })
216    }
217
218    fn generate_stream(
219        &self,
220        request: GenerationRequest,
221    ) -> BoxStream<'static, Result<StreamResponse>> {
222        let client = self.client.clone();
223        let config = self.config.clone();
224        let system_prompt = request.system_prompt.unwrap_or_else(|| {
225            self.build_system_prompt(&request.slot.kind, request.context.as_deref())
226        });
227        let user_prompt = request.slot.prompt.clone();
228        let url = config.base_url.as_deref().unwrap_or(OPENAI_API_URL).to_string();
229
230        let temperature = request.slot.temperature.or(config.temperature);
231        let api_request = ChatRequest {
232            model: config.model.clone(),
233            messages: vec![
234                ChatMessage {
235                    role: "system".to_string(),
236                    content: system_prompt,
237                },
238                ChatMessage {
239                    role: "user".to_string(),
240                    content: user_prompt,
241                },
242            ],
243            max_tokens: config.max_tokens,
244            temperature,
245            stream: Some(true),
246        };
247
248        let stream = async_stream::stream! {
249            let api_key = match config.resolve_api_key().await {
250                Ok(k) => k,
251                Err(e) => {
252                    yield Err(e);
253                    return;
254                }
255            };
256
257            let response = client
258                .post(&url)
259                .header("Authorization", format!("Bearer {}", api_key))
260                .header("Content-Type", "application/json")
261                .json(&api_request)
262                .send()
263                .await
264                .map_err(|e| aether_core::AetherError::NetworkError(e.to_string()));
265
266            let response = match response {
267                Ok(r) => r,
268                Err(e) => {
269                    yield Err(e);
270                    return;
271                }
272            };
273
274            if !response.status().is_success() {
275                let status = response.status();
276                let body = response.text().await.unwrap_or_default();
277                yield Err(aether_core::AetherError::ProviderError(format!(
278                    "API error {}: {}",
279                    status, body
280                )));
281                return;
282            }
283
284            let mut stream = response.bytes_stream();
285            
286            while let Some(chunk_result) = stream.next().await {
287                let chunk = match chunk_result {
288                    Ok(c) => c,
289                    Err(e) => {
290                        yield Err(aether_core::AetherError::NetworkError(e.to_string()));
291                        break;
292                    }
293                };
294
295                // OpenAI stream format is SSE: "data: {...}"
296                let text = String::from_utf8_lossy(&chunk);
297                for line in text.lines() {
298                    let line = line.trim();
299                    if line.is_empty() { continue; }
300                    if line == "data: [DONE]" { break; }
301                    
302                    if let Some(data) = line.strip_prefix("data: ") {
303                        if let Ok(stream_resp) = serde_json::from_str::<ChatStreamResponse>(data) {
304                            if let Some(choice) = stream_resp.choices.first() {
305                                if let Some(content) = &choice.delta.content {
306                                    yield Ok(StreamResponse {
307                                        delta: content.clone(),
308                                        metadata: None,
309                                    });
310                                }
311                            }
312                        }
313                    }
314                }
315            }
316        };
317
318        Box::pin(stream)
319    }
320
321    async fn health_check(&self) -> Result<bool> {
322        // Try a minimal API call
323        let response = self
324            .client
325            .get("https://api.openai.com/v1/models")
326            .header("Authorization", format!("Bearer {}", self.config.api_key))
327            .send()
328            .await
329            .map_err(|e| AetherError::NetworkError(e.to_string()))?;
330
331        Ok(response.status().is_success())
332    }
333}
334
335/// Strip markdown code blocks from generated code.
336fn strip_code_blocks(code: &str) -> String {
337    let code = code.trim();
338
339    // Check for ```language\n...\n``` pattern
340    if code.starts_with("```") && code.ends_with("```") {
341        let lines: Vec<&str> = code.lines().collect();
342        if lines.len() >= 2 {
343            return lines[1..lines.len() - 1].join("\n");
344        }
345    }
346
347    code.to_string()
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn test_strip_code_blocks() {
356        let input = "```html\n<div>Hello</div>\n```";
357        assert_eq!(strip_code_blocks(input), "<div>Hello</div>");
358
359        let input = "<div>Already clean</div>";
360        assert_eq!(strip_code_blocks(input), "<div>Already clean</div>");
361    }
362
363    #[test]
364    fn test_system_prompt_generation() {
365        let config = ProviderConfig::new("test-key", "gpt-4");
366        let provider = OpenAiProvider::new(config).unwrap();
367
368        let prompt = provider.build_system_prompt(&SlotKind::Html, None);
369        assert!(prompt.contains("HTML5"));
370    }
371}