aether_core/
provider.rs

1//! AI Provider trait and configuration.
2//!
3//! Defines the interface that AI backends must implement.
4
5use crate::{Result, Slot};
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9
10/// Configuration for an AI provider.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ProviderConfig {
13    /// API key for authentication.
14    pub api_key: String,
15
16    /// Model identifier (e.g., "gpt-4", "claude-3").
17    pub model: String,
18
19    /// Base URL for the API.
20    pub base_url: Option<String>,
21
22    /// Maximum tokens to generate.
23    pub max_tokens: Option<u32>,
24
25    /// Temperature for generation (0.0 - 2.0).
26    pub temperature: Option<f32>,
27
28    /// Request timeout in seconds.
29    pub timeout_seconds: Option<u64>,
30
31    /// Optional URL to fetch the API key from (for stealth/security).
32    pub api_key_url: Option<String>,
33}
34
35impl ProviderConfig {
36    /// Create a new provider config with API key and model.
37    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
38        Self {
39            api_key: api_key.into(),
40            model: model.into(),
41            base_url: None,
42            max_tokens: None,
43            temperature: None,
44            timeout_seconds: None,
45            api_key_url: None,
46        }
47    }
48
49    /// Set an external URL to fetch the API key from.
50    pub fn with_api_key_url(mut self, url: impl Into<String>) -> Self {
51        self.api_key_url = Some(url.into());
52        self
53    }
54
55    /// Resolve the API key (literal or remote).
56    pub async fn resolve_api_key(&self) -> Result<String> {
57        if let Some(ref url) = self.api_key_url {
58            let resp = reqwest::get(url)
59                .await
60                .map_err(|e| crate::AetherError::NetworkError(format!("Failed to fetch API key: {}", e)))?;
61            
62            let key = resp
63                .text()
64                .await
65                .map_err(|e| crate::AetherError::NetworkError(format!("Failed to read API key body: {}", e)))?;
66            
67            Ok(key.trim().to_string())
68        } else {
69            Ok(self.api_key.clone())
70        }
71    }
72
73    /// Set the base URL.
74    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
75        self.base_url = Some(url.into());
76        self
77    }
78
79    /// Set max tokens.
80    pub fn with_max_tokens(mut self, tokens: u32) -> Self {
81        self.max_tokens = Some(tokens);
82        self
83    }
84
85    /// Set temperature.
86    pub fn with_temperature(mut self, temp: f32) -> Self {
87        self.temperature = Some(temp.clamp(0.0, 2.0));
88        self
89    }
90
91    /// Set timeout.
92    pub fn with_timeout(mut self, seconds: u64) -> Self {
93        self.timeout_seconds = Some(seconds);
94        self
95    }
96
97    /// Load config from environment variables.
98    ///
99    /// Expected variables:
100    /// - `AETHER_API_KEY` or `OPENAI_API_KEY`
101    /// - `AETHER_MODEL` (defaults to "gpt-4")
102    /// - `AETHER_BASE_URL` (optional)
103    pub fn from_env() -> Result<Self> {
104        let api_key = std::env::var("AETHER_API_KEY")
105            .or_else(|_| std::env::var("OPENAI_API_KEY"))
106            .map_err(|_| {
107                crate::AetherError::ConfigError(
108                    "AETHER_API_KEY or OPENAI_API_KEY must be set".to_string(),
109                )
110            })?;
111
112        let model = std::env::var("AETHER_MODEL").unwrap_or_else(|_| "gpt-5.2-thinking".to_string());
113
114        let mut config = Self::new(api_key, model);
115
116        if let Ok(url) = std::env::var("AETHER_BASE_URL") {
117            config = config.with_base_url(url);
118        }
119
120        Ok(config)
121    }
122}
123
124/// Request for code generation.
125#[derive(Debug, Clone)]
126pub struct GenerationRequest {
127    /// The slot to generate code for.
128    pub slot: Slot,
129
130    /// Additional context (e.g., surrounding code).
131    pub context: Option<String>,
132
133    /// System prompt override.
134    pub system_prompt: Option<String>,
135
136    /// Specific model override for this request.
137    pub model: Option<String>,
138
139    /// Maximum tokens for this request.
140    pub max_tokens: Option<u32>,
141}
142
143use futures::stream::BoxStream;
144
145/// Response from code generation.
146#[derive(Debug, Clone)]
147pub struct GenerationResponse {
148    /// The generated code.
149    pub code: String,
150
151    /// Tokens used for the request.
152    pub tokens_used: Option<u32>,
153
154    /// Generation metadata.
155    pub metadata: Option<serde_json::Value>,
156}
157
158/// A single chunk of a streaming response.
159#[derive(Debug, Clone)]
160pub struct StreamResponse {
161    /// The new text chunk.
162    pub delta: String,
163
164    /// Final metadata (only sent in the last chunk).
165    pub metadata: Option<serde_json::Value>,
166}
167
168/// Trait that AI providers must implement.
169///
170/// This trait defines the interface for generating code from slots.
171#[async_trait]
172pub trait AiProvider: Send + Sync {
173    /// Get the provider name.
174    fn name(&self) -> &str;
175
176    /// Generate code for a slot.
177    ///
178    /// # Arguments
179    ///
180    /// * `request` - The generation request containing slot info
181    ///
182    /// # Returns
183    ///
184    /// Generated code response or an error.
185    async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse>;
186
187    /// Generate a stream of code for a slot.
188    ///
189    /// # Arguments
190    ///
191    /// * `request` - The generation request containing slot info
192    ///
193    /// # Returns
194    ///
195    /// A pinned stream of chunks or an error.
196    fn generate_stream(
197        &self,
198        _request: GenerationRequest,
199    ) -> BoxStream<'static, Result<StreamResponse>> {
200        let name = self.name().to_string();
201        Box::pin(async_stream::stream! {
202            yield Err(crate::AetherError::ProviderError(format!(
203                "Streaming not implemented for provider: {}",
204                name
205            )));
206        })
207    }
208
209    /// Generate code for multiple slots in batch.
210    ///
211    /// Default implementation calls `generate` for each slot sequentially.
212    async fn generate_batch(
213        &self,
214        requests: Vec<GenerationRequest>,
215    ) -> Result<Vec<GenerationResponse>> {
216        let mut responses = Vec::with_capacity(requests.len());
217        for request in requests {
218            responses.push(self.generate(request).await?);
219        }
220        Ok(responses)
221    }
222
223    /// Check if the provider is available and configured correctly.
224    async fn health_check(&self) -> Result<bool> {
225        Ok(true)
226    }
227}
228
229#[async_trait]
230impl<T: AiProvider + ?Sized + Send + Sync> AiProvider for Arc<T> {
231    fn name(&self) -> &str {
232        (**self).name()
233    }
234
235    async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
236        (**self).generate(request).await
237    }
238
239    fn generate_stream(
240        &self,
241        request: GenerationRequest,
242    ) -> BoxStream<'static, Result<StreamResponse>> {
243        (**self).generate_stream(request)
244    }
245}
246
247#[async_trait]
248impl<T: AiProvider + ?Sized + Send + Sync> AiProvider for Box<T> {
249    fn name(&self) -> &str {
250        (**self).name()
251    }
252
253    async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
254        (**self).generate(request).await
255    }
256
257    fn generate_stream(
258        &self,
259        request: GenerationRequest,
260    ) -> BoxStream<'static, Result<StreamResponse>> {
261        (**self).generate_stream(request)
262    }
263}
264
265/// A mock provider for testing.
266#[derive(Debug, Default)]
267pub struct MockProvider {
268    /// Responses to return (slot_name -> code).
269    pub responses: std::collections::HashMap<String, String>,
270}
271
272impl MockProvider {
273    /// Create a new mock provider.
274    pub fn new() -> Self {
275        Self::default()
276    }
277
278    /// Add a mock response.
279    pub fn with_response(mut self, slot: impl Into<String>, code: impl Into<String>) -> Self {
280        self.responses.insert(slot.into(), code.into());
281        self
282    }
283}
284
285#[async_trait]
286impl AiProvider for MockProvider {
287    fn name(&self) -> &str {
288        "mock"
289    }
290
291    async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
292        let code = self
293            .responses
294            .get(&request.slot.name)
295            .cloned()
296            .unwrap_or_else(|| format!("// Generated code for: {}", request.slot.name));
297
298        Ok(GenerationResponse {
299            code,
300            tokens_used: Some(10),
301            metadata: None,
302        })
303    }
304
305    fn generate_stream(
306        &self,
307        request: GenerationRequest,
308    ) -> BoxStream<'static, Result<StreamResponse>> {
309        let code = self
310            .responses
311            .get(&request.slot.name)
312            .cloned()
313            .unwrap_or_else(|| format!("// Generated code for: {}", request.slot.name));
314
315        let words: Vec<String> = code.split_whitespace().map(|s| format!("{} ", s)).collect();
316        
317        let stream = async_stream::stream! {
318            for word in words {
319                yield Ok(StreamResponse {
320                    delta: word,
321                    metadata: None,
322                });
323            }
324        };
325        
326        Box::pin(stream)
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[tokio::test]
335    async fn test_mock_provider() {
336        let provider = MockProvider::new()
337            .with_response("button", "<button>Click me</button>");
338
339        let request = GenerationRequest {
340            slot: Slot::new("button", "Create a button"),
341            context: None,
342            system_prompt: None,
343            model: None,
344            max_tokens: None,
345        };
346
347        let response = provider.generate(request).await.unwrap();
348        assert_eq!(response.code, "<button>Click me</button>");
349    }
350}