1use crate::{
6 AetherError, AiProvider, InjectionContext, Result, Template,
7 provider::{GenerationRequest, GenerationResponse},
8};
9use std::collections::HashMap;
10use std::sync::Arc;
11use tracing::{debug, info, instrument};
12use futures::stream::BoxStream;
13use crate::provider::StreamResponse;
14use crate::validation::{Validator, ValidationResult};
15use crate::cache::Cache;
16use crate::toon::Toon;
17
18pub struct InjectionEngine<P: AiProvider> {
35 provider: Arc<P>,
37
38 validator: Option<Arc<dyn Validator>>,
40
41 cache: Option<Arc<dyn Cache>>,
43
44 use_toon: bool,
46
47 global_context: InjectionContext,
49
50 parallel: bool,
52
53 max_retries: u32,
55}
56
57impl<P: AiProvider + 'static> InjectionEngine<P> {
58 pub fn new(provider: P) -> Self {
60 Self {
61 provider: Arc::new(provider),
62 validator: None,
63 cache: None,
64 use_toon: false,
65 global_context: InjectionContext::default(),
66 parallel: true,
67 max_retries: 2,
68 }
69 }
70
71 pub fn with_cache(mut self, cache: impl Cache + 'static) -> Self {
73 self.cache = Some(Arc::new(cache));
74 self
75 }
76
77 pub fn with_toon(mut self, enabled: bool) -> Self {
79 self.use_toon = enabled;
80 self
81 }
82
83 pub fn with_validator(mut self, validator: impl Validator + 'static) -> Self {
85 self.validator = Some(Arc::new(validator));
86 self
87 }
88
89 pub fn with_context(mut self, context: InjectionContext) -> Self {
91 self.global_context = context;
92 self
93 }
94
95 pub fn parallel(mut self, enabled: bool) -> Self {
97 self.parallel = enabled;
98 self
99 }
100
101 pub fn max_retries(mut self, retries: u32) -> Self {
103 self.max_retries = retries;
104 self
105 }
106
107 #[instrument(skip(self, template), fields(template_name = %template.name))]
112 pub async fn render(&self, template: &Template) -> Result<String> {
113 info!("Rendering template: {}", template.name);
114
115 let injections = self.generate_all(template, None).await?;
116 template.render(&injections)
117 }
118
119 #[instrument(skip(self, template, context), fields(template_name = %template.name))]
121 pub async fn render_with_context(
122 &self,
123 template: &Template,
124 context: InjectionContext,
125 ) -> Result<String> {
126 info!("Rendering template with context: {}", template.name);
127
128 let injections = self.generate_all(template, Some(context)).await?;
129 template.render(&injections)
130 }
131
132 async fn generate_all(
135 &self,
136 template: &Template,
137 extra_context: Option<InjectionContext>,
138 ) -> Result<HashMap<String, String>> {
139 let mut injections = HashMap::new();
140
141 let context_prompt = if self.use_toon {
142 let toon_ctx = Toon::serialize(&serde_json::to_value(&self.global_context).unwrap_or(serde_json::json!({})));
144 format!("[CONTEXT:TOON]\n{}\n", toon_ctx)
145 } else if let Some(ref ctx) = extra_context {
146 format!("{}\n{}", self.global_context.to_prompt(), ctx.to_prompt())
147 } else {
148 self.global_context.to_prompt()
149 };
150
151 let mut context_prompt = context_prompt;
152 if self.validator.is_some() {
154 context_prompt.push_str("\n\nIMPORTANT: Include a Rust `#[cfg(test)] mod tests { ... }` block with at least one unit test to verify your code. The system will automatically run these tests to validate your output.");
155 }
156
157 if self.parallel {
158 injections = self
159 .generate_parallel(template, &context_prompt)
160 .await?;
161 } else {
162 for (name, slot) in &template.slots {
163 debug!("Generating code for slot: {}", name);
164
165 let request = GenerationRequest {
166 slot: slot.clone(),
167 context: Some(context_prompt.clone()),
168 system_prompt: None,
169 };
170
171 let response = self.generate_with_retry(request).await?;
172 injections.insert(name.clone(), response.code);
173 }
174 }
175
176 Ok(injections)
177 }
178
179 async fn generate_parallel(
181 &self,
182 template: &Template,
183 context_prompt: &str,
184 ) -> Result<HashMap<String, String>> {
185 use tokio::task::JoinSet;
186
187 let mut join_set = JoinSet::new();
188
189 for (name, slot) in template.slots.clone() {
190 let provider = Arc::clone(&self.provider);
191 let validator = self.validator.clone();
192 let cache = self.cache.clone();
193 let context = context_prompt.to_string();
194 let max_retries = self.max_retries;
195
196 join_set.spawn(async move {
197 let request = GenerationRequest {
198 slot,
199 context: Some(context),
200 system_prompt: None,
201 };
202
203 let response = Self::generate_with_healing_static(&provider, &validator, &cache, request, max_retries).await?;
204 Ok::<_, AetherError>((name, response.code))
205 });
206 }
207
208 let mut injections = HashMap::new();
209 while let Some(result) = join_set.join_next().await {
210 let (name, code) = result.map_err(|e| AetherError::InjectionError(e.to_string()))??;
211 injections.insert(name, code);
212 }
213
214 Ok(injections)
215 }
216
217 async fn generate_with_retry(&self, request: GenerationRequest) -> Result<GenerationResponse> {
219 Self::generate_with_healing_static(&self.provider, &self.validator, &self.cache, request, self.max_retries).await
220 }
221
222 async fn generate_with_healing_static(
224 provider: &Arc<P>,
225 validator: &Option<Arc<dyn Validator>>,
226 cache: &Option<Arc<dyn Cache>>,
227 mut request: GenerationRequest,
228 max_retries: u32,
229 ) -> Result<GenerationResponse> {
230 let cache_key = if cache.is_some() {
232 Some(format!("{}:{}", request.slot.prompt, request.context.as_deref().unwrap_or("")))
233 } else {
234 None
235 };
236
237 if let (Some(ref c), Some(ref key)) = (cache, &cache_key) {
238 if let Some(cached_code) = c.get(key) {
239 debug!("Cache hit for slot: {}", request.slot.name);
240 return Ok(GenerationResponse {
241 code: cached_code,
242 tokens_used: None,
243 metadata: Some(serde_json::json!({"cache": "hit"})),
244 });
245 }
246 }
247
248 let mut last_error = None;
249
250 for attempt in 0..=max_retries {
251 let mut response = match provider.generate(request.clone()).await {
253 Ok(r) => r,
254 Err(e) => {
255 debug!("Generation attempt {} failed: {}", attempt + 1, e);
256 last_error = Some(e);
257 if attempt < max_retries {
258 tokio::time::sleep(std::time::Duration::from_millis(100 * (attempt as u64 + 1))).await;
259 continue;
260 }
261 return Err(last_error.unwrap());
262 }
263 };
264
265 if let Some(ref val) = validator {
267 if let Ok(formatted) = val.format(&request.slot.kind, &response.code) {
269 response.code = formatted;
270 }
271
272 match val.validate(&request.slot.kind, &response.code)? {
273 ValidationResult::Valid => {
274 if let (Some(ref c), Some(ref key)) = (cache, &cache_key) {
276 c.set(key, response.code.clone());
277 }
278 return Ok(response);
279 },
280 ValidationResult::Invalid(err_msg) => {
281 info!("Self-healing: Validation failed for slot '{}', attempt {}. Error: {}",
282 request.slot.name, attempt + 1, err_msg);
283
284 last_error = Some(AetherError::ProviderError(err_msg.clone()));
285
286 if attempt < max_retries {
287 request.slot.prompt = format!(
289 "{}\n\n[SELF-HEALING FEEDBACK]\nYour previous output had validation errors. Please fix them and output ONLY the corrected code.\nERROR:\n{}",
290 request.slot.prompt,
291 err_msg
292 );
293 continue;
294 }
295 }
296 }
297 } else {
298 if let (Some(ref c), Some(ref key)) = (cache, &cache_key) {
300 c.set(key, response.code.clone());
301 }
302 return Ok(response);
303 }
304 }
305
306 Err(last_error.unwrap_or_else(|| AetherError::ProviderError("Healing failed".to_string())))
307 }
308
309 pub async fn generate_slot(&self, template: &Template, slot_name: &str) -> Result<String> {
311 let slot = template
312 .slots
313 .get(slot_name)
314 .ok_or_else(|| AetherError::SlotNotFound(slot_name.to_string()))?;
315
316 let request = GenerationRequest {
317 slot: slot.clone(),
318 context: Some(self.global_context.to_prompt()),
319 system_prompt: None,
320 };
321
322 let response = self.generate_with_retry(request).await?;
323 Ok(response.code)
324 }
325
326 pub fn generate_slot_stream(
328 &self,
329 template: &Template,
330 slot_name: &str,
331 ) -> Result<BoxStream<'static, Result<StreamResponse>>> {
332 let slot = template
333 .slots
334 .get(slot_name)
335 .ok_or_else(|| AetherError::SlotNotFound(slot_name.to_string()))?;
336
337 let request = GenerationRequest {
338 slot: slot.clone(),
339 context: Some(self.global_context.to_prompt()),
340 system_prompt: None,
341 };
342
343 Ok(self.provider.generate_stream(request))
344 }
345}
346
347#[macro_export]
355macro_rules! inject {
356 ($prompt:expr, $provider:expr) => {{
357 use $crate::{InjectionEngine, Slot, Template};
358
359 let template = Template::new("{{AI:generated}}")
360 .with_slot("generated", $prompt);
361
362 let engine = InjectionEngine::new($provider);
363 engine.render(&template)
364 }};
365}
366
367#[macro_export]
369macro_rules! inject_sync {
370 ($prompt:expr, $provider:expr) => {{
371 tokio::runtime::Runtime::new()
372 .unwrap()
373 .block_on($crate::inject!($prompt, $provider))
374 }};
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380 use crate::provider::MockProvider;
381
382 #[tokio::test]
383 async fn test_engine_render() {
384 let provider = MockProvider::new()
385 .with_response("content", "<p>Hello World</p>");
386
387 let engine = InjectionEngine::new(provider);
388
389 let template = Template::new("<div>{{AI:content}}</div>")
390 .with_slot("content", "Generate a paragraph");
391
392 let result = engine.render(&template).await.unwrap();
393 assert_eq!(result, "<div><p>Hello World</p></div>");
394 }
395
396 #[tokio::test]
397 async fn test_engine_with_context() {
398 let provider = MockProvider::new()
399 .with_response("button", "<button class='btn'>Click</button>");
400
401 let engine = InjectionEngine::new(provider)
402 .with_context(InjectionContext::new().with_framework("react"));
403
404 let template = Template::new("{{AI:button}}")
405 .with_slot("button", "Create a button");
406
407 let result = engine.render(&template).await.unwrap();
408 assert!(result.contains("button"));
409 }
410
411 #[tokio::test]
412 async fn test_parallel_generation() {
413 let provider = MockProvider::new()
414 .with_response("slot1", "code1")
415 .with_response("slot2", "code2");
416
417 let engine = InjectionEngine::new(provider).parallel(true);
418
419 let template = Template::new("{{AI:slot1}} | {{AI:slot2}}");
420
421 let result = engine.render(&template).await.unwrap();
422 assert!(result.contains("code1"));
423 assert!(result.contains("code2"));
424 }
425}