use crate::{Rule, RuleAtom, RuleEngine};
use anyhow::Result;
use scirs2_core::metrics::Timer;
use std::collections::HashMap;
use tracing::{debug, info, warn};
lazy_static::lazy_static! {
static ref VALIDATION_DIRECT_TIMER: Timer = Timer::new("shacl_validation_direct".to_string());
static ref VALIDATION_PRE_REASONING_TIMER: Timer = Timer::new("shacl_validation_pre_reasoning".to_string());
}
#[derive(Debug, Clone, PartialEq)]
pub enum ValidationMode {
Direct,
PreReasoning,
PostReasoning,
Full,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum Severity {
Info,
Warning,
Violation,
}
#[derive(Debug, Clone)]
pub struct ValidationResult {
pub valid: bool,
pub severity: Severity,
pub constraint_type: String,
pub focus_node: Option<String>,
pub value: Option<String>,
pub message: String,
}
impl ValidationResult {
pub fn new(valid: bool, severity: Severity, constraint_type: String, message: String) -> Self {
Self {
valid,
severity,
constraint_type,
focus_node: None,
value: None,
message,
}
}
pub fn with_focus_node(mut self, node: String) -> Self {
self.focus_node = Some(node);
self
}
pub fn with_value(mut self, value: String) -> Self {
self.value = Some(value);
self
}
}
#[derive(Debug, Clone)]
pub struct ValidationReport {
pub conforms: bool,
pub results: Vec<ValidationResult>,
pub stats: ValidationStats,
}
impl ValidationReport {
pub fn new(conforms: bool) -> Self {
Self {
conforms,
results: Vec::new(),
stats: ValidationStats::default(),
}
}
pub fn add_result(&mut self, result: ValidationResult) {
self.conforms = self.conforms && result.valid;
self.results.push(result);
}
pub fn violation_count(&self) -> usize {
self.results
.iter()
.filter(|r| !r.valid && r.severity == Severity::Violation)
.count()
}
pub fn warning_count(&self) -> usize {
self.results
.iter()
.filter(|r| !r.valid && r.severity == Severity::Warning)
.count()
}
}
#[derive(Debug, Clone, Default)]
pub struct ValidationStats {
pub shapes_validated: usize,
pub constraints_checked: usize,
pub validation_time_ms: u128,
pub rules_applied: usize,
}
#[derive(Debug, Clone)]
pub struct ShapeConstraint {
pub id: String,
pub constraint_type: String,
pub predicate: Option<String>,
pub expected: Option<String>,
pub severity: Severity,
}
impl ShapeConstraint {
pub fn new(id: String, constraint_type: String) -> Self {
Self {
id,
constraint_type,
predicate: None,
expected: None,
severity: Severity::Violation,
}
}
}
pub struct ShaclRuleIntegration {
engine: RuleEngine,
mode: ValidationMode,
shape_rules: HashMap<String, Vec<String>>,
repair_rules: HashMap<String, Vec<Rule>>,
validation_cache: HashMap<String, ValidationResult>,
stats: IntegrationStats,
inferred_cache: Option<Vec<RuleAtom>>,
data_hash: u64,
}
impl ShaclRuleIntegration {
pub fn new(engine: RuleEngine) -> Self {
Self {
engine,
mode: ValidationMode::Full,
shape_rules: HashMap::new(),
repair_rules: HashMap::new(),
validation_cache: HashMap::new(),
stats: IntegrationStats::default(),
inferred_cache: None,
data_hash: 0,
}
}
pub fn set_mode(&mut self, mode: ValidationMode) {
info!("Setting validation mode to {:?}", mode);
self.mode = mode;
}
pub fn get_mode(&self) -> &ValidationMode {
&self.mode
}
pub fn register_shape_rule(&mut self, shape_id: String, rule_name: String) {
debug!("Registering rule '{}' for shape '{}'", rule_name, shape_id);
self.shape_rules
.entry(shape_id)
.or_default()
.push(rule_name);
}
pub fn register_repair_rule(&mut self, constraint_type: String, rule: Rule) {
debug!(
"Registering repair rule '{}' for constraint '{}'",
rule.name, constraint_type
);
self.repair_rules
.entry(constraint_type)
.or_default()
.push(rule);
}
pub fn validate_with_reasoning(
&mut self,
constraints: &[ShapeConstraint],
data: &[RuleAtom],
) -> Result<ValidationReport> {
let _validation_timer = VALIDATION_DIRECT_TIMER.start();
self.stats.total_validations += 1;
let start = std::time::Instant::now();
let mut report = ValidationReport::new(true);
let data_to_validate = match self.mode {
ValidationMode::PreReasoning | ValidationMode::Full => {
self.apply_pre_reasoning(data)?
}
_ => {
if data.len() > 100 {
use crate::simd_ops::SimdMatcher;
let matcher = SimdMatcher::new();
let mut deduped = data.to_vec();
matcher.batch_deduplicate(&mut deduped);
deduped
} else {
data.to_vec()
}
}
};
for constraint in constraints {
let result = self.validate_constraint(constraint, &data_to_validate)?;
report.add_result(result.clone());
self.validation_cache.insert(constraint.id.clone(), result);
}
report.stats.validation_time_ms = start.elapsed().as_millis();
report.stats.shapes_validated = 1;
report.stats.constraints_checked = constraints.len();
if !report.conforms
&& (self.mode == ValidationMode::PostReasoning || self.mode == ValidationMode::Full)
{
self.apply_repairs(&mut report, data)?;
}
Ok(report)
}
fn apply_pre_reasoning(&mut self, data: &[RuleAtom]) -> Result<Vec<RuleAtom>> {
let _timer = VALIDATION_PRE_REASONING_TIMER.start();
debug!("Applying pre-validation reasoning");
let inferred = self.get_inferred_facts(data)?;
self.stats.pre_reasoning_applications += 1;
Ok(inferred)
}
fn validate_constraint(
&self,
constraint: &ShapeConstraint,
data: &[RuleAtom],
) -> Result<ValidationResult> {
if let Some(cached) = self.validation_cache.get(&constraint.id) {
debug!("Cache hit for constraint '{}'", constraint.id);
return Ok(cached.clone());
}
let valid = self.check_constraint(constraint, data);
let message = if valid {
format!("Constraint '{}' satisfied", constraint.constraint_type)
} else {
format!("Constraint '{}' violated", constraint.constraint_type)
};
Ok(ValidationResult::new(
valid,
constraint.severity.clone(),
constraint.constraint_type.clone(),
message,
))
}
fn check_constraint(&self, constraint: &ShapeConstraint, data: &[RuleAtom]) -> bool {
match constraint.constraint_type.as_str() {
"sh:minCount" => {
data.len()
>= constraint
.expected
.as_ref()
.and_then(|s| s.parse().ok())
.unwrap_or(1)
}
"sh:maxCount" => {
data.len()
<= constraint
.expected
.as_ref()
.and_then(|s| s.parse().ok())
.unwrap_or(100)
}
_ => !data.is_empty(),
}
}
fn apply_repairs(&mut self, report: &mut ValidationReport, data: &[RuleAtom]) -> Result<()> {
debug!("Applying constraint repairs");
let violations: Vec<_> = report
.results
.iter()
.filter(|r| !r.valid && r.severity == Severity::Violation)
.collect();
for violation in violations {
if let Some(repair_rules) = self.repair_rules.get(&violation.constraint_type) {
for rule in repair_rules {
self.engine.add_rule(rule.clone());
}
let repaired = self.engine.forward_chain(data)?;
warn!(
"Applied {} repair rules for constraint '{}'",
repair_rules.len(),
violation.constraint_type
);
self.stats.repairs_applied += repair_rules.len();
self.stats.post_reasoning_applications += 1;
self.engine.add_facts(repaired);
}
}
Ok(())
}
pub fn clear_cache(&mut self) {
debug!(
"Clearing validation cache ({} entries)",
self.validation_cache.len()
);
self.validation_cache.clear();
}
pub fn engine_mut(&mut self) -> &mut RuleEngine {
&mut self.engine
}
pub fn engine(&self) -> &RuleEngine {
&self.engine
}
pub fn get_stats(&self) -> &IntegrationStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = IntegrationStats::default();
}
fn compute_data_hash(&self, data: &[RuleAtom]) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
data.len().hash(&mut hasher);
let sample_size = data.len().min(10);
for atom in data.iter().take(sample_size) {
format!("{:?}", atom).hash(&mut hasher);
}
if data.len() > sample_size {
for atom in data.iter().skip(data.len() - sample_size) {
format!("{:?}", atom).hash(&mut hasher);
}
}
hasher.finish()
}
pub fn invalidate_cache(&mut self) {
self.inferred_cache = None;
self.data_hash = 0;
debug!("Inference cache invalidated");
}
fn get_inferred_facts(&mut self, data: &[RuleAtom]) -> Result<Vec<RuleAtom>> {
let current_hash = self.compute_data_hash(data);
if let Some(ref cached) = self.inferred_cache {
if current_hash == self.data_hash {
debug!("Using cached inferred facts ({} facts)", cached.len());
self.stats.cache_hits += 1;
return Ok(cached.clone());
}
}
debug!("Cache miss - inferring facts");
self.stats.cache_misses += 1;
self.engine.add_facts(data.to_vec());
let inferred = self.engine.forward_chain(data)?;
self.inferred_cache = Some(inferred.clone());
self.data_hash = current_hash;
Ok(inferred)
}
}
#[derive(Debug, Clone, Default)]
pub struct IntegrationStats {
pub total_validations: usize,
pub pre_reasoning_applications: usize,
pub post_reasoning_applications: usize,
pub repairs_applied: usize,
pub cache_hits: usize,
pub cache_misses: usize,
}
impl std::fmt::Display for IntegrationStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Validations: {}, Pre-reasoning: {}, Post-reasoning: {}, Repairs: {}, Cache(hits/misses): {}/{}",
self.total_validations,
self.pre_reasoning_applications,
self.post_reasoning_applications,
self.repairs_applied,
self.cache_hits,
self.cache_misses
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Term;
#[test]
fn test_validation_result_creation() {
let result = ValidationResult::new(
true,
Severity::Info,
"sh:minCount".to_string(),
"Min count satisfied".to_string(),
);
assert!(result.valid);
assert_eq!(result.severity, Severity::Info);
}
#[test]
fn test_validation_report() {
let mut report = ValidationReport::new(true);
report.add_result(ValidationResult::new(
false,
Severity::Violation,
"sh:minCount".to_string(),
"Min count not satisfied".to_string(),
));
assert!(!report.conforms);
assert_eq!(report.violation_count(), 1);
}
#[test]
fn test_shacl_integration_creation() {
let engine = RuleEngine::new();
let integration = ShaclRuleIntegration::new(engine);
assert_eq!(*integration.get_mode(), ValidationMode::Full);
}
#[test]
fn test_validation_mode_setting() {
let engine = RuleEngine::new();
let mut integration = ShaclRuleIntegration::new(engine);
integration.set_mode(ValidationMode::PreReasoning);
assert_eq!(*integration.get_mode(), ValidationMode::PreReasoning);
}
#[test]
fn test_shape_rule_registration() {
let engine = RuleEngine::new();
let mut integration = ShaclRuleIntegration::new(engine);
integration.register_shape_rule("PersonShape".to_string(), "age_validation".to_string());
assert_eq!(integration.shape_rules.len(), 1);
}
#[test]
fn test_repair_rule_registration() {
let engine = RuleEngine::new();
let mut integration = ShaclRuleIntegration::new(engine);
let repair_rule = Rule {
name: "fix_mincount".to_string(),
body: vec![],
head: vec![],
};
integration.register_repair_rule("sh:minCount".to_string(), repair_rule);
assert_eq!(integration.repair_rules.len(), 1);
}
#[test]
fn test_validation_with_reasoning() -> Result<(), Box<dyn std::error::Error>> {
let mut engine = RuleEngine::new();
engine.add_fact(RuleAtom::Triple {
subject: Term::Constant("john".to_string()),
predicate: Term::Constant("type".to_string()),
object: Term::Constant("Person".to_string()),
});
let mut integration = ShaclRuleIntegration::new(engine);
let constraint = ShapeConstraint::new("c1".to_string(), "sh:minCount".to_string());
let data = vec![RuleAtom::Triple {
subject: Term::Constant("john".to_string()),
predicate: Term::Constant("age".to_string()),
object: Term::Literal("30".to_string()),
}];
let report = integration.validate_with_reasoning(&[constraint], &data)?;
assert_eq!(report.results.len(), 1);
Ok(())
}
#[test]
fn test_cache_clearing() {
let engine = RuleEngine::new();
let mut integration = ShaclRuleIntegration::new(engine);
integration.validation_cache.insert(
"test".to_string(),
ValidationResult::new(true, Severity::Info, "test".to_string(), "test".to_string()),
);
assert_eq!(integration.validation_cache.len(), 1);
integration.clear_cache();
assert_eq!(integration.validation_cache.len(), 0);
}
}