Skip to main content

apcore_toolkit/
ai_enhancer.rs

1// AI-driven metadata enhancement using local SLMs.
2//
3// Uses an OpenAI-compatible local API (e.g., Ollama, vLLM, LM Studio) to fill
4// metadata gaps that static analysis cannot resolve.
5//
6// All AI-generated fields are tagged with `x-generated-by: slm` in the module's
7// metadata for auditability.
8
9use std::env;
10use std::time::Duration;
11
12use serde_json::{json, Value};
13use thiserror::Error;
14use tracing::warn;
15
16use apcore::module::ModuleAnnotations;
17
18use crate::types::ScannedModule;
19
20const DEFAULT_ENDPOINT: &str = "http://localhost:11434/v1";
21const DEFAULT_MODEL: &str = "qwen:0.6b";
22const DEFAULT_THRESHOLD: f64 = 0.7;
23const DEFAULT_BATCH_SIZE: usize = 5;
24const DEFAULT_TIMEOUT: u64 = 30;
25
26/// Derive the list of annotation field names the SLM may assign confidence
27/// scores for by serializing a default `ModuleAnnotations` and inspecting
28/// its object keys.
29///
30/// This mirrors the dynamic-template approach used by the Python SDK
31/// (`dataclasses.fields(ModuleAnnotations)`) and the TypeScript SDK
32/// (`Object.entries(DEFAULT_ANNOTATIONS)`). Using runtime reflection keeps
33/// the SLM prompt template automatically in sync when upstream
34/// `apcore::module::ModuleAnnotations` gains or loses fields, eliminating
35/// the drift risk of a hardcoded list.
36///
37/// The `extra` open-extension map is excluded (matches TS behavior at
38/// `apcore-toolkit-typescript/src/ai-enhancer.ts`).
39fn annotation_field_names() -> Vec<String> {
40    match serde_json::to_value(ModuleAnnotations::default()) {
41        Ok(Value::Object(map)) => map
42            .into_iter()
43            .map(|(k, _)| k)
44            .filter(|k| k != "extra")
45            .collect(),
46        _ => Vec::new(),
47    }
48}
49
50/// Errors returned by [`AIEnhancer`] operations.
51#[derive(Debug, Error)]
52pub enum AIEnhancerError {
53    /// Invalid configuration value.
54    #[error("invalid config: {0}")]
55    Config(String),
56    /// Failed to reach the SLM endpoint.
57    #[error("connection failed: {0}")]
58    Connection(String),
59    /// SLM returned an unparseable response.
60    #[error("bad response: {0}")]
61    Response(String),
62    /// Failed to build the HTTP agent.
63    #[error("agent build failed: {0}")]
64    AgentBuild(String),
65}
66
67/// Protocol for pluggable metadata enhancement.
68///
69/// # Blocking / async compatibility
70///
71/// `enhance` is a synchronous method. The bundled [`AIEnhancer`] performs
72/// blocking HTTP requests via `ureq`, so each call may park the current
73/// thread for up to `APCORE_AI_TIMEOUT` seconds (default 30) per module.
74/// **Do not call `enhance` directly from an async task** on a Tokio (or
75/// other async) runtime — it will block a runtime worker thread and can
76/// stall the scheduler under concurrent load.
77///
78/// From an async context, wrap the call in
79/// [`tokio::task::spawn_blocking`]:
80///
81/// ```ignore
82/// let enhanced = tokio::task::spawn_blocking(move || enhancer.enhance(modules)).await?;
83/// ```
84///
85/// Enhancement is a one-shot scanning-phase operation (not per-request),
86/// so this is typically invoked once during framework adapter bootstrap.
87pub trait Enhancer {
88    /// Enhance a list of ScannedModules by filling metadata gaps.
89    ///
90    /// Synchronous and potentially long-running. See the trait-level doc
91    /// comment for guidance on invoking from async contexts.
92    fn enhance(&self, modules: Vec<ScannedModule>) -> Vec<ScannedModule>;
93}
94
95/// Enhances ScannedModule metadata using a local SLM.
96///
97/// Configuration is read from environment variables or constructor parameters:
98/// - `APCORE_AI_ENABLED`: Enable enhancement (default: false).
99/// - `APCORE_AI_ENDPOINT`: OpenAI-compatible API URL.
100/// - `APCORE_AI_MODEL`: Model name.
101/// - `APCORE_AI_THRESHOLD`: Confidence threshold (0.0–1.0).
102/// - `APCORE_AI_BATCH_SIZE`: Modules per API call.
103/// - `APCORE_AI_TIMEOUT`: Timeout in seconds per API call.
104#[derive(Debug)]
105pub struct AIEnhancer {
106    pub endpoint: String,
107    pub model: String,
108    pub threshold: f64,
109    pub batch_size: usize,
110    pub timeout: u64,
111    // Reused across all call_llm() invocations to avoid rebuilding config per call.
112    agent: ureq::Agent,
113}
114
115impl AIEnhancer {
116    /// Create a new AIEnhancer with optional overrides.
117    ///
118    /// Falls back to environment variables, then defaults.
119    pub fn new(
120        endpoint: Option<String>,
121        model: Option<String>,
122        threshold: Option<f64>,
123        batch_size: Option<usize>,
124        timeout: Option<u64>,
125    ) -> Result<Self, AIEnhancerError> {
126        let endpoint = endpoint.unwrap_or_else(|| {
127            env::var("APCORE_AI_ENDPOINT").unwrap_or_else(|_| DEFAULT_ENDPOINT.into())
128        });
129        validate_endpoint_scheme(&endpoint)?;
130        let model = model.unwrap_or_else(|| {
131            env::var("APCORE_AI_MODEL").unwrap_or_else(|_| DEFAULT_MODEL.into())
132        });
133        let threshold =
134            threshold.unwrap_or_else(|| parse_float_env("APCORE_AI_THRESHOLD", DEFAULT_THRESHOLD));
135        let batch_size = batch_size
136            .unwrap_or_else(|| parse_usize_env("APCORE_AI_BATCH_SIZE", DEFAULT_BATCH_SIZE));
137        let timeout =
138            timeout.unwrap_or_else(|| parse_u64_env("APCORE_AI_TIMEOUT", DEFAULT_TIMEOUT));
139
140        if !(0.0..=1.0).contains(&threshold) {
141            return Err(AIEnhancerError::Config(
142                "APCORE_AI_THRESHOLD must be between 0.0 and 1.0".into(),
143            ));
144        }
145        if batch_size == 0 {
146            return Err(AIEnhancerError::Config(
147                "APCORE_AI_BATCH_SIZE must be a positive integer".into(),
148            ));
149        }
150        if timeout == 0 {
151            return Err(AIEnhancerError::Config(
152                "APCORE_AI_TIMEOUT must be a positive integer".into(),
153            ));
154        }
155
156        let agent = ureq::Agent::config_builder()
157            .timeout_global(Some(Duration::from_secs(timeout)))
158            .build()
159            .new_agent();
160
161        Ok(Self {
162            endpoint,
163            model,
164            threshold,
165            batch_size,
166            timeout,
167            agent,
168        })
169    }
170
171    /// Check whether AI enhancement is enabled via environment.
172    pub fn is_enabled() -> bool {
173        env::var("APCORE_AI_ENABLED")
174            .map(|v| matches!(v.to_lowercase().as_str(), "true" | "1" | "yes"))
175            .unwrap_or(false)
176    }
177
178    /// Identify which metadata fields are missing or at defaults.
179    fn identify_gaps(&self, module: &ScannedModule) -> Vec<String> {
180        let mut gaps: Vec<String> = Vec::new();
181
182        if module.description.is_empty() || module.description == module.module_id {
183            gaps.push("description".into());
184        }
185        if module.documentation.is_none() {
186            gaps.push("documentation".into());
187        }
188        if module.annotations.is_none()
189            || module
190                .annotations
191                .as_ref()
192                .is_some_and(is_default_annotations)
193        {
194            gaps.push("annotations".into());
195        }
196        if module
197            .input_schema
198            .get("properties")
199            .and_then(|p| p.as_object())
200            .map(|o| o.is_empty())
201            .unwrap_or(true)
202        {
203            gaps.push("input_schema".into());
204        }
205
206        gaps
207    }
208
209    /// Build a structured prompt for the SLM.
210    fn build_prompt(&self, module: &ScannedModule, gaps: &[String]) -> String {
211        let mut parts = vec![
212            "You are analyzing a function to generate metadata for an AI-perceivable module system.".into(),
213            String::new(),
214            format!("Module ID: {}", module.module_id),
215            format!("Target: {}", module.target),
216        ];
217
218        if !module.description.is_empty() {
219            parts.push(format!("Current description: {}", module.description));
220        }
221
222        parts.push(String::new());
223        parts.push("Please provide the following missing metadata as JSON:".into());
224        parts.push("{".into());
225
226        for gap in gaps {
227            match gap.as_str() {
228                "description" => {
229                    parts.push(
230                        r#"  "description": "<≤200 chars, what this function does>","#.into(),
231                    );
232                }
233                "documentation" => {
234                    parts.push(r#"  "documentation": "<detailed Markdown explanation>","#.into());
235                }
236                "annotations" => {
237                    parts.push(r#"  "annotations": {"#.into());
238                    parts.push(r#"    "readonly": <true if no side effects>,"#.into());
239                    parts.push(r#"    "destructive": <true if deletes/overwrites data>,"#.into());
240                    parts.push(r#"    "idempotent": <true if safe to retry>,"#.into());
241                    parts.push(r#"    "requires_approval": <true if dangerous operation>,"#.into());
242                    parts.push(r#"    "open_world": <true if calls external systems>,"#.into());
243                    parts
244                        .push(r#"    "streaming": <true if yields results incrementally>,"#.into());
245                    parts.push(r#"    "cacheable": <true if results can be cached>,"#.into());
246                    parts.push(r#"    "cache_ttl": <seconds, 0 for no expiry>,"#.into());
247                    parts.push(r#"    "cache_key_fields": <list of input field names for cache key, or null for all>,"#.into());
248                    parts.push(r#"    "paginated": <true if supports pagination>,"#.into());
249                    parts
250                        .push(r#"    "pagination_style": <"cursor" or "offset" or "page">"#.into());
251                    parts.push("  },".into());
252                }
253                "input_schema" => {
254                    parts.push(
255                        r#"  "input_schema": <JSON Schema object for function parameters>,"#.into(),
256                    );
257                }
258                _ => {}
259            }
260        }
261
262        let confidence_keys: serde_json::Value = annotation_field_names()
263            .into_iter()
264            .map(|field| (field, serde_json::json!(0.0)))
265            .collect::<serde_json::Map<_, _>>()
266            .into();
267        let confidence_str =
268            serde_json::to_string_pretty(&confidence_keys).unwrap_or_else(|_| "{}".into());
269        parts.push(format!(r#"  "confidence": {confidence_str}"#));
270        parts.push("}".into());
271        parts.push(String::new());
272        parts.push("Respond with ONLY valid JSON, no markdown fences or explanation.".into());
273
274        parts.join("\n")
275    }
276
277    /// Call the OpenAI-compatible API and return the response text.
278    fn call_llm(&self, prompt: &str) -> Result<String, AIEnhancerError> {
279        let url = format!("{}/chat/completions", self.endpoint.trim_end_matches('/'));
280        let payload = json!({
281            "model": self.model,
282            "messages": [{"role": "user", "content": prompt}],
283            "temperature": 0.1,
284        });
285
286        let body: Value = self
287            .agent
288            .post(&url)
289            .header("Content-Type", "application/json")
290            .send_json(&payload)
291            .map_err(|e| AIEnhancerError::Connection(format!("Failed to reach SLM at {url}: {e}")))?
292            .body_mut()
293            .read_json()
294            .map_err(|e| AIEnhancerError::Response(format!("Failed to parse SLM response: {e}")))?;
295
296        body["choices"][0]["message"]["content"]
297            .as_str()
298            .map(|s| s.to_string())
299            .ok_or_else(|| AIEnhancerError::Response("Unexpected API response structure".into()))
300    }
301
302    /// Parse the SLM response as JSON, stripping markdown fences if present.
303    fn parse_response(response: &str) -> Result<Value, AIEnhancerError> {
304        let mut text = response.trim().to_string();
305
306        // Strip markdown code fences if the response is more than one line
307        if text.starts_with("```") {
308            let lines: Vec<&str> = text.split('\n').collect();
309            if lines.len() > 1 {
310                let start = if lines[0].starts_with("```") { 1 } else { 0 };
311                let end = if lines.last().map(|l| l.trim()) == Some("```") {
312                    lines.len() - 1
313                } else {
314                    lines.len()
315                };
316                text = lines[start..end].join("\n");
317            }
318        }
319
320        serde_json::from_str(&text)
321            .map_err(|e| AIEnhancerError::Response(format!("SLM returned invalid JSON: {e}")))
322    }
323
324    /// Enhance a single module by calling the SLM.
325    fn enhance_module(
326        &self,
327        module: &ScannedModule,
328        gaps: &[String],
329    ) -> Result<ScannedModule, AIEnhancerError> {
330        let prompt = self.build_prompt(module, gaps);
331        let response = self.call_llm(&prompt)?;
332        let parsed = Self::parse_response(&response)?;
333
334        let mut result = module.clone();
335        let mut confidence: serde_json::Map<String, Value> = serde_json::Map::new();
336
337        // Apply description
338        if gaps.iter().any(|g| g == "description") {
339            if let Some(desc) = parsed.get("description").and_then(|v| v.as_str()) {
340                let conf = parsed
341                    .get("confidence")
342                    .and_then(|c| c.get("description"))
343                    .and_then(|v| v.as_f64())
344                    .unwrap_or(0.0);
345                confidence.insert("description".into(), json!(conf));
346                if conf >= self.threshold {
347                    result.description = clamp_str(desc, 500, &module.module_id, "description");
348                } else {
349                    result.warnings.push(format!(
350                        "Low confidence ({conf:.2}) for description — skipped. Review manually."
351                    ));
352                }
353            }
354        }
355
356        // Apply documentation
357        if gaps.iter().any(|g| g == "documentation") {
358            if let Some(doc) = parsed.get("documentation").and_then(|v| v.as_str()) {
359                let conf = parsed
360                    .get("confidence")
361                    .and_then(|c| c.get("documentation"))
362                    .and_then(|v| v.as_f64())
363                    .unwrap_or(0.0);
364                confidence.insert("documentation".into(), json!(conf));
365                if conf >= self.threshold {
366                    result.documentation = Some(strip_ansi(&clamp_str(
367                        doc,
368                        2000,
369                        &module.module_id,
370                        "documentation",
371                    )));
372                } else {
373                    result.warnings.push(format!(
374                        "Low confidence ({conf:.2}) for documentation — skipped. Review manually."
375                    ));
376                }
377            }
378        }
379
380        // Apply annotations if above threshold (per-field confidence)
381        if gaps.iter().any(|g| g == "annotations") {
382            if let Some(ann_data) = parsed.get("annotations").and_then(|v| v.as_object()) {
383                let ann_conf = parsed
384                    .get("confidence")
385                    .and_then(|v| v.as_object())
386                    .cloned()
387                    .unwrap_or_default();
388                let mut base = module.annotations.clone().unwrap_or_default();
389                let mut any_accepted = false;
390
391                // Iterate boolean fields supplied by the SLM directly.
392                // `set_bool_annotation` validates each field's existence
393                // on `ModuleAnnotations` via a serde round-trip, so the
394                // set of known bool fields lives in one place — the
395                // upstream struct — and new fields added upstream are
396                // picked up automatically.
397                for (field, field_val) in ann_data.iter() {
398                    let Some(bool_val) = field_val.as_bool() else {
399                        continue;
400                    };
401                    let field_conf = get_annotation_confidence(&ann_conf, field);
402                    confidence.insert(format!("annotations.{field}"), json!(field_conf));
403                    if field_conf >= self.threshold {
404                        if set_bool_annotation(&mut base, field, bool_val) {
405                            any_accepted = true;
406                        } else {
407                            result.warnings.push(format!(
408                                "SLM returned unknown bool annotation '{field}' — ignored."
409                            ));
410                        }
411                    } else {
412                        result.warnings.push(format!(
413                            "Low confidence ({field_conf:.2}) for annotations.{field} — skipped. Review manually."
414                        ));
415                    }
416                }
417
418                // Integer fields: cache_ttl
419                if let Some(val) = ann_data.get("cache_ttl").and_then(|v| v.as_u64()) {
420                    let field_conf = get_annotation_confidence(&ann_conf, "cache_ttl");
421                    confidence.insert("annotations.cache_ttl".into(), json!(field_conf));
422                    if field_conf >= self.threshold {
423                        base.cache_ttl = val;
424                        any_accepted = true;
425                    } else {
426                        result.warnings.push(format!(
427                            "Low confidence ({field_conf:.2}) for annotations.cache_ttl — skipped. Review manually."
428                        ));
429                    }
430                }
431
432                // String fields: pagination_style
433                if let Some(val) = ann_data.get("pagination_style").and_then(|v| v.as_str()) {
434                    let field_conf = get_annotation_confidence(&ann_conf, "pagination_style");
435                    confidence.insert("annotations.pagination_style".into(), json!(field_conf));
436                    if field_conf >= self.threshold {
437                        base.pagination_style = val.to_string();
438                        any_accepted = true;
439                    } else {
440                        result.warnings.push(format!(
441                            "Low confidence ({field_conf:.2}) for annotations.pagination_style — skipped. Review manually."
442                        ));
443                    }
444                }
445
446                // List fields: cache_key_fields
447                if let Some(arr) = ann_data.get("cache_key_fields").and_then(|v| v.as_array()) {
448                    let field_conf = get_annotation_confidence(&ann_conf, "cache_key_fields");
449                    confidence.insert("annotations.cache_key_fields".into(), json!(field_conf));
450                    if field_conf >= self.threshold {
451                        let keys: Vec<String> = arr
452                            .iter()
453                            .filter_map(|v| v.as_str().map(|s| s.to_string()))
454                            .collect();
455                        base.cache_key_fields = Some(keys);
456                        any_accepted = true;
457                    } else {
458                        result.warnings.push(format!(
459                            "Low confidence ({field_conf:.2}) for annotations.cache_key_fields — skipped. Review manually."
460                        ));
461                    }
462                }
463
464                if any_accepted {
465                    result.annotations = Some(base);
466                }
467            }
468        }
469
470        // Apply input_schema if above threshold
471        if gaps.iter().any(|g| g == "input_schema") {
472            if let Some(schema) = parsed.get("input_schema") {
473                let conf = parsed
474                    .get("confidence")
475                    .and_then(|c| c.get("input_schema"))
476                    .and_then(|v| v.as_f64())
477                    .unwrap_or(0.0);
478                confidence.insert("input_schema".into(), json!(conf));
479                if conf >= self.threshold {
480                    result.input_schema = schema.clone();
481                } else {
482                    result.warnings.push(format!(
483                        "Low confidence ({conf:.2}) for input_schema — skipped. Review manually."
484                    ));
485                }
486            }
487        }
488
489        // Tag AI-generated fields
490        if !confidence.is_empty() {
491            result
492                .metadata
493                .insert("x-generated-by".into(), Value::String("slm".into()));
494            result
495                .metadata
496                .insert("x-ai-confidence".into(), Value::Object(confidence));
497        }
498
499        Ok(result)
500    }
501}
502
503impl Enhancer for AIEnhancer {
504    fn enhance(&self, modules: Vec<ScannedModule>) -> Vec<ScannedModule> {
505        let mut results: Vec<ScannedModule> = Vec::with_capacity(modules.len());
506
507        let mut pending: Vec<(usize, Vec<String>)> = Vec::new();
508        for (idx, module) in modules.iter().enumerate() {
509            let gaps = self.identify_gaps(module);
510            results.push(module.clone());
511            if !gaps.is_empty() {
512                pending.push((idx, gaps));
513            }
514        }
515
516        // Note: batch_size controls chunk size for rate-limiting/progress tracking only.
517        // Processing is synchronous and sequential within each chunk — this differs from
518        // TypeScript which dispatches all pending modules concurrently via Promise.allSettled.
519        // Parallel dispatch (via rayon) is a planned enhancement.
520        for batch in pending.chunks(self.batch_size) {
521            for (idx, gaps) in batch {
522                match self.enhance_module(&modules[*idx], gaps) {
523                    Ok(enhanced) => results[*idx] = enhanced,
524                    Err(e) => {
525                        warn!("AI enhancement failed for {}: {e}", modules[*idx].module_id);
526                    }
527                }
528            }
529        }
530
531        results
532    }
533}
534
535/// Check whether annotations are at their default values.
536///
537/// Uses `serde_json` round-trip equality so the comparison automatically
538/// covers any new field added to `apcore::module::ModuleAnnotations` upstream
539/// (including the `extra` extension map). `ModuleAnnotations` does not
540/// implement `PartialEq`, so direct `==` is unavailable.
541fn is_default_annotations(ann: &ModuleAnnotations) -> bool {
542    match (
543        serde_json::to_value(ann),
544        serde_json::to_value(ModuleAnnotations::default()),
545    ) {
546        (Ok(a), Ok(b)) => a == b,
547        _ => false,
548    }
549}
550
551/// Get confidence for an annotation field, checking both `annotations.<field>` and `<field>` keys.
552fn get_annotation_confidence(conf: &serde_json::Map<String, Value>, field: &str) -> f64 {
553    conf.get(&format!("annotations.{field}"))
554        .or_else(|| conf.get(field))
555        .and_then(|v| v.as_f64())
556        .unwrap_or(0.0)
557}
558
559/// Set a boolean field on `ModuleAnnotations` by name via a serde
560/// round-trip. Returns `true` if the field exists on the struct and is a
561/// boolean; `false` if the field is unknown, non-boolean, or the
562/// round-trip fails. Using serde rather than a hardcoded match removes
563/// the two-list drift risk — new bool fields added to
564/// `apcore::module::ModuleAnnotations` upstream are picked up
565/// automatically.
566fn set_bool_annotation(ann: &mut ModuleAnnotations, field: &str, value: bool) -> bool {
567    let mut serialized = match serde_json::to_value(&ann) {
568        Ok(v) => v,
569        Err(e) => {
570            warn!("set_bool_annotation: serialize failed: {e}");
571            return false;
572        }
573    };
574    let Some(obj) = serialized.as_object_mut() else {
575        return false;
576    };
577    match obj.get(field) {
578        Some(Value::Bool(_)) => {
579            obj.insert(field.to_string(), Value::Bool(value));
580        }
581        // Field absent, or present but not a bool — reject rather than
582        // fabricate a new key (serde would happily accept unknown keys
583        // via `#[serde(extra)]` on ModuleAnnotations, but misclassifying
584        // a non-bool field as bool would corrupt the struct).
585        _ => return false,
586    }
587    match serde_json::from_value::<ModuleAnnotations>(serialized) {
588        Ok(new_ann) => {
589            *ann = new_ann;
590            true
591        }
592        Err(e) => {
593            warn!("set_bool_annotation: deserialize failed: {e}");
594            false
595        }
596    }
597}
598
599/// Validate that an endpoint URL uses an HTTP(S) scheme.
600///
601/// Matches the construction-time scheme validation performed by the Python
602/// and TypeScript SDKs (see `apcore-toolkit-python/src/apcore_toolkit/ai_enhancer.py`
603/// and `apcore-toolkit-typescript/src/ai-enhancer.ts`). Rejecting non-HTTP
604/// schemes (e.g. `file://`, `ftp://`) at construction prevents misleading
605/// connection errors later inside `call_llm` and removes a small but real
606/// SSRF-adjacent surface.
607fn validate_endpoint_scheme(endpoint: &str) -> Result<(), AIEnhancerError> {
608    // Manual parsing: we cannot rely on the `url` crate because it would
609    // add a transitive dependency when the optional `http-proxy` feature
610    // is disabled. The check is intentionally simple — extract the part
611    // before "://" and compare case-insensitively against the allowed set.
612    let Some(scheme_end) = endpoint.find("://") else {
613        return Err(AIEnhancerError::Config(format!(
614            "Invalid endpoint URL (missing scheme): {endpoint}"
615        )));
616    };
617    let scheme = &endpoint[..scheme_end];
618    if scheme.is_empty() {
619        return Err(AIEnhancerError::Config(format!(
620            "Invalid endpoint URL (empty scheme): {endpoint}"
621        )));
622    }
623    let scheme_lower = scheme.to_ascii_lowercase();
624    if scheme_lower != "http" && scheme_lower != "https" {
625        return Err(AIEnhancerError::Config(format!(
626            "Invalid endpoint URL scheme: {scheme}"
627        )));
628    }
629    Ok(())
630}
631
632fn parse_float_env(name: &str, default: f64) -> f64 {
633    match env::var(name) {
634        Ok(v) => v.parse().unwrap_or_else(|_| {
635            warn!(env_var = name, value = %v, "unparseable float env var — using default {default}");
636            default
637        }),
638        Err(_) => default,
639    }
640}
641
642fn parse_usize_env(name: &str, default: usize) -> usize {
643    match env::var(name) {
644        Ok(v) => v.parse().unwrap_or_else(|_| {
645            warn!(env_var = name, value = %v, "unparseable usize env var — using default {default}");
646            default
647        }),
648        Err(_) => default,
649    }
650}
651
652fn parse_u64_env(name: &str, default: u64) -> u64 {
653    match env::var(name) {
654        Ok(v) => v.parse().unwrap_or_else(|_| {
655            warn!(env_var = name, value = %v, "unparseable u64 env var — using default {default}");
656            default
657        }),
658        Err(_) => default,
659    }
660}
661
662/// Clamp an SLM-supplied string to `max_chars` bytes, warning if truncated.
663fn clamp_str(s: &str, max_chars: usize, module_id: &str, field: &str) -> String {
664    if s.len() <= max_chars {
665        return s.to_string();
666    }
667    // Truncate at a char boundary.
668    let truncated = &s[..s
669        .char_indices()
670        .take_while(|(i, _)| *i < max_chars)
671        .last()
672        .map(|(i, c)| i + c.len_utf8())
673        .unwrap_or(max_chars)];
674    tracing::warn!(
675        module_id = %module_id,
676        field = %field,
677        original_len = s.len(),
678        clamped_len = truncated.len(),
679        "SLM-supplied field truncated to prevent oversized output"
680    );
681    truncated.to_string()
682}
683
684/// Strip ANSI CSI escape sequences (ESC [ ... letter) from a string.
685fn strip_ansi(s: &str) -> String {
686    let mut out = String::with_capacity(s.len());
687    let mut chars = s.chars().peekable();
688    while let Some(c) = chars.next() {
689        if c == '\x1b' && chars.peek() == Some(&'[') {
690            chars.next(); // consume '['
691            for c2 in chars.by_ref() {
692                if c2.is_ascii_alphabetic() {
693                    break;
694                }
695            }
696        } else {
697            out.push(c);
698        }
699    }
700    out
701}
702
703#[cfg(test)]
704mod tests {
705    use super::*;
706    use apcore::module::ModuleAnnotations;
707    use serde_json::json;
708
709    #[test]
710    fn test_ai_enhancer_new_defaults() {
711        let enhancer = AIEnhancer::new(None, None, None, None, None).unwrap();
712        assert_eq!(enhancer.endpoint, DEFAULT_ENDPOINT);
713        assert_eq!(enhancer.model, DEFAULT_MODEL);
714        assert!((enhancer.threshold - DEFAULT_THRESHOLD).abs() < f64::EPSILON);
715        assert_eq!(enhancer.batch_size, DEFAULT_BATCH_SIZE);
716        assert_eq!(enhancer.timeout, DEFAULT_TIMEOUT);
717    }
718
719    #[test]
720    fn test_ai_enhancer_new_with_overrides() {
721        let enhancer = AIEnhancer::new(
722            Some("http://custom:8080".into()),
723            Some("llama3".into()),
724            Some(0.5),
725            Some(10),
726            Some(60),
727        )
728        .unwrap();
729        assert_eq!(enhancer.endpoint, "http://custom:8080");
730        assert_eq!(enhancer.model, "llama3");
731        assert!((enhancer.threshold - 0.5).abs() < f64::EPSILON);
732    }
733
734    #[test]
735    fn test_ai_enhancer_threshold_validation() {
736        let result = AIEnhancer::new(None, None, Some(1.5), None, None);
737        assert!(result.is_err());
738    }
739
740    #[test]
741    fn test_ai_enhancer_batch_size_validation() {
742        let result = AIEnhancer::new(None, None, None, Some(0), None);
743        assert!(result.is_err());
744    }
745
746    #[test]
747    fn test_identify_gaps_complete_module() {
748        let enhancer = AIEnhancer::new(None, None, None, None, None).unwrap();
749        let mut module = ScannedModule::new(
750            "test".into(),
751            "A real description".into(),
752            json!({"type": "object", "properties": {"x": {"type": "string"}}}),
753            json!({}),
754            vec![],
755            "app:func".into(),
756        );
757        module.documentation = Some("Full docs".into());
758        module.annotations = Some(ModuleAnnotations {
759            readonly: true,
760            ..Default::default()
761        });
762        let gaps = enhancer.identify_gaps(&module);
763        assert!(gaps.is_empty());
764    }
765
766    #[test]
767    fn test_identify_gaps_missing_fields() {
768        let enhancer = AIEnhancer::new(None, None, None, None, None).unwrap();
769        let module = ScannedModule::new(
770            "test".into(),
771            String::new(),
772            json!({"type": "object"}),
773            json!({}),
774            vec![],
775            "app:func".into(),
776        );
777        let gaps = enhancer.identify_gaps(&module);
778        assert!(gaps.iter().any(|g| g == "description"));
779        assert!(gaps.iter().any(|g| g == "documentation"));
780        assert!(gaps.iter().any(|g| g == "annotations"));
781        assert!(gaps.iter().any(|g| g == "input_schema"));
782    }
783
784    #[test]
785    fn test_parse_response_valid_json() {
786        let response = r#"{"description": "hello", "confidence": {"description": 0.9}}"#;
787        let result = AIEnhancer::parse_response(response).unwrap();
788        assert_eq!(result["description"], "hello");
789    }
790
791    #[test]
792    fn test_parse_response_with_fences() {
793        let response = "```json\n{\"key\": \"value\"}\n```";
794        let result = AIEnhancer::parse_response(response).unwrap();
795        assert_eq!(result["key"], "value");
796    }
797
798    #[test]
799    fn test_parse_response_invalid() {
800        let result = AIEnhancer::parse_response("not json");
801        assert!(result.is_err());
802    }
803
804    #[test]
805    fn test_is_enabled_default() {
806        // Assuming env var is not set in test environment
807        env::remove_var("APCORE_AI_ENABLED");
808        assert!(!AIEnhancer::is_enabled());
809    }
810
811    #[test]
812    fn test_build_prompt_contains_module_info() {
813        let enhancer = AIEnhancer::new(None, None, None, None, None).unwrap();
814        let module = ScannedModule::new(
815            "users.get".into(),
816            "Get user".into(),
817            json!({}),
818            json!({}),
819            vec![],
820            "app:get_user".into(),
821        );
822        let prompt = enhancer.build_prompt(&module, &["description".into()]);
823        assert!(prompt.contains("users.get"));
824        assert!(prompt.contains("app:get_user"));
825        assert!(prompt.contains("description"));
826    }
827
828    #[test]
829    fn test_identify_gaps_description_equals_module_id() {
830        let enhancer = AIEnhancer::new(None, None, None, None, None).unwrap();
831        let module = ScannedModule::new(
832            "my_module".into(),
833            "my_module".into(), // description == module_id
834            json!({"type": "object", "properties": {"x": {"type": "string"}}}),
835            json!({}),
836            vec![],
837            "app:func".into(),
838        );
839        let gaps = enhancer.identify_gaps(&module);
840        assert!(
841            gaps.iter().any(|g| g == "description"),
842            "description matching module_id should be identified as a gap"
843        );
844    }
845
846    #[test]
847    fn test_ai_enhancer_timeout_validation() {
848        let result = AIEnhancer::new(None, None, None, None, Some(0));
849        assert!(result.is_err());
850        let err = result.unwrap_err();
851        assert!(err
852            .to_string()
853            .contains("APCORE_AI_TIMEOUT must be a positive integer"));
854    }
855
856    // All is_enabled tests are combined into one to prevent env var races
857    // when tests run in parallel (env vars are process-global).
858    #[test]
859    fn test_is_enabled_variants() {
860        use std::sync::Mutex;
861        static ENV_LOCK: Mutex<()> = Mutex::new(());
862        let _guard = ENV_LOCK.lock().unwrap();
863
864        // Default (unset) → disabled
865        unsafe { env::remove_var("APCORE_AI_ENABLED") };
866        assert!(!AIEnhancer::is_enabled(), "should be disabled by default");
867
868        // "true" → enabled
869        unsafe { env::set_var("APCORE_AI_ENABLED", "true") };
870        assert!(AIEnhancer::is_enabled(), "\"true\" should enable");
871
872        // "yes" → enabled
873        unsafe { env::set_var("APCORE_AI_ENABLED", "yes") };
874        assert!(AIEnhancer::is_enabled(), "\"yes\" should enable");
875
876        // "1" → enabled
877        unsafe { env::set_var("APCORE_AI_ENABLED", "1") };
878        assert!(AIEnhancer::is_enabled(), "\"1\" should enable");
879
880        // "false" → disabled
881        unsafe { env::set_var("APCORE_AI_ENABLED", "false") };
882        assert!(!AIEnhancer::is_enabled(), "\"false\" should disable");
883
884        // Cleanup
885        unsafe { env::remove_var("APCORE_AI_ENABLED") };
886    }
887
888    #[test]
889    fn test_parse_response_strips_json_fence() {
890        let response = "```json\n{\"description\": \"hello world\"}\n```";
891        let result = AIEnhancer::parse_response(response).unwrap();
892        assert_eq!(result["description"], "hello world");
893    }
894
895    #[test]
896    fn test_build_prompt_requests_annotations() {
897        let enhancer = AIEnhancer::new(None, None, None, None, None).unwrap();
898        let module = ScannedModule::new(
899            "test".into(),
900            "desc".into(),
901            json!({}),
902            json!({}),
903            vec![],
904            "app:func".into(),
905        );
906        let prompt = enhancer.build_prompt(&module, &["annotations".into()]);
907        assert!(
908            prompt.contains("readonly"),
909            "prompt should mention annotations fields"
910        );
911        assert!(prompt.contains("destructive"));
912        assert!(prompt.contains("idempotent"));
913    }
914
915    #[test]
916    fn test_build_prompt_requests_input_schema() {
917        let enhancer = AIEnhancer::new(None, None, None, None, None).unwrap();
918        let module = ScannedModule::new(
919            "test".into(),
920            "desc".into(),
921            json!({}),
922            json!({}),
923            vec![],
924            "app:func".into(),
925        );
926        let prompt = enhancer.build_prompt(&module, &["input_schema".into()]);
927        assert!(
928            prompt.contains("input_schema"),
929            "prompt should mention input_schema"
930        );
931        assert!(prompt.contains("JSON Schema"));
932    }
933
934    #[test]
935    fn test_build_prompt_requests_documentation() {
936        let enhancer = AIEnhancer::new(None, None, None, None, None).unwrap();
937        let module = ScannedModule::new(
938            "test".into(),
939            "desc".into(),
940            json!({}),
941            json!({}),
942            vec![],
943            "app:func".into(),
944        );
945        let prompt = enhancer.build_prompt(&module, &["documentation".into()]);
946        assert!(
947            prompt.contains("documentation"),
948            "prompt should mention documentation"
949        );
950        assert!(prompt.contains("Markdown"));
951    }
952
953    #[test]
954    fn test_parse_response_single_line_fence_does_not_panic() {
955        // A single-line ``` response used to panic with lines[1..0].
956        let response = "```";
957        let result = AIEnhancer::parse_response(response);
958        assert!(result.is_err(), "single-line fence is not valid JSON");
959    }
960
961    #[test]
962    fn test_parse_response_backtick_only_line_treated_as_json() {
963        // Regression: must not panic, must return an error gracefully.
964        let response = "```\n```";
965        let result = AIEnhancer::parse_response(response);
966        // Empty string after stripping is invalid JSON.
967        assert!(result.is_err());
968    }
969
970    // ---- set_bool_annotation (serde round-trip, D4-1 regression guards) ----
971
972    #[test]
973    fn test_set_bool_annotation_readonly() {
974        let mut ann = ModuleAnnotations::default();
975        assert!(set_bool_annotation(&mut ann, "readonly", true));
976        assert!(ann.readonly);
977    }
978
979    #[test]
980    fn test_set_bool_annotation_destructive() {
981        let mut ann = ModuleAnnotations::default();
982        assert!(set_bool_annotation(&mut ann, "destructive", true));
983        assert!(ann.destructive);
984    }
985
986    #[test]
987    fn test_set_bool_annotation_unknown_field_rejected() {
988        let mut ann = ModuleAnnotations::default();
989        assert!(!set_bool_annotation(
990            &mut ann,
991            "nonexistent_field_xyz",
992            true
993        ));
994        // Annotations unchanged.
995        assert!(is_default_annotations(&ann));
996    }
997
998    #[test]
999    fn test_set_bool_annotation_non_bool_field_rejected() {
1000        let mut ann = ModuleAnnotations::default();
1001        // `cache_ttl` is an integer field on ModuleAnnotations.
1002        // Round-trip rejects setting it to a bool.
1003        assert!(!set_bool_annotation(&mut ann, "cache_ttl", true));
1004        assert_eq!(ann.cache_ttl, 0); // unchanged default
1005    }
1006
1007    #[test]
1008    fn test_set_bool_annotation_preserves_other_fields() {
1009        let mut ann = ModuleAnnotations {
1010            destructive: true,
1011            cache_ttl: 99,
1012            ..Default::default()
1013        };
1014        assert!(set_bool_annotation(&mut ann, "readonly", true));
1015        // Original fields survive the serde round-trip.
1016        assert!(ann.readonly);
1017        assert!(ann.destructive);
1018        assert_eq!(ann.cache_ttl, 99);
1019    }
1020
1021    #[test]
1022    fn test_clamp_str_under_limit() {
1023        let s = "hello";
1024        assert_eq!(clamp_str(s, 500, "mod", "desc"), "hello");
1025    }
1026
1027    #[test]
1028    fn test_clamp_str_over_limit_truncates() {
1029        let s = "a".repeat(600);
1030        let result = clamp_str(&s, 500, "mod", "desc");
1031        assert_eq!(result.len(), 500);
1032    }
1033
1034    #[test]
1035    fn test_clamp_str_unicode_boundary() {
1036        // "é" is 2 bytes — ensure we don't split it
1037        let s = "é".repeat(300); // 600 bytes
1038        let result = clamp_str(&s, 500, "mod", "desc");
1039        assert!(result.len() <= 500);
1040        assert!(std::str::from_utf8(result.as_bytes()).is_ok());
1041    }
1042
1043    #[test]
1044    fn test_strip_ansi_no_sequences() {
1045        assert_eq!(strip_ansi("hello world"), "hello world");
1046    }
1047
1048    #[test]
1049    fn test_strip_ansi_removes_color_codes() {
1050        let input = "\x1b[31mred text\x1b[0m";
1051        assert_eq!(strip_ansi(input), "red text");
1052    }
1053
1054    #[test]
1055    fn test_strip_ansi_mixed_content() {
1056        let input = "normal \x1b[1mbold\x1b[0m text";
1057        assert_eq!(strip_ansi(input), "normal bold text");
1058    }
1059
1060    // ---- Endpoint scheme validation (A-D-003 parity with Python/TS) ----
1061
1062    #[test]
1063    fn test_ai_enhancer_rejects_file_scheme() {
1064        let result = AIEnhancer::new(Some("file:///etc/passwd".into()), None, None, None, None);
1065        assert!(result.is_err(), "file:// scheme must be rejected");
1066        let err = result.unwrap_err().to_string();
1067        assert!(
1068            err.contains("Invalid endpoint URL scheme"),
1069            "error should call out invalid scheme, got: {err}"
1070        );
1071    }
1072
1073    #[test]
1074    fn test_ai_enhancer_rejects_ftp_scheme() {
1075        let result = AIEnhancer::new(Some("ftp://example.com".into()), None, None, None, None);
1076        assert!(result.is_err(), "ftp:// scheme must be rejected");
1077    }
1078
1079    #[test]
1080    fn test_ai_enhancer_rejects_missing_scheme() {
1081        let result = AIEnhancer::new(Some("localhost:11434".into()), None, None, None, None);
1082        assert!(result.is_err(), "URL without scheme must be rejected");
1083    }
1084
1085    #[test]
1086    fn test_ai_enhancer_accepts_http_scheme() {
1087        let result = AIEnhancer::new(
1088            Some("http://localhost:11434/v1".into()),
1089            None,
1090            None,
1091            None,
1092            None,
1093        );
1094        assert!(result.is_ok(), "http:// must be accepted");
1095    }
1096
1097    #[test]
1098    fn test_ai_enhancer_accepts_https_scheme() {
1099        let result = AIEnhancer::new(
1100            Some("https://api.example.com/v1".into()),
1101            None,
1102            None,
1103            None,
1104            None,
1105        );
1106        assert!(result.is_ok(), "https:// must be accepted");
1107    }
1108
1109    // ---- Dynamic annotation field discovery (A-D-002 parity with Python/TS) ----
1110
1111    #[test]
1112    fn test_annotation_field_names_match_struct() {
1113        let names = annotation_field_names();
1114        // Must include real ModuleAnnotations fields.
1115        assert!(names.iter().any(|n| n == "readonly"));
1116        assert!(names.iter().any(|n| n == "destructive"));
1117        assert!(names.iter().any(|n| n == "idempotent"));
1118        assert!(names.iter().any(|n| n == "cacheable"));
1119        assert!(names.iter().any(|n| n == "cache_ttl"));
1120        assert!(names.iter().any(|n| n == "paginated"));
1121        // Must NOT include phantom fields from the stale const list.
1122        assert!(!names.iter().any(|n| n == "tags"));
1123        assert!(!names.iter().any(|n| n == "version"));
1124        assert!(!names.iter().any(|n| n == "category"));
1125        assert!(!names.iter().any(|n| n == "requires_confirmation"));
1126        assert!(!names.iter().any(|n| n == "long_running"));
1127        // Must NOT include the open-extension map key.
1128        assert!(!names.iter().any(|n| n == "extra"));
1129        // No duplicates.
1130        let mut sorted = names.clone();
1131        sorted.sort();
1132        sorted.dedup();
1133        assert_eq!(sorted.len(), names.len(), "field names must be unique");
1134    }
1135}