Skip to main content

apcore_toolkit/
ai_enhancer.rs

1// AI-driven metadata enhancement using local SLMs.
2//
3// Uses an OpenAI-compatible local API (e.g., Ollama, vLLM, LM Studio) to fill
4// metadata gaps that static analysis cannot resolve.
5//
6// All AI-generated fields are tagged with `x-generated-by: slm` in the module's
7// metadata for auditability.
8
9use std::env;
10use std::time::Duration;
11
12use serde_json::{json, Value};
13use thiserror::Error;
14use tracing::warn;
15
16use apcore::module::ModuleAnnotations;
17
18use crate::types::ScannedModule;
19
20const DEFAULT_ENDPOINT: &str = "http://localhost:11434/v1";
21const DEFAULT_MODEL: &str = "qwen:0.6b";
22const DEFAULT_THRESHOLD: f64 = 0.7;
23const DEFAULT_BATCH_SIZE: usize = 5;
24const DEFAULT_TIMEOUT: u64 = 30;
25
26/// All annotation fields that the SLM may assign confidence scores for.
27///
28/// Keep in sync with `apcore::module::ModuleAnnotations` when upstream adds fields.
29const ANNOTATION_FIELDS: &[&str] = &[
30    "description",
31    "documentation",
32    "tags",
33    "version",
34    "cacheable",
35    "readonly",
36    "destructive",
37    "idempotent",
38    "requires_confirmation",
39    "long_running",
40    "category",
41];
42
43/// Errors returned by [`AIEnhancer`] operations.
44#[derive(Debug, Error)]
45pub enum AIEnhancerError {
46    /// Invalid configuration value.
47    #[error("invalid config: {0}")]
48    Config(String),
49    /// Failed to reach the SLM endpoint.
50    #[error("connection failed: {0}")]
51    Connection(String),
52    /// SLM returned an unparseable response.
53    #[error("bad response: {0}")]
54    Response(String),
55    /// Failed to build the HTTP agent.
56    #[error("agent build failed: {0}")]
57    AgentBuild(String),
58}
59
60/// Protocol for pluggable metadata enhancement.
61///
62/// # Blocking / async compatibility
63///
64/// `enhance` is a synchronous method. The bundled [`AIEnhancer`] performs
65/// blocking HTTP requests via `ureq`, so each call may park the current
66/// thread for up to `APCORE_AI_TIMEOUT` seconds (default 30) per module.
67/// **Do not call `enhance` directly from an async task** on a Tokio (or
68/// other async) runtime — it will block a runtime worker thread and can
69/// stall the scheduler under concurrent load.
70///
71/// From an async context, wrap the call in
72/// [`tokio::task::spawn_blocking`]:
73///
74/// ```ignore
75/// let enhanced = tokio::task::spawn_blocking(move || enhancer.enhance(modules)).await?;
76/// ```
77///
78/// Enhancement is a one-shot scanning-phase operation (not per-request),
79/// so this is typically invoked once during framework adapter bootstrap.
80pub trait Enhancer {
81    /// Enhance a list of ScannedModules by filling metadata gaps.
82    ///
83    /// Synchronous and potentially long-running. See the trait-level doc
84    /// comment for guidance on invoking from async contexts.
85    fn enhance(&self, modules: Vec<ScannedModule>) -> Vec<ScannedModule>;
86}
87
88/// Enhances ScannedModule metadata using a local SLM.
89///
90/// Configuration is read from environment variables or constructor parameters:
91/// - `APCORE_AI_ENABLED`: Enable enhancement (default: false).
92/// - `APCORE_AI_ENDPOINT`: OpenAI-compatible API URL.
93/// - `APCORE_AI_MODEL`: Model name.
94/// - `APCORE_AI_THRESHOLD`: Confidence threshold (0.0–1.0).
95/// - `APCORE_AI_BATCH_SIZE`: Modules per API call.
96/// - `APCORE_AI_TIMEOUT`: Timeout in seconds per API call.
97#[derive(Debug)]
98pub struct AIEnhancer {
99    pub endpoint: String,
100    pub model: String,
101    pub threshold: f64,
102    pub batch_size: usize,
103    pub timeout: u64,
104    // Reused across all call_llm() invocations to avoid rebuilding config per call.
105    agent: ureq::Agent,
106}
107
108impl AIEnhancer {
109    /// Create a new AIEnhancer with optional overrides.
110    ///
111    /// Falls back to environment variables, then defaults.
112    pub fn new(
113        endpoint: Option<String>,
114        model: Option<String>,
115        threshold: Option<f64>,
116        batch_size: Option<usize>,
117        timeout: Option<u64>,
118    ) -> Result<Self, AIEnhancerError> {
119        let endpoint = endpoint.unwrap_or_else(|| {
120            env::var("APCORE_AI_ENDPOINT").unwrap_or_else(|_| DEFAULT_ENDPOINT.into())
121        });
122        let model = model.unwrap_or_else(|| {
123            env::var("APCORE_AI_MODEL").unwrap_or_else(|_| DEFAULT_MODEL.into())
124        });
125        let threshold =
126            threshold.unwrap_or_else(|| parse_float_env("APCORE_AI_THRESHOLD", DEFAULT_THRESHOLD));
127        let batch_size = batch_size
128            .unwrap_or_else(|| parse_usize_env("APCORE_AI_BATCH_SIZE", DEFAULT_BATCH_SIZE));
129        let timeout =
130            timeout.unwrap_or_else(|| parse_u64_env("APCORE_AI_TIMEOUT", DEFAULT_TIMEOUT));
131
132        if !(0.0..=1.0).contains(&threshold) {
133            return Err(AIEnhancerError::Config(
134                "APCORE_AI_THRESHOLD must be between 0.0 and 1.0".into(),
135            ));
136        }
137        if batch_size == 0 {
138            return Err(AIEnhancerError::Config(
139                "APCORE_AI_BATCH_SIZE must be a positive integer".into(),
140            ));
141        }
142        if timeout == 0 {
143            return Err(AIEnhancerError::Config(
144                "APCORE_AI_TIMEOUT must be a positive integer".into(),
145            ));
146        }
147
148        let agent = ureq::Agent::config_builder()
149            .timeout_global(Some(Duration::from_secs(timeout)))
150            .build()
151            .new_agent();
152
153        Ok(Self {
154            endpoint,
155            model,
156            threshold,
157            batch_size,
158            timeout,
159            agent,
160        })
161    }
162
163    /// Check whether AI enhancement is enabled via environment.
164    pub fn is_enabled() -> bool {
165        env::var("APCORE_AI_ENABLED")
166            .map(|v| matches!(v.to_lowercase().as_str(), "true" | "1" | "yes"))
167            .unwrap_or(false)
168    }
169
170    /// Identify which metadata fields are missing or at defaults.
171    fn identify_gaps(&self, module: &ScannedModule) -> Vec<String> {
172        let mut gaps: Vec<String> = Vec::new();
173
174        if module.description.is_empty() || module.description == module.module_id {
175            gaps.push("description".into());
176        }
177        if module.documentation.is_none() {
178            gaps.push("documentation".into());
179        }
180        if module.annotations.is_none()
181            || module
182                .annotations
183                .as_ref()
184                .is_some_and(is_default_annotations)
185        {
186            gaps.push("annotations".into());
187        }
188        if module
189            .input_schema
190            .get("properties")
191            .and_then(|p| p.as_object())
192            .map(|o| o.is_empty())
193            .unwrap_or(true)
194        {
195            gaps.push("input_schema".into());
196        }
197
198        gaps
199    }
200
201    /// Build a structured prompt for the SLM.
202    fn build_prompt(&self, module: &ScannedModule, gaps: &[String]) -> String {
203        let mut parts = vec![
204            "You are analyzing a function to generate metadata for an AI-perceivable module system.".into(),
205            String::new(),
206            format!("Module ID: {}", module.module_id),
207            format!("Target: {}", module.target),
208        ];
209
210        if !module.description.is_empty() {
211            parts.push(format!("Current description: {}", module.description));
212        }
213
214        parts.push(String::new());
215        parts.push("Please provide the following missing metadata as JSON:".into());
216        parts.push("{".into());
217
218        for gap in gaps {
219            match gap.as_str() {
220                "description" => {
221                    parts.push(
222                        r#"  "description": "<≤200 chars, what this function does>","#.into(),
223                    );
224                }
225                "documentation" => {
226                    parts.push(r#"  "documentation": "<detailed Markdown explanation>","#.into());
227                }
228                "annotations" => {
229                    parts.push(r#"  "annotations": {"#.into());
230                    parts.push(r#"    "readonly": <true if no side effects>,"#.into());
231                    parts.push(r#"    "destructive": <true if deletes/overwrites data>,"#.into());
232                    parts.push(r#"    "idempotent": <true if safe to retry>,"#.into());
233                    parts.push(r#"    "requires_approval": <true if dangerous operation>,"#.into());
234                    parts.push(r#"    "open_world": <true if calls external systems>,"#.into());
235                    parts
236                        .push(r#"    "streaming": <true if yields results incrementally>,"#.into());
237                    parts.push(r#"    "cacheable": <true if results can be cached>,"#.into());
238                    parts.push(r#"    "cache_ttl": <seconds, 0 for no expiry>,"#.into());
239                    parts.push(r#"    "cache_key_fields": <list of input field names for cache key, or null for all>,"#.into());
240                    parts.push(r#"    "paginated": <true if supports pagination>,"#.into());
241                    parts
242                        .push(r#"    "pagination_style": <"cursor" or "offset" or "page">"#.into());
243                    parts.push("  },".into());
244                }
245                "input_schema" => {
246                    parts.push(
247                        r#"  "input_schema": <JSON Schema object for function parameters>,"#.into(),
248                    );
249                }
250                _ => {}
251            }
252        }
253
254        let confidence_keys: serde_json::Value = ANNOTATION_FIELDS
255            .iter()
256            .map(|&field| (field.to_string(), serde_json::json!(0.0)))
257            .collect::<serde_json::Map<_, _>>()
258            .into();
259        let confidence_str =
260            serde_json::to_string_pretty(&confidence_keys).unwrap_or_else(|_| "{}".into());
261        parts.push(format!(r#"  "confidence": {confidence_str}"#));
262        parts.push("}".into());
263        parts.push(String::new());
264        parts.push("Respond with ONLY valid JSON, no markdown fences or explanation.".into());
265
266        parts.join("\n")
267    }
268
269    /// Call the OpenAI-compatible API and return the response text.
270    fn call_llm(&self, prompt: &str) -> Result<String, AIEnhancerError> {
271        let url = format!("{}/chat/completions", self.endpoint.trim_end_matches('/'));
272        let payload = json!({
273            "model": self.model,
274            "messages": [{"role": "user", "content": prompt}],
275            "temperature": 0.1,
276        });
277
278        let body: Value = self
279            .agent
280            .post(&url)
281            .header("Content-Type", "application/json")
282            .send_json(&payload)
283            .map_err(|e| AIEnhancerError::Connection(format!("Failed to reach SLM at {url}: {e}")))?
284            .body_mut()
285            .read_json()
286            .map_err(|e| AIEnhancerError::Response(format!("Failed to parse SLM response: {e}")))?;
287
288        body["choices"][0]["message"]["content"]
289            .as_str()
290            .map(|s| s.to_string())
291            .ok_or_else(|| AIEnhancerError::Response("Unexpected API response structure".into()))
292    }
293
294    /// Parse the SLM response as JSON, stripping markdown fences if present.
295    fn parse_response(response: &str) -> Result<Value, AIEnhancerError> {
296        let mut text = response.trim().to_string();
297
298        // Strip markdown code fences if the response is more than one line
299        if text.starts_with("```") {
300            let lines: Vec<&str> = text.split('\n').collect();
301            if lines.len() > 1 {
302                let start = if lines[0].starts_with("```") { 1 } else { 0 };
303                let end = if lines.last().map(|l| l.trim()) == Some("```") {
304                    lines.len() - 1
305                } else {
306                    lines.len()
307                };
308                text = lines[start..end].join("\n");
309            }
310        }
311
312        serde_json::from_str(&text)
313            .map_err(|e| AIEnhancerError::Response(format!("SLM returned invalid JSON: {e}")))
314    }
315
316    /// Enhance a single module by calling the SLM.
317    fn enhance_module(
318        &self,
319        module: &ScannedModule,
320        gaps: &[String],
321    ) -> Result<ScannedModule, AIEnhancerError> {
322        let prompt = self.build_prompt(module, gaps);
323        let response = self.call_llm(&prompt)?;
324        let parsed = Self::parse_response(&response)?;
325
326        let mut result = module.clone();
327        let mut confidence: serde_json::Map<String, Value> = serde_json::Map::new();
328
329        // Apply description
330        if gaps.iter().any(|g| g == "description") {
331            if let Some(desc) = parsed.get("description").and_then(|v| v.as_str()) {
332                let conf = parsed
333                    .get("confidence")
334                    .and_then(|c| c.get("description"))
335                    .and_then(|v| v.as_f64())
336                    .unwrap_or(0.0);
337                confidence.insert("description".into(), json!(conf));
338                if conf >= self.threshold {
339                    result.description = clamp_str(desc, 500, &module.module_id, "description");
340                } else {
341                    result.warnings.push(format!(
342                        "Low confidence ({conf:.2}) for description — skipped. Review manually."
343                    ));
344                }
345            }
346        }
347
348        // Apply documentation
349        if gaps.iter().any(|g| g == "documentation") {
350            if let Some(doc) = parsed.get("documentation").and_then(|v| v.as_str()) {
351                let conf = parsed
352                    .get("confidence")
353                    .and_then(|c| c.get("documentation"))
354                    .and_then(|v| v.as_f64())
355                    .unwrap_or(0.0);
356                confidence.insert("documentation".into(), json!(conf));
357                if conf >= self.threshold {
358                    result.documentation = Some(strip_ansi(&clamp_str(
359                        doc,
360                        2000,
361                        &module.module_id,
362                        "documentation",
363                    )));
364                } else {
365                    result.warnings.push(format!(
366                        "Low confidence ({conf:.2}) for documentation — skipped. Review manually."
367                    ));
368                }
369            }
370        }
371
372        // Apply annotations if above threshold (per-field confidence)
373        if gaps.iter().any(|g| g == "annotations") {
374            if let Some(ann_data) = parsed.get("annotations").and_then(|v| v.as_object()) {
375                let ann_conf = parsed
376                    .get("confidence")
377                    .and_then(|v| v.as_object())
378                    .cloned()
379                    .unwrap_or_default();
380                let mut base = module.annotations.clone().unwrap_or_default();
381                let mut any_accepted = false;
382
383                // Iterate boolean fields supplied by the SLM directly.
384                // `set_bool_annotation` validates each field's existence
385                // on `ModuleAnnotations` via a serde round-trip, so the
386                // set of known bool fields lives in one place — the
387                // upstream struct — and new fields added upstream are
388                // picked up automatically.
389                for (field, field_val) in ann_data.iter() {
390                    let Some(bool_val) = field_val.as_bool() else {
391                        continue;
392                    };
393                    let field_conf = get_annotation_confidence(&ann_conf, field);
394                    confidence.insert(format!("annotations.{field}"), json!(field_conf));
395                    if field_conf >= self.threshold {
396                        if set_bool_annotation(&mut base, field, bool_val) {
397                            any_accepted = true;
398                        } else {
399                            result.warnings.push(format!(
400                                "SLM returned unknown bool annotation '{field}' — ignored."
401                            ));
402                        }
403                    } else {
404                        result.warnings.push(format!(
405                            "Low confidence ({field_conf:.2}) for annotations.{field} — skipped. Review manually."
406                        ));
407                    }
408                }
409
410                // Integer fields: cache_ttl
411                if let Some(val) = ann_data.get("cache_ttl").and_then(|v| v.as_u64()) {
412                    let field_conf = get_annotation_confidence(&ann_conf, "cache_ttl");
413                    confidence.insert("annotations.cache_ttl".into(), json!(field_conf));
414                    if field_conf >= self.threshold {
415                        base.cache_ttl = val;
416                        any_accepted = true;
417                    } else {
418                        result.warnings.push(format!(
419                            "Low confidence ({field_conf:.2}) for annotations.cache_ttl — skipped. Review manually."
420                        ));
421                    }
422                }
423
424                // String fields: pagination_style
425                if let Some(val) = ann_data.get("pagination_style").and_then(|v| v.as_str()) {
426                    let field_conf = get_annotation_confidence(&ann_conf, "pagination_style");
427                    confidence.insert("annotations.pagination_style".into(), json!(field_conf));
428                    if field_conf >= self.threshold {
429                        base.pagination_style = val.to_string();
430                        any_accepted = true;
431                    } else {
432                        result.warnings.push(format!(
433                            "Low confidence ({field_conf:.2}) for annotations.pagination_style — skipped. Review manually."
434                        ));
435                    }
436                }
437
438                // List fields: cache_key_fields
439                if let Some(arr) = ann_data.get("cache_key_fields").and_then(|v| v.as_array()) {
440                    let field_conf = get_annotation_confidence(&ann_conf, "cache_key_fields");
441                    confidence.insert("annotations.cache_key_fields".into(), json!(field_conf));
442                    if field_conf >= self.threshold {
443                        let keys: Vec<String> = arr
444                            .iter()
445                            .filter_map(|v| v.as_str().map(|s| s.to_string()))
446                            .collect();
447                        base.cache_key_fields = Some(keys);
448                        any_accepted = true;
449                    } else {
450                        result.warnings.push(format!(
451                            "Low confidence ({field_conf:.2}) for annotations.cache_key_fields — skipped. Review manually."
452                        ));
453                    }
454                }
455
456                if any_accepted {
457                    result.annotations = Some(base);
458                }
459            }
460        }
461
462        // Apply input_schema if above threshold
463        if gaps.iter().any(|g| g == "input_schema") {
464            if let Some(schema) = parsed.get("input_schema") {
465                let conf = parsed
466                    .get("confidence")
467                    .and_then(|c| c.get("input_schema"))
468                    .and_then(|v| v.as_f64())
469                    .unwrap_or(0.0);
470                confidence.insert("input_schema".into(), json!(conf));
471                if conf >= self.threshold {
472                    result.input_schema = schema.clone();
473                } else {
474                    result.warnings.push(format!(
475                        "Low confidence ({conf:.2}) for input_schema — skipped. Review manually."
476                    ));
477                }
478            }
479        }
480
481        // Tag AI-generated fields
482        if !confidence.is_empty() {
483            result
484                .metadata
485                .insert("x-generated-by".into(), Value::String("slm".into()));
486            result
487                .metadata
488                .insert("x-ai-confidence".into(), Value::Object(confidence));
489        }
490
491        Ok(result)
492    }
493}
494
495impl Enhancer for AIEnhancer {
496    fn enhance(&self, modules: Vec<ScannedModule>) -> Vec<ScannedModule> {
497        let mut results: Vec<ScannedModule> = Vec::with_capacity(modules.len());
498
499        let mut pending: Vec<(usize, Vec<String>)> = Vec::new();
500        for (idx, module) in modules.iter().enumerate() {
501            let gaps = self.identify_gaps(module);
502            results.push(module.clone());
503            if !gaps.is_empty() {
504                pending.push((idx, gaps));
505            }
506        }
507
508        for batch in pending.chunks(self.batch_size) {
509            for (idx, gaps) in batch {
510                match self.enhance_module(&modules[*idx], gaps) {
511                    Ok(enhanced) => results[*idx] = enhanced,
512                    Err(e) => {
513                        warn!("AI enhancement failed for {}: {e}", modules[*idx].module_id);
514                    }
515                }
516            }
517        }
518
519        results
520    }
521}
522
523/// Check whether annotations are at their default values.
524///
525/// Uses `serde_json` round-trip equality so the comparison automatically
526/// covers any new field added to `apcore::module::ModuleAnnotations` upstream
527/// (including the `extra` extension map). `ModuleAnnotations` does not
528/// implement `PartialEq`, so direct `==` is unavailable.
529fn is_default_annotations(ann: &ModuleAnnotations) -> bool {
530    match (
531        serde_json::to_value(ann),
532        serde_json::to_value(ModuleAnnotations::default()),
533    ) {
534        (Ok(a), Ok(b)) => a == b,
535        _ => false,
536    }
537}
538
539/// Get confidence for an annotation field, checking both `annotations.<field>` and `<field>` keys.
540fn get_annotation_confidence(conf: &serde_json::Map<String, Value>, field: &str) -> f64 {
541    conf.get(&format!("annotations.{field}"))
542        .or_else(|| conf.get(field))
543        .and_then(|v| v.as_f64())
544        .unwrap_or(0.0)
545}
546
547/// Set a boolean field on `ModuleAnnotations` by name via a serde
548/// round-trip. Returns `true` if the field exists on the struct and is a
549/// boolean; `false` if the field is unknown, non-boolean, or the
550/// round-trip fails. Using serde rather than a hardcoded match removes
551/// the two-list drift risk — new bool fields added to
552/// `apcore::module::ModuleAnnotations` upstream are picked up
553/// automatically.
554fn set_bool_annotation(ann: &mut ModuleAnnotations, field: &str, value: bool) -> bool {
555    let mut serialized = match serde_json::to_value(&ann) {
556        Ok(v) => v,
557        Err(e) => {
558            warn!("set_bool_annotation: serialize failed: {e}");
559            return false;
560        }
561    };
562    let Some(obj) = serialized.as_object_mut() else {
563        return false;
564    };
565    match obj.get(field) {
566        Some(Value::Bool(_)) => {
567            obj.insert(field.to_string(), Value::Bool(value));
568        }
569        // Field absent, or present but not a bool — reject rather than
570        // fabricate a new key (serde would happily accept unknown keys
571        // via `#[serde(extra)]` on ModuleAnnotations, but misclassifying
572        // a non-bool field as bool would corrupt the struct).
573        _ => return false,
574    }
575    match serde_json::from_value::<ModuleAnnotations>(serialized) {
576        Ok(new_ann) => {
577            *ann = new_ann;
578            true
579        }
580        Err(e) => {
581            warn!("set_bool_annotation: deserialize failed: {e}");
582            false
583        }
584    }
585}
586
587fn parse_float_env(name: &str, default: f64) -> f64 {
588    match env::var(name) {
589        Ok(v) => v.parse().unwrap_or_else(|_| {
590            warn!(env_var = name, value = %v, "unparseable float env var — using default {default}");
591            default
592        }),
593        Err(_) => default,
594    }
595}
596
597fn parse_usize_env(name: &str, default: usize) -> usize {
598    match env::var(name) {
599        Ok(v) => v.parse().unwrap_or_else(|_| {
600            warn!(env_var = name, value = %v, "unparseable usize env var — using default {default}");
601            default
602        }),
603        Err(_) => default,
604    }
605}
606
607fn parse_u64_env(name: &str, default: u64) -> u64 {
608    match env::var(name) {
609        Ok(v) => v.parse().unwrap_or_else(|_| {
610            warn!(env_var = name, value = %v, "unparseable u64 env var — using default {default}");
611            default
612        }),
613        Err(_) => default,
614    }
615}
616
617/// Clamp an SLM-supplied string to `max_chars` bytes, warning if truncated.
618fn clamp_str(s: &str, max_chars: usize, module_id: &str, field: &str) -> String {
619    if s.len() <= max_chars {
620        return s.to_string();
621    }
622    // Truncate at a char boundary.
623    let truncated = &s[..s
624        .char_indices()
625        .take_while(|(i, _)| *i < max_chars)
626        .last()
627        .map(|(i, c)| i + c.len_utf8())
628        .unwrap_or(max_chars)];
629    tracing::warn!(
630        module_id = %module_id,
631        field = %field,
632        original_len = s.len(),
633        clamped_len = truncated.len(),
634        "SLM-supplied field truncated to prevent oversized output"
635    );
636    truncated.to_string()
637}
638
639/// Strip ANSI CSI escape sequences (ESC [ ... letter) from a string.
640fn strip_ansi(s: &str) -> String {
641    let mut out = String::with_capacity(s.len());
642    let mut chars = s.chars().peekable();
643    while let Some(c) = chars.next() {
644        if c == '\x1b' && chars.peek() == Some(&'[') {
645            chars.next(); // consume '['
646            for c2 in chars.by_ref() {
647                if c2.is_ascii_alphabetic() {
648                    break;
649                }
650            }
651        } else {
652            out.push(c);
653        }
654    }
655    out
656}
657
658#[cfg(test)]
659mod tests {
660    use super::*;
661    use apcore::module::ModuleAnnotations;
662    use serde_json::json;
663
664    #[test]
665    fn test_ai_enhancer_new_defaults() {
666        let enhancer = AIEnhancer::new(None, None, None, None, None).unwrap();
667        assert_eq!(enhancer.endpoint, DEFAULT_ENDPOINT);
668        assert_eq!(enhancer.model, DEFAULT_MODEL);
669        assert!((enhancer.threshold - DEFAULT_THRESHOLD).abs() < f64::EPSILON);
670        assert_eq!(enhancer.batch_size, DEFAULT_BATCH_SIZE);
671        assert_eq!(enhancer.timeout, DEFAULT_TIMEOUT);
672    }
673
674    #[test]
675    fn test_ai_enhancer_new_with_overrides() {
676        let enhancer = AIEnhancer::new(
677            Some("http://custom:8080".into()),
678            Some("llama3".into()),
679            Some(0.5),
680            Some(10),
681            Some(60),
682        )
683        .unwrap();
684        assert_eq!(enhancer.endpoint, "http://custom:8080");
685        assert_eq!(enhancer.model, "llama3");
686        assert!((enhancer.threshold - 0.5).abs() < f64::EPSILON);
687    }
688
689    #[test]
690    fn test_ai_enhancer_threshold_validation() {
691        let result = AIEnhancer::new(None, None, Some(1.5), None, None);
692        assert!(result.is_err());
693    }
694
695    #[test]
696    fn test_ai_enhancer_batch_size_validation() {
697        let result = AIEnhancer::new(None, None, None, Some(0), None);
698        assert!(result.is_err());
699    }
700
701    #[test]
702    fn test_identify_gaps_complete_module() {
703        let enhancer = AIEnhancer::new(None, None, None, None, None).unwrap();
704        let mut module = ScannedModule::new(
705            "test".into(),
706            "A real description".into(),
707            json!({"type": "object", "properties": {"x": {"type": "string"}}}),
708            json!({}),
709            vec![],
710            "app:func".into(),
711        );
712        module.documentation = Some("Full docs".into());
713        module.annotations = Some(ModuleAnnotations {
714            readonly: true,
715            ..Default::default()
716        });
717        let gaps = enhancer.identify_gaps(&module);
718        assert!(gaps.is_empty());
719    }
720
721    #[test]
722    fn test_identify_gaps_missing_fields() {
723        let enhancer = AIEnhancer::new(None, None, None, None, None).unwrap();
724        let module = ScannedModule::new(
725            "test".into(),
726            String::new(),
727            json!({"type": "object"}),
728            json!({}),
729            vec![],
730            "app:func".into(),
731        );
732        let gaps = enhancer.identify_gaps(&module);
733        assert!(gaps.iter().any(|g| g == "description"));
734        assert!(gaps.iter().any(|g| g == "documentation"));
735        assert!(gaps.iter().any(|g| g == "annotations"));
736        assert!(gaps.iter().any(|g| g == "input_schema"));
737    }
738
739    #[test]
740    fn test_parse_response_valid_json() {
741        let response = r#"{"description": "hello", "confidence": {"description": 0.9}}"#;
742        let result = AIEnhancer::parse_response(response).unwrap();
743        assert_eq!(result["description"], "hello");
744    }
745
746    #[test]
747    fn test_parse_response_with_fences() {
748        let response = "```json\n{\"key\": \"value\"}\n```";
749        let result = AIEnhancer::parse_response(response).unwrap();
750        assert_eq!(result["key"], "value");
751    }
752
753    #[test]
754    fn test_parse_response_invalid() {
755        let result = AIEnhancer::parse_response("not json");
756        assert!(result.is_err());
757    }
758
759    #[test]
760    fn test_is_enabled_default() {
761        // Assuming env var is not set in test environment
762        env::remove_var("APCORE_AI_ENABLED");
763        assert!(!AIEnhancer::is_enabled());
764    }
765
766    #[test]
767    fn test_build_prompt_contains_module_info() {
768        let enhancer = AIEnhancer::new(None, None, None, None, None).unwrap();
769        let module = ScannedModule::new(
770            "users.get".into(),
771            "Get user".into(),
772            json!({}),
773            json!({}),
774            vec![],
775            "app:get_user".into(),
776        );
777        let prompt = enhancer.build_prompt(&module, &["description".into()]);
778        assert!(prompt.contains("users.get"));
779        assert!(prompt.contains("app:get_user"));
780        assert!(prompt.contains("description"));
781    }
782
783    #[test]
784    fn test_identify_gaps_description_equals_module_id() {
785        let enhancer = AIEnhancer::new(None, None, None, None, None).unwrap();
786        let module = ScannedModule::new(
787            "my_module".into(),
788            "my_module".into(), // description == module_id
789            json!({"type": "object", "properties": {"x": {"type": "string"}}}),
790            json!({}),
791            vec![],
792            "app:func".into(),
793        );
794        let gaps = enhancer.identify_gaps(&module);
795        assert!(
796            gaps.iter().any(|g| g == "description"),
797            "description matching module_id should be identified as a gap"
798        );
799    }
800
801    #[test]
802    fn test_ai_enhancer_timeout_validation() {
803        let result = AIEnhancer::new(None, None, None, None, Some(0));
804        assert!(result.is_err());
805        let err = result.unwrap_err();
806        assert!(err
807            .to_string()
808            .contains("APCORE_AI_TIMEOUT must be a positive integer"));
809    }
810
811    // All is_enabled tests are combined into one to prevent env var races
812    // when tests run in parallel (env vars are process-global).
813    #[test]
814    fn test_is_enabled_variants() {
815        use std::sync::Mutex;
816        static ENV_LOCK: Mutex<()> = Mutex::new(());
817        let _guard = ENV_LOCK.lock().unwrap();
818
819        // Default (unset) → disabled
820        unsafe { env::remove_var("APCORE_AI_ENABLED") };
821        assert!(!AIEnhancer::is_enabled(), "should be disabled by default");
822
823        // "true" → enabled
824        unsafe { env::set_var("APCORE_AI_ENABLED", "true") };
825        assert!(AIEnhancer::is_enabled(), "\"true\" should enable");
826
827        // "yes" → enabled
828        unsafe { env::set_var("APCORE_AI_ENABLED", "yes") };
829        assert!(AIEnhancer::is_enabled(), "\"yes\" should enable");
830
831        // "1" → enabled
832        unsafe { env::set_var("APCORE_AI_ENABLED", "1") };
833        assert!(AIEnhancer::is_enabled(), "\"1\" should enable");
834
835        // "false" → disabled
836        unsafe { env::set_var("APCORE_AI_ENABLED", "false") };
837        assert!(!AIEnhancer::is_enabled(), "\"false\" should disable");
838
839        // Cleanup
840        unsafe { env::remove_var("APCORE_AI_ENABLED") };
841    }
842
843    #[test]
844    fn test_parse_response_strips_json_fence() {
845        let response = "```json\n{\"description\": \"hello world\"}\n```";
846        let result = AIEnhancer::parse_response(response).unwrap();
847        assert_eq!(result["description"], "hello world");
848    }
849
850    #[test]
851    fn test_build_prompt_requests_annotations() {
852        let enhancer = AIEnhancer::new(None, None, None, None, None).unwrap();
853        let module = ScannedModule::new(
854            "test".into(),
855            "desc".into(),
856            json!({}),
857            json!({}),
858            vec![],
859            "app:func".into(),
860        );
861        let prompt = enhancer.build_prompt(&module, &["annotations".into()]);
862        assert!(
863            prompt.contains("readonly"),
864            "prompt should mention annotations fields"
865        );
866        assert!(prompt.contains("destructive"));
867        assert!(prompt.contains("idempotent"));
868    }
869
870    #[test]
871    fn test_build_prompt_requests_input_schema() {
872        let enhancer = AIEnhancer::new(None, None, None, None, None).unwrap();
873        let module = ScannedModule::new(
874            "test".into(),
875            "desc".into(),
876            json!({}),
877            json!({}),
878            vec![],
879            "app:func".into(),
880        );
881        let prompt = enhancer.build_prompt(&module, &["input_schema".into()]);
882        assert!(
883            prompt.contains("input_schema"),
884            "prompt should mention input_schema"
885        );
886        assert!(prompt.contains("JSON Schema"));
887    }
888
889    #[test]
890    fn test_build_prompt_requests_documentation() {
891        let enhancer = AIEnhancer::new(None, None, None, None, None).unwrap();
892        let module = ScannedModule::new(
893            "test".into(),
894            "desc".into(),
895            json!({}),
896            json!({}),
897            vec![],
898            "app:func".into(),
899        );
900        let prompt = enhancer.build_prompt(&module, &["documentation".into()]);
901        assert!(
902            prompt.contains("documentation"),
903            "prompt should mention documentation"
904        );
905        assert!(prompt.contains("Markdown"));
906    }
907
908    #[test]
909    fn test_parse_response_single_line_fence_does_not_panic() {
910        // A single-line ``` response used to panic with lines[1..0].
911        let response = "```";
912        let result = AIEnhancer::parse_response(response);
913        assert!(result.is_err(), "single-line fence is not valid JSON");
914    }
915
916    #[test]
917    fn test_parse_response_backtick_only_line_treated_as_json() {
918        // Regression: must not panic, must return an error gracefully.
919        let response = "```\n```";
920        let result = AIEnhancer::parse_response(response);
921        // Empty string after stripping is invalid JSON.
922        assert!(result.is_err());
923    }
924
925    // ---- set_bool_annotation (serde round-trip, D4-1 regression guards) ----
926
927    #[test]
928    fn test_set_bool_annotation_readonly() {
929        let mut ann = ModuleAnnotations::default();
930        assert!(set_bool_annotation(&mut ann, "readonly", true));
931        assert!(ann.readonly);
932    }
933
934    #[test]
935    fn test_set_bool_annotation_destructive() {
936        let mut ann = ModuleAnnotations::default();
937        assert!(set_bool_annotation(&mut ann, "destructive", true));
938        assert!(ann.destructive);
939    }
940
941    #[test]
942    fn test_set_bool_annotation_unknown_field_rejected() {
943        let mut ann = ModuleAnnotations::default();
944        assert!(!set_bool_annotation(
945            &mut ann,
946            "nonexistent_field_xyz",
947            true
948        ));
949        // Annotations unchanged.
950        assert!(is_default_annotations(&ann));
951    }
952
953    #[test]
954    fn test_set_bool_annotation_non_bool_field_rejected() {
955        let mut ann = ModuleAnnotations::default();
956        // `cache_ttl` is an integer field on ModuleAnnotations.
957        // Round-trip rejects setting it to a bool.
958        assert!(!set_bool_annotation(&mut ann, "cache_ttl", true));
959        assert_eq!(ann.cache_ttl, 0); // unchanged default
960    }
961
962    #[test]
963    fn test_set_bool_annotation_preserves_other_fields() {
964        let mut ann = ModuleAnnotations {
965            destructive: true,
966            cache_ttl: 99,
967            ..Default::default()
968        };
969        assert!(set_bool_annotation(&mut ann, "readonly", true));
970        // Original fields survive the serde round-trip.
971        assert!(ann.readonly);
972        assert!(ann.destructive);
973        assert_eq!(ann.cache_ttl, 99);
974    }
975
976    #[test]
977    fn test_clamp_str_under_limit() {
978        let s = "hello";
979        assert_eq!(clamp_str(s, 500, "mod", "desc"), "hello");
980    }
981
982    #[test]
983    fn test_clamp_str_over_limit_truncates() {
984        let s = "a".repeat(600);
985        let result = clamp_str(&s, 500, "mod", "desc");
986        assert_eq!(result.len(), 500);
987    }
988
989    #[test]
990    fn test_clamp_str_unicode_boundary() {
991        // "é" is 2 bytes — ensure we don't split it
992        let s = "é".repeat(300); // 600 bytes
993        let result = clamp_str(&s, 500, "mod", "desc");
994        assert!(result.len() <= 500);
995        assert!(std::str::from_utf8(result.as_bytes()).is_ok());
996    }
997
998    #[test]
999    fn test_strip_ansi_no_sequences() {
1000        assert_eq!(strip_ansi("hello world"), "hello world");
1001    }
1002
1003    #[test]
1004    fn test_strip_ansi_removes_color_codes() {
1005        let input = "\x1b[31mred text\x1b[0m";
1006        assert_eq!(strip_ansi(input), "red text");
1007    }
1008
1009    #[test]
1010    fn test_strip_ansi_mixed_content() {
1011        let input = "normal \x1b[1mbold\x1b[0m text";
1012        assert_eq!(strip_ansi(input), "normal bold text");
1013    }
1014}