Skip to main content

converge_analytics/packs/classification/
mod.rs

1mod solver;
2mod types;
3
4pub use solver::*;
5pub use types::*;
6
7use converge_optimization::packs::{
8    InvariantDef, InvariantResult, Pack, PackSolveResult, default_gate_evaluation,
9};
10use converge_pack::gate::GateResult as Result;
11use converge_pack::gate::{KernelTraceLink, ProblemSpec, PromotionGate, ProposedPlan};
12
13pub struct ClassificationPack;
14
15impl Pack for ClassificationPack {
16    fn name(&self) -> &'static str {
17        "classification"
18    }
19
20    fn version(&self) -> &'static str {
21        "1.0.0"
22    }
23
24    fn validate_inputs(&self, inputs: &serde_json::Value) -> Result<()> {
25        let input: ClassificationInput = serde_json::from_value(inputs.clone())
26            .map_err(|e| converge_pack::GateError::invalid_input(format!("Invalid input: {e}")))?;
27        input.validate()
28    }
29
30    fn invariants(&self) -> &[InvariantDef] {
31        static INVARIANTS: std::sync::LazyLock<Vec<InvariantDef>> =
32            std::sync::LazyLock::new(|| {
33                vec![
34                    InvariantDef::critical(
35                        "valid-probabilities",
36                        "All probabilities must be in [0, 1]",
37                    ),
38                    InvariantDef::advisory(
39                        "class-imbalance",
40                        "One class has > 90% of predictions — model may be degenerate",
41                    ),
42                ]
43            });
44        &INVARIANTS
45    }
46
47    fn solve(&self, spec: &ProblemSpec) -> Result<PackSolveResult> {
48        let input: ClassificationInput = spec.inputs_as()?;
49        input.validate()?;
50
51        let solver = LogisticClassifier;
52        let (output, report) = solver.solve(&input, spec)?;
53
54        let trace = KernelTraceLink::audit_only(format!("trace-{}", spec.problem_id));
55
56        let avg_decisiveness: f64 = output
57            .predictions
58            .iter()
59            .map(|p| (p.probability - 0.5).abs() * 2.0)
60            .sum::<f64>()
61            / output.total as f64;
62        let confidence = avg_decisiveness.clamp(0.3, 0.95);
63
64        let plan = ProposedPlan::from_payload(
65            format!("plan-{}", spec.problem_id),
66            self.name(),
67            output.summary(),
68            &output,
69            confidence,
70            trace,
71        )?;
72
73        Ok(PackSolveResult::new(plan, report))
74    }
75
76    fn check_invariants(&self, plan: &ProposedPlan) -> Result<Vec<InvariantResult>> {
77        let output: ClassificationOutput = serde_json::from_value(plan.plan.clone())
78            .map_err(|e| converge_pack::GateError::invalid_input(e.to_string()))?;
79
80        let mut results = vec![];
81
82        let all_valid = output
83            .predictions
84            .iter()
85            .all(|p| (0.0..=1.0).contains(&p.probability));
86
87        if all_valid {
88            results.push(InvariantResult::pass("valid-probabilities"));
89        } else {
90            results.push(InvariantResult::fail(
91                "valid-probabilities",
92                converge_pack::gate::Violation::new(
93                    "valid-probabilities",
94                    1.0,
95                    "Probability outside [0, 1] range",
96                ),
97            ));
98        }
99
100        let majority = output.positive_count.max(output.negative_count) as f64;
101        let ratio = majority / output.total as f64;
102        if ratio > 0.9 {
103            results.push(InvariantResult::fail(
104                "class-imbalance",
105                converge_pack::gate::Violation::new(
106                    "class-imbalance",
107                    ratio,
108                    format!("{:.0}% of predictions in one class", ratio * 100.0),
109                ),
110            ));
111        } else {
112            results.push(InvariantResult::pass("class-imbalance"));
113        }
114
115        Ok(results)
116    }
117
118    fn evaluate_gate(
119        &self,
120        _plan: &ProposedPlan,
121        invariant_results: &[InvariantResult],
122    ) -> PromotionGate {
123        default_gate_evaluation(invariant_results, self.invariants())
124    }
125}