Skip to main content

candle_mi/
sparse.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Shared sparse-feature types used by both CLT and SAE modules.
4//!
5//! [`FeatureId`] is a marker trait for feature identifiers, and
6//! [`SparseActivations`] stores the non-zero activations in descending
7//! magnitude order.  These live here so that the public type identity is
8//! stable regardless of which feature flags are enabled.
9
10/// Marker trait for feature identifiers in sparse activation vectors.
11///
12/// Implemented by `CltFeatureId` (CLT features with layer + index,
13/// requires `clt` feature) and `SaeFeatureId` (SAE features with index
14/// only, requires `sae` feature).
15pub trait FeatureId:
16    std::fmt::Debug
17    + Clone
18    + Copy
19    + PartialEq
20    + Eq
21    + PartialOrd
22    + Ord
23    + std::hash::Hash
24    + std::fmt::Display
25{
26}
27
28/// Sparse representation of feature activations.
29///
30/// Only features with non-zero activation are stored,
31/// sorted by activation magnitude in descending order.
32///
33/// Generic over the feature identifier type `F`:
34/// - `CltFeatureId` for CLT features (layer + index, requires `clt` feature)
35/// - `SaeFeatureId` for SAE features (index only, requires `sae` feature)
36#[derive(Debug, Clone)]
37pub struct SparseActivations<F: FeatureId> {
38    /// Active features with their activation magnitudes, sorted descending.
39    pub features: Vec<(F, f32)>,
40}
41
42impl<F: FeatureId> SparseActivations<F> {
43    /// Number of active features.
44    #[must_use]
45    pub const fn len(&self) -> usize {
46        self.features.len()
47    }
48
49    /// Whether no features are active.
50    #[must_use]
51    pub const fn is_empty(&self) -> bool {
52        self.features.is_empty()
53    }
54
55    /// Truncate to the top-k most active features.
56    pub fn truncate(&mut self, k: usize) {
57        self.features.truncate(k);
58    }
59}