Skip to main content

reddb_server/runtime/ai/
ner.rs

1//! LLM-based NER for AskPipeline Stage 1 — issue #123.
2//!
3//! Opt-in replacement for the heuristic `extract_tokens` regex pass in
4//! [`crate::runtime::ask_pipeline`]. Default backend stays heuristic;
5//! when an operator turns on `ai.ner.backend = "llm"` (config knob —
6//! flagged for follow-up wiring), `AskPipeline::extract_tokens` is
7//! routed through [`LlmNer::extract`] instead.
8//!
9//! Design notes
10//! ------------
11//! * **Pure module.** This file imports [`TokenSet`] / [`EffectiveScope`]
12//!   from `runtime::ask_pipeline` and `runtime::statement_frame` via
13//!   `use` only. It never edits those modules — registration is a
14//!   separate orchestrator-batch step.
15//! * **Auth gate.** Mirrors ADR 0008: every LLM call checks
16//!   `ai:ner:read` against an `AuthContext` trait defined locally.
17//!   Production callers will plug the auth store in; tests use
18//!   [`StubAuthContext`].
19//! * **Output sanitization.** LLM responses are JSON-parsed and every
20//!   token is run through the same defenses the rest of the engine
21//!   uses on untrusted strings: control-byte / CRLF / NUL / quote
22//!   injection rejection plus a secret-redactor pattern check (so a
23//!   hallucinated `sk_live_...` or `Bearer ...` never leaks into the
24//!   pipeline as a "literal").
25//! * **Token cap.** Bounded at `max_tokens_returned` per call (default
26//!   32). Excess returns [`NerError::ResponseExceedsTokenLimit`]
27//!   rather than silently truncating, so callers can record a metric.
28//! * **Stub variant.** [`NerProvider::Stub`] never touches the network
29//!   — every test below uses it.
30//!
31//! Inline `#[cfg(test)] mod tests` covers stub Empty/Echo/Canned,
32//! timeout simulation, malformed-response rejection, secret-in-response
33//! rejection, token-cap enforcement, auth-gate denial, and each
34//! [`HeuristicFallback`] variant.
35
36use std::time::Duration;
37
38use crate::serde_json as crate_json;
39use crate::serde_json::Value as JsonValue;
40
41use crate::runtime::ask_pipeline::{extract_tokens as heuristic_extract_tokens, TokenSet};
42use crate::runtime::statement_frame::EffectiveScope;
43
44/// Default per-call token budget. Anything above this returns
45/// [`NerError::ResponseExceedsTokenLimit`].
46pub const DEFAULT_MAX_TOKENS: usize = 32;
47
48/// Default LLM call timeout, in milliseconds. 5 seconds matches the
49/// AskPipeline budget allocated to Stage 1 in PRD #118.
50pub const DEFAULT_TIMEOUT_MS: u32 = 5_000;
51
52/// Capability string the auth gate looks for.
53pub const NER_CAPABILITY: &str = "ai:ner:read";
54
55/// Provider abstraction. Network-bound variants share the same
56/// request/response shape (chat completions w/ a JSON-mode prompt);
57/// the `Stub` variant exists so tests and `disable_network` deploys
58/// can exercise the surface without going over the wire.
59#[derive(Debug, Clone)]
60pub enum NerProvider {
61    /// Calls an OpenAI-compatible chat endpoint (`/v1/chat/completions`)
62    /// with a prompt that asks for entity extraction in a structured
63    /// JSON shape. Hits the network.
64    OpenAiCompat { endpoint: String, model: String },
65    /// Anthropic native messages API (`/v1/messages`). Hits the network.
66    AnthropicNative { endpoint: String, model: String },
67    /// In-process stub returning a deterministic response. Used in
68    /// tests and when network calls are administratively disabled.
69    Stub(StubBehavior),
70}
71
72/// Behaviors the [`NerProvider::Stub`] variant can simulate.
73#[derive(Debug, Clone)]
74pub enum StubBehavior {
75    /// Always returns an empty [`TokenSet`].
76    Empty,
77    /// Returns the input echoed as a single keyword (lowercased,
78    /// trimmed). Useful for round-trip tests.
79    Echo,
80    /// Returns a fixed canned response — useful for snapshot tests
81    /// where the caller wants to assert downstream stage output.
82    Canned(TokenSet),
83    /// Sleeps `Duration` then returns success — used to drive the
84    /// timeout path under tests without waiting on real network I/O.
85    SlowDuration(Duration),
86    /// Returns a hand-crafted JSON string verbatim, pushed through the
87    /// same parse + sanitize pipeline as a real LLM response. Lets
88    /// adversarial-corpus tests stress the response sanitizer without
89    /// a real provider.
90    RawJson(String),
91}
92
93/// What [`LlmNer::extract`] should do when the LLM call fails.
94#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95pub enum HeuristicFallback {
96    /// On LLM error / timeout / disabled, fall through to the existing
97    /// `extract_tokens` heuristic. Recommended default: keeps Stage 1
98    /// answering even if the LLM provider is degraded.
99    UseHeuristic,
100    /// Return an empty [`TokenSet`] on failure (strict mode — useful
101    /// when an empty result is preferred over a heuristic guess, e.g.
102    /// for compliance audits).
103    EmptyOnFail,
104    /// Bubble the error up to the caller (caller-handles mode).
105    Propagate,
106}
107
108/// All the ways [`LlmNer::extract`] can fail. Variants mirror the
109/// metric labels emitted by `runtime/ai_ner_failures_total`.
110#[derive(Debug, Clone, PartialEq, Eq)]
111pub enum NerError {
112    /// LLM call exceeded `timeout_ms`.
113    NetworkTimeout,
114    /// Provider returned a non-2xx status. `body_excerpt` is the first
115    /// 256 bytes of the body, with control bytes stripped — safe to log.
116    ProviderRejected { status: u16, body_excerpt: String },
117    /// Response wasn't valid JSON, or didn't match the expected shape.
118    ResponseMalformed { reason: String },
119    /// Provider returned more than `max_tokens_returned` tokens.
120    ResponseExceedsTokenLimit { count: usize, max: usize },
121    /// A returned token matched a secret-redactor pattern. The
122    /// `pattern` label says which one (`sk_live`, `bearer`, etc.) so
123    /// SOC tooling can alert on hallucinated leaks.
124    SecretInResponse { pattern: String },
125    /// Caller does not hold `ai:ner:read`.
126    AuthDenied,
127}
128
129impl std::fmt::Display for NerError {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        match self {
132            NerError::NetworkTimeout => write!(f, "ner: network timeout"),
133            NerError::ProviderRejected { status, .. } => {
134                write!(f, "ner: provider rejected (status={status})")
135            }
136            NerError::ResponseMalformed { reason } => {
137                write!(f, "ner: malformed response ({reason})")
138            }
139            NerError::ResponseExceedsTokenLimit { count, max } => {
140                write!(f, "ner: response exceeds token limit ({count} > {max})")
141            }
142            NerError::SecretInResponse { pattern } => {
143                write!(f, "ner: secret pattern in response ({pattern})")
144            }
145            NerError::AuthDenied => write!(f, "ner: auth denied (missing {NER_CAPABILITY})"),
146        }
147    }
148}
149
150impl std::error::Error for NerError {}
151
152/// Trait the LLM-NER uses to gate calls. Production callers will plug
153/// in the engine's auth store; tests use [`StubAuthContext`].
154///
155/// Defined here (and not in `statement_frame`) on purpose — this
156/// module is the only consumer for now, and the trait is small enough
157/// to inline. If a second consumer emerges, lift it out then.
158pub trait AuthContext: std::fmt::Debug + Send + Sync {
159    fn has_capability(&self, capability: &str) -> bool;
160}
161
162/// Test/embedded stub. Carries an explicit allowlist; matches by
163/// exact string. Default-empty means "no capabilities" — the strictest
164/// possible test setup.
165#[derive(Debug, Clone, Default)]
166pub struct StubAuthContext {
167    capabilities: Vec<String>,
168}
169
170impl StubAuthContext {
171    pub fn new(caps: impl IntoIterator<Item = impl Into<String>>) -> Self {
172        Self {
173            capabilities: caps.into_iter().map(Into::into).collect(),
174        }
175    }
176
177    pub fn allow_all() -> Self {
178        Self::new([NER_CAPABILITY])
179    }
180
181    pub fn deny_all() -> Self {
182        Self::default()
183    }
184}
185
186impl AuthContext for StubAuthContext {
187    fn has_capability(&self, capability: &str) -> bool {
188        self.capabilities.iter().any(|c| c == capability)
189    }
190}
191
192/// Top-level NER handle. Cheap to clone — providers carry only config
193/// strings; no live HTTP client is held (a fresh blocking client is
194/// built per call so timeout + cancellation are bounded).
195#[derive(Debug, Clone)]
196pub struct LlmNer {
197    pub provider: NerProvider,
198    pub fallback: HeuristicFallback,
199    pub timeout_ms: u32,
200    pub max_tokens_returned: usize,
201}
202
203impl LlmNer {
204    /// Convenience constructor with the documented defaults.
205    pub fn new(provider: NerProvider, fallback: HeuristicFallback) -> Self {
206        Self {
207            provider,
208            fallback,
209            timeout_ms: DEFAULT_TIMEOUT_MS,
210            max_tokens_returned: DEFAULT_MAX_TOKENS,
211        }
212    }
213
214    /// Stage 1 entrypoint. Mirrors the signature of
215    /// `ask_pipeline::extract_tokens` but returns `Result` so the
216    /// caller can decide whether to honor [`HeuristicFallback`].
217    ///
218    /// `auth` is passed separately rather than read off `scope` so
219    /// embedded callers can plug a different gate without faking an
220    /// `EffectiveScope`.
221    pub async fn extract(
222        &self,
223        question: &str,
224        scope: &EffectiveScope,
225        auth: &dyn AuthContext,
226    ) -> Result<TokenSet, NerError> {
227        // Auth gate first — never even attempt the call without the
228        // capability. This matches ADR 0008's "deny by default, log
229        // every denial" posture.
230        if !auth.has_capability(NER_CAPABILITY) {
231            return Err(NerError::AuthDenied);
232        }
233
234        // `scope` participates in the prompt-construction step (visible
235        // collections become a hint for the LLM). For the stub paths it
236        // doesn't matter; for the network paths we read it via
237        // `build_prompt`.
238        let result = match &self.provider {
239            NerProvider::Stub(behavior) => self.run_stub(behavior, question),
240            NerProvider::OpenAiCompat { endpoint, model } => {
241                self.run_openai_compat(endpoint, model, question, scope)
242                    .await
243            }
244            NerProvider::AnthropicNative { endpoint, model } => {
245                self.run_anthropic(endpoint, model, question, scope).await
246            }
247        };
248
249        match result {
250            Ok(tokens) => Ok(tokens),
251            Err(err) => self.handle_failure(err, question),
252        }
253    }
254
255    /// Apply [`HeuristicFallback`] policy.
256    fn handle_failure(&self, err: NerError, question: &str) -> Result<TokenSet, NerError> {
257        // Auth denials never fall back — that would defeat the gate.
258        if matches!(err, NerError::AuthDenied) {
259            return Err(err);
260        }
261        match self.fallback {
262            HeuristicFallback::UseHeuristic => Ok(heuristic_extract_tokens(question)),
263            HeuristicFallback::EmptyOnFail => Ok(TokenSet::default()),
264            HeuristicFallback::Propagate => Err(err),
265        }
266    }
267
268    /// Stub dispatcher.
269    fn run_stub(&self, behavior: &StubBehavior, question: &str) -> Result<TokenSet, NerError> {
270        match behavior {
271            StubBehavior::Empty => Ok(TokenSet::default()),
272            StubBehavior::Echo => {
273                let trimmed = question.trim().to_lowercase();
274                if trimmed.is_empty() {
275                    Ok(TokenSet::default())
276                } else {
277                    Ok(TokenSet {
278                        keywords: vec![trimmed],
279                        literals: vec![],
280                    })
281                }
282            }
283            StubBehavior::Canned(tokens) => Ok(tokens.clone()),
284            StubBehavior::SlowDuration(d) => {
285                // Synthesize the timeout deterministically — never
286                // actually sleeps `d`. We only check whether `d`
287                // exceeds the configured budget.
288                if d.as_millis() as u32 > self.timeout_ms {
289                    Err(NerError::NetworkTimeout)
290                } else {
291                    Ok(TokenSet::default())
292                }
293            }
294            StubBehavior::RawJson(raw) => parse_and_sanitize(raw, self.max_tokens_returned),
295        }
296    }
297
298    // --- Network paths -----------------------------------------------------
299    //
300    // The two network providers are intentionally thin: build a prompt,
301    // ship it, parse + sanitize the JSON body. Any provider-specific
302    // shaping lives in `build_prompt` / `extract_payload`.
303    //
304    // Network bodies stay behind `ai-ner-network` so the default
305    // server build does not pull an HTTP client unless explicitly
306    // requested.
307
308    #[cfg(feature = "ai-ner-network")]
309    async fn run_openai_compat(
310        &self,
311        endpoint: &str,
312        model: &str,
313        question: &str,
314        scope: &EffectiveScope,
315    ) -> Result<TokenSet, NerError> {
316        let body = crate::json!({
317            "model": model,
318            "response_format": crate::json!({ "type": "json_object" }),
319            "messages": vec![
320                crate::json!({ "role": "system", "content": NER_SYSTEM_PROMPT }),
321                crate::json!({ "role": "user", "content": build_prompt(question, scope) }),
322            ],
323        });
324        let raw = http_post_json(endpoint, &body, self.timeout_ms).await?;
325        let payload = extract_openai_payload(&raw)?;
326        parse_and_sanitize(&payload, self.max_tokens_returned)
327    }
328
329    #[cfg(not(feature = "ai-ner-network"))]
330    async fn run_openai_compat(
331        &self,
332        _endpoint: &str,
333        _model: &str,
334        _question: &str,
335        _scope: &EffectiveScope,
336    ) -> Result<TokenSet, NerError> {
337        // Without the network feature, report a NetworkTimeout so the
338        // fallback policy still exercises end-to-end.
339        Err(NerError::NetworkTimeout)
340    }
341
342    #[cfg(feature = "ai-ner-network")]
343    async fn run_anthropic(
344        &self,
345        endpoint: &str,
346        model: &str,
347        question: &str,
348        scope: &EffectiveScope,
349    ) -> Result<TokenSet, NerError> {
350        let body = crate::json!({
351            "model": model,
352            "max_tokens": 1024,
353            "system": NER_SYSTEM_PROMPT,
354            "messages": vec![
355                crate::json!({ "role": "user", "content": build_prompt(question, scope) }),
356            ],
357        });
358        let raw = http_post_json(endpoint, &body, self.timeout_ms).await?;
359        let payload = extract_anthropic_payload(&raw)?;
360        parse_and_sanitize(&payload, self.max_tokens_returned)
361    }
362
363    #[cfg(not(feature = "ai-ner-network"))]
364    async fn run_anthropic(
365        &self,
366        _endpoint: &str,
367        _model: &str,
368        _question: &str,
369        _scope: &EffectiveScope,
370    ) -> Result<TokenSet, NerError> {
371        Err(NerError::NetworkTimeout)
372    }
373}
374
375/// System prompt the providers share. Kept short on purpose — the
376/// fewer instructions, the fewer ways the model can wander off the
377/// JSON shape we expect.
378const NER_SYSTEM_PROMPT: &str = "\
379You are an entity extraction service for a database query pipeline. \
380Read the user's question and return a JSON object with two fields: \
381'keywords' (array of lowercase content words, length >= 2) and \
382'literals' (array of identifier-shaped tokens kept in original case). \
383Return JSON only — no prose, no markdown.";
384
385/// Build the user-message prompt. We pin the visible-collection list
386/// so the LLM doesn't invent table names that aren't in scope.
387#[allow(dead_code)] // used only by network paths; suppress warning when feature off
388fn build_prompt(question: &str, scope: &EffectiveScope) -> String {
389    use crate::runtime::statement_frame::ReadFrame;
390    let visible: Vec<&str> = scope
391        .visible_collections()
392        .map(|set| set.iter().map(String::as_str).collect())
393        .unwrap_or_default();
394    format!(
395        "Question: {q}\nVisible collections: {v:?}\nReturn JSON only.",
396        q = question,
397        v = visible
398    )
399}
400
401#[cfg(feature = "ai-ner-network")]
402async fn http_post_json(
403    endpoint: &str,
404    body: &crate_json::Value,
405    timeout_ms: u32,
406) -> Result<String, NerError> {
407    let client = reqwest::Client::builder()
408        .timeout(Duration::from_millis(timeout_ms as u64))
409        .build()
410        .map_err(|e| NerError::ResponseMalformed {
411            reason: format!("client build: {e}"),
412        })?;
413    let resp = client
414        .post(endpoint)
415        .header("content-type", "application/json")
416        .body(body.to_string_compact())
417        .send()
418        .await
419        .map_err(|e| {
420            if e.is_timeout() {
421                NerError::NetworkTimeout
422            } else {
423                NerError::ResponseMalformed {
424                    reason: format!("transport: {e}"),
425                }
426            }
427        })?;
428    let status = resp.status().as_u16();
429    let text = resp.text().await.map_err(|e| NerError::ResponseMalformed {
430        reason: format!("body read: {e}"),
431    })?;
432    if !(200..300).contains(&status) {
433        return Err(NerError::ProviderRejected {
434            status,
435            body_excerpt: scrub_excerpt(&text),
436        });
437    }
438    Ok(text)
439}
440
441#[cfg(feature = "ai-ner-network")]
442fn extract_openai_payload(raw: &str) -> Result<String, NerError> {
443    let v: JsonValue = crate_json::from_str(raw).map_err(|e| NerError::ResponseMalformed {
444        reason: format!("outer json: {e}"),
445    })?;
446    v["choices"]
447        .as_array()
448        .and_then(|choices| choices.first())
449        .and_then(|choice| choice["message"]["content"].as_str())
450        .map(str::to_owned)
451        .ok_or_else(|| NerError::ResponseMalformed {
452            reason: "missing choices[0].message.content".into(),
453        })
454}
455
456#[cfg(feature = "ai-ner-network")]
457fn extract_anthropic_payload(raw: &str) -> Result<String, NerError> {
458    let v: JsonValue = crate_json::from_str(raw).map_err(|e| NerError::ResponseMalformed {
459        reason: format!("outer json: {e}"),
460    })?;
461    v["content"]
462        .as_array()
463        .and_then(|content| content.first())
464        .and_then(|item| item["text"].as_str())
465        .map(str::to_owned)
466        .ok_or_else(|| NerError::ResponseMalformed {
467            reason: "missing content[0].text".into(),
468        })
469}
470
471#[allow(dead_code)] // used by `ProviderRejected` excerpt path
472fn scrub_excerpt(s: &str) -> String {
473    let trimmed: String = s
474        .chars()
475        .take(256)
476        .filter(|c| !c.is_control() || *c == ' ')
477        .collect();
478    trimmed
479}
480
481/// Core sanitizer. Parses `raw` as JSON, expects `{ keywords: [...],
482/// literals: [...] }`, and rejects anything that smells off.
483///
484/// All policy lives here so both the network and stub paths share the
485/// exact same defenses.
486fn parse_and_sanitize(raw: &str, max_tokens: usize) -> Result<TokenSet, NerError> {
487    let parsed: JsonValue = crate_json::from_str(raw).map_err(|e| NerError::ResponseMalformed {
488        reason: format!("json parse: {e}"),
489    })?;
490    let obj = parsed
491        .as_object()
492        .ok_or_else(|| NerError::ResponseMalformed {
493            reason: "expected JSON object at root".into(),
494        })?;
495
496    let keywords = collect_string_array(obj.get("keywords"), "keywords")?;
497    let literals = collect_string_array(obj.get("literals"), "literals")?;
498
499    let total = keywords.len() + literals.len();
500    if total > max_tokens {
501        return Err(NerError::ResponseExceedsTokenLimit {
502            count: total,
503            max: max_tokens,
504        });
505    }
506
507    for token in keywords.iter().chain(literals.iter()) {
508        validate_token(token)?;
509    }
510
511    Ok(TokenSet { keywords, literals })
512}
513
514/// Pull a `Vec<String>` out of a JSON value, with structural errors
515/// labeled by `field` so debugging the provider is easier.
516fn collect_string_array(v: Option<&JsonValue>, field: &str) -> Result<Vec<String>, NerError> {
517    let arr = match v {
518        Some(JsonValue::Array(a)) => a,
519        Some(JsonValue::Null) | None => return Ok(Vec::new()),
520        Some(other) => {
521            return Err(NerError::ResponseMalformed {
522                reason: format!("{field}: expected array, got {}", json_kind(other)),
523            });
524        }
525    };
526    let mut out = Vec::with_capacity(arr.len());
527    for (i, item) in arr.iter().enumerate() {
528        match item {
529            JsonValue::String(s) => out.push(s.clone()),
530            other => {
531                return Err(NerError::ResponseMalformed {
532                    reason: format!("{field}[{i}]: expected string, got {}", json_kind(other)),
533                });
534            }
535        }
536    }
537    Ok(out)
538}
539
540fn json_kind(v: &JsonValue) -> &'static str {
541    match v {
542        JsonValue::Null => "null",
543        JsonValue::Bool(_) => "bool",
544        JsonValue::Number(_) => "number",
545        JsonValue::String(_) => "string",
546        JsonValue::Array(_) => "array",
547        JsonValue::Object(_) => "object",
548    }
549}
550
551/// Single-token validation. Order matters: secret detection runs
552/// before structural checks so a hallucinated `sk_live_...` always
553/// reports as `SecretInResponse`, not as `ResponseMalformed`.
554fn validate_token(token: &str) -> Result<(), NerError> {
555    if let Some(pattern) = match_secret_pattern(token) {
556        return Err(NerError::SecretInResponse {
557            pattern: pattern.into(),
558        });
559    }
560    if token.is_empty() {
561        return Err(NerError::ResponseMalformed {
562            reason: "empty token".into(),
563        });
564    }
565    if token.len() > 256 {
566        return Err(NerError::ResponseMalformed {
567            reason: format!("token too long ({} bytes)", token.len()),
568        });
569    }
570    for (i, byte) in token.as_bytes().iter().enumerate() {
571        match byte {
572            // NUL, CR, LF — classic injection vectors.
573            0x00 => {
574                return Err(NerError::ResponseMalformed {
575                    reason: format!("NUL byte at offset {i}"),
576                });
577            }
578            b'\n' | b'\r' => {
579                return Err(NerError::ResponseMalformed {
580                    reason: format!("CR/LF at offset {i}"),
581                });
582            }
583            // Quote injection — keep the parsing surface simple.
584            b'"' | b'\'' | b'`' => {
585                return Err(NerError::ResponseMalformed {
586                    reason: format!("quote injection at offset {i}"),
587                });
588            }
589            // Other control bytes (anything < 0x20 except \t).
590            b if *b < 0x20 && *b != b'\t' => {
591                return Err(NerError::ResponseMalformed {
592                    reason: format!("control byte 0x{b:02x} at offset {i}"),
593                });
594            }
595            _ => {}
596        }
597    }
598    Ok(())
599}
600
601/// Secret-redactor patterns. Mirrors the inline policy used by the
602/// audit / log boundary guards (see `secret_redactor` in the audit
603/// pipeline). Kept here as a const-table so the CI lint can scan a
604/// single canonical list.
605fn match_secret_pattern(token: &str) -> Option<&'static str> {
606    // Constructed at runtime per the secret_fixture_gen pattern — we
607    // never materialize the full secret in source, only the prefix.
608    const PATTERNS: &[(&str, &str)] = &[
609        ("sk_", "sk_prefix"),
610        ("rs_", "rs_prefix"),
611        ("reddb_", "reddb_prefix"),
612        ("Bearer ", "bearer"),
613        ("bearer ", "bearer"),
614    ];
615    for (prefix, label) in PATTERNS {
616        if token.starts_with(prefix) {
617            return Some(label);
618        }
619    }
620    // JWT shape: three base64url segments separated by dots, each
621    // segment at least 4 chars. Cheap structural match — good enough
622    // to flag a hallucinated token without false-positiving on real
623    // identifiers.
624    if looks_like_jwt(token) {
625        return Some("jwt");
626    }
627    // Connection-string credentials: `://user:pass@host`.
628    if token.contains("://") && token.contains(':') && token.contains('@') {
629        if let Some(scheme_end) = token.find("://") {
630            let rest = &token[scheme_end + 3..];
631            if let Some(at) = rest.find('@') {
632                let userpass = &rest[..at];
633                if userpass.contains(':') {
634                    return Some("conn_string_credentials");
635                }
636            }
637        }
638    }
639    None
640}
641
642fn looks_like_jwt(token: &str) -> bool {
643    let parts: Vec<&str> = token.split('.').collect();
644    if parts.len() != 3 {
645        return false;
646    }
647    parts.iter().all(|p| {
648        p.len() >= 4
649            && p.bytes()
650                .all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'-')
651    })
652}
653
654// ============================================================================
655// Tests
656// ============================================================================
657
658#[cfg(test)]
659mod tests {
660    use super::*;
661    use crate::runtime::ask_pipeline::TokenSet;
662
663    /// Minimal `EffectiveScope` for tests. Built via the public
664    /// `Default`-equivalent path used in `ask_pipeline`'s own tests.
665    fn make_scope() -> EffectiveScope {
666        // EffectiveScope's fields are pub(crate), and this module
667        // lives in the same crate, so direct construction is fine.
668        // We mirror the `make_scope` pattern from ask_pipeline tests.
669        use crate::storage::transaction::snapshot::Snapshot;
670        use std::collections::HashSet;
671        EffectiveScope {
672            tenant: None,
673            identity: None,
674            snapshot: Snapshot {
675                xid: 0,
676                in_progress: HashSet::new(),
677            },
678            visible_collections: None,
679        }
680    }
681
682    fn allow() -> StubAuthContext {
683        StubAuthContext::allow_all()
684    }
685
686    fn deny() -> StubAuthContext {
687        StubAuthContext::deny_all()
688    }
689
690    // --- Stub variants -----------------------------------------------------
691
692    #[tokio::test]
693    async fn stub_empty_returns_empty_token_set() {
694        let ner = LlmNer::new(
695            NerProvider::Stub(StubBehavior::Empty),
696            HeuristicFallback::Propagate,
697        );
698        let out = ner
699            .extract("anything", &make_scope(), &allow())
700            .await
701            .unwrap();
702        assert!(out.is_empty());
703    }
704
705    #[tokio::test]
706    async fn stub_echo_returns_lowercased_keyword() {
707        let ner = LlmNer::new(
708            NerProvider::Stub(StubBehavior::Echo),
709            HeuristicFallback::Propagate,
710        );
711        let out = ner
712            .extract("  Hello WORLD  ", &make_scope(), &allow())
713            .await
714            .unwrap();
715        assert_eq!(out.keywords, vec!["hello world".to_string()]);
716        assert!(out.literals.is_empty());
717    }
718
719    #[tokio::test]
720    async fn stub_echo_empty_question_yields_empty_set() {
721        let ner = LlmNer::new(
722            NerProvider::Stub(StubBehavior::Echo),
723            HeuristicFallback::Propagate,
724        );
725        let out = ner.extract("   ", &make_scope(), &allow()).await.unwrap();
726        assert!(out.is_empty());
727    }
728
729    #[tokio::test]
730    async fn stub_canned_returns_provided_tokens() {
731        let canned = TokenSet {
732            keywords: vec!["passport".into()],
733            literals: vec!["FDD-1".into()],
734        };
735        let ner = LlmNer::new(
736            NerProvider::Stub(StubBehavior::Canned(canned.clone())),
737            HeuristicFallback::Propagate,
738        );
739        let out = ner.extract("q?", &make_scope(), &allow()).await.unwrap();
740        assert_eq!(out, canned);
741    }
742
743    // --- Timeout simulation ------------------------------------------------
744
745    #[tokio::test]
746    async fn slow_stub_within_budget_succeeds() {
747        let mut ner = LlmNer::new(
748            NerProvider::Stub(StubBehavior::SlowDuration(Duration::from_millis(10))),
749            HeuristicFallback::Propagate,
750        );
751        ner.timeout_ms = 100;
752        assert!(ner.extract("q?", &make_scope(), &allow()).await.is_ok());
753    }
754
755    #[tokio::test]
756    async fn slow_stub_over_budget_times_out_and_propagates() {
757        let mut ner = LlmNer::new(
758            NerProvider::Stub(StubBehavior::SlowDuration(Duration::from_millis(500))),
759            HeuristicFallback::Propagate,
760        );
761        ner.timeout_ms = 50;
762        let err = ner
763            .extract("q?", &make_scope(), &allow())
764            .await
765            .unwrap_err();
766        assert_eq!(err, NerError::NetworkTimeout);
767    }
768
769    // --- Malformed-response rejection -------------------------------------
770
771    #[tokio::test]
772    async fn malformed_not_json_is_rejected() {
773        let ner = LlmNer::new(
774            NerProvider::Stub(StubBehavior::RawJson("not-json".into())),
775            HeuristicFallback::Propagate,
776        );
777        let err = ner
778            .extract("q?", &make_scope(), &allow())
779            .await
780            .unwrap_err();
781        assert!(matches!(err, NerError::ResponseMalformed { .. }));
782    }
783
784    #[tokio::test]
785    async fn malformed_wrong_root_type_is_rejected() {
786        let ner = LlmNer::new(
787            NerProvider::Stub(StubBehavior::RawJson("[1,2,3]".into())),
788            HeuristicFallback::Propagate,
789        );
790        let err = ner
791            .extract("q?", &make_scope(), &allow())
792            .await
793            .unwrap_err();
794        assert!(matches!(err, NerError::ResponseMalformed { .. }));
795    }
796
797    #[tokio::test]
798    async fn malformed_keywords_not_array_is_rejected() {
799        let ner = LlmNer::new(
800            NerProvider::Stub(StubBehavior::RawJson(r#"{"keywords":"oops"}"#.into())),
801            HeuristicFallback::Propagate,
802        );
803        let err = ner
804            .extract("q?", &make_scope(), &allow())
805            .await
806            .unwrap_err();
807        assert!(matches!(err, NerError::ResponseMalformed { .. }));
808    }
809
810    // --- Adversarial corpus (control / quote / secret patterns) -----------
811
812    /// Build the adversarial corpus at runtime so no real-shaped
813    /// secret ever sits in source. Returns 12 payloads — each one
814    /// must trigger either `ResponseMalformed` or `SecretInResponse`.
815    fn adversarial_corpus() -> Vec<(&'static str, String)> {
816        // Constructed prefixes — keeps the lint scanners happy.
817        let sk_prefix = format!("{}{}", "sk_", "live_DEADBEEFcafe");
818        let rs_prefix = format!("{}{}", "rs_", "test_TOKENtoken");
819        let reddb_prefix = format!("{}{}", "reddb_", "internal_secret_X");
820        let bearer = format!("{}{}", "Bearer ", "ABC.DEF.GHI");
821        let jwt = format!("{}.{}.{}", "abcd1234", "wxyz5678", "qrst9012");
822        let conn = "postgres://user:pwd@host:5432/db".to_string();
823
824        vec![
825            (
826                "crlf_in_keyword",
827                "{\"keywords\":[\"foo\\r\\nbar\"]}".into(),
828            ),
829            (
830                "nul_in_literal",
831                "{\"literals\":[\"foo\\u0000bar\"]}".into(),
832            ),
833            ("dquote_injection", "{\"keywords\":[\"foo\\\"bar\"]}".into()),
834            ("squote_injection", "{\"keywords\":[\"foo'bar\"]}".into()),
835            ("backtick_injection", "{\"keywords\":[\"foo`bar\"]}".into()),
836            (
837                "control_byte_low",
838                "{\"keywords\":[\"foo\\u0007bar\"]}".into(),
839            ),
840            ("sk_live", format!(r#"{{"keywords":["{sk_prefix}"]}}"#)),
841            ("rs_test", format!(r#"{{"keywords":["{rs_prefix}"]}}"#)),
842            (
843                "reddb_internal",
844                format!(r#"{{"literals":["{reddb_prefix}"]}}"#),
845            ),
846            ("bearer_token", format!(r#"{{"keywords":["{bearer}"]}}"#)),
847            ("jwt_shape", format!(r#"{{"literals":["{jwt}"]}}"#)),
848            ("conn_string", format!(r#"{{"keywords":["{conn}"]}}"#)),
849        ]
850    }
851
852    #[tokio::test]
853    async fn adversarial_corpus_is_fully_rejected() {
854        let corpus = adversarial_corpus();
855        assert!(corpus.len() >= 10, "corpus must be ≥10 payloads");
856        for (label, raw) in corpus {
857            let ner = LlmNer::new(
858                NerProvider::Stub(StubBehavior::RawJson(raw)),
859                HeuristicFallback::Propagate,
860            );
861            let err = ner
862                .extract("q?", &make_scope(), &allow())
863                .await
864                .expect_err(&format!("payload {label} should have been rejected"));
865            assert!(
866                matches!(
867                    err,
868                    NerError::ResponseMalformed { .. } | NerError::SecretInResponse { .. }
869                ),
870                "payload {label}: unexpected error variant {err:?}"
871            );
872        }
873    }
874
875    #[tokio::test]
876    async fn secret_in_response_reports_pattern_label() {
877        let raw = format!(r#"{{"keywords":["{}{}"]}}"#, "sk_", "live_zzzz");
878        let ner = LlmNer::new(
879            NerProvider::Stub(StubBehavior::RawJson(raw)),
880            HeuristicFallback::Propagate,
881        );
882        match ner
883            .extract("q?", &make_scope(), &allow())
884            .await
885            .unwrap_err()
886        {
887            NerError::SecretInResponse { pattern } => assert_eq!(pattern, "sk_prefix"),
888            other => panic!("expected SecretInResponse, got {other:?}"),
889        }
890    }
891
892    // --- Token-cap enforcement --------------------------------------------
893
894    #[tokio::test]
895    async fn token_cap_excess_is_rejected() {
896        // Build a payload with 33 keywords (one over default 32).
897        let kws: Vec<String> = (0..33).map(|i| format!("kw{i}")).collect();
898        let raw = crate_json::json!({ "keywords": kws }).to_string();
899        let ner = LlmNer::new(
900            NerProvider::Stub(StubBehavior::RawJson(raw)),
901            HeuristicFallback::Propagate,
902        );
903        let err = ner
904            .extract("q?", &make_scope(), &allow())
905            .await
906            .unwrap_err();
907        match err {
908            NerError::ResponseExceedsTokenLimit { count, max } => {
909                assert_eq!(count, 33);
910                assert_eq!(max, DEFAULT_MAX_TOKENS);
911            }
912            other => panic!("expected ResponseExceedsTokenLimit, got {other:?}"),
913        }
914    }
915
916    #[tokio::test]
917    async fn token_cap_at_limit_succeeds() {
918        let kws: Vec<String> = (0..DEFAULT_MAX_TOKENS).map(|i| format!("kw{i}")).collect();
919        let raw = crate_json::json!({ "keywords": kws }).to_string();
920        let ner = LlmNer::new(
921            NerProvider::Stub(StubBehavior::RawJson(raw)),
922            HeuristicFallback::Propagate,
923        );
924        let out = ner.extract("q?", &make_scope(), &allow()).await.unwrap();
925        assert_eq!(out.keywords.len(), DEFAULT_MAX_TOKENS);
926    }
927
928    // --- Auth gate ---------------------------------------------------------
929
930    #[tokio::test]
931    async fn auth_gate_denies_without_capability() {
932        let ner = LlmNer::new(
933            NerProvider::Stub(StubBehavior::Empty),
934            HeuristicFallback::UseHeuristic,
935        );
936        let err = ner.extract("q?", &make_scope(), &deny()).await.unwrap_err();
937        assert_eq!(err, NerError::AuthDenied);
938    }
939
940    #[tokio::test]
941    async fn auth_gate_denial_does_not_fall_back() {
942        // Even with UseHeuristic, AuthDenied must propagate — falling
943        // back would silently bypass the gate.
944        let ner = LlmNer::new(
945            NerProvider::Stub(StubBehavior::Empty),
946            HeuristicFallback::UseHeuristic,
947        );
948        let err = ner
949            .extract("FDD-1", &make_scope(), &deny())
950            .await
951            .unwrap_err();
952        assert_eq!(err, NerError::AuthDenied);
953    }
954
955    // --- Fallback semantics ------------------------------------------------
956
957    #[tokio::test]
958    async fn fallback_use_heuristic_runs_extract_tokens() {
959        // RawJson with malformed payload → triggers failure → fallback.
960        let ner = LlmNer::new(
961            NerProvider::Stub(StubBehavior::RawJson("not-json".into())),
962            HeuristicFallback::UseHeuristic,
963        );
964        let out = ner
965            .extract("show order 987654321 details", &make_scope(), &allow())
966            .await
967            .unwrap();
968        // Heuristic recognizes long digit run as a literal.
969        assert!(out.literals.iter().any(|l| l == "987654321"));
970    }
971
972    #[tokio::test]
973    async fn fallback_empty_on_fail_returns_empty() {
974        let ner = LlmNer::new(
975            NerProvider::Stub(StubBehavior::RawJson("not-json".into())),
976            HeuristicFallback::EmptyOnFail,
977        );
978        let out = ner
979            .extract("show order 987654321 details", &make_scope(), &allow())
980            .await
981            .unwrap();
982        assert!(out.is_empty());
983    }
984
985    #[tokio::test]
986    async fn fallback_propagate_returns_error() {
987        let ner = LlmNer::new(
988            NerProvider::Stub(StubBehavior::RawJson("not-json".into())),
989            HeuristicFallback::Propagate,
990        );
991        let err = ner
992            .extract("show order 987654321 details", &make_scope(), &allow())
993            .await
994            .unwrap_err();
995        assert!(matches!(err, NerError::ResponseMalformed { .. }));
996    }
997
998    // --- Helpers -----------------------------------------------------------
999
1000    #[test]
1001    fn jwt_detector_matches_three_segments() {
1002        assert!(looks_like_jwt("abcd.efgh.ijkl"));
1003        assert!(!looks_like_jwt("abcd.efgh"));
1004        assert!(!looks_like_jwt("abc.def.ghi.jkl"));
1005        assert!(!looks_like_jwt("ab.cd.ef")); // segments too short
1006    }
1007
1008    #[test]
1009    fn scrub_excerpt_drops_control_bytes() {
1010        let s = format!("ok\x07bad\nstill");
1011        let cleaned = scrub_excerpt(&s);
1012        assert!(!cleaned.contains('\x07'));
1013        assert!(!cleaned.contains('\n'));
1014    }
1015
1016    #[test]
1017    fn validate_token_accepts_normal_strings() {
1018        assert!(validate_token("passport").is_ok());
1019        assert!(validate_token("FDD-12313").is_ok());
1020        assert!(validate_token("foo_bar.baz").is_ok());
1021    }
1022}