use rand::{RngExt, SeedableRng};
use rand_chacha::ChaCha8Rng;
use rayon::prelude::*;
const TWO_PI: f64 = std::f64::consts::TAU;
pub struct KuramotoSolver {
pub n: usize,
pub n_inv: f64,
pub omega: Vec<f64>,
pub coupling: Vec<f64>,
pub phases: Vec<f64>,
pub noise_amp: f64,
pub field_pressure: f64,
dtheta: Vec<f64>,
sin_diff: Vec<f64>,
noise: Vec<f64>,
cos_theta: Vec<f64>,
geo_coupling: Vec<f64>,
pgbo_coupling: Vec<f64>,
}
impl KuramotoSolver {
pub fn new(
omega: Vec<f64>,
coupling_flat: Vec<f64>,
initial_phases: Vec<f64>,
noise_amp: f64,
) -> Self {
let n = omega.len();
assert!(n > 0, "omega must not be empty");
assert_eq!(
initial_phases.len(),
n,
"initial_phases length mismatch: got {}, expected {}",
initial_phases.len(),
n
);
assert_eq!(
coupling_flat.len(),
n * n,
"coupling length mismatch: got {}, expected {}",
coupling_flat.len(),
n * n
);
assert_all_finite("omega", &omega);
assert_all_finite("coupling", &coupling_flat);
assert_all_finite("initial_phases", &initial_phases);
assert!(
noise_amp.is_finite() && noise_amp >= 0.0,
"noise_amp must be finite and non-negative"
);
Self {
n,
n_inv: 1.0 / n as f64,
omega,
coupling: coupling_flat,
phases: initial_phases,
noise_amp,
field_pressure: 0.0,
dtheta: vec![0.0; n],
sin_diff: vec![0.0; n * n],
noise: vec![0.0; n],
cos_theta: vec![0.0; n],
geo_coupling: vec![0.0; n],
pgbo_coupling: vec![0.0; n],
}
}
pub fn set_field_pressure(&mut self, f: f64) {
assert!(f.is_finite(), "field_pressure must be finite");
self.field_pressure = f;
}
pub fn step(&mut self, dt: f64, seed: u64) -> f64 {
assert_dt(dt);
let n = self.n;
let phases = &self.phases;
self.sin_diff
.par_chunks_mut(n)
.enumerate()
.for_each(|(row_idx, row)| {
let theta_n = phases[row_idx];
for (col_idx, value) in row.iter_mut().enumerate() {
*value = (phases[col_idx] - theta_n).sin();
}
});
if seed == 0 || self.noise_amp == 0.0 {
self.noise.fill(0.0);
} else {
fill_standard_normals(&mut self.noise, seed);
}
self.dtheta
.par_iter_mut()
.enumerate()
.for_each(|(row_idx, dtheta_n)| {
let coupling_row = &self.coupling[row_idx * n..(row_idx + 1) * n];
let sin_row = &self.sin_diff[row_idx * n..(row_idx + 1) * n];
let coupling_sum = crate::simd::dot_f64_dispatch(coupling_row, sin_row);
*dtheta_n = self.omega[row_idx]
+ coupling_sum * self.n_inv
+ self.noise_amp * self.noise[row_idx];
});
for (phase, dtheta) in self.phases.iter_mut().zip(self.dtheta.iter()) {
*phase = (*phase + dtheta * dt).rem_euclid(TWO_PI);
}
self.order_parameter()
}
pub fn run(&mut self, n_steps: usize, dt: f64, seed: u64) -> Vec<f64> {
assert_dt(dt);
let mut order_values = Vec::with_capacity(n_steps);
for step_idx in 0..n_steps {
let step_seed = if seed == 0 {
0
} else {
seed.wrapping_add(step_idx as u64)
};
order_values.push(self.step(dt, step_seed));
}
order_values
}
#[allow(clippy::too_many_arguments)]
pub fn step_ssgf(
&mut self,
dt: f64,
seed: u64,
w_flat: &[f64],
sigma_g: f64,
h_flat: &[f64],
pgbo_weight: f64,
) -> f64 {
assert_dt(dt);
assert!(sigma_g.is_finite(), "sigma_g must be finite");
assert!(pgbo_weight.is_finite(), "pgbo_weight must be finite");
let n = self.n;
let phases = &self.phases;
self.sin_diff
.par_chunks_mut(n)
.enumerate()
.for_each(|(row_idx, row)| {
let theta_n = phases[row_idx];
for (col_idx, value) in row.iter_mut().enumerate() {
*value = (phases[col_idx] - theta_n).sin();
}
});
if seed == 0 || self.noise_amp == 0.0 {
self.noise.fill(0.0);
} else {
fill_standard_normals(&mut self.noise, seed);
}
let has_geo = !w_flat.is_empty() && sigma_g != 0.0;
if has_geo {
assert_eq!(
w_flat.len(),
n * n,
"w_flat length mismatch: got {}, expected {}",
w_flat.len(),
n * n
);
assert_all_finite("w_flat", w_flat);
self.geo_coupling
.par_iter_mut()
.enumerate()
.for_each(|(row_idx, geo_n)| {
let w_row = &w_flat[row_idx * n..(row_idx + 1) * n];
let sin_row = &self.sin_diff[row_idx * n..(row_idx + 1) * n];
*geo_n = sigma_g
* w_row
.iter()
.zip(sin_row.iter())
.map(|(w, s)| w * s)
.sum::<f64>();
});
} else {
self.geo_coupling.fill(0.0);
}
let has_pgbo = !h_flat.is_empty() && pgbo_weight != 0.0;
if has_pgbo {
assert_eq!(
h_flat.len(),
n * n,
"h_flat length mismatch: got {}, expected {}",
h_flat.len(),
n * n
);
assert_all_finite("h_flat", h_flat);
self.pgbo_coupling
.par_iter_mut()
.enumerate()
.for_each(|(row_idx, pgbo_n)| {
let h_row = &h_flat[row_idx * n..(row_idx + 1) * n];
let sin_row = &self.sin_diff[row_idx * n..(row_idx + 1) * n];
*pgbo_n = pgbo_weight
* h_row
.iter()
.zip(sin_row.iter())
.map(|(h, s)| h * s)
.sum::<f64>();
});
} else {
self.pgbo_coupling.fill(0.0);
}
if self.field_pressure != 0.0 {
for (c, &theta) in self.cos_theta.iter_mut().zip(phases.iter()) {
*c = theta.cos();
}
} else {
self.cos_theta.fill(0.0);
}
self.dtheta
.par_iter_mut()
.enumerate()
.for_each(|(i, dtheta_n)| {
let coupling_row = &self.coupling[i * n..(i + 1) * n];
let sin_row = &self.sin_diff[i * n..(i + 1) * n];
let coupling_sum = crate::simd::dot_f64_dispatch(coupling_row, sin_row);
*dtheta_n = self.omega[i]
+ coupling_sum * self.n_inv
+ self.geo_coupling[i]
+ self.pgbo_coupling[i]
+ self.field_pressure * self.cos_theta[i]
+ self.noise_amp * self.noise[i];
});
for (phase, dtheta) in self.phases.iter_mut().zip(self.dtheta.iter()) {
*phase = (*phase + dtheta * dt).rem_euclid(TWO_PI);
}
self.order_parameter()
}
#[allow(clippy::too_many_arguments)]
pub fn run_ssgf(
&mut self,
n_steps: usize,
dt: f64,
seed: u64,
w_flat: &[f64],
sigma_g: f64,
h_flat: &[f64],
pgbo_weight: f64,
) -> Vec<f64> {
assert_dt(dt);
assert!(sigma_g.is_finite(), "sigma_g must be finite");
assert!(pgbo_weight.is_finite(), "pgbo_weight must be finite");
if !w_flat.is_empty() {
assert_eq!(
w_flat.len(),
self.n * self.n,
"w_flat length mismatch: got {}, expected {}",
w_flat.len(),
self.n * self.n
);
assert_all_finite("w_flat", w_flat);
}
if !h_flat.is_empty() {
assert_eq!(
h_flat.len(),
self.n * self.n,
"h_flat length mismatch: got {}, expected {}",
h_flat.len(),
self.n * self.n
);
assert_all_finite("h_flat", h_flat);
}
let mut order_values = Vec::with_capacity(n_steps);
for step_idx in 0..n_steps {
let step_seed = if seed == 0 {
0
} else {
seed.wrapping_add(step_idx as u64)
};
order_values.push(self.step_ssgf(dt, step_seed, w_flat, sigma_g, h_flat, pgbo_weight));
}
order_values
}
pub fn order_parameter(&self) -> f64 {
if self.phases.is_empty() {
return 0.0;
}
let n_inv = 1.0 / self.phases.len() as f64;
let mean_cos = self.phases.iter().map(|theta| theta.cos()).sum::<f64>() * n_inv;
let mean_sin = self.phases.iter().map(|theta| theta.sin()).sum::<f64>() * n_inv;
(mean_cos * mean_cos + mean_sin * mean_sin).sqrt()
}
pub fn get_phases(&self) -> &[f64] {
&self.phases
}
pub fn set_phases(&mut self, phases: Vec<f64>) {
assert_eq!(
phases.len(),
self.n,
"phases length mismatch: got {}, expected {}",
phases.len(),
self.n
);
assert_all_finite("phases", &phases);
self.phases = phases;
}
pub fn set_coupling(&mut self, coupling_flat: Vec<f64>) {
assert_eq!(
coupling_flat.len(),
self.n * self.n,
"coupling length mismatch: got {}, expected {}",
coupling_flat.len(),
self.n * self.n
);
assert_all_finite("coupling", &coupling_flat);
self.coupling = coupling_flat;
}
}
fn assert_dt(dt: f64) {
assert!(dt.is_finite() && dt > 0.0, "dt must be finite and positive");
}
fn assert_all_finite(name: &str, values: &[f64]) {
assert!(
values.iter().all(|value| value.is_finite()),
"{name} values must be finite"
);
}
fn fill_standard_normals(out: &mut [f64], seed: u64) {
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut i = 0usize;
while i + 1 < out.len() {
let u1 = rng.random::<f64>().max(f64::MIN_POSITIVE);
let u2 = rng.random::<f64>();
let r = (-2.0 * u1.ln()).sqrt();
let theta = TWO_PI * u2;
out[i] = r * theta.cos();
out[i + 1] = r * theta.sin();
i += 2;
}
if i < out.len() {
let u1 = rng.random::<f64>().max(f64::MIN_POSITIVE);
let u2 = rng.random::<f64>();
let r = (-2.0 * u1.ln()).sqrt();
out[i] = r * (TWO_PI * u2).cos();
}
}