use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2, PyUntypedArrayMethods};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use crate::simd::{batch_evolve_multi_step, batch_evolve_rk4, SimdBatch};
use crate::vector3::Vector3;
fn numpy_to_simd(arr: &PyReadonlyArray2<f64>) -> PyResult<SimdBatch> {
let shape = arr.shape();
if shape.len() != 2 {
return Err(PyValueError::new_err(format!(
"expected 2D array, got {}D",
shape.len()
)));
}
if shape[1] != 3 {
return Err(PyValueError::new_err(format!(
"expected shape (N, 3), got ({}, {})",
shape[0], shape[1]
)));
}
let n = shape[0];
let slice = arr
.as_slice()
.map_err(|e| PyValueError::new_err(format!("array is not contiguous: {e}")))?;
let spins: Vec<Vector3<f64>> = (0..n)
.map(|i| Vector3::new(slice[i * 3], slice[i * 3 + 1], slice[i * 3 + 2]))
.collect();
Ok(SimdBatch::from_vector3_slice(&spins))
}
fn simd_to_numpy<'py>(py: Python<'py>, batch: &SimdBatch) -> Bound<'py, PyArray2<f64>> {
let n = batch.len();
let mut flat = Vec::with_capacity(n * 3);
for i in 0..n {
flat.push(batch.x[i]);
flat.push(batch.y[i]);
flat.push(batch.z[i]);
}
let arr = numpy::ndarray::Array2::from_shape_vec((n, 3), flat).expect("shape is always (n, 3)");
arr.into_pyarray(py)
}
#[pyfunction]
pub fn batch_rk4_step<'py>(
py: Python<'py>,
m: PyReadonlyArray2<'py, f64>,
h_eff: PyReadonlyArray2<'py, f64>,
alpha: f64,
gamma: f64,
dt: f64,
) -> PyResult<Bound<'py, PyArray2<f64>>> {
let m_batch = numpy_to_simd(&m)?;
let h_batch = numpy_to_simd(&h_eff)?;
if m_batch.len() != h_batch.len() {
return Err(PyValueError::new_err(format!(
"m and h_eff must have the same number of rows: {} vs {}",
m_batch.len(),
h_batch.len()
)));
}
let m_new = batch_evolve_rk4(&m_batch, &h_batch, alpha, gamma, dt)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(simd_to_numpy(py, &m_new))
}
#[pyfunction]
pub fn batch_rk4_multistep<'py>(
py: Python<'py>,
m: PyReadonlyArray2<'py, f64>,
h_eff: PyReadonlyArray2<'py, f64>,
alpha: f64,
gamma: f64,
dt: f64,
num_steps: usize,
) -> PyResult<Bound<'py, PyArray2<f64>>> {
let m_batch = numpy_to_simd(&m)?;
let h_batch = numpy_to_simd(&h_eff)?;
if m_batch.len() != h_batch.len() {
return Err(PyValueError::new_err(format!(
"m and h_eff must have the same number of rows: {} vs {}",
m_batch.len(),
h_batch.len()
)));
}
let m_new = batch_evolve_multi_step(m_batch, &h_batch, alpha, gamma, dt, num_steps)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(simd_to_numpy(py, &m_new))
}