use serde::{Deserialize, Serialize};
use super::types::{FuzzyInferenceOutput, LinguisticVariable};
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(tag = "method", rename_all = "snake_case")]
pub enum DefuzzMethod {
Centroid,
Bisector,
MeanOfMaxima,
Height,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct Domain {
pub min: f64,
pub max: f64,
pub steps: usize,
}
impl Domain {
pub fn new(min: f64, max: f64, steps: usize) -> Self {
Self { min, max, steps }
}
fn is_valid(&self) -> bool {
self.min.is_finite() && self.max.is_finite() && self.min < self.max && self.steps > 0
}
}
pub fn defuzzify_mamdani(
output: &FuzzyInferenceOutput,
variables: &[LinguisticVariable],
output_variable: &str,
domain: Domain,
method: DefuzzMethod,
) -> Option<f64> {
if !domain.is_valid() {
return None;
}
let variable = variables.iter().find(|v| v.name == output_variable)?;
let prefix = format!("{output_variable}.");
let consequents: Vec<(&super::types::FuzzySet, f64)> = output
.memberships
.iter()
.filter_map(|(key, strength)| {
key.strip_prefix(&prefix).and_then(|set_name| {
variable
.sets
.iter()
.find(|s| s.name == set_name)
.map(|s| (s, strength.value()))
})
})
.filter(|(_, strength)| *strength > 0.0)
.collect();
if consequents.is_empty() {
return None;
}
let dx = (domain.max - domain.min) / (domain.steps as f64);
let samples: Vec<(f64, f64)> = (0..=domain.steps)
.map(|i| {
let x = domain.min + (i as f64) * dx;
let mu = consequents
.iter()
.map(|(set, strength)| set.function.evaluate(x).value().min(*strength))
.fold(0.0_f64, f64::max);
(x, mu)
})
.collect();
match method {
DefuzzMethod::Centroid => {
let num: f64 = samples.iter().map(|(x, mu)| x * mu).sum();
let den: f64 = samples.iter().map(|(_, mu)| *mu).sum();
if den == 0.0 { None } else { Some(num / den) }
}
DefuzzMethod::Bisector => {
let total: f64 = samples.iter().map(|(_, mu)| *mu).sum();
if total == 0.0 {
return None;
}
let half = total / 2.0;
let mut acc = 0.0;
for (x, mu) in &samples {
acc += mu;
if acc >= half {
return Some(*x);
}
}
samples.last().map(|(x, _)| *x)
}
DefuzzMethod::MeanOfMaxima => {
let max_mu = samples.iter().map(|(_, mu)| *mu).fold(0.0_f64, f64::max);
if max_mu == 0.0 {
return None;
}
let xs: Vec<f64> = samples
.iter()
.filter(|(_, mu)| (mu - max_mu).abs() < 1e-9)
.map(|(x, _)| *x)
.collect();
if xs.is_empty() {
None
} else {
Some(xs.iter().sum::<f64>() / (xs.len() as f64))
}
}
DefuzzMethod::Height => samples
.iter()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(x, _)| *x),
}
}
pub fn weighted_average(rules: &[(f64, f64)]) -> Option<f64> {
let den: f64 = rules.iter().map(|(strength, _)| *strength).sum();
if den == 0.0 || !den.is_finite() {
return None;
}
let num: f64 = rules.iter().map(|(strength, value)| strength * value).sum();
if !num.is_finite() {
return None;
}
Some(num / den)
}
#[cfg(test)]
mod tests {
use super::{DefuzzMethod, Domain, defuzzify_mamdani, weighted_average};
use crate::fuzzy::{
ActivatedRule, FuzzyInferenceOutput, FuzzySet, LinguisticVariable, MembershipDegree,
MembershipFunction,
};
use std::collections::BTreeMap;
fn make_output(key: &str, strength: f64) -> FuzzyInferenceOutput {
let md = MembershipDegree::new(strength);
let mut memberships = BTreeMap::new();
memberships.insert(key.to_string(), md);
FuzzyInferenceOutput {
input_memberships: BTreeMap::new(),
memberships,
activated_rules: vec![ActivatedRule {
id: "r1".to_string(),
antecedent_strength: md,
weight: MembershipDegree::one(),
strength: md,
consequent: key.to_string(),
}],
confidence: md,
total_rules: 1,
}
}
fn sym_triangle_vars() -> Vec<LinguisticVariable> {
vec![LinguisticVariable {
name: "out".to_string(),
sets: vec![FuzzySet {
name: "mid".to_string(),
function: MembershipFunction::Triangular {
min: 0.0,
peak: 50.0,
max: 100.0,
},
}],
}]
}
#[test]
fn domain_invalid_min_ge_max_returns_none() {
let d = Domain::new(100.0, 0.0, 100);
assert!(
defuzzify_mamdani(
&make_output("out.mid", 1.0),
&sym_triangle_vars(),
"out",
d,
DefuzzMethod::Centroid
)
.is_none()
);
}
#[test]
fn domain_invalid_zero_steps_returns_none() {
let d = Domain::new(0.0, 100.0, 0);
assert!(
defuzzify_mamdani(
&make_output("out.mid", 1.0),
&sym_triangle_vars(),
"out",
d,
DefuzzMethod::Centroid
)
.is_none()
);
}
#[test]
fn domain_invalid_non_finite_returns_none() {
let d = Domain::new(f64::NAN, 100.0, 100);
assert!(
defuzzify_mamdani(
&make_output("out.mid", 1.0),
&sym_triangle_vars(),
"out",
d,
DefuzzMethod::Centroid
)
.is_none()
);
}
#[test]
fn unknown_output_variable_returns_none() {
let output = make_output("out.mid", 0.8);
let result = defuzzify_mamdani(
&output,
&sym_triangle_vars(),
"nonexistent",
Domain::new(0.0, 100.0, 100),
DefuzzMethod::Centroid,
);
assert!(result.is_none());
}
#[test]
fn zero_strength_consequent_returns_none() {
let output = make_output("out.mid", 0.0);
let result = defuzzify_mamdani(
&output,
&sym_triangle_vars(),
"out",
Domain::new(0.0, 100.0, 100),
DefuzzMethod::Centroid,
);
assert!(result.is_none());
}
fn sym_result(method: DefuzzMethod) -> f64 {
let output = make_output("out.mid", 1.0);
defuzzify_mamdani(
&output,
&sym_triangle_vars(),
"out",
Domain::new(0.0, 100.0, 1000),
method,
)
.unwrap()
}
#[test]
fn centroid_symmetric_triangle_returns_center() {
assert!((sym_result(DefuzzMethod::Centroid) - 50.0).abs() < 1.0);
}
#[test]
fn bisector_symmetric_triangle_returns_center() {
assert!((sym_result(DefuzzMethod::Bisector) - 50.0).abs() < 1.0);
}
#[test]
fn mean_of_maxima_symmetric_triangle_returns_center() {
assert!((sym_result(DefuzzMethod::MeanOfMaxima) - 50.0).abs() < 1.0);
}
#[test]
fn height_symmetric_triangle_returns_center() {
assert!((sym_result(DefuzzMethod::Height) - 50.0).abs() < 1.0);
}
#[test]
fn weighted_average_single_rule() {
assert!((weighted_average(&[(1.0, 42.0)]).unwrap() - 42.0).abs() < 1e-10);
}
#[test]
fn weighted_average_two_equal_rules() {
assert!((weighted_average(&[(0.5, 10.0), (0.5, 20.0)]).unwrap() - 15.0).abs() < 1e-10);
}
#[test]
fn weighted_average_empty_returns_none() {
assert!(weighted_average(&[]).is_none());
}
#[test]
fn weighted_average_zero_den_returns_none() {
assert!(weighted_average(&[(0.0, 10.0)]).is_none());
}
}