use alloc::format;
use alloc::string::String;
use alloc::vec::Vec;
use crate::domain::DiVector;
use crate::error::{RcfError, RcfResult};
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct FeatureGroup {
name: String,
indices: Vec<usize>,
}
impl FeatureGroup {
pub fn new(
name: impl Into<String>,
indices: impl IntoIterator<Item = usize>,
) -> RcfResult<Self> {
let name = name.into();
if name.is_empty() {
return Err(RcfError::InvalidConfig(
"FeatureGroup name must not be empty".into(),
));
}
let indices: Vec<usize> = indices.into_iter().collect();
if indices.is_empty() {
return Err(RcfError::InvalidConfig(
format!("FeatureGroup \"{name}\" must declare at least one index").into(),
));
}
Ok(Self { name, indices })
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub fn indices(&self) -> &[usize] {
&self.indices
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct FeatureGroups {
groups: Vec<FeatureGroup>,
max_index: usize,
}
impl FeatureGroups {
#[must_use]
pub fn builder() -> FeatureGroupsBuilder {
FeatureGroupsBuilder::default()
}
#[must_use]
pub fn len(&self) -> usize {
self.groups.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.groups.is_empty()
}
#[must_use]
pub fn groups(&self) -> &[FeatureGroup] {
&self.groups
}
#[must_use]
pub fn max_index(&self) -> usize {
self.max_index
}
pub fn validate_for_dimension(&self, d: usize) -> RcfResult<()> {
if self.is_empty() {
return Ok(());
}
if self.max_index >= d {
return Err(RcfError::OutOfBounds {
index: self.max_index,
len: d,
});
}
Ok(())
}
}
#[derive(Debug, Default, Clone)]
pub struct FeatureGroupsBuilder {
groups: Vec<FeatureGroup>,
}
impl FeatureGroupsBuilder {
#[must_use]
pub fn add(
mut self,
name: impl Into<String>,
indices: impl IntoIterator<Item = usize>,
) -> Self {
let name = name.into();
let indices: Vec<usize> = indices.into_iter().collect();
self.groups.push(FeatureGroup { name, indices });
self
}
pub fn build(self) -> RcfResult<FeatureGroups> {
let mut max_index: usize = 0;
for (i, g) in self.groups.iter().enumerate() {
if g.name.is_empty() {
return Err(RcfError::InvalidConfig(
format!("FeatureGroup at position {i} has an empty name").into(),
));
}
if g.indices.is_empty() {
return Err(RcfError::InvalidConfig(
format!(
"FeatureGroup \"{}\" must declare at least one index",
g.name
)
.into(),
));
}
for &idx in &g.indices {
if idx > max_index {
max_index = idx;
}
}
}
for i in 0..self.groups.len() {
for j in (i + 1)..self.groups.len() {
if self.groups[i].name == self.groups[j].name {
return Err(RcfError::InvalidConfig(
format!("duplicate FeatureGroup name \"{}\"", self.groups[i].name).into(),
));
}
}
}
Ok(FeatureGroups {
groups: self.groups,
max_index,
})
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct GroupScores {
scores: Vec<(String, f64)>,
total: f64,
}
impl GroupScores {
#[must_use]
pub fn new(scores: Vec<(String, f64)>, total: f64) -> Self {
Self { scores, total }
}
#[must_use]
pub fn scores(&self) -> &[(String, f64)] {
&self.scores
}
#[must_use]
pub fn total(&self) -> f64 {
self.total
}
#[must_use]
pub fn len(&self) -> usize {
self.scores.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.scores.is_empty()
}
#[must_use]
pub fn explained(&self) -> f64 {
self.scores.iter().map(|(_, s)| *s).sum()
}
#[must_use]
pub fn coverage(&self) -> f64 {
if self.total == 0.0 || !self.total.is_finite() {
return 0.0;
}
self.explained() / self.total
}
#[must_use]
pub fn top_group(&self) -> Option<(&str, f64)> {
self.scores
.iter()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(core::cmp::Ordering::Equal))
.map(|(n, s)| (n.as_str(), *s))
}
}
#[must_use]
pub fn decompose(di: &DiVector, groups: &FeatureGroups) -> GroupScores {
let mut scores = Vec::with_capacity(groups.len());
for group in groups.groups() {
let contribution: f64 = group.indices.iter().map(|&i| di.per_dim_total(i)).sum();
scores.push((group.name.clone(), contribution));
}
GroupScores::new(scores, di.total())
}
impl<const D: usize> crate::forest::RandomCutForest<D> {
pub fn group_scores(&self, point: &[f64; D], groups: &FeatureGroups) -> RcfResult<GroupScores> {
groups.validate_for_dimension(D)?;
let di = self.attribution(point)?;
Ok(decompose(&di, groups))
}
}
impl<const D: usize> crate::thresholded::ThresholdedForest<D> {
pub fn group_scores(&self, point: &[f64; D], groups: &FeatureGroups) -> RcfResult<GroupScores> {
self.forest().group_scores(point, groups)
}
}
#[cfg(feature = "std")]
impl<K, const D: usize> crate::pool::TenantForestPool<K, D>
where
K: core::hash::Hash + Eq + Clone,
{
pub fn group_scores(
&mut self,
key: &K,
point: &[f64; D],
groups: &FeatureGroups,
) -> RcfResult<GroupScores> {
if !self.contains(key) {
self.score_only(key, point)?;
}
let detector = self
.get_mut(key)
.expect("tenant was just forced into the pool");
detector.group_scores(point, groups)
}
}
#[cfg(test)]
#[allow(clippy::float_cmp)] mod tests {
use super::*;
fn four_d() -> FeatureGroups {
FeatureGroups::builder()
.add("rate", [0, 1])
.add("payload", [2, 3])
.build()
.unwrap()
}
#[test]
fn feature_group_rejects_empty_name() {
assert!(FeatureGroup::new("", [0]).is_err());
}
#[test]
fn feature_group_rejects_empty_indices() {
assert!(FeatureGroup::new("rate", std::iter::empty::<usize>()).is_err());
}
#[test]
fn builder_rejects_empty_name() {
let err = FeatureGroups::builder().add("", [0]).build().unwrap_err();
assert!(matches!(err, RcfError::InvalidConfig(_)));
}
#[test]
fn builder_rejects_empty_indices() {
let err = FeatureGroups::builder()
.add("rate", std::iter::empty::<usize>())
.build()
.unwrap_err();
assert!(matches!(err, RcfError::InvalidConfig(_)));
}
#[test]
fn builder_rejects_duplicate_names() {
let err = FeatureGroups::builder()
.add("rate", [0])
.add("rate", [1])
.build()
.unwrap_err();
assert!(matches!(err, RcfError::InvalidConfig(_)));
}
#[test]
fn builder_tracks_max_index() {
let g = FeatureGroups::builder()
.add("a", [0, 3])
.add("b", [7])
.build()
.unwrap();
assert_eq!(g.max_index(), 7);
}
#[test]
fn validate_for_dimension_passes_on_fit() {
let g = four_d();
g.validate_for_dimension(4).unwrap();
}
#[test]
fn validate_for_dimension_rejects_out_of_bounds() {
let g = four_d();
let err = g.validate_for_dimension(3).unwrap_err();
assert!(matches!(err, RcfError::OutOfBounds { .. }));
}
#[test]
fn validate_for_dimension_accepts_empty_set() {
let g = FeatureGroups::default();
g.validate_for_dimension(0).unwrap();
g.validate_for_dimension(100).unwrap();
}
#[test]
fn decompose_partitioning_matches_total() {
let mut di = DiVector::zeros(4);
di.add_high(0, 1.0).unwrap();
di.add_high(1, 2.0).unwrap();
di.add_low(2, 3.0).unwrap();
di.add_low(3, 4.0).unwrap();
let scores = decompose(&di, &four_d());
assert_eq!(scores.scores()[0], ("rate".to_string(), 3.0));
assert_eq!(scores.scores()[1], ("payload".to_string(), 7.0));
assert_eq!(scores.total(), 10.0);
assert_eq!(scores.explained(), 10.0);
assert_eq!(scores.coverage(), 1.0);
}
#[test]
fn decompose_with_gap_has_coverage_below_one() {
let mut di = DiVector::zeros(4);
di.add_high(0, 1.0).unwrap();
di.add_high(1, 1.0).unwrap();
di.add_high(2, 2.0).unwrap();
di.add_high(3, 2.0).unwrap();
let only_rate = FeatureGroups::builder()
.add("rate", [0, 1])
.build()
.unwrap();
let scores = decompose(&di, &only_rate);
assert_eq!(scores.scores()[0], ("rate".to_string(), 2.0));
assert_eq!(scores.total(), 6.0);
assert!((scores.coverage() - 2.0 / 6.0).abs() < 1e-12);
}
#[test]
fn decompose_with_overlap_has_coverage_above_one() {
let mut di = DiVector::zeros(2);
di.add_high(0, 1.0).unwrap();
di.add_high(1, 1.0).unwrap();
let overlap = FeatureGroups::builder()
.add("a", [0])
.add("b", [0, 1])
.build()
.unwrap();
let scores = decompose(&di, &overlap);
assert_eq!(scores.scores()[0], ("a".to_string(), 1.0));
assert_eq!(scores.scores()[1], ("b".to_string(), 2.0));
assert_eq!(scores.total(), 2.0);
assert!((scores.coverage() - 3.0 / 2.0).abs() < 1e-12);
}
#[test]
fn top_group_picks_max_contribution() {
let mut di = DiVector::zeros(4);
di.add_high(0, 1.0).unwrap();
di.add_high(1, 1.0).unwrap();
di.add_low(2, 5.0).unwrap();
di.add_low(3, 5.0).unwrap();
let scores = decompose(&di, &four_d());
let (name, value) = scores.top_group().unwrap();
assert_eq!(name, "payload");
assert_eq!(value, 10.0);
}
#[test]
fn empty_group_scores_top_is_none() {
let scores = GroupScores::new(vec![], 0.0);
assert!(scores.top_group().is_none());
assert!(scores.is_empty());
assert_eq!(scores.coverage(), 0.0);
}
#[test]
fn coverage_handles_zero_total() {
let scores = GroupScores::new(vec![("a".into(), 0.0)], 0.0);
assert_eq!(scores.coverage(), 0.0);
}
}