Skip to main content

converge_analytics/packs/classification/
types.rs

1use converge_pack::gate::GateResult as Result;
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct ClassificationInput {
6    pub records: Vec<Vec<f64>>,
7    pub weights: Vec<f64>,
8    pub bias: f64,
9    pub threshold: f64,
10    pub labels: Option<(String, String)>,
11}
12
13impl ClassificationInput {
14    pub fn validate(&self) -> Result<()> {
15        if self.records.is_empty() {
16            return Err(converge_pack::GateError::invalid_input(
17                "At least one record required",
18            ));
19        }
20        let dim = self.weights.len();
21        if dim == 0 {
22            return Err(converge_pack::GateError::invalid_input(
23                "At least one weight (feature) required",
24            ));
25        }
26        for (i, record) in self.records.iter().enumerate() {
27            if record.len() != dim {
28                return Err(converge_pack::GateError::invalid_input(format!(
29                    "Record {} has {} features, expected {}",
30                    i,
31                    record.len(),
32                    dim
33                )));
34            }
35        }
36        if !(0.0..=1.0).contains(&self.threshold) {
37            return Err(converge_pack::GateError::invalid_input(
38                "Threshold must be in [0.0, 1.0]",
39            ));
40        }
41        Ok(())
42    }
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct ClassifiedRecord {
47    pub index: usize,
48    pub probability: f64,
49    pub label: String,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct ClassificationOutput {
54    pub predictions: Vec<ClassifiedRecord>,
55    pub positive_count: usize,
56    pub negative_count: usize,
57    pub total: usize,
58}
59
60impl ClassificationOutput {
61    pub fn summary(&self) -> String {
62        format!(
63            "Classified {} records: {} positive, {} negative",
64            self.total, self.positive_count, self.negative_count,
65        )
66    }
67}