use crate::error::{QcError, QcIssue, QcResult, Severity};
use std::collections::HashMap;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct QualityRule {
pub id: String,
pub name: String,
pub description: String,
pub category: RuleCategory,
pub severity: Severity,
pub priority: i32,
pub rule_type: RuleType,
pub config: RuleConfig,
pub enabled: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum RuleCategory {
Raster,
Vector,
Metadata,
Topology,
Attribution,
General,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum RuleType {
Threshold {
field: String,
operator: ComparisonOperator,
value: f64,
},
Range {
field: String,
min: f64,
max: f64,
},
Enumeration {
field: String,
allowed_values: Vec<String>,
},
Pattern {
field: String,
pattern: String,
},
Custom {
function_name: String,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum ComparisonOperator {
Equal,
NotEqual,
GreaterThan,
GreaterThanOrEqual,
LessThan,
LessThanOrEqual,
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct RuleConfig {
pub parameters: HashMap<String, String>,
pub pass_threshold: Option<f64>,
pub fail_threshold: Option<f64>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct RuleSet {
pub name: String,
pub description: String,
pub version: String,
pub rules: Vec<QualityRule>,
}
impl RuleSet {
#[must_use]
pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
version: "1.0".to_string(),
rules: Vec::new(),
}
}
pub fn add_rule(&mut self, rule: QualityRule) {
self.rules.push(rule);
}
pub fn from_toml_file(path: impl AsRef<std::path::Path>) -> QcResult<Self> {
let content = std::fs::read_to_string(path).map_err(QcError::Io)?;
let ruleset: RuleSet = toml::from_str(&content)?;
Ok(ruleset)
}
pub fn to_toml_file(&self, path: impl AsRef<std::path::Path>) -> QcResult<()> {
let content = toml::to_string_pretty(self).map_err(|e| {
QcError::InvalidConfiguration(format!("Failed to serialize rule set: {}", e))
})?;
std::fs::write(path, content).map_err(QcError::Io)?;
Ok(())
}
#[must_use]
pub fn get_enabled_rules(&self) -> Vec<&QualityRule> {
let mut rules: Vec<&QualityRule> = self.rules.iter().filter(|r| r.enabled).collect();
rules.sort_by_key(|x| std::cmp::Reverse(x.priority));
rules
}
#[must_use]
pub fn get_rules_by_category(&self, category: RuleCategory) -> Vec<&QualityRule> {
self.rules
.iter()
.filter(|r| r.category == category)
.collect()
}
}
pub struct RulesEngine {
rule_set: RuleSet,
}
impl RulesEngine {
#[must_use]
pub fn new(rule_set: RuleSet) -> Self {
Self { rule_set }
}
pub fn from_toml_file(path: impl AsRef<std::path::Path>) -> QcResult<Self> {
let rule_set = RuleSet::from_toml_file(path)?;
Ok(Self::new(rule_set))
}
pub fn execute_rule(
&self,
rule: &QualityRule,
data: &HashMap<String, f64>,
) -> QcResult<Option<QcIssue>> {
if !rule.enabled {
return Ok(None);
}
let violated = match &rule.rule_type {
RuleType::Threshold {
field,
operator,
value,
} => {
if let Some(&field_value) = data.get(field) {
!self.compare_values(field_value, *value, *operator)
} else {
true }
}
RuleType::Range { field, min, max } => {
if let Some(&field_value) = data.get(field) {
field_value < *min || field_value > *max
} else {
true }
}
RuleType::Enumeration { .. } => {
false
}
RuleType::Pattern { .. } => {
false
}
RuleType::Custom { .. } => {
false
}
};
if violated {
Ok(Some(
QcIssue::new(
rule.severity,
format!("{:?}", rule.category).to_lowercase(),
&rule.name,
format!("{}: Rule violated", rule.description),
)
.with_rule_id(&rule.id),
))
} else {
Ok(None)
}
}
pub fn execute_all(&self, data: &HashMap<String, f64>) -> QcResult<Vec<QcIssue>> {
let mut issues = Vec::new();
for rule in self.rule_set.get_enabled_rules() {
if let Some(issue) = self.execute_rule(rule, data)? {
issues.push(issue);
}
}
Ok(issues)
}
pub fn execute_category(
&self,
category: RuleCategory,
data: &HashMap<String, f64>,
) -> QcResult<Vec<QcIssue>> {
let mut issues = Vec::new();
for rule in self.rule_set.get_rules_by_category(category) {
if let Some(issue) = self.execute_rule(rule, data)? {
issues.push(issue);
}
}
Ok(issues)
}
#[must_use]
pub const fn rule_set(&self) -> &RuleSet {
&self.rule_set
}
fn compare_values(&self, a: f64, b: f64, op: ComparisonOperator) -> bool {
match op {
ComparisonOperator::Equal => (a - b).abs() < f64::EPSILON,
ComparisonOperator::NotEqual => (a - b).abs() >= f64::EPSILON,
ComparisonOperator::GreaterThan => a > b,
ComparisonOperator::GreaterThanOrEqual => a >= b,
ComparisonOperator::LessThan => a < b,
ComparisonOperator::LessThanOrEqual => a <= b,
}
}
}
pub struct RuleBuilder {
rule: QualityRule,
}
impl RuleBuilder {
#[must_use]
pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
Self {
rule: QualityRule {
id: id.into(),
name: name.into(),
description: String::new(),
category: RuleCategory::General,
severity: Severity::Warning,
priority: 0,
rule_type: RuleType::Custom {
function_name: "default".to_string(),
},
config: RuleConfig::default(),
enabled: true,
},
}
}
#[must_use]
pub fn description(mut self, description: impl Into<String>) -> Self {
self.rule.description = description.into();
self
}
#[must_use]
pub const fn category(mut self, category: RuleCategory) -> Self {
self.rule.category = category;
self
}
#[must_use]
pub const fn severity(mut self, severity: Severity) -> Self {
self.rule.severity = severity;
self
}
#[must_use]
pub const fn priority(mut self, priority: i32) -> Self {
self.rule.priority = priority;
self
}
#[must_use]
pub fn threshold(
mut self,
field: impl Into<String>,
operator: ComparisonOperator,
value: f64,
) -> Self {
self.rule.rule_type = RuleType::Threshold {
field: field.into(),
operator,
value,
};
self
}
#[must_use]
pub fn range(mut self, field: impl Into<String>, min: f64, max: f64) -> Self {
self.rule.rule_type = RuleType::Range {
field: field.into(),
min,
max,
};
self
}
#[must_use]
pub fn build(self) -> QualityRule {
self.rule
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rule_builder() {
let rule = RuleBuilder::new("TEST-001", "Test Rule")
.description("Test description")
.category(RuleCategory::Raster)
.severity(Severity::Major)
.priority(10)
.threshold("field1", ComparisonOperator::GreaterThan, 100.0)
.build();
assert_eq!(rule.id, "TEST-001");
assert_eq!(rule.name, "Test Rule");
assert_eq!(rule.category, RuleCategory::Raster);
assert_eq!(rule.severity, Severity::Major);
assert_eq!(rule.priority, 10);
}
#[test]
fn test_rule_set() {
let mut ruleset = RuleSet::new("Test Rules", "Test rule set");
let rule = RuleBuilder::new("R001", "Rule 1")
.threshold("value", ComparisonOperator::LessThan, 50.0)
.build();
ruleset.add_rule(rule);
assert_eq!(ruleset.rules.len(), 1);
}
#[test]
fn test_rules_engine() {
let mut ruleset = RuleSet::new("Test", "Test");
let rule = RuleBuilder::new("R001", "Max Value Check")
.threshold("max_value", ComparisonOperator::LessThanOrEqual, 100.0)
.severity(Severity::Major)
.build();
ruleset.add_rule(rule);
let engine = RulesEngine::new(ruleset);
let mut data = HashMap::new();
data.insert("max_value".to_string(), 150.0);
let result = engine.execute_all(&data);
assert!(result.is_ok());
let issues = result.ok().unwrap_or_default();
assert_eq!(issues.len(), 1);
}
#[test]
fn test_comparison_operators() {
let engine = RulesEngine::new(RuleSet::new("Test", "Test"));
assert!(engine.compare_values(10.0, 5.0, ComparisonOperator::GreaterThan));
assert!(engine.compare_values(5.0, 10.0, ComparisonOperator::LessThan));
assert!(engine.compare_values(10.0, 10.0, ComparisonOperator::Equal));
assert!(engine.compare_values(10.0, 5.0, ComparisonOperator::NotEqual));
}
}