use std::{fmt::Debug, marker::PhantomData};
use crate::prelude::*;
use burn::{
module::{AutodiffModule, ModuleDisplay},
tensor::backend::AutodiffBackend,
};
#[derive(Clone, Debug)]
pub struct Constrained<M, Man> {
module: M,
_manifold: PhantomData<Man>,
}
impl<B, M, Man> Module<B> for Constrained<M, Man>
where
M: Module<B>,
B: Backend,
Man: Clone + Debug + Send,
{
type Record = M::Record;
fn collect_devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
self.module.collect_devices(devices)
}
fn fork(self, device: &B::Device) -> Self {
let module = self.module.fork(device);
Self {
module,
_manifold: PhantomData,
}
}
fn to_device(self, device: &B::Device) -> Self {
let module = self.module.to_device(device);
Self {
module,
_manifold: PhantomData,
}
}
fn visit<Visitor: burn::module::ModuleVisitor<B>>(&self, visitor: &mut Visitor) {
self.module.visit(visitor);
}
fn map<Mapper: burn::module::ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self {
let module = self.module.map(mapper);
Self {
module,
_manifold: PhantomData,
}
}
fn load_record(self, record: Self::Record) -> Self {
let module = self.module.load_record(record);
Self {
module,
_manifold: PhantomData,
}
}
fn into_record(self) -> Self::Record {
self.module.into_record()
}
}
impl<B, M, Man> AutodiffModule<B> for Constrained<M, Man>
where
M: AutodiffModule<B>,
B: AutodiffBackend,
Man: Clone + Debug + Send,
{
type InnerModule = M::InnerModule;
fn valid(&self) -> Self::InnerModule {
self.module.valid()
}
}
impl<M, Man> burn::module::ModuleDisplayDefault for Constrained<M, Man>
where
M: burn::module::ModuleDisplayDefault,
Man: Clone + Debug + Send,
{
fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
self.module.content(content)
}
}
impl<M, Man> ModuleDisplay for Constrained<M, Man>
where
M: ModuleDisplay,
Man: Clone + Debug + Send,
{
fn format(&self, passed_settings: burn::module::DisplaySettings) -> String {
format!("Constrained<{}>", self.module.format(passed_settings))
}
}
impl<M, Man> Constrained<M, Man> {
pub fn new(module: M) -> Self {
Self {
module,
_manifold: PhantomData,
}
}
pub fn inner(&self) -> &M {
&self.module
}
pub fn inner_mut(&mut self) -> &mut M {
&mut self.module
}
pub fn project_tensor<B, const D: usize>(
&self,
point: Tensor<B, D>,
vector: Tensor<B, D>,
) -> Tensor<B, D>
where
B: Backend,
M: Module<B>,
Man: Manifold<B> + Clone + Debug + Send,
{
Man::project(point, vector)
}
pub fn retract_tensor<B, const D: usize>(
&self,
point: Tensor<B, D>,
direction: Tensor<B, D>,
) -> Tensor<B, D>
where
B: Backend,
M: Module<B>,
Man: Manifold<B> + Clone + Debug + Send,
{
Man::retract(point, direction)
}
pub fn euclidean_to_riemannian<B, const D: usize>(
&self,
point: Tensor<B, D>,
grad: Tensor<B, D>,
) -> Tensor<B, D>
where
B: Backend,
M: Module<B>,
Man: Manifold<B> + Clone + Debug + Send,
{
Man::egrad2rgrad(point, grad)
}
pub fn project_to_manifold<B, const D: usize>(&self, point: Tensor<B, D>) -> Tensor<B, D>
where
B: Backend,
M: Module<B>,
Man: Manifold<B> + Clone + Debug + Send,
{
Man::proj(point)
}
pub fn manifold_name<B>(&self) -> &'static str
where
B: Backend,
Man: Manifold<B>,
{
Man::name()
}
}
pub trait ConstrainedModule<B: Backend> {
#[must_use]
fn apply_manifold_constraints(self) -> Self;
fn get_manifold_info(&self) -> std::collections::HashMap<String, String>;
fn has_manifold_constraints(&self) -> bool {
true
}
}
impl<B, M, Man> ConstrainedModule<B> for Constrained<M, Man>
where
M: Module<B>,
B: Backend,
Man: Manifold<B> + Clone + Debug + Send,
{
fn apply_manifold_constraints(self) -> Self {
self
}
fn get_manifold_info(&self) -> std::collections::HashMap<String, String> {
let mut info = std::collections::HashMap::new();
info.insert("manifold_type".to_string(), Man::name().to_string());
info
}
}