Skip to main content

aether_core/
engine.rs

1//! Injection Engine - The main orchestrator for AI code injection.
2//!
3//! This module provides the high-level API for rendering templates with AI-generated code.
4
5use crate::{
6    AetherError, AiProvider, InjectionContext, Result, Template,
7    provider::{GenerationRequest, GenerationResponse},
8};
9use std::collections::HashMap;
10use std::sync::Arc;
11use tracing::{debug, info, instrument};
12use futures::stream::BoxStream;
13use crate::provider::StreamResponse;
14use crate::validation::{Validator, ValidationResult};
15use crate::cache::Cache;
16use crate::toon::Toon;
17
18/// The main engine for AI code injection.
19///
20/// # Example
21///
22/// ```rust,ignore
23/// use aether_core::{InjectionEngine, Template};
24/// use aether_ai::OpenAiProvider;
25///
26/// let provider = OpenAiProvider::from_env()?;
27/// let engine = InjectionEngine::new(provider);
28///
29/// let template = Template::new("<div>{{AI:content}}</div>")
30///     .with_slot("content", "Generate a welcome message");
31///
32/// let result = engine.render(&template).await?;
33/// ```
34pub struct InjectionEngine<P: AiProvider> {
35    /// The AI provider for code generation.
36    provider: Arc<P>,
37    
38    /// Optional validator for self-healing.
39    validator: Option<Arc<dyn Validator>>,
40
41    /// Optional cache for performance/cost optimization.
42    cache: Option<Arc<dyn Cache>>,
43
44    /// Whether to use TOON format for context injection.
45    use_toon: bool,
46
47    /// Global context applied to all generations.
48    global_context: InjectionContext,
49
50    /// Whether to run generations in parallel.
51    parallel: bool,
52
53    /// Maximum retries for failed generations.
54    max_retries: u32,
55}
56
57impl<P: AiProvider + 'static> InjectionEngine<P> {
58    /// Create a new injection engine with the given provider.
59    pub fn new(provider: P) -> Self {
60        Self {
61            provider: Arc::new(provider),
62            validator: None,
63            cache: None,
64            use_toon: false,
65            global_context: InjectionContext::default(),
66            parallel: true,
67            max_retries: 2,
68        }
69    }
70
71    /// Set the cache for performance optimization.
72    pub fn with_cache(mut self, cache: impl Cache + 'static) -> Self {
73        self.cache = Some(Arc::new(cache));
74        self
75    }
76
77    /// Enable or disable TOON format for context.
78    pub fn with_toon(mut self, enabled: bool) -> Self {
79        self.use_toon = enabled;
80        self
81    }
82
83    /// Set the validator for self-healing.
84    pub fn with_validator(mut self, validator: impl Validator + 'static) -> Self {
85        self.validator = Some(Arc::new(validator));
86        self
87    }
88
89    /// Set the global context.
90    pub fn with_context(mut self, context: InjectionContext) -> Self {
91        self.global_context = context;
92        self
93    }
94
95    /// Enable or disable parallel generation.
96    pub fn parallel(mut self, enabled: bool) -> Self {
97        self.parallel = enabled;
98        self
99    }
100
101    /// Set maximum retries for failed generations.
102    pub fn max_retries(mut self, retries: u32) -> Self {
103        self.max_retries = retries;
104        self
105    }
106
107    /// Render a template with AI-generated code.
108    ///
109    /// This method will generate code for all slots in the template
110    /// and return the final rendered content.
111    #[instrument(skip(self, template), fields(template_name = %template.name))]
112    pub async fn render(&self, template: &Template) -> Result<String> {
113        info!("Rendering template: {}", template.name);
114
115        let injections = self.generate_all(template, None).await?;
116        template.render(&injections)
117    }
118
119    /// Render a template with additional context.
120    #[instrument(skip(self, template, context), fields(template_name = %template.name))]
121    pub async fn render_with_context(
122        &self,
123        template: &Template,
124        context: InjectionContext,
125    ) -> Result<String> {
126        info!("Rendering template with context: {}", template.name);
127
128        let injections = self.generate_all(template, Some(context)).await?;
129        template.render(&injections)
130    }
131
132    /// Generate code for all slots in a template.
133    /// Generate code for all slots in a template.
134    async fn generate_all(
135        &self,
136        template: &Template,
137        extra_context: Option<InjectionContext>,
138    ) -> Result<HashMap<String, String>> {
139        let mut injections = HashMap::new();
140
141        let context_prompt = if self.use_toon {
142            // TOON optimization
143            let toon_ctx = Toon::serialize(&serde_json::to_value(&self.global_context).unwrap_or(serde_json::json!({})));
144            format!("[CONTEXT:TOON]\n{}\n", toon_ctx)
145        } else if let Some(ref ctx) = extra_context {
146            format!("{}\n{}", self.global_context.to_prompt(), ctx.to_prompt())
147        } else {
148            self.global_context.to_prompt()
149        };
150
151        let mut context_prompt = context_prompt;
152        // If self-healing is enabled, encourage AI to write unit tests
153        if self.validator.is_some() {
154            context_prompt.push_str("\n\nIMPORTANT: Include a Rust `#[cfg(test)] mod tests { ... }` block with at least one unit test to verify your code. The system will automatically run these tests to validate your output.");
155        }
156
157        if self.parallel {
158            injections = self
159                .generate_parallel(template, &context_prompt)
160                .await?;
161        } else {
162            for (name, slot) in &template.slots {
163                debug!("Generating code for slot: {}", name);
164
165                let request = GenerationRequest {
166                    slot: slot.clone(),
167                    context: Some(context_prompt.clone()),
168                    system_prompt: None,
169                };
170
171                let response = self.generate_with_retry(request).await?;
172                injections.insert(name.clone(), response.code);
173            }
174        }
175
176        Ok(injections)
177    }
178
179    /// Generate code for all slots in parallel.
180    async fn generate_parallel(
181        &self,
182        template: &Template,
183        context_prompt: &str,
184    ) -> Result<HashMap<String, String>> {
185        use tokio::task::JoinSet;
186
187        let mut join_set = JoinSet::new();
188
189        for (name, slot) in template.slots.clone() {
190            let provider = Arc::clone(&self.provider);
191            let validator = self.validator.clone();
192            let cache = self.cache.clone();
193            let context = context_prompt.to_string();
194            let max_retries = self.max_retries;
195
196            join_set.spawn(async move {
197                let request = GenerationRequest {
198                    slot,
199                    context: Some(context),
200                    system_prompt: None,
201                };
202
203                let response = Self::generate_with_healing_static(&provider, &validator, &cache, request, max_retries).await?;
204                Ok::<_, AetherError>((name, response.code))
205            });
206        }
207
208        let mut injections = HashMap::new();
209        while let Some(result) = join_set.join_next().await {
210            let (name, code) = result.map_err(|e| AetherError::InjectionError(e.to_string()))??;
211            injections.insert(name, code);
212        }
213
214        Ok(injections)
215    }
216
217    /// Generate with self-healing logic.
218    async fn generate_with_retry(&self, request: GenerationRequest) -> Result<GenerationResponse> {
219        Self::generate_with_healing_static(&self.provider, &self.validator, &self.cache, request, self.max_retries).await
220    }
221
222    /// Static version of generate with self-healing support.
223    async fn generate_with_healing_static(
224        provider: &Arc<P>,
225        validator: &Option<Arc<dyn Validator>>,
226        cache: &Option<Arc<dyn Cache>>,
227        mut request: GenerationRequest,
228        max_retries: u32,
229    ) -> Result<GenerationResponse> {
230        // 0. Check cache first
231        let cache_key = if cache.is_some() {
232            Some(format!("{}:{}", request.slot.prompt, request.context.as_deref().unwrap_or("")))
233        } else {
234            None
235        };
236
237        if let (Some(ref c), Some(ref key)) = (cache, &cache_key) {
238            if let Some(cached_code) = c.get(key) {
239                debug!("Cache hit for slot: {}", request.slot.name);
240                return Ok(GenerationResponse {
241                    code: cached_code,
242                    tokens_used: None,
243                    metadata: Some(serde_json::json!({"cache": "hit"})),
244                });
245            }
246        }
247
248        let mut last_error = None;
249
250        for attempt in 0..=max_retries {
251            // 1. Generate code
252            let mut response = match provider.generate(request.clone()).await {
253                Ok(r) => r,
254                Err(e) => {
255                    debug!("Generation attempt {} failed: {}", attempt + 1, e);
256                    last_error = Some(e);
257                    if attempt < max_retries {
258                        tokio::time::sleep(std::time::Duration::from_millis(100 * (attempt as u64 + 1))).await;
259                        continue;
260                    }
261                    return Err(last_error.unwrap());
262                }
263            };
264
265            // 2. Validate and Heal if validator is present
266            if let Some(ref val) = validator {
267                // Apply formatting (Linter compliance)
268                if let Ok(formatted) = val.format(&request.slot.kind, &response.code) {
269                    response.code = formatted;
270                }
271
272                match val.validate(&request.slot.kind, &response.code)? {
273                    ValidationResult::Valid => {
274                        // Success! Cache if enabled
275                        if let (Some(ref c), Some(ref key)) = (cache, &cache_key) {
276                            c.set(key, response.code.clone());
277                        }
278                        return Ok(response);
279                    },
280                    ValidationResult::Invalid(err_msg) => {
281                        info!("Self-healing: Validation failed for slot '{}', attempt {}. Error: {}", 
282                            request.slot.name, attempt + 1, err_msg);
283                        
284                        last_error = Some(AetherError::ProviderError(err_msg.clone()));
285
286                        if attempt < max_retries {
287                            // Feedback Loop: Add error to prompt for next attempt
288                            request.slot.prompt = format!(
289                                "{}\n\n[SELF-HEALING FEEDBACK]\nYour previous output had validation errors. Please fix them and output ONLY the corrected code.\nERROR:\n{}",
290                                request.slot.prompt,
291                                err_msg
292                            );
293                            continue;
294                        }
295                    }
296                }
297            } else {
298                // No validator, just cache and return
299                if let (Some(ref c), Some(ref key)) = (cache, &cache_key) {
300                    c.set(key, response.code.clone());
301                }
302                return Ok(response);
303            }
304        }
305
306        Err(last_error.unwrap_or_else(|| AetherError::ProviderError("Healing failed".to_string())))
307    }
308
309    /// Generate code for a single slot.
310    pub async fn generate_slot(&self, template: &Template, slot_name: &str) -> Result<String> {
311        let slot = template
312            .slots
313            .get(slot_name)
314            .ok_or_else(|| AetherError::SlotNotFound(slot_name.to_string()))?;
315
316        let request = GenerationRequest {
317            slot: slot.clone(),
318            context: Some(self.global_context.to_prompt()),
319            system_prompt: None,
320        };
321
322        let response = self.generate_with_retry(request).await?;
323        Ok(response.code)
324    }
325
326    /// Generate code for a single slot as a stream.
327    pub fn generate_slot_stream(
328        &self,
329        template: &Template,
330        slot_name: &str,
331    ) -> Result<BoxStream<'static, Result<StreamResponse>>> {
332        let slot = template
333            .slots
334            .get(slot_name)
335            .ok_or_else(|| AetherError::SlotNotFound(slot_name.to_string()))?;
336
337        let request = GenerationRequest {
338            slot: slot.clone(),
339            context: Some(self.global_context.to_prompt()),
340            system_prompt: None,
341        };
342
343        Ok(self.provider.generate_stream(request))
344    }
345}
346
347/// Convenience function for one-line AI code injection.
348///
349/// # Example
350///
351/// ```rust,ignore
352/// let code = aether_core::inject!("Create a login form with email and password", OpenAiProvider::from_env()?);
353/// ```
354#[macro_export]
355macro_rules! inject {
356    ($prompt:expr, $provider:expr) => {{
357        use $crate::{InjectionEngine, Slot, Template};
358
359        let template = Template::new("{{AI:generated}}")
360            .with_slot("generated", $prompt);
361
362        let engine = InjectionEngine::new($provider);
363        engine.render(&template)
364    }};
365}
366
367/// Convenience function for synchronous one-line injection (blocking).
368#[macro_export]
369macro_rules! inject_sync {
370    ($prompt:expr, $provider:expr) => {{
371        tokio::runtime::Runtime::new()
372            .unwrap()
373            .block_on($crate::inject!($prompt, $provider))
374    }};
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380    use crate::provider::MockProvider;
381
382    #[tokio::test]
383    async fn test_engine_render() {
384        let provider = MockProvider::new()
385            .with_response("content", "<p>Hello World</p>");
386
387        let engine = InjectionEngine::new(provider);
388
389        let template = Template::new("<div>{{AI:content}}</div>")
390            .with_slot("content", "Generate a paragraph");
391
392        let result = engine.render(&template).await.unwrap();
393        assert_eq!(result, "<div><p>Hello World</p></div>");
394    }
395
396    #[tokio::test]
397    async fn test_engine_with_context() {
398        let provider = MockProvider::new()
399            .with_response("button", "<button class='btn'>Click</button>");
400
401        let engine = InjectionEngine::new(provider)
402            .with_context(InjectionContext::new().with_framework("react"));
403
404        let template = Template::new("{{AI:button}}")
405            .with_slot("button", "Create a button");
406
407        let result = engine.render(&template).await.unwrap();
408        assert!(result.contains("button"));
409    }
410
411    #[tokio::test]
412    async fn test_parallel_generation() {
413        let provider = MockProvider::new()
414            .with_response("slot1", "code1")
415            .with_response("slot2", "code2");
416
417        let engine = InjectionEngine::new(provider).parallel(true);
418
419        let template = Template::new("{{AI:slot1}} | {{AI:slot2}}");
420
421        let result = engine.render(&template).await.unwrap();
422        assert!(result.contains("code1"));
423        assert!(result.contains("code2"));
424    }
425}