use converge_pack::gate::{ObjectiveSpec, ProblemSpec};
use prism::fuzzy::{
ActivatedRule, DefuzzMethod, Domain, FuzzyInferenceEngine, FuzzyInferenceInput,
FuzzyInferenceOutput, FuzzySet, LinguisticVariable, MembershipFunction, SugenoInferenceEngine,
SugenoInferenceInput, TsukamotoInferenceEngine, TsukamotoInferenceInput, defuzzify_mamdani,
weighted_average,
};
use prism::packs::anomaly_detection::{AnomalyDetectionInput, ZScoreSolver};
use prism::packs::classification::{ClassificationInput, LogisticClassifier};
use prism::packs::descriptive_stats::{DescriptiveStatsInput, DescriptiveStatsSolver};
use prism::packs::forecasting::{ExponentialSmoothingSolver, ForecastingInput};
use prism::packs::naive_bayes::{ClassDef, GaussianNaiveBayes, GaussianParams, NaiveBayesInput};
use prism::packs::ranking::{RankItem, RankingInput, WeightedScoringSolver};
use prism::packs::regression::{LinearRegressionSolver, RegressionInput};
use prism::packs::segmentation::{KMeansSolver, SegmentationInput};
use prism::packs::similarity::{
DistanceMetric, PairwiseSimilaritySolver, SimilarityInput, SimilarityItem,
};
use prism::packs::trend_detection::{
MovingAverageTrendSolver, TrendDetectionInput, TrendDirection,
};
fn spec() -> ProblemSpec {
ProblemSpec::builder("ref-test", "test")
.objective(ObjectiveSpec::maximize("accuracy"))
.build()
.unwrap()
}
#[test]
fn zscore_hand_computed() {
let input = AnomalyDetectionInput {
values: vec![10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 100.0],
threshold: 2.0,
labels: None,
};
let (output, _) = ZScoreSolver.solve(&input, &spec()).unwrap();
assert!(
(output.mean - 19.0).abs() < 1e-9,
"mean should be 19.0, got {}",
output.mean
);
assert!(
(output.std_dev - 27.0).abs() < 1e-9,
"stddev should be 27.0, got {}",
output.std_dev
);
assert_eq!(output.anomaly_count, 1, "only 100.0 is an anomaly");
assert_eq!(output.anomalies[0].index, 9);
assert!((output.anomalies[0].z_score - 3.0).abs() < 1e-9);
}
#[test]
fn descriptive_stats_hand_computed() {
let input = DescriptiveStatsInput {
values: vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0],
percentiles: vec![25.0, 50.0, 75.0],
};
let (output, _) = DescriptiveStatsSolver.solve(&input, &spec()).unwrap();
assert_eq!(output.count, 8);
assert!(
(output.mean - 5.0).abs() < 1e-9,
"mean = 5.0, got {}",
output.mean
);
assert!(
(output.median - 4.5).abs() < 1e-9,
"median = 4.5, got {}",
output.median
);
assert!(
(output.variance - 4.0).abs() < 1e-9,
"variance = 4.0, got {}",
output.variance
);
assert!(
(output.std_dev - 2.0).abs() < 1e-9,
"stddev = 2.0, got {}",
output.std_dev
);
assert!((output.min - 2.0).abs() < 1e-9);
assert!((output.max - 9.0).abs() < 1e-9);
assert!((output.range - 7.0).abs() < 1e-9);
}
#[test]
fn linear_regression_exact() {
let input = RegressionInput {
records: vec![vec![1.0, 1.0], vec![2.0, 0.0], vec![0.0, 3.0]],
weights: vec![2.0, 3.0],
bias: 1.0,
};
let (output, _) = LinearRegressionSolver.solve(&input, &spec()).unwrap();
assert_eq!(output.total, 3);
assert!((output.predictions[0].value - 6.0).abs() < 1e-9);
assert!((output.predictions[1].value - 5.0).abs() < 1e-9);
assert!((output.predictions[2].value - 10.0).abs() < 1e-9);
assert!(
(output.mean_prediction - 7.0).abs() < 1e-9,
"mean = (6+5+10)/3 = 7.0"
);
}
#[test]
fn logistic_classification_sigmoid() {
let input = ClassificationInput {
records: vec![vec![1.0, 0.0], vec![0.0, 0.0]],
weights: vec![3.0, 0.0],
bias: -1.5,
threshold: 0.5,
labels: None,
};
let (output, _) = LogisticClassifier.solve(&input, &spec()).unwrap();
let expected_p1 = 1.0 / (1.0 + (-1.5_f64).exp());
let expected_p0 = 1.0 / (1.0 + 1.5_f64.exp());
assert_eq!(output.positive_count, 1);
assert_eq!(output.negative_count, 1);
assert!(
(output.predictions[0].probability - expected_p1).abs() < 1e-6,
"sigmoid(1.5) ≈ {expected_p1}, got {}",
output.predictions[0].probability
);
assert!(
(output.predictions[1].probability - expected_p0).abs() < 1e-6,
"sigmoid(-1.5) ≈ {expected_p0}, got {}",
output.predictions[1].probability
);
}
#[test]
fn fuzzy_inference_hand_computed() {
let input: FuzzyInferenceInput = serde_json::from_value(serde_json::json!({
"inputs": {
"authenticity": 0.84,
"novelty": 0.25
},
"variables": [
{
"name": "authenticity",
"sets": [
{"name": "high", "function": {"kind": "right_shoulder", "start": 0.6, "end": 0.9}}
]
},
{
"name": "novelty",
"sets": [
{"name": "low", "function": {"kind": "left_shoulder", "start": 0.2, "end": 0.6}}
]
},
{
"name": "satisfaction",
"sets": [
{"name": "high", "function": {"kind": "right_shoulder", "start": 0.6, "end": 0.9}}
]
}
],
"rules": [
{
"id": "authentic-simple-fit",
"if": {
"op": "and",
"terms": [
{"op": "is", "variable": "authenticity", "set": "high"},
{"op": "is", "variable": "novelty", "set": "low"}
]
},
"then": {"variable": "satisfaction", "set": "high"}
}
]
}))
.unwrap();
let (output, _) = FuzzyInferenceEngine.solve(&input, &spec()).unwrap();
assert!((output.input_memberships["authenticity"]["high"] - 0.8).abs() < 1e-9);
assert!((output.input_memberships["novelty"]["low"] - 0.875).abs() < 1e-9);
assert!((output.memberships["satisfaction.high"] - 0.8).abs() < 1e-9);
assert_eq!(output.activated_rules[0].id, "authentic-simple-fit");
assert!((output.confidence - 0.8).abs() < 1e-9);
}
#[test]
fn fuzzy_membership_functions_are_hand_computable() {
let warm = MembershipFunction::Triangular {
min: 40.0,
peak: 60.0,
max: 80.0,
};
let hot = MembershipFunction::RightShoulder {
start: 60.0,
end: 80.0,
};
assert!((warm.evaluate(50.0) - 0.5).abs() < 1e-9);
assert!((warm.evaluate(60.0) - 1.0).abs() < 1e-9);
assert!((warm.evaluate(70.0) - 0.5).abs() < 1e-9);
assert!((hot.evaluate(70.0) - 0.5).abs() < 1e-9);
assert!((hot.evaluate(85.0) - 1.0).abs() < 1e-9);
}
#[test]
fn fuzzy_gaussian_membership_is_hand_computable() {
let g = MembershipFunction::Gaussian {
center: 0.0,
sigma: 1.0,
};
assert!((g.evaluate(0.0) - 1.0).abs() < 1e-9);
let half_sigma = (-0.5_f64).exp();
assert!((g.evaluate(1.0) - half_sigma).abs() < 1e-9);
assert!((g.evaluate(-1.0) - half_sigma).abs() < 1e-9);
let g2 = MembershipFunction::Gaussian {
center: 5.0,
sigma: 2.0,
};
assert!((g2.evaluate(5.0) - 1.0).abs() < 1e-9);
assert!((g2.evaluate(7.0) - half_sigma).abs() < 1e-9);
assert!((g2.evaluate(3.0) - half_sigma).abs() < 1e-9);
}
#[test]
fn weighted_average_hand_computed() {
let result = weighted_average(&[(0.6, 10.0), (0.4, 20.0)]);
assert!((result.unwrap() - 14.0).abs() < 1e-9);
assert!(weighted_average(&[]).is_none());
assert!(weighted_average(&[(0.0, 5.0), (0.0, 10.0)]).is_none());
}
fn symmetric_triangle_output() -> (FuzzyInferenceOutput, Vec<LinguisticVariable>) {
let variables = vec![LinguisticVariable {
name: "y".into(),
sets: vec![FuzzySet {
name: "mid".into(),
function: MembershipFunction::Triangular {
min: 0.0,
peak: 5.0,
max: 10.0,
},
}],
}];
let mut memberships = std::collections::BTreeMap::new();
memberships.insert("y.mid".to_string(), 1.0);
let output = FuzzyInferenceOutput {
input_memberships: std::collections::BTreeMap::new(),
memberships,
activated_rules: vec![ActivatedRule {
id: "rule-1".into(),
antecedent_strength: 1.0,
weight: 1.0,
strength: 1.0,
consequent: "y.mid".into(),
}],
confidence: 1.0,
total_rules: 1,
};
(output, variables)
}
#[test]
fn defuzzify_mamdani_centroid_symmetric_triangle() {
let (output, variables) = symmetric_triangle_output();
let r = defuzzify_mamdani(
&output,
&variables,
"y",
Domain {
min: 0.0,
max: 10.0,
steps: 1000,
},
DefuzzMethod::Centroid,
);
let centroid = r.expect("centroid defined");
assert!(
(centroid - 5.0).abs() < 1e-6,
"centroid ≈ 5.0, got {centroid}"
);
}
#[test]
fn defuzzify_mamdani_mom_symmetric_triangle() {
let (output, variables) = symmetric_triangle_output();
let r = defuzzify_mamdani(
&output,
&variables,
"y",
Domain {
min: 0.0,
max: 10.0,
steps: 1000,
},
DefuzzMethod::MeanOfMaxima,
);
assert!((r.unwrap() - 5.0).abs() < 1e-9);
}
#[test]
fn defuzzify_mamdani_no_rules_fired_returns_none() {
let variables = vec![LinguisticVariable {
name: "y".into(),
sets: vec![FuzzySet {
name: "mid".into(),
function: MembershipFunction::Triangular {
min: 0.0,
peak: 5.0,
max: 10.0,
},
}],
}];
let output = FuzzyInferenceOutput {
input_memberships: std::collections::BTreeMap::new(),
memberships: std::collections::BTreeMap::new(),
activated_rules: vec![],
confidence: 0.0,
total_rules: 0,
};
let r = defuzzify_mamdani(
&output,
&variables,
"y",
Domain {
min: 0.0,
max: 10.0,
steps: 100,
},
DefuzzMethod::Centroid,
);
assert!(r.is_none());
}
#[test]
fn sugeno_constant_consequents_hand_computed() {
let input: SugenoInferenceInput = serde_json::from_value(serde_json::json!({
"inputs": { "x": 5.0 },
"variables": [
{
"name": "x",
"sets": [
{"name": "low", "function": {"kind": "left_shoulder", "start": 2.0, "end": 8.0}},
{"name": "high", "function": {"kind": "right_shoulder", "start": 2.0, "end": 8.0}}
]
}
],
"rules": [
{"id": "low-rule",
"if": {"op": "is", "variable": "x", "set": "low"},
"then": {"kind": "constant", "value": 0.0}},
{"id": "high-rule",
"if": {"op": "is", "variable": "x", "set": "high"},
"then": {"kind": "constant", "value": 10.0}}
]
}))
.unwrap();
let (output, _) = SugenoInferenceEngine.solve(&input, &spec()).unwrap();
assert!((output.input_memberships["x"]["low"] - 0.5).abs() < 1e-9);
assert!((output.input_memberships["x"]["high"] - 0.5).abs() < 1e-9);
assert_eq!(output.activated_rules.len(), 2);
assert!((output.output.unwrap() - 5.0).abs() < 1e-9);
assert!((output.confidence - 0.5).abs() < 1e-9);
}
#[test]
fn sugeno_linear_consequent_hand_computed() {
let input: SugenoInferenceInput = serde_json::from_value(serde_json::json!({
"inputs": { "x": 4.0 },
"variables": [
{
"name": "x",
"sets": [
{"name": "mid", "function": {"kind": "triangular", "min": 0.0, "peak": 4.0, "max": 8.0}}
]
}
],
"rules": [
{
"id": "linear-rule",
"if": {"op": "is", "variable": "x", "set": "mid"},
"then": {"kind": "linear", "intercept": 2.0, "coefficients": {"x": 0.5}}
}
]
}))
.unwrap();
let (output, _) = SugenoInferenceEngine.solve(&input, &spec()).unwrap();
assert!((output.input_memberships["x"]["mid"] - 1.0).abs() < 1e-9);
assert_eq!(output.activated_rules.len(), 1);
assert!((output.output.unwrap() - 4.0).abs() < 1e-9);
assert!((output.confidence - 1.0).abs() < 1e-9);
let activated = &output.activated_rules[0];
assert!((activated.consequent_value - 4.0).abs() < 1e-9);
}
#[test]
fn fuzzy_membership_monotonic_classification() {
assert!(
MembershipFunction::LeftShoulder {
start: 0.0,
end: 1.0
}
.is_monotonic()
);
assert!(
MembershipFunction::RightShoulder {
start: 0.0,
end: 1.0
}
.is_monotonic()
);
assert!(
!MembershipFunction::Triangular {
min: 0.0,
peak: 0.5,
max: 1.0
}
.is_monotonic()
);
assert!(
!MembershipFunction::Trapezoidal {
min: 0.0,
lower_peak: 0.3,
upper_peak: 0.7,
max: 1.0
}
.is_monotonic()
);
assert!(
!MembershipFunction::Gaussian {
center: 0.0,
sigma: 1.0
}
.is_monotonic()
);
}
#[test]
fn fuzzy_membership_inverse_for_shoulders() {
let left = MembershipFunction::LeftShoulder {
start: 0.0,
end: 10.0,
};
assert!((left.inverse(1.0).unwrap() - 0.0).abs() < 1e-9);
assert!((left.inverse(0.0).unwrap() - 10.0).abs() < 1e-9);
assert!((left.inverse(0.5).unwrap() - 5.0).abs() < 1e-9);
let right = MembershipFunction::RightShoulder {
start: 0.0,
end: 10.0,
};
assert!((right.inverse(0.0).unwrap() - 0.0).abs() < 1e-9);
assert!((right.inverse(1.0).unwrap() - 10.0).abs() < 1e-9);
assert!((right.inverse(0.5).unwrap() - 5.0).abs() < 1e-9);
let triangle = MembershipFunction::Triangular {
min: 0.0,
peak: 5.0,
max: 10.0,
};
assert!(triangle.inverse(0.5).is_err());
let gaussian = MembershipFunction::Gaussian {
center: 0.0,
sigma: 1.0,
};
assert!(gaussian.inverse(0.5).is_err());
assert!(left.inverse(1.5).is_err());
assert!(left.inverse(-0.1).is_err());
assert!(left.inverse(f64::NAN).is_err());
}
#[test]
fn tsukamoto_hand_computed() {
let input: TsukamotoInferenceInput = serde_json::from_value(serde_json::json!({
"inputs": { "x": 3.0 },
"variables": [
{
"name": "x",
"sets": [
{"name": "low", "function": {"kind": "left_shoulder", "start": 0.0, "end": 10.0}},
{"name": "high", "function": {"kind": "right_shoulder", "start": 0.0, "end": 10.0}}
]
},
{
"name": "y",
"sets": [
{"name": "small_y", "function": {"kind": "left_shoulder", "start": 0.0, "end": 100.0}},
{"name": "big_y", "function": {"kind": "right_shoulder", "start": 50.0, "end": 250.0}}
]
}
],
"rules": [
{"id": "low-rule",
"if": {"op": "is", "variable": "x", "set": "low"},
"then": {"variable": "y", "set": "small_y"}},
{"id": "high-rule",
"if": {"op": "is", "variable": "x", "set": "high"},
"then": {"variable": "y", "set": "big_y"}}
]
}))
.unwrap();
let (output, _) = TsukamotoInferenceEngine.solve(&input, &spec()).unwrap();
assert!((output.input_memberships["x"]["low"] - 0.7).abs() < 1e-9);
assert!((output.input_memberships["x"]["high"] - 0.3).abs() < 1e-9);
assert_eq!(output.activated_rules.len(), 2);
let low = output
.activated_rules
.iter()
.find(|r| r.id == "low-rule")
.unwrap();
let high = output
.activated_rules
.iter()
.find(|r| r.id == "high-rule")
.unwrap();
assert!((low.firing_strength - 0.7).abs() < 1e-9);
assert!((low.consequent_value - 30.0).abs() < 1e-9);
assert!((high.firing_strength - 0.3).abs() < 1e-9);
assert!((high.consequent_value - 110.0).abs() < 1e-9);
assert!((output.output.unwrap() - 54.0).abs() < 1e-9);
assert!((output.confidence - 0.7).abs() < 1e-9);
}
#[test]
fn tsukamoto_rejects_non_monotonic_consequent() {
let input: TsukamotoInferenceInput = serde_json::from_value(serde_json::json!({
"inputs": { "x": 0.5 },
"variables": [
{
"name": "x",
"sets": [
{"name": "high", "function": {"kind": "right_shoulder", "start": 0.0, "end": 1.0}}
]
},
{
"name": "y",
"sets": [
{"name": "mid", "function": {"kind": "triangular", "min": 0.0, "peak": 0.5, "max": 1.0}}
]
}
],
"rules": [
{
"if": {"op": "is", "variable": "x", "set": "high"},
"then": {"variable": "y", "set": "mid"}
}
]
}))
.unwrap();
assert!(input.validate().is_err());
}
#[test]
fn sugeno_no_rules_fired_returns_none() {
let input: SugenoInferenceInput = serde_json::from_value(serde_json::json!({
"inputs": { "x": 0.5 },
"variables": [
{
"name": "x",
"sets": [
{"name": "high", "function": {"kind": "right_shoulder", "start": 5.0, "end": 8.0}}
]
}
],
"rules": [
{"if": {"op": "is", "variable": "x", "set": "high"},
"then": {"kind": "constant", "value": 10.0}}
]
}))
.unwrap();
let (output, _) = SugenoInferenceEngine.solve(&input, &spec()).unwrap();
assert!(output.output.is_none());
assert_eq!(output.activated_rules.len(), 0);
assert_eq!(output.confidence, 0.0);
}
#[test]
fn cosine_similarity_hand_computed() {
let input = SimilarityInput {
items: vec![
SimilarityItem {
id: "A".to_string(),
features: vec![1.0, 0.0, 0.0],
},
SimilarityItem {
id: "B".to_string(),
features: vec![1.0, 0.0, 0.0],
},
SimilarityItem {
id: "C".to_string(),
features: vec![0.0, 1.0, 0.0],
},
],
metric: DistanceMetric::Cosine,
top_k: None,
};
let (output, _) = PairwiseSimilaritySolver.solve(&input, &spec()).unwrap();
assert_eq!(output.total_pairs, 3);
let ab = output
.pairs
.iter()
.find(|p| (p.id_a == "A" && p.id_b == "B") || (p.id_a == "B" && p.id_b == "A"))
.expect("A-B pair must exist");
assert!(
(ab.score - 1.0).abs() < 1e-9,
"identical vectors → cos = 1.0, got {}",
ab.score
);
let ac = output
.pairs
.iter()
.find(|p| (p.id_a == "A" && p.id_b == "C") || (p.id_a == "C" && p.id_b == "A"))
.expect("A-C pair must exist");
assert!(
ac.score.abs() < 1e-9,
"orthogonal vectors → cos = 0.0, got {}",
ac.score
);
}
#[test]
fn cosine_similarity_45_degrees() {
let input = SimilarityInput {
items: vec![
SimilarityItem {
id: "A".to_string(),
features: vec![1.0, 1.0],
},
SimilarityItem {
id: "B".to_string(),
features: vec![1.0, 0.0],
},
],
metric: DistanceMetric::Cosine,
top_k: None,
};
let (output, _) = PairwiseSimilaritySolver.solve(&input, &spec()).unwrap();
let expected = 1.0 / 2.0_f64.sqrt();
assert!(
(output.pairs[0].score - expected).abs() < 1e-6,
"cos(45°) = 1/√2 ≈ {expected}, got {}",
output.pairs[0].score
);
}
#[test]
fn exponential_smoothing_hand_traced() {
let input = ForecastingInput {
values: vec![100.0, 110.0, 120.0],
horizon: 1,
alpha: 0.5,
};
let (output, _) = ExponentialSmoothingSolver.solve(&input, &spec()).unwrap();
assert_eq!(output.predictions.len(), 1);
assert!(
(output.predictions[0].value - 112.5).abs() < 1e-6,
"SES forecast = 112.5, got {}",
output.predictions[0].value
);
}
#[test]
fn kmeans_two_separated_clusters() {
let input = SegmentationInput {
records: vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![10.0, 10.0],
vec![11.0, 10.0],
vec![10.0, 11.0],
],
k: 2,
max_iterations: 100,
seed: Some(42),
};
let (output, _) = KMeansSolver.solve(&input, &spec()).unwrap();
assert_eq!(output.assignments[0], output.assignments[1]);
assert_eq!(output.assignments[1], output.assignments[2]);
assert_eq!(output.assignments[3], output.assignments[4]);
assert_eq!(output.assignments[4], output.assignments[5]);
assert_ne!(
output.assignments[0], output.assignments[3],
"the two clusters must be distinct"
);
for centroid in &output.centroids {
let near_origin = centroid[0] < 2.0 && centroid[1] < 2.0;
let near_ten = centroid[0] > 8.0 && centroid[1] > 8.0;
assert!(
near_origin || near_ten,
"centroid {:?} should be near (0.33, 0.33) or (10.33, 10.33)",
centroid
);
}
}
#[test]
fn ranking_hand_computed() {
let input = RankingInput {
items: vec![
RankItem {
id: "A".to_string(),
scores: vec![100.0, 10.0],
},
RankItem {
id: "B".to_string(),
scores: vec![50.0, 90.0],
},
RankItem {
id: "C".to_string(),
scores: vec![75.0, 50.0],
},
],
weights: vec![0.7, 0.3],
higher_is_better: vec![true, true],
top_k: None,
};
let (output, _) = WeightedScoringSolver.solve(&input, &spec()).unwrap();
assert_eq!(output.ranked[0].id, "A");
assert_eq!(output.ranked[1].id, "C");
assert_eq!(output.ranked[2].id, "B");
assert!((output.ranked[0].composite_score - 0.70).abs() < 1e-6);
assert!((output.ranked[1].composite_score - 0.50).abs() < 1e-6);
assert!((output.ranked[2].composite_score - 0.30).abs() < 1e-6);
}
#[test]
fn zscore_20_values_multiple_anomalies() {
let mut values = vec![50.0; 18];
values.push(150.0);
values.push(-50.0);
let input = AnomalyDetectionInput {
values,
threshold: 3.0,
labels: None,
};
let (output, _) = ZScoreSolver.solve(&input, &spec()).unwrap();
assert!((output.mean - 50.0).abs() < 1e-9, "mean = 50.0");
assert!(
(output.std_dev - 1000.0_f64.sqrt()).abs() < 1e-6,
"stddev = √1000 ≈ 31.623"
);
assert_eq!(
output.anomaly_count, 2,
"both 150 and -50 are anomalies at threshold 3.0"
);
}
#[test]
fn descriptive_stats_15_values() {
let input = DescriptiveStatsInput {
values: (1..=15).map(|x| x as f64).collect(),
percentiles: vec![25.0, 50.0, 75.0],
};
let (output, _) = DescriptiveStatsSolver.solve(&input, &spec()).unwrap();
assert_eq!(output.count, 15);
assert!((output.mean - 8.0).abs() < 1e-9);
assert!(
(output.median - 8.0).abs() < 1e-9,
"odd count → middle = 8.0"
);
assert!(
(output.variance - 280.0 / 15.0).abs() < 1e-9,
"variance = 280/15"
);
assert!((output.min - 1.0).abs() < 1e-9);
assert!((output.max - 15.0).abs() < 1e-9);
assert!((output.range - 14.0).abs() < 1e-9);
}
#[test]
fn linear_regression_5d() {
let input = RegressionInput {
records: vec![
vec![1.0, 1.0, 1.0, 1.0, 1.0],
vec![2.0, 0.0, 1.0, 0.0, 1.0],
vec![0.0, 3.0, 0.0, 2.0, 0.0],
vec![1.0, 2.0, 3.0, 0.0, 0.0],
],
weights: vec![1.0, 2.0, 3.0, 4.0, 5.0],
bias: 10.0,
};
let (output, _) = LinearRegressionSolver.solve(&input, &spec()).unwrap();
assert_eq!(output.total, 4);
assert!((output.predictions[0].value - 25.0).abs() < 1e-9);
assert!((output.predictions[1].value - 20.0).abs() < 1e-9);
assert!((output.predictions[2].value - 24.0).abs() < 1e-9);
assert!((output.predictions[3].value - 24.0).abs() < 1e-9);
assert!((output.mean_prediction - 23.25).abs() < 1e-9);
}
#[test]
fn logistic_classification_boundary() {
let input = ClassificationInput {
records: vec![vec![0.5], vec![1.0], vec![0.0], vec![0.3], vec![0.7]],
weights: vec![10.0],
bias: -5.0,
threshold: 0.5,
labels: None,
};
let (output, _) = LogisticClassifier.solve(&input, &spec()).unwrap();
assert_eq!(output.positive_count, 3, "x=0.5, 0.7, 1.0 are positive");
assert_eq!(output.negative_count, 2, "x=0.0, 0.3 are negative");
let p_07 = output.predictions[4].probability;
let p_03 = output.predictions[3].probability;
assert!(
(p_07 + p_03 - 1.0).abs() < 1e-9,
"sigmoid symmetry: p(0.7) + p(0.3) = 1.0"
);
assert!(
(output.predictions[0].probability - 0.5).abs() < 1e-9,
"sigmoid(0) = 0.5 exactly"
);
}
#[test]
fn cosine_similarity_5d() {
let input = SimilarityInput {
items: vec![
SimilarityItem {
id: "A".into(),
features: vec![1.0, 0.0, 0.0, 0.0, 0.0],
},
SimilarityItem {
id: "B".into(),
features: vec![0.0, 1.0, 0.0, 0.0, 0.0],
},
SimilarityItem {
id: "C".into(),
features: vec![1.0, 1.0, 0.0, 0.0, 0.0],
},
SimilarityItem {
id: "D".into(),
features: vec![1.0, 1.0, 1.0, 1.0, 1.0],
},
],
metric: DistanceMetric::Cosine,
top_k: None,
};
let (output, _) = PairwiseSimilaritySolver.solve(&input, &spec()).unwrap();
assert_eq!(output.total_pairs, 6);
let find = |a: &str, b: &str| -> f64 {
output
.pairs
.iter()
.find(|p| (p.id_a == a && p.id_b == b) || (p.id_a == b && p.id_b == a))
.unwrap_or_else(|| panic!("pair {a}-{b} not found"))
.score
};
assert!(find("A", "B").abs() < 1e-9, "orthogonal → 0");
assert!(
(find("A", "C") - 1.0 / 2.0_f64.sqrt()).abs() < 1e-6,
"45° → 1/√2"
);
assert!(
(find("A", "D") - 1.0 / 5.0_f64.sqrt()).abs() < 1e-6,
"A·D/norms → 1/√5"
);
assert!((find("B", "C") - 1.0 / 2.0_f64.sqrt()).abs() < 1e-6);
assert!((find("B", "D") - 1.0 / 5.0_f64.sqrt()).abs() < 1e-6);
assert!(
(find("C", "D") - 2.0 / 10.0_f64.sqrt()).abs() < 1e-6,
"C·D = 2/(√2·√5)"
);
}
#[test]
fn exponential_smoothing_8_values() {
let input = ForecastingInput {
values: vec![100.0, 120.0, 90.0, 130.0, 110.0, 95.0, 105.0, 115.0],
horizon: 1,
alpha: 0.3,
};
let (output, _) = ExponentialSmoothingSolver.solve(&input, &spec()).unwrap();
assert_eq!(output.predictions.len(), 1);
assert!(
(output.predictions[0].value - 108.206584).abs() < 1e-4,
"8-step SES forecast ≈ 108.2066, got {}",
output.predictions[0].value
);
}
#[test]
fn kmeans_4_clusters_3d() {
let input = SegmentationInput {
records: vec![
vec![0.0, 0.0, 0.0],
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![100.0, 0.0, 0.0],
vec![101.0, 0.0, 0.0],
vec![100.0, 1.0, 0.0],
vec![0.0, 100.0, 0.0],
vec![1.0, 100.0, 0.0],
vec![0.0, 101.0, 0.0],
vec![0.0, 0.0, 100.0],
vec![1.0, 0.0, 100.0],
vec![0.0, 1.0, 100.0],
],
k: 4,
max_iterations: 100,
seed: None,
};
let mut found = false;
for seed in 0..20 {
let mut attempt = input.clone();
attempt.seed = Some(seed);
let (output, _) = KMeansSolver.solve(&attempt, &spec()).unwrap();
let groups_ok = [0, 3, 6, 9].iter().all(|&start| {
output.assignments[start] == output.assignments[start + 1]
&& output.assignments[start + 1] == output.assignments[start + 2]
});
let labels: std::collections::HashSet<_> = [
output.assignments[0],
output.assignments[3],
output.assignments[6],
output.assignments[9],
]
.into_iter()
.collect();
if groups_ok && labels.len() == 4 {
found = true;
break;
}
}
assert!(
found,
"k-means must find 4 clusters in at least one of 20 seed attempts"
);
}
#[test]
fn ranking_5_items_mixed_directions() {
let input = RankingInput {
items: vec![
RankItem {
id: "A".into(),
scores: vec![90.0, 20.0, 80.0],
},
RankItem {
id: "B".into(),
scores: vec![70.0, 10.0, 60.0],
},
RankItem {
id: "C".into(),
scores: vec![80.0, 40.0, 100.0],
},
RankItem {
id: "D".into(),
scores: vec![60.0, 30.0, 70.0],
},
RankItem {
id: "E".into(),
scores: vec![100.0, 50.0, 90.0],
},
],
weights: vec![0.5, 0.3, 0.2],
higher_is_better: vec![true, false, true],
top_k: None,
};
let (output, _) = WeightedScoringSolver.solve(&input, &spec()).unwrap();
assert_eq!(output.ranked[0].id, "A");
assert_eq!(output.ranked[1].id, "E");
assert_eq!(output.ranked[2].id, "C");
assert_eq!(output.ranked[3].id, "B");
assert_eq!(output.ranked[4].id, "D");
assert!((output.ranked[0].composite_score - 0.700).abs() < 1e-6);
assert!((output.ranked[1].composite_score - 0.650).abs() < 1e-6);
assert!((output.ranked[2].composite_score - 0.525).abs() < 1e-6);
assert!((output.ranked[3].composite_score - 0.425).abs() < 1e-6);
assert!((output.ranked[4].composite_score - 0.200).abs() < 1e-6);
}
#[test]
fn trend_pure_rising_hand_computed() {
let input = TrendDetectionInput {
values: vec![10.0, 20.0, 30.0, 40.0, 50.0],
window: 3,
sensitivity: 0.5,
};
let (output, _) = MovingAverageTrendSolver.solve(&input, &spec()).unwrap();
assert_eq!(output.overall_direction, TrendDirection::Rising);
assert!(
(output.overall_slope - 10.0).abs() < 1e-9,
"overall slope = 10.0"
);
assert_eq!(output.segments.len(), 1);
assert_eq!(output.segments[0].direction, TrendDirection::Rising);
assert!((output.segments[0].slope - 10.0).abs() < 1e-9);
assert_eq!(output.segments[0].start, 0);
assert_eq!(output.segments[0].end, 4);
assert!(output.changepoints.is_empty());
}
#[test]
fn trend_pure_falling_hand_computed() {
let input = TrendDetectionInput {
values: vec![50.0, 40.0, 30.0, 20.0, 10.0],
window: 3,
sensitivity: 0.5,
};
let (output, _) = MovingAverageTrendSolver.solve(&input, &spec()).unwrap();
assert_eq!(output.overall_direction, TrendDirection::Falling);
assert!((output.overall_slope - (-10.0)).abs() < 1e-9);
assert_eq!(output.segments.len(), 1);
assert_eq!(output.segments[0].direction, TrendDirection::Falling);
assert!((output.segments[0].slope - (-10.0)).abs() < 1e-9);
assert!(output.changepoints.is_empty());
}
#[test]
fn trend_stable_hand_computed() {
let input = TrendDetectionInput {
values: vec![5.0, 5.0, 5.0, 5.0, 5.0],
window: 3,
sensitivity: 1.0,
};
let (output, _) = MovingAverageTrendSolver.solve(&input, &spec()).unwrap();
assert_eq!(output.overall_direction, TrendDirection::Stable);
assert!(output.overall_slope.abs() < 1e-9);
assert_eq!(output.segments.len(), 1);
assert_eq!(output.segments[0].direction, TrendDirection::Stable);
assert!(output.changepoints.is_empty());
}
#[test]
fn trend_rise_then_fall_hand_computed() {
let input = TrendDetectionInput {
values: vec![10.0, 20.0, 30.0, 20.0, 10.0],
window: 2,
sensitivity: 0.9,
};
let (output, _) = MovingAverageTrendSolver.solve(&input, &spec()).unwrap();
assert_eq!(output.overall_direction, TrendDirection::Stable);
assert!(
output.overall_slope.abs() < 1e-9,
"symmetric series → slope = 0"
);
assert_eq!(output.segments.len(), 2, "one rising + one falling segment");
assert_eq!(output.segments[0].direction, TrendDirection::Rising);
assert!((output.segments[0].slope - 10.0).abs() < 1e-9);
assert_eq!(output.segments[0].start, 0);
assert_eq!(output.segments[0].end, 3);
assert_eq!(output.segments[1].direction, TrendDirection::Falling);
assert!((output.segments[1].slope - (-10.0)).abs() < 1e-9);
assert_eq!(output.segments[1].start, 3);
assert_eq!(output.segments[1].end, 4);
assert_eq!(output.changepoints.len(), 1);
assert_eq!(output.changepoints[0].index, 3);
assert!((output.changepoints[0].magnitude - 20.0).abs() < 1e-9);
}
#[test]
fn naive_bayes_two_class_hand_computed() {
let input = NaiveBayesInput {
classes: vec![
ClassDef {
name: "A".into(),
prior: 0.5,
feature_params: vec![
GaussianParams {
mean: 0.0,
std_dev: 1.0,
},
GaussianParams {
mean: 0.0,
std_dev: 1.0,
},
],
},
ClassDef {
name: "B".into(),
prior: 0.5,
feature_params: vec![
GaussianParams {
mean: 3.0,
std_dev: 1.0,
},
GaussianParams {
mean: 3.0,
std_dev: 1.0,
},
],
},
],
features: vec![1.0, 1.0],
};
let (output, _) = GaussianNaiveBayes.solve(&input, &spec()).unwrap();
assert_eq!(output.predicted, "A");
assert!(
(output.confidence - 0.9526).abs() < 1e-3,
"P(A|x) ≈ 0.9526, got {}",
output.confidence
);
let sum: f64 = output.probabilities.iter().map(|p| p.probability).sum();
assert!(
(sum - 1.0).abs() < 1e-9,
"probabilities must sum to 1.0, got {sum}"
);
let p_b = output
.probabilities
.iter()
.find(|p| p.class == "B")
.expect("class B must appear");
assert!(
(p_b.probability + output.confidence - 1.0).abs() < 1e-9,
"P(A) + P(B) = 1.0"
);
}
#[test]
fn naive_bayes_symmetric_priors_favors_closer_mean() {
let input = NaiveBayesInput {
classes: vec![
ClassDef {
name: "far".into(),
prior: 0.5,
feature_params: vec![GaussianParams {
mean: 0.0,
std_dev: 1.0,
}],
},
ClassDef {
name: "near".into(),
prior: 0.5,
feature_params: vec![GaussianParams {
mean: 1.0,
std_dev: 1.0,
}],
},
],
features: vec![0.6],
};
let (output, _) = GaussianNaiveBayes.solve(&input, &spec()).unwrap();
assert_eq!(output.predicted, "near");
assert!(
output.confidence > 0.5,
"near should win with > 50% probability"
);
}
#[test]
fn naive_bayes_prior_breaks_equidistant_tie() {
let input = NaiveBayesInput {
classes: vec![
ClassDef {
name: "strong-prior".into(),
prior: 0.9,
feature_params: vec![GaussianParams {
mean: 0.0,
std_dev: 1.0,
}],
},
ClassDef {
name: "weak-prior".into(),
prior: 0.1,
feature_params: vec![GaussianParams {
mean: 1.0,
std_dev: 1.0,
}],
},
],
features: vec![0.5],
};
let (output, _) = GaussianNaiveBayes.solve(&input, &spec()).unwrap();
assert_eq!(output.predicted, "strong-prior");
}