ghostflow_ml/
rbf_network.rs

1//! RBF Network - Radial Basis Function Neural Network
2
3use ghostflow_core::Tensor;
4use rand::prelude::*;
5
6/// RBF Network for classification and regression
7pub struct RBFNetwork {
8    pub n_centers: usize,
9    pub gamma: f32,
10    pub max_iter: usize,
11    pub tol: f32,
12    centers_: Option<Vec<Vec<f32>>>,
13    weights_: Option<Vec<f32>>,
14    bias_: f32,
15    n_features_: usize,
16}
17
18impl RBFNetwork {
19    pub fn new(n_centers: usize) -> Self {
20        RBFNetwork {
21            n_centers,
22            gamma: 1.0,
23            max_iter: 100,
24            tol: 1e-4,
25            centers_: None,
26            weights_: None,
27            bias_: 0.0,
28            n_features_: 0,
29        }
30    }
31
32    pub fn gamma(mut self, g: f32) -> Self {
33        self.gamma = g;
34        self
35    }
36
37    fn rbf_kernel(&self, x: &[f32], center: &[f32]) -> f32 {
38        let sq_dist: f32 = x.iter().zip(center.iter())
39            .map(|(&a, &b)| (a - b).powi(2)).sum();
40        (-self.gamma * sq_dist).exp()
41    }
42
43    fn kmeans_init(&self, x: &[f32], n_samples: usize, n_features: usize) -> Vec<Vec<f32>> {
44        let mut rng = thread_rng();
45        let n_centers = self.n_centers.min(n_samples);
46
47        // K-means++ initialization
48        let mut centers = Vec::with_capacity(n_centers);
49        
50        // First center: random
51        let first_idx = rng.gen_range(0..n_samples);
52        centers.push(x[first_idx * n_features..(first_idx + 1) * n_features].to_vec());
53
54        // Remaining centers
55        for _ in 1..n_centers {
56            let distances: Vec<f32> = (0..n_samples)
57                .map(|i| {
58                    let xi = &x[i * n_features..(i + 1) * n_features];
59                    centers.iter()
60                        .map(|c| {
61                            xi.iter().zip(c.iter())
62                                .map(|(&a, &b)| (a - b).powi(2)).sum::<f32>()
63                        })
64                        .fold(f32::MAX, f32::min)
65                })
66                .collect();
67
68            let total: f32 = distances.iter().sum();
69            if total < 1e-10 {
70                let idx = rng.gen_range(0..n_samples);
71                centers.push(x[idx * n_features..(idx + 1) * n_features].to_vec());
72                continue;
73            }
74
75            // Sample proportional to distance squared
76            let threshold = rng.gen::<f32>() * total;
77            let mut cumsum = 0.0f32;
78            for (i, &d) in distances.iter().enumerate() {
79                cumsum += d;
80                if cumsum >= threshold {
81                    centers.push(x[i * n_features..(i + 1) * n_features].to_vec());
82                    break;
83                }
84            }
85        }
86
87        // Run a few k-means iterations
88        for _ in 0..10 {
89            // Assign points to nearest center
90            let mut assignments = vec![0usize; n_samples];
91            for i in 0..n_samples {
92                let xi = &x[i * n_features..(i + 1) * n_features];
93                let mut min_dist = f32::MAX;
94                for (j, center) in centers.iter().enumerate() {
95                    let dist: f32 = xi.iter().zip(center.iter())
96                        .map(|(&a, &b)| (a - b).powi(2)).sum();
97                    if dist < min_dist {
98                        min_dist = dist;
99                        assignments[i] = j;
100                    }
101                }
102            }
103
104            // Update centers
105            let mut new_centers = vec![vec![0.0f32; n_features]; n_centers];
106            let mut counts = vec![0usize; n_centers];
107
108            for i in 0..n_samples {
109                let c = assignments[i];
110                counts[c] += 1;
111                for j in 0..n_features {
112                    new_centers[c][j] += x[i * n_features + j];
113                }
114            }
115
116            for c in 0..n_centers {
117                if counts[c] > 0 {
118                    for j in 0..n_features {
119                        new_centers[c][j] /= counts[c] as f32;
120                    }
121                } else {
122                    // Keep old center if no points assigned
123                    new_centers[c] = centers[c].clone();
124                }
125            }
126
127            centers = new_centers;
128        }
129
130        centers
131    }
132
133    pub fn fit(&mut self, x: &Tensor, y: &Tensor) {
134        let x_data = x.data_f32();
135        let y_data = y.data_f32();
136        let n_samples = x.dims()[0];
137        let n_features = x.dims()[1];
138
139        self.n_features_ = n_features;
140
141        // Initialize centers using k-means
142        let centers = self.kmeans_init(&x_data, n_samples, n_features);
143        let n_centers = centers.len();
144
145        // Compute RBF activations
146        let mut phi = vec![0.0f32; n_samples * (n_centers + 1)]; // +1 for bias
147        for i in 0..n_samples {
148            let xi = &x_data[i * n_features..(i + 1) * n_features];
149            for (j, center) in centers.iter().enumerate() {
150                phi[i * (n_centers + 1) + j] = self.rbf_kernel(xi, center);
151            }
152            phi[i * (n_centers + 1) + n_centers] = 1.0; // Bias term
153        }
154
155        // Solve least squares: (Phi^T Phi + lambda I) w = Phi^T y
156        let lambda = 1e-6f32;
157        let m = n_centers + 1;
158
159        let mut ata = vec![0.0f32; m * m];
160        let mut aty = vec![0.0f32; m];
161
162        for i in 0..m {
163            for k in 0..n_samples {
164                aty[i] += phi[k * m + i] * y_data[k];
165            }
166            for j in 0..m {
167                for k in 0..n_samples {
168                    ata[i * m + j] += phi[k * m + i] * phi[k * m + j];
169                }
170            }
171            ata[i * m + i] += lambda;
172        }
173
174        // Solve linear system
175        let weights = solve_linear_system(&ata, &aty, m);
176
177        self.centers_ = Some(centers);
178        self.weights_ = Some(weights[..n_centers].to_vec());
179        self.bias_ = weights[n_centers];
180    }
181
182    pub fn predict(&self, x: &Tensor) -> Tensor {
183        let x_data = x.data_f32();
184        let n_samples = x.dims()[0];
185        let n_features = x.dims()[1];
186
187        let centers = self.centers_.as_ref().expect("Model not fitted");
188        let weights = self.weights_.as_ref().unwrap();
189
190        let predictions: Vec<f32> = (0..n_samples)
191            .map(|i| {
192                let xi = &x_data[i * n_features..(i + 1) * n_features];
193                let mut pred = self.bias_;
194                for (j, center) in centers.iter().enumerate() {
195                    pred += weights[j] * self.rbf_kernel(xi, center);
196                }
197                pred
198            })
199            .collect();
200
201        Tensor::from_slice(&predictions, &[n_samples]).unwrap()
202    }
203
204    pub fn predict_proba(&self, x: &Tensor) -> Tensor {
205        let predictions = self.predict(x);
206        let pred_data = predictions.data_f32();
207        
208        // Apply sigmoid for probability
209        let proba: Vec<f32> = pred_data.iter()
210            .map(|&p| 1.0 / (1.0 + (-p).exp()))
211            .collect();
212
213        Tensor::from_slice(&proba, &[proba.len()]).unwrap()
214    }
215}
216
217/// RBF Classifier
218pub struct RBFClassifier {
219    pub n_centers: usize,
220    pub gamma: f32,
221    networks_: Vec<RBFNetwork>,
222    classes_: Vec<i32>,
223}
224
225impl RBFClassifier {
226    pub fn new(n_centers: usize) -> Self {
227        RBFClassifier {
228            n_centers,
229            gamma: 1.0,
230            networks_: Vec::new(),
231            classes_: Vec::new(),
232        }
233    }
234
235    pub fn gamma(mut self, g: f32) -> Self {
236        self.gamma = g;
237        self
238    }
239
240    pub fn fit(&mut self, x: &Tensor, y: &Tensor) {
241        let y_data = y.data_f32();
242
243        // Find unique classes
244        let mut classes: Vec<i32> = y_data.iter().map(|&v| v as i32).collect();
245        classes.sort();
246        classes.dedup();
247
248        if classes.len() == 2 {
249            // Binary classification
250            let mut network = RBFNetwork::new(self.n_centers);
251            network.gamma = self.gamma;
252            
253            // Convert labels to 0/1
254            let y_binary: Vec<f32> = y_data.iter()
255                .map(|&v| if v as i32 == classes[1] { 1.0 } else { 0.0 })
256                .collect();
257            let y_tensor = Tensor::from_slice(&y_binary, &[y_binary.len()]).unwrap();
258            
259            network.fit(x, &y_tensor);
260            self.networks_ = vec![network];
261        } else {
262            // One-vs-rest for multiclass
263            for &class in &classes {
264                let mut network = RBFNetwork::new(self.n_centers);
265                network.gamma = self.gamma;
266                
267                let y_binary: Vec<f32> = y_data.iter()
268                    .map(|&v| if v as i32 == class { 1.0 } else { 0.0 })
269                    .collect();
270                let y_tensor = Tensor::from_slice(&y_binary, &[y_binary.len()]).unwrap();
271                
272                network.fit(x, &y_tensor);
273                self.networks_.push(network);
274            }
275        }
276
277        self.classes_ = classes;
278    }
279
280    pub fn predict(&self, x: &Tensor) -> Tensor {
281        let n_samples = x.dims()[0];
282
283        if self.networks_.len() == 1 {
284            // Binary classification
285            let proba = self.networks_[0].predict_proba(x);
286            let proba_data = proba.data_f32();
287            
288            let predictions: Vec<f32> = proba_data.iter()
289                .map(|&p| if p >= 0.5 { self.classes_[1] as f32 } else { self.classes_[0] as f32 })
290                .collect();
291            
292            Tensor::from_slice(&predictions, &[n_samples]).unwrap()
293        } else {
294            // Multiclass: pick class with highest score
295            let scores: Vec<Vec<f32>> = self.networks_.iter()
296                .map(|net| net.predict(x).data_f32().clone())
297                .collect();
298
299            let predictions: Vec<f32> = (0..n_samples)
300                .map(|i| {
301                    let mut max_score = f32::NEG_INFINITY;
302                    let mut max_class = self.classes_[0];
303                    for (j, class) in self.classes_.iter().enumerate() {
304                        if scores[j][i] > max_score {
305                            max_score = scores[j][i];
306                            max_class = *class;
307                        }
308                    }
309                    max_class as f32
310                })
311                .collect();
312
313            Tensor::from_slice(&predictions, &[n_samples]).unwrap()
314        }
315    }
316
317    pub fn predict_proba(&self, x: &Tensor) -> Tensor {
318        let n_samples = x.dims()[0];
319        let n_classes = self.classes_.len();
320
321        if self.networks_.len() == 1 {
322            // Binary classification
323            let proba = self.networks_[0].predict_proba(x);
324            let proba_data = proba.data_f32();
325            
326            let result: Vec<f32> = proba_data.iter()
327                .flat_map(|&p| vec![1.0 - p, p])
328                .collect();
329            
330            Tensor::from_slice(&result, &[n_samples, 2]).unwrap()
331        } else {
332            // Multiclass: softmax over scores
333            let scores: Vec<Vec<f32>> = self.networks_.iter()
334                .map(|net| net.predict(x).data_f32().clone())
335                .collect();
336
337            let mut result = vec![0.0f32; n_samples * n_classes];
338            for i in 0..n_samples {
339                let max_score: f32 = (0..n_classes).map(|j| scores[j][i]).fold(f32::NEG_INFINITY, f32::max);
340                let exp_sum: f32 = (0..n_classes).map(|j| (scores[j][i] - max_score).exp()).sum();
341                
342                for j in 0..n_classes {
343                    result[i * n_classes + j] = (scores[j][i] - max_score).exp() / exp_sum;
344                }
345            }
346
347            Tensor::from_slice(&result, &[n_samples, n_classes]).unwrap()
348        }
349    }
350}
351
352fn solve_linear_system(a: &[f32], b: &[f32], n: usize) -> Vec<f32> {
353    let mut aug = vec![0.0f32; n * (n + 1)];
354    for i in 0..n {
355        for j in 0..n {
356            aug[i * (n + 1) + j] = a[i * n + j];
357        }
358        aug[i * (n + 1) + n] = b[i];
359    }
360
361    for i in 0..n {
362        let mut max_row = i;
363        for k in (i + 1)..n {
364            if aug[k * (n + 1) + i].abs() > aug[max_row * (n + 1) + i].abs() {
365                max_row = k;
366            }
367        }
368
369        for j in 0..=n {
370            let tmp = aug[i * (n + 1) + j];
371            aug[i * (n + 1) + j] = aug[max_row * (n + 1) + j];
372            aug[max_row * (n + 1) + j] = tmp;
373        }
374
375        let pivot = aug[i * (n + 1) + i];
376        if pivot.abs() < 1e-10 { continue; }
377
378        for j in i..=n {
379            aug[i * (n + 1) + j] /= pivot;
380        }
381
382        for k in 0..n {
383            if k != i {
384                let factor = aug[k * (n + 1) + i];
385                for j in i..=n {
386                    aug[k * (n + 1) + j] -= factor * aug[i * (n + 1) + j];
387                }
388            }
389        }
390    }
391
392    (0..n).map(|i| aug[i * (n + 1) + n]).collect()
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398
399    #[test]
400    fn test_rbf_network_regression() {
401        let x = Tensor::from_slice(&[0.0f32, 0.0,
402            1.0, 0.0,
403            0.0, 1.0,
404            1.0, 1.0,
405        ], &[4, 2]).unwrap();
406        let y = Tensor::from_slice(&[0.0f32, 1.0, 1.0, 0.0], &[4]).unwrap();
407
408        let mut rbf = RBFNetwork::new(4).gamma(1.0);
409        rbf.fit(&x, &y);
410        let pred = rbf.predict(&x);
411        
412        assert_eq!(pred.dims(), &[4]);
413    }
414
415    #[test]
416    fn test_rbf_classifier() {
417        let x = Tensor::from_slice(&[0.0f32, 0.0,
418            1.0, 0.0,
419            0.0, 1.0,
420            1.0, 1.0,
421        ], &[4, 2]).unwrap();
422        let y = Tensor::from_slice(&[0.0f32, 1.0, 1.0, 0.0], &[4]).unwrap();
423
424        let mut clf = RBFClassifier::new(4).gamma(1.0);
425        clf.fit(&x, &y);
426        let pred = clf.predict(&x);
427        
428        assert_eq!(pred.dims(), &[4]);
429    }
430}
431
432