use scirs2_core::ndarray::{Array, ArrayD, Axis, IxDyn};
#[derive(Debug, thiserror::Error)]
pub enum ScoringError {
#[error("Shape mismatch: input {input:?}, weights {weights:?}")]
ShapeMismatch {
input: Vec<usize>,
weights: Vec<usize>,
},
#[error("Axis {axis} out of range for {ndim}D tensor")]
AxisOutOfRange {
axis: usize,
ndim: usize,
},
#[error("Division by zero in weight normalization")]
ZeroWeightSum,
#[error("Invalid probability value {value}: must be in [0, 1]")]
InvalidProbability {
value: f64,
},
#[error("Empty input tensor")]
EmptyInput,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ScoringMode {
Standard,
LogProbability,
LogOdds,
}
#[derive(Debug, Clone)]
pub struct ScoringConfig {
pub mode: ScoringMode,
pub log_floor: f64,
pub temperature: f64,
}
impl Default for ScoringConfig {
fn default() -> Self {
Self {
mode: ScoringMode::Standard,
log_floor: f64::MIN_POSITIVE.ln(), temperature: 1.0,
}
}
}
impl ScoringConfig {
pub fn log_probability() -> Self {
Self {
mode: ScoringMode::LogProbability,
..Self::default()
}
}
pub fn log_odds() -> Self {
Self {
mode: ScoringMode::LogOdds,
..Self::default()
}
}
pub fn with_temperature(mut self, t: f64) -> Self {
self.temperature = t;
self
}
}
fn log_sum_exp_slice(slice: &[f64], log_floor: f64) -> f64 {
if slice.is_empty() {
return log_floor;
}
let max = slice.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
if max == f64::NEG_INFINITY {
return log_floor;
}
let sum_exp: f64 = slice.iter().map(|&x| (x - max).exp()).sum();
max + sum_exp.ln()
}
fn log_sum_exp_along_axis(
input: &ArrayD<f64>,
axis: usize,
log_floor: f64,
) -> Result<ArrayD<f64>, ScoringError> {
if axis >= input.ndim() {
return Err(ScoringError::AxisOutOfRange {
axis,
ndim: input.ndim(),
});
}
if input.is_empty() {
return Err(ScoringError::EmptyInput);
}
Ok(input.map_axis(Axis(axis), |lane| {
let s: Vec<f64> = lane.iter().cloned().collect();
log_sum_exp_slice(&s, log_floor)
}))
}
fn log_product_along_axis(input: &ArrayD<f64>, axis: usize) -> Result<ArrayD<f64>, ScoringError> {
if axis >= input.ndim() {
return Err(ScoringError::AxisOutOfRange {
axis,
ndim: input.ndim(),
});
}
if input.is_empty() {
return Err(ScoringError::EmptyInput);
}
Ok(input.map_axis(Axis(axis), |lane| lane.iter().sum::<f64>()))
}
pub struct LogSpaceAggregator {
config: ScoringConfig,
}
impl LogSpaceAggregator {
pub fn new(config: ScoringConfig) -> Self {
Self { config }
}
pub fn log_sum_exp(
&self,
input: &ArrayD<f64>,
axis: Option<usize>,
) -> Result<ArrayD<f64>, ScoringError> {
if input.is_empty() {
return Err(ScoringError::EmptyInput);
}
match axis {
None => {
let flat: Vec<f64> = input.iter().cloned().collect();
let result = log_sum_exp_slice(&flat, self.config.log_floor);
Ok(ArrayD::from_elem(IxDyn(&[]), result))
}
Some(ax) => log_sum_exp_along_axis(input, ax, self.config.log_floor),
}
}
pub fn log_product(
&self,
input: &ArrayD<f64>,
axis: Option<usize>,
) -> Result<ArrayD<f64>, ScoringError> {
if input.is_empty() {
return Err(ScoringError::EmptyInput);
}
match axis {
None => {
let result: f64 = input.iter().sum();
let result = result.max(self.config.log_floor);
Ok(ArrayD::from_elem(IxDyn(&[]), result))
}
Some(ax) => {
let out = log_product_along_axis(input, ax)?;
Ok(out.mapv(|v| v.max(self.config.log_floor)))
}
}
}
pub fn log_add_exp(
&self,
a: &ArrayD<f64>,
b: &ArrayD<f64>,
) -> Result<ArrayD<f64>, ScoringError> {
if a.shape() != b.shape() {
return Err(ScoringError::ShapeMismatch {
input: a.shape().to_vec(),
weights: b.shape().to_vec(),
});
}
let result = a.mapv(|_| 0.0_f64); let result = scirs2_core::ndarray::Zip::from(&result)
.and(a)
.and(b)
.map_collect(|_, &ai, &bi| {
let max = ai.max(bi);
let min = ai.min(bi);
if max == f64::NEG_INFINITY {
self.config.log_floor
} else {
max + (1.0_f64 + (min - max).exp()).ln()
}
});
Ok(result)
}
pub fn to_log_space(&self, probs: &ArrayD<f64>) -> Result<ArrayD<f64>, ScoringError> {
for &v in probs.iter() {
if !v.is_finite() || !(0.0..=1.0).contains(&v) {
return Err(ScoringError::InvalidProbability { value: v });
}
}
let floor = self.config.log_floor;
Ok(probs.mapv(|p| if p <= 0.0 { floor } else { p.ln().max(floor) }))
}
pub fn from_log_space(&self, log_probs: &ArrayD<f64>) -> Result<ArrayD<f64>, ScoringError> {
Ok(log_probs.mapv(|lp| lp.exp()))
}
}
fn validate_weights_for_axis(
input: &ArrayD<f64>,
weights: &ArrayD<f64>,
axis: Option<usize>,
) -> Result<(), ScoringError> {
match axis {
None => {
if weights.shape() != input.shape() && weights.len() != input.len() {
return Err(ScoringError::ShapeMismatch {
input: input.shape().to_vec(),
weights: weights.shape().to_vec(),
});
}
}
Some(ax) => {
if ax >= input.ndim() {
return Err(ScoringError::AxisOutOfRange {
axis: ax,
ndim: input.ndim(),
});
}
let expected_len = input.shape()[ax];
let compatible = weights.shape() == input.shape()
|| (weights.ndim() == 1 && weights.len() == expected_len);
if !compatible {
return Err(ScoringError::ShapeMismatch {
input: input.shape().to_vec(),
weights: weights.shape().to_vec(),
});
}
}
}
Ok(())
}
pub struct WeightedQuantifier {
config: ScoringConfig,
}
impl WeightedQuantifier {
pub fn new(config: ScoringConfig) -> Self {
Self { config }
}
pub fn weighted_exists(
&self,
input: &ArrayD<f64>,
weights: &ArrayD<f64>,
axis: Option<usize>,
) -> Result<ArrayD<f64>, ScoringError> {
if input.is_empty() {
return Err(ScoringError::EmptyInput);
}
validate_weights_for_axis(input, weights, axis)?;
match self.config.mode {
ScoringMode::Standard => self.weighted_exists_standard(input, weights, axis),
ScoringMode::LogProbability | ScoringMode::LogOdds => {
self.weighted_exists_log(input, weights, axis)
}
}
}
fn weighted_exists_standard(
&self,
input: &ArrayD<f64>,
weights: &ArrayD<f64>,
axis: Option<usize>,
) -> Result<ArrayD<f64>, ScoringError> {
let w = broadcast_weights(weights, input, axis)?;
let weight_sum: f64 = w.iter().sum();
if weight_sum == 0.0 {
return Err(ScoringError::ZeroWeightSum);
}
match axis {
None => {
let numerator: f64 = input.iter().zip(w.iter()).map(|(&x, &wi)| wi * x).sum();
let result = numerator / weight_sum;
Ok(ArrayD::from_elem(IxDyn(&[]), result))
}
Some(ax) => {
let weighted = input * &w;
let num = weighted.sum_axis(Axis(ax));
let w_sum = w.sum_axis(Axis(ax));
let result = scirs2_core::ndarray::Zip::from(&num)
.and(&w_sum)
.map_collect(|&n, &ws| if ws == 0.0 { 0.0 } else { n / ws });
Ok(result)
}
}
}
fn weighted_exists_log(
&self,
input: &ArrayD<f64>,
weights: &ArrayD<f64>,
axis: Option<usize>,
) -> Result<ArrayD<f64>, ScoringError> {
let w = broadcast_weights(weights, input, axis)?;
let weight_sum: f64 = w.iter().sum();
if weight_sum == 0.0 {
return Err(ScoringError::ZeroWeightSum);
}
let log_norm = weight_sum.ln();
let floor = self.config.log_floor;
let log_w_plus_x =
scirs2_core::ndarray::Zip::from(&w)
.and(input)
.map_collect(|&wi, &xi| {
if wi <= 0.0 {
floor
} else {
(wi.ln() + xi).max(floor)
}
});
let agg = LogSpaceAggregator::new(self.config.clone());
let lse = agg.log_sum_exp(&log_w_plus_x, axis)?;
Ok(lse.mapv(|v| v - log_norm))
}
pub fn weighted_forall(
&self,
input: &ArrayD<f64>,
weights: &ArrayD<f64>,
axis: Option<usize>,
) -> Result<ArrayD<f64>, ScoringError> {
if input.is_empty() {
return Err(ScoringError::EmptyInput);
}
validate_weights_for_axis(input, weights, axis)?;
match self.config.mode {
ScoringMode::Standard => self.weighted_forall_standard(input, weights, axis),
ScoringMode::LogProbability | ScoringMode::LogOdds => {
self.weighted_forall_log(input, weights, axis)
}
}
}
fn weighted_forall_standard(
&self,
input: &ArrayD<f64>,
weights: &ArrayD<f64>,
axis: Option<usize>,
) -> Result<ArrayD<f64>, ScoringError> {
let w = broadcast_weights(weights, input, axis)?;
let weight_sum: f64 = w.iter().sum();
if weight_sum == 0.0 {
return Err(ScoringError::ZeroWeightSum);
}
let log_input = input.mapv(|x| {
if x <= 0.0 {
self.config.log_floor
} else {
x.ln()
}
});
match axis {
None => {
let log_geo: f64 = log_input
.iter()
.zip(w.iter())
.map(|(&lx, &wi)| lx * wi / weight_sum)
.sum();
Ok(ArrayD::from_elem(IxDyn(&[]), log_geo.exp()))
}
Some(ax) => {
let w_sum_ax = w.sum_axis(Axis(ax));
let weighted_log = &log_input * &w;
let num = weighted_log.sum_axis(Axis(ax));
let result = scirs2_core::ndarray::Zip::from(&num)
.and(&w_sum_ax)
.map_collect(|&n, &ws| {
if ws == 0.0 {
1.0 } else {
(n / ws).exp()
}
});
Ok(result)
}
}
}
fn weighted_forall_log(
&self,
input: &ArrayD<f64>,
weights: &ArrayD<f64>,
axis: Option<usize>,
) -> Result<ArrayD<f64>, ScoringError> {
let w = broadcast_weights(weights, input, axis)?;
let weight_sum: f64 = w.iter().sum();
if weight_sum == 0.0 {
return Err(ScoringError::ZeroWeightSum);
}
match axis {
None => {
let result: f64 = input
.iter()
.zip(w.iter())
.map(|(&xi, &wi)| xi * wi / weight_sum)
.sum();
Ok(ArrayD::from_elem(IxDyn(&[]), result))
}
Some(ax) => {
let w_sum_ax = w.sum_axis(Axis(ax));
let weighted = input * &w;
let num = weighted.sum_axis(Axis(ax));
let result = scirs2_core::ndarray::Zip::from(&num)
.and(&w_sum_ax)
.map_collect(|&n, &ws| if ws == 0.0 { 0.0 } else { n / ws });
Ok(result)
}
}
}
pub fn weighted_exists_grad(
&self,
grad: &ArrayD<f64>,
input: &ArrayD<f64>,
weights: &ArrayD<f64>,
axis: Option<usize>,
) -> Result<ArrayD<f64>, ScoringError> {
if input.is_empty() {
return Err(ScoringError::EmptyInput);
}
validate_weights_for_axis(input, weights, axis)?;
let w = broadcast_weights(weights, input, axis)?;
let weight_sum: f64 = w.iter().sum();
if weight_sum == 0.0 {
return Err(ScoringError::ZeroWeightSum);
}
let w_norm = w.mapv(|wi| wi / weight_sum);
match axis {
None => {
let g_scalar = grad.iter().next().copied().unwrap_or(0.0);
Ok(w_norm.mapv(|wn| wn * g_scalar))
}
Some(ax) => {
let grad_expanded = grad.view().insert_axis(Axis(ax));
Ok(&w_norm * &grad_expanded)
}
}
}
pub fn weighted_forall_grad(
&self,
grad: &ArrayD<f64>,
input: &ArrayD<f64>,
weights: &ArrayD<f64>,
axis: Option<usize>,
) -> Result<ArrayD<f64>, ScoringError> {
if input.is_empty() {
return Err(ScoringError::EmptyInput);
}
validate_weights_for_axis(input, weights, axis)?;
let w = broadcast_weights(weights, input, axis)?;
let weight_sum: f64 = w.iter().sum();
if weight_sum == 0.0 {
return Err(ScoringError::ZeroWeightSum);
}
match self.config.mode {
ScoringMode::Standard => {
let log_input = input.mapv(|x| {
if x <= 0.0 {
self.config.log_floor
} else {
x.ln()
}
});
let forall_out = match axis {
None => {
let log_geo: f64 = log_input
.iter()
.zip(w.iter())
.map(|(&lx, &wi)| lx * wi / weight_sum)
.sum();
ArrayD::from_elem(input.raw_dim(), log_geo.exp())
}
Some(ax) => {
let w_sum_ax = w.sum_axis(Axis(ax));
let weighted_log = &log_input * &w;
let num = weighted_log.sum_axis(Axis(ax));
let out_no_ax = scirs2_core::ndarray::Zip::from(&num)
.and(&w_sum_ax)
.map_collect(|&n, &ws| if ws == 0.0 { 1.0 } else { (n / ws).exp() });
out_no_ax
.insert_axis(Axis(ax))
.broadcast(input.raw_dim())
.map_or_else(|| Array::zeros(input.raw_dim()), |v| v.to_owned())
}
};
let w_norm = w.mapv(|wi| wi / weight_sum);
let scale = scirs2_core::ndarray::Zip::from(&w_norm)
.and(&forall_out)
.and(input)
.map_collect(
|&wn, &out_v, &xi| {
if xi == 0.0 {
0.0
} else {
wn * out_v / xi
}
},
);
match axis {
None => {
let g_scalar = grad.iter().next().copied().unwrap_or(0.0);
Ok(scale.mapv(|s| s * g_scalar))
}
Some(ax) => {
let grad_expanded = grad.view().insert_axis(Axis(ax));
Ok(&scale * &grad_expanded)
}
}
}
ScoringMode::LogProbability | ScoringMode::LogOdds => {
let w_norm = w.mapv(|wi| wi / weight_sum);
match axis {
None => {
let g_scalar = grad.iter().next().copied().unwrap_or(0.0);
Ok(w_norm.mapv(|wn| wn * g_scalar))
}
Some(ax) => {
let grad_expanded = grad.view().insert_axis(Axis(ax));
Ok(&w_norm * &grad_expanded)
}
}
}
}
}
}
fn broadcast_weights(
weights: &ArrayD<f64>,
input: &ArrayD<f64>,
axis: Option<usize>,
) -> Result<ArrayD<f64>, ScoringError> {
if weights.shape() == input.shape() {
return Ok(weights.clone());
}
match axis {
None => {
if weights.len() != input.len() {
return Err(ScoringError::ShapeMismatch {
input: input.shape().to_vec(),
weights: weights.shape().to_vec(),
});
}
weights
.clone()
.into_shape_with_order(input.raw_dim())
.map_err(|_| ScoringError::ShapeMismatch {
input: input.shape().to_vec(),
weights: weights.shape().to_vec(),
})
}
Some(ax) => {
if weights.ndim() == 1 && weights.len() == input.shape()[ax] {
let mut shape = vec![1usize; input.ndim()];
shape[ax] = input.shape()[ax];
let reshaped = weights
.clone()
.into_shape_with_order(IxDyn(&shape))
.map_err(|_| ScoringError::ShapeMismatch {
input: input.shape().to_vec(),
weights: weights.shape().to_vec(),
})?;
reshaped
.broadcast(input.raw_dim())
.map(|v| v.to_owned())
.ok_or_else(|| ScoringError::ShapeMismatch {
input: input.shape().to_vec(),
weights: weights.shape().to_vec(),
})
} else if weights.shape() == input.shape() {
Ok(weights.clone())
} else {
Err(ScoringError::ShapeMismatch {
input: input.shape().to_vec(),
weights: weights.shape().to_vec(),
})
}
}
}
}
pub fn log_sum_exp(
input: &ArrayD<f64>,
axis: Option<usize>,
config: ScoringConfig,
) -> Result<ArrayD<f64>, ScoringError> {
LogSpaceAggregator::new(config).log_sum_exp(input, axis)
}
pub fn weighted_soft_exists(
input: &ArrayD<f64>,
weights: &ArrayD<f64>,
axis: Option<usize>,
config: ScoringConfig,
) -> Result<ArrayD<f64>, ScoringError> {
WeightedQuantifier::new(config).weighted_exists(input, weights, axis)
}
pub fn weighted_soft_forall(
input: &ArrayD<f64>,
weights: &ArrayD<f64>,
axis: Option<usize>,
config: ScoringConfig,
) -> Result<ArrayD<f64>, ScoringError> {
WeightedQuantifier::new(config).weighted_forall(input, weights, axis)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
const EPS: f64 = 1e-9;
fn config() -> ScoringConfig {
ScoringConfig::default()
}
fn agg() -> LogSpaceAggregator {
LogSpaceAggregator::new(config())
}
fn make_1d(data: Vec<f64>) -> ArrayD<f64> {
Array::from_vec(data).into_dyn()
}
fn make_2d(data: Vec<Vec<f64>>) -> ArrayD<f64> {
let rows = data.len();
let cols = data[0].len();
let flat: Vec<f64> = data.into_iter().flatten().collect();
Array2::from_shape_vec((rows, cols), flat)
.expect("valid shape")
.into_dyn()
}
#[test]
fn test_log_sum_exp_scalar() {
let input = make_1d(vec![3.0]);
let result = agg().log_sum_exp(&input, None).expect("log_sum_exp scalar");
assert!(
(result[[]] - 3.0).abs() < EPS,
"expected 3.0, got {}",
result[[]]
);
}
#[test]
fn test_log_sum_exp_zeros() {
let n = 4usize;
let input = make_1d(vec![0.0; n]);
let result = agg().log_sum_exp(&input, None).expect("log_sum_exp zeros");
let expected = (n as f64).ln();
assert!(
(result[[]] - expected).abs() < EPS,
"expected log({}), got {}",
n,
result[[]]
);
}
#[test]
fn test_log_sum_exp_vs_naive() {
let vals = vec![1.0, 2.0, 3.0];
let input = make_1d(vals.clone());
let result = agg().log_sum_exp(&input, None).expect("vs naive");
let naive = vals.iter().map(|&x| x.exp()).sum::<f64>().ln();
assert!(
(result[[]] - naive).abs() < 1e-10,
"stable != naive: {} vs {}",
result[[]],
naive
);
}
#[test]
fn test_log_sum_exp_numerical_stability() {
let input = make_1d(vec![300.0, 299.0, 298.0]);
let result = agg()
.log_sum_exp(&input, None)
.expect("numerical stability");
assert!(
result[[]].is_finite(),
"result should be finite, got {}",
result[[]]
);
let expected = 300.0 + (1.0 + (-1.0_f64).exp() + (-2.0_f64).exp()).ln();
assert!((result[[]] - expected).abs() < 1e-10);
}
#[test]
fn test_log_sum_exp_axis_0() {
let input = make_2d(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]);
let result = agg().log_sum_exp(&input, Some(0)).expect("axis 0");
assert_eq!(result.shape(), &[3]);
for col in 0..3 {
let a = (col + 1) as f64;
let b = (col + 4) as f64;
let expected = a.max(b) + (1.0 + (a.min(b) - a.max(b)).exp()).ln();
assert!((result[[col]] - expected).abs() < 1e-10);
}
}
#[test]
fn test_log_sum_exp_axis_1() {
let input = make_2d(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]);
let result = agg().log_sum_exp(&input, Some(1)).expect("axis 1");
assert_eq!(result.shape(), &[2]);
for row in 0..2 {
let vals: Vec<f64> = (1..=3).map(|c| (row * 3 + c) as f64).collect();
let expected_v = vals.iter().map(|&v| v.exp()).sum::<f64>().ln();
assert!(
(result[[row]] - expected_v).abs() < 1e-8,
"row {}: {} vs {}",
row,
result[[row]],
expected_v
);
}
}
#[test]
fn test_log_sum_exp_full_reduction() {
let input = make_2d(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
let result = agg().log_sum_exp(&input, None).expect("full reduction");
assert_eq!(result.shape(), &[] as &[usize]);
let naive = (1.0_f64.exp() + 2.0_f64.exp() + 3.0_f64.exp() + 4.0_f64.exp()).ln();
assert!((result[[]] - naive).abs() < 1e-8);
}
#[test]
fn test_log_product_basic() {
let input = make_1d(vec![0.5_f64.ln(), 0.25_f64.ln()]);
let result = agg().log_product(&input, None).expect("log_product basic");
let expected = 0.125_f64.ln();
assert!((result[[]] - expected).abs() < 1e-10);
}
#[test]
fn test_log_add_exp_symmetry() {
let a = make_1d(vec![1.0, 2.0, 3.0]);
let b = make_1d(vec![3.0, 1.0, 2.0]);
let ab = agg().log_add_exp(&a, &b).expect("log_add_exp ab");
let ba = agg().log_add_exp(&b, &a).expect("log_add_exp ba");
for i in 0..3 {
assert!(
(ab[[i]] - ba[[i]]).abs() < EPS,
"symmetry violated at {}",
i
);
}
}
#[test]
fn test_to_log_space_range() {
let probs = make_1d(vec![0.0, 0.1, 0.5, 0.9, 1.0]);
let result = agg().to_log_space(&probs).expect("to_log_space");
for &v in result.iter() {
assert!(v <= 0.0, "log-probability must be <= 0, got {}", v);
}
}
#[test]
fn test_from_log_space_roundtrip() {
let probs = make_1d(vec![0.1, 0.5, 0.9]);
let log_p = agg().to_log_space(&probs).expect("to_log_space");
let recovered = agg().from_log_space(&log_p).expect("from_log_space");
for i in 0..3 {
assert!(
(probs[[i]] - recovered[[i]]).abs() < 1e-12,
"roundtrip failed at {}: {} != {}",
i,
probs[[i]],
recovered[[i]]
);
}
}
#[test]
fn test_log_floor_prevents_neg_inf() {
let probs = make_1d(vec![0.0, 0.5, 1.0]); let result = agg().to_log_space(&probs).expect("log_floor");
for &v in result.iter() {
assert!(v.is_finite(), "value should be finite, got {}", v);
}
assert!(result[[0]] <= 0.0, "floor should be <= 0");
}
#[test]
fn test_weighted_exists_uniform_weights() {
let input = make_1d(vec![0.2, 0.4, 0.6, 0.8]);
let weights = make_1d(vec![1.0, 1.0, 1.0, 1.0]);
let q = WeightedQuantifier::new(config());
let result = q
.weighted_exists(&input, &weights, None)
.expect("uniform weights");
let expected = 0.5; assert!(
(result[[]] - expected).abs() < EPS,
"expected {}, got {}",
expected,
result[[]]
);
}
#[test]
fn test_weighted_exists_zero_weight_error() {
let input = make_1d(vec![0.5, 0.5]);
let weights = make_1d(vec![0.0, 0.0]);
let q = WeightedQuantifier::new(config());
let result = q.weighted_exists(&input, &weights, None);
assert!(
matches!(result, Err(ScoringError::ZeroWeightSum)),
"expected ZeroWeightSum error"
);
}
#[test]
fn test_weighted_exists_concentrated_weight() {
let input = make_1d(vec![0.1, 0.3, 0.7, 0.9]);
let weights = make_1d(vec![0.0, 0.0, 1.0, 0.0]);
let q = WeightedQuantifier::new(config());
let result = q
.weighted_exists(&input, &weights, None)
.expect("concentrated weight");
assert!(
(result[[]] - 0.7).abs() < EPS,
"expected 0.7, got {}",
result[[]]
);
}
#[test]
fn test_weighted_forall_uniform() {
let vals = vec![0.5, 0.25, 1.0, 0.5];
let input = make_1d(vals.clone());
let weights = make_1d(vec![1.0; 4]);
let q = WeightedQuantifier::new(config());
let result = q
.weighted_forall(&input, &weights, None)
.expect("forall uniform");
let geo: f64 = vals.iter().product::<f64>().powf(0.25);
assert!(
(result[[]] - geo).abs() < 1e-10,
"expected {}, got {}",
geo,
result[[]]
);
}
#[test]
fn test_weighted_exists_gradient_shape() {
let input = make_2d(vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]]);
let weights = make_2d(vec![vec![1.0, 2.0, 1.0], vec![1.0, 2.0, 1.0]]);
let q = WeightedQuantifier::new(config());
let out = q
.weighted_exists(&input, &weights, Some(1))
.expect("forward");
assert_eq!(out.shape(), &[2]);
let grad = Array::ones(out.raw_dim());
let d_input = q
.weighted_exists_grad(&grad, &input, &weights, Some(1))
.expect("grad");
assert_eq!(
d_input.shape(),
input.shape(),
"gradient should match input shape"
);
}
#[test]
fn test_weighted_exists_gradient_finite() {
let input = make_1d(vec![0.2, 0.5, 0.8]);
let weights = make_1d(vec![1.0, 3.0, 1.0]);
let q = WeightedQuantifier::new(config());
let out = q.weighted_exists(&input, &weights, None).expect("forward");
let grad = Array::ones(out.raw_dim());
let d_input = q
.weighted_exists_grad(&grad, &input, &weights, None)
.expect("grad");
for &v in d_input.iter() {
assert!(v.is_finite(), "gradient must be finite, got {}", v);
}
}
#[test]
fn test_scoring_config_default() {
let cfg = ScoringConfig::default();
assert_eq!(cfg.mode, ScoringMode::Standard);
assert!((cfg.temperature - 1.0).abs() < EPS);
assert!(cfg.log_floor < -100.0, "log_floor should be very negative");
assert!(cfg.log_floor.is_finite(), "log_floor must be finite");
}
#[test]
fn test_scoring_config_builders() {
let lp = ScoringConfig::log_probability();
assert_eq!(lp.mode, ScoringMode::LogProbability);
let lo = ScoringConfig::log_odds();
assert_eq!(lo.mode, ScoringMode::LogOdds);
let with_t = ScoringConfig::default().with_temperature(0.5);
assert!((with_t.temperature - 0.5).abs() < EPS);
}
#[test]
fn test_free_function_log_sum_exp() {
let input = make_1d(vec![0.0, 0.0, 0.0]);
let result = log_sum_exp(&input, None, config()).expect("free fn log_sum_exp");
let expected = (3.0_f64).ln();
assert!((result[[]] - expected).abs() < EPS);
}
#[test]
fn test_log_space_quantifier_mode_via_gradient_ops() {
use crate::gradient_ops::{soft_exists, QuantifierMode};
let input = make_1d(vec![0.0, 0.0, 0.0]);
let scoring_cfg = ScoringConfig::log_probability();
let mode = QuantifierMode::LogSpace(scoring_cfg);
let result = soft_exists(&input, None, mode).expect("log_space quantifier");
let expected = (3.0_f64).ln(); assert!(
(result[[]] - expected).abs() < 1e-10,
"expected log(3)={}, got {}",
expected,
result[[]]
);
}
}