Skip to main content

zeph_config/
sanitizer.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::providers::ProviderName;
5use serde::{Deserialize, Serialize};
6
7use crate::defaults::default_true;
8
9// ---------------------------------------------------------------------------
10// ContentIsolationConfig
11// ---------------------------------------------------------------------------
12
13fn default_max_content_size() -> usize {
14    65_536
15}
16
17/// Configuration for the embedding anomaly guard, nested under
18/// `[security.content_isolation.embedding_guard]`.
19#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
20pub struct EmbeddingGuardConfig {
21    /// Enable embedding-based anomaly detection (default: false — opt-in).
22    #[serde(default)]
23    pub enabled: bool,
24    /// Cosine distance threshold above which outputs are flagged as anomalous.
25    #[serde(
26        default = "default_embedding_threshold",
27        deserialize_with = "validate_embedding_threshold"
28    )]
29    pub threshold: f64,
30    /// Minimum clean samples before centroid-based detection activates.
31    /// Before this count, regex fallback is used instead.
32    #[serde(
33        default = "default_embedding_min_samples",
34        deserialize_with = "validate_min_samples"
35    )]
36    pub min_samples: usize,
37    /// EMA alpha floor for centroid updates after stabilization (n >= `min_samples`).
38    ///
39    /// Once the centroid has accumulated `min_samples` clean outputs, each new sample
40    /// can shift it by at most this fraction. Lower values make the centroid more
41    /// resistant to slow drift attacks but slower to adapt to legitimate distribution
42    /// changes. Default: 0.01 (1% per sample).
43    #[serde(default = "default_ema_floor")]
44    pub ema_floor: f32,
45}
46
47fn validate_embedding_threshold<'de, D>(deserializer: D) -> Result<f64, D::Error>
48where
49    D: serde::Deserializer<'de>,
50{
51    let value = <f64 as serde::Deserialize>::deserialize(deserializer)?;
52    if value.is_nan() || value.is_infinite() {
53        return Err(serde::de::Error::custom(
54            "embedding_guard.threshold must be a finite number",
55        ));
56    }
57    if !(value > 0.0 && value <= 1.0) {
58        return Err(serde::de::Error::custom(
59            "embedding_guard.threshold must be in (0.0, 1.0]",
60        ));
61    }
62    Ok(value)
63}
64
65fn validate_min_samples<'de, D>(deserializer: D) -> Result<usize, D::Error>
66where
67    D: serde::Deserializer<'de>,
68{
69    let value = <usize as serde::Deserialize>::deserialize(deserializer)?;
70    if value == 0 {
71        return Err(serde::de::Error::custom(
72            "embedding_guard.min_samples must be >= 1",
73        ));
74    }
75    Ok(value)
76}
77
78fn default_embedding_threshold() -> f64 {
79    0.35
80}
81
82fn default_embedding_min_samples() -> usize {
83    10
84}
85
86fn default_ema_floor() -> f32 {
87    0.01
88}
89
90impl Default for EmbeddingGuardConfig {
91    fn default() -> Self {
92        Self {
93            enabled: false,
94            threshold: default_embedding_threshold(),
95            min_samples: default_embedding_min_samples(),
96            ema_floor: default_ema_floor(),
97        }
98    }
99}
100
101/// Configuration for the content isolation pipeline, nested under
102/// `[security.content_isolation]` in the agent config file.
103#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
104#[allow(clippy::struct_excessive_bools)] // config struct — boolean flags are idiomatic for TOML-deserialized configuration
105pub struct ContentIsolationConfig {
106    /// When `false`, the sanitizer is a no-op: content passes through unchanged.
107    #[serde(default = "default_true")]
108    pub enabled: bool,
109
110    /// Maximum byte length of untrusted content before truncation.
111    #[serde(default = "default_max_content_size")]
112    pub max_content_size: usize,
113
114    /// When `true`, injection patterns detected in content are recorded as
115    /// flags and a warning is prepended to the spotlighting wrapper.
116    #[serde(default = "default_true")]
117    pub flag_injection_patterns: bool,
118
119    /// When `true`, untrusted content is wrapped in spotlighting XML delimiters
120    /// that instruct the LLM to treat the enclosed text as data, not instructions.
121    #[serde(default = "default_true")]
122    pub spotlight_untrusted: bool,
123
124    /// Quarantine summarizer configuration.
125    #[serde(default)]
126    pub quarantine: QuarantineConfig,
127
128    /// Embedding anomaly guard configuration.
129    #[serde(default)]
130    pub embedding_guard: EmbeddingGuardConfig,
131
132    /// When `true`, MCP tool results flowing through ACP-serving sessions receive
133    /// unconditional quarantine summarization and cross-boundary audit log entries.
134    /// This prevents confused-deputy attacks where untrusted MCP output influences
135    /// responses served to ACP clients (e.g. IDE integrations).
136    #[serde(default = "default_true")]
137    pub mcp_to_acp_boundary: bool,
138
139    /// NLI entailment check stage configuration.
140    #[serde(default)]
141    pub nli: NliConfig,
142
143    /// PAAC secret placeholder masking configuration.
144    #[serde(default)]
145    pub secret_masking: SecretMaskingConfig,
146}
147
148impl Default for ContentIsolationConfig {
149    fn default() -> Self {
150        Self {
151            enabled: true,
152            max_content_size: default_max_content_size(),
153            flag_injection_patterns: true,
154            spotlight_untrusted: true,
155            quarantine: QuarantineConfig::default(),
156            embedding_guard: EmbeddingGuardConfig::default(),
157            mcp_to_acp_boundary: true,
158            nli: NliConfig::default(),
159            secret_masking: SecretMaskingConfig::default(),
160        }
161    }
162}
163
164/// Configuration for the SONAR NLI entailment check stage, nested under
165/// `[security.content_isolation.nli]` in the agent config file.
166///
167/// When `enabled = false` (the default), the NLI stage is skipped entirely.
168#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
169pub struct NliConfig {
170    /// Enable NLI entailment-based injection detection (default: false — opt-in).
171    #[serde(default)]
172    pub enabled: bool,
173
174    /// Provider name from `[[llm.providers]]` to use for NLI inference.
175    ///
176    /// An empty [`ProviderName`] falls back to the default provider. Prefer a fast, cheap model.
177    #[serde(default)]
178    pub provider: ProviderName,
179
180    /// Entailment score threshold above which content is flagged (default: 0.75).
181    #[serde(default = "default_nli_threshold")]
182    pub threshold: f32,
183
184    /// Maximum milliseconds to wait for the NLI provider response (default: 5000).
185    #[serde(default = "default_nli_timeout_ms")]
186    pub timeout_ms: u64,
187
188    /// Maximum characters of content sent to the NLI provider (default: 2048).
189    #[serde(default = "default_nli_max_content_len")]
190    pub max_content_len: usize,
191}
192
193fn default_nli_threshold() -> f32 {
194    0.75
195}
196
197fn default_nli_timeout_ms() -> u64 {
198    5000
199}
200
201fn default_nli_max_content_len() -> usize {
202    2048
203}
204
205impl Default for NliConfig {
206    fn default() -> Self {
207        Self {
208            enabled: false,
209            provider: ProviderName::default(),
210            threshold: default_nli_threshold(),
211            timeout_ms: default_nli_timeout_ms(),
212            max_content_len: default_nli_max_content_len(),
213        }
214    }
215}
216
217/// Configuration for PAAC secret placeholder masking, nested under
218/// `[security.secret_masking]` in the agent config file.
219///
220/// When `enabled = false` (the default), vault secrets are not masked.
221#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
222pub struct SecretMaskingConfig {
223    /// Enable secret placeholder masking (default: false — opt-in).
224    #[serde(default)]
225    pub enabled: bool,
226
227    /// Minimum secret byte length to be eligible for masking (default: 8).
228    ///
229    /// Secrets shorter than this value are not substituted to avoid false matches
230    /// on common short strings.
231    #[serde(default = "default_min_secret_len")]
232    pub min_secret_len: usize,
233}
234
235fn default_min_secret_len() -> usize {
236    8
237}
238
239impl Default for SecretMaskingConfig {
240    fn default() -> Self {
241        Self {
242            enabled: false,
243            min_secret_len: default_min_secret_len(),
244        }
245    }
246}
247
248/// Configuration for the quarantine summarizer, nested under
249/// `[security.content_isolation.quarantine]` in the agent config file.
250#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
251pub struct QuarantineConfig {
252    /// When `false`, quarantine summarization is disabled entirely.
253    #[serde(default)]
254    pub enabled: bool,
255
256    /// Source kinds to route through the quarantine LLM.
257    #[serde(default = "default_quarantine_sources")]
258    pub sources: Vec<String>,
259
260    /// Provider name passed to `create_named_provider`.
261    #[serde(default = "default_quarantine_model")]
262    pub model: String,
263}
264
265fn default_quarantine_sources() -> Vec<String> {
266    vec!["web_scrape".to_owned(), "a2a_message".to_owned()]
267}
268
269fn default_quarantine_model() -> String {
270    "claude".to_owned()
271}
272
273impl Default for QuarantineConfig {
274    fn default() -> Self {
275        Self {
276            enabled: false,
277            sources: default_quarantine_sources(),
278            model: default_quarantine_model(),
279        }
280    }
281}
282
283// ---------------------------------------------------------------------------
284// ExfiltrationGuardConfig
285// ---------------------------------------------------------------------------
286
287/// Configuration for exfiltration guards, nested under
288/// `[security.exfiltration_guard]` in the agent config file.
289#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
290pub struct ExfiltrationGuardConfig {
291    /// Strip external markdown images from LLM output to prevent pixel-tracking exfiltration.
292    #[serde(default = "default_true")]
293    pub block_markdown_images: bool,
294
295    /// Cross-reference tool call arguments against URLs seen in flagged untrusted content.
296    #[serde(default = "default_true")]
297    pub validate_tool_urls: bool,
298
299    /// Skip Qdrant embedding for messages that contained injection-flagged content.
300    #[serde(default = "default_true")]
301    pub guard_memory_writes: bool,
302}
303
304impl Default for ExfiltrationGuardConfig {
305    fn default() -> Self {
306        Self {
307            block_markdown_images: true,
308            validate_tool_urls: true,
309            guard_memory_writes: true,
310        }
311    }
312}
313
314// ---------------------------------------------------------------------------
315// MemoryWriteValidationConfig
316// ---------------------------------------------------------------------------
317
318fn default_max_content_bytes() -> usize {
319    4096
320}
321
322fn default_max_entity_name_bytes() -> usize {
323    256
324}
325
326fn default_min_entity_name_bytes() -> usize {
327    3
328}
329
330fn default_max_fact_bytes() -> usize {
331    1024
332}
333
334fn default_max_entities() -> usize {
335    50
336}
337
338fn default_max_edges() -> usize {
339    100
340}
341
342/// Configuration for memory write validation, nested under `[security.memory_validation]`.
343///
344/// Enabled by default with conservative limits.
345#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
346pub struct MemoryWriteValidationConfig {
347    /// Master switch. When `false`, validation is a no-op.
348    #[serde(default = "default_true")]
349    pub enabled: bool,
350    /// Maximum byte length of content passed to `memory_save`.
351    #[serde(default = "default_max_content_bytes")]
352    pub max_content_bytes: usize,
353    /// Minimum byte length of an entity name in graph extraction.
354    #[serde(default = "default_min_entity_name_bytes")]
355    pub min_entity_name_bytes: usize,
356    /// Maximum byte length of a single entity name in graph extraction.
357    #[serde(default = "default_max_entity_name_bytes")]
358    pub max_entity_name_bytes: usize,
359    /// Maximum byte length of an edge fact string in graph extraction.
360    #[serde(default = "default_max_fact_bytes")]
361    pub max_fact_bytes: usize,
362    /// Maximum number of entities allowed per graph extraction result.
363    #[serde(default = "default_max_entities")]
364    pub max_entities_per_extraction: usize,
365    /// Maximum number of edges allowed per graph extraction result.
366    #[serde(default = "default_max_edges")]
367    pub max_edges_per_extraction: usize,
368    /// Forbidden substring patterns.
369    #[serde(default)]
370    pub forbidden_content_patterns: Vec<String>,
371}
372
373impl Default for MemoryWriteValidationConfig {
374    fn default() -> Self {
375        Self {
376            enabled: true,
377            max_content_bytes: default_max_content_bytes(),
378            min_entity_name_bytes: default_min_entity_name_bytes(),
379            max_entity_name_bytes: default_max_entity_name_bytes(),
380            max_fact_bytes: default_max_fact_bytes(),
381            max_entities_per_extraction: default_max_entities(),
382            max_edges_per_extraction: default_max_edges(),
383            forbidden_content_patterns: Vec::new(),
384        }
385    }
386}
387
388// ---------------------------------------------------------------------------
389// PiiFilterConfig
390// ---------------------------------------------------------------------------
391
392/// A single user-defined PII pattern.
393#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
394pub struct CustomPiiPattern {
395    /// Human-readable name used in the replacement label.
396    pub name: String,
397    /// Regular expression pattern.
398    pub pattern: String,
399    /// Replacement text. Defaults to `[PII:custom]`.
400    #[serde(default = "default_custom_replacement")]
401    pub replacement: String,
402}
403
404fn default_custom_replacement() -> String {
405    "[PII:custom]".to_owned()
406}
407
408/// Configuration for the PII filter, nested under `[security.pii_filter]` in the config file.
409///
410/// Disabled by default — opt-in to avoid unexpected data loss.
411#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
412#[allow(clippy::struct_excessive_bools)] // config struct — boolean flags are idiomatic for TOML-deserialized configuration
413pub struct PiiFilterConfig {
414    /// Master switch. When `false`, the filter is a no-op.
415    #[serde(default)]
416    pub enabled: bool,
417    /// Scrub email addresses.
418    #[serde(default = "default_true")]
419    pub filter_email: bool,
420    /// Scrub US phone numbers.
421    #[serde(default = "default_true")]
422    pub filter_phone: bool,
423    /// Scrub US Social Security Numbers.
424    #[serde(default = "default_true")]
425    pub filter_ssn: bool,
426    /// Scrub credit card numbers (16-digit patterns).
427    #[serde(default = "default_true")]
428    pub filter_credit_card: bool,
429    /// Custom regex patterns to add on top of the built-ins.
430    #[serde(default)]
431    pub custom_patterns: Vec<CustomPiiPattern>,
432}
433
434impl Default for PiiFilterConfig {
435    fn default() -> Self {
436        Self {
437            enabled: false,
438            filter_email: true,
439            filter_phone: true,
440            filter_ssn: true,
441            filter_credit_card: true,
442            custom_patterns: Vec::new(),
443        }
444    }
445}
446
447// ---------------------------------------------------------------------------
448// GuardrailConfig
449// ---------------------------------------------------------------------------
450
451/// What happens when the guardrail flags input.
452#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
453#[serde(rename_all = "lowercase")]
454pub enum GuardrailAction {
455    /// Block the input and return an error message to the user.
456    #[default]
457    Block,
458    /// Allow the input but emit a warning message.
459    Warn,
460}
461
462/// Behavior on timeout or LLM error.
463#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
464#[serde(rename_all = "lowercase")]
465pub enum GuardrailFailStrategy {
466    /// Block input on timeout/error (safe default for security-sensitive deployments).
467    #[default]
468    Closed,
469    /// Allow input on timeout/error (for availability-sensitive deployments).
470    Open,
471}
472
473/// Configuration for the LLM-based guardrail, nested under `[security.guardrail]`.
474#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
475pub struct GuardrailConfig {
476    /// Enable the guardrail (default: false).
477    #[serde(default)]
478    pub enabled: bool,
479    /// Provider to use for guardrail classification (e.g. `"ollama"`, `"claude"`).
480    #[serde(default)]
481    pub provider: Option<String>,
482    /// Model to use (e.g. `"llama-guard-3:1b"`).
483    #[serde(default)]
484    pub model: Option<String>,
485    /// Timeout for each guardrail LLM call in milliseconds (default: 500).
486    #[serde(default = "default_guardrail_timeout_ms")]
487    pub timeout_ms: u64,
488    /// Action to take when a message is flagged (default: block).
489    #[serde(default)]
490    pub action: GuardrailAction,
491    /// What to do on timeout or LLM error (default: closed — block).
492    #[serde(default = "default_fail_strategy")]
493    pub fail_strategy: GuardrailFailStrategy,
494    /// When `true`, also scan tool outputs before they enter message history (default: false).
495    #[serde(default)]
496    pub scan_tool_output: bool,
497    /// Maximum number of characters to send to the guard model (default: 4096).
498    #[serde(default = "default_max_input_chars")]
499    pub max_input_chars: usize,
500}
501fn default_guardrail_timeout_ms() -> u64 {
502    500
503}
504fn default_max_input_chars() -> usize {
505    4096
506}
507fn default_fail_strategy() -> GuardrailFailStrategy {
508    GuardrailFailStrategy::Closed
509}
510impl Default for GuardrailConfig {
511    fn default() -> Self {
512        Self {
513            enabled: false,
514            provider: None,
515            model: None,
516            timeout_ms: default_guardrail_timeout_ms(),
517            action: GuardrailAction::default(),
518            fail_strategy: default_fail_strategy(),
519            scan_tool_output: false,
520            max_input_chars: default_max_input_chars(),
521        }
522    }
523}
524
525// ---------------------------------------------------------------------------
526// ResponseVerificationConfig
527// ---------------------------------------------------------------------------
528
529/// Configuration for post-LLM response verification, nested under
530/// `[security.response_verification]` in the agent config file.
531///
532/// Scans LLM responses for injected instruction patterns before tool dispatch.
533/// This is defense-in-depth layer 3 (after input sanitization and pre-execution verification).
534#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
535pub struct ResponseVerificationConfig {
536    /// Enable post-LLM response verification (default: true).
537    #[serde(default = "default_true")]
538    pub enabled: bool,
539    /// Block tool dispatch when injection patterns are detected (default: false).
540    ///
541    /// When `false`, flagged responses are logged and shown in the TUI SEC panel
542    /// but still delivered. When `true`, the response is suppressed and the user
543    /// is notified.
544    #[serde(default)]
545    pub block_on_detection: bool,
546    /// Optional LLM provider for async deep verification of flagged responses.
547    ///
548    /// When set: suspicious responses are delivered immediately with a `[FLAGGED]`
549    /// annotation, and background LLM verification runs asynchronously. The verifier
550    /// receives a sanitized summary (via `QuarantinedSummarizer`) to prevent recursive
551    /// injection. Empty string = disabled (regex-only verification).
552    #[serde(default)]
553    pub verifier_provider: ProviderName,
554}
555
556impl Default for ResponseVerificationConfig {
557    fn default() -> Self {
558        Self {
559            enabled: true,
560            block_on_detection: false,
561            verifier_provider: ProviderName::default(),
562        }
563    }
564}
565
566#[cfg(test)]
567mod tests {
568    use super::*;
569
570    #[test]
571    fn content_isolation_default_mcp_to_acp_boundary_true() {
572        let cfg = ContentIsolationConfig::default();
573        assert!(cfg.mcp_to_acp_boundary);
574    }
575
576    #[test]
577    fn content_isolation_deserialize_mcp_to_acp_boundary_false() {
578        let toml = r"
579            mcp_to_acp_boundary = false
580        ";
581        let cfg: ContentIsolationConfig = toml::from_str(toml).unwrap();
582        assert!(!cfg.mcp_to_acp_boundary);
583    }
584
585    #[test]
586    fn content_isolation_deserialize_absent_defaults_true() {
587        let cfg: ContentIsolationConfig = toml::from_str("").unwrap();
588        assert!(cfg.mcp_to_acp_boundary);
589    }
590
591    fn de_guard(toml: &str) -> Result<EmbeddingGuardConfig, toml::de::Error> {
592        toml::from_str(toml)
593    }
594
595    #[test]
596    fn threshold_valid() {
597        let cfg = de_guard("threshold = 0.35\nmin_samples = 5").unwrap();
598        assert!((cfg.threshold - 0.35).abs() < f64::EPSILON);
599    }
600
601    #[test]
602    fn threshold_one_valid() {
603        let cfg = de_guard("threshold = 1.0\nmin_samples = 1").unwrap();
604        assert!((cfg.threshold - 1.0).abs() < f64::EPSILON);
605    }
606
607    #[test]
608    fn threshold_zero_rejected() {
609        assert!(de_guard("threshold = 0.0\nmin_samples = 1").is_err());
610    }
611
612    #[test]
613    fn threshold_above_one_rejected() {
614        assert!(de_guard("threshold = 1.5\nmin_samples = 1").is_err());
615    }
616
617    #[test]
618    fn threshold_negative_rejected() {
619        assert!(de_guard("threshold = -0.1\nmin_samples = 1").is_err());
620    }
621
622    #[test]
623    fn min_samples_zero_rejected() {
624        assert!(de_guard("threshold = 0.35\nmin_samples = 0").is_err());
625    }
626
627    #[test]
628    fn min_samples_one_valid() {
629        let cfg = de_guard("threshold = 0.35\nmin_samples = 1").unwrap();
630        assert_eq!(cfg.min_samples, 1);
631    }
632}
633
634// ---------------------------------------------------------------------------
635// CausalIpiConfig
636// ---------------------------------------------------------------------------
637
638fn default_causal_threshold() -> f32 {
639    0.7
640}
641
642fn validate_causal_threshold<'de, D>(deserializer: D) -> Result<f32, D::Error>
643where
644    D: serde::Deserializer<'de>,
645{
646    let value = <f32 as serde::Deserialize>::deserialize(deserializer)?;
647    if value.is_nan() || value.is_infinite() {
648        return Err(serde::de::Error::custom(
649            "causal_ipi.threshold must be a finite number",
650        ));
651    }
652    if !(value > 0.0 && value <= 1.0) {
653        return Err(serde::de::Error::custom(
654            "causal_ipi.threshold must be in (0.0, 1.0]",
655        ));
656    }
657    Ok(value)
658}
659
660fn default_probe_max_tokens() -> u32 {
661    100
662}
663
664fn default_probe_timeout_ms() -> u64 {
665    3000
666}
667
668/// Temporal causal IPI analysis at tool-return boundaries.
669///
670/// When enabled, the agent generates behavioral probes before and after tool batch dispatch
671/// and compares them to detect behavioral deviation caused by injected instructions in
672/// tool outputs. Probes are per-batch (2 LLM calls total), not per individual tool.
673///
674/// Config section: `[security.causal_ipi]`
675#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
676pub struct CausalIpiConfig {
677    /// Master switch. Default: false (opt-in).
678    #[serde(default)]
679    pub enabled: bool,
680
681    /// Causal attribution score threshold for flagging. Range: (0.0, 1.0]. Default 0.7.
682    ///
683    /// Scores above this value trigger a WARN log, metric increment, and `SecurityEvent`.
684    /// Content is never blocked — this is an observation layer only.
685    #[serde(
686        default = "default_causal_threshold",
687        deserialize_with = "validate_causal_threshold"
688    )]
689    pub threshold: f32,
690
691    /// LLM provider name from `[[llm.providers]]` for probe calls.
692    ///
693    /// Should reference a fast/cheap provider — probes run on every tool batch return.
694    /// When `None`, falls back to the agent's default provider.
695    #[serde(default)]
696    pub provider: Option<String>,
697
698    /// Maximum tokens for each probe response. Limits cost per probe call. Default: 100.
699    ///
700    /// Two probes per batch = max `2 * probe_max_tokens` output tokens per tool batch.
701    #[serde(default = "default_probe_max_tokens")]
702    pub probe_max_tokens: u32,
703
704    /// Timeout in milliseconds for each individual probe LLM call. Default: 3000.
705    ///
706    /// On timeout: WARN log, skip causal analysis for the batch (never block).
707    #[serde(default = "default_probe_timeout_ms")]
708    pub probe_timeout_ms: u64,
709
710    /// Shadow memory configuration for cross-turn trajectory analysis.
711    #[serde(default)]
712    pub shadow_memory: ShadowMemoryConfig,
713}
714
715impl Default for CausalIpiConfig {
716    fn default() -> Self {
717        Self {
718            enabled: false,
719            threshold: default_causal_threshold(),
720            provider: None,
721            probe_max_tokens: default_probe_max_tokens(),
722            probe_timeout_ms: default_probe_timeout_ms(),
723            shadow_memory: ShadowMemoryConfig::default(),
724        }
725    }
726}
727
728// ---------------------------------------------------------------------------
729// ShadowMemoryConfig
730// ---------------------------------------------------------------------------
731
732fn default_shadow_window() -> usize {
733    8
734}
735
736fn default_shadow_max_events() -> usize {
737    64
738}
739
740fn default_shadow_drift_threshold() -> f32 {
741    0.6
742}
743
744fn validate_shadow_window<'de, D>(deserializer: D) -> Result<usize, D::Error>
745where
746    D: serde::Deserializer<'de>,
747{
748    let value = <usize as serde::Deserialize>::deserialize(deserializer)?;
749    if value == 0 {
750        return Err(serde::de::Error::custom(
751            "shadow_memory.window_size must be >= 1",
752        ));
753    }
754    Ok(value)
755}
756
757fn validate_shadow_max_events<'de, D>(deserializer: D) -> Result<usize, D::Error>
758where
759    D: serde::Deserializer<'de>,
760{
761    let value = <usize as serde::Deserialize>::deserialize(deserializer)?;
762    if value == 0 {
763        return Err(serde::de::Error::custom(
764            "shadow_memory.max_events must be >= 1",
765        ));
766    }
767    Ok(value)
768}
769
770fn validate_shadow_drift_threshold<'de, D>(deserializer: D) -> Result<f32, D::Error>
771where
772    D: serde::Deserializer<'de>,
773{
774    let value = <f32 as serde::Deserialize>::deserialize(deserializer)?;
775    if value.is_nan() || value.is_infinite() {
776        return Err(serde::de::Error::custom(
777            "shadow_memory.drift_threshold must be a finite number",
778        ));
779    }
780    if !(value > 0.0 && value <= 1.0) {
781        return Err(serde::de::Error::custom(
782            "shadow_memory.drift_threshold must be in (0.0, 1.0]",
783        ));
784    }
785    Ok(value)
786}
787
788/// Per-session append-only event store for cross-turn trajectory analysis.
789///
790/// Detects multi-turn attacks that distribute payload across several turns —
791/// invisible to the stateless [`CausalIpiConfig`] single-batch analysis.
792///
793/// Config section: `[security.causal_ipi.shadow_memory]`
794///
795/// # Examples
796///
797/// ```toml
798/// [security.causal_ipi.shadow_memory]
799/// enabled = true
800/// window_size = 8
801/// max_events = 64
802/// drift_threshold = 0.6
803/// ```
804#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
805pub struct ShadowMemoryConfig {
806    /// Enable shadow memory trajectory tracking. Default: false.
807    #[serde(default)]
808    pub enabled: bool,
809
810    /// Sliding window size for drift computation. Must be >= 1. Default: 8.
811    #[serde(
812        default = "default_shadow_window",
813        deserialize_with = "validate_shadow_window"
814    )]
815    pub window_size: usize,
816
817    /// Maximum events retained before oldest are evicted. Must be >= 1. Default: 64.
818    #[serde(
819        default = "default_shadow_max_events",
820        deserialize_with = "validate_shadow_max_events"
821    )]
822    pub max_events: usize,
823
824    /// Goal drift score threshold for flagging. Range: (0.0, 1.0]. Default: 0.6.
825    #[serde(
826        default = "default_shadow_drift_threshold",
827        deserialize_with = "validate_shadow_drift_threshold"
828    )]
829    pub drift_threshold: f32,
830}
831
832impl Default for ShadowMemoryConfig {
833    fn default() -> Self {
834        Self {
835            enabled: false,
836            window_size: default_shadow_window(),
837            max_events: default_shadow_max_events(),
838            drift_threshold: default_shadow_drift_threshold(),
839        }
840    }
841}
842
843#[cfg(test)]
844mod causal_ipi_tests {
845    use super::*;
846
847    #[test]
848    fn causal_ipi_defaults() {
849        let cfg = CausalIpiConfig::default();
850        assert!(!cfg.enabled);
851        assert!((cfg.threshold - 0.7).abs() < 1e-6);
852        assert!(cfg.provider.is_none());
853        assert_eq!(cfg.probe_max_tokens, 100);
854        assert_eq!(cfg.probe_timeout_ms, 3000);
855    }
856
857    #[test]
858    fn causal_ipi_deserialize_enabled() {
859        let toml = r#"
860            enabled = true
861            threshold = 0.8
862            provider = "fast"
863            probe_max_tokens = 150
864            probe_timeout_ms = 5000
865        "#;
866        let cfg: CausalIpiConfig = toml::from_str(toml).unwrap();
867        assert!(cfg.enabled);
868        assert!((cfg.threshold - 0.8).abs() < 1e-6);
869        assert_eq!(cfg.provider.as_deref(), Some("fast"));
870        assert_eq!(cfg.probe_max_tokens, 150);
871        assert_eq!(cfg.probe_timeout_ms, 5000);
872    }
873
874    #[test]
875    fn causal_ipi_threshold_zero_rejected() {
876        let result: Result<CausalIpiConfig, _> = toml::from_str("threshold = 0.0");
877        assert!(result.is_err());
878    }
879
880    #[test]
881    fn causal_ipi_threshold_above_one_rejected() {
882        let result: Result<CausalIpiConfig, _> = toml::from_str("threshold = 1.1");
883        assert!(result.is_err());
884    }
885
886    #[test]
887    fn causal_ipi_threshold_exactly_one_accepted() {
888        let cfg: CausalIpiConfig = toml::from_str("threshold = 1.0").unwrap();
889        assert!((cfg.threshold - 1.0).abs() < 1e-6);
890    }
891}
892
893#[cfg(test)]
894mod shadow_memory_config_tests {
895    use super::*;
896
897    #[test]
898    fn shadow_memory_defaults() {
899        let cfg = ShadowMemoryConfig::default();
900        assert!(!cfg.enabled);
901        assert_eq!(cfg.window_size, 8);
902        assert_eq!(cfg.max_events, 64);
903        assert!((cfg.drift_threshold - 0.6).abs() < 1e-6);
904    }
905
906    #[test]
907    fn shadow_memory_window_zero_rejected() {
908        let result: Result<ShadowMemoryConfig, _> = toml::from_str("window_size = 0");
909        assert!(result.is_err());
910    }
911
912    #[test]
913    fn shadow_memory_max_events_zero_rejected() {
914        let result: Result<ShadowMemoryConfig, _> = toml::from_str("max_events = 0");
915        assert!(result.is_err());
916    }
917
918    #[test]
919    fn shadow_memory_drift_threshold_zero_rejected() {
920        let result: Result<ShadowMemoryConfig, _> = toml::from_str("drift_threshold = 0.0");
921        assert!(result.is_err());
922    }
923
924    #[test]
925    fn shadow_memory_drift_threshold_above_one_rejected() {
926        let result: Result<ShadowMemoryConfig, _> = toml::from_str("drift_threshold = 1.1");
927        assert!(result.is_err());
928    }
929
930    #[test]
931    fn shadow_memory_drift_threshold_exactly_one_accepted() {
932        let cfg: ShadowMemoryConfig = toml::from_str("drift_threshold = 1.0").unwrap();
933        assert!((cfg.drift_threshold - 1.0).abs() < 1e-6);
934    }
935
936    #[test]
937    fn shadow_memory_full_deserialization() {
938        let toml = r"
939            enabled = true
940            window_size = 4
941            max_events = 32
942            drift_threshold = 0.8
943        ";
944        let cfg: ShadowMemoryConfig = toml::from_str(toml).unwrap();
945        assert!(cfg.enabled);
946        assert_eq!(cfg.window_size, 4);
947        assert_eq!(cfg.max_events, 32);
948        assert!((cfg.drift_threshold - 0.8).abs() < 1e-6);
949    }
950}