candle_mi/sparse.rs
1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Shared sparse-feature types used by both CLT and SAE modules.
4//!
5//! [`FeatureId`] is a marker trait for feature identifiers, and
6//! [`SparseActivations`] stores the non-zero activations in descending
7//! magnitude order. These live here so that the public type identity is
8//! stable regardless of which feature flags are enabled.
9
10/// Marker trait for feature identifiers in sparse activation vectors.
11///
12/// Implemented by `CltFeatureId` (CLT features with layer + index,
13/// requires `clt` feature) and `SaeFeatureId` (SAE features with index
14/// only, requires `sae` feature).
15pub trait FeatureId:
16 std::fmt::Debug
17 + Clone
18 + Copy
19 + PartialEq
20 + Eq
21 + PartialOrd
22 + Ord
23 + std::hash::Hash
24 + std::fmt::Display
25{
26}
27
28/// Sparse representation of feature activations.
29///
30/// Only features with non-zero activation are stored,
31/// sorted by activation magnitude in descending order.
32///
33/// Generic over the feature identifier type `F`:
34/// - `CltFeatureId` for CLT features (layer + index, requires `clt` feature)
35/// - `SaeFeatureId` for SAE features (index only, requires `sae` feature)
36#[derive(Debug, Clone)]
37pub struct SparseActivations<F: FeatureId> {
38 /// Active features with their activation magnitudes, sorted descending.
39 pub features: Vec<(F, f32)>,
40}
41
42impl<F: FeatureId> SparseActivations<F> {
43 /// Number of active features.
44 #[must_use]
45 pub const fn len(&self) -> usize {
46 self.features.len()
47 }
48
49 /// Whether no features are active.
50 #[must_use]
51 pub const fn is_empty(&self) -> bool {
52 self.features.is_empty()
53 }
54
55 /// Truncate to the top-k most active features.
56 pub fn truncate(&mut self, k: usize) {
57 self.features.truncate(k);
58 }
59}