aether_core/
engine.rs

1//! Injection Engine - The main orchestrator for AI code injection.
2//!
3//! This module provides the high-level API for rendering templates with AI-generated code.
4
5use crate::{
6    AetherError, AiProvider, InjectionContext, Result, Template,
7    provider::{GenerationRequest, GenerationResponse},
8    config::AetherConfig,
9};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tracing::{debug, info, instrument};
13use futures::stream::BoxStream;
14use crate::provider::StreamResponse;
15use crate::validation::{Validator, ValidationResult};
16use crate::cache::Cache;
17use crate::toon::Toon;
18
19/// The main engine for AI code injection.
20///
21/// # Example
22///
23/// ```rust,ignore
24/// use aether_core::{InjectionEngine, Template, AetherConfig};
25/// use aether_ai::OpenAiProvider;
26///
27/// let provider = OpenAiProvider::from_env()?;
28/// 
29/// // Using config
30/// let config = AetherConfig::from_env();
31/// let engine = InjectionEngine::with_config(provider, config);
32///
33/// // Or simple
34/// let engine = InjectionEngine::new(provider);
35/// ```
36pub struct InjectionEngine<P: AiProvider> {
37    /// The AI provider for code generation.
38    provider: Arc<P>,
39    
40    /// Optional validator for self-healing.
41    validator: Option<Arc<dyn Validator>>,
42
43    /// Optional cache for performance/cost optimization.
44    cache: Option<Arc<dyn Cache>>,
45
46    /// Whether to use TOON format for context injection.
47    use_toon: bool,
48
49    /// Auto TOON threshold (characters).
50    auto_toon_threshold: Option<usize>,
51
52    /// Global context applied to all generations.
53    global_context: InjectionContext,
54
55    /// Whether to run generations in parallel.
56    parallel: bool,
57
58    /// Maximum retries for failed generations.
59    max_retries: u32,
60}
61
62impl<P: AiProvider + 'static> InjectionEngine<P> {
63    /// Create a new injection engine with the given provider and default config.
64    pub fn new(provider: P) -> Self {
65        Self::with_config(provider, AetherConfig::default())
66    }
67
68    /// Create a new injection engine with the given provider and config.
69    pub fn with_config(provider: P, config: AetherConfig) -> Self {
70        Self {
71            provider: Arc::new(provider),
72            validator: None,
73            cache: None,
74            use_toon: config.toon_enabled,
75            auto_toon_threshold: config.auto_toon_threshold,
76            global_context: InjectionContext::default(),
77            parallel: config.parallel,
78            max_retries: config.max_retries,
79        }
80    }
81
82    /// Set the cache for performance optimization.
83    pub fn with_cache(mut self, cache: impl Cache + 'static) -> Self {
84        self.cache = Some(Arc::new(cache));
85        self
86    }
87
88    /// Enable or disable TOON format for context.
89    pub fn with_toon(mut self, enabled: bool) -> Self {
90        self.use_toon = enabled;
91        self
92    }
93
94    /// Set the validator for self-healing.
95    pub fn with_validator(mut self, validator: impl Validator + 'static) -> Self {
96        self.validator = Some(Arc::new(validator));
97        self
98    }
99
100    /// Set the global context.
101    pub fn with_context(mut self, context: InjectionContext) -> Self {
102        self.global_context = context;
103        self
104    }
105
106    /// Enable or disable parallel generation.
107    pub fn parallel(mut self, enabled: bool) -> Self {
108        self.parallel = enabled;
109        self
110    }
111
112    /// Set maximum retries for failed generations.
113    pub fn max_retries(mut self, retries: u32) -> Self {
114        self.max_retries = retries;
115        self
116    }
117
118    /// Render a template with AI-generated code.
119    ///
120    /// This method will generate code for all slots in the template
121    /// and return the final rendered content.
122    #[instrument(skip(self, template), fields(template_name = %template.name))]
123    pub async fn render(&self, template: &Template) -> Result<String> {
124        info!("Rendering template: {}", template.name);
125
126        let injections = self.generate_all(template, None).await?;
127        template.render(&injections)
128    }
129
130    /// Render a template with additional context.
131    #[instrument(skip(self, template, context), fields(template_name = %template.name))]
132    pub async fn render_with_context(
133        &self,
134        template: &Template,
135        context: InjectionContext,
136    ) -> Result<String> {
137        info!("Rendering template with context: {}", template.name);
138
139        let injections = self.generate_all(template, Some(context)).await?;
140        template.render(&injections)
141    }
142
143    async fn generate_all(
144        &self,
145        template: &Template,
146        extra_context: Option<InjectionContext>,
147    ) -> Result<HashMap<String, String>> {
148        let mut injections = HashMap::new();
149
150        // Build base context first to check length
151        let base_context = if let Some(ref ctx) = extra_context {
152            format!("{}\n{}", self.global_context.to_prompt(), ctx.to_prompt())
153        } else {
154            self.global_context.to_prompt()
155        };
156
157        // Determine if TOON should be used (explicit or auto-threshold)
158        let should_use_toon = self.use_toon || self.auto_toon_threshold
159            .map(|threshold| base_context.len() >= threshold)
160            .unwrap_or(false);
161
162        let context_prompt = if should_use_toon {
163            // TOON optimization - compress context
164            let toon_ctx = Toon::serialize(&serde_json::to_value(&self.global_context).unwrap_or(serde_json::json!({})));
165            format!("[CONTEXT:TOON]\n{}\n[Note: TOON is a compact key:value format to save tokens]\n", toon_ctx)
166        } else {
167            base_context
168        };
169
170        let mut context_prompt = context_prompt;
171        // If self-healing is enabled, encourage AI to write unit tests
172        if self.validator.is_some() {
173            context_prompt.push_str("\n\nIMPORTANT: Include a Rust `#[cfg(test)] mod tests { ... }` block with at least one unit test to verify your code. The system will automatically run these tests to validate your output.");
174        }
175
176        if self.parallel {
177            injections = self
178                .generate_parallel(template, &context_prompt)
179                .await?;
180        } else {
181            for (name, slot) in &template.slots {
182                debug!("Generating code for slot: {}", name);
183
184                let request = GenerationRequest {
185                    slot: slot.clone(),
186                    context: Some(context_prompt.clone()),
187                    system_prompt: None,
188                };
189
190                let response = self.generate_with_retry(request).await?;
191                injections.insert(name.clone(), response.code);
192            }
193        }
194
195        Ok(injections)
196    }
197
198    /// Generate code for all slots in parallel.
199    async fn generate_parallel(
200        &self,
201        template: &Template,
202        context_prompt: &str,
203    ) -> Result<HashMap<String, String>> {
204        use tokio::task::JoinSet;
205
206        let mut join_set = JoinSet::new();
207
208        for (name, slot) in template.slots.clone() {
209            let provider = Arc::clone(&self.provider);
210            let validator = self.validator.clone();
211            let cache = self.cache.clone();
212            let context = context_prompt.to_string();
213            let max_retries = self.max_retries;
214
215            join_set.spawn(async move {
216                let request = GenerationRequest {
217                    slot,
218                    context: Some(context),
219                    system_prompt: None,
220                };
221
222                let response = Self::generate_with_healing_static(&provider, &validator, &cache, request, max_retries).await?;
223                Ok::<_, AetherError>((name, response.code))
224            });
225        }
226
227        let mut injections = HashMap::new();
228        while let Some(result) = join_set.join_next().await {
229            let (name, code) = result.map_err(|e| AetherError::InjectionError(e.to_string()))??;
230            injections.insert(name, code);
231        }
232
233        Ok(injections)
234    }
235
236    /// Generate with self-healing logic.
237    async fn generate_with_retry(&self, request: GenerationRequest) -> Result<GenerationResponse> {
238        Self::generate_with_healing_static(&self.provider, &self.validator, &self.cache, request, self.max_retries).await
239    }
240
241    /// Static version of generate with self-healing support.
242    async fn generate_with_healing_static(
243        provider: &Arc<P>,
244        validator: &Option<Arc<dyn Validator>>,
245        cache: &Option<Arc<dyn Cache>>,
246        mut request: GenerationRequest,
247        max_retries: u32,
248    ) -> Result<GenerationResponse> {
249        // 0. Check cache first
250        let cache_key = if cache.is_some() {
251            Some(format!("{}:{}", request.slot.prompt, request.context.as_deref().unwrap_or("")))
252        } else {
253            None
254        };
255
256        if let (Some(ref c), Some(ref key)) = (cache, &cache_key) {
257            if let Some(cached_code) = c.get(key) {
258                debug!("Cache hit for slot: {}", request.slot.name);
259                return Ok(GenerationResponse {
260                    code: cached_code,
261                    tokens_used: None,
262                    metadata: Some(serde_json::json!({"cache": "hit"})),
263                });
264            }
265        }
266
267        let mut last_error = None;
268
269        for attempt in 0..=max_retries {
270            // 1. Generate code
271            let mut response = match provider.generate(request.clone()).await {
272                Ok(r) => r,
273                Err(e) => {
274                    debug!("Generation attempt {} failed: {}", attempt + 1, e);
275                    last_error = Some(e);
276                    if attempt < max_retries {
277                        tokio::time::sleep(std::time::Duration::from_millis(100 * (attempt as u64 + 1))).await;
278                        continue;
279                    }
280                    return Err(last_error.unwrap());
281                }
282            };
283
284            // 2. Validate and Heal if validator is present
285            if let Some(ref val) = validator {
286                // Apply formatting (Linter compliance)
287                if let Ok(formatted) = val.format(&request.slot.kind, &response.code) {
288                    response.code = formatted;
289                }
290
291                match val.validate(&request.slot.kind, &response.code)? {
292                    ValidationResult::Valid => {
293                        // Success! Cache if enabled
294                        if let (Some(ref c), Some(ref key)) = (cache, &cache_key) {
295                            c.set(key, response.code.clone());
296                        }
297                        return Ok(response);
298                    },
299                    ValidationResult::Invalid(err_msg) => {
300                        info!("Self-healing: Validation failed for slot '{}', attempt {}. Error: {}", 
301                            request.slot.name, attempt + 1, err_msg);
302                        
303                        last_error = Some(AetherError::ProviderError(err_msg.clone()));
304
305                        if attempt < max_retries {
306                            // Feedback Loop: Add error to prompt for next attempt
307                            request.slot.prompt = format!(
308                                "{}\n\n[SELF-HEALING FEEDBACK]\nYour previous output had validation errors. Please fix them and output ONLY the corrected code.\nERROR:\n{}",
309                                request.slot.prompt,
310                                err_msg
311                            );
312                            continue;
313                        }
314                    }
315                }
316            } else {
317                // No validator, just cache and return
318                if let (Some(ref c), Some(ref key)) = (cache, &cache_key) {
319                    c.set(key, response.code.clone());
320                }
321                return Ok(response);
322            }
323        }
324
325        Err(last_error.unwrap_or_else(|| AetherError::ProviderError("Healing failed".to_string())))
326    }
327
328    /// Generate code for a single slot.
329    pub async fn generate_slot(&self, template: &Template, slot_name: &str) -> Result<String> {
330        let slot = template
331            .slots
332            .get(slot_name)
333            .ok_or_else(|| AetherError::SlotNotFound(slot_name.to_string()))?;
334
335        let request = GenerationRequest {
336            slot: slot.clone(),
337            context: Some(self.global_context.to_prompt()),
338            system_prompt: None,
339        };
340
341        let response = self.generate_with_retry(request).await?;
342        Ok(response.code)
343    }
344
345    /// Generate code for a single slot as a stream.
346    pub fn generate_slot_stream(
347        &self,
348        template: &Template,
349        slot_name: &str,
350    ) -> Result<BoxStream<'static, Result<StreamResponse>>> {
351        let slot = template
352            .slots
353            .get(slot_name)
354            .ok_or_else(|| AetherError::SlotNotFound(slot_name.to_string()))?;
355
356        let request = GenerationRequest {
357            slot: slot.clone(),
358            context: Some(self.global_context.to_prompt()),
359            system_prompt: None,
360        };
361
362        Ok(self.provider.generate_stream(request))
363    }
364}
365
366/// Convenience function for one-line AI code injection.
367///
368/// # Example
369///
370/// ```rust,ignore
371/// let code = aether_core::inject!("Create a login form with email and password", OpenAiProvider::from_env()?);
372/// ```
373#[macro_export]
374macro_rules! inject {
375    ($prompt:expr, $provider:expr) => {{
376        use $crate::{InjectionEngine, Slot, Template};
377
378        let template = Template::new("{{AI:generated}}")
379            .with_slot("generated", $prompt);
380
381        let engine = InjectionEngine::new($provider);
382        engine.render(&template)
383    }};
384}
385
386/// Convenience function for synchronous one-line injection (blocking).
387#[macro_export]
388macro_rules! inject_sync {
389    ($prompt:expr, $provider:expr) => {{
390        tokio::runtime::Runtime::new()
391            .unwrap()
392            .block_on($crate::inject!($prompt, $provider))
393    }};
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399    use crate::provider::MockProvider;
400
401    #[tokio::test]
402    async fn test_engine_render() {
403        let provider = MockProvider::new()
404            .with_response("content", "<p>Hello World</p>");
405
406        let engine = InjectionEngine::new(provider);
407
408        let template = Template::new("<div>{{AI:content}}</div>")
409            .with_slot("content", "Generate a paragraph");
410
411        let result = engine.render(&template).await.unwrap();
412        assert_eq!(result, "<div><p>Hello World</p></div>");
413    }
414
415    #[tokio::test]
416    async fn test_engine_with_context() {
417        let provider = MockProvider::new()
418            .with_response("button", "<button class='btn'>Click</button>");
419
420        let engine = InjectionEngine::new(provider)
421            .with_context(InjectionContext::new().with_framework("react"));
422
423        let template = Template::new("{{AI:button}}")
424            .with_slot("button", "Create a button");
425
426        let result = engine.render(&template).await.unwrap();
427        assert!(result.contains("button"));
428    }
429
430    #[tokio::test]
431    async fn test_parallel_generation() {
432        let provider = MockProvider::new()
433            .with_response("slot1", "code1")
434            .with_response("slot2", "code2");
435
436        let engine = InjectionEngine::new(provider).parallel(true);
437
438        let template = Template::new("{{AI:slot1}} | {{AI:slot2}}");
439
440        let result = engine.render(&template).await.unwrap();
441        assert!(result.contains("code1"));
442        assert!(result.contains("code2"));
443    }
444}