rustkernel_ml/
messages.rs

1//! Ring message types for ML kernels.
2//!
3//! This module defines request/response message types for GPU-native
4//! persistent actor communication for machine learning algorithms.
5
6use crate::types::{ClusteringResult, DataMatrix, DistanceMetric};
7use serde::{Deserialize, Serialize};
8
9// ============================================================================
10// K-Means Messages
11// ============================================================================
12
13/// K-Means clustering input for batch execution.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct KMeansInput {
16    /// Input data matrix (n_samples x n_features).
17    pub data: DataMatrix,
18    /// Number of clusters.
19    pub k: usize,
20    /// Maximum number of iterations.
21    pub max_iterations: u32,
22    /// Convergence tolerance for centroid movement.
23    pub tolerance: f64,
24}
25
26impl KMeansInput {
27    /// Create a new K-Means input.
28    pub fn new(data: DataMatrix, k: usize) -> Self {
29        Self {
30            data,
31            k,
32            max_iterations: 100,
33            tolerance: 1e-4,
34        }
35    }
36
37    /// Set maximum iterations.
38    pub fn with_max_iterations(mut self, max_iterations: u32) -> Self {
39        self.max_iterations = max_iterations;
40        self
41    }
42
43    /// Set convergence tolerance.
44    pub fn with_tolerance(mut self, tolerance: f64) -> Self {
45        self.tolerance = tolerance;
46        self
47    }
48}
49
50/// K-Means clustering output.
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct KMeansOutput {
53    /// The clustering result.
54    pub result: ClusteringResult,
55    /// Computation time in microseconds.
56    pub compute_time_us: u64,
57}
58
59// ============================================================================
60// DBSCAN Messages
61// ============================================================================
62
63/// DBSCAN clustering input for batch execution.
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct DBSCANInput {
66    /// Input data matrix.
67    pub data: DataMatrix,
68    /// Maximum distance for neighborhood (epsilon).
69    pub eps: f64,
70    /// Minimum points to form a dense region.
71    pub min_samples: usize,
72    /// Distance metric to use.
73    pub metric: DistanceMetric,
74}
75
76impl DBSCANInput {
77    /// Create a new DBSCAN input.
78    pub fn new(data: DataMatrix, eps: f64, min_samples: usize) -> Self {
79        Self {
80            data,
81            eps,
82            min_samples,
83            metric: DistanceMetric::Euclidean,
84        }
85    }
86
87    /// Set the distance metric.
88    pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
89        self.metric = metric;
90        self
91    }
92}
93
94/// DBSCAN clustering output.
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct DBSCANOutput {
97    /// The clustering result.
98    pub result: ClusteringResult,
99    /// Computation time in microseconds.
100    pub compute_time_us: u64,
101}
102
103// ============================================================================
104// Hierarchical Clustering Messages
105// ============================================================================
106
107/// Linkage method for hierarchical clustering.
108#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
109pub enum Linkage {
110    /// Single linkage (minimum distance).
111    Single,
112    /// Complete linkage (maximum distance).
113    Complete,
114    /// Average linkage (UPGMA).
115    Average,
116    /// Ward's method (minimize variance).
117    Ward,
118}
119
120/// Hierarchical clustering input for batch execution.
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct HierarchicalInput {
123    /// Input data matrix.
124    pub data: DataMatrix,
125    /// Number of clusters to form.
126    pub n_clusters: usize,
127    /// Linkage method.
128    pub linkage: Linkage,
129    /// Distance metric.
130    pub metric: DistanceMetric,
131}
132
133impl HierarchicalInput {
134    /// Create a new hierarchical clustering input.
135    pub fn new(data: DataMatrix, n_clusters: usize) -> Self {
136        Self {
137            data,
138            n_clusters,
139            linkage: Linkage::Complete,
140            metric: DistanceMetric::Euclidean,
141        }
142    }
143
144    /// Set the linkage method.
145    pub fn with_linkage(mut self, linkage: Linkage) -> Self {
146        self.linkage = linkage;
147        self
148    }
149
150    /// Set the distance metric.
151    pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
152        self.metric = metric;
153        self
154    }
155}
156
157/// Hierarchical clustering output.
158#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct HierarchicalOutput {
160    /// The clustering result.
161    pub result: ClusteringResult,
162    /// Computation time in microseconds.
163    pub compute_time_us: u64,
164}
165
166// ============================================================================
167// Anomaly Detection Messages
168// ============================================================================
169
170/// Isolation Forest input for batch execution.
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct IsolationForestInput {
173    /// Input data matrix.
174    pub data: DataMatrix,
175    /// Number of trees in the ensemble.
176    pub n_trees: usize,
177    /// Contamination proportion (expected fraction of outliers).
178    pub contamination: f64,
179}
180
181impl IsolationForestInput {
182    /// Create a new Isolation Forest input.
183    pub fn new(data: DataMatrix) -> Self {
184        Self {
185            data,
186            n_trees: 100,
187            contamination: 0.1,
188        }
189    }
190
191    /// Set the number of trees.
192    pub fn with_n_trees(mut self, n_trees: usize) -> Self {
193        self.n_trees = n_trees;
194        self
195    }
196
197    /// Set the contamination proportion.
198    pub fn with_contamination(mut self, contamination: f64) -> Self {
199        self.contamination = contamination;
200        self
201    }
202}
203
204/// Anomaly detection output.
205#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct AnomalyOutput {
207    /// Anomaly scores for each sample (higher = more anomalous).
208    pub scores: Vec<f64>,
209    /// Labels (1 = anomaly, 0 = normal) based on threshold.
210    pub labels: Vec<i32>,
211    /// The threshold used for classification.
212    pub threshold: f64,
213    /// Computation time in microseconds.
214    pub compute_time_us: u64,
215}
216
217/// Local Outlier Factor input for batch execution.
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct LOFInput {
220    /// Input data matrix.
221    pub data: DataMatrix,
222    /// Number of neighbors.
223    pub n_neighbors: usize,
224    /// Contamination proportion.
225    pub contamination: f64,
226    /// Distance metric.
227    pub metric: DistanceMetric,
228}
229
230impl LOFInput {
231    /// Create a new LOF input.
232    pub fn new(data: DataMatrix) -> Self {
233        Self {
234            data,
235            n_neighbors: 20,
236            contamination: 0.1,
237            metric: DistanceMetric::Euclidean,
238        }
239    }
240
241    /// Set the number of neighbors.
242    pub fn with_n_neighbors(mut self, n_neighbors: usize) -> Self {
243        self.n_neighbors = n_neighbors;
244        self
245    }
246}
247
248// ============================================================================
249// Regression Messages
250// ============================================================================
251
252/// Regression input for batch execution.
253#[derive(Debug, Clone, Serialize, Deserialize)]
254pub struct RegressionInput {
255    /// Feature matrix X (n_samples x n_features).
256    pub x: DataMatrix,
257    /// Target vector y (n_samples).
258    pub y: Vec<f64>,
259    /// Whether to fit intercept.
260    pub fit_intercept: bool,
261    /// Regularization parameter (for Ridge regression).
262    pub alpha: Option<f64>,
263}
264
265impl RegressionInput {
266    /// Create a new regression input.
267    pub fn new(x: DataMatrix, y: Vec<f64>) -> Self {
268        Self {
269            x,
270            y,
271            fit_intercept: true,
272            alpha: None,
273        }
274    }
275
276    /// Enable Ridge regularization with given alpha.
277    pub fn with_ridge(mut self, alpha: f64) -> Self {
278        self.alpha = Some(alpha);
279        self
280    }
281}
282
283/// Regression output.
284#[derive(Debug, Clone, Serialize, Deserialize)]
285pub struct RegressionOutput {
286    /// Coefficient vector.
287    pub coefficients: Vec<f64>,
288    /// Intercept (if fit_intercept was true).
289    pub intercept: Option<f64>,
290    /// R-squared score.
291    pub r_squared: f64,
292    /// Computation time in microseconds.
293    pub compute_time_us: u64,
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    #[test]
301    fn test_kmeans_input_builder() {
302        let data = DataMatrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
303        let input = KMeansInput::new(data, 2)
304            .with_max_iterations(50)
305            .with_tolerance(1e-6);
306        assert_eq!(input.k, 2);
307        assert_eq!(input.max_iterations, 50);
308    }
309
310    #[test]
311    fn test_dbscan_input_builder() {
312        let data = DataMatrix::from_rows(&[&[1.0, 2.0]]);
313        let input = DBSCANInput::new(data, 0.5, 3).with_metric(DistanceMetric::Manhattan);
314        assert_eq!(input.eps, 0.5);
315        assert_eq!(input.min_samples, 3);
316    }
317}