1use crate::{
6 AetherError, AiProvider, InjectionContext, Result, Template,
7 provider::{GenerationRequest, GenerationResponse},
8 config::AetherConfig,
9};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tracing::{debug, info, instrument};
13use futures::stream::BoxStream;
14use crate::provider::StreamResponse;
15use crate::validation::{Validator, ValidationResult};
16use crate::cache::Cache;
17use crate::toon::Toon;
18pub use crate::observer::ObserverPtr;
19use std::hash::{Hash, Hasher};
20use std::collections::hash_map::DefaultHasher;
21
22pub struct InjectionEngine<P: AiProvider> {
40 provider: Arc<P>,
42
43 validator: Option<Arc<dyn Validator>>,
45
46 cache: Option<Arc<dyn Cache>>,
48
49 use_toon: bool,
51
52 auto_toon_threshold: Option<usize>,
54
55 global_context: InjectionContext,
57
58 parallel: bool,
60
61 max_retries: u32,
63
64 observer: Option<ObserverPtr>,
66}
67
68#[derive(Debug, Clone, Default)]
71pub struct RenderSession {
72 pub results: HashMap<(u64, u64), String>,
74}
75
76impl RenderSession {
77 pub fn new() -> Self {
79 Self::default()
80 }
81
82 pub fn hash<T: Hash>(t: &T) -> u64 {
84 let mut s = DefaultHasher::new();
85 t.hash(&mut s);
86 s.finish()
87 }
88}
89
90impl<P: AiProvider + 'static> InjectionEngine<P> {
91 pub fn new(provider: P) -> Self {
93 Self::with_config(provider, AetherConfig::default())
94 }
95
96 pub fn new_raw(provider: Arc<P>) -> Self {
98 Self {
99 provider,
100 validator: None,
101 cache: None,
102 use_toon: false,
103 auto_toon_threshold: None,
104 global_context: InjectionContext::default(),
105 parallel: false,
106 max_retries: 0,
107 observer: None,
108 }
109 }
110
111 pub fn with_config(provider: P, config: AetherConfig) -> Self {
113 let validator: Option<Arc<dyn Validator>> = if config.healing_enabled {
114 Some(Arc::new(crate::validation::MultiValidator::new()))
115 } else {
116 None
117 };
118
119 Self {
120 provider: Arc::new(provider),
121 validator,
122 cache: None,
123 use_toon: config.toon_enabled,
124 auto_toon_threshold: config.auto_toon_threshold,
125 global_context: InjectionContext::default(),
126 parallel: config.parallel,
127 max_retries: config.max_retries,
128 observer: None,
129 }
130 }
131
132 pub fn with_cache(mut self, cache: impl Cache + 'static) -> Self {
134 self.cache = Some(Arc::new(cache));
135 self
136 }
137
138 pub fn with_toon(mut self, enabled: bool) -> Self {
140 self.use_toon = enabled;
141 self
142 }
143
144 pub fn with_validator(mut self, validator: impl Validator + 'static) -> Self {
146 self.validator = Some(Arc::new(validator));
147 self
148 }
149
150 pub fn with_context(mut self, context: InjectionContext) -> Self {
152 self.global_context = context;
153 self
154 }
155
156 pub fn parallel(mut self, enabled: bool) -> Self {
158 self.parallel = enabled;
159 self
160 }
161
162 pub fn max_retries(mut self, retries: u32) -> Self {
164 self.max_retries = retries;
165 self
166 }
167
168 pub fn with_observer(mut self, observer: impl crate::observer::EngineObserver + 'static) -> Self {
170 self.observer = Some(Arc::new(observer));
171 self
172 }
173
174 #[instrument(skip(self, template), fields(template_name = %template.name))]
179 pub async fn render(&self, template: &Template) -> Result<String> {
180 info!("Rendering template: {}", template.name);
181
182 let injections = self.generate_all(template, None).await?;
183 template.render(&injections)
184 }
185
186 #[instrument(skip(self, template, context), fields(template_name = %template.name))]
188 pub async fn render_with_context(
189 &self,
190 template: &Template,
191 context: InjectionContext,
192 ) -> Result<String> {
193 info!("Rendering template with context: {}", template.name);
194
195 let injections = self.generate_all(template, Some(context)).await?;
196 template.render(&injections)
197 }
198
199 #[instrument(skip(self, template, session), fields(template_name = %template.name))]
204 pub async fn render_incremental(
205 &self,
206 template: &Template,
207 session: &mut RenderSession,
208 ) -> Result<String> {
209 info!("Incrementally rendering template: {}", template.name);
210
211 let context_hash = RenderSession::hash(&self.global_context);
212 let mut injections = HashMap::new();
213
214 for (name, slot) in &template.slots {
215 let slot_hash = RenderSession::hash(slot);
216 let key = (slot_hash, context_hash);
217
218 if let Some(cached) = session.results.get(&key) {
219 debug!("Incremental hit for slot: {}", name);
220 injections.insert(name.clone(), cached.clone());
221 } else {
222 debug!("Incremental miss for slot: {}", name);
223 let code = self.generate_slot(template, name).await?;
224 session.results.insert(key, code.clone());
225 injections.insert(name.clone(), code);
226 }
227 }
228
229 template.render(&injections)
230 }
231
232 async fn generate_all(
233 &self,
234 template: &Template,
235 extra_context: Option<InjectionContext>,
236 ) -> Result<HashMap<String, String>> {
237 let mut injections = HashMap::new();
238
239 let base_context = if let Some(ref ctx) = extra_context {
241 format!("{}\n{}", self.global_context.to_prompt(), ctx.to_prompt())
242 } else {
243 self.global_context.to_prompt()
244 };
245
246 let should_use_toon = self.use_toon || self.auto_toon_threshold
248 .map(|threshold| base_context.len() >= threshold)
249 .unwrap_or(false);
250
251 let mut context_prompt = if should_use_toon {
252 let context_value = serde_json::to_value(&self.global_context)
254 .map_err(|e| AetherError::ContextSerializationError(e.to_string()))?;
255 let toon_ctx = Toon::serialize(&context_value);
256 format!(
257 "[CONTEXT:TOON]\n{}\n\n[TOON Protocol Note]\nTOON is a compact key:value mapping protocol. Each line represents 'key: value'. Use this context to inform your code generation, respecting the framework, language, and architectural constraints defined within.",
258 toon_ctx
259 )
260 } else {
261 base_context
262 };
263
264 if self.validator.is_some() {
266 context_prompt.push_str("\n\nIMPORTANT: The system is running in TDD (Test-Driven Development) mode. ");
267 context_prompt.push_str("Your code will be validated against compiler checks and functional tests. ");
268 context_prompt.push_str("If possible, include unit tests in your response to help self-verify. ");
269 context_prompt.push_str("If validation fails, you will receive feedback to fix the code.");
270 }
271
272 let context_prompt = Arc::new(context_prompt);
273
274 if self.parallel {
275 injections = self
276 .generate_parallel(template, context_prompt)
277 .await?;
278 } else {
279 for (name, slot) in &template.slots {
280 debug!("Generating code for slot: {}", name);
281 let id = uuid::Uuid::new_v4().to_string();
282
283 let request = GenerationRequest {
284 slot: slot.clone(),
285 context: Some((*context_prompt).clone()),
286 system_prompt: None,
287 };
288
289 if let Some(ref obs) = self.observer {
290 obs.on_start(&id, &template.name, name, &request);
291 }
292
293 match self.generate_with_retry(request, &id).await {
294 Ok(response) => {
295 if let Some(ref obs) = self.observer {
296 obs.on_success(&id, &response);
297 }
298 injections.insert(name.clone(), response.code);
299 }
300 Err(e) => {
301 if let Some(ref obs) = self.observer {
302 obs.on_failure(&id, &e.to_string());
303 }
304 return Err(e);
305 }
306 }
307 }
308 }
309
310 Ok(injections)
311 }
312
313 async fn generate_parallel(
314 &self,
315 template: &Template,
316 context_prompt: Arc<String>,
317 ) -> Result<HashMap<String, String>> {
318 use tokio::task::JoinSet;
319
320 let mut join_set = JoinSet::new();
321
322 for (name, slot) in template.slots.clone() {
323 let provider = Arc::clone(&self.provider);
324 let validator = self.validator.clone();
325 let cache = self.cache.clone();
326 let context = Arc::clone(&context_prompt);
327 let max_retries = self.max_retries;
328 let template_name = template.name.clone();
329 let observer = self.observer.clone();
330
331 join_set.spawn(async move {
332 let id = uuid::Uuid::new_v4().to_string();
333 let request = GenerationRequest {
334 slot,
335 context: Some((*context).clone()),
336 system_prompt: None,
337 };
338
339 if let Some(ref obs) = observer {
340 obs.on_start(&id, &template_name, &name, &request);
341 }
342
343 match Self::generate_with_healing_static(&provider, &validator, &cache, &observer, request, max_retries, &id).await {
344 Ok(response) => {
345 if let Some(ref obs) = observer {
346 obs.on_success(&id, &response);
347 }
348 Ok::<_, AetherError>((name, response.code))
349 }
350 Err(e) => {
351 if let Some(ref obs) = observer {
352 obs.on_failure(&id, &e.to_string());
353 }
354 Err(e)
355 }
356 }
357 });
358 }
359
360 let mut injections = HashMap::new();
361 while let Some(result) = join_set.join_next().await {
362 let (name, code) = result.map_err(|e| AetherError::InjectionError(e.to_string()))??;
363 injections.insert(name, code);
364 }
365
366 Ok(injections)
367 }
368
369 async fn generate_with_retry(&self, request: GenerationRequest, id: &str) -> Result<GenerationResponse> {
371 Self::generate_with_healing_static(&self.provider, &self.validator, &self.cache, &self.observer, request, self.max_retries, id).await
372 }
373
374 async fn generate_with_healing_static(
376 provider: &Arc<P>,
377 validator: &Option<Arc<dyn Validator>>,
378 cache: &Option<Arc<dyn Cache>>,
379 observer: &Option<ObserverPtr>,
380 mut request: GenerationRequest,
381 max_retries: u32,
382 id: &str,
383 ) -> Result<GenerationResponse> {
384 let cache_key = if cache.is_some() {
386 Some(format!("{}:{}", request.slot.prompt, request.context.as_deref().unwrap_or("")))
387 } else {
388 None
389 };
390
391 if let (Some(ref c), Some(ref key)) = (cache, &cache_key) {
392 if let Some(cached_code) = c.get(key) {
393 debug!("Cache hit for slot: {}", request.slot.name);
394 return Ok(GenerationResponse {
395 code: cached_code,
396 tokens_used: None,
397 metadata: Some(serde_json::json!({"cache": "hit"})),
398 });
399 }
400 }
401
402 let mut last_error = None;
403
404 for attempt in 0..=max_retries {
405 let mut response = match provider.generate(request.clone()).await {
407 Ok(r) => r,
408 Err(e) => {
409 debug!("Generation attempt {} failed: {}", attempt + 1, e);
410 last_error = Some(e);
411 if attempt < max_retries {
412 tokio::time::sleep(std::time::Duration::from_millis(100 * (attempt as u64 + 1))).await;
413 continue;
414 }
415 return Err(last_error.unwrap());
416 }
417 };
418
419 if let Some(ref val) = validator {
421 if let Ok(formatted) = val.format(&request.slot.kind, &response.code) {
423 response.code = formatted;
424 }
425
426 match val.validate_with_slot(&request.slot, &response.code)? {
428 ValidationResult::Valid => {
429 if let (Some(ref c), Some(ref key)) = (cache, &cache_key) {
431 c.set(key, response.code.clone());
432 }
433 return Ok(response);
434 },
435 ValidationResult::Invalid(err_msg) => {
436 info!("Self-healing: Validation failed for slot '{}', attempt {}. Error: {}",
437 request.slot.name, attempt + 1, err_msg);
438
439 if let Some(ref obs) = observer {
440 obs.on_healing_step(id, attempt + 1, &err_msg);
441 }
442
443 last_error = Some(AetherError::ValidationFailed {
444 slot: request.slot.name.clone(),
445 error: err_msg.clone()
446 });
447
448 if attempt < max_retries {
449 request.slot.prompt = format!(
451 "{}\n\n[SELF-HEALING FEEDBACK]\nYour previous output had validation errors. Please fix them and output ONLY the corrected code.\nERROR:\n{}",
452 request.slot.prompt,
453 err_msg
454 );
455 continue;
456 }
457 }
458 }
459 } else {
460 if let (Some(ref c), Some(ref key)) = (cache, &cache_key) {
462 c.set(key, response.code.clone());
463 }
464 return Ok(response);
465 }
466 }
467
468 Err(last_error.unwrap_or_else(|| AetherError::MaxRetriesExceeded {
469 slot: request.slot.name,
470 retries: max_retries,
471 last_error: "Healing failed without specific error".to_string()
472 }))
473 }
474
475 pub async fn generate_slot(&self, template: &Template, slot_name: &str) -> Result<String> {
477 let slot = template
478 .slots
479 .get(slot_name)
480 .ok_or_else(|| AetherError::SlotNotFound(slot_name.to_string()))?;
481
482 let request = GenerationRequest {
483 slot: slot.clone(),
484 context: Some(self.global_context.to_prompt()),
485 system_prompt: None,
486 };
487
488 let id = uuid::Uuid::new_v4().to_string();
489 if let Some(ref obs) = self.observer {
490 obs.on_start(&id, &template.name, slot_name, &request);
491 }
492
493 match self.generate_with_retry(request, &id).await {
494 Ok(response) => {
495 if let Some(ref obs) = self.observer {
496 obs.on_success(&id, &response);
497 }
498 Ok(response.code)
499 }
500 Err(e) => {
501 if let Some(ref obs) = self.observer {
502 obs.on_failure(&id, &e.to_string());
503 }
504 Err(e)
505 }
506 }
507 }
508
509 pub fn generate_slot_stream(
511 &self,
512 template: &Template,
513 slot_name: &str,
514 ) -> Result<BoxStream<'static, Result<StreamResponse>>> {
515 let slot = template
516 .slots
517 .get(slot_name)
518 .ok_or_else(|| AetherError::SlotNotFound(slot_name.to_string()))?;
519
520 let request = GenerationRequest {
521 slot: slot.clone(),
522 context: Some(self.global_context.to_prompt()),
523 system_prompt: None,
524 };
525
526 Ok(self.provider.generate_stream(request))
527 }
528
529 pub async fn inject_raw(&self, prompt: &str) -> Result<String> {
532 let template = Template::new("{{AI:gen}}")
533 .with_slot("gen", prompt);
534
535 self.render(&template).await
536 }
537}
538
539#[macro_export]
547macro_rules! inject {
548 ($prompt:expr, $provider:expr) => {{
549 use $crate::{InjectionEngine, Slot, Template};
550
551 let template = Template::new("{{AI:generated}}")
552 .with_slot("generated", $prompt);
553
554 let engine = InjectionEngine::new($provider);
555 engine.render(&template)
556 }};
557}
558
559#[macro_export]
561macro_rules! inject_sync {
562 ($prompt:expr, $provider:expr) => {{
563 tokio::runtime::Runtime::new()
564 .unwrap()
565 .block_on($crate::inject!($prompt, $provider))
566 }};
567}
568
569#[cfg(test)]
570mod tests {
571 use super::*;
572 use crate::provider::MockProvider;
573
574 #[tokio::test]
575 async fn test_engine_render() {
576 let provider = MockProvider::new()
577 .with_response("content", "<p>Hello World</p>");
578
579 let engine = InjectionEngine::new(provider);
580
581 let template = Template::new("<div>{{AI:content}}</div>")
582 .with_slot("content", "Generate a paragraph");
583
584 let result = engine.render(&template).await.unwrap();
585 assert_eq!(result, "<div><p>Hello World</p></div>");
586 }
587
588 #[tokio::test]
589 async fn test_engine_with_context() {
590 let provider = MockProvider::new()
591 .with_response("button", "<button class='btn'>Click</button>");
592
593 let engine = InjectionEngine::new(provider)
594 .with_context(InjectionContext::new().with_framework("react"));
595
596 let template = Template::new("{{AI:button}}")
597 .with_slot("button", "Create a button");
598
599 let result = engine.render(&template).await.unwrap();
600 assert!(result.contains("button"));
601 }
602
603 #[tokio::test]
604 async fn test_parallel_generation() {
605 let provider = MockProvider::new()
606 .with_response("slot1", "code1")
607 .with_response("slot2", "code2");
608
609 let engine = InjectionEngine::new(provider).parallel(true);
610
611 let template = Template::new("{{AI:slot1}} | {{AI:slot2}}");
612
613 let result = engine.render(&template).await.unwrap();
614 assert!(result.contains("code1"));
615 assert!(result.contains("code2"));
616 }
617
618 #[tokio::test]
619 async fn test_max_retries_exceeded() {
620 let provider = MockProvider::new()
621 .with_response("fail", "invalid code");
622
623 struct FailingValidator;
625 impl Validator for FailingValidator {
626 fn validate(&self, _: &SlotKind, _: &str) -> Result<ValidationResult> {
627 Ok(ValidationResult::Invalid("Always fails".to_string()))
628 }
629 }
630
631 let engine = InjectionEngine::new(provider)
632 .with_validator(FailingValidator)
633 .max_retries(1);
634
635 let template = Template::new("{{AI:fail}}");
636 let result = engine.render(&template).await;
637
638 match result {
639 Err(AetherError::MaxRetriesExceeded { slot, retries, .. }) => {
640 assert_eq!(slot, "fail");
641 assert_eq!(retries, 1);
642 }
643 _ => panic!("Expected MaxRetriesExceeded error, got {:?}", result),
644 }
645 }
646
647 #[tokio::test]
648 async fn test_auto_toon_activation() {
649 let provider = MockProvider::new()
650 .with_response("slot", "code");
651
652 let config = AetherConfig::default().with_auto_toon_threshold(Some(5));
654 let engine = InjectionEngine::with_config(provider, config)
655 .with_context(InjectionContext::new().with_framework("very_long_framework_name"));
656
657 let template = Template::new("{{AI:slot}}");
658 let _ = engine.render(&template).await.unwrap();
659
660 }
663}