#![allow(missing_docs)]
use crate::BoxError;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct TransBoxConcept {
center: Vec<f32>,
offset: Vec<f32>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct TransBoxRole {
center: Vec<f32>,
offset: Vec<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransBoxModel {
concepts: Vec<TransBoxConcept>,
roles: Vec<TransBoxRole>,
dim: usize,
}
impl TransBoxConcept {
pub fn new(center: Vec<f32>, offset: Vec<f32>) -> Result<Self, BoxError> {
if center.len() != offset.len() {
return Err(BoxError::DimensionMismatch {
expected: center.len(),
actual: offset.len(),
});
}
for (i, (&c, &o)) in center.iter().zip(offset.iter()).enumerate() {
if !c.is_finite() {
return Err(BoxError::InvalidBounds {
dim: i,
min: c as f64,
max: c as f64,
});
}
if !o.is_finite() || o < 0.0 {
return Err(BoxError::InvalidBounds {
dim: i,
min: 0.0,
max: o as f64,
});
}
}
Ok(Self { center, offset })
}
#[must_use]
pub fn dim(&self) -> usize {
self.center.len()
}
pub fn center(&self) -> &[f32] {
&self.center
}
pub fn offset(&self) -> &[f32] {
&self.offset
}
#[must_use]
pub fn bounds(&self) -> (Vec<f32>, Vec<f32>) {
let min: Vec<f32> = self
.center
.iter()
.zip(self.offset.iter())
.map(|(&c, &o)| c - o)
.collect();
let max: Vec<f32> = self
.center
.iter()
.zip(self.offset.iter())
.map(|(&c, &o)| c + o)
.collect();
(min, max)
}
pub fn center_mut(&mut self) -> &mut [f32] {
&mut self.center
}
pub fn offset_mut(&mut self) -> &mut [f32] {
&mut self.offset
}
}
impl TransBoxRole {
pub fn new(center: Vec<f32>, offset: Vec<f32>) -> Result<Self, BoxError> {
if center.len() != offset.len() {
return Err(BoxError::DimensionMismatch {
expected: center.len(),
actual: offset.len(),
});
}
for (i, (&c, &o)) in center.iter().zip(offset.iter()).enumerate() {
if !c.is_finite() {
return Err(BoxError::InvalidBounds {
dim: i,
min: c as f64,
max: c as f64,
});
}
if !o.is_finite() || o < 0.0 {
return Err(BoxError::InvalidBounds {
dim: i,
min: 0.0,
max: o as f64,
});
}
}
Ok(Self { center, offset })
}
#[must_use]
pub fn dim(&self) -> usize {
self.center.len()
}
pub fn center(&self) -> &[f32] {
&self.center
}
pub fn offset(&self) -> &[f32] {
&self.offset
}
pub fn apply(&self, concept: &TransBoxConcept) -> Result<TransBoxConcept, BoxError> {
if self.center.len() != concept.center.len() {
return Err(BoxError::DimensionMismatch {
expected: concept.center.len(),
actual: self.center.len(),
});
}
let new_center: Vec<f32> = concept
.center
.iter()
.zip(self.center.iter())
.map(|(&c, &r)| c + r)
.collect();
let new_offset: Vec<f32> = concept
.offset
.iter()
.zip(self.offset.iter())
.map(|(&c, &r)| c + r)
.collect();
TransBoxConcept::new(new_center, new_offset)
}
pub fn compose(&self, other: &TransBoxRole) -> Result<TransBoxRole, BoxError> {
if self.center.len() != other.center.len() {
return Err(BoxError::DimensionMismatch {
expected: other.center.len(),
actual: self.center.len(),
});
}
let new_center: Vec<f32> = self
.center
.iter()
.zip(other.center.iter())
.map(|(&a, &b)| a + b)
.collect();
let new_offset: Vec<f32> = self
.offset
.iter()
.zip(other.offset.iter())
.map(|(&a, &b)| a + b)
.collect();
TransBoxRole::new(new_center, new_offset)
}
pub fn center_mut(&mut self) -> &mut [f32] {
&mut self.center
}
pub fn offset_mut(&mut self) -> &mut [f32] {
&mut self.offset
}
}
impl TransBoxModel {
pub fn new(
concepts: Vec<TransBoxConcept>,
roles: Vec<TransBoxRole>,
dim: usize,
) -> Result<Self, BoxError> {
for c in &concepts {
if c.dim() != dim {
return Err(BoxError::DimensionMismatch {
expected: dim,
actual: c.dim(),
});
}
}
for r in &roles {
if r.center.len() != dim {
return Err(BoxError::DimensionMismatch {
expected: dim,
actual: r.center.len(),
});
}
}
Ok(Self {
concepts,
roles,
dim,
})
}
#[must_use]
pub fn dim(&self) -> usize {
self.dim
}
#[must_use]
pub fn num_concepts(&self) -> usize {
self.concepts.len()
}
#[must_use]
pub fn num_roles(&self) -> usize {
self.roles.len()
}
pub fn concepts(&self) -> &[TransBoxConcept] {
&self.concepts
}
pub fn roles(&self) -> &[TransBoxRole] {
&self.roles
}
}
pub fn inclusion_loss(
center_a: &[f32],
offset_a: &[f32],
center_b: &[f32],
offset_b: &[f32],
margin: f32,
) -> Result<f32, BoxError> {
let dim = center_a.len();
if offset_a.len() != dim || center_b.len() != dim || offset_b.len() != dim {
return Err(BoxError::DimensionMismatch {
expected: dim,
actual: offset_a.len().max(center_b.len()).max(offset_b.len()),
});
}
let mut sum_sq = 0.0f32;
for i in 0..dim {
let v = (center_a[i] - center_b[i]).abs() + offset_a[i] - offset_b[i] - margin;
let relu_v = v.max(0.0);
sum_sq += relu_v * relu_v;
}
Ok(sum_sq.sqrt())
}
pub fn score_triple(
head: &TransBoxConcept,
role: &TransBoxRole,
tail: &TransBoxConcept,
margin: f32,
) -> Result<f32, BoxError> {
let transformed = role.apply(head)?;
inclusion_loss(
transformed.center(),
transformed.offset(),
tail.center(),
tail.offset(),
margin,
)
}
pub fn existential_transbox(
role: &TransBoxRole,
filler: &TransBoxConcept,
) -> Result<TransBoxConcept, BoxError> {
role.apply(filler)
}
pub fn subsumption_loss(
subsumer: &TransBoxConcept,
subsumed: &TransBoxConcept,
margin: f32,
) -> Result<f32, BoxError> {
inclusion_loss(
subsumed.center(),
subsumed.offset(),
subsumer.center(),
subsumer.offset(),
margin,
)
}
pub fn intersection(a: &TransBoxConcept, b: &TransBoxConcept) -> Result<TransBoxConcept, BoxError> {
if a.center.len() != b.center.len() {
return Err(BoxError::DimensionMismatch {
expected: a.center.len(),
actual: b.center.len(),
});
}
let center: Vec<f32> = a
.center
.iter()
.zip(b.center.iter())
.map(|(&ca, &cb)| (ca + cb) / 2.0)
.collect();
let offset: Vec<f32> = a
.offset
.iter()
.zip(b.offset.iter())
.zip(a.center.iter())
.zip(b.center.iter())
.map(|(((&oa, &ob), &ca), &cb)| ((oa + ob) / 2.0 - (ca - cb).abs() / 2.0).max(0.0))
.collect();
TransBoxConcept::new(center, offset)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn concept_new_valid() {
let c = TransBoxConcept::new(vec![0.0, 0.0], vec![1.0, 1.0]).unwrap();
assert_eq!(c.dim(), 2);
}
#[test]
fn concept_rejects_negative_offset() {
assert!(TransBoxConcept::new(vec![0.0], vec![-1.0]).is_err());
}
#[test]
fn concept_rejects_non_finite() {
assert!(TransBoxConcept::new(vec![f32::NAN], vec![1.0]).is_err());
assert!(TransBoxConcept::new(vec![0.0], vec![f32::INFINITY]).is_err());
}
#[test]
fn concept_rejects_dim_mismatch() {
assert!(TransBoxConcept::new(vec![0.0, 0.0], vec![1.0]).is_err());
}
#[test]
fn role_new_valid() {
let r = TransBoxRole::new(vec![1.0, 0.0], vec![0.5, 0.5]).unwrap();
assert_eq!(r.center.len(), 2);
}
#[test]
fn role_rejects_negative_offset() {
assert!(TransBoxRole::new(vec![0.0], vec![-0.1]).is_err());
}
#[test]
fn role_apply_additive() {
let c = TransBoxConcept::new(vec![0.0, 0.0], vec![0.5, 0.5]).unwrap();
let r = TransBoxRole::new(vec![1.0, 2.0], vec![0.3, 0.4]).unwrap();
let t = r.apply(&c).unwrap();
assert!((t.center()[0] - 1.0).abs() < 1e-6);
assert!((t.center()[1] - 2.0).abs() < 1e-6);
assert!((t.offset()[0] - 0.8).abs() < 1e-6);
assert!((t.offset()[1] - 0.9).abs() < 1e-6);
}
#[test]
fn role_apply_dimension_mismatch() {
let c = TransBoxConcept::new(vec![0.0, 0.0], vec![1.0, 1.0]).unwrap();
let r = TransBoxRole::new(vec![0.0], vec![1.0]).unwrap();
assert!(r.apply(&c).is_err());
}
#[test]
fn role_compose() {
let r1 = TransBoxRole::new(vec![1.0, 0.0], vec![0.5, 0.5]).unwrap();
let r2 = TransBoxRole::new(vec![0.0, 1.0], vec![0.3, 0.3]).unwrap();
let composed = r1.compose(&r2).unwrap();
assert!((composed.center()[0] - 1.0).abs() < 1e-6);
assert!((composed.center()[1] - 1.0).abs() < 1e-6);
assert!((composed.offset()[0] - 0.8).abs() < 1e-6);
assert!((composed.offset()[1] - 0.8).abs() < 1e-6);
}
#[test]
fn inclusion_loss_contained_is_zero() {
let ca = vec![0.0, 0.0];
let oa = vec![0.5, 0.5];
let cb = vec![0.0, 0.0];
let ob = vec![1.0, 1.0];
let loss = inclusion_loss(&ca, &oa, &cb, &ob, 0.0).unwrap();
assert!(loss.abs() < 1e-6, "contained loss = {loss}, expected 0");
}
#[test]
fn inclusion_loss_not_contained_is_positive() {
let ca = vec![0.0, 0.0];
let oa = vec![1.0, 1.0];
let cb = vec![0.0, 0.0];
let ob = vec![0.5, 0.5];
let loss = inclusion_loss(&ca, &oa, &cb, &ob, 0.0).unwrap();
assert!(loss > 0.0, "non-contained loss = {loss}, expected > 0");
}
#[test]
fn inclusion_loss_with_margin() {
let ca = vec![0.0, 0.0];
let oa = vec![0.4, 0.4];
let cb = vec![0.0, 0.0];
let ob = vec![0.5, 0.5];
let loss = inclusion_loss(&ca, &oa, &cb, &ob, 0.1).unwrap();
assert!(loss.abs() < 1e-6, "with margin loss = {loss}, expected 0");
}
#[test]
fn inclusion_loss_dimension_mismatch() {
assert!(inclusion_loss(&[0.0], &[1.0], &[0.0, 0.0], &[1.0, 1.0], 0.0).is_err());
}
#[test]
fn score_triple_perfect_match() {
let h = TransBoxConcept::new(vec![0.0, 0.0], vec![0.3, 0.3]).unwrap();
let r = TransBoxRole::new(vec![0.0, 0.0], vec![0.0, 0.0]).unwrap();
let t = TransBoxConcept::new(vec![0.0, 0.0], vec![0.5, 0.5]).unwrap();
let s = score_triple(&h, &r, &t, 0.0).unwrap();
assert!(s.abs() < 1e-6, "perfect match score = {s}, expected 0");
}
#[test]
fn score_triple_mismatch_is_positive() {
let h = TransBoxConcept::new(vec![0.0, 0.0], vec![0.5, 0.5]).unwrap();
let r = TransBoxRole::new(vec![5.0, 5.0], vec![0.0, 0.0]).unwrap();
let t = TransBoxConcept::new(vec![0.0, 0.0], vec![0.5, 0.5]).unwrap();
let s = score_triple(&h, &r, &t, 0.0).unwrap();
assert!(s > 0.0, "mismatch score = {s}, expected > 0");
}
#[test]
fn existential_transbox_additive() {
let role = TransBoxRole::new(vec![1.0, 0.0], vec![0.5, 0.5]).unwrap();
let filler = TransBoxConcept::new(vec![0.0, 1.0], vec![0.3, 0.3]).unwrap();
let ex = existential_transbox(&role, &filler).unwrap();
assert!((ex.center()[0] - 1.0).abs() < 1e-6);
assert!((ex.center()[1] - 1.0).abs() < 1e-6);
assert!((ex.offset()[0] - 0.8).abs() < 1e-6);
assert!((ex.offset()[1] - 0.8).abs() < 1e-6);
}
#[test]
fn subsumption_loss_valid_is_zero() {
let parent = TransBoxConcept::new(vec![0.0, 0.0], vec![2.0, 2.0]).unwrap();
let child = TransBoxConcept::new(vec![0.0, 0.0], vec![1.0, 1.0]).unwrap();
let loss = subsumption_loss(&parent, &child, 0.0).unwrap();
assert!(loss.abs() < 1e-6, "valid subsumption loss = {loss}");
}
#[test]
fn subsumption_loss_invalid_is_positive() {
let parent = TransBoxConcept::new(vec![0.0, 0.0], vec![0.5, 0.5]).unwrap();
let child = TransBoxConcept::new(vec![0.0, 0.0], vec![1.0, 1.0]).unwrap();
let loss = subsumption_loss(&parent, &child, 0.0).unwrap();
assert!(
loss > 0.0,
"invalid subsumption loss = {loss}, expected > 0"
);
}
#[test]
fn intersection_of_overlapping() {
let a = TransBoxConcept::new(vec![0.0, 0.0], vec![1.0, 1.0]).unwrap();
let b = TransBoxConcept::new(vec![1.0, 1.0], vec![1.0, 1.0]).unwrap();
let inter = intersection(&a, &b).unwrap();
assert!((inter.center()[0] - 0.5).abs() < 1e-6);
assert!((inter.offset()[0] - 0.5).abs() < 1e-6);
}
#[test]
fn intersection_of_identical() {
let a = TransBoxConcept::new(vec![0.0, 0.0], vec![1.0, 1.0]).unwrap();
let inter = intersection(&a, &a).unwrap();
assert!((inter.center()[0]).abs() < 1e-6);
assert!((inter.offset()[0] - 1.0).abs() < 1e-6);
}
#[test]
fn intersection_disjoint_has_zero_offset() {
let a = TransBoxConcept::new(vec![0.0, 0.0], vec![0.5, 0.5]).unwrap();
let b = TransBoxConcept::new(vec![10.0, 10.0], vec![0.5, 0.5]).unwrap();
let inter = intersection(&a, &b).unwrap();
assert!(inter.offset()[0].abs() < 1e-6);
}
#[test]
fn model_construction() {
let concepts = vec![
TransBoxConcept::new(vec![0.0, 0.0], vec![1.0, 1.0]).unwrap(),
TransBoxConcept::new(vec![1.0, 1.0], vec![0.5, 0.5]).unwrap(),
];
let roles = vec![TransBoxRole::new(vec![0.0, 0.0], vec![0.5, 0.5]).unwrap()];
let model = TransBoxModel::new(concepts, roles, 2).unwrap();
assert_eq!(model.num_concepts(), 2);
assert_eq!(model.num_roles(), 1);
assert_eq!(model.dim(), 2);
}
#[test]
fn model_rejects_dim_mismatch() {
let concepts = vec![TransBoxConcept::new(vec![0.0], vec![1.0]).unwrap()];
let roles = vec![TransBoxRole::new(vec![0.0, 0.0], vec![1.0, 1.0]).unwrap()];
assert!(TransBoxModel::new(concepts, roles, 1).is_err());
}
#[test]
fn concept_bounds() {
let c = TransBoxConcept::new(vec![1.0, 2.0], vec![0.5, 1.0]).unwrap();
let (min, max) = c.bounds();
assert!((min[0] - 0.5).abs() < 1e-6);
assert!((min[1] - 1.0).abs() < 1e-6);
assert!((max[0] - 1.5).abs() < 1e-6);
assert!((max[1] - 3.0).abs() < 1e-6);
}
}