use super::FixedPoint;
use super::FixedVector;
use super::FixedMatrix;
use super::linalg::compute_tier_dot_raw;
use crate::fixed_point::core_types::errors::OverflowDetected;
use crate::fixed_point::universal::fasc::stack_evaluator::BinaryStorage;
pub trait FiberBundle {
fn project(&self, total_point: &FixedVector) -> FixedVector;
fn base_dim(&self) -> usize;
fn fiber_dim(&self) -> usize;
fn total_dim(&self) -> usize {
self.base_dim() + self.fiber_dim()
}
fn lift(&self, base: &FixedVector, fiber: &FixedVector) -> FixedVector;
fn local_trivialization(&self, total_point: &FixedVector) -> (FixedVector, FixedVector);
}
pub trait BundleConnection: FiberBundle {
fn horizontal_lift(
&self,
total_point: &FixedVector,
base_tangent: &FixedVector,
) -> Result<FixedVector, OverflowDetected>;
fn vertical_component(
&self,
total_point: &FixedVector,
total_tangent: &FixedVector,
) -> FixedVector;
fn horizontal_component(
&self,
total_point: &FixedVector,
total_tangent: &FixedVector,
) -> FixedVector {
let vert = self.vertical_component(total_point, total_tangent);
total_tangent - &vert
}
fn parallel_transport_along(
&self,
base_path: &[FixedVector],
initial_fiber: &FixedVector,
) -> Result<FixedVector, OverflowDetected>;
}
pub struct TrivialBundle {
pub base_dimension: usize,
pub fiber_dimension: usize,
}
impl FiberBundle for TrivialBundle {
fn project(&self, total_point: &FixedVector) -> FixedVector {
let mut base = FixedVector::new(self.base_dimension);
for i in 0..self.base_dimension {
base[i] = total_point[i];
}
base
}
fn base_dim(&self) -> usize { self.base_dimension }
fn fiber_dim(&self) -> usize { self.fiber_dimension }
fn lift(&self, base: &FixedVector, fiber: &FixedVector) -> FixedVector {
let mut total = FixedVector::new(self.base_dimension + self.fiber_dimension);
for i in 0..self.base_dimension {
total[i] = base[i];
}
for i in 0..self.fiber_dimension {
total[self.base_dimension + i] = fiber[i];
}
total
}
fn local_trivialization(&self, total_point: &FixedVector) -> (FixedVector, FixedVector) {
let mut base = FixedVector::new(self.base_dimension);
let mut fiber = FixedVector::new(self.fiber_dimension);
for i in 0..self.base_dimension {
base[i] = total_point[i];
}
for i in 0..self.fiber_dimension {
fiber[i] = total_point[self.base_dimension + i];
}
(base, fiber)
}
}
impl BundleConnection for TrivialBundle {
fn horizontal_lift(
&self,
_total_point: &FixedVector,
base_tangent: &FixedVector,
) -> Result<FixedVector, OverflowDetected> {
let mut total_tangent = FixedVector::new(self.total_dim());
for i in 0..self.base_dimension {
total_tangent[i] = base_tangent[i];
}
Ok(total_tangent)
}
fn vertical_component(
&self,
_total_point: &FixedVector,
total_tangent: &FixedVector,
) -> FixedVector {
let mut vert = FixedVector::new(self.total_dim());
for i in 0..self.fiber_dimension {
vert[self.base_dimension + i] = total_tangent[self.base_dimension + i];
}
vert
}
fn parallel_transport_along(
&self,
_base_path: &[FixedVector],
initial_fiber: &FixedVector,
) -> Result<FixedVector, OverflowDetected> {
Ok(initial_fiber.clone())
}
}
pub struct VectorBundle {
pub base_dim_val: usize,
pub fiber_dim_val: usize,
pub connection_coeffs: Option<Vec<FixedPoint>>,
}
impl VectorBundle {
pub fn flat(base_dim: usize, fiber_dim: usize) -> Self {
Self {
base_dim_val: base_dim,
fiber_dim_val: fiber_dim,
connection_coeffs: None,
}
}
pub fn with_connection(base_dim: usize, fiber_dim: usize, coeffs: Vec<FixedPoint>) -> Self {
assert_eq!(coeffs.len(), fiber_dim * fiber_dim * base_dim,
"Connection coefficients must have size fiber_dim² × base_dim");
Self {
base_dim_val: base_dim,
fiber_dim_val: fiber_dim,
connection_coeffs: Some(coeffs),
}
}
fn get_coeff(&self, a: usize, b: usize, i: usize) -> FixedPoint {
match &self.connection_coeffs {
None => FixedPoint::ZERO,
Some(coeffs) => {
let k = self.fiber_dim_val;
let n = self.base_dim_val;
coeffs[a * k * n + b * n + i]
}
}
}
}
impl FiberBundle for VectorBundle {
fn project(&self, total_point: &FixedVector) -> FixedVector {
let mut base = FixedVector::new(self.base_dim_val);
for i in 0..self.base_dim_val {
base[i] = total_point[i];
}
base
}
fn base_dim(&self) -> usize { self.base_dim_val }
fn fiber_dim(&self) -> usize { self.fiber_dim_val }
fn lift(&self, base: &FixedVector, fiber: &FixedVector) -> FixedVector {
let mut total = FixedVector::new(self.base_dim_val + self.fiber_dim_val);
for i in 0..self.base_dim_val {
total[i] = base[i];
}
for i in 0..self.fiber_dim_val {
total[self.base_dim_val + i] = fiber[i];
}
total
}
fn local_trivialization(&self, total_point: &FixedVector) -> (FixedVector, FixedVector) {
let mut base = FixedVector::new(self.base_dim_val);
let mut fiber = FixedVector::new(self.fiber_dim_val);
for i in 0..self.base_dim_val {
base[i] = total_point[i];
}
for i in 0..self.fiber_dim_val {
fiber[i] = total_point[self.base_dim_val + i];
}
(base, fiber)
}
}
impl BundleConnection for VectorBundle {
fn horizontal_lift(
&self,
total_point: &FixedVector,
base_tangent: &FixedVector,
) -> Result<FixedVector, OverflowDetected> {
let k = self.fiber_dim_val;
let n = self.base_dim_val;
let fiber: Vec<FixedPoint> = (0..k)
.map(|a| total_point[n + a])
.collect();
let mut total_tangent = FixedVector::new(n + k);
for i in 0..n {
total_tangent[i] = base_tangent[i];
}
for a in 0..k {
let mut terms: Vec<BinaryStorage> = Vec::with_capacity(k * n);
for b in 0..k {
for i in 0..n {
terms.push((self.get_coeff(a, b, i) * fiber[b] * base_tangent[i]).raw());
}
}
let ones: Vec<BinaryStorage> = vec![FixedPoint::one().raw(); terms.len()];
let sum = FixedPoint::from_raw(compute_tier_dot_raw(&terms, &ones));
total_tangent[n + a] = -sum;
}
Ok(total_tangent)
}
fn vertical_component(
&self,
_total_point: &FixedVector,
total_tangent: &FixedVector,
) -> FixedVector {
let n = self.base_dim_val;
let k = self.fiber_dim_val;
let mut vert = FixedVector::new(n + k);
for a in 0..k {
vert[n + a] = total_tangent[n + a];
}
vert
}
fn parallel_transport_along(
&self,
base_path: &[FixedVector],
initial_fiber: &FixedVector,
) -> Result<FixedVector, OverflowDetected> {
if base_path.len() < 2 {
return Ok(initial_fiber.clone());
}
let k = self.fiber_dim_val;
let mut fiber = initial_fiber.clone();
for step in 0..base_path.len() - 1 {
let n = self.base_dim_val;
let dx: Vec<FixedPoint> = (0..n)
.map(|i| base_path[step + 1][i] - base_path[step][i])
.collect();
let mut new_fiber = FixedVector::new(k);
for a in 0..k {
let mut terms: Vec<BinaryStorage> = Vec::with_capacity(k * n);
for b in 0..k {
for i in 0..n {
terms.push((self.get_coeff(a, b, i) * fiber[b] * dx[i]).raw());
}
}
let ones: Vec<BinaryStorage> = vec![FixedPoint::one().raw(); terms.len()];
let correction = FixedPoint::from_raw(compute_tier_dot_raw(&terms, &ones));
new_fiber[a] = fiber[a] - correction;
}
fiber = new_fiber;
}
Ok(fiber)
}
}
pub struct PrincipalBundle {
pub base_dim_val: usize,
pub group_dim: usize, pub matrix_dim: usize, pub num_charts: usize,
pub transitions: Vec<FixedMatrix>,
}
impl PrincipalBundle {
pub fn trivial(base_dim: usize, group_dim: usize, matrix_dim: usize, num_charts: usize) -> Self {
let id = FixedMatrix::identity(matrix_dim);
let transitions = vec![id; num_charts * num_charts];
Self {
base_dim_val: base_dim,
group_dim,
matrix_dim,
num_charts,
transitions,
}
}
pub fn transition(&self, alpha: usize, beta: usize) -> &FixedMatrix {
&self.transitions[alpha * self.num_charts + beta]
}
pub fn set_transition(
&mut self,
alpha: usize,
beta: usize,
g: FixedMatrix,
) -> Result<(), OverflowDetected> {
let g_inv = super::derived::inverse(&g)?;
self.transitions[alpha * self.num_charts + beta] = g;
self.transitions[beta * self.num_charts + alpha] = g_inv;
Ok(())
}
pub fn verify_cocycle(&self, tol: FixedPoint) -> (bool, FixedPoint) {
let mut max_err = FixedPoint::ZERO;
let mut ok = true;
for alpha in 0..self.num_charts {
for beta in 0..self.num_charts {
for gamma in 0..self.num_charts {
let g_ab = self.transition(alpha, beta);
let g_bg = self.transition(beta, gamma);
let g_ag = self.transition(alpha, gamma);
let product = g_ab * g_bg;
for i in 0..self.matrix_dim {
for j in 0..self.matrix_dim {
let err = (product.get(i, j) - g_ag.get(i, j)).abs();
if err > max_err { max_err = err; }
if err > tol { ok = false; }
}
}
}
}
}
(ok, max_err)
}
}
impl FiberBundle for PrincipalBundle {
fn project(&self, total_point: &FixedVector) -> FixedVector {
let mut base = FixedVector::new(self.base_dim_val);
for i in 0..self.base_dim_val {
base[i] = total_point[i];
}
base
}
fn base_dim(&self) -> usize { self.base_dim_val }
fn fiber_dim(&self) -> usize { self.group_dim }
fn lift(&self, base: &FixedVector, fiber: &FixedVector) -> FixedVector {
let mut total = FixedVector::new(self.base_dim_val + self.group_dim);
for i in 0..self.base_dim_val {
total[i] = base[i];
}
for i in 0..self.group_dim {
total[self.base_dim_val + i] = fiber[i];
}
total
}
fn local_trivialization(&self, total_point: &FixedVector) -> (FixedVector, FixedVector) {
let mut base = FixedVector::new(self.base_dim_val);
let mut fiber = FixedVector::new(self.group_dim);
for i in 0..self.base_dim_val {
base[i] = total_point[i];
}
for i in 0..self.group_dim {
fiber[i] = total_point[self.base_dim_val + i];
}
(base, fiber)
}
}
pub fn apply_representation(
group_element: &FixedMatrix,
fiber_element: &FixedVector,
) -> FixedVector {
group_element.mul_vector(fiber_element)
}
pub fn change_chart(
bundle: &PrincipalBundle,
alpha: usize,
beta: usize,
fiber_alpha: &FixedVector,
) -> FixedVector {
let g = bundle.transition(alpha, beta);
g.mul_vector(fiber_alpha)
}
pub fn vector_bundle_curvature(
bundle: &VectorBundle,
_base_point: &FixedVector,
) -> Result<super::tensor::Tensor, OverflowDetected> {
let k = bundle.fiber_dim_val;
let n = bundle.base_dim_val;
let _h = super::curvature::differentiation_step();
if bundle.connection_coeffs.is_none() {
return Ok(super::tensor::Tensor::new(&[k, k, n, n]));
}
let mut curv = super::tensor::Tensor::new(&[k, k, n, n]);
for a in 0..k {
for b in 0..k {
for i in 0..n {
for j in 0..n {
let d_i_a_bj = FixedPoint::ZERO; let d_j_a_bi = FixedPoint::ZERO;
let mut pos_terms: Vec<BinaryStorage> = Vec::with_capacity(k);
let mut neg_terms: Vec<BinaryStorage> = Vec::with_capacity(k);
for c in 0..k {
pos_terms.push((bundle.get_coeff(a, c, i) * bundle.get_coeff(c, b, j)).raw());
neg_terms.push((bundle.get_coeff(a, c, j) * bundle.get_coeff(c, b, i)).raw());
}
let pos_ones: Vec<BinaryStorage> = vec![FixedPoint::one().raw(); pos_terms.len()];
let neg_ones: Vec<BinaryStorage> = vec![FixedPoint::one().raw(); neg_terms.len()];
let quad_pos = FixedPoint::from_raw(compute_tier_dot_raw(&pos_terms, &pos_ones));
let quad_neg = FixedPoint::from_raw(compute_tier_dot_raw(&neg_terms, &neg_ones));
curv.set(&[a, b, i, j], d_i_a_bj - d_j_a_bi + quad_pos - quad_neg);
}
}
}
}
Ok(curv)
}