aether_core/
engine.rs

1//! Injection Engine - The main orchestrator for AI code injection.
2//!
3//! This module provides the high-level API for rendering templates with AI-generated code.
4
5use crate::{
6    AetherError, AiProvider, InjectionContext, Result, Template, SlotKind,
7    provider::{GenerationRequest, GenerationResponse},
8    config::AetherConfig,
9};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tracing::{debug, info, instrument};
13use futures::stream::BoxStream;
14use crate::provider::StreamResponse;
15use crate::validation::{Validator, ValidationResult};
16use crate::cache::Cache;
17use crate::toon::Toon;
18pub use crate::observer::ObserverPtr;
19use std::hash::{Hash, Hasher};
20
21// ============================================================
22// Internal Types
23// ============================================================
24
25/// A simple FNV-1a hasher for stable hashing across runs.
26/// This ensures RenderSession cache keys remain stable even if the process restarts.
27struct StableHasher(u64);
28
29impl StableHasher {
30    fn new() -> Self {
31        Self(14695981039346656037)
32    }
33
34    fn hash<T: Hash>(t: &T) -> u64 {
35        let mut s = Self::new();
36        t.hash(&mut s);
37        s.finish()
38    }
39}
40
41impl Hasher for StableHasher {
42    fn finish(&self) -> u64 {
43        self.0
44    }
45    fn write(&mut self, bytes: &[u8]) {
46        for &byte in bytes {
47            self.0 ^= byte as u64;
48            self.0 = self.0.wrapping_mul(1099511628211);
49        }
50    }
51}
52
53/// Context passed to a generation worker.
54struct WorkerContext<P: AiProvider + ?Sized + 'static> {
55    provider: Arc<P>,
56    validator: Option<Arc<dyn Validator>>,
57    cache: Option<Arc<dyn Cache>>,
58    observer: Option<ObserverPtr>,
59    config: AetherConfig,
60}
61
62impl<P: AiProvider + ?Sized + 'static> Clone for WorkerContext<P> {
63    fn clone(&self) -> Self {
64        Self {
65            provider: Arc::clone(&self.provider),
66            validator: self.validator.clone(),
67            cache: self.cache.clone(),
68            observer: self.observer.clone(),
69            config: self.config.clone(),
70        }
71    }
72}
73
74/// The main engine for AI code injection.
75///
76/// # Example
77///
78/// ```rust,ignore
79/// use aether_core::{InjectionEngine, Template, AetherConfig};
80/// use aether_ai::OpenAiProvider;
81///
82/// let provider = OpenAiProvider::from_env()?;
83/// 
84/// // Using config
85/// let config = AetherConfig::from_env();
86/// let engine = InjectionEngine::with_config(provider, config);
87///
88/// // Or simple
89/// let engine = InjectionEngine::new(provider);
90/// ```
91pub struct InjectionEngine<P: AiProvider + ?Sized> {
92    /// The AI provider for code generation.
93    provider: Arc<P>,
94    
95    /// Optional validator for self-healing.
96    validator: Option<Arc<dyn Validator>>,
97
98    /// Optional cache for performance/cost optimization.
99    cache: Option<Arc<dyn Cache>>,
100
101    /// Engine configuration.
102    config: AetherConfig,
103
104    /// Global context applied to all generations.
105    global_context: InjectionContext,
106
107    /// Optional observer for tracking events.
108    observer: Option<ObserverPtr>,
109}
110
111/// A session for tracking incremental rendering state.
112/// Holds fingerprints of slots and context to identify changes.
113#[derive(Debug, Clone, Default)]
114pub struct RenderSession {
115    /// Cached results indexed by (SlotHash, ContextHash)
116    pub results: HashMap<(u64, u64), String>,
117}
118
119impl RenderSession {
120    /// Create a new empty render session.
121    pub fn new() -> Self {
122        Self::default()
123    }
124
125    /// Calculate a stable hash for any hashable object.
126    pub fn hash<T: Hash>(t: &T) -> u64 {
127        StableHasher::hash(t)
128    }
129}
130
131impl<P: AiProvider + ?Sized + 'static> InjectionEngine<P> {
132    /// Create a new injection engine with the given provider and default config.
133    pub fn new(provider: P) -> Self where P: Sized {
134        Self::with_config(provider, AetherConfig::default())
135    }
136
137    /// Internal: Create a raw engine without full config for script-based calls.
138    pub fn new_raw(provider: Arc<P>) -> Self {
139        Self {
140            provider,
141            validator: None,
142            cache: None,
143            config: AetherConfig::default(),
144            global_context: InjectionContext::default(),
145            observer: None,
146        }
147    }
148
149    /// Create a new injection engine with the given provider and config.
150    pub fn with_config(provider: P, config: AetherConfig) -> Self where P: Sized {
151        Self::with_config_arc(Arc::new(provider), config)
152    }
153
154    /// Create a new injection engine with the given provider (Arc) and config.
155    pub fn with_config_arc(provider: Arc<P>, config: AetherConfig) -> Self {
156        let validator: Option<Arc<dyn Validator>> = if config.healing_enabled {
157            Some(Arc::new(crate::validation::MultiValidator::new()))
158        } else {
159            None
160        };
161
162        Self {
163            provider,
164            validator,
165            cache: None,
166            config,
167            global_context: InjectionContext::default(),
168            observer: None,
169        }
170    }
171
172    /// Set the cache for performance optimization.
173    pub fn with_cache(mut self, cache: impl Cache + 'static) -> Self {
174        self.cache = Some(Arc::new(cache));
175        self
176    }
177
178    /// Enable or disable TOON format for context.
179    pub fn with_toon(mut self, enabled: bool) -> Self {
180        self.config.toon_enabled = enabled;
181        self
182    }
183
184    /// Set the validator for self-healing.
185    pub fn with_validator(mut self, validator: impl Validator + 'static) -> Self {
186        self.validator = Some(Arc::new(validator));
187        self
188    }
189
190    /// Set the global context.
191    pub fn with_context(mut self, context: InjectionContext) -> Self {
192        self.global_context = context;
193        self
194    }
195
196    /// Enable or disable parallel generation.
197    pub fn parallel(mut self, enabled: bool) -> Self {
198        self.config.parallel = enabled;
199        self
200    }
201
202    /// Set maximum retries for failed generations.
203    pub fn max_retries(mut self, retries: u32) -> Self {
204        self.config.max_retries = retries;
205        self
206    }
207
208    /// Get the current cache.
209    pub fn cache(&self) -> Option<Arc<dyn Cache>> {
210        self.cache.clone()
211    }
212
213    /// Set an observer for tracking events.
214    pub fn with_observer(mut self, observer: impl crate::observer::EngineObserver + 'static) -> Self {
215        self.observer = Some(Arc::new(observer));
216        self
217    }
218
219    /// Render a template with AI-generated code.
220    ///
221    /// This method will generate code for all slots in the template
222    /// and return the final rendered content.
223    #[instrument(skip(self, template), fields(template_name = %template.name))]
224    pub async fn render(&self, template: &Template) -> Result<String> {
225        info!("Rendering template: {}", template.name);
226
227        let injections = self.generate_all(template, None).await?;
228        template.render(&injections)
229    }
230
231    /// Render a template with additional context.
232    #[instrument(skip(self, template, context), fields(template_name = %template.name))]
233    pub async fn render_with_context(
234        &self,
235        template: &Template,
236        context: InjectionContext,
237    ) -> Result<String> {
238        info!("Rendering template with context: {}", template.name);
239
240        let injections = self.generate_all(template, Some(context)).await?;
241        template.render(&injections)
242    }
243
244    /// Render a template incrementally using a session.
245    /// 
246    /// This will only generate code for slots that have changed 
247    /// based on their definition and the current context.
248    #[instrument(skip(self, template, session), fields(template_name = %template.name))]
249    pub async fn render_incremental(
250        &self,
251        template: &Template,
252        session: &mut RenderSession,
253    ) -> Result<String> {
254        info!("Incrementally rendering template: {}", template.name);
255        
256        let context_hash = RenderSession::hash(&self.global_context);
257        let mut injections = HashMap::new();
258        
259        for (name, slot) in &template.slots {
260            let slot_hash = RenderSession::hash(slot);
261            let key = (slot_hash, context_hash);
262            
263            if let Some(cached) = session.results.get(&key) {
264                debug!("Incremental hit for slot: {}", name);
265                injections.insert(name.clone(), cached.clone());
266            } else {
267                debug!("Incremental miss for slot: {}", name);
268                let code = self.generate_slot(template, name).await?;
269                session.results.insert(key, code.clone());
270                injections.insert(name.clone(), code);
271            }
272        }
273        
274        template.render(&injections)
275    }
276
277    async fn generate_all(
278        &self,
279        template: &Template,
280        extra_context: Option<InjectionContext>,
281    ) -> Result<HashMap<String, String>> {
282        let mut injections = HashMap::new();
283
284        // Build base context first to check length
285        let base_context = if let Some(ref ctx) = extra_context {
286            format!("{}\n{}", self.global_context.to_prompt(), ctx.to_prompt())
287        } else {
288            self.global_context.to_prompt()
289        };
290
291        // Determine if TOON should be used (explicit or auto-threshold)
292        let should_use_toon = self.config.toon_enabled || self.config.auto_toon_threshold
293            .map(|threshold| base_context.len() >= threshold)
294            .unwrap_or(false);
295
296        let mut context_prompt = if should_use_toon {
297            // TOON optimization - compress context
298            let context_value = serde_json::to_value(&self.global_context)
299                .map_err(|e| AetherError::ContextSerializationError(e.to_string()))?;
300            let toon_ctx = Toon::serialize(&context_value);
301            
302            if let Some(ref obs) = self.observer {
303                let original_size = base_context.len();
304                let compressed_size = toon_ctx.len();
305                let saved = if original_size > compressed_size { original_size - compressed_size } else { 0 };
306                
307                obs.on_metadata("global", "toon_compression_metrics", serde_json::json!({
308                    "original_chars": original_size,
309                    "compressed_chars": compressed_size,
310                    "saved_chars": saved,
311                    "ratio": (compressed_size as f64 / original_size.max(1) as f64)
312                }));
313            }
314
315            format!(
316                "{}\n{}\n\n{}",
317                self.config.prompt_toon_header,
318                toon_ctx,
319                self.config.prompt_toon_note
320            )
321        } else {
322            base_context
323        };
324
325        // If self-healing is enabled, encourage AI to pass tests
326        if self.validator.is_some() {
327            context_prompt.push_str(&self.config.prompt_tdd_notice);
328        }
329        
330        let context_prompt = Arc::new(context_prompt);
331
332        if self.config.parallel {
333            injections = self
334                .generate_parallel(template, context_prompt)
335                .await?;
336        } else {
337            for (name, slot) in &template.slots {
338                debug!("Generating code for slot: {}", name);
339                let id = uuid::Uuid::new_v4().to_string();
340
341                let request = GenerationRequest {
342                    max_tokens: slot.max_tokens,
343                    model: slot.model.clone(),
344                    slot: slot.clone(),
345                    context: Some((*context_prompt).clone()),
346                    system_prompt: None,
347                };
348
349                if let Some(ref obs) = self.observer {
350                    obs.on_start(&id, &template.name, name, &request);
351                }
352
353                match self.generate_with_retry(request, &id).await {
354                    Ok(response) => {
355                        if let Some(ref obs) = self.observer {
356                            obs.on_success(&id, &response);
357                        }
358                        injections.insert(name.clone(), response.code);
359                    }
360                    Err(e) => {
361                        if let Some(ref obs) = self.observer {
362                            obs.on_failure(&id, &e.to_string());
363                        }
364                        return Err(e);
365                    }
366                }
367            }
368        }
369
370        Ok(injections)
371    }
372
373    async fn generate_parallel(
374        &self,
375        template: &Template,
376        context_prompt: Arc<String>,
377    ) -> Result<HashMap<String, String>> {
378        use tokio::task::JoinSet;
379
380        let mut join_set = JoinSet::new();
381
382        for (name, slot) in template.slots.clone() {
383            let context = Arc::clone(&context_prompt);
384            let worker_ctx = WorkerContext {
385                provider: Arc::clone(&self.provider),
386                validator: self.validator.clone(),
387                cache: self.cache.clone(),
388                observer: self.observer.clone(),
389                config: self.config.clone(),
390            };
391            let template_name = template.name.clone();
392
393            join_set.spawn(async move {
394                let id = uuid::Uuid::new_v4().to_string();
395                let request = GenerationRequest {
396                    max_tokens: slot.max_tokens,
397                    model: slot.model.clone(),
398                    slot,
399                    context: Some((*context).clone()),
400                    system_prompt: None,
401                };
402
403                if let Some(ref obs) = worker_ctx.observer {
404                    obs.on_start(&id, &template_name, &name, &request);
405                }
406
407                match Self::generate_with_healing_static(worker_ctx.clone(), request, &id).await {
408                    Ok(response) => {
409                        if let Some(ref obs) = worker_ctx.observer {
410                            obs.on_success(&id, &response);
411                        }
412                        Ok::<_, AetherError>((name, response.code))
413                    }
414                    Err(e) => {
415                        if let Some(ref obs) = worker_ctx.observer {
416                            obs.on_failure(&id, &e.to_string());
417                        }
418                        Err(e)
419                    }
420                }
421            });
422        }
423
424        let mut injections = HashMap::new();
425        while let Some(result) = join_set.join_next().await {
426            let (name, code) = result.map_err(|e| AetherError::InjectionError(e.to_string()))??;
427            injections.insert(name, code);
428        }
429
430        Ok(injections)
431    }
432
433    /// Generate with self-healing logic.
434    async fn generate_with_retry(&self, request: GenerationRequest, id: &str) -> Result<GenerationResponse> {
435        let worker_ctx = WorkerContext {
436            provider: Arc::clone(&self.provider),
437            validator: self.validator.clone(),
438            cache: self.cache.clone(),
439            observer: self.observer.clone(),
440            config: self.config.clone(),
441        };
442        Self::generate_with_healing_static(worker_ctx, request, id).await
443    }
444
445    /// Static version of generate with self-healing support.
446    async fn generate_with_healing_static(
447        ctx: WorkerContext<P>,
448        mut request: GenerationRequest,
449        id: &str,
450    ) -> Result<GenerationResponse> {
451        // 0. Check cache first
452        let cache_key = if ctx.cache.is_some() {
453            // Use stable hash for cache key to optimize memory and maintain consistency
454            let mut s = StableHasher::new();
455            request.slot.prompt.hash(&mut s);
456            request.context.as_deref().unwrap_or("").hash(&mut s);
457            request.model.as_deref().unwrap_or("").hash(&mut s);
458            request.max_tokens.unwrap_or(0).hash(&mut s);
459            Some(format!("aether:cache:{:x}", s.finish()))
460        } else {
461            None
462        };
463
464        if let (Some(ref c), Some(ref key)) = (ctx.cache.as_ref(), &cache_key) {
465            if let Some(cached_code) = c.get(key) {
466                debug!("Cache hit for slot: {}", request.slot.name);
467                return Ok(GenerationResponse {
468                    code: cached_code,
469                    tokens_used: None,
470                    metadata: Some(serde_json::json!({"cache": "hit"})),
471                });
472            }
473        }
474
475        let mut last_error = None;
476        let mut previous_code: Option<String> = None;
477
478        for attempt in 0..=ctx.config.max_retries {
479            // 1. Generate code
480            let mut response = match ctx.provider.generate(request.clone()).await {
481                Ok(r) => r,
482                Err(e) => {
483                    debug!("Generation attempt {} failed: {}", attempt + 1, e);
484                    last_error = Some(e);
485                    if attempt < ctx.config.max_retries {
486                        tokio::time::sleep(std::time::Duration::from_millis(ctx.config.retry_backoff_ms * (attempt as u64 + 1))).await;
487                        continue;
488                    }
489                    return Err(last_error.unwrap());
490                }
491            };
492
493            // Detect infinite loops (AI generating exact same failing code)
494            if let Some(prev) = &previous_code {
495                if prev == &response.code {
496                    debug!("Self-healing: AI generated identical code for slot '{}', aborting.", request.slot.name);
497                    return Err(AetherError::MaxRetriesExceeded {
498                        slot: request.slot.name.clone(),
499                        retries: attempt,
500                        last_error: "AI stuck in loop (generated identical code)".to_string()
501                    });
502                }
503            }
504            previous_code = Some(response.code.clone());
505
506            // 2. Validate and Heal if validator is present
507            if let Some(ref val) = ctx.validator {
508                // Apply formatting (Linter compliance)
509                if let Ok(formatted) = val.format(&request.slot.kind, &response.code) {
510                    response.code = formatted;
511                }
512                
513                // Use validate_with_slot to support TDD harnesses
514                match val.validate_with_slot(&request.slot, &response.code)? {
515                    ValidationResult::Valid => {
516                        // Success! Cache if enabled
517                        if let (Some(ref c), Some(ref key)) = (ctx.cache.as_ref(), &cache_key) {
518                            c.set(key, response.code.clone());
519                        }
520                        return Ok(response);
521                    },
522                    ValidationResult::Invalid(err_msg) => {
523                        info!("Self-healing: Validation failed for slot '{}', attempt {}. Error: {}", 
524                            request.slot.name, attempt + 1, err_msg);
525                        
526                        if let Some(ref obs) = ctx.observer {
527                            obs.on_healing_step(id, attempt + 1, &err_msg);
528                        }
529
530                        last_error = Some(AetherError::ValidationFailed { 
531                            slot: request.slot.name.clone(), 
532                            error: err_msg.clone() 
533                        });
534
535                        if attempt < ctx.config.max_retries {
536                            // Feedback Loop: Add error to prompt for next attempt
537                            request.slot.prompt = format!(
538                                "{}\n\n{}{}",
539                                request.slot.prompt,
540                                ctx.config.prompt_healing_feedback,
541                                err_msg
542                            );
543                            continue;
544                        }
545                    }
546                }
547            } else {
548                // No validator, just cache and return
549                if let (Some(ref c), Some(ref key)) = (ctx.cache.as_ref(), &cache_key) {
550                    c.set(key, response.code.clone());
551                }
552                return Ok(response);
553            }
554        }
555
556        let final_err = AetherError::MaxRetriesExceeded { 
557            slot: request.slot.name, 
558            retries: ctx.config.max_retries, 
559            last_error: last_error.map(|e| e.to_string()).unwrap_or_else(|| "Healing failed without specific error".to_string())
560        };
561        Err(final_err)
562    }
563
564    /// Generate code for a single slot.
565    pub async fn generate_slot(&self, template: &Template, slot_name: &str) -> Result<String> {
566        let slot = template
567            .slots
568            .get(slot_name)
569            .ok_or_else(|| AetherError::SlotNotFound(slot_name.to_string()))?;
570
571        let request = GenerationRequest {
572            max_tokens: slot.max_tokens,
573            model: slot.model.clone(),
574            slot: slot.clone(),
575            context: Some(self.global_context.to_prompt()),
576            system_prompt: None,
577        };
578
579        let id = uuid::Uuid::new_v4().to_string();
580        if let Some(ref obs) = self.observer {
581            obs.on_start(&id, &template.name, slot_name, &request);
582        }
583
584        match self.generate_with_retry(request, &id).await {
585            Ok(response) => {
586                if let Some(ref obs) = self.observer {
587                    obs.on_success(&id, &response);
588                }
589                Ok(response.code)
590            }
591            Err(e) => {
592                if let Some(ref obs) = self.observer {
593                    obs.on_failure(&id, &e.to_string());
594                }
595                Err(e)
596            }
597        }
598    }
599
600    /// Generate code for a single slot as a stream.
601    pub fn generate_slot_stream(
602        &self,
603        template: &Template,
604        slot_name: &str,
605    ) -> Result<BoxStream<'static, Result<StreamResponse>>> {
606        let slot = template
607            .slots
608            .get(slot_name)
609            .ok_or_else(|| AetherError::SlotNotFound(slot_name.to_string()))?;
610
611        let request = GenerationRequest {
612            max_tokens: slot.max_tokens,
613            model: slot.model.clone(),
614            slot: slot.clone(),
615            context: Some(self.global_context.to_prompt()),
616            system_prompt: None,
617        };
618
619        Ok(self.provider.generate_stream(request))
620    }
621
622    /// Inject a raw prompt and get the code back directly.
623    /// Used primarily by the script runtime.
624    pub async fn inject_raw(&self, prompt: &str) -> Result<String> {
625        let template = Template::new("{{AI:gen}}")
626            .with_slot("gen", prompt);
627        
628        self.render(&template).await
629    }
630}
631
632/// Convenience function for one-line AI code injection.
633///
634/// # Example
635///
636/// ```rust,ignore
637/// let code = aether_core::inject!("Create a login form with email and password", OpenAiProvider::from_env()?);
638/// ```
639#[macro_export]
640macro_rules! inject {
641    ($prompt:expr, $provider:expr) => {{
642        use $crate::{InjectionEngine, Slot, Template};
643
644        let template = Template::new("{{AI:generated}}")
645            .with_slot("generated", $prompt);
646
647        let engine = InjectionEngine::new($provider);
648        engine.render(&template)
649    }};
650}
651
652/// Convenience function for synchronous one-line injection (blocking).
653#[macro_export]
654macro_rules! inject_sync {
655    ($prompt:expr, $provider:expr) => {{
656        tokio::runtime::Runtime::new()
657            .unwrap()
658            .block_on($crate::inject!($prompt, $provider))
659    }};
660}
661
662#[cfg(test)]
663mod tests {
664    use super::*;
665    use crate::provider::MockProvider;
666
667    #[tokio::test]
668    async fn test_engine_render() {
669        let provider = MockProvider::new()
670            .with_response("content", "<p>Hello World</p>");
671
672        let engine = InjectionEngine::new(provider);
673
674        let template = Template::new("<div>{{AI:content}}</div>")
675            .with_slot("content", "Generate a paragraph");
676
677        let result = engine.render(&template).await.unwrap();
678        assert_eq!(result, "<div><p>Hello World</p></div>");
679    }
680
681    #[tokio::test]
682    async fn test_engine_with_context() {
683        let provider = MockProvider::new()
684            .with_response("button", "<button class='btn'>Click</button>");
685
686        let engine = InjectionEngine::new(provider)
687            .with_context(InjectionContext::new().with_framework("react"));
688
689        let template = Template::new("{{AI:button}}")
690            .with_slot("button", "Create a button");
691
692        let result = engine.render(&template).await.unwrap();
693        assert!(result.contains("button"));
694    }
695
696    #[tokio::test]
697    async fn test_parallel_generation() {
698        let provider = MockProvider::new()
699            .with_response("slot1", "code1")
700            .with_response("slot2", "code2");
701
702        let engine = InjectionEngine::new(provider).parallel(true);
703
704        let template = Template::new("{{AI:slot1}} | {{AI:slot2}}");
705
706        let result = engine.render(&template).await.unwrap();
707        assert!(result.contains("code1"));
708        assert!(result.contains("code2"));
709    }
710
711    #[tokio::test]
712    async fn test_max_retries_exceeded() {
713        let provider = MockProvider::new()
714            .with_response("fail", "invalid code");
715
716        // Use a validator that always fails
717        struct FailingValidator;
718        impl Validator for FailingValidator {
719            fn validate(&self, _: &SlotKind, _: &str) -> Result<ValidationResult> {
720                Ok(ValidationResult::Invalid("Always fails".to_string()))
721            }
722            fn format(&self, _: &SlotKind, code: &str) -> Result<String> {
723                Ok(code.to_string())
724            }
725        }
726
727        let engine = InjectionEngine::new(provider)
728            .with_validator(FailingValidator)
729            .max_retries(1);
730
731        let template = Template::new("{{AI:fail}}");
732        let result = engine.render(&template).await;
733
734        match result {
735            Err(AetherError::MaxRetriesExceeded { slot, retries, .. }) => {
736                assert_eq!(slot, "fail");
737                assert_eq!(retries, 1);
738            }
739            _ => panic!("Expected MaxRetriesExceeded error, got {:?}", result),
740        }
741    }
742
743    #[tokio::test]
744    async fn test_auto_toon_activation() {
745        let provider = MockProvider::new()
746            .with_response("slot", "code");
747
748        // Set a very low threshold to force TOON
749        let config = AetherConfig::default().with_auto_toon_threshold(Some(5));
750        let engine = InjectionEngine::with_config(provider, config)
751            .with_context(InjectionContext::new().with_framework("very_long_framework_name"));
752
753        let template = Template::new("{{AI:slot}}");
754        let _ = engine.render(&template).await.unwrap();
755        
756        // Internal check: toon should be used because context length > 5
757        // Since we can't easily check internal state, we verify it runs without error
758    }
759}