1use std::time::Duration;
37
38use crate::serde_json as crate_json;
39use crate::serde_json::Value as JsonValue;
40
41use crate::runtime::ask_pipeline::{extract_tokens as heuristic_extract_tokens, TokenSet};
42use crate::runtime::statement_frame::EffectiveScope;
43
44pub const DEFAULT_MAX_TOKENS: usize = 32;
47
48pub const DEFAULT_TIMEOUT_MS: u32 = 5_000;
51
52pub const NER_CAPABILITY: &str = "ai:ner:read";
54
55#[derive(Debug, Clone)]
60pub enum NerProvider {
61 OpenAiCompat { endpoint: String, model: String },
65 AnthropicNative { endpoint: String, model: String },
67 Stub(StubBehavior),
70}
71
72#[derive(Debug, Clone)]
74pub enum StubBehavior {
75 Empty,
77 Echo,
80 Canned(TokenSet),
83 SlowDuration(Duration),
86 RawJson(String),
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95pub enum HeuristicFallback {
96 UseHeuristic,
100 EmptyOnFail,
104 Propagate,
106}
107
108#[derive(Debug, Clone, PartialEq, Eq)]
111pub enum NerError {
112 NetworkTimeout,
114 ProviderRejected { status: u16, body_excerpt: String },
117 ResponseMalformed { reason: String },
119 ResponseExceedsTokenLimit { count: usize, max: usize },
121 SecretInResponse { pattern: String },
125 AuthDenied,
127}
128
129impl std::fmt::Display for NerError {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 match self {
132 NerError::NetworkTimeout => write!(f, "ner: network timeout"),
133 NerError::ProviderRejected { status, .. } => {
134 write!(f, "ner: provider rejected (status={status})")
135 }
136 NerError::ResponseMalformed { reason } => {
137 write!(f, "ner: malformed response ({reason})")
138 }
139 NerError::ResponseExceedsTokenLimit { count, max } => {
140 write!(f, "ner: response exceeds token limit ({count} > {max})")
141 }
142 NerError::SecretInResponse { pattern } => {
143 write!(f, "ner: secret pattern in response ({pattern})")
144 }
145 NerError::AuthDenied => write!(f, "ner: auth denied (missing {NER_CAPABILITY})"),
146 }
147 }
148}
149
150impl std::error::Error for NerError {}
151
152pub trait AuthContext: std::fmt::Debug + Send + Sync {
159 fn has_capability(&self, capability: &str) -> bool;
160}
161
162#[derive(Debug, Clone, Default)]
166pub struct StubAuthContext {
167 capabilities: Vec<String>,
168}
169
170impl StubAuthContext {
171 pub fn new(caps: impl IntoIterator<Item = impl Into<String>>) -> Self {
172 Self {
173 capabilities: caps.into_iter().map(Into::into).collect(),
174 }
175 }
176
177 pub fn allow_all() -> Self {
178 Self::new([NER_CAPABILITY])
179 }
180
181 pub fn deny_all() -> Self {
182 Self::default()
183 }
184}
185
186impl AuthContext for StubAuthContext {
187 fn has_capability(&self, capability: &str) -> bool {
188 self.capabilities.iter().any(|c| c == capability)
189 }
190}
191
192#[derive(Debug, Clone)]
196pub struct LlmNer {
197 pub provider: NerProvider,
198 pub fallback: HeuristicFallback,
199 pub timeout_ms: u32,
200 pub max_tokens_returned: usize,
201}
202
203impl LlmNer {
204 pub fn new(provider: NerProvider, fallback: HeuristicFallback) -> Self {
206 Self {
207 provider,
208 fallback,
209 timeout_ms: DEFAULT_TIMEOUT_MS,
210 max_tokens_returned: DEFAULT_MAX_TOKENS,
211 }
212 }
213
214 pub async fn extract(
222 &self,
223 question: &str,
224 scope: &EffectiveScope,
225 auth: &dyn AuthContext,
226 ) -> Result<TokenSet, NerError> {
227 if !auth.has_capability(NER_CAPABILITY) {
231 return Err(NerError::AuthDenied);
232 }
233
234 let result = match &self.provider {
239 NerProvider::Stub(behavior) => self.run_stub(behavior, question),
240 NerProvider::OpenAiCompat { endpoint, model } => {
241 self.run_openai_compat(endpoint, model, question, scope)
242 .await
243 }
244 NerProvider::AnthropicNative { endpoint, model } => {
245 self.run_anthropic(endpoint, model, question, scope).await
246 }
247 };
248
249 match result {
250 Ok(tokens) => Ok(tokens),
251 Err(err) => self.handle_failure(err, question),
252 }
253 }
254
255 fn handle_failure(&self, err: NerError, question: &str) -> Result<TokenSet, NerError> {
257 if matches!(err, NerError::AuthDenied) {
259 return Err(err);
260 }
261 match self.fallback {
262 HeuristicFallback::UseHeuristic => Ok(heuristic_extract_tokens(question)),
263 HeuristicFallback::EmptyOnFail => Ok(TokenSet::default()),
264 HeuristicFallback::Propagate => Err(err),
265 }
266 }
267
268 fn run_stub(&self, behavior: &StubBehavior, question: &str) -> Result<TokenSet, NerError> {
270 match behavior {
271 StubBehavior::Empty => Ok(TokenSet::default()),
272 StubBehavior::Echo => {
273 let trimmed = question.trim().to_lowercase();
274 if trimmed.is_empty() {
275 Ok(TokenSet::default())
276 } else {
277 Ok(TokenSet {
278 keywords: vec![trimmed],
279 literals: vec![],
280 })
281 }
282 }
283 StubBehavior::Canned(tokens) => Ok(tokens.clone()),
284 StubBehavior::SlowDuration(d) => {
285 if d.as_millis() as u32 > self.timeout_ms {
289 Err(NerError::NetworkTimeout)
290 } else {
291 Ok(TokenSet::default())
292 }
293 }
294 StubBehavior::RawJson(raw) => parse_and_sanitize(raw, self.max_tokens_returned),
295 }
296 }
297
298 #[cfg(feature = "ai-ner-network")]
309 async fn run_openai_compat(
310 &self,
311 endpoint: &str,
312 model: &str,
313 question: &str,
314 scope: &EffectiveScope,
315 ) -> Result<TokenSet, NerError> {
316 let body = crate::json!({
317 "model": model,
318 "response_format": crate::json!({ "type": "json_object" }),
319 "messages": vec![
320 crate::json!({ "role": "system", "content": NER_SYSTEM_PROMPT }),
321 crate::json!({ "role": "user", "content": build_prompt(question, scope) }),
322 ],
323 });
324 let raw = http_post_json(endpoint, &body, self.timeout_ms).await?;
325 let payload = extract_openai_payload(&raw)?;
326 parse_and_sanitize(&payload, self.max_tokens_returned)
327 }
328
329 #[cfg(not(feature = "ai-ner-network"))]
330 async fn run_openai_compat(
331 &self,
332 _endpoint: &str,
333 _model: &str,
334 _question: &str,
335 _scope: &EffectiveScope,
336 ) -> Result<TokenSet, NerError> {
337 Err(NerError::NetworkTimeout)
340 }
341
342 #[cfg(feature = "ai-ner-network")]
343 async fn run_anthropic(
344 &self,
345 endpoint: &str,
346 model: &str,
347 question: &str,
348 scope: &EffectiveScope,
349 ) -> Result<TokenSet, NerError> {
350 let body = crate::json!({
351 "model": model,
352 "max_tokens": 1024,
353 "system": NER_SYSTEM_PROMPT,
354 "messages": vec![
355 crate::json!({ "role": "user", "content": build_prompt(question, scope) }),
356 ],
357 });
358 let raw = http_post_json(endpoint, &body, self.timeout_ms).await?;
359 let payload = extract_anthropic_payload(&raw)?;
360 parse_and_sanitize(&payload, self.max_tokens_returned)
361 }
362
363 #[cfg(not(feature = "ai-ner-network"))]
364 async fn run_anthropic(
365 &self,
366 _endpoint: &str,
367 _model: &str,
368 _question: &str,
369 _scope: &EffectiveScope,
370 ) -> Result<TokenSet, NerError> {
371 Err(NerError::NetworkTimeout)
372 }
373}
374
375const NER_SYSTEM_PROMPT: &str = "\
379You are an entity extraction service for a database query pipeline. \
380Read the user's question and return a JSON object with two fields: \
381'keywords' (array of lowercase content words, length >= 2) and \
382'literals' (array of identifier-shaped tokens kept in original case). \
383Return JSON only — no prose, no markdown.";
384
385#[allow(dead_code)] fn build_prompt(question: &str, scope: &EffectiveScope) -> String {
389 use crate::runtime::statement_frame::ReadFrame;
390 let visible: Vec<&str> = scope
391 .visible_collections()
392 .map(|set| set.iter().map(String::as_str).collect())
393 .unwrap_or_default();
394 format!(
395 "Question: {q}\nVisible collections: {v:?}\nReturn JSON only.",
396 q = question,
397 v = visible
398 )
399}
400
401#[cfg(feature = "ai-ner-network")]
402async fn http_post_json(
403 endpoint: &str,
404 body: &crate_json::Value,
405 timeout_ms: u32,
406) -> Result<String, NerError> {
407 let client = reqwest::Client::builder()
408 .timeout(Duration::from_millis(timeout_ms as u64))
409 .build()
410 .map_err(|e| NerError::ResponseMalformed {
411 reason: format!("client build: {e}"),
412 })?;
413 let resp = client
414 .post(endpoint)
415 .header("content-type", "application/json")
416 .body(body.to_string_compact())
417 .send()
418 .await
419 .map_err(|e| {
420 if e.is_timeout() {
421 NerError::NetworkTimeout
422 } else {
423 NerError::ResponseMalformed {
424 reason: format!("transport: {e}"),
425 }
426 }
427 })?;
428 let status = resp.status().as_u16();
429 let text = resp.text().await.map_err(|e| NerError::ResponseMalformed {
430 reason: format!("body read: {e}"),
431 })?;
432 if !(200..300).contains(&status) {
433 return Err(NerError::ProviderRejected {
434 status,
435 body_excerpt: scrub_excerpt(&text),
436 });
437 }
438 Ok(text)
439}
440
441#[cfg(feature = "ai-ner-network")]
442fn extract_openai_payload(raw: &str) -> Result<String, NerError> {
443 let v: JsonValue = crate_json::from_str(raw).map_err(|e| NerError::ResponseMalformed {
444 reason: format!("outer json: {e}"),
445 })?;
446 v["choices"]
447 .as_array()
448 .and_then(|choices| choices.first())
449 .and_then(|choice| choice["message"]["content"].as_str())
450 .map(str::to_owned)
451 .ok_or_else(|| NerError::ResponseMalformed {
452 reason: "missing choices[0].message.content".into(),
453 })
454}
455
456#[cfg(feature = "ai-ner-network")]
457fn extract_anthropic_payload(raw: &str) -> Result<String, NerError> {
458 let v: JsonValue = crate_json::from_str(raw).map_err(|e| NerError::ResponseMalformed {
459 reason: format!("outer json: {e}"),
460 })?;
461 v["content"]
462 .as_array()
463 .and_then(|content| content.first())
464 .and_then(|item| item["text"].as_str())
465 .map(str::to_owned)
466 .ok_or_else(|| NerError::ResponseMalformed {
467 reason: "missing content[0].text".into(),
468 })
469}
470
471#[allow(dead_code)] fn scrub_excerpt(s: &str) -> String {
473 let trimmed: String = s
474 .chars()
475 .take(256)
476 .filter(|c| !c.is_control() || *c == ' ')
477 .collect();
478 trimmed
479}
480
481fn parse_and_sanitize(raw: &str, max_tokens: usize) -> Result<TokenSet, NerError> {
487 let parsed: JsonValue = crate_json::from_str(raw).map_err(|e| NerError::ResponseMalformed {
488 reason: format!("json parse: {e}"),
489 })?;
490 let obj = parsed
491 .as_object()
492 .ok_or_else(|| NerError::ResponseMalformed {
493 reason: "expected JSON object at root".into(),
494 })?;
495
496 let keywords = collect_string_array(obj.get("keywords"), "keywords")?;
497 let literals = collect_string_array(obj.get("literals"), "literals")?;
498
499 let total = keywords.len() + literals.len();
500 if total > max_tokens {
501 return Err(NerError::ResponseExceedsTokenLimit {
502 count: total,
503 max: max_tokens,
504 });
505 }
506
507 for token in keywords.iter().chain(literals.iter()) {
508 validate_token(token)?;
509 }
510
511 Ok(TokenSet { keywords, literals })
512}
513
514fn collect_string_array(v: Option<&JsonValue>, field: &str) -> Result<Vec<String>, NerError> {
517 let arr = match v {
518 Some(JsonValue::Array(a)) => a,
519 Some(JsonValue::Null) | None => return Ok(Vec::new()),
520 Some(other) => {
521 return Err(NerError::ResponseMalformed {
522 reason: format!("{field}: expected array, got {}", json_kind(other)),
523 });
524 }
525 };
526 let mut out = Vec::with_capacity(arr.len());
527 for (i, item) in arr.iter().enumerate() {
528 match item {
529 JsonValue::String(s) => out.push(s.clone()),
530 other => {
531 return Err(NerError::ResponseMalformed {
532 reason: format!("{field}[{i}]: expected string, got {}", json_kind(other)),
533 });
534 }
535 }
536 }
537 Ok(out)
538}
539
540fn json_kind(v: &JsonValue) -> &'static str {
541 match v {
542 JsonValue::Null => "null",
543 JsonValue::Bool(_) => "bool",
544 JsonValue::Number(_) => "number",
545 JsonValue::String(_) => "string",
546 JsonValue::Array(_) => "array",
547 JsonValue::Object(_) => "object",
548 }
549}
550
551fn validate_token(token: &str) -> Result<(), NerError> {
555 if let Some(pattern) = match_secret_pattern(token) {
556 return Err(NerError::SecretInResponse {
557 pattern: pattern.into(),
558 });
559 }
560 if token.is_empty() {
561 return Err(NerError::ResponseMalformed {
562 reason: "empty token".into(),
563 });
564 }
565 if token.len() > 256 {
566 return Err(NerError::ResponseMalformed {
567 reason: format!("token too long ({} bytes)", token.len()),
568 });
569 }
570 for (i, byte) in token.as_bytes().iter().enumerate() {
571 match byte {
572 0x00 => {
574 return Err(NerError::ResponseMalformed {
575 reason: format!("NUL byte at offset {i}"),
576 });
577 }
578 b'\n' | b'\r' => {
579 return Err(NerError::ResponseMalformed {
580 reason: format!("CR/LF at offset {i}"),
581 });
582 }
583 b'"' | b'\'' | b'`' => {
585 return Err(NerError::ResponseMalformed {
586 reason: format!("quote injection at offset {i}"),
587 });
588 }
589 b if *b < 0x20 && *b != b'\t' => {
591 return Err(NerError::ResponseMalformed {
592 reason: format!("control byte 0x{b:02x} at offset {i}"),
593 });
594 }
595 _ => {}
596 }
597 }
598 Ok(())
599}
600
601fn match_secret_pattern(token: &str) -> Option<&'static str> {
606 const PATTERNS: &[(&str, &str)] = &[
609 ("sk_", "sk_prefix"),
610 ("rs_", "rs_prefix"),
611 ("reddb_", "reddb_prefix"),
612 ("Bearer ", "bearer"),
613 ("bearer ", "bearer"),
614 ];
615 for (prefix, label) in PATTERNS {
616 if token.starts_with(prefix) {
617 return Some(label);
618 }
619 }
620 if looks_like_jwt(token) {
625 return Some("jwt");
626 }
627 if token.contains("://") && token.contains(':') && token.contains('@') {
629 if let Some(scheme_end) = token.find("://") {
630 let rest = &token[scheme_end + 3..];
631 if let Some(at) = rest.find('@') {
632 let userpass = &rest[..at];
633 if userpass.contains(':') {
634 return Some("conn_string_credentials");
635 }
636 }
637 }
638 }
639 None
640}
641
642fn looks_like_jwt(token: &str) -> bool {
643 let parts: Vec<&str> = token.split('.').collect();
644 if parts.len() != 3 {
645 return false;
646 }
647 parts.iter().all(|p| {
648 p.len() >= 4
649 && p.bytes()
650 .all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'-')
651 })
652}
653
654#[cfg(test)]
659mod tests {
660 use super::*;
661 use crate::runtime::ask_pipeline::TokenSet;
662
663 fn make_scope() -> EffectiveScope {
666 use crate::storage::transaction::snapshot::Snapshot;
670 use std::collections::HashSet;
671 EffectiveScope {
672 tenant: None,
673 identity: None,
674 snapshot: Snapshot {
675 xid: 0,
676 in_progress: HashSet::new(),
677 },
678 visible_collections: None,
679 }
680 }
681
682 fn allow() -> StubAuthContext {
683 StubAuthContext::allow_all()
684 }
685
686 fn deny() -> StubAuthContext {
687 StubAuthContext::deny_all()
688 }
689
690 #[tokio::test]
693 async fn stub_empty_returns_empty_token_set() {
694 let ner = LlmNer::new(
695 NerProvider::Stub(StubBehavior::Empty),
696 HeuristicFallback::Propagate,
697 );
698 let out = ner
699 .extract("anything", &make_scope(), &allow())
700 .await
701 .unwrap();
702 assert!(out.is_empty());
703 }
704
705 #[tokio::test]
706 async fn stub_echo_returns_lowercased_keyword() {
707 let ner = LlmNer::new(
708 NerProvider::Stub(StubBehavior::Echo),
709 HeuristicFallback::Propagate,
710 );
711 let out = ner
712 .extract(" Hello WORLD ", &make_scope(), &allow())
713 .await
714 .unwrap();
715 assert_eq!(out.keywords, vec!["hello world".to_string()]);
716 assert!(out.literals.is_empty());
717 }
718
719 #[tokio::test]
720 async fn stub_echo_empty_question_yields_empty_set() {
721 let ner = LlmNer::new(
722 NerProvider::Stub(StubBehavior::Echo),
723 HeuristicFallback::Propagate,
724 );
725 let out = ner.extract(" ", &make_scope(), &allow()).await.unwrap();
726 assert!(out.is_empty());
727 }
728
729 #[tokio::test]
730 async fn stub_canned_returns_provided_tokens() {
731 let canned = TokenSet {
732 keywords: vec!["passport".into()],
733 literals: vec!["FDD-1".into()],
734 };
735 let ner = LlmNer::new(
736 NerProvider::Stub(StubBehavior::Canned(canned.clone())),
737 HeuristicFallback::Propagate,
738 );
739 let out = ner.extract("q?", &make_scope(), &allow()).await.unwrap();
740 assert_eq!(out, canned);
741 }
742
743 #[tokio::test]
746 async fn slow_stub_within_budget_succeeds() {
747 let mut ner = LlmNer::new(
748 NerProvider::Stub(StubBehavior::SlowDuration(Duration::from_millis(10))),
749 HeuristicFallback::Propagate,
750 );
751 ner.timeout_ms = 100;
752 assert!(ner.extract("q?", &make_scope(), &allow()).await.is_ok());
753 }
754
755 #[tokio::test]
756 async fn slow_stub_over_budget_times_out_and_propagates() {
757 let mut ner = LlmNer::new(
758 NerProvider::Stub(StubBehavior::SlowDuration(Duration::from_millis(500))),
759 HeuristicFallback::Propagate,
760 );
761 ner.timeout_ms = 50;
762 let err = ner
763 .extract("q?", &make_scope(), &allow())
764 .await
765 .unwrap_err();
766 assert_eq!(err, NerError::NetworkTimeout);
767 }
768
769 #[tokio::test]
772 async fn malformed_not_json_is_rejected() {
773 let ner = LlmNer::new(
774 NerProvider::Stub(StubBehavior::RawJson("not-json".into())),
775 HeuristicFallback::Propagate,
776 );
777 let err = ner
778 .extract("q?", &make_scope(), &allow())
779 .await
780 .unwrap_err();
781 assert!(matches!(err, NerError::ResponseMalformed { .. }));
782 }
783
784 #[tokio::test]
785 async fn malformed_wrong_root_type_is_rejected() {
786 let ner = LlmNer::new(
787 NerProvider::Stub(StubBehavior::RawJson("[1,2,3]".into())),
788 HeuristicFallback::Propagate,
789 );
790 let err = ner
791 .extract("q?", &make_scope(), &allow())
792 .await
793 .unwrap_err();
794 assert!(matches!(err, NerError::ResponseMalformed { .. }));
795 }
796
797 #[tokio::test]
798 async fn malformed_keywords_not_array_is_rejected() {
799 let ner = LlmNer::new(
800 NerProvider::Stub(StubBehavior::RawJson(r#"{"keywords":"oops"}"#.into())),
801 HeuristicFallback::Propagate,
802 );
803 let err = ner
804 .extract("q?", &make_scope(), &allow())
805 .await
806 .unwrap_err();
807 assert!(matches!(err, NerError::ResponseMalformed { .. }));
808 }
809
810 fn adversarial_corpus() -> Vec<(&'static str, String)> {
816 let sk_prefix = format!("{}{}", "sk_", "live_DEADBEEFcafe");
818 let rs_prefix = format!("{}{}", "rs_", "test_TOKENtoken");
819 let reddb_prefix = format!("{}{}", "reddb_", "internal_secret_X");
820 let bearer = format!("{}{}", "Bearer ", "ABC.DEF.GHI");
821 let jwt = format!("{}.{}.{}", "abcd1234", "wxyz5678", "qrst9012");
822 let conn = "postgres://user:pwd@host:5432/db".to_string();
823
824 vec![
825 (
826 "crlf_in_keyword",
827 "{\"keywords\":[\"foo\\r\\nbar\"]}".into(),
828 ),
829 (
830 "nul_in_literal",
831 "{\"literals\":[\"foo\\u0000bar\"]}".into(),
832 ),
833 ("dquote_injection", "{\"keywords\":[\"foo\\\"bar\"]}".into()),
834 ("squote_injection", "{\"keywords\":[\"foo'bar\"]}".into()),
835 ("backtick_injection", "{\"keywords\":[\"foo`bar\"]}".into()),
836 (
837 "control_byte_low",
838 "{\"keywords\":[\"foo\\u0007bar\"]}".into(),
839 ),
840 ("sk_live", format!(r#"{{"keywords":["{sk_prefix}"]}}"#)),
841 ("rs_test", format!(r#"{{"keywords":["{rs_prefix}"]}}"#)),
842 (
843 "reddb_internal",
844 format!(r#"{{"literals":["{reddb_prefix}"]}}"#),
845 ),
846 ("bearer_token", format!(r#"{{"keywords":["{bearer}"]}}"#)),
847 ("jwt_shape", format!(r#"{{"literals":["{jwt}"]}}"#)),
848 ("conn_string", format!(r#"{{"keywords":["{conn}"]}}"#)),
849 ]
850 }
851
852 #[tokio::test]
853 async fn adversarial_corpus_is_fully_rejected() {
854 let corpus = adversarial_corpus();
855 assert!(corpus.len() >= 10, "corpus must be ≥10 payloads");
856 for (label, raw) in corpus {
857 let ner = LlmNer::new(
858 NerProvider::Stub(StubBehavior::RawJson(raw)),
859 HeuristicFallback::Propagate,
860 );
861 let err = ner
862 .extract("q?", &make_scope(), &allow())
863 .await
864 .expect_err(&format!("payload {label} should have been rejected"));
865 assert!(
866 matches!(
867 err,
868 NerError::ResponseMalformed { .. } | NerError::SecretInResponse { .. }
869 ),
870 "payload {label}: unexpected error variant {err:?}"
871 );
872 }
873 }
874
875 #[tokio::test]
876 async fn secret_in_response_reports_pattern_label() {
877 let raw = format!(r#"{{"keywords":["{}{}"]}}"#, "sk_", "live_zzzz");
878 let ner = LlmNer::new(
879 NerProvider::Stub(StubBehavior::RawJson(raw)),
880 HeuristicFallback::Propagate,
881 );
882 match ner
883 .extract("q?", &make_scope(), &allow())
884 .await
885 .unwrap_err()
886 {
887 NerError::SecretInResponse { pattern } => assert_eq!(pattern, "sk_prefix"),
888 other => panic!("expected SecretInResponse, got {other:?}"),
889 }
890 }
891
892 #[tokio::test]
895 async fn token_cap_excess_is_rejected() {
896 let kws: Vec<String> = (0..33).map(|i| format!("kw{i}")).collect();
898 let raw = crate_json::json!({ "keywords": kws }).to_string();
899 let ner = LlmNer::new(
900 NerProvider::Stub(StubBehavior::RawJson(raw)),
901 HeuristicFallback::Propagate,
902 );
903 let err = ner
904 .extract("q?", &make_scope(), &allow())
905 .await
906 .unwrap_err();
907 match err {
908 NerError::ResponseExceedsTokenLimit { count, max } => {
909 assert_eq!(count, 33);
910 assert_eq!(max, DEFAULT_MAX_TOKENS);
911 }
912 other => panic!("expected ResponseExceedsTokenLimit, got {other:?}"),
913 }
914 }
915
916 #[tokio::test]
917 async fn token_cap_at_limit_succeeds() {
918 let kws: Vec<String> = (0..DEFAULT_MAX_TOKENS).map(|i| format!("kw{i}")).collect();
919 let raw = crate_json::json!({ "keywords": kws }).to_string();
920 let ner = LlmNer::new(
921 NerProvider::Stub(StubBehavior::RawJson(raw)),
922 HeuristicFallback::Propagate,
923 );
924 let out = ner.extract("q?", &make_scope(), &allow()).await.unwrap();
925 assert_eq!(out.keywords.len(), DEFAULT_MAX_TOKENS);
926 }
927
928 #[tokio::test]
931 async fn auth_gate_denies_without_capability() {
932 let ner = LlmNer::new(
933 NerProvider::Stub(StubBehavior::Empty),
934 HeuristicFallback::UseHeuristic,
935 );
936 let err = ner.extract("q?", &make_scope(), &deny()).await.unwrap_err();
937 assert_eq!(err, NerError::AuthDenied);
938 }
939
940 #[tokio::test]
941 async fn auth_gate_denial_does_not_fall_back() {
942 let ner = LlmNer::new(
945 NerProvider::Stub(StubBehavior::Empty),
946 HeuristicFallback::UseHeuristic,
947 );
948 let err = ner
949 .extract("FDD-1", &make_scope(), &deny())
950 .await
951 .unwrap_err();
952 assert_eq!(err, NerError::AuthDenied);
953 }
954
955 #[tokio::test]
958 async fn fallback_use_heuristic_runs_extract_tokens() {
959 let ner = LlmNer::new(
961 NerProvider::Stub(StubBehavior::RawJson("not-json".into())),
962 HeuristicFallback::UseHeuristic,
963 );
964 let out = ner
965 .extract("show order 987654321 details", &make_scope(), &allow())
966 .await
967 .unwrap();
968 assert!(out.literals.iter().any(|l| l == "987654321"));
970 }
971
972 #[tokio::test]
973 async fn fallback_empty_on_fail_returns_empty() {
974 let ner = LlmNer::new(
975 NerProvider::Stub(StubBehavior::RawJson("not-json".into())),
976 HeuristicFallback::EmptyOnFail,
977 );
978 let out = ner
979 .extract("show order 987654321 details", &make_scope(), &allow())
980 .await
981 .unwrap();
982 assert!(out.is_empty());
983 }
984
985 #[tokio::test]
986 async fn fallback_propagate_returns_error() {
987 let ner = LlmNer::new(
988 NerProvider::Stub(StubBehavior::RawJson("not-json".into())),
989 HeuristicFallback::Propagate,
990 );
991 let err = ner
992 .extract("show order 987654321 details", &make_scope(), &allow())
993 .await
994 .unwrap_err();
995 assert!(matches!(err, NerError::ResponseMalformed { .. }));
996 }
997
998 #[test]
1001 fn jwt_detector_matches_three_segments() {
1002 assert!(looks_like_jwt("abcd.efgh.ijkl"));
1003 assert!(!looks_like_jwt("abcd.efgh"));
1004 assert!(!looks_like_jwt("abc.def.ghi.jkl"));
1005 assert!(!looks_like_jwt("ab.cd.ef")); }
1007
1008 #[test]
1009 fn scrub_excerpt_drops_control_bytes() {
1010 let s = format!("ok\x07bad\nstill");
1011 let cleaned = scrub_excerpt(&s);
1012 assert!(!cleaned.contains('\x07'));
1013 assert!(!cleaned.contains('\n'));
1014 }
1015
1016 #[test]
1017 fn validate_token_accepts_normal_strings() {
1018 assert!(validate_token("passport").is_ok());
1019 assert!(validate_token("FDD-12313").is_ok());
1020 assert!(validate_token("foo_bar.baz").is_ok());
1021 }
1022}