Skip to main content

oxiphysics_gpu/kernels/md_force/
virialstresstensorkernel_traits.rs

1//! # VirialStressTensorKernel - Trait Implementations
2//!
3//! This module contains trait implementations for `VirialStressTensorKernel`.
4//!
5//! ## Implemented Traits
6//!
7//! - `ComputeKernel`
8//!
9//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
10
11#[allow(unused_imports)]
12use super::functions::*;
13use crate::compute::ComputeKernel;
14
15use super::types::VirialStressTensorKernel;
16
17#[allow(clippy::needless_range_loop)]
18impl ComputeKernel for VirialStressTensorKernel {
19    fn name(&self) -> &str {
20        "VirialStressTensorKernel"
21    }
22    fn execute(&self, inputs: &[&[f64]], outputs: &mut [Vec<f64>], work_size: usize) {
23        if inputs.len() < 2 || outputs.is_empty() {
24            return;
25        }
26        let pos = inputs[0];
27        let epsilon = inputs[1][0];
28        let sigma = inputs[1][1];
29        let cutoff = inputs[1][2];
30        let n = work_size;
31        let cutoff2 = cutoff * cutoff;
32        let mut w = [0.0f64; 6];
33        for i in 0..n {
34            for j in (i + 1)..n {
35                let dx = pos[i * 3] - pos[j * 3];
36                let dy = pos[i * 3 + 1] - pos[j * 3 + 1];
37                let dz = pos[i * 3 + 2] - pos[j * 3 + 2];
38                let r2 = dx * dx + dy * dy + dz * dz;
39                if r2 >= cutoff2 || r2 < 1e-30 {
40                    continue;
41                }
42                let r2_inv = 1.0 / r2;
43                let sr2 = sigma * sigma * r2_inv;
44                let sr6 = sr2 * sr2 * sr2;
45                let sr12 = sr6 * sr6;
46                let f_over_r2 = 24.0 * epsilon * (2.0 * sr12 - sr6) * r2_inv;
47                let r_vec = [dx, dy, dz];
48                let c = [(0usize, 0usize), (1, 1), (2, 2), (0, 1), (0, 2), (1, 2)];
49                for (ci, &(a, b)) in c.iter().enumerate() {
50                    w[ci] -= r_vec[a] * f_over_r2 * r_vec[b];
51                }
52            }
53        }
54        outputs[0] = w.to_vec();
55    }
56}