use crate::core::error::{Error, Result, StorageError};
use crate::core::property::PropertyMap;
use crate::core::vector::cosine_similarity;
pub trait SemanticRule {
fn validate(&self, props: &PropertyMap) -> Result<()>;
}
pub struct Sentinel {
rules: Vec<Box<dyn SemanticRule>>,
}
impl Sentinel {
pub fn new() -> Self {
Self { rules: Vec::new() }
}
pub fn add_rule(&mut self, rule: Box<dyn SemanticRule>) {
self.rules.push(rule);
}
pub fn validate(&self, props: &PropertyMap) -> Result<()> {
for rule in &self.rules {
rule.validate(props)?;
}
Ok(())
}
}
impl Default for Sentinel {
fn default() -> Self {
Self::new()
}
}
pub struct VectorBanRule {
pub property_name: String,
pub threshold: f32,
banned_vectors: Vec<Vec<f32>>,
}
impl VectorBanRule {
pub fn new(property_name: impl Into<String>, threshold: f32) -> Self {
Self {
property_name: property_name.into(),
threshold,
banned_vectors: Vec::new(),
}
}
pub fn add_banned_vector(&mut self, vector: Vec<f32>) -> Result<()> {
const MAX_BANNED_VECTORS_PER_RULE: usize = 1_000;
if self.banned_vectors.len() >= MAX_BANNED_VECTORS_PER_RULE {
return Err(Error::Storage(StorageError::CapacityExceeded {
resource: "VectorBanRule.banned_vectors".to_string(),
current: self.banned_vectors.len(),
limit: MAX_BANNED_VECTORS_PER_RULE,
}));
}
self.banned_vectors.push(vector);
Ok(())
}
}
impl SemanticRule for VectorBanRule {
fn validate(&self, props: &PropertyMap) -> Result<()> {
let val = match props.get(&self.property_name) {
Some(v) => v,
None => return Ok(()),
};
let vec = match val.as_vector() {
Some(v) => v,
None => return Ok(()), };
for banned in &self.banned_vectors {
match cosine_similarity(vec, banned) {
Ok(similarity) => {
if !similarity.is_finite() {
return Err(Error::other(format!(
"Vector property '{}' similarity check resulted in non-finite value (NaN/Inf)",
self.property_name
)));
}
if similarity > self.threshold {
return Err(Error::other(format!(
"Vector property '{}' is too similar to a banned vector (similarity: {:.4} > {:.4})",
self.property_name, similarity, self.threshold
)));
}
}
Err(_) => {
continue;
}
}
}
Ok(())
}
}
pub struct NumericRangeRule {
pub property_name: String,
pub min: Option<f64>,
pub max: Option<f64>,
}
impl NumericRangeRule {
pub fn new(property_name: impl Into<String>) -> Self {
Self {
property_name: property_name.into(),
min: None,
max: None,
}
}
pub fn min(mut self, min: f64) -> Self {
self.min = Some(min);
self
}
pub fn max(mut self, max: f64) -> Self {
self.max = Some(max);
self
}
}
impl SemanticRule for NumericRangeRule {
fn validate(&self, props: &PropertyMap) -> Result<()> {
let val = match props.get(&self.property_name) {
Some(v) => v,
None => return Ok(()),
};
let num = if let Some(f) = val.as_float() {
f
} else if let Some(i) = val.as_int() {
i as f64
} else {
return Ok(()); };
if !num.is_finite() {
return Err(Error::other(format!(
"Property '{}' value is not finite (NaN or Inf)",
self.property_name
)));
}
if let Some(min) = self.min
&& num < min
{
return Err(Error::other(format!(
"Property '{}' value is less than minimum {}",
self.property_name, min
)));
}
if let Some(max) = self.max
&& num > max
{
return Err(Error::other(format!(
"Property '{}' value is greater than maximum {}",
self.property_name, max
)));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::property::PropertyMapBuilder;
#[test]
fn test_vector_ban_rule() {
let mut rule = VectorBanRule::new("embedding", 0.9);
rule.add_banned_vector(vec![1.0, 0.0]).unwrap();
let props1 = PropertyMapBuilder::new()
.insert_vector("embedding", &[1.0, 0.0])
.build();
assert!(rule.validate(&props1).is_err());
let props2 = PropertyMapBuilder::new()
.insert_vector("embedding", &[0.99, 0.14])
.build();
assert!(rule.validate(&props2).is_err());
let props3 = PropertyMapBuilder::new()
.insert_vector("embedding", &[0.0, 1.0])
.build();
assert!(rule.validate(&props3).is_ok());
}
#[test]
fn test_vector_ban_rule_capacity_limit() {
let mut rule = VectorBanRule::new("embedding", 0.9);
for _ in 0..1000 {
rule.add_banned_vector(vec![1.0, 0.0]).unwrap();
}
let result = rule.add_banned_vector(vec![1.0, 0.0]);
assert!(result.is_err());
match result.unwrap_err() {
Error::Storage(StorageError::CapacityExceeded { limit, .. }) => {
assert_eq!(limit, 1000);
}
_ => panic!("Expected CapacityExceeded error"),
}
}
#[test]
fn test_vector_ban_rule_nan_handling() {
let mut rule = VectorBanRule::new("embedding", 0.9);
rule.add_banned_vector(vec![1.0, 0.0]).unwrap();
let props = PropertyMapBuilder::new()
.insert_vector("embedding", &[f32::NAN, 0.0])
.build();
let result = rule.validate(&props);
assert!(result.is_err());
assert!(format!("{}", result.unwrap_err()).contains("non-finite"));
}
#[test]
fn test_numeric_range_nan_handling() {
let rule = NumericRangeRule::new("age").min(18.0);
let props = PropertyMapBuilder::new().insert("age", f64::NAN).build();
let result = rule.validate(&props);
assert!(result.is_err());
assert!(format!("{}", result.unwrap_err()).contains("not finite"));
}
#[test]
fn test_numeric_range_error_privacy() {
let rule = NumericRangeRule::new("salary").max(50000.0);
let props = PropertyMapBuilder::new().insert("salary", 100000.0).build();
let result = rule.validate(&props);
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(msg.contains("greater than maximum 50000"));
assert!(
!msg.contains("100000"),
"Sensitive value leaked in error message"
);
}
#[test]
fn test_sentinel_integration() {
let mut sentinel = Sentinel::new();
let mut ban_rule = VectorBanRule::new("embedding", 0.8);
ban_rule.add_banned_vector(vec![1.0, 0.0]).unwrap();
sentinel.add_rule(Box::new(ban_rule));
let range_rule = NumericRangeRule::new("age").min(18.0);
sentinel.add_rule(Box::new(range_rule));
let valid = PropertyMapBuilder::new()
.insert_vector("embedding", &[0.0, 1.0])
.insert("age", 25)
.build();
assert!(sentinel.validate(&valid).is_ok());
let toxic = PropertyMapBuilder::new()
.insert_vector("embedding", &[0.9, 0.1]) .insert("age", 25)
.build();
assert!(sentinel.validate(&toxic).is_err());
let underage = PropertyMapBuilder::new()
.insert_vector("embedding", &[0.0, 1.0])
.insert("age", 16)
.build();
let res = sentinel.validate(&underage);
assert!(res.is_err());
assert!(format!("{}", res.unwrap_err()).contains("less than minimum"));
}
}