aether_core/
provider.rs

1//! AI Provider trait and configuration.
2//!
3//! Defines the interface that AI backends must implement.
4
5use crate::{Result, Slot};
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9
10/// Configuration for an AI provider.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ProviderConfig {
13    /// API key for authentication.
14    pub api_key: String,
15
16    /// Model identifier (e.g., "gpt-4", "claude-3").
17    pub model: String,
18
19    /// Base URL for the API.
20    pub base_url: Option<String>,
21
22    /// Maximum tokens to generate.
23    pub max_tokens: Option<u32>,
24
25    /// Temperature for generation (0.0 - 2.0).
26    pub temperature: Option<f32>,
27
28    /// Request timeout in seconds.
29    pub timeout_seconds: Option<u64>,
30
31    /// Optional URL to fetch the API key from (for stealth/security).
32    pub api_key_url: Option<String>,
33}
34
35impl ProviderConfig {
36    /// Create a new provider config with API key and model.
37    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
38        Self {
39            api_key: api_key.into(),
40            model: model.into(),
41            base_url: None,
42            max_tokens: None,
43            temperature: None,
44            timeout_seconds: None,
45            api_key_url: None,
46        }
47    }
48
49    /// Set an external URL to fetch the API key from.
50    pub fn with_api_key_url(mut self, url: impl Into<String>) -> Self {
51        self.api_key_url = Some(url.into());
52        self
53    }
54
55    /// Resolve the API key (literal or remote).
56    pub async fn resolve_api_key(&self) -> Result<String> {
57        if let Some(ref url) = self.api_key_url {
58            let resp = reqwest::get(url)
59                .await
60                .map_err(|e| crate::AetherError::NetworkError(format!("Failed to fetch API key: {}", e)))?;
61            
62            let key = resp
63                .text()
64                .await
65                .map_err(|e| crate::AetherError::NetworkError(format!("Failed to read API key body: {}", e)))?;
66            
67            Ok(key.trim().to_string())
68        } else {
69            Ok(self.api_key.clone())
70        }
71    }
72
73    /// Set the base URL.
74    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
75        self.base_url = Some(url.into());
76        self
77    }
78
79    /// Set max tokens.
80    pub fn with_max_tokens(mut self, tokens: u32) -> Self {
81        self.max_tokens = Some(tokens);
82        self
83    }
84
85    /// Set temperature.
86    pub fn with_temperature(mut self, temp: f32) -> Self {
87        self.temperature = Some(temp.clamp(0.0, 2.0));
88        self
89    }
90
91    /// Set timeout.
92    pub fn with_timeout(mut self, seconds: u64) -> Self {
93        self.timeout_seconds = Some(seconds);
94        self
95    }
96
97    /// Load config from environment variables.
98    ///
99    /// Expected variables:
100    /// - `AETHER_API_KEY` or `OPENAI_API_KEY`
101    /// - `AETHER_MODEL` (defaults to "gpt-4")
102    /// - `AETHER_BASE_URL` (optional)
103    pub fn from_env() -> Result<Self> {
104        let api_key = std::env::var("AETHER_API_KEY")
105            .or_else(|_| std::env::var("OPENAI_API_KEY"))
106            .map_err(|_| {
107                crate::AetherError::ConfigError(
108                    "AETHER_API_KEY or OPENAI_API_KEY must be set".to_string(),
109                )
110            })?;
111
112        let model = std::env::var("AETHER_MODEL").unwrap_or_else(|_| "gpt-5.2-thinking".to_string());
113
114        let mut config = Self::new(api_key, model);
115
116        if let Ok(url) = std::env::var("AETHER_BASE_URL") {
117            config = config.with_base_url(url);
118        }
119
120        Ok(config)
121    }
122}
123
124/// Request for code generation.
125#[derive(Debug, Clone)]
126pub struct GenerationRequest {
127    /// The slot to generate code for.
128    pub slot: Slot,
129
130    /// Additional context (e.g., surrounding code).
131    pub context: Option<String>,
132
133    /// System prompt override.
134    pub system_prompt: Option<String>,
135}
136
137use futures::stream::BoxStream;
138
139/// Response from code generation.
140#[derive(Debug, Clone)]
141pub struct GenerationResponse {
142    /// The generated code.
143    pub code: String,
144
145    /// Tokens used for the request.
146    pub tokens_used: Option<u32>,
147
148    /// Generation metadata.
149    pub metadata: Option<serde_json::Value>,
150}
151
152/// A single chunk of a streaming response.
153#[derive(Debug, Clone)]
154pub struct StreamResponse {
155    /// The new text chunk.
156    pub delta: String,
157
158    /// Final metadata (only sent in the last chunk).
159    pub metadata: Option<serde_json::Value>,
160}
161
162/// Trait that AI providers must implement.
163///
164/// This trait defines the interface for generating code from slots.
165#[async_trait]
166pub trait AiProvider: Send + Sync {
167    /// Get the provider name.
168    fn name(&self) -> &str;
169
170    /// Generate code for a slot.
171    ///
172    /// # Arguments
173    ///
174    /// * `request` - The generation request containing slot info
175    ///
176    /// # Returns
177    ///
178    /// Generated code response or an error.
179    async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse>;
180
181    /// Generate a stream of code for a slot.
182    ///
183    /// # Arguments
184    ///
185    /// * `request` - The generation request containing slot info
186    ///
187    /// # Returns
188    ///
189    /// A pinned stream of chunks or an error.
190    fn generate_stream(
191        &self,
192        _request: GenerationRequest,
193    ) -> BoxStream<'static, Result<StreamResponse>> {
194        let name = self.name().to_string();
195        Box::pin(async_stream::stream! {
196            yield Err(crate::AetherError::ProviderError(format!(
197                "Streaming not implemented for provider: {}",
198                name
199            )));
200        })
201    }
202
203    /// Generate code for multiple slots in batch.
204    ///
205    /// Default implementation calls `generate` for each slot sequentially.
206    async fn generate_batch(
207        &self,
208        requests: Vec<GenerationRequest>,
209    ) -> Result<Vec<GenerationResponse>> {
210        let mut responses = Vec::with_capacity(requests.len());
211        for request in requests {
212            responses.push(self.generate(request).await?);
213        }
214        Ok(responses)
215    }
216
217    /// Check if the provider is available and configured correctly.
218    async fn health_check(&self) -> Result<bool> {
219        Ok(true)
220    }
221}
222
223#[async_trait]
224impl<T: AiProvider + ?Sized + Send + Sync> AiProvider for Arc<T> {
225    fn name(&self) -> &str {
226        (**self).name()
227    }
228
229    async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
230        (**self).generate(request).await
231    }
232
233    fn generate_stream(
234        &self,
235        request: GenerationRequest,
236    ) -> BoxStream<'static, Result<StreamResponse>> {
237        (**self).generate_stream(request)
238    }
239}
240
241#[async_trait]
242impl<T: AiProvider + ?Sized + Send + Sync> AiProvider for Box<T> {
243    fn name(&self) -> &str {
244        (**self).name()
245    }
246
247    async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
248        (**self).generate(request).await
249    }
250
251    fn generate_stream(
252        &self,
253        request: GenerationRequest,
254    ) -> BoxStream<'static, Result<StreamResponse>> {
255        (**self).generate_stream(request)
256    }
257}
258
259/// A mock provider for testing.
260#[derive(Debug, Default)]
261pub struct MockProvider {
262    /// Responses to return (slot_name -> code).
263    pub responses: std::collections::HashMap<String, String>,
264}
265
266impl MockProvider {
267    /// Create a new mock provider.
268    pub fn new() -> Self {
269        Self::default()
270    }
271
272    /// Add a mock response.
273    pub fn with_response(mut self, slot: impl Into<String>, code: impl Into<String>) -> Self {
274        self.responses.insert(slot.into(), code.into());
275        self
276    }
277}
278
279#[async_trait]
280impl AiProvider for MockProvider {
281    fn name(&self) -> &str {
282        "mock"
283    }
284
285    async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
286        let code = self
287            .responses
288            .get(&request.slot.name)
289            .cloned()
290            .unwrap_or_else(|| format!("// Generated code for: {}", request.slot.name));
291
292        Ok(GenerationResponse {
293            code,
294            tokens_used: Some(10),
295            metadata: None,
296        })
297    }
298
299    fn generate_stream(
300        &self,
301        request: GenerationRequest,
302    ) -> BoxStream<'static, Result<StreamResponse>> {
303        let code = self
304            .responses
305            .get(&request.slot.name)
306            .cloned()
307            .unwrap_or_else(|| format!("// Generated code for: {}", request.slot.name));
308
309        let words: Vec<String> = code.split_whitespace().map(|s| format!("{} ", s)).collect();
310        
311        let stream = async_stream::stream! {
312            for word in words {
313                yield Ok(StreamResponse {
314                    delta: word,
315                    metadata: None,
316                });
317            }
318        };
319        
320        Box::pin(stream)
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327
328    #[tokio::test]
329    async fn test_mock_provider() {
330        let provider = MockProvider::new()
331            .with_response("button", "<button>Click me</button>");
332
333        let request = GenerationRequest {
334            slot: Slot::new("button", "Create a button"),
335            context: None,
336            system_prompt: None,
337        };
338
339        let response = provider.generate(request).await.unwrap();
340        assert_eq!(response.code, "<button>Click me</button>");
341    }
342}