use crate::providers::ProviderName;
use serde::{Deserialize, Serialize};
use crate::defaults::default_true;
fn default_max_content_size() -> usize {
65_536
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct EmbeddingGuardConfig {
#[serde(default)]
pub enabled: bool,
#[serde(
default = "default_embedding_threshold",
deserialize_with = "validate_embedding_threshold"
)]
pub threshold: f64,
#[serde(
default = "default_embedding_min_samples",
deserialize_with = "validate_min_samples"
)]
pub min_samples: usize,
#[serde(default = "default_ema_floor")]
pub ema_floor: f32,
}
fn validate_embedding_threshold<'de, D>(deserializer: D) -> Result<f64, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = <f64 as serde::Deserialize>::deserialize(deserializer)?;
if value.is_nan() || value.is_infinite() {
return Err(serde::de::Error::custom(
"embedding_guard.threshold must be a finite number",
));
}
if !(value > 0.0 && value <= 1.0) {
return Err(serde::de::Error::custom(
"embedding_guard.threshold must be in (0.0, 1.0]",
));
}
Ok(value)
}
fn validate_min_samples<'de, D>(deserializer: D) -> Result<usize, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = <usize as serde::Deserialize>::deserialize(deserializer)?;
if value == 0 {
return Err(serde::de::Error::custom(
"embedding_guard.min_samples must be >= 1",
));
}
Ok(value)
}
fn default_embedding_threshold() -> f64 {
0.35
}
fn default_embedding_min_samples() -> usize {
10
}
fn default_ema_floor() -> f32 {
0.01
}
impl Default for EmbeddingGuardConfig {
fn default() -> Self {
Self {
enabled: false,
threshold: default_embedding_threshold(),
min_samples: default_embedding_min_samples(),
ema_floor: default_ema_floor(),
}
}
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[allow(clippy::struct_excessive_bools)]
pub struct ContentIsolationConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_max_content_size")]
pub max_content_size: usize,
#[serde(default = "default_true")]
pub flag_injection_patterns: bool,
#[serde(default = "default_true")]
pub spotlight_untrusted: bool,
#[serde(default)]
pub quarantine: QuarantineConfig,
#[serde(default)]
pub embedding_guard: EmbeddingGuardConfig,
#[serde(default = "default_true")]
pub mcp_to_acp_boundary: bool,
}
impl Default for ContentIsolationConfig {
fn default() -> Self {
Self {
enabled: true,
max_content_size: default_max_content_size(),
flag_injection_patterns: true,
spotlight_untrusted: true,
quarantine: QuarantineConfig::default(),
embedding_guard: EmbeddingGuardConfig::default(),
mcp_to_acp_boundary: true,
}
}
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct QuarantineConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_quarantine_sources")]
pub sources: Vec<String>,
#[serde(default = "default_quarantine_model")]
pub model: String,
}
fn default_quarantine_sources() -> Vec<String> {
vec!["web_scrape".to_owned(), "a2a_message".to_owned()]
}
fn default_quarantine_model() -> String {
"claude".to_owned()
}
impl Default for QuarantineConfig {
fn default() -> Self {
Self {
enabled: false,
sources: default_quarantine_sources(),
model: default_quarantine_model(),
}
}
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct ExfiltrationGuardConfig {
#[serde(default = "default_true")]
pub block_markdown_images: bool,
#[serde(default = "default_true")]
pub validate_tool_urls: bool,
#[serde(default = "default_true")]
pub guard_memory_writes: bool,
}
impl Default for ExfiltrationGuardConfig {
fn default() -> Self {
Self {
block_markdown_images: true,
validate_tool_urls: true,
guard_memory_writes: true,
}
}
}
fn default_max_content_bytes() -> usize {
4096
}
fn default_max_entity_name_bytes() -> usize {
256
}
fn default_min_entity_name_bytes() -> usize {
3
}
fn default_max_fact_bytes() -> usize {
1024
}
fn default_max_entities() -> usize {
50
}
fn default_max_edges() -> usize {
100
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct MemoryWriteValidationConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_max_content_bytes")]
pub max_content_bytes: usize,
#[serde(default = "default_min_entity_name_bytes")]
pub min_entity_name_bytes: usize,
#[serde(default = "default_max_entity_name_bytes")]
pub max_entity_name_bytes: usize,
#[serde(default = "default_max_fact_bytes")]
pub max_fact_bytes: usize,
#[serde(default = "default_max_entities")]
pub max_entities_per_extraction: usize,
#[serde(default = "default_max_edges")]
pub max_edges_per_extraction: usize,
#[serde(default)]
pub forbidden_content_patterns: Vec<String>,
}
impl Default for MemoryWriteValidationConfig {
fn default() -> Self {
Self {
enabled: true,
max_content_bytes: default_max_content_bytes(),
min_entity_name_bytes: default_min_entity_name_bytes(),
max_entity_name_bytes: default_max_entity_name_bytes(),
max_fact_bytes: default_max_fact_bytes(),
max_entities_per_extraction: default_max_entities(),
max_edges_per_extraction: default_max_edges(),
forbidden_content_patterns: Vec::new(),
}
}
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct CustomPiiPattern {
pub name: String,
pub pattern: String,
#[serde(default = "default_custom_replacement")]
pub replacement: String,
}
fn default_custom_replacement() -> String {
"[PII:custom]".to_owned()
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[allow(clippy::struct_excessive_bools)]
pub struct PiiFilterConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_true")]
pub filter_email: bool,
#[serde(default = "default_true")]
pub filter_phone: bool,
#[serde(default = "default_true")]
pub filter_ssn: bool,
#[serde(default = "default_true")]
pub filter_credit_card: bool,
#[serde(default)]
pub custom_patterns: Vec<CustomPiiPattern>,
}
impl Default for PiiFilterConfig {
fn default() -> Self {
Self {
enabled: false,
filter_email: true,
filter_phone: true,
filter_ssn: true,
filter_credit_card: true,
custom_patterns: Vec::new(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum GuardrailAction {
#[default]
Block,
Warn,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum GuardrailFailStrategy {
#[default]
Closed,
Open,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct GuardrailConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub provider: Option<String>,
#[serde(default)]
pub model: Option<String>,
#[serde(default = "default_guardrail_timeout_ms")]
pub timeout_ms: u64,
#[serde(default)]
pub action: GuardrailAction,
#[serde(default = "default_fail_strategy")]
pub fail_strategy: GuardrailFailStrategy,
#[serde(default)]
pub scan_tool_output: bool,
#[serde(default = "default_max_input_chars")]
pub max_input_chars: usize,
}
fn default_guardrail_timeout_ms() -> u64 {
500
}
fn default_max_input_chars() -> usize {
4096
}
fn default_fail_strategy() -> GuardrailFailStrategy {
GuardrailFailStrategy::Closed
}
impl Default for GuardrailConfig {
fn default() -> Self {
Self {
enabled: false,
provider: None,
model: None,
timeout_ms: default_guardrail_timeout_ms(),
action: GuardrailAction::default(),
fail_strategy: default_fail_strategy(),
scan_tool_output: false,
max_input_chars: default_max_input_chars(),
}
}
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct ResponseVerificationConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default)]
pub block_on_detection: bool,
#[serde(default)]
pub verifier_provider: ProviderName,
}
impl Default for ResponseVerificationConfig {
fn default() -> Self {
Self {
enabled: true,
block_on_detection: false,
verifier_provider: ProviderName::default(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn content_isolation_default_mcp_to_acp_boundary_true() {
let cfg = ContentIsolationConfig::default();
assert!(cfg.mcp_to_acp_boundary);
}
#[test]
fn content_isolation_deserialize_mcp_to_acp_boundary_false() {
let toml = r"
mcp_to_acp_boundary = false
";
let cfg: ContentIsolationConfig = toml::from_str(toml).unwrap();
assert!(!cfg.mcp_to_acp_boundary);
}
#[test]
fn content_isolation_deserialize_absent_defaults_true() {
let cfg: ContentIsolationConfig = toml::from_str("").unwrap();
assert!(cfg.mcp_to_acp_boundary);
}
fn de_guard(toml: &str) -> Result<EmbeddingGuardConfig, toml::de::Error> {
toml::from_str(toml)
}
#[test]
fn threshold_valid() {
let cfg = de_guard("threshold = 0.35\nmin_samples = 5").unwrap();
assert!((cfg.threshold - 0.35).abs() < f64::EPSILON);
}
#[test]
fn threshold_one_valid() {
let cfg = de_guard("threshold = 1.0\nmin_samples = 1").unwrap();
assert!((cfg.threshold - 1.0).abs() < f64::EPSILON);
}
#[test]
fn threshold_zero_rejected() {
assert!(de_guard("threshold = 0.0\nmin_samples = 1").is_err());
}
#[test]
fn threshold_above_one_rejected() {
assert!(de_guard("threshold = 1.5\nmin_samples = 1").is_err());
}
#[test]
fn threshold_negative_rejected() {
assert!(de_guard("threshold = -0.1\nmin_samples = 1").is_err());
}
#[test]
fn min_samples_zero_rejected() {
assert!(de_guard("threshold = 0.35\nmin_samples = 0").is_err());
}
#[test]
fn min_samples_one_valid() {
let cfg = de_guard("threshold = 0.35\nmin_samples = 1").unwrap();
assert_eq!(cfg.min_samples, 1);
}
}
fn default_causal_threshold() -> f32 {
0.7
}
fn validate_causal_threshold<'de, D>(deserializer: D) -> Result<f32, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = <f32 as serde::Deserialize>::deserialize(deserializer)?;
if value.is_nan() || value.is_infinite() {
return Err(serde::de::Error::custom(
"causal_ipi.threshold must be a finite number",
));
}
if !(value > 0.0 && value <= 1.0) {
return Err(serde::de::Error::custom(
"causal_ipi.threshold must be in (0.0, 1.0]",
));
}
Ok(value)
}
fn default_probe_max_tokens() -> u32 {
100
}
fn default_probe_timeout_ms() -> u64 {
3000
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct CausalIpiConfig {
#[serde(default)]
pub enabled: bool,
#[serde(
default = "default_causal_threshold",
deserialize_with = "validate_causal_threshold"
)]
pub threshold: f32,
#[serde(default)]
pub provider: Option<String>,
#[serde(default = "default_probe_max_tokens")]
pub probe_max_tokens: u32,
#[serde(default = "default_probe_timeout_ms")]
pub probe_timeout_ms: u64,
}
impl Default for CausalIpiConfig {
fn default() -> Self {
Self {
enabled: false,
threshold: default_causal_threshold(),
provider: None,
probe_max_tokens: default_probe_max_tokens(),
probe_timeout_ms: default_probe_timeout_ms(),
}
}
}
#[cfg(test)]
mod causal_ipi_tests {
use super::*;
#[test]
fn causal_ipi_defaults() {
let cfg = CausalIpiConfig::default();
assert!(!cfg.enabled);
assert!((cfg.threshold - 0.7).abs() < 1e-6);
assert!(cfg.provider.is_none());
assert_eq!(cfg.probe_max_tokens, 100);
assert_eq!(cfg.probe_timeout_ms, 3000);
}
#[test]
fn causal_ipi_deserialize_enabled() {
let toml = r#"
enabled = true
threshold = 0.8
provider = "fast"
probe_max_tokens = 150
probe_timeout_ms = 5000
"#;
let cfg: CausalIpiConfig = toml::from_str(toml).unwrap();
assert!(cfg.enabled);
assert!((cfg.threshold - 0.8).abs() < 1e-6);
assert_eq!(cfg.provider.as_deref(), Some("fast"));
assert_eq!(cfg.probe_max_tokens, 150);
assert_eq!(cfg.probe_timeout_ms, 5000);
}
#[test]
fn causal_ipi_threshold_zero_rejected() {
let result: Result<CausalIpiConfig, _> = toml::from_str("threshold = 0.0");
assert!(result.is_err());
}
#[test]
fn causal_ipi_threshold_above_one_rejected() {
let result: Result<CausalIpiConfig, _> = toml::from_str("threshold = 1.1");
assert!(result.is_err());
}
#[test]
fn causal_ipi_threshold_exactly_one_accepted() {
let cfg: CausalIpiConfig = toml::from_str("threshold = 1.0").unwrap();
assert!((cfg.threshold - 1.0).abs() < 1e-6);
}
}