use crate::optimizer::AMSGradState;
use crate::BoxError;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct DenseBox {
pub min: Vec<f32>,
pub max: Vec<f32>,
}
impl DenseBox {
pub fn new(min: Vec<f32>, max: Vec<f32>) -> Self {
Self { min, max }
}
#[inline]
pub fn volume(&self) -> f32 {
self.min
.iter()
.zip(self.max.iter())
.map(|(&a, &b)| (b - a).max(0.0))
.product::<f32>()
}
}
#[derive(Debug, Clone, Serialize)]
pub struct TrainableBox {
pub(crate) mu: Vec<f32>,
pub(crate) delta: Vec<f32>,
}
impl<'de> Deserialize<'de> for TrainableBox {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
struct Raw {
mu: Vec<f32>,
delta: Vec<f32>,
}
let raw = Raw::deserialize(deserializer)?;
if raw.mu.len() != raw.delta.len() {
return Err(serde::de::Error::custom(format!(
"mu (len {}) and delta (len {}) must have same length",
raw.mu.len(),
raw.delta.len()
)));
}
Ok(Self {
mu: raw.mu,
delta: raw.delta,
})
}
}
impl TrainableBox {
pub fn new(mu: Vec<f32>, delta: Vec<f32>) -> Result<Self, BoxError> {
if mu.len() != delta.len() {
return Err(BoxError::DimensionMismatch {
expected: mu.len(),
actual: delta.len(),
});
}
Ok(Self { mu, delta })
}
#[must_use]
pub fn from_vector(vector: &[f32], init_width: f32) -> Self {
let mu = vector.to_vec();
let delta: Vec<f32> = vec![init_width.ln(); mu.len()];
Self::new(mu, delta).expect("from_vector: mu and delta have same length by construction")
}
#[must_use]
pub fn mu(&self) -> &[f32] {
&self.mu
}
#[must_use]
pub fn delta(&self) -> &[f32] {
&self.delta
}
#[must_use]
pub fn dim(&self) -> usize {
self.mu.len()
}
#[must_use]
pub(crate) fn to_box(&self) -> DenseBox {
let min: Vec<f32> = self
.mu
.iter()
.zip(self.delta.iter())
.map(|(&m, &d)| m - (d.exp() / 2.0))
.collect();
let max: Vec<f32> = self
.mu
.iter()
.zip(self.delta.iter())
.map(|(&m, &d)| m + (d.exp() / 2.0))
.collect();
DenseBox::new(min, max)
}
#[cfg(feature = "ndarray-backend")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray-backend")))]
pub fn to_ndarray_box(&self) -> Result<crate::ndarray_backend::NdarrayBox, BoxError> {
let dense = self.to_box();
crate::ndarray_backend::NdarrayBox::new(
ndarray::Array1::from(dense.min),
ndarray::Array1::from(dense.max),
1.0,
)
}
#[must_use]
pub fn num_parameters(&self) -> usize {
2 * self.dim()
}
pub fn update_amsgrad(
&mut self,
grad_mu: &[f32],
grad_delta: &[f32],
state: &mut AMSGradState,
) {
let dim = self.dim();
let n = self.num_parameters();
let mut grads = Vec::with_capacity(n);
grads.extend_from_slice(&grad_mu[..dim]);
grads.extend_from_slice(&grad_delta[..dim]);
state.t += 1;
let t = state.t as f32;
for (i, &g) in grads.iter().enumerate().take(n) {
let g_safe = if g.is_finite() { g } else { 0.0 };
state.m[i] = state.beta1 * state.m[i] + (1.0 - state.beta1) * g_safe;
let v_new = state.beta2 * state.v[i] + (1.0 - state.beta2) * g_safe * g_safe;
state.v[i] = v_new;
state.v_hat[i] = state.v_hat[i].max(v_new);
}
let bias_correction = 1.0 - state.beta1.powf(t);
for i in 0..dim {
let m_hat = state.m[i] / bias_correction;
let update = state.lr * m_hat / (state.v_hat[i].sqrt() + state.epsilon);
self.mu[i] -= update;
if !self.mu[i].is_finite() {
self.mu[i] = 0.0;
}
}
for i in 0..dim {
let idx = dim + i;
let m_hat = state.m[idx] / bias_correction;
let update = state.lr * m_hat / (state.v_hat[idx].sqrt() + state.epsilon);
self.delta[i] -= update;
self.delta[i] = self.delta[i].clamp(0.05_f32.ln(), 10.0_f32.ln());
if !self.delta[i].is_finite() {
self.delta[i] = 0.5_f32.ln();
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct DenseCone {
pub axes: Vec<f32>,
pub apertures: Vec<f32>,
}
impl DenseCone {
pub fn new(axes: Vec<f32>, apertures: Vec<f32>) -> Self {
Self { axes, apertures }
}
#[inline]
pub fn dim(&self) -> usize {
self.axes.len()
}
#[inline]
pub fn cone_distance(&self, entity: &Self, cen: f32) -> f32 {
let mut dist_out = 0.0_f32;
let mut dist_in = 0.0_f32;
for i in 0..self.dim() {
let e = entity.axes[i];
let q_axis = self.axes[i];
let q_aper = self.apertures[i];
let distance_to_axis = ((e - q_axis) / 2.0).sin().abs();
let distance_base = (q_aper / 2.0).sin().abs();
if distance_to_axis < distance_base {
dist_in += distance_to_axis.min(distance_base);
} else {
let delta1 = e - (q_axis - q_aper);
let delta2 = e - (q_axis + q_aper);
let d1 = (delta1 / 2.0).sin().abs();
let d2 = (delta2 / 2.0).sin().abs();
dist_out += d1.min(d2);
}
}
dist_out + cen * dist_in
}
}
#[derive(Debug, Clone, Serialize)]
pub struct TrainableCone {
pub(crate) raw_axes: Vec<f32>,
pub(crate) raw_apertures: Vec<f32>,
}
impl<'de> Deserialize<'de> for TrainableCone {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
struct Raw {
raw_axes: Vec<f32>,
raw_apertures: Vec<f32>,
}
let raw = Raw::deserialize(deserializer)?;
if raw.raw_axes.len() != raw.raw_apertures.len() {
return Err(serde::de::Error::custom(format!(
"raw_axes (len {}) and raw_apertures (len {}) must have same length",
raw.raw_axes.len(),
raw.raw_apertures.len()
)));
}
Ok(Self {
raw_axes: raw.raw_axes,
raw_apertures: raw.raw_apertures,
})
}
}
impl TrainableCone {
pub fn new(raw_axes: Vec<f32>, raw_apertures: Vec<f32>) -> Result<Self, BoxError> {
if raw_axes.len() != raw_apertures.len() {
return Err(BoxError::DimensionMismatch {
expected: raw_axes.len(),
actual: raw_apertures.len(),
});
}
Ok(Self {
raw_axes,
raw_apertures,
})
}
#[must_use]
pub fn from_vector(vector: &[f32], init_aperture: f32) -> Self {
let pi = std::f32::consts::PI;
let raw_axes: Vec<f32> = vector
.iter()
.map(|&v| {
let clamped = (v / pi).clamp(-0.999, 0.999);
clamped.atanh()
})
.collect();
let ratio = ((init_aperture - pi / 2.0) / (pi / 2.0)).clamp(-0.999, 0.999);
let raw_aper = ratio.atanh() / 2.0;
let raw_apertures = vec![raw_aper; vector.len()];
Self::new(raw_axes, raw_apertures)
.expect("from_vector: raw_axes and raw_apertures have same length by construction")
}
#[must_use]
pub fn raw_axes(&self) -> &[f32] {
&self.raw_axes
}
#[must_use]
pub fn raw_apertures(&self) -> &[f32] {
&self.raw_apertures
}
#[must_use]
pub fn dim(&self) -> usize {
self.raw_axes.len()
}
#[must_use]
pub fn axes(&self) -> Vec<f32> {
self.raw_axes
.iter()
.map(|&r| r.tanh() * std::f32::consts::PI)
.collect()
}
#[must_use]
pub fn apertures(&self) -> Vec<f32> {
let pi = std::f32::consts::PI;
self.raw_apertures
.iter()
.map(|&r| (2.0 * r).tanh() * (pi / 2.0) + (pi / 2.0))
.collect()
}
#[must_use]
pub fn mean_aperture(&self) -> f32 {
let aps = self.apertures();
aps.iter().sum::<f32>() / aps.len() as f32
}
#[must_use]
pub(crate) fn to_cone(&self) -> DenseCone {
DenseCone::new(self.axes(), self.apertures())
}
#[cfg(feature = "ndarray-backend")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray-backend")))]
pub fn to_ndarray_cone(
&self,
) -> Result<crate::ndarray_backend::NdarrayCone, crate::cone::ConeError> {
crate::ndarray_backend::NdarrayCone::new(
ndarray::Array1::from(self.axes()),
ndarray::Array1::from(self.apertures()),
)
}
#[must_use]
pub fn num_parameters(&self) -> usize {
2 * self.dim()
}
pub fn cone_distance(&self, entity: &Self, cen: f32) -> f32 {
self.to_cone().cone_distance(&entity.to_cone(), cen)
}
pub fn update_amsgrad(
&mut self,
grad_axes: &[f32],
grad_apertures: &[f32],
state: &mut AMSGradState,
) {
let dim = self.dim();
let n = self.num_parameters();
let mut grads = Vec::with_capacity(n);
grads.extend_from_slice(&grad_axes[..dim]);
grads.extend_from_slice(&grad_apertures[..dim]);
state.t += 1;
let t = state.t as f32;
for (i, &g) in grads.iter().enumerate().take(n) {
let g_safe = if g.is_finite() { g } else { 0.0 };
state.m[i] = state.beta1 * state.m[i] + (1.0 - state.beta1) * g_safe;
let v_new = state.beta2 * state.v[i] + (1.0 - state.beta2) * g_safe * g_safe;
state.v[i] = v_new;
state.v_hat[i] = state.v_hat[i].max(v_new);
}
let bias_correction = 1.0 - state.beta1.powf(t);
for i in 0..dim {
let m_hat = state.m[i] / bias_correction;
let update = state.lr * m_hat / (state.v_hat[i].sqrt() + state.epsilon);
self.raw_axes[i] -= update;
self.raw_axes[i] = self.raw_axes[i].clamp(-6.0, 6.0);
if !self.raw_axes[i].is_finite() {
self.raw_axes[i] = 0.0;
}
}
for i in 0..dim {
let idx = dim + i;
let m_hat = state.m[idx] / bias_correction;
let update = state.lr * m_hat / (state.v_hat[idx].sqrt() + state.epsilon);
self.raw_apertures[i] -= update;
self.raw_apertures[i] = self.raw_apertures[i].clamp(-6.0, 6.0);
if !self.raw_apertures[i].is_finite() {
self.raw_apertures[i] = 0.0;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn trainable_cone_apertures_in_valid_range() {
for raw_a in [-3.0, -1.0, 0.0, 1.0, 3.0] {
let cone = TrainableCone::new(vec![0.0, 0.0], vec![raw_a, raw_a]).unwrap();
let aps = cone.apertures();
for (i, &a) in aps.iter().enumerate() {
assert!(
a > 0.0 && a < std::f32::consts::PI,
"aperture[{i}] must be in (0, pi), got {a} for raw_aperture={raw_a}",
);
}
}
for raw_a in [-10.0, 10.0] {
let cone = TrainableCone::new(vec![0.0], vec![raw_a]).unwrap();
let a = cone.apertures()[0];
assert!((0.0..=std::f32::consts::PI).contains(&a));
}
}
#[test]
fn trainable_cone_axes_in_valid_range() {
for raw_a in [-10.0, -1.0, 0.0, 1.0, 10.0] {
let cone = TrainableCone::new(vec![raw_a, raw_a], vec![0.0, 0.0]).unwrap();
let axes = cone.axes();
for (i, &a) in axes.iter().enumerate() {
assert!(
(-std::f32::consts::PI..=std::f32::consts::PI).contains(&a),
"axes[{i}] must be in [-pi, pi], got {a} for raw_axis={raw_a}",
);
}
}
}
#[test]
fn trainable_cone_from_vector_roundtrip() {
let init_aperture = 1.0_f32;
let cone = TrainableCone::from_vector(&[1.0, 0.0, -0.5], init_aperture);
let aps = cone.apertures();
for &a in &aps {
assert!(
(a - init_aperture).abs() < 0.05,
"aperture should roundtrip, expected {init_aperture} got {a}",
);
}
}
#[test]
fn trainable_cone_to_dense_cone() {
let cone = TrainableCone::new(vec![0.0, 0.0], vec![0.0, 0.0]).unwrap();
let dense = cone.to_cone();
for &a in &dense.axes {
assert!(a.abs() < 1e-6, "axis should be 0, got {a}");
}
for &a in &dense.apertures {
assert!(
(a - std::f32::consts::FRAC_PI_2).abs() < 1e-6,
"aperture should be pi/2, got {a}"
);
}
}
#[test]
fn dense_cone_distance_wide_contains_narrow() {
let wide = DenseCone::new(vec![0.5, 0.5], vec![2.5, 2.5]);
let narrow = DenseCone::new(vec![0.5, 0.5], vec![0.3, 0.3]);
let d = wide.cone_distance(&narrow, 0.02);
assert!(
d < 0.1,
"wide cone should have low distance to narrow entity, got {d}"
);
}
#[test]
fn dense_cone_distance_far_entity_has_high_distance() {
let query = DenseCone::new(vec![0.0, 0.0], vec![0.3, 0.3]);
let near = DenseCone::new(vec![0.1, 0.1], vec![0.1, 0.1]);
let far = DenseCone::new(vec![3.0, 3.0], vec![0.1, 0.1]);
let d_near = query.cone_distance(&near, 0.02);
let d_far = query.cone_distance(&far, 0.02);
assert!(
d_far > d_near,
"far entity should have higher distance: near={d_near}, far={d_far}"
);
}
#[test]
fn trainable_cone_update_amsgrad_does_not_panic() {
let mut cone = TrainableCone::new(vec![0.0, 0.0], vec![0.0, 0.0]).unwrap();
let mut state = AMSGradState::new(cone.num_parameters(), 0.01);
let grad_axes = vec![0.1, -0.1];
let grad_apertures = vec![0.05, 0.05];
cone.update_amsgrad(&grad_axes, &grad_apertures, &mut state);
assert!(cone.raw_axes.iter().all(|x| x.is_finite()));
assert!(cone.raw_apertures.iter().all(|x| x.is_finite()));
}
#[cfg(feature = "ndarray-backend")]
#[test]
fn trainable_cone_to_ndarray_cone_roundtrip() {
let tc = TrainableCone::new(vec![0.5, -0.3, 1.0], vec![0.0, 1.0, -1.0]).unwrap();
let nc = tc.to_ndarray_cone().unwrap();
assert_eq!(nc.dim(), 3);
let tc_axes = tc.axes();
let nc_axes: Vec<f32> = nc.axes().to_vec();
for (i, (&a, &b)) in tc_axes.iter().zip(nc_axes.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-6,
"axis[{i}] mismatch: trainable={a}, ndarray={b}"
);
}
let tc_aps = tc.apertures();
let nc_aps: Vec<f32> = nc.apertures().to_vec();
for (i, (&a, &b)) in tc_aps.iter().zip(nc_aps.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-6,
"aperture[{i}] mismatch: trainable={a}, ndarray={b}"
);
}
}
#[cfg(feature = "ndarray-backend")]
#[test]
fn trainable_box_to_ndarray_box_roundtrip() {
use crate::Box as BoxTrait;
let tb = TrainableBox::new(vec![1.0, 2.0, 3.0], vec![0.0, 0.5, -0.5]).unwrap();
let nb = tb.to_ndarray_box().unwrap();
assert_eq!(nb.dim(), 3);
let dense = tb.to_box();
let volume = nb.volume().unwrap();
let expected_vol: f32 = dense
.min
.iter()
.zip(dense.max.iter())
.map(|(&a, &b)| b - a)
.product();
assert!(
(volume - expected_vol).abs() < 1e-5,
"volume mismatch: got {volume}, expected {expected_vol}"
);
}
#[test]
fn trainable_box_dimension_mismatch_returns_err() {
let result = TrainableBox::new(vec![1.0, 2.0], vec![0.5]);
assert!(
matches!(
result,
Err(BoxError::DimensionMismatch {
expected: 2,
actual: 1
})
),
"expected DimensionMismatch, got {result:?}"
);
}
#[test]
#[cfg(feature = "ndarray-backend")]
fn trainable_box_serde_roundtrip() {
let tb = TrainableBox::new(vec![1.0, -2.5, 3.0], vec![0.5, -0.3, 1.2]).unwrap();
let json = serde_json::to_string(&tb).unwrap();
let tb2: TrainableBox = serde_json::from_str(&json).unwrap();
assert_eq!(tb.mu, tb2.mu);
assert_eq!(tb.delta, tb2.delta);
}
#[test]
#[cfg(feature = "ndarray-backend")]
fn trainable_cone_serde_roundtrip() {
let tc = TrainableCone::new(vec![0.5, -0.3, 1.0], vec![0.0, 1.0, -1.0]).unwrap();
let json = serde_json::to_string(&tc).unwrap();
let tc2: TrainableCone = serde_json::from_str(&json).unwrap();
assert_eq!(tc.raw_axes, tc2.raw_axes);
assert_eq!(tc.raw_apertures, tc2.raw_apertures);
}
#[test]
fn trainable_cone_dimension_mismatch_returns_err() {
let result = TrainableCone::new(vec![0.0, 0.0, 0.0], vec![1.0]);
assert!(
matches!(
result,
Err(BoxError::DimensionMismatch {
expected: 3,
actual: 1
})
),
"expected DimensionMismatch, got {result:?}"
);
}
#[test]
#[cfg(feature = "ndarray-backend")]
fn trainable_cone_deserialize_rejects_length_mismatch() {
let json = r#"{"raw_axes":[1.0,2.0,3.0],"raw_apertures":[1.0]}"#;
let result: Result<TrainableCone, _> = serde_json::from_str(json);
assert!(result.is_err(), "should reject mismatched lengths");
}
}