use crate::ViolationComputable;
use scirs2_core::ndarray::Array1;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ViolationExplanation {
pub description: String,
pub violated_constraints: HashMap<usize, f32>,
pub responsible_dimensions: Vec<(usize, f32)>, pub suggestions: Vec<String>,
pub nearest_feasible: Option<Array1<f32>>,
}
impl ViolationExplanation {
pub fn new(description: impl Into<String>) -> Self {
Self {
description: description.into(),
violated_constraints: HashMap::new(),
responsible_dimensions: Vec::new(),
suggestions: Vec::new(),
nearest_feasible: None,
}
}
pub fn add_violated_constraint(&mut self, index: usize, violation: f32) {
self.violated_constraints.insert(index, violation);
}
pub fn add_responsible_dimension(&mut self, dim: usize, contribution: f32) {
self.responsible_dimensions.push((dim, contribution));
self.responsible_dimensions
.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
}
pub fn add_suggestion(&mut self, suggestion: impl Into<String>) {
self.suggestions.push(suggestion.into());
}
pub fn set_nearest_feasible(&mut self, point: Array1<f32>) {
self.nearest_feasible = Some(point);
}
pub fn to_report(&self) -> String {
let mut report = format!("=== Violation Explanation ===\n{}\n\n", self.description);
if !self.violated_constraints.is_empty() {
report.push_str("Violated Constraints:\n");
let mut violations: Vec<_> = self.violated_constraints.iter().collect();
violations.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap_or(std::cmp::Ordering::Equal));
for (idx, viol) in violations {
report.push_str(&format!(
" - Constraint #{}: violation = {:.4}\n",
idx, viol
));
}
report.push('\n');
}
if !self.responsible_dimensions.is_empty() {
report.push_str("Most Responsible Dimensions:\n");
for (dim, contrib) in self.responsible_dimensions.iter().take(5) {
report.push_str(&format!(
" - Dimension {}: contribution = {:.4}\n",
dim, contrib
));
}
report.push('\n');
}
if !self.suggestions.is_empty() {
report.push_str("Suggestions:\n");
for suggestion in &self.suggestions {
report.push_str(&format!(" - {}\n", suggestion));
}
report.push('\n');
}
if let Some(ref nearest) = self.nearest_feasible {
report.push_str(&format!(
"Nearest Feasible Point: {:?}\n",
nearest.as_slice()
));
}
report
}
}
#[derive(Debug, Clone)]
pub struct MinimalViolatingSubsetFinder<C: ViolationComputable> {
constraints: Vec<C>,
names: Vec<String>,
}
impl<C: ViolationComputable + Clone> MinimalViolatingSubsetFinder<C> {
pub fn new(constraints: Vec<C>, names: Vec<String>) -> Self {
assert_eq!(constraints.len(), names.len());
Self { constraints, names }
}
pub fn find_mvs(&self, point: &Array1<f32>) -> Vec<usize> {
let point_slice = point.as_slice().unwrap_or(&[]);
let violated: Vec<usize> = self
.constraints
.iter()
.enumerate()
.filter(|(_, c)| !c.check(point_slice))
.map(|(i, _)| i)
.collect();
if violated.is_empty() {
return Vec::new(); }
let minimal = violated.clone();
let mut _changed = true;
let _ = _changed;
minimal
}
pub fn explain(&self, point: &Array1<f32>) -> ViolationExplanation {
let point_slice = point.as_slice().unwrap_or(&[]);
let mvs = self.find_mvs(point);
let mut explanation = ViolationExplanation::new(format!(
"Point violates {} out of {} constraints",
mvs.len(),
self.constraints.len()
));
for &idx in &mvs {
let violation = self.constraints[idx].violation(point_slice);
explanation.add_violated_constraint(idx, violation);
explanation.add_suggestion(format!(
"Fix constraint '{}' (violation: {:.4})",
self.names[idx], violation
));
}
explanation
}
pub fn num_constraints(&self) -> usize {
self.constraints.len()
}
}
#[derive(Debug, Clone)]
pub struct ViolationAttributionAnalyzer<C: ViolationComputable> {
constraints: Vec<C>,
feature_names: Vec<String>,
}
impl<C: ViolationComputable + Clone> ViolationAttributionAnalyzer<C> {
pub fn new(constraints: Vec<C>, feature_names: Vec<String>) -> Self {
Self {
constraints,
feature_names,
}
}
pub fn attribute_violations(&self, point: &Array1<f32>) -> Vec<(usize, f32)> {
let point_slice = point.as_slice().unwrap_or(&[]);
let mut attributions = Vec::new();
let total_violation: f32 = self
.constraints
.iter()
.map(|c| c.violation(point_slice).max(0.0))
.sum();
if total_violation < 1e-8 {
return attributions; }
for dim in 0..point.len() {
let mut perturbed = point.clone();
let epsilon = 0.01;
perturbed[dim] += epsilon;
let viol_plus: f32 = self
.constraints
.iter()
.map(|c| c.violation(perturbed.as_slice().unwrap_or(&[])).max(0.0))
.sum();
perturbed[dim] = point[dim] - epsilon;
let viol_minus: f32 = self
.constraints
.iter()
.map(|c| c.violation(perturbed.as_slice().unwrap_or(&[])).max(0.0))
.sum();
let sensitivity = ((viol_plus - total_violation).abs()
+ (viol_minus - total_violation).abs())
/ (2.0 * epsilon);
attributions.push((dim, sensitivity));
}
attributions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
attributions
}
pub fn explain(&self, point: &Array1<f32>) -> ViolationExplanation {
let point_slice = point.as_slice().unwrap_or(&[]);
let attributions = self.attribute_violations(point);
let total_violation: f32 = self
.constraints
.iter()
.map(|c| c.violation(point_slice).max(0.0))
.sum();
let mut explanation =
ViolationExplanation::new(format!("Total violation: {:.4}", total_violation));
for (dim, attr) in &attributions {
explanation.add_responsible_dimension(*dim, *attr);
}
for (dim, _) in attributions.iter().take(3) {
let feature_name = self
.feature_names
.get(*dim)
.map(|s| s.as_str())
.unwrap_or("unknown");
explanation.add_suggestion(format!(
"Adjust feature '{}' (dimension {})",
feature_name, dim
));
}
explanation
}
}
#[derive(Debug, Clone)]
pub struct CounterfactualAnalyzer<C: ViolationComputable> {
constraints: Vec<C>,
max_iterations: usize,
step_size: f32,
}
impl<C: ViolationComputable + Clone> CounterfactualAnalyzer<C> {
pub fn new(constraints: Vec<C>, max_iterations: usize, step_size: f32) -> Self {
Self {
constraints,
max_iterations,
step_size,
}
}
pub fn find_nearest_feasible(&self, point: &Array1<f32>) -> Option<Array1<f32>> {
let mut current = point.clone();
let epsilon = 1e-4;
for _ in 0..self.max_iterations {
let current_slice = current.as_slice().unwrap_or(&[]);
let is_feasible = self.constraints.iter().all(|c| c.check(current_slice));
if is_feasible {
return Some(current);
}
let total_violation: f32 = self
.constraints
.iter()
.map(|c| c.violation(current_slice).max(0.0))
.sum();
if total_violation < 1e-6 {
return Some(current);
}
for dim in 0..current.len() {
let mut perturbed = current.clone();
perturbed[dim] += epsilon;
let viol_plus: f32 = self
.constraints
.iter()
.map(|c| c.violation(perturbed.as_slice().unwrap_or(&[])).max(0.0))
.sum();
let grad = (viol_plus - total_violation) / epsilon;
current[dim] -= self.step_size * grad;
}
}
None
}
pub fn explain(&self, point: &Array1<f32>) -> ViolationExplanation {
let point_slice = point.as_slice().unwrap_or(&[]);
let total_violation: f32 = self
.constraints
.iter()
.map(|c| c.violation(point_slice).max(0.0))
.sum();
let mut explanation = ViolationExplanation::new(format!(
"Searching for nearest feasible point (current violation: {:.4})",
total_violation
));
if let Some(nearest) = self.find_nearest_feasible(point) {
let distance = point
.iter()
.zip(nearest.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
.sqrt();
explanation.set_nearest_feasible(nearest.clone());
explanation.add_suggestion(format!(
"Move to nearest feasible point (distance: {:.4})",
distance
));
for dim in 0..point.len() {
let change = (point[dim] - nearest[dim]).abs();
if change > 1e-4 {
explanation.add_responsible_dimension(dim, change);
}
}
} else {
explanation
.add_suggestion("Could not find feasible point in search radius".to_string());
}
explanation
}
}
#[derive(Debug, Clone)]
pub struct ViolationExplainer<C: ViolationComputable> {
mvs_finder: Option<MinimalViolatingSubsetFinder<C>>,
attribution_analyzer: Option<ViolationAttributionAnalyzer<C>>,
counterfactual_analyzer: Option<CounterfactualAnalyzer<C>>,
}
impl<C: ViolationComputable + Clone> ViolationExplainer<C> {
pub fn new() -> Self {
Self {
mvs_finder: None,
attribution_analyzer: None,
counterfactual_analyzer: None,
}
}
pub fn with_mvs_finder(mut self, constraints: Vec<C>, names: Vec<String>) -> Self {
self.mvs_finder = Some(MinimalViolatingSubsetFinder::new(constraints, names));
self
}
pub fn with_attribution_analyzer(
mut self,
constraints: Vec<C>,
feature_names: Vec<String>,
) -> Self {
self.attribution_analyzer = Some(ViolationAttributionAnalyzer::new(
constraints,
feature_names,
));
self
}
pub fn with_counterfactual_analyzer(
mut self,
constraints: Vec<C>,
max_iterations: usize,
step_size: f32,
) -> Self {
self.counterfactual_analyzer = Some(CounterfactualAnalyzer::new(
constraints,
max_iterations,
step_size,
));
self
}
pub fn explain(&self, point: &Array1<f32>) -> ViolationExplanation {
let mut explanation = ViolationExplanation::new("Comprehensive Violation Analysis");
if let Some(ref mvs) = self.mvs_finder {
let mvs_exp = mvs.explain(point);
explanation.violated_constraints = mvs_exp.violated_constraints;
explanation.suggestions.extend(mvs_exp.suggestions);
}
if let Some(ref attr) = self.attribution_analyzer {
let attr_exp = attr.explain(point);
explanation.responsible_dimensions = attr_exp.responsible_dimensions;
explanation.suggestions.extend(attr_exp.suggestions);
}
if let Some(ref cf) = self.counterfactual_analyzer {
let cf_exp = cf.explain(point);
explanation.nearest_feasible = cf_exp.nearest_feasible;
explanation.suggestions.extend(cf_exp.suggestions);
}
explanation
}
}
impl<C: ViolationComputable + Clone> Default for ViolationExplainer<C> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::LinearConstraint;
#[test]
fn test_violation_explanation() {
let mut exp = ViolationExplanation::new("Test violation");
exp.add_violated_constraint(0, 1.5);
exp.add_responsible_dimension(0, 0.8);
exp.add_suggestion("Reduce dimension 0");
let report = exp.to_report();
assert!(report.contains("Test violation"));
assert!(report.contains("Constraint #0"));
}
#[test]
fn test_mvs_finder() {
let constraints = vec![
LinearConstraint::less_eq(vec![1.0], 5.0),
LinearConstraint::less_eq(vec![1.0], 3.0),
];
let names = vec!["c1".to_string(), "c2".to_string()];
let finder = MinimalViolatingSubsetFinder::new(constraints, names);
let point = Array1::from_vec(vec![7.0]); let mvs = finder.find_mvs(&point);
assert!(!mvs.is_empty());
assert_eq!(finder.num_constraints(), 2);
}
#[test]
fn test_attribution_analyzer() {
let constraints = vec![LinearConstraint::less_eq(vec![1.0, 1.0], 5.0)];
let features = vec!["x".to_string(), "y".to_string()];
let analyzer = ViolationAttributionAnalyzer::new(constraints, features);
let point = Array1::from_vec(vec![4.0, 3.0]); let exp = analyzer.explain(&point);
assert!(!exp.responsible_dimensions.is_empty());
}
#[test]
fn test_counterfactual_analyzer() {
let constraints = vec![LinearConstraint::less_eq(vec![1.0], 5.0)];
let analyzer = CounterfactualAnalyzer::new(constraints, 100, 0.1);
let point = Array1::from_vec(vec![10.0]); let nearest = analyzer.find_nearest_feasible(&point);
assert!(nearest.is_some());
if let Some(feasible) = nearest {
assert!(feasible[0] <= 5.0 + 1e-2); }
}
#[test]
fn test_unified_explainer() {
let constraints = vec![LinearConstraint::less_eq(vec![1.0], 5.0)];
let names = vec!["x_max".to_string()];
let features = vec!["x".to_string()];
let explainer = ViolationExplainer::new()
.with_mvs_finder(constraints.clone(), names)
.with_attribution_analyzer(constraints.clone(), features)
.with_counterfactual_analyzer(constraints, 100, 0.1);
let point = Array1::from_vec(vec![10.0]);
let exp = explainer.explain(&point);
assert!(!exp.violated_constraints.is_empty());
assert!(!exp.suggestions.is_empty());
}
}