pub struct CrossLayerTranscoder { /* private fields */ }Expand description
Cross-Layer Transcoder.
Loads CLT encoder/decoder weights on-demand from HuggingFace safetensors,
with memory-efficient streaming (only one encoder on GPU at a time)
and a micro-cache for steering vectors.
Downloads are lazy: open() only fetches config and the first
encoder for dimension detection. Subsequent files are downloaded as needed by
load_encoder(), decoder_vector(),
and cache_steering_vectors().
§Example
use candle_mi::clt::CrossLayerTranscoder;
use candle_core::Device;
let mut clt = CrossLayerTranscoder::open("mntss/clt-gemma-2-2b-426k")?;
println!("CLT: {} layers, d_model={}", clt.config().n_layers, clt.config().d_model);
// Load encoder for layer 10
let device = Device::Cpu;
clt.load_encoder(10, &device)?;Implementations§
Source§impl CrossLayerTranscoder
impl CrossLayerTranscoder
Sourcepub fn open(clt_repo: &str) -> Result<Self>
pub fn open(clt_repo: &str) -> Result<Self>
Open a CLT from HuggingFace and detect its configuration.
Only downloads config.yaml and W_enc_0.safetensors (~75 MB).
All other encoder/decoder files are downloaded lazily on first use.
§Arguments
clt_repo—HuggingFacerepository ID (e.g.,"mntss/clt-gemma-2-2b-426k")
§Errors
Returns MIError::Download if the repository is inaccessible or files
cannot be fetched. Returns MIError::Config if the weight format is
unexpected.
Sourcepub fn loaded_encoder_layer(&self) -> Option<usize>
pub fn loaded_encoder_layer(&self) -> Option<usize>
Check whether an encoder is currently loaded and for which layer.
Sourcepub fn load_encoder(&mut self, layer: usize, device: &Device) -> Result<()>
pub fn load_encoder(&mut self, layer: usize, device: &Device) -> Result<()>
Load a single encoder’s weights to the specified device.
Frees any previously loaded encoder first (stream-and-free pattern). Peak GPU overhead: ~75 MB for CLT-426K, ~450 MB for CLT-2.5M.
§Arguments
layer— Layer index (0..n_layers)device— Target device (CPU or CUDA)
§Errors
Returns MIError::Config if the layer is out of range.
Returns MIError::Download if the encoder file cannot be fetched.
Returns MIError::Model on tensor deserialization failure.
Sourcepub fn encode(
&self,
residual: &Tensor,
layer: usize,
) -> Result<SparseActivations<CltFeatureId>>
pub fn encode( &self, residual: &Tensor, layer: usize, ) -> Result<SparseActivations<CltFeatureId>>
Encode a residual stream activation into sparse CLT features.
The residual should be the “residual mid” activation at the given layer (after attention, before MLP).
Returns all features that pass the ReLU threshold, sorted by
activation magnitude in descending order.
§Shapes
residual:[d_model]— residual stream activation at one position- returns:
SparseActivations<CltFeatureId>with(CltFeatureId, f32)pairs
§Requires
load_encoder(layer) must have been called first.
§Errors
Returns MIError::Hook if no encoder is loaded or the wrong layer is loaded.
Returns MIError::Model on tensor operation failure.
Sourcepub fn top_k(
&self,
residual: &Tensor,
layer: usize,
k: usize,
) -> Result<SparseActivations<CltFeatureId>>
pub fn top_k( &self, residual: &Tensor, layer: usize, k: usize, ) -> Result<SparseActivations<CltFeatureId>>
Encode and return only the top-k most active features.
§Shapes
residual:[d_model]— residual stream activation at one position- returns:
SparseActivations<CltFeatureId>truncated to at mostkentries
§Requires
load_encoder(layer) must have been called first.
§Errors
Same as encode().
Sourcepub fn decoder_vector(
&mut self,
feature: &CltFeatureId,
target_layer: usize,
device: &Device,
) -> Result<Tensor>
pub fn decoder_vector( &mut self, feature: &CltFeatureId, target_layer: usize, device: &Device, ) -> Result<Tensor>
Extract a single feature’s decoder vector for a target downstream layer.
Loads from safetensors on demand. Checks the steering cache first to avoid redundant file reads.
§Shapes
- returns:
[d_model]— decoder vector ondevice
§Arguments
feature— The CLT feature to extract the decoder fortarget_layer— The downstream layer to decode to (must be >= feature.layer)device— Device to place the resulting tensor on
§Errors
Returns MIError::Config if layer indices are out of range.
Returns MIError::Download if the decoder file cannot be fetched.
Returns MIError::Model on tensor operation failure.
Sourcepub fn cache_steering_vectors(
&mut self,
features: &[(CltFeatureId, usize)],
device: &Device,
) -> Result<()>
pub fn cache_steering_vectors( &mut self, features: &[(CltFeatureId, usize)], device: &Device, ) -> Result<()>
Pre-load decoder vectors into the steering micro-cache.
Each entry is a (CltFeatureId, target_layer) pair. Vectors are
loaded to the specified device and kept pinned for repeated injection.
Uses an OOM-safe pattern: loads each decoder file to CPU, extracts needed columns as independent F32 tensors, drops the large file, then moves small tensors to the target device.
Memory: 50 features × 2304 × 4 bytes = ~450 KB (negligible).
§Errors
Returns MIError::Download if decoder files cannot be fetched.
Returns MIError::Model on tensor operation failure.
Sourcepub fn cache_steering_vectors_all_downstream(
&mut self,
features: &[CltFeatureId],
device: &Device,
) -> Result<()>
pub fn cache_steering_vectors_all_downstream( &mut self, features: &[CltFeatureId], device: &Device, ) -> Result<()>
Cache steering vectors for ALL downstream layers of each feature.
For each feature at source layer l, caches decoder vectors for every
downstream target layer l..n_layers. This enables multi-layer
“clamping” injection where the steering signal propagates through all
downstream transformer layers.
Same OOM-safe pattern as cache_steering_vectors().
§Arguments
features— Feature IDs to cache (all downstream layers are cached automatically)device— Device to store cached tensors on (typically GPU)
§Errors
Returns MIError::Config if any feature layer is out of range.
Returns MIError::Download if decoder files cannot be fetched.
Returns MIError::Model on tensor operation failure.
Sourcepub fn clear_steering_cache(&mut self)
pub fn clear_steering_cache(&mut self)
Clear all cached steering vectors, freeing device memory.
Sourcepub fn steering_cache_len(&self) -> usize
pub fn steering_cache_len(&self) -> usize
Number of vectors currently in the steering cache.
Sourcepub fn prepare_hook_injection(
&self,
features: &[(CltFeatureId, usize)],
position: usize,
seq_len: usize,
strength: f32,
device: &Device,
) -> Result<HookSpec>
pub fn prepare_hook_injection( &self, features: &[(CltFeatureId, usize)], position: usize, seq_len: usize, strength: f32, device: &Device, ) -> Result<HookSpec>
Build a crate::HookSpec that injects CLT decoder vectors into the residual stream.
Groups cached steering vectors by target layer, accumulates them per layer,
scales by strength, and creates crate::Intervention::Add entries on
crate::HookPoint::ResidPost for each target layer. The resulting HookSpec
can be passed directly to MIModel::forward().
§Shapes
- Internally constructs
[1, seq_len, d_model]tensors with the steering vector placed atpositionand zeros elsewhere.
§Arguments
features— List of(feature_id, target_layer)pairs (must be cached)position— Token position in the sequence to inject atseq_len— Total sequence length (needed to construct position-specific tensors)strength— Scalar multiplier for the accumulated steering vectorsdevice— Device to construct injection tensors on
§Errors
Returns MIError::Hook if any feature is not in the steering cache.
Returns MIError::Model on tensor construction failure.
Sourcepub fn inject(
&self,
residual: &Tensor,
features: &[(CltFeatureId, usize)],
position: usize,
strength: f32,
) -> Result<Tensor>
pub fn inject( &self, residual: &Tensor, features: &[(CltFeatureId, usize)], position: usize, strength: f32, ) -> Result<Tensor>
Inject cached steering vectors directly into a residual stream tensor.
Convenience method for use outside the forward pass (e.g., in analysis
scripts). Returns a new tensor with the injection applied:
residual[:, position, :] += strength × Σ decoder_vectors
§Shapes
residual:[batch, seq_len, d_model]— hidden states- returns:
[batch, seq_len, d_model]— modified hidden states
§Arguments
residual— Hidden states tensorfeatures— List of(feature, target_layer)pairs to inject (must be cached)position— Token position in the sequence to inject atstrength— Scalar multiplier for the steering vectors
§Errors
Returns MIError::Hook if any feature is not in the steering cache.
Returns MIError::Config if dimensions don’t match.
Returns MIError::Model on tensor operation failure.
Sourcepub fn score_features_by_decoder_projection(
&mut self,
direction: &Tensor,
target_layer: usize,
top_k: usize,
cosine: bool,
) -> Result<Vec<(CltFeatureId, f32)>>
pub fn score_features_by_decoder_projection( &mut self, direction: &Tensor, target_layer: usize, top_k: usize, cosine: bool, ) -> Result<Vec<(CltFeatureId, f32)>>
Score all CLT features by how strongly their decoder vector at
target_layer projects along a given direction vector.
For each source layer 0..n_layers where source_layer <= target_layer:
loads the decoder file to CPU, extracts the target layer slice
[n_features, d_model], and computes scores = slice @ direction.
When cosine is true, scores are normalized by both the direction
vector norm and each decoder row norm (cosine similarity).
§Shapes
direction:[d_model]— target direction vector (e.g., token embedding)- returns: top-k
(CltFeatureId, f32)pairs, sorted by score descending
§Arguments
direction—[d_model]direction vector to project decoders ontotarget_layer— downstream layer to examine decoders attop_k— number of top-scoring features to returncosine— whether to use cosine similarity instead of dot product
§Errors
Returns MIError::Config if direction shape is wrong or target_layer
is out of range.
Returns MIError::Download if decoder files cannot be fetched.
Returns MIError::Model on tensor operation failure.
§Memory
Processes one decoder file at a time on CPU (up to ~2 GB for layer 0). No GPU memory required.
Sourcepub fn score_features_by_decoder_projection_batch(
&mut self,
directions: &[Tensor],
target_layer: usize,
top_k: usize,
cosine: bool,
) -> Result<Vec<Vec<(CltFeatureId, f32)>>>
pub fn score_features_by_decoder_projection_batch( &mut self, directions: &[Tensor], target_layer: usize, top_k: usize, cosine: bool, ) -> Result<Vec<Vec<(CltFeatureId, f32)>>>
Batch version of score_features_by_decoder_projection.
Scores multiple direction vectors against all decoder files in a single
pass. Each decoder file is loaded once for all directions, reducing
I/O from n_words × n_layers file reads to just n_layers.
§Shapes
directions: slice of[d_model]tensors (one per word/direction)- returns: one
Vec<(CltFeatureId, f32)>per direction (top-k per word)
§Arguments
directions— slice of[d_model]direction vectorstarget_layer— downstream layer to examine decoders attop_k— number of top-scoring features to return per directioncosine— whether to use cosine similarity
§Errors
Returns MIError::Config if any direction has wrong shape, directions is
empty, or target_layer is out of range.
Returns MIError::Download if decoder files cannot be fetched.
Returns MIError::Model on tensor operation failure.
§Memory
Stacks directions to [n_words, d_model] on CPU. Each decoder file
loaded one at a time (up to ~2 GB for layer 0). No GPU memory required.
Sourcepub fn extract_decoder_vectors(
&mut self,
features: &[CltFeatureId],
target_layer: usize,
) -> Result<HashMap<CltFeatureId, Tensor>>
pub fn extract_decoder_vectors( &mut self, features: &[CltFeatureId], target_layer: usize, ) -> Result<HashMap<CltFeatureId, Tensor>>
Extract decoder vectors for a set of features at a specific target layer.
Groups features by source layer, loads each decoder file once, and
extracts the decoder vector at the target layer offset as an independent
F32 CPU tensor. Uses the OOM-safe to_vec1 + from_vec pattern to
ensure large decoder files are freed before processing the next layer.
§Shapes
- returns:
HashMap<CltFeatureId, Tensor>where each tensor is[d_model](F32, CPU)
§Arguments
features— feature IDs to extract decoder vectors fortarget_layer— downstream layer to extract decoders at
§Errors
Returns MIError::Config if any feature layer or target_layer is out
of range, or if target_layer < feature.layer for any feature.
Returns MIError::Download if decoder files cannot be fetched.
Returns MIError::Model on tensor operation failure.
§Memory
Loads each decoder to CPU (up to ~2 GB), extracts independent F32 tensors, then drops the large file before processing the next layer.
Sourcepub fn build_attribution_graph(
&mut self,
direction: &Tensor,
target_layer: usize,
top_k: usize,
cosine: bool,
) -> Result<AttributionGraph>
pub fn build_attribution_graph( &mut self, direction: &Tensor, target_layer: usize, top_k: usize, cosine: bool, ) -> Result<AttributionGraph>
Build an attribution graph by scoring features against a direction.
Convenience wrapper around
score_features_by_decoder_projection
that returns an AttributionGraph instead of a raw Vec.
§Shapes
direction:[d_model]
§Errors
Same as score_features_by_decoder_projection.
Sourcepub fn build_attribution_graph_batch(
&mut self,
directions: &[Tensor],
target_layer: usize,
top_k: usize,
cosine: bool,
) -> Result<Vec<AttributionGraph>>
pub fn build_attribution_graph_batch( &mut self, directions: &[Tensor], target_layer: usize, top_k: usize, cosine: bool, ) -> Result<Vec<AttributionGraph>>
Build attribution graphs for multiple directions in a single pass.
Convenience wrapper around
score_features_by_decoder_projection_batch
that returns Vec<AttributionGraph>.
§Shapes
directions: slice of[d_model]tensors
§Errors
Auto Trait Implementations§
impl Freeze for CrossLayerTranscoder
impl !RefUnwindSafe for CrossLayerTranscoder
impl Send for CrossLayerTranscoder
impl Sync for CrossLayerTranscoder
impl Unpin for CrossLayerTranscoder
impl UnsafeUnpin for CrossLayerTranscoder
impl !UnwindSafe for CrossLayerTranscoder
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more