aether_core/
engine.rs

1//! Injection Engine - The main orchestrator for AI code injection.
2//!
3//! This module provides the high-level API for rendering templates with AI-generated code.
4
5use crate::{
6    AetherError, AiProvider, InjectionContext, Result, Template,
7    provider::{GenerationRequest, GenerationResponse},
8    config::AetherConfig,
9};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tracing::{debug, info, instrument};
13use futures::stream::BoxStream;
14use crate::provider::StreamResponse;
15use crate::validation::{Validator, ValidationResult};
16use crate::cache::Cache;
17use crate::toon::Toon;
18pub use crate::observer::ObserverPtr;
19use std::hash::{Hash, Hasher};
20use std::collections::hash_map::DefaultHasher;
21
22/// The main engine for AI code injection.
23///
24/// # Example
25///
26/// ```rust,ignore
27/// use aether_core::{InjectionEngine, Template, AetherConfig};
28/// use aether_ai::OpenAiProvider;
29///
30/// let provider = OpenAiProvider::from_env()?;
31/// 
32/// // Using config
33/// let config = AetherConfig::from_env();
34/// let engine = InjectionEngine::with_config(provider, config);
35///
36/// // Or simple
37/// let engine = InjectionEngine::new(provider);
38/// ```
39pub struct InjectionEngine<P: AiProvider> {
40    /// The AI provider for code generation.
41    provider: Arc<P>,
42    
43    /// Optional validator for self-healing.
44    validator: Option<Arc<dyn Validator>>,
45
46    /// Optional cache for performance/cost optimization.
47    cache: Option<Arc<dyn Cache>>,
48
49    /// Whether to use TOON format for context injection.
50    use_toon: bool,
51
52    /// Auto TOON threshold (characters).
53    auto_toon_threshold: Option<usize>,
54
55    /// Global context applied to all generations.
56    global_context: InjectionContext,
57
58    /// Whether to run generations in parallel.
59    parallel: bool,
60
61    /// Maximum retries for failed generations.
62    max_retries: u32,
63
64    /// Optional observer for tracking events.
65    observer: Option<ObserverPtr>,
66}
67
68/// A session for tracking incremental rendering state.
69/// Holds fingerprints of slots and context to identify changes.
70#[derive(Debug, Clone, Default)]
71pub struct RenderSession {
72    /// Cached results indexed by (SlotHash, ContextHash)
73    pub results: HashMap<(u64, u64), String>,
74}
75
76impl RenderSession {
77    /// Create a new empty render session.
78    pub fn new() -> Self {
79        Self::default()
80    }
81
82    /// Calculate a stable hash for any hashable object.
83    pub fn hash<T: Hash>(t: &T) -> u64 {
84        let mut s = DefaultHasher::new();
85        t.hash(&mut s);
86        s.finish()
87    }
88}
89
90impl<P: AiProvider + 'static> InjectionEngine<P> {
91    /// Create a new injection engine with the given provider and default config.
92    pub fn new(provider: P) -> Self {
93        Self::with_config(provider, AetherConfig::default())
94    }
95
96    /// Internal: Create a raw engine without full config for script-based calls.
97    pub fn new_raw(provider: Arc<P>) -> Self {
98        Self {
99            provider,
100            validator: None,
101            cache: None,
102            use_toon: false,
103            auto_toon_threshold: None,
104            global_context: InjectionContext::default(),
105            parallel: false,
106            max_retries: 0,
107            observer: None,
108        }
109    }
110
111    /// Create a new injection engine with the given provider and config.
112    pub fn with_config(provider: P, config: AetherConfig) -> Self {
113        let validator: Option<Arc<dyn Validator>> = if config.healing_enabled {
114            Some(Arc::new(crate::validation::MultiValidator::new()))
115        } else {
116            None
117        };
118
119        Self {
120            provider: Arc::new(provider),
121            validator,
122            cache: None,
123            use_toon: config.toon_enabled,
124            auto_toon_threshold: config.auto_toon_threshold,
125            global_context: InjectionContext::default(),
126            parallel: config.parallel,
127            max_retries: config.max_retries,
128            observer: None,
129        }
130    }
131
132    /// Set the cache for performance optimization.
133    pub fn with_cache(mut self, cache: impl Cache + 'static) -> Self {
134        self.cache = Some(Arc::new(cache));
135        self
136    }
137
138    /// Enable or disable TOON format for context.
139    pub fn with_toon(mut self, enabled: bool) -> Self {
140        self.use_toon = enabled;
141        self
142    }
143
144    /// Set the validator for self-healing.
145    pub fn with_validator(mut self, validator: impl Validator + 'static) -> Self {
146        self.validator = Some(Arc::new(validator));
147        self
148    }
149
150    /// Set the global context.
151    pub fn with_context(mut self, context: InjectionContext) -> Self {
152        self.global_context = context;
153        self
154    }
155
156    /// Enable or disable parallel generation.
157    pub fn parallel(mut self, enabled: bool) -> Self {
158        self.parallel = enabled;
159        self
160    }
161
162    /// Set maximum retries for failed generations.
163    pub fn max_retries(mut self, retries: u32) -> Self {
164        self.max_retries = retries;
165        self
166    }
167
168    /// Set an observer for tracking events.
169    pub fn with_observer(mut self, observer: impl crate::observer::EngineObserver + 'static) -> Self {
170        self.observer = Some(Arc::new(observer));
171        self
172    }
173
174    /// Render a template with AI-generated code.
175    ///
176    /// This method will generate code for all slots in the template
177    /// and return the final rendered content.
178    #[instrument(skip(self, template), fields(template_name = %template.name))]
179    pub async fn render(&self, template: &Template) -> Result<String> {
180        info!("Rendering template: {}", template.name);
181
182        let injections = self.generate_all(template, None).await?;
183        template.render(&injections)
184    }
185
186    /// Render a template with additional context.
187    #[instrument(skip(self, template, context), fields(template_name = %template.name))]
188    pub async fn render_with_context(
189        &self,
190        template: &Template,
191        context: InjectionContext,
192    ) -> Result<String> {
193        info!("Rendering template with context: {}", template.name);
194
195        let injections = self.generate_all(template, Some(context)).await?;
196        template.render(&injections)
197    }
198
199    /// Render a template incrementally using a session.
200    /// 
201    /// This will only generate code for slots that have changed 
202    /// based on their definition and the current context.
203    #[instrument(skip(self, template, session), fields(template_name = %template.name))]
204    pub async fn render_incremental(
205        &self,
206        template: &Template,
207        session: &mut RenderSession,
208    ) -> Result<String> {
209        info!("Incrementally rendering template: {}", template.name);
210        
211        let context_hash = RenderSession::hash(&self.global_context);
212        let mut injections = HashMap::new();
213        
214        for (name, slot) in &template.slots {
215            let slot_hash = RenderSession::hash(slot);
216            let key = (slot_hash, context_hash);
217            
218            if let Some(cached) = session.results.get(&key) {
219                debug!("Incremental hit for slot: {}", name);
220                injections.insert(name.clone(), cached.clone());
221            } else {
222                debug!("Incremental miss for slot: {}", name);
223                let code = self.generate_slot(template, name).await?;
224                session.results.insert(key, code.clone());
225                injections.insert(name.clone(), code);
226            }
227        }
228        
229        template.render(&injections)
230    }
231
232    async fn generate_all(
233        &self,
234        template: &Template,
235        extra_context: Option<InjectionContext>,
236    ) -> Result<HashMap<String, String>> {
237        let mut injections = HashMap::new();
238
239        // Build base context first to check length
240        let base_context = if let Some(ref ctx) = extra_context {
241            format!("{}\n{}", self.global_context.to_prompt(), ctx.to_prompt())
242        } else {
243            self.global_context.to_prompt()
244        };
245
246        // Determine if TOON should be used (explicit or auto-threshold)
247        let should_use_toon = self.use_toon || self.auto_toon_threshold
248            .map(|threshold| base_context.len() >= threshold)
249            .unwrap_or(false);
250
251        let mut context_prompt = if should_use_toon {
252            // TOON optimization - compress context
253            let context_value = serde_json::to_value(&self.global_context)
254                .map_err(|e| AetherError::ContextSerializationError(e.to_string()))?;
255            let toon_ctx = Toon::serialize(&context_value);
256            format!(
257                "[CONTEXT:TOON]\n{}\n\n[TOON Protocol Note]\nTOON is a compact key:value mapping protocol. Each line represents 'key: value'. Use this context to inform your code generation, respecting the framework, language, and architectural constraints defined within.",
258                toon_ctx
259            )
260        } else {
261            base_context
262        };
263
264        // If self-healing is enabled, encourage AI to pass tests
265        if self.validator.is_some() {
266            context_prompt.push_str("\n\nIMPORTANT: The system is running in TDD (Test-Driven Development) mode. ");
267            context_prompt.push_str("Your code will be validated against compiler checks and functional tests. ");
268            context_prompt.push_str("If possible, include unit tests in your response to help self-verify. ");
269            context_prompt.push_str("If validation fails, you will receive feedback to fix the code.");
270        }
271        
272        let context_prompt = Arc::new(context_prompt);
273
274        if self.parallel {
275            injections = self
276                .generate_parallel(template, context_prompt)
277                .await?;
278        } else {
279            for (name, slot) in &template.slots {
280                debug!("Generating code for slot: {}", name);
281                let id = uuid::Uuid::new_v4().to_string();
282
283                let request = GenerationRequest {
284                    slot: slot.clone(),
285                    context: Some((*context_prompt).clone()),
286                    system_prompt: None,
287                };
288
289                if let Some(ref obs) = self.observer {
290                    obs.on_start(&id, &template.name, name, &request);
291                }
292
293                match self.generate_with_retry(request, &id).await {
294                    Ok(response) => {
295                        if let Some(ref obs) = self.observer {
296                            obs.on_success(&id, &response);
297                        }
298                        injections.insert(name.clone(), response.code);
299                    }
300                    Err(e) => {
301                        if let Some(ref obs) = self.observer {
302                            obs.on_failure(&id, &e.to_string());
303                        }
304                        return Err(e);
305                    }
306                }
307            }
308        }
309
310        Ok(injections)
311    }
312
313    async fn generate_parallel(
314        &self,
315        template: &Template,
316        context_prompt: Arc<String>,
317    ) -> Result<HashMap<String, String>> {
318        use tokio::task::JoinSet;
319
320        let mut join_set = JoinSet::new();
321
322        for (name, slot) in template.slots.clone() {
323            let provider = Arc::clone(&self.provider);
324            let validator = self.validator.clone();
325            let cache = self.cache.clone();
326            let context = Arc::clone(&context_prompt);
327            let max_retries = self.max_retries;
328            let template_name = template.name.clone();
329            let observer = self.observer.clone();
330
331            join_set.spawn(async move {
332                let id = uuid::Uuid::new_v4().to_string();
333                let request = GenerationRequest {
334                    slot,
335                    context: Some((*context).clone()),
336                    system_prompt: None,
337                };
338
339                if let Some(ref obs) = observer {
340                    obs.on_start(&id, &template_name, &name, &request);
341                }
342
343                match Self::generate_with_healing_static(&provider, &validator, &cache, &observer, request, max_retries, &id).await {
344                    Ok(response) => {
345                        if let Some(ref obs) = observer {
346                            obs.on_success(&id, &response);
347                        }
348                        Ok::<_, AetherError>((name, response.code))
349                    }
350                    Err(e) => {
351                        if let Some(ref obs) = observer {
352                            obs.on_failure(&id, &e.to_string());
353                        }
354                        Err(e)
355                    }
356                }
357            });
358        }
359
360        let mut injections = HashMap::new();
361        while let Some(result) = join_set.join_next().await {
362            let (name, code) = result.map_err(|e| AetherError::InjectionError(e.to_string()))??;
363            injections.insert(name, code);
364        }
365
366        Ok(injections)
367    }
368
369    /// Generate with self-healing logic.
370    async fn generate_with_retry(&self, request: GenerationRequest, id: &str) -> Result<GenerationResponse> {
371        Self::generate_with_healing_static(&self.provider, &self.validator, &self.cache, &self.observer, request, self.max_retries, id).await
372    }
373
374    /// Static version of generate with self-healing support.
375    async fn generate_with_healing_static(
376        provider: &Arc<P>,
377        validator: &Option<Arc<dyn Validator>>,
378        cache: &Option<Arc<dyn Cache>>,
379        observer: &Option<ObserverPtr>,
380        mut request: GenerationRequest,
381        max_retries: u32,
382        id: &str,
383    ) -> Result<GenerationResponse> {
384        // 0. Check cache first
385        let cache_key = if cache.is_some() {
386            Some(format!("{}:{}", request.slot.prompt, request.context.as_deref().unwrap_or("")))
387        } else {
388            None
389        };
390
391        if let (Some(ref c), Some(ref key)) = (cache, &cache_key) {
392            if let Some(cached_code) = c.get(key) {
393                debug!("Cache hit for slot: {}", request.slot.name);
394                return Ok(GenerationResponse {
395                    code: cached_code,
396                    tokens_used: None,
397                    metadata: Some(serde_json::json!({"cache": "hit"})),
398                });
399            }
400        }
401
402        let mut last_error = None;
403
404        for attempt in 0..=max_retries {
405            // 1. Generate code
406            let mut response = match provider.generate(request.clone()).await {
407                Ok(r) => r,
408                Err(e) => {
409                    debug!("Generation attempt {} failed: {}", attempt + 1, e);
410                    last_error = Some(e);
411                    if attempt < max_retries {
412                        tokio::time::sleep(std::time::Duration::from_millis(100 * (attempt as u64 + 1))).await;
413                        continue;
414                    }
415                    return Err(last_error.unwrap());
416                }
417            };
418
419            // 2. Validate and Heal if validator is present
420            if let Some(ref val) = validator {
421                // Apply formatting (Linter compliance)
422                if let Ok(formatted) = val.format(&request.slot.kind, &response.code) {
423                    response.code = formatted;
424                }
425                
426                // Use validate_with_slot to support TDD harnesses
427                match val.validate_with_slot(&request.slot, &response.code)? {
428                    ValidationResult::Valid => {
429                        // Success! Cache if enabled
430                        if let (Some(ref c), Some(ref key)) = (cache, &cache_key) {
431                            c.set(key, response.code.clone());
432                        }
433                        return Ok(response);
434                    },
435                    ValidationResult::Invalid(err_msg) => {
436                        info!("Self-healing: Validation failed for slot '{}', attempt {}. Error: {}", 
437                            request.slot.name, attempt + 1, err_msg);
438                        
439                        if let Some(ref obs) = observer {
440                            obs.on_healing_step(id, attempt + 1, &err_msg);
441                        }
442
443                        last_error = Some(AetherError::ValidationFailed { 
444                            slot: request.slot.name.clone(), 
445                            error: err_msg.clone() 
446                        });
447
448                        if attempt < max_retries {
449                            // Feedback Loop: Add error to prompt for next attempt
450                            request.slot.prompt = format!(
451                                "{}\n\n[SELF-HEALING FEEDBACK]\nYour previous output had validation errors. Please fix them and output ONLY the corrected code.\nERROR:\n{}",
452                                request.slot.prompt,
453                                err_msg
454                            );
455                            continue;
456                        }
457                    }
458                }
459            } else {
460                // No validator, just cache and return
461                if let (Some(ref c), Some(ref key)) = (cache, &cache_key) {
462                    c.set(key, response.code.clone());
463                }
464                return Ok(response);
465            }
466        }
467
468        Err(last_error.unwrap_or_else(|| AetherError::MaxRetriesExceeded { 
469            slot: request.slot.name, 
470            retries: max_retries, 
471            last_error: "Healing failed without specific error".to_string() 
472        }))
473    }
474
475    /// Generate code for a single slot.
476    pub async fn generate_slot(&self, template: &Template, slot_name: &str) -> Result<String> {
477        let slot = template
478            .slots
479            .get(slot_name)
480            .ok_or_else(|| AetherError::SlotNotFound(slot_name.to_string()))?;
481
482        let request = GenerationRequest {
483            slot: slot.clone(),
484            context: Some(self.global_context.to_prompt()),
485            system_prompt: None,
486        };
487
488        let id = uuid::Uuid::new_v4().to_string();
489        if let Some(ref obs) = self.observer {
490            obs.on_start(&id, &template.name, slot_name, &request);
491        }
492
493        match self.generate_with_retry(request, &id).await {
494            Ok(response) => {
495                if let Some(ref obs) = self.observer {
496                    obs.on_success(&id, &response);
497                }
498                Ok(response.code)
499            }
500            Err(e) => {
501                if let Some(ref obs) = self.observer {
502                    obs.on_failure(&id, &e.to_string());
503                }
504                Err(e)
505            }
506        }
507    }
508
509    /// Generate code for a single slot as a stream.
510    pub fn generate_slot_stream(
511        &self,
512        template: &Template,
513        slot_name: &str,
514    ) -> Result<BoxStream<'static, Result<StreamResponse>>> {
515        let slot = template
516            .slots
517            .get(slot_name)
518            .ok_or_else(|| AetherError::SlotNotFound(slot_name.to_string()))?;
519
520        let request = GenerationRequest {
521            slot: slot.clone(),
522            context: Some(self.global_context.to_prompt()),
523            system_prompt: None,
524        };
525
526        Ok(self.provider.generate_stream(request))
527    }
528
529    /// Inject a raw prompt and get the code back directly.
530    /// Used primarily by the script runtime.
531    pub async fn inject_raw(&self, prompt: &str) -> Result<String> {
532        let template = Template::new("{{AI:gen}}")
533            .with_slot("gen", prompt);
534        
535        self.render(&template).await
536    }
537}
538
539/// Convenience function for one-line AI code injection.
540///
541/// # Example
542///
543/// ```rust,ignore
544/// let code = aether_core::inject!("Create a login form with email and password", OpenAiProvider::from_env()?);
545/// ```
546#[macro_export]
547macro_rules! inject {
548    ($prompt:expr, $provider:expr) => {{
549        use $crate::{InjectionEngine, Slot, Template};
550
551        let template = Template::new("{{AI:generated}}")
552            .with_slot("generated", $prompt);
553
554        let engine = InjectionEngine::new($provider);
555        engine.render(&template)
556    }};
557}
558
559/// Convenience function for synchronous one-line injection (blocking).
560#[macro_export]
561macro_rules! inject_sync {
562    ($prompt:expr, $provider:expr) => {{
563        tokio::runtime::Runtime::new()
564            .unwrap()
565            .block_on($crate::inject!($prompt, $provider))
566    }};
567}
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572    use crate::provider::MockProvider;
573
574    #[tokio::test]
575    async fn test_engine_render() {
576        let provider = MockProvider::new()
577            .with_response("content", "<p>Hello World</p>");
578
579        let engine = InjectionEngine::new(provider);
580
581        let template = Template::new("<div>{{AI:content}}</div>")
582            .with_slot("content", "Generate a paragraph");
583
584        let result = engine.render(&template).await.unwrap();
585        assert_eq!(result, "<div><p>Hello World</p></div>");
586    }
587
588    #[tokio::test]
589    async fn test_engine_with_context() {
590        let provider = MockProvider::new()
591            .with_response("button", "<button class='btn'>Click</button>");
592
593        let engine = InjectionEngine::new(provider)
594            .with_context(InjectionContext::new().with_framework("react"));
595
596        let template = Template::new("{{AI:button}}")
597            .with_slot("button", "Create a button");
598
599        let result = engine.render(&template).await.unwrap();
600        assert!(result.contains("button"));
601    }
602
603    #[tokio::test]
604    async fn test_parallel_generation() {
605        let provider = MockProvider::new()
606            .with_response("slot1", "code1")
607            .with_response("slot2", "code2");
608
609        let engine = InjectionEngine::new(provider).parallel(true);
610
611        let template = Template::new("{{AI:slot1}} | {{AI:slot2}}");
612
613        let result = engine.render(&template).await.unwrap();
614        assert!(result.contains("code1"));
615        assert!(result.contains("code2"));
616    }
617
618    #[tokio::test]
619    async fn test_max_retries_exceeded() {
620        let provider = MockProvider::new()
621            .with_response("fail", "invalid code");
622
623        // Use a validator that always fails
624        struct FailingValidator;
625        impl Validator for FailingValidator {
626            fn validate(&self, _: &SlotKind, _: &str) -> Result<ValidationResult> {
627                Ok(ValidationResult::Invalid("Always fails".to_string()))
628            }
629        }
630
631        let engine = InjectionEngine::new(provider)
632            .with_validator(FailingValidator)
633            .max_retries(1);
634
635        let template = Template::new("{{AI:fail}}");
636        let result = engine.render(&template).await;
637
638        match result {
639            Err(AetherError::MaxRetriesExceeded { slot, retries, .. }) => {
640                assert_eq!(slot, "fail");
641                assert_eq!(retries, 1);
642            }
643            _ => panic!("Expected MaxRetriesExceeded error, got {:?}", result),
644        }
645    }
646
647    #[tokio::test]
648    async fn test_auto_toon_activation() {
649        let provider = MockProvider::new()
650            .with_response("slot", "code");
651
652        // Set a very low threshold to force TOON
653        let config = AetherConfig::default().with_auto_toon_threshold(Some(5));
654        let engine = InjectionEngine::with_config(provider, config)
655            .with_context(InjectionContext::new().with_framework("very_long_framework_name"));
656
657        let template = Template::new("{{AI:slot}}");
658        let _ = engine.render(&template).await.unwrap();
659        
660        // Internal check: toon should be used because context length > 5
661        // Since we can't easily check internal state, we verify it runs without error
662    }
663}