1use crate::{
6 AetherError, AiProvider, InjectionContext, Result, Template,
7 provider::{GenerationRequest, GenerationResponse},
8 config::AetherConfig,
9};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tracing::{debug, info, instrument};
13use futures::stream::BoxStream;
14use crate::provider::StreamResponse;
15use crate::validation::{Validator, ValidationResult};
16use crate::cache::Cache;
17use crate::toon::Toon;
18
19pub struct InjectionEngine<P: AiProvider> {
37 provider: Arc<P>,
39
40 validator: Option<Arc<dyn Validator>>,
42
43 cache: Option<Arc<dyn Cache>>,
45
46 use_toon: bool,
48
49 auto_toon_threshold: Option<usize>,
51
52 global_context: InjectionContext,
54
55 parallel: bool,
57
58 max_retries: u32,
60}
61
62impl<P: AiProvider + 'static> InjectionEngine<P> {
63 pub fn new(provider: P) -> Self {
65 Self::with_config(provider, AetherConfig::default())
66 }
67
68 pub fn with_config(provider: P, config: AetherConfig) -> Self {
70 Self {
71 provider: Arc::new(provider),
72 validator: None,
73 cache: None,
74 use_toon: config.toon_enabled,
75 auto_toon_threshold: config.auto_toon_threshold,
76 global_context: InjectionContext::default(),
77 parallel: config.parallel,
78 max_retries: config.max_retries,
79 }
80 }
81
82 pub fn with_cache(mut self, cache: impl Cache + 'static) -> Self {
84 self.cache = Some(Arc::new(cache));
85 self
86 }
87
88 pub fn with_toon(mut self, enabled: bool) -> Self {
90 self.use_toon = enabled;
91 self
92 }
93
94 pub fn with_validator(mut self, validator: impl Validator + 'static) -> Self {
96 self.validator = Some(Arc::new(validator));
97 self
98 }
99
100 pub fn with_context(mut self, context: InjectionContext) -> Self {
102 self.global_context = context;
103 self
104 }
105
106 pub fn parallel(mut self, enabled: bool) -> Self {
108 self.parallel = enabled;
109 self
110 }
111
112 pub fn max_retries(mut self, retries: u32) -> Self {
114 self.max_retries = retries;
115 self
116 }
117
118 #[instrument(skip(self, template), fields(template_name = %template.name))]
123 pub async fn render(&self, template: &Template) -> Result<String> {
124 info!("Rendering template: {}", template.name);
125
126 let injections = self.generate_all(template, None).await?;
127 template.render(&injections)
128 }
129
130 #[instrument(skip(self, template, context), fields(template_name = %template.name))]
132 pub async fn render_with_context(
133 &self,
134 template: &Template,
135 context: InjectionContext,
136 ) -> Result<String> {
137 info!("Rendering template with context: {}", template.name);
138
139 let injections = self.generate_all(template, Some(context)).await?;
140 template.render(&injections)
141 }
142
143 async fn generate_all(
144 &self,
145 template: &Template,
146 extra_context: Option<InjectionContext>,
147 ) -> Result<HashMap<String, String>> {
148 let mut injections = HashMap::new();
149
150 let base_context = if let Some(ref ctx) = extra_context {
152 format!("{}\n{}", self.global_context.to_prompt(), ctx.to_prompt())
153 } else {
154 self.global_context.to_prompt()
155 };
156
157 let should_use_toon = self.use_toon || self.auto_toon_threshold
159 .map(|threshold| base_context.len() >= threshold)
160 .unwrap_or(false);
161
162 let context_prompt = if should_use_toon {
163 let toon_ctx = Toon::serialize(&serde_json::to_value(&self.global_context).unwrap_or(serde_json::json!({})));
165 format!("[CONTEXT:TOON]\n{}\n[Note: TOON is a compact key:value format to save tokens]\n", toon_ctx)
166 } else {
167 base_context
168 };
169
170 let mut context_prompt = context_prompt;
171 if self.validator.is_some() {
173 context_prompt.push_str("\n\nIMPORTANT: Include a Rust `#[cfg(test)] mod tests { ... }` block with at least one unit test to verify your code. The system will automatically run these tests to validate your output.");
174 }
175
176 if self.parallel {
177 injections = self
178 .generate_parallel(template, &context_prompt)
179 .await?;
180 } else {
181 for (name, slot) in &template.slots {
182 debug!("Generating code for slot: {}", name);
183
184 let request = GenerationRequest {
185 slot: slot.clone(),
186 context: Some(context_prompt.clone()),
187 system_prompt: None,
188 };
189
190 let response = self.generate_with_retry(request).await?;
191 injections.insert(name.clone(), response.code);
192 }
193 }
194
195 Ok(injections)
196 }
197
198 async fn generate_parallel(
200 &self,
201 template: &Template,
202 context_prompt: &str,
203 ) -> Result<HashMap<String, String>> {
204 use tokio::task::JoinSet;
205
206 let mut join_set = JoinSet::new();
207
208 for (name, slot) in template.slots.clone() {
209 let provider = Arc::clone(&self.provider);
210 let validator = self.validator.clone();
211 let cache = self.cache.clone();
212 let context = context_prompt.to_string();
213 let max_retries = self.max_retries;
214
215 join_set.spawn(async move {
216 let request = GenerationRequest {
217 slot,
218 context: Some(context),
219 system_prompt: None,
220 };
221
222 let response = Self::generate_with_healing_static(&provider, &validator, &cache, request, max_retries).await?;
223 Ok::<_, AetherError>((name, response.code))
224 });
225 }
226
227 let mut injections = HashMap::new();
228 while let Some(result) = join_set.join_next().await {
229 let (name, code) = result.map_err(|e| AetherError::InjectionError(e.to_string()))??;
230 injections.insert(name, code);
231 }
232
233 Ok(injections)
234 }
235
236 async fn generate_with_retry(&self, request: GenerationRequest) -> Result<GenerationResponse> {
238 Self::generate_with_healing_static(&self.provider, &self.validator, &self.cache, request, self.max_retries).await
239 }
240
241 async fn generate_with_healing_static(
243 provider: &Arc<P>,
244 validator: &Option<Arc<dyn Validator>>,
245 cache: &Option<Arc<dyn Cache>>,
246 mut request: GenerationRequest,
247 max_retries: u32,
248 ) -> Result<GenerationResponse> {
249 let cache_key = if cache.is_some() {
251 Some(format!("{}:{}", request.slot.prompt, request.context.as_deref().unwrap_or("")))
252 } else {
253 None
254 };
255
256 if let (Some(ref c), Some(ref key)) = (cache, &cache_key) {
257 if let Some(cached_code) = c.get(key) {
258 debug!("Cache hit for slot: {}", request.slot.name);
259 return Ok(GenerationResponse {
260 code: cached_code,
261 tokens_used: None,
262 metadata: Some(serde_json::json!({"cache": "hit"})),
263 });
264 }
265 }
266
267 let mut last_error = None;
268
269 for attempt in 0..=max_retries {
270 let mut response = match provider.generate(request.clone()).await {
272 Ok(r) => r,
273 Err(e) => {
274 debug!("Generation attempt {} failed: {}", attempt + 1, e);
275 last_error = Some(e);
276 if attempt < max_retries {
277 tokio::time::sleep(std::time::Duration::from_millis(100 * (attempt as u64 + 1))).await;
278 continue;
279 }
280 return Err(last_error.unwrap());
281 }
282 };
283
284 if let Some(ref val) = validator {
286 if let Ok(formatted) = val.format(&request.slot.kind, &response.code) {
288 response.code = formatted;
289 }
290
291 match val.validate(&request.slot.kind, &response.code)? {
292 ValidationResult::Valid => {
293 if let (Some(ref c), Some(ref key)) = (cache, &cache_key) {
295 c.set(key, response.code.clone());
296 }
297 return Ok(response);
298 },
299 ValidationResult::Invalid(err_msg) => {
300 info!("Self-healing: Validation failed for slot '{}', attempt {}. Error: {}",
301 request.slot.name, attempt + 1, err_msg);
302
303 last_error = Some(AetherError::ProviderError(err_msg.clone()));
304
305 if attempt < max_retries {
306 request.slot.prompt = format!(
308 "{}\n\n[SELF-HEALING FEEDBACK]\nYour previous output had validation errors. Please fix them and output ONLY the corrected code.\nERROR:\n{}",
309 request.slot.prompt,
310 err_msg
311 );
312 continue;
313 }
314 }
315 }
316 } else {
317 if let (Some(ref c), Some(ref key)) = (cache, &cache_key) {
319 c.set(key, response.code.clone());
320 }
321 return Ok(response);
322 }
323 }
324
325 Err(last_error.unwrap_or_else(|| AetherError::ProviderError("Healing failed".to_string())))
326 }
327
328 pub async fn generate_slot(&self, template: &Template, slot_name: &str) -> Result<String> {
330 let slot = template
331 .slots
332 .get(slot_name)
333 .ok_or_else(|| AetherError::SlotNotFound(slot_name.to_string()))?;
334
335 let request = GenerationRequest {
336 slot: slot.clone(),
337 context: Some(self.global_context.to_prompt()),
338 system_prompt: None,
339 };
340
341 let response = self.generate_with_retry(request).await?;
342 Ok(response.code)
343 }
344
345 pub fn generate_slot_stream(
347 &self,
348 template: &Template,
349 slot_name: &str,
350 ) -> Result<BoxStream<'static, Result<StreamResponse>>> {
351 let slot = template
352 .slots
353 .get(slot_name)
354 .ok_or_else(|| AetherError::SlotNotFound(slot_name.to_string()))?;
355
356 let request = GenerationRequest {
357 slot: slot.clone(),
358 context: Some(self.global_context.to_prompt()),
359 system_prompt: None,
360 };
361
362 Ok(self.provider.generate_stream(request))
363 }
364}
365
366#[macro_export]
374macro_rules! inject {
375 ($prompt:expr, $provider:expr) => {{
376 use $crate::{InjectionEngine, Slot, Template};
377
378 let template = Template::new("{{AI:generated}}")
379 .with_slot("generated", $prompt);
380
381 let engine = InjectionEngine::new($provider);
382 engine.render(&template)
383 }};
384}
385
386#[macro_export]
388macro_rules! inject_sync {
389 ($prompt:expr, $provider:expr) => {{
390 tokio::runtime::Runtime::new()
391 .unwrap()
392 .block_on($crate::inject!($prompt, $provider))
393 }};
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399 use crate::provider::MockProvider;
400
401 #[tokio::test]
402 async fn test_engine_render() {
403 let provider = MockProvider::new()
404 .with_response("content", "<p>Hello World</p>");
405
406 let engine = InjectionEngine::new(provider);
407
408 let template = Template::new("<div>{{AI:content}}</div>")
409 .with_slot("content", "Generate a paragraph");
410
411 let result = engine.render(&template).await.unwrap();
412 assert_eq!(result, "<div><p>Hello World</p></div>");
413 }
414
415 #[tokio::test]
416 async fn test_engine_with_context() {
417 let provider = MockProvider::new()
418 .with_response("button", "<button class='btn'>Click</button>");
419
420 let engine = InjectionEngine::new(provider)
421 .with_context(InjectionContext::new().with_framework("react"));
422
423 let template = Template::new("{{AI:button}}")
424 .with_slot("button", "Create a button");
425
426 let result = engine.render(&template).await.unwrap();
427 assert!(result.contains("button"));
428 }
429
430 #[tokio::test]
431 async fn test_parallel_generation() {
432 let provider = MockProvider::new()
433 .with_response("slot1", "code1")
434 .with_response("slot2", "code2");
435
436 let engine = InjectionEngine::new(provider).parallel(true);
437
438 let template = Template::new("{{AI:slot1}} | {{AI:slot2}}");
439
440 let result = engine.render(&template).await.unwrap();
441 assert!(result.contains("code1"));
442 assert!(result.contains("code2"));
443 }
444}