use crate::ScalarLoss;
use terni::Loss;
pub trait Carrier: Clone + Default {
fn compose(self, other: Self) -> Self;
fn norm(&self) -> ScalarLoss;
}
#[derive(Clone, Default, Debug, PartialEq)]
pub struct ScalarConnection {
pub loss: ScalarLoss,
}
impl ScalarConnection {
pub fn zero() -> Self {
ScalarConnection {
loss: ScalarLoss::zero(),
}
}
pub fn new(loss: f64) -> Self {
ScalarConnection {
loss: ScalarLoss::new(loss),
}
}
}
impl Carrier for ScalarConnection {
fn compose(self, other: Self) -> Self {
ScalarConnection {
loss: ScalarLoss::new(self.loss.as_f64() + other.loss.as_f64()),
}
}
fn norm(&self) -> ScalarLoss {
self.loss.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn scalar_connection_zero() {
let c = ScalarConnection::zero();
assert!(c.norm().is_zero());
}
#[test]
fn scalar_connection_new() {
let c = ScalarConnection::new(2.5);
assert_eq!(c.norm().as_f64(), 2.5);
}
#[test]
fn scalar_connection_compose() {
let a = ScalarConnection::new(1.0);
let b = ScalarConnection::new(2.0);
let c = a.compose(b);
assert_eq!(c.norm().as_f64(), 3.0);
}
#[test]
fn scalar_connection_identity_left() {
let id = ScalarConnection::default();
let c = ScalarConnection::new(1.5);
let result = id.compose(c.clone());
assert_eq!(result.norm().as_f64(), c.norm().as_f64());
}
#[test]
fn scalar_connection_identity_right() {
let id = ScalarConnection::default();
let c = ScalarConnection::new(1.5);
let result = c.clone().compose(id);
assert_eq!(result.norm().as_f64(), c.norm().as_f64());
}
}