Skip to main content

oxiphysics_gpu/
neural_physics.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Neural network-based physics acceleration (CPU mock).
5//!
6//! Provides a simple feed-forward neural network with multiple activation
7//! functions, forward-pass inference, MSE loss, ML force potentials, and
8//! collision probability prediction — all on CPU as a GPU mock backend.
9
10// ── Activation type ──────────────────────────────────────────────────────────
11
12/// Activation function used in a neural layer.
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum ActivationType {
15    /// Rectified linear unit: `max(0, x)`.
16    Relu,
17    /// Hyperbolic tangent.
18    Tanh,
19    /// Logistic sigmoid: `1 / (1 + exp(-x))`.
20    Sigmoid,
21    /// Identity / no activation.
22    Linear,
23}
24
25// ── Layer and network structs ─────────────────────────────────────────────────
26
27/// A single fully-connected neural network layer.
28#[derive(Debug, Clone)]
29pub struct NeuralLayer {
30    /// Weight matrix: `weights[out][in]`.
31    pub weights: Vec<Vec<f32>>,
32    /// Bias vector, one entry per output neuron.
33    pub biases: Vec<f32>,
34    /// Activation function applied after the linear transform.
35    pub activation: ActivationType,
36}
37
38/// A feed-forward neural network composed of stacked [`NeuralLayer`]s.
39#[derive(Debug, Clone)]
40pub struct NeuralNet {
41    /// The ordered list of layers.
42    pub layers: Vec<NeuralLayer>,
43    /// Expected size of the input vector.
44    pub input_size: usize,
45    /// Size of the network's output.
46    pub output_size: usize,
47}
48
49// ── Activation functions ──────────────────────────────────────────────────────
50
51/// Evaluate activation function `act` at scalar `x`.
52pub fn activate(x: f32, act: &ActivationType) -> f32 {
53    match act {
54        ActivationType::Relu => x.max(0.0),
55        ActivationType::Tanh => x.tanh(),
56        ActivationType::Sigmoid => 1.0 / (1.0 + (-x).exp()),
57        ActivationType::Linear => x,
58    }
59}
60
61/// Evaluate the derivative of activation function `act` at *pre-activation* `x`.
62pub fn activate_derivative(x: f32, act: &ActivationType) -> f32 {
63    match act {
64        ActivationType::Relu => {
65            if x > 0.0 {
66                1.0
67            } else {
68                0.0
69            }
70        }
71        ActivationType::Tanh => {
72            let t = x.tanh();
73            1.0 - t * t
74        }
75        ActivationType::Sigmoid => {
76            let s = 1.0 / (1.0 + (-x).exp());
77            s * (1.0 - s)
78        }
79        ActivationType::Linear => 1.0,
80    }
81}
82
83// ── Inference ─────────────────────────────────────────────────────────────────
84
85/// Run a forward pass through `net`, returning the output vector.
86///
87/// # Panics
88/// Panics in debug mode if `input.len() != net.input_size`.
89pub fn forward_pass(net: &NeuralNet, input: &[f32]) -> Vec<f32> {
90    debug_assert_eq!(input.len(), net.input_size);
91    let mut current: Vec<f32> = input.to_vec();
92    for layer in &net.layers {
93        let n_out = layer.biases.len();
94        let mut next = Vec::with_capacity(n_out);
95        for o in 0..n_out {
96            let mut sum = layer.biases[o];
97            for (i, &inp) in current.iter().enumerate() {
98                if i < layer.weights[o].len() {
99                    sum += layer.weights[o][i] * inp;
100                }
101            }
102            next.push(activate(sum, &layer.activation));
103        }
104        current = next;
105    }
106    current
107}
108
109// ── Loss ──────────────────────────────────────────────────────────────────────
110
111/// Mean squared error between `predicted` and `target` vectors.
112///
113/// Returns 0 if either slice is empty or lengths differ.
114pub fn mse_loss(predicted: &[f32], target: &[f32]) -> f32 {
115    if predicted.is_empty() || predicted.len() != target.len() {
116        return 0.0;
117    }
118    let n = predicted.len() as f32;
119    predicted
120        .iter()
121        .zip(target.iter())
122        .map(|(p, t)| (p - t) * (p - t))
123        .sum::<f32>()
124        / n
125}
126
127// ── Physics applications ──────────────────────────────────────────────────────
128
129/// Predict interatomic forces using a neural network potential.
130///
131/// For each atom, concatenates its position `[x, y, z]` with its type index,
132/// runs a forward pass, and interprets the first three output components as the
133/// predicted force `[fx, fy, fz]`.
134#[allow(clippy::too_many_arguments)]
135pub fn neural_force_prediction(
136    net: &NeuralNet,
137    positions: &[[f32; 3]],
138    types: &[u32],
139) -> Vec<[f32; 3]> {
140    positions
141        .iter()
142        .zip(types.iter())
143        .map(|(pos, &atom_type)| {
144            let mut inp = Vec::with_capacity(net.input_size);
145            inp.push(pos[0]);
146            inp.push(pos[1]);
147            inp.push(pos[2]);
148            inp.push(atom_type as f32);
149            // Pad or truncate to net.input_size
150            inp.resize(net.input_size, 0.0);
151            let out = forward_pass(net, &inp);
152            let fx = out.first().copied().unwrap_or(0.0);
153            let fy = out.get(1).copied().unwrap_or(0.0);
154            let fz = out.get(2).copied().unwrap_or(0.0);
155            [fx, fy, fz]
156        })
157        .collect()
158}
159
160/// Predict collision probability between two spheres using a neural network.
161///
162/// Input features: relative displacement `[dx, dy, dz]`, radii `[ra, rb]`.
163/// Returns a scalar in `[0, 1]`.
164pub fn neural_collision_check(
165    net: &NeuralNet,
166    pos_a: [f32; 3],
167    pos_b: [f32; 3],
168    radii: [f32; 2],
169) -> f32 {
170    let dx = pos_b[0] - pos_a[0];
171    let dy = pos_b[1] - pos_a[1];
172    let dz = pos_b[2] - pos_a[2];
173    let mut inp = vec![dx, dy, dz, radii[0], radii[1]];
174    inp.resize(net.input_size, 0.0);
175    let out = forward_pass(net, &inp);
176    // Clamp output to [0, 1]
177    out.first().copied().unwrap_or(0.0).clamp(0.0, 1.0)
178}
179
180/// Run a batched GPU-style forward pass for multiple input vectors.
181pub fn gpu_neural_batch_forward(net: &NeuralNet, batch: &[Vec<f32>]) -> Vec<Vec<f32>> {
182    batch.iter().map(|inp| forward_pass(net, inp)).collect()
183}
184
185// ── Network construction ──────────────────────────────────────────────────────
186
187/// Create a fully-connected network with the given layer sizes and random weights.
188///
189/// `layer_sizes` must contain at least 2 entries (input + output).
190/// All hidden layers use `activation`; the output layer uses `Linear`.
191pub fn create_network(layer_sizes: &[usize], activation: ActivationType) -> NeuralNet {
192    use rand::RngExt;
193    assert!(
194        layer_sizes.len() >= 2,
195        "Need at least input and output sizes"
196    );
197
198    let mut rng = rand::rng();
199    let mut layers = Vec::new();
200
201    for i in 0..layer_sizes.len() - 1 {
202        let n_in = layer_sizes[i];
203        let n_out = layer_sizes[i + 1];
204        let is_last = i == layer_sizes.len() - 2;
205        let act = if is_last {
206            ActivationType::Linear
207        } else {
208            activation
209        };
210
211        let scale = (2.0_f32 / n_in as f32).sqrt();
212        let weights: Vec<Vec<f32>> = (0..n_out)
213            .map(|_| (0..n_in).map(|_| rng.random_range(-scale..scale)).collect())
214            .collect();
215        let biases: Vec<f32> = (0..n_out).map(|_| 0.0_f32).collect();
216        layers.push(NeuralLayer {
217            weights,
218            biases,
219            activation: act,
220        });
221    }
222
223    NeuralNet {
224        input_size: layer_sizes[0],
225        output_size: *layer_sizes.last().expect("collection should not be empty"),
226        layers,
227    }
228}
229
230// ── Tests ─────────────────────────────────────────────────────────────────────
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    fn simple_net() -> NeuralNet {
237        // 2 → 3 → 1
238        create_network(&[2, 3, 1], ActivationType::Relu)
239    }
240
241    #[test]
242    fn test_activate_relu_positive() {
243        assert!((activate(2.0, &ActivationType::Relu) - 2.0).abs() < 1e-6);
244    }
245
246    #[test]
247    fn test_activate_relu_negative() {
248        assert!((activate(-1.0, &ActivationType::Relu)).abs() < 1e-6);
249    }
250
251    #[test]
252    fn test_activate_relu_zero() {
253        assert!((activate(0.0, &ActivationType::Relu)).abs() < 1e-6);
254    }
255
256    #[test]
257    fn test_activate_tanh_zero() {
258        assert!((activate(0.0, &ActivationType::Tanh)).abs() < 1e-6);
259    }
260
261    #[test]
262    fn test_activate_sigmoid_zero() {
263        assert!((activate(0.0, &ActivationType::Sigmoid) - 0.5).abs() < 1e-6);
264    }
265
266    #[test]
267    fn test_activate_linear() {
268        assert!((activate(3.125, &ActivationType::Linear) - 3.125).abs() < 1e-6);
269    }
270
271    #[test]
272    fn test_activate_derivative_relu_positive() {
273        assert!((activate_derivative(1.0, &ActivationType::Relu) - 1.0).abs() < 1e-6);
274    }
275
276    #[test]
277    fn test_activate_derivative_relu_negative() {
278        assert!((activate_derivative(-1.0, &ActivationType::Relu)).abs() < 1e-6);
279    }
280
281    #[test]
282    fn test_activate_derivative_tanh_zero() {
283        assert!((activate_derivative(0.0, &ActivationType::Tanh) - 1.0).abs() < 1e-6);
284    }
285
286    #[test]
287    fn test_activate_derivative_sigmoid_zero() {
288        assert!((activate_derivative(0.0, &ActivationType::Sigmoid) - 0.25).abs() < 1e-5);
289    }
290
291    #[test]
292    fn test_activate_derivative_linear() {
293        assert!((activate_derivative(99.0, &ActivationType::Linear) - 1.0).abs() < 1e-6);
294    }
295
296    #[test]
297    fn test_mse_loss_zero() {
298        let a = vec![1.0, 2.0, 3.0];
299        assert!((mse_loss(&a, &a)).abs() < 1e-6);
300    }
301
302    #[test]
303    fn test_mse_loss_known() {
304        let p = vec![0.0, 0.0];
305        let t = vec![1.0, 1.0];
306        assert!((mse_loss(&p, &t) - 1.0).abs() < 1e-6);
307    }
308
309    #[test]
310    fn test_mse_loss_empty() {
311        assert!((mse_loss(&[], &[])).abs() < 1e-6);
312    }
313
314    #[test]
315    fn test_mse_loss_length_mismatch() {
316        assert!((mse_loss(&[1.0], &[1.0, 2.0])).abs() < 1e-6);
317    }
318
319    #[test]
320    fn test_create_network_sizes() {
321        let net = create_network(&[4, 8, 8, 3], ActivationType::Relu);
322        assert_eq!(net.input_size, 4);
323        assert_eq!(net.output_size, 3);
324        assert_eq!(net.layers.len(), 3);
325    }
326
327    #[test]
328    fn test_create_network_layer_dims() {
329        let net = create_network(&[3, 5, 2], ActivationType::Tanh);
330        assert_eq!(net.layers[0].weights.len(), 5);
331        assert_eq!(net.layers[0].weights[0].len(), 3);
332        assert_eq!(net.layers[1].weights.len(), 2);
333        assert_eq!(net.layers[1].weights[0].len(), 5);
334    }
335
336    #[test]
337    fn test_create_network_output_activation_linear() {
338        let net = create_network(&[2, 4, 1], ActivationType::Relu);
339        assert_eq!(
340            net.layers.last().unwrap().activation,
341            ActivationType::Linear
342        );
343    }
344
345    #[test]
346    fn test_forward_pass_output_size() {
347        let net = simple_net();
348        let out = forward_pass(&net, &[0.5, -0.3]);
349        assert_eq!(out.len(), 1);
350    }
351
352    #[test]
353    fn test_forward_pass_deterministic() {
354        let net = simple_net();
355        let a = forward_pass(&net, &[1.0, 0.0]);
356        let b = forward_pass(&net, &[1.0, 0.0]);
357        assert_eq!(a, b);
358    }
359
360    #[test]
361    fn test_forward_pass_zero_input() {
362        let net = simple_net();
363        let out = forward_pass(&net, &[0.0, 0.0]);
364        assert_eq!(out.len(), 1);
365    }
366
367    #[test]
368    fn test_forward_pass_sigmoid_net() {
369        let net = create_network(&[2, 2, 1], ActivationType::Sigmoid);
370        let out = forward_pass(&net, &[0.0, 0.0]);
371        // Output of sigmoid net on zero input should be in range [0,1] roughly
372        assert!(out[0].is_finite());
373    }
374
375    #[test]
376    fn test_neural_force_prediction_shape() {
377        let net = create_network(&[4, 8, 3], ActivationType::Relu);
378        let positions = vec![[1.0_f32, 0.0, 0.0], [0.0, 1.0, 0.0]];
379        let types = vec![0u32, 1];
380        let forces = neural_force_prediction(&net, &positions, &types);
381        assert_eq!(forces.len(), 2);
382    }
383
384    #[test]
385    fn test_neural_force_prediction_finite() {
386        let net = create_network(&[4, 6, 3], ActivationType::Tanh);
387        let positions = vec![[0.0_f32; 3]];
388        let types = vec![0u32];
389        let forces = neural_force_prediction(&net, &positions, &types);
390        assert!(forces[0][0].is_finite());
391        assert!(forces[0][1].is_finite());
392        assert!(forces[0][2].is_finite());
393    }
394
395    #[test]
396    fn test_neural_collision_check_range() {
397        let net = create_network(&[5, 4, 1], ActivationType::Sigmoid);
398        let prob = neural_collision_check(&net, [0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.5, 0.5]);
399        assert!((0.0..=1.0).contains(&prob));
400    }
401
402    #[test]
403    fn test_neural_collision_check_zero_sep() {
404        let net = create_network(&[5, 4, 1], ActivationType::Sigmoid);
405        let prob = neural_collision_check(&net, [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 1.0]);
406        assert!((0.0..=1.0).contains(&prob));
407    }
408
409    #[test]
410    fn test_gpu_neural_batch_forward_shape() {
411        let net = create_network(&[3, 4, 2], ActivationType::Relu);
412        let batch: Vec<Vec<f32>> = vec![
413            vec![1.0, 2.0, 3.0],
414            vec![0.0, 0.0, 0.0],
415            vec![-1.0, 0.5, 0.1],
416        ];
417        let results = gpu_neural_batch_forward(&net, &batch);
418        assert_eq!(results.len(), 3);
419        for r in &results {
420            assert_eq!(r.len(), 2);
421        }
422    }
423
424    #[test]
425    fn test_gpu_neural_batch_forward_empty() {
426        let net = create_network(&[2, 2, 1], ActivationType::Linear);
427        let results = gpu_neural_batch_forward(&net, &[]);
428        assert!(results.is_empty());
429    }
430
431    #[test]
432    fn test_create_network_two_layers() {
433        let net = create_network(&[1, 1], ActivationType::Linear);
434        assert_eq!(net.layers.len(), 1);
435        assert_eq!(net.input_size, 1);
436        assert_eq!(net.output_size, 1);
437    }
438
439    #[test]
440    fn test_network_weights_finite() {
441        let net = create_network(&[5, 10, 3], ActivationType::Relu);
442        for layer in &net.layers {
443            for row in &layer.weights {
444                for &w in row {
445                    assert!(w.is_finite());
446                }
447            }
448        }
449    }
450
451    #[test]
452    fn test_forward_pass_tanh_bounded() {
453        let net = create_network(&[2, 4, 1], ActivationType::Tanh);
454        let out = forward_pass(&net, &[100.0, -100.0]);
455        // tanh saturates; linear output should still be finite
456        assert!(out[0].is_finite());
457    }
458
459    #[test]
460    fn test_mse_loss_asymmetric() {
461        let p = vec![2.0_f32, 0.0];
462        let t = vec![0.0_f32, 2.0];
463        // (4 + 4) / 2 = 4
464        assert!((mse_loss(&p, &t) - 4.0).abs() < 1e-5);
465    }
466
467    #[test]
468    fn test_neural_force_empty_input() {
469        let net = create_network(&[4, 4, 3], ActivationType::Linear);
470        let forces = neural_force_prediction(&net, &[], &[]);
471        assert!(forces.is_empty());
472    }
473
474    #[test]
475    fn test_batch_forward_single_item() {
476        let net = create_network(&[2, 3, 1], ActivationType::Relu);
477        let batch = vec![vec![0.5_f32, -0.5]];
478        let out = gpu_neural_batch_forward(&net, &batch);
479        assert_eq!(out.len(), 1);
480        assert_eq!(out[0].len(), 1);
481    }
482}