use crate::types::{ClusteringResult, DataMatrix, DistanceMetric};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KMeansInput {
pub data: DataMatrix,
pub k: usize,
pub max_iterations: u32,
pub tolerance: f64,
}
impl KMeansInput {
pub fn new(data: DataMatrix, k: usize) -> Self {
Self {
data,
k,
max_iterations: 100,
tolerance: 1e-4,
}
}
pub fn with_max_iterations(mut self, max_iterations: u32) -> Self {
self.max_iterations = max_iterations;
self
}
pub fn with_tolerance(mut self, tolerance: f64) -> Self {
self.tolerance = tolerance;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KMeansOutput {
pub result: ClusteringResult,
pub compute_time_us: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DBSCANInput {
pub data: DataMatrix,
pub eps: f64,
pub min_samples: usize,
pub metric: DistanceMetric,
}
impl DBSCANInput {
pub fn new(data: DataMatrix, eps: f64, min_samples: usize) -> Self {
Self {
data,
eps,
min_samples,
metric: DistanceMetric::Euclidean,
}
}
pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
self.metric = metric;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DBSCANOutput {
pub result: ClusteringResult,
pub compute_time_us: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum Linkage {
Single,
Complete,
Average,
Ward,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HierarchicalInput {
pub data: DataMatrix,
pub n_clusters: usize,
pub linkage: Linkage,
pub metric: DistanceMetric,
}
impl HierarchicalInput {
pub fn new(data: DataMatrix, n_clusters: usize) -> Self {
Self {
data,
n_clusters,
linkage: Linkage::Complete,
metric: DistanceMetric::Euclidean,
}
}
pub fn with_linkage(mut self, linkage: Linkage) -> Self {
self.linkage = linkage;
self
}
pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
self.metric = metric;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HierarchicalOutput {
pub result: ClusteringResult,
pub compute_time_us: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IsolationForestInput {
pub data: DataMatrix,
pub n_trees: usize,
pub contamination: f64,
}
impl IsolationForestInput {
pub fn new(data: DataMatrix) -> Self {
Self {
data,
n_trees: 100,
contamination: 0.1,
}
}
pub fn with_n_trees(mut self, n_trees: usize) -> Self {
self.n_trees = n_trees;
self
}
pub fn with_contamination(mut self, contamination: f64) -> Self {
self.contamination = contamination;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnomalyOutput {
pub scores: Vec<f64>,
pub labels: Vec<i32>,
pub threshold: f64,
pub compute_time_us: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LOFInput {
pub data: DataMatrix,
pub n_neighbors: usize,
pub contamination: f64,
pub metric: DistanceMetric,
}
impl LOFInput {
pub fn new(data: DataMatrix) -> Self {
Self {
data,
n_neighbors: 20,
contamination: 0.1,
metric: DistanceMetric::Euclidean,
}
}
pub fn with_n_neighbors(mut self, n_neighbors: usize) -> Self {
self.n_neighbors = n_neighbors;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegressionInput {
pub x: DataMatrix,
pub y: Vec<f64>,
pub fit_intercept: bool,
pub alpha: Option<f64>,
}
impl RegressionInput {
pub fn new(x: DataMatrix, y: Vec<f64>) -> Self {
Self {
x,
y,
fit_intercept: true,
alpha: None,
}
}
pub fn with_ridge(mut self, alpha: f64) -> Self {
self.alpha = Some(alpha);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegressionOutput {
pub coefficients: Vec<f64>,
pub intercept: Option<f64>,
pub r_squared: f64,
pub compute_time_us: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kmeans_input_builder() {
let data = DataMatrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
let input = KMeansInput::new(data, 2)
.with_max_iterations(50)
.with_tolerance(1e-6);
assert_eq!(input.k, 2);
assert_eq!(input.max_iterations, 50);
}
#[test]
fn test_dbscan_input_builder() {
let data = DataMatrix::from_rows(&[&[1.0, 2.0]]);
let input = DBSCANInput::new(data, 0.5, 3).with_metric(DistanceMetric::Manhattan);
assert_eq!(input.eps, 0.5);
assert_eq!(input.min_samples, 3);
}
}