use crate::{LieAlgebra, LieGroup};
use std::fmt;
use std::ops::{Add, Mul, MulAssign, Neg, Sub};
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct RPlusAlgebra(pub(crate) f64);
impl Add for RPlusAlgebra {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self(self.0 + rhs.0)
}
}
impl Add<&RPlusAlgebra> for RPlusAlgebra {
type Output = RPlusAlgebra;
fn add(self, rhs: &RPlusAlgebra) -> RPlusAlgebra {
self + *rhs
}
}
impl Add<RPlusAlgebra> for &RPlusAlgebra {
type Output = RPlusAlgebra;
fn add(self, rhs: RPlusAlgebra) -> RPlusAlgebra {
*self + rhs
}
}
impl Add<&RPlusAlgebra> for &RPlusAlgebra {
type Output = RPlusAlgebra;
fn add(self, rhs: &RPlusAlgebra) -> RPlusAlgebra {
*self + *rhs
}
}
impl Sub for RPlusAlgebra {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
Self(self.0 - rhs.0)
}
}
impl Neg for RPlusAlgebra {
type Output = Self;
fn neg(self) -> Self {
Self(-self.0)
}
}
impl Mul<f64> for RPlusAlgebra {
type Output = Self;
fn mul(self, scalar: f64) -> Self {
Self(self.0 * scalar)
}
}
impl Mul<RPlusAlgebra> for f64 {
type Output = RPlusAlgebra;
fn mul(self, rhs: RPlusAlgebra) -> RPlusAlgebra {
rhs * self
}
}
impl RPlusAlgebra {
#[must_use]
pub fn new(value: f64) -> Self {
Self(value)
}
#[must_use]
pub fn value(&self) -> f64 {
self.0
}
}
impl LieAlgebra for RPlusAlgebra {
const DIM: usize = 1;
fn zero() -> Self {
Self(0.0)
}
fn add(&self, other: &Self) -> Self {
Self(self.0 + other.0)
}
fn scale(&self, scalar: f64) -> Self {
Self(self.0 * scalar)
}
fn norm(&self) -> f64 {
self.0.abs()
}
fn basis_element(i: usize) -> Self {
assert_eq!(i, 0, "ℝ⁺ algebra is 1-dimensional");
Self(1.0)
}
fn from_components(components: &[f64]) -> Self {
assert_eq!(components.len(), 1, "ℝ⁺ algebra has dimension 1");
Self(components[0])
}
fn to_components(&self) -> Vec<f64> {
vec![self.0]
}
fn bracket(&self, _other: &Self) -> Self {
Self::zero()
}
#[inline]
fn inner(&self, other: &Self) -> f64 {
self.0 * other.0
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct RPlus {
value: f64,
}
impl RPlus {
#[must_use]
pub fn from_value(value: f64) -> Self {
assert!(value > 0.0, "ℝ⁺ elements must be positive, got {}", value);
Self { value }
}
#[must_use]
pub fn from_value_clamped(value: f64) -> Self {
Self {
value: value.max(1e-10),
}
}
#[must_use]
pub fn value(&self) -> f64 {
self.value
}
#[must_use]
pub fn from_log(log_value: f64) -> Self {
Self {
value: log_value.exp(),
}
}
#[must_use]
pub fn scaling(magnitude: f64) -> Self {
Self::from_log(magnitude)
}
#[cfg(feature = "rand")]
#[must_use]
pub fn random<R: rand::Rng>(rng: &mut R, log_mean: f64, log_std: f64) -> Self {
use rand::distributions::Distribution;
use rand_distr::Normal;
let normal =
Normal::new(log_mean, log_std).expect("log_std must be non-negative and finite");
Self::from_log(normal.sample(rng))
}
}
impl approx::AbsDiffEq for RPlusAlgebra {
type Epsilon = f64;
fn default_epsilon() -> Self::Epsilon {
1e-10
}
fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
(self.0 - other.0).abs() < epsilon
}
}
impl approx::RelativeEq for RPlusAlgebra {
fn default_max_relative() -> Self::Epsilon {
1e-10
}
fn relative_eq(
&self,
other: &Self,
epsilon: Self::Epsilon,
max_relative: Self::Epsilon,
) -> bool {
approx::RelativeEq::relative_eq(&self.0, &other.0, epsilon, max_relative)
}
}
impl fmt::Display for RPlusAlgebra {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "r+({:.4})", self.0)
}
}
impl Mul<&RPlus> for &RPlus {
type Output = RPlus;
fn mul(self, rhs: &RPlus) -> RPlus {
self.compose(rhs)
}
}
impl Mul<&RPlus> for RPlus {
type Output = RPlus;
fn mul(self, rhs: &RPlus) -> RPlus {
self.compose(rhs)
}
}
impl MulAssign<&RPlus> for RPlus {
fn mul_assign(&mut self, rhs: &RPlus) {
*self = self.compose(rhs);
}
}
impl LieGroup for RPlus {
const MATRIX_DIM: usize = 1;
type Algebra = RPlusAlgebra;
fn identity() -> Self {
Self { value: 1.0 }
}
fn compose(&self, other: &Self) -> Self {
Self {
value: self.value * other.value,
}
}
fn inverse(&self) -> Self {
Self {
value: 1.0 / self.value,
}
}
fn conjugate_transpose(&self) -> Self {
self.inverse()
}
fn adjoint_action(&self, algebra_element: &RPlusAlgebra) -> RPlusAlgebra {
*algebra_element
}
fn distance_to_identity(&self) -> f64 {
self.value.ln().abs()
}
fn exp(tangent: &RPlusAlgebra) -> Self {
Self::from_log(tangent.0)
}
fn log(&self) -> crate::error::LogResult<RPlusAlgebra> {
Ok(RPlusAlgebra(self.value.ln()))
}
}
impl std::fmt::Display for RPlus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ℝ⁺({:.4})", self.value)
}
}
impl crate::Abelian for RPlus {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_identity() {
let e = RPlus::identity();
assert!((e.value() - 1.0).abs() < 1e-10);
}
#[test]
fn test_compose() {
let a = RPlus::from_value(2.0);
let b = RPlus::from_value(3.0);
let product = a.compose(&b);
assert!((product.value() - 6.0).abs() < 1e-10);
}
#[test]
fn test_inverse() {
let a = RPlus::from_value(4.0);
let a_inv = a.inverse();
assert!((a_inv.value() - 0.25).abs() < 1e-10);
let product = a.compose(&a_inv);
assert!((product.value() - 1.0).abs() < 1e-10);
}
#[test]
fn test_exp_log_roundtrip() {
let x = RPlusAlgebra(1.5);
let g = RPlus::exp(&x);
let x_back = g.log().unwrap();
assert!((x_back.value() - x.value()).abs() < 1e-10);
}
#[test]
fn test_log_exp_roundtrip() {
let g = RPlus::from_value(std::f64::consts::E);
let x = g.log().unwrap();
let g_back = RPlus::exp(&x);
assert!((g_back.value() - g.value()).abs() < 1e-10);
}
#[test]
fn test_distance_to_identity() {
let e = RPlus::identity();
assert!(e.distance_to_identity() < 1e-10);
let g = RPlus::from_value(std::f64::consts::E);
assert!((g.distance_to_identity() - 1.0).abs() < 1e-10);
let a = RPlus::from_value(2.0);
let b = RPlus::from_value(0.5);
assert!((a.distance_to_identity() - b.distance_to_identity()).abs() < 1e-10);
}
#[test]
fn test_algebra_operations() {
let x = RPlusAlgebra(1.0);
let y = RPlusAlgebra(2.0);
let sum = x.add(&y);
assert!((sum.value() - 3.0).abs() < 1e-10);
let scaled = x.scale(3.0);
assert!((scaled.value() - 3.0).abs() < 1e-10);
let zero = RPlusAlgebra::zero();
assert!(zero.value().abs() < 1e-10);
}
#[test]
fn test_abelian_property() {
let a = RPlus::from_value(2.0);
let b = RPlus::from_value(3.0);
let ab = a.compose(&b);
let ba = b.compose(&a);
assert!((ab.value() - ba.value()).abs() < 1e-10);
}
#[test]
fn test_from_value_clamped() {
let g = RPlus::from_value_clamped(-0.5);
assert!(g.value() > 0.0);
assert!(g.value() >= 1e-10);
let h = RPlus::from_value_clamped(0.0);
assert!(h.value() > 0.0);
let k = RPlus::from_value_clamped(2.5);
assert!((k.value() - 2.5).abs() < 1e-10);
}
#[test]
fn test_scaling() {
let s0 = RPlus::scaling(0.0);
assert!((s0.value() - 1.0).abs() < 1e-10);
let s1 = RPlus::scaling(1.0);
assert!((s1.value() - std::f64::consts::E).abs() < 1e-10);
let sm1 = RPlus::scaling(-1.0);
assert!((sm1.value() - 1.0 / std::f64::consts::E).abs() < 1e-10);
}
#[test]
#[cfg(feature = "rand")]
fn test_random() {
use rand::SeedableRng;
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
for _ in 0..100 {
let g = RPlus::random(&mut rng, 0.0, 1.0);
assert!(g.value() > 0.0);
}
let mut log_sum = 0.0;
let n = 1000;
for _ in 0..n {
let g = RPlus::random(&mut rng, 0.5, 0.1);
log_sum += g.value().ln();
}
let log_mean = log_sum / n as f64;
assert!(
(log_mean - 0.5).abs() < 0.1,
"Log mean should be approximately 0.5"
);
}
#[test]
fn test_adjoint() {
let g = RPlus::from_value(3.0);
let adj = g.conjugate_transpose();
assert!(
(adj.value() - 1.0 / 3.0).abs() < 1e-10,
"Adjoint should equal inverse"
);
}
#[test]
fn test_adjoint_action() {
let g = RPlus::from_value(5.0);
let x = RPlusAlgebra(2.5);
let result = g.adjoint_action(&x);
assert!((result.value() - x.value()).abs() < 1e-10);
}
#[test]
fn test_display() {
let g = RPlus::from_value(2.5);
let s = format!("{}", g);
assert!(s.contains("2.5"));
assert!(s.contains("ℝ⁺"));
}
#[test]
fn test_algebra_dim() {
assert_eq!(RPlusAlgebra::DIM, 1);
}
#[test]
fn test_algebra_basis_element() {
let basis = RPlusAlgebra::basis_element(0);
assert!((basis.value() - 1.0).abs() < 1e-10);
}
#[test]
#[should_panic(expected = "1-dimensional")]
fn test_algebra_basis_element_out_of_bounds() {
let _ = RPlusAlgebra::basis_element(1);
}
#[test]
fn test_algebra_from_to_components() {
let x = RPlusAlgebra::from_components(&[3.5]);
assert!((x.value() - 3.5).abs() < 1e-10);
let comps = x.to_components();
assert_eq!(comps.len(), 1);
assert!((comps[0] - 3.5).abs() < 1e-10);
}
#[test]
#[should_panic(expected = "dimension 1")]
fn test_algebra_from_components_wrong_dim() {
let _ = RPlusAlgebra::from_components(&[1.0, 2.0]);
}
#[test]
fn test_algebra_norm() {
let x = RPlusAlgebra(-3.0);
assert!((x.norm() - 3.0).abs() < 1e-10);
let y = RPlusAlgebra(2.5);
assert!((y.norm() - 2.5).abs() < 1e-10);
}
#[test]
fn test_from_log() {
let g = RPlus::from_log(0.0);
assert!((g.value() - 1.0).abs() < 1e-10);
let h = RPlus::from_log(1.0);
assert!((h.value() - std::f64::consts::E).abs() < 1e-10);
let k = RPlus::from_log(-1.0);
assert!((k.value() - 1.0 / std::f64::consts::E).abs() < 1e-10);
}
#[test]
fn test_group_dim() {
assert_eq!(RPlus::MATRIX_DIM, 1);
}
#[test]
#[should_panic(expected = "positive")]
fn test_from_value_panics_on_zero() {
let _ = RPlus::from_value(0.0);
}
#[test]
#[should_panic(expected = "positive")]
fn test_from_value_panics_on_negative() {
let _ = RPlus::from_value(-1.0);
}
}