use std::fmt::Debug;
use crate::prelude::*;
pub mod steifiel;
pub use steifiel::SteifielsManifold;
pub mod sphere;
pub use sphere::Sphere;
pub mod matrix_groups;
pub use matrix_groups::OrthogonalGroup;
pub mod utils;
pub trait Manifold<B: Backend>: Clone + Send + Sync {
const RANK_PER_POINT: usize;
fn new() -> Self;
fn name() -> &'static str;
#[must_use]
fn specific_name(s: &Shape) -> String {
let dims = &s.dims;
let num_dims = dims.len();
let (channel_dims, manifold_dims) = dims.split_at(num_dims - Self::RANK_PER_POINT);
format!(
"{channel_dims:?} Channels worth of points in {} with specific n's {manifold_dims:?}",
Self::name()
)
}
#[must_use]
fn acceptable_shape(s: &Shape) -> bool {
let enough_points = s.num_dims() >= Self::RANK_PER_POINT;
if !enough_points {
return false;
}
let (_, a_i) = s.dims.split_at(s.num_dims() - Self::RANK_PER_POINT);
Self::acceptable_dims(a_i)
}
fn acceptable_dims(_a_is: &[usize]) -> bool;
fn project<const D: usize>(point: Tensor<B, D>, vector: Tensor<B, D>) -> Tensor<B, D>;
fn egrad2rgrad<const D: usize>(point: Tensor<B, D>, grad: Tensor<B, D>) -> Tensor<B, D> {
Self::project(point, grad)
}
fn project_tangent<const D: usize>(point: Tensor<B, D>, vector: Tensor<B, D>) -> Tensor<B, D> {
Self::project(point, vector)
}
fn parallel_transport<const D: usize>(
_point1: Tensor<B, D>,
point2: Tensor<B, D>,
tangent: Tensor<B, D>,
) -> Tensor<B, D> {
Self::project_tangent(point2, tangent)
}
fn retract<const D: usize>(point: Tensor<B, D>, direction: Tensor<B, D>) -> Tensor<B, D>;
fn expmap<const D: usize>(point: Tensor<B, D>, direction: Tensor<B, D>) -> Tensor<B, D> {
Self::retract(point, direction)
}
fn inner<const D: usize>(point: Tensor<B, D>, u: Tensor<B, D>, v: Tensor<B, D>)
-> Tensor<B, D>;
fn proj<const D: usize>(point: Tensor<B, D>) -> Tensor<B, D>;
fn is_in_manifold<const D: usize>(point: Tensor<B, D>) -> Tensor<B, D, burn::tensor::Bool>;
fn is_tangent_at<const D: usize>(
point: Tensor<B, D>,
vector: Tensor<B, D>,
) -> Tensor<B, D, burn::tensor::Bool>;
}
#[derive(Clone, Debug)]
pub struct Euclidean;
impl<B: Backend> Manifold<B> for Euclidean {
const RANK_PER_POINT: usize = 1;
fn new() -> Self {
Self
}
fn name() -> &'static str {
"Euclidean"
}
fn project<const D: usize>(_point: Tensor<B, D>, vector: Tensor<B, D>) -> Tensor<B, D> {
vector
}
fn retract<const D: usize>(point: Tensor<B, D>, direction: Tensor<B, D>) -> Tensor<B, D> {
point + direction
}
fn inner<const D: usize>(
_point: Tensor<B, D>,
u: Tensor<B, D>,
v: Tensor<B, D>,
) -> Tensor<B, D> {
(u * v).sum_dim(D - 1)
}
fn is_in_manifold<const D: usize>(
point: Tensor<B, D>,
) -> burn::tensor::Tensor<B, D, burn::tensor::Bool> {
point
.clone()
.detach()
.is_nan()
.any_dim(<Self as Manifold<B>>::RANK_PER_POINT)
.bool_not()
}
fn proj<const D: usize>(point: Tensor<B, D>) -> Tensor<B, D> {
point
}
fn is_tangent_at<const D: usize>(
point: Tensor<B, D>,
vector: Tensor<B, D>,
) -> Tensor<B, D, burn::tensor::Bool> {
let vector_exists = vector
.clone()
.detach()
.is_nan()
.any_dim(<Self as Manifold<B>>::RANK_PER_POINT)
.bool_not();
Self::is_in_manifold(point).bool_and(vector_exists)
}
fn acceptable_dims(_a_is: &[usize]) -> bool {
true
}
}