use super::FixedPoint;
use super::FixedVector;
use super::FixedMatrix;
use super::tensor::Tensor;
use super::linalg::compute_tier_dot_raw;
use super::derived::inverse;
use crate::fixed_point::universal::fasc::stack_evaluator::BinaryStorage;
use crate::fixed_point::core_types::errors::OverflowDetected;
pub fn differentiation_step() -> FixedPoint {
#[cfg(table_format = "q32_32")]
{ FixedPoint::from_raw(1i64 << (32 - 11)) }
#[cfg(table_format = "q16_16")]
{ FixedPoint::from_raw(1i32 << (16 - 5)) }
#[cfg(table_format = "q64_64")]
{ FixedPoint::from_raw(1i128 << (64 - 21)) }
#[cfg(table_format = "q128_128")]
{
use crate::fixed_point::I256;
FixedPoint::from_raw(I256::from_i128(1) << (128usize - 43))
}
#[cfg(table_format = "q256_256")]
{
use crate::fixed_point::I512;
FixedPoint::from_raw(I512::from_i128(1) << (256usize - 85))
}
}
#[inline]
fn divide_by_two_h(val: FixedPoint) -> FixedPoint {
#[cfg(table_format = "q32_32")]
{ FixedPoint::from_raw(val.raw() << 10) } #[cfg(table_format = "q16_16")]
{ FixedPoint::from_raw(val.raw() << 4) } #[cfg(table_format = "q64_64")]
{ FixedPoint::from_raw(val.raw() << 20u32) } #[cfg(table_format = "q128_128")]
{ FixedPoint::from_raw(val.raw() << 42usize) } #[cfg(table_format = "q256_256")]
{ FixedPoint::from_raw(val.raw() << 84usize) } }
pub trait MetricProvider {
fn dimension(&self) -> usize;
fn metric(&self, p: &FixedVector) -> FixedMatrix;
fn metric_inverse(&self, p: &FixedVector) -> Result<FixedMatrix, OverflowDetected> {
inverse(&self.metric(p))
}
fn christoffel_closed_form(&self, _p: &FixedVector) -> Option<Tensor> {
None
}
fn scalar_curvature_closed_form(&self, _p: &FixedVector) -> Option<FixedPoint> {
None
}
}
fn metric_partial(
provider: &dyn MetricProvider,
p: &FixedVector,
k: usize,
) -> FixedMatrix {
let h = differentiation_step();
let n = provider.dimension();
let mut p_plus = p.clone();
let mut p_minus = p.clone();
p_plus[k] = p_plus[k] + h;
p_minus[k] = p_minus[k] - h;
let g_plus = provider.metric(&p_plus);
let g_minus = provider.metric(&p_minus);
let mut result = FixedMatrix::new(n, n);
for i in 0..n {
for j in 0..n {
let diff = g_plus.get(i, j) - g_minus.get(i, j);
result.set(i, j, divide_by_two_h(diff));
}
}
result
}
pub fn christoffel(
provider: &dyn MetricProvider,
p: &FixedVector,
) -> Result<Tensor, OverflowDetected> {
if let Some(gamma) = provider.christoffel_closed_form(p) {
return Ok(gamma);
}
let n = provider.dimension();
let g_inv = provider.metric_inverse(p)?;
let dg: Vec<FixedMatrix> = (0..n).map(|k| metric_partial(provider, p, k)).collect();
let mut gamma = Tensor::new(&[n, n, n]);
let half = FixedPoint::one() / FixedPoint::from_int(2);
for k in 0..n {
for i in 0..n {
for j in 0..n {
let g_inv_row: Vec<BinaryStorage> = (0..n).map(|l| g_inv.get(k, l).raw()).collect();
let bracket: Vec<BinaryStorage> = (0..n).map(|l| {
let term = dg[i].get(j, l) + dg[j].get(l, i) - dg[l].get(i, j);
term.raw()
}).collect();
let contracted = FixedPoint::from_raw(
compute_tier_dot_raw(&g_inv_row, &bracket)
);
gamma.set(&[k, i, j], half * contracted);
}
}
}
Ok(gamma)
}
pub fn riemann_curvature(
provider: &dyn MetricProvider,
p: &FixedVector,
) -> Result<Tensor, OverflowDetected> {
let n = provider.dimension();
let h = differentiation_step();
let gamma_center = christoffel(provider, p)?;
let mut dgamma: Vec<Tensor> = Vec::with_capacity(n);
for j in 0..n {
let mut p_plus = p.clone();
let mut p_minus = p.clone();
p_plus[j] = p_plus[j] + h;
p_minus[j] = p_minus[j] - h;
let gamma_plus = christoffel(provider, &p_plus)?;
let gamma_minus = christoffel(provider, &p_minus)?;
let mut dg_j = Tensor::new(&[n, n, n]);
for l in 0..n {
for ii in 0..n {
for kk in 0..n {
let diff = gamma_plus.get(&[l, ii, kk]) - gamma_minus.get(&[l, ii, kk]);
dg_j.set(&[l, ii, kk], divide_by_two_h(diff));
}
}
}
dgamma.push(dg_j);
}
let mut riemann = Tensor::new(&[n, n, n, n]);
for l in 0..n {
for i in 0..n {
for j in 0..n {
for k in 0..n {
let deriv_term = dgamma[j].get(&[l, i, k]) - dgamma[k].get(&[l, i, j]);
let gamma_jm: Vec<BinaryStorage> = (0..n).map(|m|
gamma_center.get(&[l, j, m]).raw()
).collect();
let gamma_mik: Vec<BinaryStorage> = (0..n).map(|m|
gamma_center.get(&[m, i, k]).raw()
).collect();
let gamma_km: Vec<BinaryStorage> = (0..n).map(|m|
gamma_center.get(&[l, k, m]).raw()
).collect();
let gamma_mij: Vec<BinaryStorage> = (0..n).map(|m|
gamma_center.get(&[m, i, j]).raw()
).collect();
let contraction_pos = FixedPoint::from_raw(
compute_tier_dot_raw(&gamma_jm, &gamma_mik)
);
let contraction_neg = FixedPoint::from_raw(
compute_tier_dot_raw(&gamma_km, &gamma_mij)
);
riemann.set(&[l, i, j, k], deriv_term + contraction_pos - contraction_neg);
}
}
}
}
Ok(riemann)
}
pub fn ricci_tensor(
provider: &dyn MetricProvider,
p: &FixedVector,
) -> Result<FixedMatrix, OverflowDetected> {
let n = provider.dimension();
let riemann = riemann_curvature(provider, p)?;
let mut ricci = FixedMatrix::new(n, n);
for i in 0..n {
for j in 0..n {
let k_vals: Vec<BinaryStorage> = (0..n).map(|k|
riemann.get(&[k, i, k, j]).raw()
).collect();
let ones: Vec<BinaryStorage> = (0..n).map(|_|
FixedPoint::one().raw()
).collect();
let val = FixedPoint::from_raw(compute_tier_dot_raw(&k_vals, &ones));
ricci.set(i, j, val);
}
}
Ok(ricci)
}
pub fn ricci_from_riemann(riemann: &Tensor, n: usize) -> FixedMatrix {
let mut ricci = FixedMatrix::new(n, n);
for i in 0..n {
for j in 0..n {
let mut sum = FixedPoint::ZERO;
for k in 0..n {
sum = sum + riemann.get(&[k, i, k, j]);
}
ricci.set(i, j, sum);
}
}
ricci
}
pub fn scalar_curvature(
provider: &dyn MetricProvider,
p: &FixedVector,
) -> Result<FixedPoint, OverflowDetected> {
if let Some(r) = provider.scalar_curvature_closed_form(p) {
return Ok(r);
}
let n = provider.dimension();
let g_inv = provider.metric_inverse(p)?;
let ricci = ricci_tensor(provider, p)?;
let mut g_flat = Vec::with_capacity(n * n);
let mut r_flat = Vec::with_capacity(n * n);
for i in 0..n {
for j in 0..n {
g_flat.push(g_inv.get(i, j).raw());
r_flat.push(ricci.get(i, j).raw());
}
}
Ok(FixedPoint::from_raw(compute_tier_dot_raw(&g_flat, &r_flat)))
}
pub fn scalar_from_ricci(g_inv: &FixedMatrix, ricci: &FixedMatrix) -> FixedPoint {
let n = g_inv.rows();
let mut g_flat = Vec::with_capacity(n * n);
let mut r_flat = Vec::with_capacity(n * n);
for i in 0..n {
for j in 0..n {
g_flat.push(g_inv.get(i, j).raw());
r_flat.push(ricci.get(i, j).raw());
}
}
FixedPoint::from_raw(compute_tier_dot_raw(&g_flat, &r_flat))
}
pub fn sectional_curvature(
provider: &dyn MetricProvider,
p: &FixedVector,
u: &FixedVector,
v: &FixedVector,
) -> Result<FixedPoint, OverflowDetected> {
let n = provider.dimension();
let g = provider.metric(p);
let riemann = riemann_curvature(provider, p)?;
let mut w = FixedVector::new(n);
for s in 0..n {
let g_row: Vec<BinaryStorage> = (0..n).map(|l| g.get(l, s).raw()).collect();
let u_raw: Vec<BinaryStorage> = (0..n).map(|l| u[l].raw()).collect();
w[s] = FixedPoint::from_raw(compute_tier_dot_raw(&g_row, &u_raw));
}
let mut numerator = FixedPoint::ZERO;
for s in 0..n {
for i in 0..n {
for j in 0..n {
for k in 0..n {
let r_comp = riemann.get(&[s, i, j, k]);
if !r_comp.is_zero() {
numerator = numerator + r_comp * u[i] * v[j] * v[k] * w[s];
}
}
}
}
}
let u_raw: Vec<BinaryStorage> = (0..n).map(|i| u[i].raw()).collect();
let v_raw: Vec<BinaryStorage> = (0..n).map(|i| v[i].raw()).collect();
let gu: Vec<BinaryStorage> = (0..n).map(|i| {
let g_row: Vec<BinaryStorage> = (0..n).map(|j| g.get(i, j).raw()).collect();
compute_tier_dot_raw(&g_row, &u_raw)
}).collect();
let uu = FixedPoint::from_raw(compute_tier_dot_raw(&u_raw, &gu));
let gv: Vec<BinaryStorage> = (0..n).map(|i| {
let g_row: Vec<BinaryStorage> = (0..n).map(|j| g.get(i, j).raw()).collect();
compute_tier_dot_raw(&g_row, &v_raw)
}).collect();
let vv = FixedPoint::from_raw(compute_tier_dot_raw(&v_raw, &gv));
let uv = FixedPoint::from_raw(compute_tier_dot_raw(&u_raw, &gv));
let denom = uu * vv - uv * uv;
if denom.is_zero() {
return Err(OverflowDetected::DomainError);
}
Ok(numerator / denom)
}
pub struct EuclideanMetric {
pub dim: usize,
}
impl MetricProvider for EuclideanMetric {
fn dimension(&self) -> usize { self.dim }
fn metric(&self, _p: &FixedVector) -> FixedMatrix {
FixedMatrix::identity(self.dim)
}
fn metric_inverse(&self, _p: &FixedVector) -> Result<FixedMatrix, OverflowDetected> {
Ok(FixedMatrix::identity(self.dim))
}
}
pub struct SphereMetric {
pub radius: FixedPoint,
}
impl MetricProvider for SphereMetric {
fn dimension(&self) -> usize { 2 }
fn metric(&self, p: &FixedVector) -> FixedMatrix {
let theta = p[0];
let r_sq = self.radius * self.radius;
let sin_theta = theta.sin();
let z = FixedPoint::ZERO;
FixedMatrix::from_slice(2, 2, &[
r_sq, z,
z, r_sq * sin_theta * sin_theta,
])
}
fn christoffel_closed_form(&self, p: &FixedVector) -> Option<Tensor> {
let theta = p[0];
let sin_t = theta.sin();
let cos_t = theta.cos();
let mut gamma = Tensor::new(&[2, 2, 2]);
gamma.set(&[0, 1, 1], -sin_t * cos_t);
if !sin_t.is_zero() {
let cot_t = cos_t / sin_t;
gamma.set(&[1, 0, 1], cot_t);
gamma.set(&[1, 1, 0], cot_t);
}
Some(gamma)
}
fn scalar_curvature_closed_form(&self, _p: &FixedVector) -> Option<FixedPoint> {
let r_sq = self.radius * self.radius;
Some(FixedPoint::from_int(2) / r_sq)
}
}
pub struct HyperbolicMetric;
impl MetricProvider for HyperbolicMetric {
fn dimension(&self) -> usize { 2 }
fn metric(&self, p: &FixedVector) -> FixedMatrix {
let y = p[1];
let y_sq = y * y;
let scale = FixedPoint::one() / y_sq;
let z = FixedPoint::ZERO;
FixedMatrix::from_slice(2, 2, &[
scale, z,
z, scale,
])
}
fn christoffel_closed_form(&self, p: &FixedVector) -> Option<Tensor> {
let y = p[1];
if y.is_zero() { return None; }
let inv_y = FixedPoint::one() / y;
let mut gamma = Tensor::new(&[2, 2, 2]);
gamma.set(&[0, 0, 1], -inv_y);
gamma.set(&[0, 1, 0], -inv_y);
gamma.set(&[1, 0, 0], inv_y);
gamma.set(&[1, 1, 1], -inv_y);
Some(gamma)
}
fn scalar_curvature_closed_form(&self, _p: &FixedVector) -> Option<FixedPoint> {
Some(FixedPoint::from_int(-2))
}
}
use super::ode::{OdeSystem, rk4_step};
pub struct GeodesicOde<'a> {
provider: &'a dyn MetricProvider,
}
impl<'a> OdeSystem for GeodesicOde<'a> {
fn eval(&self, _t: FixedPoint, state: &FixedVector) -> FixedVector {
let n = self.provider.dimension();
let mut x = FixedVector::new(n);
let mut v = FixedVector::new(n);
for i in 0..n { x[i] = state[i]; v[i] = state[n + i]; }
let gamma = match christoffel(self.provider, &x) {
Ok(g) => g,
Err(_) => return FixedVector::new(2 * n), };
let mut dstate = FixedVector::new(2 * n);
for k in 0..n { dstate[k] = v[k]; }
for k in 0..n {
let mut gamma_k = Vec::with_capacity(n * n);
let mut vv = Vec::with_capacity(n * n);
for i in 0..n {
for j in 0..n {
gamma_k.push(gamma.get(&[k, i, j]).raw());
vv.push((v[i] * v[j]).raw());
}
}
let contraction = FixedPoint::from_raw(
compute_tier_dot_raw(&gamma_k, &vv)
);
dstate[n + k] = -contraction;
}
dstate
}
}
pub fn geodesic_integrate(
provider: &dyn MetricProvider,
initial_point: &FixedVector,
initial_velocity: &FixedVector,
total_time: FixedPoint,
num_steps: usize,
) -> Result<Vec<FixedVector>, OverflowDetected> {
let n = provider.dimension();
let h = total_time / FixedPoint::from_int(num_steps as i32);
let mut state = FixedVector::new(2 * n);
for i in 0..n { state[i] = initial_point[i]; state[n + i] = initial_velocity[i]; }
let sys = GeodesicOde { provider };
let mut points = Vec::with_capacity(num_steps + 1);
let mut t = FixedPoint::ZERO;
let extract_pos = |s: &FixedVector| -> FixedVector {
let mut p = FixedVector::new(n);
for i in 0..n { p[i] = s[i]; }
p
};
points.push(extract_pos(&state));
for _ in 0..num_steps {
state = rk4_step(&sys, t, &state, h);
t = t + h;
points.push(extract_pos(&state));
}
Ok(points)
}
pub fn parallel_transport_ode(
provider: &dyn MetricProvider,
curve: &[FixedVector],
initial_vector: &FixedVector,
reorthog_interval: usize,
) -> Result<FixedVector, OverflowDetected> {
if curve.len() < 2 {
return Ok(initial_vector.clone());
}
let n = provider.dimension();
let mut v = initial_vector.clone();
for step in 0..curve.len() - 1 {
let p = &curve[step];
let p_next = &curve[step + 1];
let dx: Vec<FixedPoint> = (0..n).map(|i| p_next[i] - p[i]).collect();
let gamma = christoffel(provider, p)?;
let mut v_new = FixedVector::new(n);
for k in 0..n {
let gamma_k_v: Vec<BinaryStorage> = (0..n).map(|j| {
let gamma_ki: Vec<BinaryStorage> = (0..n).map(|i|
gamma.get(&[k, i, j]).raw()
).collect();
let v_raw: Vec<BinaryStorage> = (0..n).map(|i| v[i].raw()).collect();
compute_tier_dot_raw(&gamma_ki, &v_raw)
}).collect();
let dx_raw: Vec<BinaryStorage> = dx.iter().map(|d| d.raw()).collect();
let correction = FixedPoint::from_raw(
compute_tier_dot_raw(&gamma_k_v, &dx_raw)
);
v_new[k] = v[k] - correction;
}
v = v_new;
if reorthog_interval > 0 && (step + 1) % reorthog_interval == 0 {
let dx_vec = FixedVector::from_slice(&dx);
let dx_norm_sq = dx_vec.dot_precise(&dx_vec);
if !dx_norm_sq.is_zero() {
let v_dot_dx = v.dot_precise(&dx_vec);
let coeff = v_dot_dx / dx_norm_sq;
v = &v - &(&dx_vec * coeff);
}
}
}
Ok(v)
}