use std::collections::HashMap;
use std::fmt::Debug;
use std::marker::PhantomData;
use burn::module::Module;
use crate::constrained_module::Constrained;
use crate::prelude::*;
#[derive(Debug, Clone)]
pub struct MultiManifoldOptimizerConfig {
pub learning_rate: f64,
pub beta1: f64,
pub beta2: f64,
pub eps: f64,
pub weight_decay: f64,
}
impl Default for MultiManifoldOptimizerConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
weight_decay: 0.0,
}
}
}
#[derive(Debug)]
pub struct MultiManifoldOptimizer<B: Backend> {
#[allow(unused)]
config: MultiManifoldOptimizerConfig,
_backend: PhantomData<B>,
}
impl<B: Backend> MultiManifoldOptimizer<B> {
#[must_use]
pub fn new(config: MultiManifoldOptimizerConfig) -> Self {
Self {
config,
_backend: PhantomData,
}
}
pub fn collect_manifolds<M: Module<B>>(&mut self, _module: &M) {
}
pub fn register_manifold<M: Manifold<B> + Send + Sync + 'static>(&mut self, _path: String) {
}
pub fn apply_constraints<M: Module<B>>(self, module: M) -> M {
module
}
}
pub trait ManifoldOptimizable<B: Backend>: Module<B> {
#[must_use]
fn apply_manifold_constraints(self) -> Self;
fn get_manifold_info(&self) -> HashMap<String, String>;
}
impl<B, M, Man> ManifoldOptimizable<B> for Constrained<M, Man>
where
M: Module<B>,
B: Backend,
Man: Manifold<B> + Clone + Debug + Send,
{
fn apply_manifold_constraints(self) -> Self {
self
}
fn get_manifold_info(&self) -> HashMap<String, String> {
let mut info = HashMap::new();
info.insert("manifold_type".to_string(), Man::name().to_string());
info
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
use burn::nn::LinearConfig;
type TestBackend = NdArray;
#[test]
fn test_multi_manifold_optimizer() {
let config = MultiManifoldOptimizerConfig::default();
let optimizer = MultiManifoldOptimizer::<TestBackend>::new(config);
assert_eq!(optimizer.config.learning_rate, 1e-3);
}
#[test]
fn test_constrained_module_trait() {
let device = Default::default();
let linear = LinearConfig::new(2, 2).init::<TestBackend>(&device);
let constrained_linear = Constrained::<_, Euclidean>::new(linear);
let info = constrained_linear.get_manifold_info();
assert_eq!(info.get("manifold_type"), Some(&"Euclidean".to_string()));
}
#[test]
fn test_apply_constraints() {
let config = MultiManifoldOptimizerConfig::default();
let optimizer = MultiManifoldOptimizer::<TestBackend>::new(config);
let device = Default::default();
let linear = LinearConfig::new(2, 2).init::<TestBackend>(&device);
let constrained_linear = Constrained::<_, Euclidean>::new(linear);
let result = optimizer.apply_constraints(constrained_linear);
assert_eq!(
result.get_manifold_info().get("manifold_type"),
Some(&"Euclidean".to_string())
);
}
}