1use crate::{
6 AetherError, AiProvider, InjectionContext, Result, Template, SlotKind,
7 provider::{GenerationRequest, GenerationResponse},
8 config::AetherConfig,
9};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tracing::{debug, info, instrument};
13use futures::stream::BoxStream;
14use crate::provider::StreamResponse;
15use crate::validation::{Validator, ValidationResult};
16use crate::cache::Cache;
17use crate::toon::Toon;
18pub use crate::observer::ObserverPtr;
19use std::hash::{Hash, Hasher};
20
21struct StableHasher(u64);
28
29impl StableHasher {
30 fn new() -> Self {
31 Self(14695981039346656037)
32 }
33
34 fn hash<T: Hash>(t: &T) -> u64 {
35 let mut s = Self::new();
36 t.hash(&mut s);
37 s.finish()
38 }
39}
40
41impl Hasher for StableHasher {
42 fn finish(&self) -> u64 {
43 self.0
44 }
45 fn write(&mut self, bytes: &[u8]) {
46 for &byte in bytes {
47 self.0 ^= byte as u64;
48 self.0 = self.0.wrapping_mul(1099511628211);
49 }
50 }
51}
52
53struct WorkerContext<P: AiProvider + ?Sized + 'static> {
55 provider: Arc<P>,
56 validator: Option<Arc<dyn Validator>>,
57 cache: Option<Arc<dyn Cache>>,
58 observer: Option<ObserverPtr>,
59 config: AetherConfig,
60}
61
62impl<P: AiProvider + ?Sized + 'static> Clone for WorkerContext<P> {
63 fn clone(&self) -> Self {
64 Self {
65 provider: Arc::clone(&self.provider),
66 validator: self.validator.clone(),
67 cache: self.cache.clone(),
68 observer: self.observer.clone(),
69 config: self.config.clone(),
70 }
71 }
72}
73
74pub struct InjectionEngine<P: AiProvider + ?Sized> {
92 provider: Arc<P>,
94
95 validator: Option<Arc<dyn Validator>>,
97
98 cache: Option<Arc<dyn Cache>>,
100
101 config: AetherConfig,
103
104 global_context: InjectionContext,
106
107 observer: Option<ObserverPtr>,
109}
110
111#[derive(Debug, Clone, Default)]
114pub struct RenderSession {
115 pub results: HashMap<(u64, u64), String>,
117}
118
119impl RenderSession {
120 pub fn new() -> Self {
122 Self::default()
123 }
124
125 pub fn hash<T: Hash>(t: &T) -> u64 {
127 StableHasher::hash(t)
128 }
129}
130
131impl<P: AiProvider + ?Sized + 'static> InjectionEngine<P> {
132 pub fn new(provider: P) -> Self where P: Sized {
134 Self::with_config(provider, AetherConfig::default())
135 }
136
137 pub fn new_raw(provider: Arc<P>) -> Self {
139 Self {
140 provider,
141 validator: None,
142 cache: None,
143 config: AetherConfig::default(),
144 global_context: InjectionContext::default(),
145 observer: None,
146 }
147 }
148
149 pub fn with_config(provider: P, config: AetherConfig) -> Self where P: Sized {
151 Self::with_config_arc(Arc::new(provider), config)
152 }
153
154 pub fn with_config_arc(provider: Arc<P>, config: AetherConfig) -> Self {
156 let validator: Option<Arc<dyn Validator>> = if config.healing_enabled {
157 Some(Arc::new(crate::validation::MultiValidator::new()))
158 } else {
159 None
160 };
161
162 Self {
163 provider,
164 validator,
165 cache: None,
166 config,
167 global_context: InjectionContext::default(),
168 observer: None,
169 }
170 }
171
172 pub fn with_cache(mut self, cache: impl Cache + 'static) -> Self {
174 self.cache = Some(Arc::new(cache));
175 self
176 }
177
178 pub fn with_toon(mut self, enabled: bool) -> Self {
180 self.config.toon_enabled = enabled;
181 self
182 }
183
184 pub fn with_validator(mut self, validator: impl Validator + 'static) -> Self {
186 self.validator = Some(Arc::new(validator));
187 self
188 }
189
190 pub fn with_context(mut self, context: InjectionContext) -> Self {
192 self.global_context = context;
193 self
194 }
195
196 pub fn parallel(mut self, enabled: bool) -> Self {
198 self.config.parallel = enabled;
199 self
200 }
201
202 pub fn max_retries(mut self, retries: u32) -> Self {
204 self.config.max_retries = retries;
205 self
206 }
207
208 pub fn cache(&self) -> Option<Arc<dyn Cache>> {
210 self.cache.clone()
211 }
212
213 pub fn with_observer(mut self, observer: impl crate::observer::EngineObserver + 'static) -> Self {
215 self.observer = Some(Arc::new(observer));
216 self
217 }
218
219 #[instrument(skip(self, template), fields(template_name = %template.name))]
224 pub async fn render(&self, template: &Template) -> Result<String> {
225 info!("Rendering template: {}", template.name);
226
227 let injections = self.generate_all(template, None).await?;
228 template.render(&injections)
229 }
230
231 #[instrument(skip(self, template, context), fields(template_name = %template.name))]
233 pub async fn render_with_context(
234 &self,
235 template: &Template,
236 context: InjectionContext,
237 ) -> Result<String> {
238 info!("Rendering template with context: {}", template.name);
239
240 let injections = self.generate_all(template, Some(context)).await?;
241 template.render(&injections)
242 }
243
244 #[instrument(skip(self, template, session), fields(template_name = %template.name))]
249 pub async fn render_incremental(
250 &self,
251 template: &Template,
252 session: &mut RenderSession,
253 ) -> Result<String> {
254 info!("Incrementally rendering template: {}", template.name);
255
256 let context_hash = RenderSession::hash(&self.global_context);
257 let mut injections = HashMap::new();
258
259 for (name, slot) in &template.slots {
260 let slot_hash = RenderSession::hash(slot);
261 let key = (slot_hash, context_hash);
262
263 if let Some(cached) = session.results.get(&key) {
264 debug!("Incremental hit for slot: {}", name);
265 injections.insert(name.clone(), cached.clone());
266 } else {
267 debug!("Incremental miss for slot: {}", name);
268 let code = self.generate_slot(template, name).await?;
269 session.results.insert(key, code.clone());
270 injections.insert(name.clone(), code);
271 }
272 }
273
274 template.render(&injections)
275 }
276
277 async fn generate_all(
278 &self,
279 template: &Template,
280 extra_context: Option<InjectionContext>,
281 ) -> Result<HashMap<String, String>> {
282 let mut injections = HashMap::new();
283
284 let base_context = if let Some(ref ctx) = extra_context {
286 format!("{}\n{}", self.global_context.to_prompt(), ctx.to_prompt())
287 } else {
288 self.global_context.to_prompt()
289 };
290
291 let should_use_toon = self.config.toon_enabled || self.config.auto_toon_threshold
293 .map(|threshold| base_context.len() >= threshold)
294 .unwrap_or(false);
295
296 let mut context_prompt = if should_use_toon {
297 let context_value = serde_json::to_value(&self.global_context)
299 .map_err(|e| AetherError::ContextSerializationError(e.to_string()))?;
300 let toon_ctx = Toon::serialize(&context_value);
301
302 if let Some(ref obs) = self.observer {
303 let original_size = base_context.len();
304 let compressed_size = toon_ctx.len();
305 let saved = if original_size > compressed_size { original_size - compressed_size } else { 0 };
306
307 obs.on_metadata("global", "toon_compression_metrics", serde_json::json!({
308 "original_chars": original_size,
309 "compressed_chars": compressed_size,
310 "saved_chars": saved,
311 "ratio": (compressed_size as f64 / original_size.max(1) as f64)
312 }));
313 }
314
315 format!(
316 "{}\n{}\n\n{}",
317 self.config.prompt_toon_header,
318 toon_ctx,
319 self.config.prompt_toon_note
320 )
321 } else {
322 base_context
323 };
324
325 if self.validator.is_some() {
327 context_prompt.push_str(&self.config.prompt_tdd_notice);
328 }
329
330 let context_prompt = Arc::new(context_prompt);
331
332 if self.config.parallel {
333 injections = self
334 .generate_parallel(template, context_prompt)
335 .await?;
336 } else {
337 for (name, slot) in &template.slots {
338 debug!("Generating code for slot: {}", name);
339 let id = uuid::Uuid::new_v4().to_string();
340
341 let request = GenerationRequest {
342 max_tokens: slot.max_tokens,
343 model: slot.model.clone(),
344 slot: slot.clone(),
345 context: Some((*context_prompt).clone()),
346 system_prompt: None,
347 };
348
349 if let Some(ref obs) = self.observer {
350 obs.on_start(&id, &template.name, name, &request);
351 }
352
353 match self.generate_with_retry(request, &id).await {
354 Ok(response) => {
355 if let Some(ref obs) = self.observer {
356 obs.on_success(&id, &response);
357 }
358 injections.insert(name.clone(), response.code);
359 }
360 Err(e) => {
361 if let Some(ref obs) = self.observer {
362 obs.on_failure(&id, &e.to_string());
363 }
364 return Err(e);
365 }
366 }
367 }
368 }
369
370 Ok(injections)
371 }
372
373 async fn generate_parallel(
374 &self,
375 template: &Template,
376 context_prompt: Arc<String>,
377 ) -> Result<HashMap<String, String>> {
378 use tokio::task::JoinSet;
379
380 let mut join_set = JoinSet::new();
381
382 for (name, slot) in template.slots.clone() {
383 let context = Arc::clone(&context_prompt);
384 let worker_ctx = WorkerContext {
385 provider: Arc::clone(&self.provider),
386 validator: self.validator.clone(),
387 cache: self.cache.clone(),
388 observer: self.observer.clone(),
389 config: self.config.clone(),
390 };
391 let template_name = template.name.clone();
392
393 join_set.spawn(async move {
394 let id = uuid::Uuid::new_v4().to_string();
395 let request = GenerationRequest {
396 max_tokens: slot.max_tokens,
397 model: slot.model.clone(),
398 slot,
399 context: Some((*context).clone()),
400 system_prompt: None,
401 };
402
403 if let Some(ref obs) = worker_ctx.observer {
404 obs.on_start(&id, &template_name, &name, &request);
405 }
406
407 match Self::generate_with_healing_static(worker_ctx.clone(), request, &id).await {
408 Ok(response) => {
409 if let Some(ref obs) = worker_ctx.observer {
410 obs.on_success(&id, &response);
411 }
412 Ok::<_, AetherError>((name, response.code))
413 }
414 Err(e) => {
415 if let Some(ref obs) = worker_ctx.observer {
416 obs.on_failure(&id, &e.to_string());
417 }
418 Err(e)
419 }
420 }
421 });
422 }
423
424 let mut injections = HashMap::new();
425 while let Some(result) = join_set.join_next().await {
426 let (name, code) = result.map_err(|e| AetherError::InjectionError(e.to_string()))??;
427 injections.insert(name, code);
428 }
429
430 Ok(injections)
431 }
432
433 async fn generate_with_retry(&self, request: GenerationRequest, id: &str) -> Result<GenerationResponse> {
435 let worker_ctx = WorkerContext {
436 provider: Arc::clone(&self.provider),
437 validator: self.validator.clone(),
438 cache: self.cache.clone(),
439 observer: self.observer.clone(),
440 config: self.config.clone(),
441 };
442 Self::generate_with_healing_static(worker_ctx, request, id).await
443 }
444
445 async fn generate_with_healing_static(
447 ctx: WorkerContext<P>,
448 mut request: GenerationRequest,
449 id: &str,
450 ) -> Result<GenerationResponse> {
451 let cache_key = if ctx.cache.is_some() {
453 let mut s = StableHasher::new();
455 request.slot.prompt.hash(&mut s);
456 request.context.as_deref().unwrap_or("").hash(&mut s);
457 request.model.as_deref().unwrap_or("").hash(&mut s);
458 request.max_tokens.unwrap_or(0).hash(&mut s);
459 Some(format!("aether:cache:{:x}", s.finish()))
460 } else {
461 None
462 };
463
464 if let (Some(ref c), Some(ref key)) = (ctx.cache.as_ref(), &cache_key) {
465 if let Some(cached_code) = c.get(key) {
466 debug!("Cache hit for slot: {}", request.slot.name);
467 return Ok(GenerationResponse {
468 code: cached_code,
469 tokens_used: None,
470 metadata: Some(serde_json::json!({"cache": "hit"})),
471 });
472 }
473 }
474
475 let mut last_error = None;
476 let mut previous_code: Option<String> = None;
477
478 for attempt in 0..=ctx.config.max_retries {
479 let mut response = match ctx.provider.generate(request.clone()).await {
481 Ok(r) => r,
482 Err(e) => {
483 debug!("Generation attempt {} failed: {}", attempt + 1, e);
484 last_error = Some(e);
485 if attempt < ctx.config.max_retries {
486 tokio::time::sleep(std::time::Duration::from_millis(ctx.config.retry_backoff_ms * (attempt as u64 + 1))).await;
487 continue;
488 }
489 return Err(last_error.unwrap());
490 }
491 };
492
493 if let Some(prev) = &previous_code {
495 if prev == &response.code {
496 debug!("Self-healing: AI generated identical code for slot '{}', aborting.", request.slot.name);
497 return Err(AetherError::MaxRetriesExceeded {
498 slot: request.slot.name.clone(),
499 retries: attempt,
500 last_error: "AI stuck in loop (generated identical code)".to_string()
501 });
502 }
503 }
504 previous_code = Some(response.code.clone());
505
506 if let Some(ref val) = ctx.validator {
508 if let Ok(formatted) = val.format(&request.slot.kind, &response.code) {
510 response.code = formatted;
511 }
512
513 match val.validate_with_slot(&request.slot, &response.code)? {
515 ValidationResult::Valid => {
516 if let (Some(ref c), Some(ref key)) = (ctx.cache.as_ref(), &cache_key) {
518 c.set(key, response.code.clone());
519 }
520 return Ok(response);
521 },
522 ValidationResult::Invalid(err_msg) => {
523 info!("Self-healing: Validation failed for slot '{}', attempt {}. Error: {}",
524 request.slot.name, attempt + 1, err_msg);
525
526 if let Some(ref obs) = ctx.observer {
527 obs.on_healing_step(id, attempt + 1, &err_msg);
528 }
529
530 last_error = Some(AetherError::ValidationFailed {
531 slot: request.slot.name.clone(),
532 error: err_msg.clone()
533 });
534
535 if attempt < ctx.config.max_retries {
536 request.slot.prompt = format!(
538 "{}\n\n{}{}",
539 request.slot.prompt,
540 ctx.config.prompt_healing_feedback,
541 err_msg
542 );
543 continue;
544 }
545 }
546 }
547 } else {
548 if let (Some(ref c), Some(ref key)) = (ctx.cache.as_ref(), &cache_key) {
550 c.set(key, response.code.clone());
551 }
552 return Ok(response);
553 }
554 }
555
556 let final_err = AetherError::MaxRetriesExceeded {
557 slot: request.slot.name,
558 retries: ctx.config.max_retries,
559 last_error: last_error.map(|e| e.to_string()).unwrap_or_else(|| "Healing failed without specific error".to_string())
560 };
561 Err(final_err)
562 }
563
564 pub async fn generate_slot(&self, template: &Template, slot_name: &str) -> Result<String> {
566 let slot = template
567 .slots
568 .get(slot_name)
569 .ok_or_else(|| AetherError::SlotNotFound(slot_name.to_string()))?;
570
571 let request = GenerationRequest {
572 max_tokens: slot.max_tokens,
573 model: slot.model.clone(),
574 slot: slot.clone(),
575 context: Some(self.global_context.to_prompt()),
576 system_prompt: None,
577 };
578
579 let id = uuid::Uuid::new_v4().to_string();
580 if let Some(ref obs) = self.observer {
581 obs.on_start(&id, &template.name, slot_name, &request);
582 }
583
584 match self.generate_with_retry(request, &id).await {
585 Ok(response) => {
586 if let Some(ref obs) = self.observer {
587 obs.on_success(&id, &response);
588 }
589 Ok(response.code)
590 }
591 Err(e) => {
592 if let Some(ref obs) = self.observer {
593 obs.on_failure(&id, &e.to_string());
594 }
595 Err(e)
596 }
597 }
598 }
599
600 pub fn generate_slot_stream(
602 &self,
603 template: &Template,
604 slot_name: &str,
605 ) -> Result<BoxStream<'static, Result<StreamResponse>>> {
606 let slot = template
607 .slots
608 .get(slot_name)
609 .ok_or_else(|| AetherError::SlotNotFound(slot_name.to_string()))?;
610
611 let request = GenerationRequest {
612 max_tokens: slot.max_tokens,
613 model: slot.model.clone(),
614 slot: slot.clone(),
615 context: Some(self.global_context.to_prompt()),
616 system_prompt: None,
617 };
618
619 Ok(self.provider.generate_stream(request))
620 }
621
622 pub async fn inject_raw(&self, prompt: &str) -> Result<String> {
625 let template = Template::new("{{AI:gen}}")
626 .with_slot("gen", prompt);
627
628 self.render(&template).await
629 }
630}
631
632#[macro_export]
640macro_rules! inject {
641 ($prompt:expr, $provider:expr) => {{
642 use $crate::{InjectionEngine, Slot, Template};
643
644 let template = Template::new("{{AI:generated}}")
645 .with_slot("generated", $prompt);
646
647 let engine = InjectionEngine::new($provider);
648 engine.render(&template)
649 }};
650}
651
652#[macro_export]
654macro_rules! inject_sync {
655 ($prompt:expr, $provider:expr) => {{
656 tokio::runtime::Runtime::new()
657 .unwrap()
658 .block_on($crate::inject!($prompt, $provider))
659 }};
660}
661
662#[cfg(test)]
663mod tests {
664 use super::*;
665 use crate::provider::MockProvider;
666
667 #[tokio::test]
668 async fn test_engine_render() {
669 let provider = MockProvider::new()
670 .with_response("content", "<p>Hello World</p>");
671
672 let engine = InjectionEngine::new(provider);
673
674 let template = Template::new("<div>{{AI:content}}</div>")
675 .with_slot("content", "Generate a paragraph");
676
677 let result = engine.render(&template).await.unwrap();
678 assert_eq!(result, "<div><p>Hello World</p></div>");
679 }
680
681 #[tokio::test]
682 async fn test_engine_with_context() {
683 let provider = MockProvider::new()
684 .with_response("button", "<button class='btn'>Click</button>");
685
686 let engine = InjectionEngine::new(provider)
687 .with_context(InjectionContext::new().with_framework("react"));
688
689 let template = Template::new("{{AI:button}}")
690 .with_slot("button", "Create a button");
691
692 let result = engine.render(&template).await.unwrap();
693 assert!(result.contains("button"));
694 }
695
696 #[tokio::test]
697 async fn test_parallel_generation() {
698 let provider = MockProvider::new()
699 .with_response("slot1", "code1")
700 .with_response("slot2", "code2");
701
702 let engine = InjectionEngine::new(provider).parallel(true);
703
704 let template = Template::new("{{AI:slot1}} | {{AI:slot2}}");
705
706 let result = engine.render(&template).await.unwrap();
707 assert!(result.contains("code1"));
708 assert!(result.contains("code2"));
709 }
710
711 #[tokio::test]
712 async fn test_max_retries_exceeded() {
713 let provider = MockProvider::new()
714 .with_response("fail", "invalid code");
715
716 struct FailingValidator;
718 impl Validator for FailingValidator {
719 fn validate(&self, _: &SlotKind, _: &str) -> Result<ValidationResult> {
720 Ok(ValidationResult::Invalid("Always fails".to_string()))
721 }
722 fn format(&self, _: &SlotKind, code: &str) -> Result<String> {
723 Ok(code.to_string())
724 }
725 }
726
727 let engine = InjectionEngine::new(provider)
728 .with_validator(FailingValidator)
729 .max_retries(1);
730
731 let template = Template::new("{{AI:fail}}");
732 let result = engine.render(&template).await;
733
734 match result {
735 Err(AetherError::MaxRetriesExceeded { slot, retries, .. }) => {
736 assert_eq!(slot, "fail");
737 assert_eq!(retries, 1);
738 }
739 _ => panic!("Expected MaxRetriesExceeded error, got {:?}", result),
740 }
741 }
742
743 #[tokio::test]
744 async fn test_auto_toon_activation() {
745 let provider = MockProvider::new()
746 .with_response("slot", "code");
747
748 let config = AetherConfig::default().with_auto_toon_threshold(Some(5));
750 let engine = InjectionEngine::with_config(provider, config)
751 .with_context(InjectionContext::new().with_framework("very_long_framework_name"));
752
753 let template = Template::new("{{AI:slot}}");
754 let _ = engine.render(&template).await.unwrap();
755
756 }
759}