1use ghostflow_core::Tensor;
4use rand::prelude::*;
5
6pub struct RBFNetwork {
8 pub n_centers: usize,
9 pub gamma: f32,
10 pub max_iter: usize,
11 pub tol: f32,
12 centers_: Option<Vec<Vec<f32>>>,
13 weights_: Option<Vec<f32>>,
14 bias_: f32,
15 n_features_: usize,
16}
17
18impl RBFNetwork {
19 pub fn new(n_centers: usize) -> Self {
20 RBFNetwork {
21 n_centers,
22 gamma: 1.0,
23 max_iter: 100,
24 tol: 1e-4,
25 centers_: None,
26 weights_: None,
27 bias_: 0.0,
28 n_features_: 0,
29 }
30 }
31
32 pub fn gamma(mut self, g: f32) -> Self {
33 self.gamma = g;
34 self
35 }
36
37 fn rbf_kernel(&self, x: &[f32], center: &[f32]) -> f32 {
38 let sq_dist: f32 = x.iter().zip(center.iter())
39 .map(|(&a, &b)| (a - b).powi(2)).sum();
40 (-self.gamma * sq_dist).exp()
41 }
42
43 fn kmeans_init(&self, x: &[f32], n_samples: usize, n_features: usize) -> Vec<Vec<f32>> {
44 let mut rng = thread_rng();
45 let n_centers = self.n_centers.min(n_samples);
46
47 let mut centers = Vec::with_capacity(n_centers);
49
50 let first_idx = rng.gen_range(0..n_samples);
52 centers.push(x[first_idx * n_features..(first_idx + 1) * n_features].to_vec());
53
54 for _ in 1..n_centers {
56 let distances: Vec<f32> = (0..n_samples)
57 .map(|i| {
58 let xi = &x[i * n_features..(i + 1) * n_features];
59 centers.iter()
60 .map(|c| {
61 xi.iter().zip(c.iter())
62 .map(|(&a, &b)| (a - b).powi(2)).sum::<f32>()
63 })
64 .fold(f32::MAX, f32::min)
65 })
66 .collect();
67
68 let total: f32 = distances.iter().sum();
69 if total < 1e-10 {
70 let idx = rng.gen_range(0..n_samples);
71 centers.push(x[idx * n_features..(idx + 1) * n_features].to_vec());
72 continue;
73 }
74
75 let threshold = rng.gen::<f32>() * total;
77 let mut cumsum = 0.0f32;
78 for (i, &d) in distances.iter().enumerate() {
79 cumsum += d;
80 if cumsum >= threshold {
81 centers.push(x[i * n_features..(i + 1) * n_features].to_vec());
82 break;
83 }
84 }
85 }
86
87 for _ in 0..10 {
89 let mut assignments = vec![0usize; n_samples];
91 for i in 0..n_samples {
92 let xi = &x[i * n_features..(i + 1) * n_features];
93 let mut min_dist = f32::MAX;
94 for (j, center) in centers.iter().enumerate() {
95 let dist: f32 = xi.iter().zip(center.iter())
96 .map(|(&a, &b)| (a - b).powi(2)).sum();
97 if dist < min_dist {
98 min_dist = dist;
99 assignments[i] = j;
100 }
101 }
102 }
103
104 let mut new_centers = vec![vec![0.0f32; n_features]; n_centers];
106 let mut counts = vec![0usize; n_centers];
107
108 for i in 0..n_samples {
109 let c = assignments[i];
110 counts[c] += 1;
111 for j in 0..n_features {
112 new_centers[c][j] += x[i * n_features + j];
113 }
114 }
115
116 for c in 0..n_centers {
117 if counts[c] > 0 {
118 for j in 0..n_features {
119 new_centers[c][j] /= counts[c] as f32;
120 }
121 } else {
122 new_centers[c] = centers[c].clone();
124 }
125 }
126
127 centers = new_centers;
128 }
129
130 centers
131 }
132
133 pub fn fit(&mut self, x: &Tensor, y: &Tensor) {
134 let x_data = x.data_f32();
135 let y_data = y.data_f32();
136 let n_samples = x.dims()[0];
137 let n_features = x.dims()[1];
138
139 self.n_features_ = n_features;
140
141 let centers = self.kmeans_init(&x_data, n_samples, n_features);
143 let n_centers = centers.len();
144
145 let mut phi = vec![0.0f32; n_samples * (n_centers + 1)]; for i in 0..n_samples {
148 let xi = &x_data[i * n_features..(i + 1) * n_features];
149 for (j, center) in centers.iter().enumerate() {
150 phi[i * (n_centers + 1) + j] = self.rbf_kernel(xi, center);
151 }
152 phi[i * (n_centers + 1) + n_centers] = 1.0; }
154
155 let lambda = 1e-6f32;
157 let m = n_centers + 1;
158
159 let mut ata = vec![0.0f32; m * m];
160 let mut aty = vec![0.0f32; m];
161
162 for i in 0..m {
163 for k in 0..n_samples {
164 aty[i] += phi[k * m + i] * y_data[k];
165 }
166 for j in 0..m {
167 for k in 0..n_samples {
168 ata[i * m + j] += phi[k * m + i] * phi[k * m + j];
169 }
170 }
171 ata[i * m + i] += lambda;
172 }
173
174 let weights = solve_linear_system(&ata, &aty, m);
176
177 self.centers_ = Some(centers);
178 self.weights_ = Some(weights[..n_centers].to_vec());
179 self.bias_ = weights[n_centers];
180 }
181
182 pub fn predict(&self, x: &Tensor) -> Tensor {
183 let x_data = x.data_f32();
184 let n_samples = x.dims()[0];
185 let n_features = x.dims()[1];
186
187 let centers = self.centers_.as_ref().expect("Model not fitted");
188 let weights = self.weights_.as_ref().unwrap();
189
190 let predictions: Vec<f32> = (0..n_samples)
191 .map(|i| {
192 let xi = &x_data[i * n_features..(i + 1) * n_features];
193 let mut pred = self.bias_;
194 for (j, center) in centers.iter().enumerate() {
195 pred += weights[j] * self.rbf_kernel(xi, center);
196 }
197 pred
198 })
199 .collect();
200
201 Tensor::from_slice(&predictions, &[n_samples]).unwrap()
202 }
203
204 pub fn predict_proba(&self, x: &Tensor) -> Tensor {
205 let predictions = self.predict(x);
206 let pred_data = predictions.data_f32();
207
208 let proba: Vec<f32> = pred_data.iter()
210 .map(|&p| 1.0 / (1.0 + (-p).exp()))
211 .collect();
212
213 Tensor::from_slice(&proba, &[proba.len()]).unwrap()
214 }
215}
216
217pub struct RBFClassifier {
219 pub n_centers: usize,
220 pub gamma: f32,
221 networks_: Vec<RBFNetwork>,
222 classes_: Vec<i32>,
223}
224
225impl RBFClassifier {
226 pub fn new(n_centers: usize) -> Self {
227 RBFClassifier {
228 n_centers,
229 gamma: 1.0,
230 networks_: Vec::new(),
231 classes_: Vec::new(),
232 }
233 }
234
235 pub fn gamma(mut self, g: f32) -> Self {
236 self.gamma = g;
237 self
238 }
239
240 pub fn fit(&mut self, x: &Tensor, y: &Tensor) {
241 let y_data = y.data_f32();
242
243 let mut classes: Vec<i32> = y_data.iter().map(|&v| v as i32).collect();
245 classes.sort();
246 classes.dedup();
247
248 if classes.len() == 2 {
249 let mut network = RBFNetwork::new(self.n_centers);
251 network.gamma = self.gamma;
252
253 let y_binary: Vec<f32> = y_data.iter()
255 .map(|&v| if v as i32 == classes[1] { 1.0 } else { 0.0 })
256 .collect();
257 let y_tensor = Tensor::from_slice(&y_binary, &[y_binary.len()]).unwrap();
258
259 network.fit(x, &y_tensor);
260 self.networks_ = vec![network];
261 } else {
262 for &class in &classes {
264 let mut network = RBFNetwork::new(self.n_centers);
265 network.gamma = self.gamma;
266
267 let y_binary: Vec<f32> = y_data.iter()
268 .map(|&v| if v as i32 == class { 1.0 } else { 0.0 })
269 .collect();
270 let y_tensor = Tensor::from_slice(&y_binary, &[y_binary.len()]).unwrap();
271
272 network.fit(x, &y_tensor);
273 self.networks_.push(network);
274 }
275 }
276
277 self.classes_ = classes;
278 }
279
280 pub fn predict(&self, x: &Tensor) -> Tensor {
281 let n_samples = x.dims()[0];
282
283 if self.networks_.len() == 1 {
284 let proba = self.networks_[0].predict_proba(x);
286 let proba_data = proba.data_f32();
287
288 let predictions: Vec<f32> = proba_data.iter()
289 .map(|&p| if p >= 0.5 { self.classes_[1] as f32 } else { self.classes_[0] as f32 })
290 .collect();
291
292 Tensor::from_slice(&predictions, &[n_samples]).unwrap()
293 } else {
294 let scores: Vec<Vec<f32>> = self.networks_.iter()
296 .map(|net| net.predict(x).data_f32().clone())
297 .collect();
298
299 let predictions: Vec<f32> = (0..n_samples)
300 .map(|i| {
301 let mut max_score = f32::NEG_INFINITY;
302 let mut max_class = self.classes_[0];
303 for (j, class) in self.classes_.iter().enumerate() {
304 if scores[j][i] > max_score {
305 max_score = scores[j][i];
306 max_class = *class;
307 }
308 }
309 max_class as f32
310 })
311 .collect();
312
313 Tensor::from_slice(&predictions, &[n_samples]).unwrap()
314 }
315 }
316
317 pub fn predict_proba(&self, x: &Tensor) -> Tensor {
318 let n_samples = x.dims()[0];
319 let n_classes = self.classes_.len();
320
321 if self.networks_.len() == 1 {
322 let proba = self.networks_[0].predict_proba(x);
324 let proba_data = proba.data_f32();
325
326 let result: Vec<f32> = proba_data.iter()
327 .flat_map(|&p| vec![1.0 - p, p])
328 .collect();
329
330 Tensor::from_slice(&result, &[n_samples, 2]).unwrap()
331 } else {
332 let scores: Vec<Vec<f32>> = self.networks_.iter()
334 .map(|net| net.predict(x).data_f32().clone())
335 .collect();
336
337 let mut result = vec![0.0f32; n_samples * n_classes];
338 for i in 0..n_samples {
339 let max_score: f32 = (0..n_classes).map(|j| scores[j][i]).fold(f32::NEG_INFINITY, f32::max);
340 let exp_sum: f32 = (0..n_classes).map(|j| (scores[j][i] - max_score).exp()).sum();
341
342 for j in 0..n_classes {
343 result[i * n_classes + j] = (scores[j][i] - max_score).exp() / exp_sum;
344 }
345 }
346
347 Tensor::from_slice(&result, &[n_samples, n_classes]).unwrap()
348 }
349 }
350}
351
352fn solve_linear_system(a: &[f32], b: &[f32], n: usize) -> Vec<f32> {
353 let mut aug = vec![0.0f32; n * (n + 1)];
354 for i in 0..n {
355 for j in 0..n {
356 aug[i * (n + 1) + j] = a[i * n + j];
357 }
358 aug[i * (n + 1) + n] = b[i];
359 }
360
361 for i in 0..n {
362 let mut max_row = i;
363 for k in (i + 1)..n {
364 if aug[k * (n + 1) + i].abs() > aug[max_row * (n + 1) + i].abs() {
365 max_row = k;
366 }
367 }
368
369 for j in 0..=n {
370 let tmp = aug[i * (n + 1) + j];
371 aug[i * (n + 1) + j] = aug[max_row * (n + 1) + j];
372 aug[max_row * (n + 1) + j] = tmp;
373 }
374
375 let pivot = aug[i * (n + 1) + i];
376 if pivot.abs() < 1e-10 { continue; }
377
378 for j in i..=n {
379 aug[i * (n + 1) + j] /= pivot;
380 }
381
382 for k in 0..n {
383 if k != i {
384 let factor = aug[k * (n + 1) + i];
385 for j in i..=n {
386 aug[k * (n + 1) + j] -= factor * aug[i * (n + 1) + j];
387 }
388 }
389 }
390 }
391
392 (0..n).map(|i| aug[i * (n + 1) + n]).collect()
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398
399 #[test]
400 fn test_rbf_network_regression() {
401 let x = Tensor::from_slice(&[
402 0.0, 0.0,
403 1.0, 0.0,
404 0.0, 1.0,
405 1.0, 1.0,
406 ], &[4, 2]).unwrap();
407 let y = Tensor::from_slice(&[0.0, 1.0, 1.0, 0.0], &[4]).unwrap();
408
409 let mut rbf = RBFNetwork::new(4).gamma(1.0);
410 rbf.fit(&x, &y);
411 let pred = rbf.predict(&x);
412
413 assert_eq!(pred.dims(), &[4]);
414 }
415
416 #[test]
417 fn test_rbf_classifier() {
418 let x = Tensor::from_slice(&[
419 0.0, 0.0,
420 1.0, 0.0,
421 0.0, 1.0,
422 1.0, 1.0,
423 ], &[4, 2]).unwrap();
424 let y = Tensor::from_slice(&[0.0, 1.0, 1.0, 0.0], &[4]).unwrap();
425
426 let mut clf = RBFClassifier::new(4).gamma(1.0);
427 clf.fit(&x, &y);
428 let pred = clf.predict(&x);
429
430 assert_eq!(pred.dims(), &[4]);
431 }
432}