use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::opt::homotopy_euler_predictor";
#[must_use]
pub fn homotopy_euler_predictor(
x_curr: &str,
v: &str,
dt_scaled: &str,
x_pred: &str,
n_paths: u32,
n_dim: u32,
) -> Program {
if n_paths == 0 {
return crate::invalid_output_program(
OP_ID,
x_pred,
DataType::U32,
"Fix: homotopy_euler_predictor requires n_paths > 0, got 0.".to_string(),
);
}
if n_dim == 0 {
return crate::invalid_output_program(
OP_ID,
x_pred,
DataType::U32,
"Fix: homotopy_euler_predictor requires n_dim > 0, got 0.".to_string(),
);
}
let cells = n_paths * n_dim;
let t = Expr::InvocationId { axis: 0 };
let value = Expr::add(
Expr::load(x_curr, t.clone()),
Expr::shr(
Expr::mul(
Expr::load(dt_scaled, Expr::u32(0)),
Expr::load(v, t.clone()),
),
Expr::u32(16),
),
);
let body = vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(cells)),
vec![Node::store(x_pred, t, value)],
)];
Program::wrapped(
vec![
BufferDecl::storage(x_curr, 0, BufferAccess::ReadOnly, DataType::U32).with_count(cells),
BufferDecl::storage(v, 1, BufferAccess::ReadOnly, DataType::U32).with_count(cells),
BufferDecl::storage(dt_scaled, 2, BufferAccess::ReadOnly, DataType::U32).with_count(1),
BufferDecl::storage(x_pred, 3, BufferAccess::ReadWrite, DataType::U32)
.with_count(cells),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
#[must_use]
pub fn homotopy_euler_predictor_cpu(x_curr: &[f64], v: &[f64], dt: f64) -> Vec<f64> {
x_curr
.iter()
.zip(v.iter())
.map(|(&x, &dv)| x + dt * dv)
.collect()
}
#[must_use]
pub fn linear_homotopy_cpu(g_x: &[f64], f_x: &[f64], t: f64) -> Vec<f64> {
let s = 1.0 - t;
g_x.iter()
.zip(f_x.iter())
.map(|(&g, &f)| s * g + t * f)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < 1e-10 * (1.0 + a.abs() + b.abs())
}
#[test]
fn cpu_zero_dt_holds_state() {
let x = vec![1.0, 2.0, 3.0];
let v = vec![10.0, 20.0, 30.0];
let out = homotopy_euler_predictor_cpu(&x, &v, 0.0);
assert_eq!(out, x);
}
#[test]
fn cpu_unit_dt_advances_by_v() {
let x = vec![5.0];
let v = vec![3.0];
let out = homotopy_euler_predictor_cpu(&x, &v, 1.0);
assert!(approx_eq(out[0], 8.0));
}
#[test]
fn cpu_mismatched_inputs_truncate_to_complete_pairs() {
assert_eq!(
homotopy_euler_predictor_cpu(&[1.0, 2.0], &[3.0], 1.0),
vec![4.0]
);
assert_eq!(linear_homotopy_cpu(&[1.0], &[3.0, 5.0], 0.5), vec![2.0]);
}
#[test]
fn cpu_iterated_steps_track_known_path() {
let mut x = vec![1.0];
let dt = 0.01;
for _ in 0..100 {
let v: Vec<f64> = x.iter().map(|&xi| -xi).collect();
x = homotopy_euler_predictor_cpu(&x, &v, dt);
}
let exact = (-1.0f64).exp();
assert!((x[0] - exact).abs() < 0.05);
}
#[test]
fn cpu_linear_homotopy_endpoints_match() {
let g = vec![1.0, 2.0];
let f = vec![10.0, 20.0];
let h0 = linear_homotopy_cpu(&g, &f, 0.0);
let h1 = linear_homotopy_cpu(&g, &f, 1.0);
assert_eq!(h0, g);
assert_eq!(h1, f);
}
#[test]
fn cpu_linear_homotopy_midpoint_averages() {
let g = vec![0.0, 4.0];
let f = vec![10.0, 0.0];
let mid = linear_homotopy_cpu(&g, &f, 0.5);
assert!(approx_eq(mid[0], 5.0));
assert!(approx_eq(mid[1], 2.0));
}
#[test]
fn ir_program_buffer_layout() {
let p = homotopy_euler_predictor("xc", "v", "dt", "xp", 4, 8);
assert_eq!(p.workgroup_size, [256, 1, 1]);
let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["xc", "v", "dt", "xp"]);
assert_eq!(p.buffers[0].count(), 32);
assert_eq!(p.buffers[1].count(), 32);
assert_eq!(p.buffers[2].count(), 1);
assert_eq!(p.buffers[3].count(), 32);
}
#[test]
fn zero_n_paths_traps() {
let p = homotopy_euler_predictor("xc", "v", "dt", "xp", 0, 1);
assert!(p.stats().trap());
}
#[test]
fn zero_n_dim_traps() {
let p = homotopy_euler_predictor("xc", "v", "dt", "xp", 1, 0);
assert!(p.stats().trap());
}
}