1use alloc::format;
47use alloc::string::String;
48use alloc::vec::Vec;
49
50use crate::domain::DiVector;
51use crate::error::{RcfError, RcfResult};
52
53#[derive(Debug, Clone, PartialEq, Eq)]
55#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
56pub struct FeatureGroup {
57 name: String,
59 indices: Vec<usize>,
61}
62
63impl FeatureGroup {
64 pub fn new(
73 name: impl Into<String>,
74 indices: impl IntoIterator<Item = usize>,
75 ) -> RcfResult<Self> {
76 let name = name.into();
77 if name.is_empty() {
78 return Err(RcfError::InvalidConfig(
79 "FeatureGroup name must not be empty".into(),
80 ));
81 }
82 let indices: Vec<usize> = indices.into_iter().collect();
83 if indices.is_empty() {
84 return Err(RcfError::InvalidConfig(
85 format!("FeatureGroup \"{name}\" must declare at least one index").into(),
86 ));
87 }
88 Ok(Self { name, indices })
89 }
90
91 #[must_use]
93 pub fn name(&self) -> &str {
94 &self.name
95 }
96
97 #[must_use]
99 pub fn indices(&self) -> &[usize] {
100 &self.indices
101 }
102}
103
104#[derive(Debug, Clone, PartialEq, Eq, Default)]
106#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
107pub struct FeatureGroups {
108 groups: Vec<FeatureGroup>,
111 max_index: usize,
114}
115
116impl FeatureGroups {
117 #[must_use]
119 pub fn builder() -> FeatureGroupsBuilder {
120 FeatureGroupsBuilder::default()
121 }
122
123 #[must_use]
125 pub fn len(&self) -> usize {
126 self.groups.len()
127 }
128
129 #[must_use]
131 pub fn is_empty(&self) -> bool {
132 self.groups.is_empty()
133 }
134
135 #[must_use]
138 pub fn groups(&self) -> &[FeatureGroup] {
139 &self.groups
140 }
141
142 #[must_use]
146 pub fn max_index(&self) -> usize {
147 self.max_index
148 }
149
150 pub fn validate_for_dimension(&self, d: usize) -> RcfResult<()> {
158 if self.is_empty() {
159 return Ok(());
160 }
161 if self.max_index >= d {
162 return Err(RcfError::OutOfBounds {
163 index: self.max_index,
164 len: d,
165 });
166 }
167 Ok(())
168 }
169}
170
171#[derive(Debug, Default, Clone)]
173pub struct FeatureGroupsBuilder {
174 groups: Vec<FeatureGroup>,
176}
177
178impl FeatureGroupsBuilder {
179 #[must_use]
182 pub fn add(
183 mut self,
184 name: impl Into<String>,
185 indices: impl IntoIterator<Item = usize>,
186 ) -> Self {
187 let name = name.into();
188 let indices: Vec<usize> = indices.into_iter().collect();
189 self.groups.push(FeatureGroup { name, indices });
190 self
191 }
192
193 pub fn build(self) -> RcfResult<FeatureGroups> {
200 let mut max_index: usize = 0;
201 for (i, g) in self.groups.iter().enumerate() {
202 if g.name.is_empty() {
203 return Err(RcfError::InvalidConfig(
204 format!("FeatureGroup at position {i} has an empty name").into(),
205 ));
206 }
207 if g.indices.is_empty() {
208 return Err(RcfError::InvalidConfig(
209 format!(
210 "FeatureGroup \"{}\" must declare at least one index",
211 g.name
212 )
213 .into(),
214 ));
215 }
216 for &idx in &g.indices {
217 if idx > max_index {
218 max_index = idx;
219 }
220 }
221 }
222 for i in 0..self.groups.len() {
224 for j in (i + 1)..self.groups.len() {
225 if self.groups[i].name == self.groups[j].name {
226 return Err(RcfError::InvalidConfig(
227 format!("duplicate FeatureGroup name \"{}\"", self.groups[i].name).into(),
228 ));
229 }
230 }
231 }
232 Ok(FeatureGroups {
233 groups: self.groups,
234 max_index,
235 })
236 }
237}
238
239#[derive(Debug, Clone, PartialEq)]
241#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
242pub struct GroupScores {
243 scores: Vec<(String, f64)>,
246 total: f64,
249}
250
251impl GroupScores {
252 #[must_use]
258 pub fn new(scores: Vec<(String, f64)>, total: f64) -> Self {
259 Self { scores, total }
260 }
261
262 #[must_use]
264 pub fn scores(&self) -> &[(String, f64)] {
265 &self.scores
266 }
267
268 #[must_use]
272 pub fn total(&self) -> f64 {
273 self.total
274 }
275
276 #[must_use]
278 pub fn len(&self) -> usize {
279 self.scores.len()
280 }
281
282 #[must_use]
284 pub fn is_empty(&self) -> bool {
285 self.scores.is_empty()
286 }
287
288 #[must_use]
293 pub fn explained(&self) -> f64 {
294 self.scores.iter().map(|(_, s)| *s).sum()
295 }
296
297 #[must_use]
302 pub fn coverage(&self) -> f64 {
303 if self.total == 0.0 || !self.total.is_finite() {
304 return 0.0;
305 }
306 self.explained() / self.total
307 }
308
309 #[must_use]
312 pub fn top_group(&self) -> Option<(&str, f64)> {
313 self.scores
314 .iter()
315 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(core::cmp::Ordering::Equal))
316 .map(|(n, s)| (n.as_str(), *s))
317 }
318}
319
320#[must_use]
324pub fn decompose(di: &DiVector, groups: &FeatureGroups) -> GroupScores {
325 let mut scores = Vec::with_capacity(groups.len());
326 for group in groups.groups() {
327 let contribution: f64 = group.indices.iter().map(|&i| di.per_dim_total(i)).sum();
328 scores.push((group.name.clone(), contribution));
329 }
330 GroupScores::new(scores, di.total())
331}
332
333impl<const D: usize> crate::forest::RandomCutForest<D> {
334 pub fn group_scores(&self, point: &[f64; D], groups: &FeatureGroups) -> RcfResult<GroupScores> {
345 groups.validate_for_dimension(D)?;
346 let di = self.attribution(point)?;
347 Ok(decompose(&di, groups))
348 }
349}
350
351impl<const D: usize> crate::thresholded::ThresholdedForest<D> {
352 pub fn group_scores(&self, point: &[f64; D], groups: &FeatureGroups) -> RcfResult<GroupScores> {
361 self.forest().group_scores(point, groups)
362 }
363}
364
365#[cfg(feature = "std")]
366impl<K, const D: usize> crate::pool::TenantForestPool<K, D>
367where
368 K: core::hash::Hash + Eq + Clone,
369{
370 pub fn group_scores(
388 &mut self,
389 key: &K,
390 point: &[f64; D],
391 groups: &FeatureGroups,
392 ) -> RcfResult<GroupScores> {
393 if !self.contains(key) {
397 self.score_only(key, point)?;
398 }
399 let detector = self
400 .get_mut(key)
401 .expect("tenant was just forced into the pool");
402 detector.group_scores(point, groups)
403 }
404}
405
406#[cfg(test)]
407#[allow(clippy::float_cmp)] mod tests {
409 use super::*;
410
411 fn four_d() -> FeatureGroups {
412 FeatureGroups::builder()
413 .add("rate", [0, 1])
414 .add("payload", [2, 3])
415 .build()
416 .unwrap()
417 }
418
419 #[test]
420 fn feature_group_rejects_empty_name() {
421 assert!(FeatureGroup::new("", [0]).is_err());
422 }
423
424 #[test]
425 fn feature_group_rejects_empty_indices() {
426 assert!(FeatureGroup::new("rate", std::iter::empty::<usize>()).is_err());
427 }
428
429 #[test]
430 fn builder_rejects_empty_name() {
431 let err = FeatureGroups::builder().add("", [0]).build().unwrap_err();
432 assert!(matches!(err, RcfError::InvalidConfig(_)));
433 }
434
435 #[test]
436 fn builder_rejects_empty_indices() {
437 let err = FeatureGroups::builder()
438 .add("rate", std::iter::empty::<usize>())
439 .build()
440 .unwrap_err();
441 assert!(matches!(err, RcfError::InvalidConfig(_)));
442 }
443
444 #[test]
445 fn builder_rejects_duplicate_names() {
446 let err = FeatureGroups::builder()
447 .add("rate", [0])
448 .add("rate", [1])
449 .build()
450 .unwrap_err();
451 assert!(matches!(err, RcfError::InvalidConfig(_)));
452 }
453
454 #[test]
455 fn builder_tracks_max_index() {
456 let g = FeatureGroups::builder()
457 .add("a", [0, 3])
458 .add("b", [7])
459 .build()
460 .unwrap();
461 assert_eq!(g.max_index(), 7);
462 }
463
464 #[test]
465 fn validate_for_dimension_passes_on_fit() {
466 let g = four_d();
467 g.validate_for_dimension(4).unwrap();
468 }
469
470 #[test]
471 fn validate_for_dimension_rejects_out_of_bounds() {
472 let g = four_d();
473 let err = g.validate_for_dimension(3).unwrap_err();
474 assert!(matches!(err, RcfError::OutOfBounds { .. }));
475 }
476
477 #[test]
478 fn validate_for_dimension_accepts_empty_set() {
479 let g = FeatureGroups::default();
480 g.validate_for_dimension(0).unwrap();
481 g.validate_for_dimension(100).unwrap();
482 }
483
484 #[test]
485 fn decompose_partitioning_matches_total() {
486 let mut di = DiVector::zeros(4);
487 di.add_high(0, 1.0).unwrap();
488 di.add_high(1, 2.0).unwrap();
489 di.add_low(2, 3.0).unwrap();
490 di.add_low(3, 4.0).unwrap();
491 let scores = decompose(&di, &four_d());
492 assert_eq!(scores.scores()[0], ("rate".to_string(), 3.0));
493 assert_eq!(scores.scores()[1], ("payload".to_string(), 7.0));
494 assert_eq!(scores.total(), 10.0);
495 assert_eq!(scores.explained(), 10.0);
496 assert_eq!(scores.coverage(), 1.0);
497 }
498
499 #[test]
500 fn decompose_with_gap_has_coverage_below_one() {
501 let mut di = DiVector::zeros(4);
502 di.add_high(0, 1.0).unwrap();
503 di.add_high(1, 1.0).unwrap();
504 di.add_high(2, 2.0).unwrap();
506 di.add_high(3, 2.0).unwrap();
507 let only_rate = FeatureGroups::builder()
508 .add("rate", [0, 1])
509 .build()
510 .unwrap();
511 let scores = decompose(&di, &only_rate);
512 assert_eq!(scores.scores()[0], ("rate".to_string(), 2.0));
513 assert_eq!(scores.total(), 6.0);
514 assert!((scores.coverage() - 2.0 / 6.0).abs() < 1e-12);
515 }
516
517 #[test]
518 fn decompose_with_overlap_has_coverage_above_one() {
519 let mut di = DiVector::zeros(2);
520 di.add_high(0, 1.0).unwrap();
521 di.add_high(1, 1.0).unwrap();
522 let overlap = FeatureGroups::builder()
523 .add("a", [0])
524 .add("b", [0, 1])
525 .build()
526 .unwrap();
527 let scores = decompose(&di, &overlap);
528 assert_eq!(scores.scores()[0], ("a".to_string(), 1.0));
529 assert_eq!(scores.scores()[1], ("b".to_string(), 2.0));
530 assert_eq!(scores.total(), 2.0);
531 assert!((scores.coverage() - 3.0 / 2.0).abs() < 1e-12);
532 }
533
534 #[test]
535 fn top_group_picks_max_contribution() {
536 let mut di = DiVector::zeros(4);
537 di.add_high(0, 1.0).unwrap();
538 di.add_high(1, 1.0).unwrap();
539 di.add_low(2, 5.0).unwrap();
540 di.add_low(3, 5.0).unwrap();
541 let scores = decompose(&di, &four_d());
542 let (name, value) = scores.top_group().unwrap();
543 assert_eq!(name, "payload");
544 assert_eq!(value, 10.0);
545 }
546
547 #[test]
548 fn empty_group_scores_top_is_none() {
549 let scores = GroupScores::new(vec![], 0.0);
550 assert!(scores.top_group().is_none());
551 assert!(scores.is_empty());
552 assert_eq!(scores.coverage(), 0.0);
553 }
554
555 #[test]
556 fn coverage_handles_zero_total() {
557 let scores = GroupScores::new(vec![("a".into(), 0.0)], 0.0);
558 assert_eq!(scores.coverage(), 0.0);
559 }
560}