use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use rayon::prelude::*;
use crate::constants::GAMMA;
use crate::error::{Error, Result};
use crate::vector3::Vector3;
pub fn parallel_sweep<P, R, F>(params: &[P], f: F) -> Vec<R>
where
P: Send + Sync,
R: Send,
F: Fn(&P) -> R + Send + Sync,
{
params.par_iter().map(f).collect()
}
pub fn parallel_sweep_with_progress<P, R, F>(params: &[P], f: F) -> (Vec<R>, Arc<AtomicUsize>)
where
P: Send + Sync,
R: Send,
F: Fn(&P) -> R + Send + Sync,
{
let progress = Arc::new(AtomicUsize::new(0));
let progress_ref = Arc::clone(&progress);
let results: Vec<R> = params
.par_iter()
.map(|p| {
let result = f(p);
progress_ref.fetch_add(1, Ordering::Relaxed);
result
})
.collect();
(results, progress)
}
#[derive(Debug, Clone)]
pub struct ParameterSweep<P, R> {
pub parameters: Vec<P>,
pub results: Vec<Option<R>>,
}
impl<P, R> ParameterSweep<P, R>
where
P: Send + Sync + Clone,
R: Send + Clone,
{
pub fn new(parameters: Vec<P>) -> Self {
let n = parameters.len();
Self {
parameters,
results: vec![None; n],
}
}
pub fn execute<F>(&mut self, f: F)
where
F: Fn(&P) -> R + Send + Sync,
{
let computed: Vec<R> = self.parameters.par_iter().map(f).collect();
self.results = computed.into_iter().map(Some).collect();
}
pub fn execute_with_progress<F>(&mut self, f: F) -> Arc<AtomicUsize>
where
F: Fn(&P) -> R + Send + Sync,
{
let progress = Arc::new(AtomicUsize::new(0));
let progress_ref = Arc::clone(&progress);
let computed: Vec<R> = self
.parameters
.par_iter()
.map(|p| {
let result = f(p);
progress_ref.fetch_add(1, Ordering::Relaxed);
result
})
.collect();
self.results = computed.into_iter().map(Some).collect();
progress
}
pub fn completed_count(&self) -> usize {
self.results.iter().filter(|r| r.is_some()).count()
}
pub fn is_complete(&self) -> bool {
self.completed_count() == self.parameters.len()
}
pub fn results_ref(&self) -> &[Option<R>] {
&self.results
}
}
#[derive(Debug, Clone)]
pub struct FieldSweepResult {
pub field: f64,
pub magnetization: Vector3<f64>,
pub m_parallel: f64,
}
pub fn field_sweep(
fields: &[f64],
field_direction: Vector3<f64>,
m_init: Vector3<f64>,
alpha: f64,
dt: f64,
max_steps: usize,
tol: f64,
) -> Vec<FieldSweepResult> {
let dir = field_direction.normalize();
fields
.par_iter()
.map(|&h_mag| {
let h_ext = dir * h_mag;
let mut m = m_init.normalize();
for _ in 0..max_steps {
let dm_dt = llg_torque_simple(m, h_ext, alpha);
let torque_mag = dm_dt.magnitude();
if torque_mag < tol {
break;
}
m = (m + dm_dt * dt).normalize();
}
let m_par = m.dot(&dir);
FieldSweepResult {
field: h_mag,
magnetization: m,
m_parallel: m_par,
}
})
.collect()
}
#[derive(Debug, Clone)]
pub struct TemperatureSweepResult {
pub temperature: f64,
pub avg_magnetization: f64,
pub magnetization_std: f64,
}
pub fn temperature_sweep(
temperatures: &[f64],
h_eff: f64,
mu: f64,
curie_temp: f64,
) -> Result<Vec<TemperatureSweepResult>> {
if curie_temp <= 0.0 {
return Err(Error::InvalidParameter {
param: "curie_temp".to_string(),
reason: "Curie temperature must be positive".to_string(),
});
}
let kb = 1.380649e-23;
let results: Vec<TemperatureSweepResult> = temperatures
.par_iter()
.map(|&temp| {
if temp <= 0.0 {
return TemperatureSweepResult {
temperature: temp,
avg_magnetization: 1.0,
magnetization_std: 0.0,
};
}
let mut m = 0.5_f64; for _ in 0..200 {
let h_total = mu * h_eff + 3.0 * curie_temp * m * kb;
let x = h_total / (kb * temp);
let m_new = langevin(x);
let dm = (m_new - m).abs();
m = m_new;
if dm < 1e-12 {
break;
}
}
let sigma = if temp < curie_temp {
((kb * temp) / (mu * h_eff.max(1e-30))).sqrt() * (1.0 - m).max(0.0)
} else {
((kb * temp) / (mu * h_eff.max(1e-30))).sqrt().min(1.0)
};
TemperatureSweepResult {
temperature: temp,
avg_magnetization: m.clamp(0.0, 1.0),
magnetization_std: sigma.clamp(0.0, 1.0),
}
})
.collect();
Ok(results)
}
#[derive(Debug, Clone)]
pub struct SweepPoint2D<R> {
pub param1: f64,
pub param2: f64,
pub result: R,
}
pub fn parallel_sweep_2d<R, F>(params1: &[f64], params2: &[f64], f: F) -> Vec<SweepPoint2D<R>>
where
R: Send,
F: Fn(f64, f64) -> R + Send + Sync,
{
let grid: Vec<(f64, f64)> = params1
.iter()
.flat_map(|&p1| params2.iter().map(move |&p2| (p1, p2)))
.collect();
grid.par_iter()
.map(|&(p1, p2)| SweepPoint2D {
param1: p1,
param2: p2,
result: f(p1, p2),
})
.collect()
}
fn llg_torque_simple(m: Vector3<f64>, h_ext: Vector3<f64>, alpha: f64) -> Vector3<f64> {
let m_cross_h = m.cross(&h_ext);
let m_cross_m_cross_h = m.cross(&m_cross_h);
let prefactor = -GAMMA / (1.0 + alpha * alpha);
(m_cross_h + m_cross_m_cross_h * alpha) * prefactor
}
fn langevin(x: f64) -> f64 {
if x.abs() < 1e-4 {
x / 3.0 - x * x * x / 45.0
} else if x.abs() > 20.0 {
let abs_x = x.abs();
let correction = 2.0 * (-2.0 * abs_x).exp();
x.signum() * (1.0 + correction) - 1.0 / x
} else {
let coth = x.cosh() / x.sinh();
coth - 1.0 / x
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parallel_sweep_trivial() {
let params: Vec<f64> = (0..50).map(|i| i as f64).collect();
let results = parallel_sweep(¶ms, |&x| x * x);
assert_eq!(results.len(), 50);
for (i, &r) in results.iter().enumerate() {
let expected = (i as f64) * (i as f64);
assert!(
(r - expected).abs() < 1e-15,
"index {}: got {}, expected {}",
i,
r,
expected,
);
}
}
#[test]
fn test_parallel_sweep_deterministic() {
let params: Vec<f64> = (0..100).map(|i| i as f64 * 0.1).collect();
let r1 = parallel_sweep(¶ms, |&x| x.sin());
let r2 = parallel_sweep(¶ms, |&x| x.sin());
assert_eq!(r1.len(), r2.len());
for (a, b) in r1.iter().zip(r2.iter()) {
assert!(
(a - b).abs() < 1e-15,
"determinism violated: {} != {}",
a,
b,
);
}
}
#[test]
fn test_parameter_sweep_builder() {
let params: Vec<i32> = (0..20).collect();
let mut sweep = ParameterSweep::<i32, i64>::new(params);
assert!(!sweep.is_complete());
assert_eq!(sweep.completed_count(), 0);
sweep.execute(|&x| (x as i64) * (x as i64));
assert!(sweep.is_complete());
assert_eq!(sweep.completed_count(), 20);
let results = sweep.results_ref();
assert_eq!(results[5], Some(25), "5^2 should be 25");
}
#[test]
fn test_field_sweep_saturation() {
let fields: Vec<f64> = vec![0.01, 0.1, 1.0, 10.0];
let dir = Vector3::new(0.0, 0.0, 1.0);
let m_init = Vector3::new(1.0, 0.0, 0.0); let alpha = 0.5; let dt = 1e-13;
let max_steps = 1_000_000;
let tol = 1e-10;
let results = field_sweep(&fields, dir, m_init, alpha, dt, max_steps, tol);
assert_eq!(results.len(), 4);
let last = &results[3];
assert!(
last.m_parallel > 0.9,
"at H=10T, m_parallel={} should be > 0.9",
last.m_parallel,
);
}
#[test]
fn test_temperature_sweep_ordering() {
let temps: Vec<f64> = vec![10.0, 100.0, 300.0, 500.0, 800.0, 1000.0];
let h_eff = 1.0; let mu = 9.274e-24; let tc = 600.0;
let results =
temperature_sweep(&temps, h_eff, mu, tc).expect("temperature sweep should succeed");
assert_eq!(results.len(), 6);
let m_low = results[0].avg_magnetization;
let m_high = results[5].avg_magnetization;
assert!(
m_low >= m_high,
"m(10K)={} should be >= m(1000K)={}",
m_low,
m_high,
);
}
#[test]
fn test_temperature_sweep_error_on_bad_curie() {
let temps = vec![100.0];
let result = temperature_sweep(&temps, 1.0, 1e-23, 0.0);
assert!(result.is_err());
let result = temperature_sweep(&temps, 1.0, 1e-23, -100.0);
assert!(result.is_err());
}
#[test]
fn test_sweep_with_progress() {
let params: Vec<u32> = (0..30).collect();
let (results, progress) = parallel_sweep_with_progress(¶ms, |&x| x * 2);
assert_eq!(results.len(), 30);
assert_eq!(progress.load(Ordering::Relaxed), 30);
for (i, &r) in results.iter().enumerate() {
assert_eq!(r, (i as u32) * 2);
}
}
#[test]
fn test_2d_sweep() {
let p1: Vec<f64> = vec![1.0, 2.0, 3.0];
let p2: Vec<f64> = vec![10.0, 20.0];
let results = parallel_sweep_2d(&p1, &p2, |a, b| a + b);
assert_eq!(results.len(), 6);
let mut found = [false; 6];
for r in &results {
let expected = r.param1 + r.param2;
assert!(
(r.result - expected).abs() < 1e-15,
"({}, {}): got {}, expected {}",
r.param1,
r.param2,
r.result,
expected,
);
let idx = p1
.iter()
.position(|&x| (x - r.param1).abs() < 1e-15)
.expect("param1 should be in p1");
let idy = p2
.iter()
.position(|&x| (x - r.param2).abs() < 1e-15)
.expect("param2 should be in p2");
found[idx * p2.len() + idy] = true;
}
assert!(found.iter().all(|&f| f), "not all grid points covered");
}
#[test]
fn test_langevin_function() {
assert!((langevin(0.0)).abs() < 1e-10);
assert!((langevin(100.0) - 1.0).abs() < 0.011);
let x = 1e-5;
assert!((langevin(x) - x / 3.0).abs() < 1e-15);
}
}