pub mod backend;
use std::collections::BTreeSet;
use anyhow::Result;
use crate::model::{Cluster, ClusterType};
pub struct TypeBinarizer {
pub classes: Vec<String>,
}
impl TypeBinarizer {
pub fn new(classes: Vec<String>) -> Self {
Self { classes }
}
pub fn transform(&self, types: &[ClusterType]) -> Vec<Vec<f64>> {
types
.iter()
.map(|ct| {
self.classes
.iter()
.map(|cls| if ct.names.contains(cls) { 1.0 } else { 0.0 })
.collect()
})
.collect()
}
pub fn inverse_transform(&self, matrix: &[Vec<bool>]) -> Vec<ClusterType> {
matrix
.iter()
.map(|row| {
let names: BTreeSet<String> = row
.iter()
.zip(self.classes.iter())
.filter(|(&flag, _)| flag)
.map(|(_, cls)| cls.clone())
.collect();
ClusterType { names }
})
.collect()
}
}
pub trait RandomForestModel: Send + Sync {
fn fit(&mut self, x: &[Vec<f64>], y: &[Vec<f64>]) -> Result<()>;
fn predict_proba(&self, x: &[Vec<f64>]) -> Result<Vec<Vec<f64>>>;
}
pub struct TypeClassifier {
pub binarizer: TypeBinarizer,
pub domains: Vec<String>,
model: Option<Box<dyn RandomForestModel>>,
}
impl TypeClassifier {
pub fn new(classes: Vec<String>) -> Self {
Self {
binarizer: TypeBinarizer::new(classes),
domains: Vec::new(),
model: None,
}
}
pub fn set_model(&mut self, model: Box<dyn RandomForestModel>) {
self.model = Some(model);
}
pub fn set_domains(&mut self, domains: Vec<String>) {
self.domains = domains;
}
pub fn predict_types(&self, clusters: &mut [Cluster]) -> Result<()> {
let model = self
.model
.as_ref()
.ok_or_else(|| anyhow::anyhow!("TypeClassifier model not set"))?;
let comps: Vec<Vec<f64>> = clusters
.iter()
.map(|c| c.domain_composition(Some(&self.domains), true, false, true))
.collect();
let probas = model.predict_proba(&comps)?;
for (cluster, proba) in clusters.iter_mut().zip(probas.iter()) {
let type_flags: Vec<bool> = proba.iter().map(|&p| p > 0.5).collect();
let types = self.binarizer.inverse_transform(&[type_flags]);
cluster.cluster_type = types.into_iter().next();
cluster.type_probabilities = self
.binarizer
.classes
.iter()
.zip(proba.iter())
.map(|(cls, &p)| (cls.clone(), p))
.collect();
}
Ok(())
}
}