aether_core/
provider.rs

1//! AI Provider trait and configuration.
2//!
3//! Defines the interface that AI backends must implement.
4
5use crate::{Result, Slot};
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8
9/// Configuration for an AI provider.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ProviderConfig {
12    /// API key for authentication.
13    pub api_key: String,
14
15    /// Model identifier (e.g., "gpt-4", "claude-3").
16    pub model: String,
17
18    /// Base URL for the API.
19    pub base_url: Option<String>,
20
21    /// Maximum tokens to generate.
22    pub max_tokens: Option<u32>,
23
24    /// Temperature for generation (0.0 - 2.0).
25    pub temperature: Option<f32>,
26
27    /// Request timeout in seconds.
28    pub timeout_seconds: Option<u64>,
29
30    /// Optional URL to fetch the API key from (for stealth/security).
31    pub api_key_url: Option<String>,
32}
33
34impl ProviderConfig {
35    /// Create a new provider config with API key and model.
36    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
37        Self {
38            api_key: api_key.into(),
39            model: model.into(),
40            base_url: None,
41            max_tokens: None,
42            temperature: None,
43            timeout_seconds: None,
44            api_key_url: None,
45        }
46    }
47
48    /// Set an external URL to fetch the API key from.
49    pub fn with_api_key_url(mut self, url: impl Into<String>) -> Self {
50        self.api_key_url = Some(url.into());
51        self
52    }
53
54    /// Resolve the API key (literal or remote).
55    pub async fn resolve_api_key(&self) -> Result<String> {
56        if let Some(ref url) = self.api_key_url {
57            let resp = reqwest::get(url)
58                .await
59                .map_err(|e| crate::AetherError::NetworkError(format!("Failed to fetch API key: {}", e)))?;
60            
61            let key = resp
62                .text()
63                .await
64                .map_err(|e| crate::AetherError::NetworkError(format!("Failed to read API key body: {}", e)))?;
65            
66            Ok(key.trim().to_string())
67        } else {
68            Ok(self.api_key.clone())
69        }
70    }
71
72    /// Set the base URL.
73    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
74        self.base_url = Some(url.into());
75        self
76    }
77
78    /// Set max tokens.
79    pub fn with_max_tokens(mut self, tokens: u32) -> Self {
80        self.max_tokens = Some(tokens);
81        self
82    }
83
84    /// Set temperature.
85    pub fn with_temperature(mut self, temp: f32) -> Self {
86        self.temperature = Some(temp.clamp(0.0, 2.0));
87        self
88    }
89
90    /// Set timeout.
91    pub fn with_timeout(mut self, seconds: u64) -> Self {
92        self.timeout_seconds = Some(seconds);
93        self
94    }
95
96    /// Load config from environment variables.
97    ///
98    /// Expected variables:
99    /// - `AETHER_API_KEY` or `OPENAI_API_KEY`
100    /// - `AETHER_MODEL` (defaults to "gpt-4")
101    /// - `AETHER_BASE_URL` (optional)
102    pub fn from_env() -> Result<Self> {
103        let api_key = std::env::var("AETHER_API_KEY")
104            .or_else(|_| std::env::var("OPENAI_API_KEY"))
105            .map_err(|_| {
106                crate::AetherError::ConfigError(
107                    "AETHER_API_KEY or OPENAI_API_KEY must be set".to_string(),
108                )
109            })?;
110
111        let model = std::env::var("AETHER_MODEL").unwrap_or_else(|_| "gpt-5.2-thinking".to_string());
112
113        let mut config = Self::new(api_key, model);
114
115        if let Ok(url) = std::env::var("AETHER_BASE_URL") {
116            config = config.with_base_url(url);
117        }
118
119        Ok(config)
120    }
121}
122
123/// Request for code generation.
124#[derive(Debug, Clone)]
125pub struct GenerationRequest {
126    /// The slot to generate code for.
127    pub slot: Slot,
128
129    /// Additional context (e.g., surrounding code).
130    pub context: Option<String>,
131
132    /// System prompt override.
133    pub system_prompt: Option<String>,
134}
135
136use futures::stream::BoxStream;
137
138/// Response from code generation.
139#[derive(Debug, Clone)]
140pub struct GenerationResponse {
141    /// The generated code.
142    pub code: String,
143
144    /// Tokens used for the request.
145    pub tokens_used: Option<u32>,
146
147    /// Generation metadata.
148    pub metadata: Option<serde_json::Value>,
149}
150
151/// A single chunk of a streaming response.
152#[derive(Debug, Clone)]
153pub struct StreamResponse {
154    /// The new text chunk.
155    pub delta: String,
156
157    /// Final metadata (only sent in the last chunk).
158    pub metadata: Option<serde_json::Value>,
159}
160
161/// Trait that AI providers must implement.
162///
163/// This trait defines the interface for generating code from slots.
164#[async_trait]
165pub trait AiProvider: Send + Sync {
166    /// Get the provider name.
167    fn name(&self) -> &str;
168
169    /// Generate code for a slot.
170    ///
171    /// # Arguments
172    ///
173    /// * `request` - The generation request containing slot info
174    ///
175    /// # Returns
176    ///
177    /// Generated code response or an error.
178    async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse>;
179
180    /// Generate a stream of code for a slot.
181    ///
182    /// # Arguments
183    ///
184    /// * `request` - The generation request containing slot info
185    ///
186    /// # Returns
187    ///
188    /// A pinned stream of chunks or an error.
189    fn generate_stream(
190        &self,
191        _request: GenerationRequest,
192    ) -> BoxStream<'static, Result<StreamResponse>> {
193        let name = self.name().to_string();
194        Box::pin(async_stream::stream! {
195            yield Err(crate::AetherError::ProviderError(format!(
196                "Streaming not implemented for provider: {}",
197                name
198            )));
199        })
200    }
201
202    /// Generate code for multiple slots in batch.
203    ///
204    /// Default implementation calls `generate` for each slot sequentially.
205    async fn generate_batch(
206        &self,
207        requests: Vec<GenerationRequest>,
208    ) -> Result<Vec<GenerationResponse>> {
209        let mut responses = Vec::with_capacity(requests.len());
210        for request in requests {
211            responses.push(self.generate(request).await?);
212        }
213        Ok(responses)
214    }
215
216    /// Check if the provider is available and configured correctly.
217    async fn health_check(&self) -> Result<bool> {
218        Ok(true)
219    }
220}
221
222/// A mock provider for testing.
223#[derive(Debug, Default)]
224pub struct MockProvider {
225    /// Responses to return (slot_name -> code).
226    pub responses: std::collections::HashMap<String, String>,
227}
228
229impl MockProvider {
230    /// Create a new mock provider.
231    pub fn new() -> Self {
232        Self::default()
233    }
234
235    /// Add a mock response.
236    pub fn with_response(mut self, slot: impl Into<String>, code: impl Into<String>) -> Self {
237        self.responses.insert(slot.into(), code.into());
238        self
239    }
240}
241
242#[async_trait]
243impl AiProvider for MockProvider {
244    fn name(&self) -> &str {
245        "mock"
246    }
247
248    async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
249        let code = self
250            .responses
251            .get(&request.slot.name)
252            .cloned()
253            .unwrap_or_else(|| format!("// Generated code for: {}", request.slot.name));
254
255        Ok(GenerationResponse {
256            code,
257            tokens_used: Some(10),
258            metadata: None,
259        })
260    }
261
262    fn generate_stream(
263        &self,
264        request: GenerationRequest,
265    ) -> BoxStream<'static, Result<StreamResponse>> {
266        let code = self
267            .responses
268            .get(&request.slot.name)
269            .cloned()
270            .unwrap_or_else(|| format!("// Generated code for: {}", request.slot.name));
271
272        use futures::StreamExt;
273        let words: Vec<String> = code.split_whitespace().map(|s| format!("{} ", s)).collect();
274        
275        let stream = async_stream::stream! {
276            for word in words {
277                yield Ok(StreamResponse {
278                    delta: word,
279                    metadata: None,
280                });
281            }
282        };
283        
284        Box::pin(stream)
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    #[tokio::test]
293    async fn test_mock_provider() {
294        let provider = MockProvider::new()
295            .with_response("button", "<button>Click me</button>");
296
297        let request = GenerationRequest {
298            slot: Slot::new("button", "Create a button"),
299            context: None,
300            system_prompt: None,
301        };
302
303        let response = provider.generate(request).await.unwrap();
304        assert_eq!(response.code, "<button>Click me</button>");
305    }
306}