Skip to main content

CrossLayerTranscoder

Struct CrossLayerTranscoder 

Source
pub struct CrossLayerTranscoder { /* private fields */ }
Expand description

Cross-Layer Transcoder.

Loads CLT encoder/decoder weights on-demand from HuggingFace safetensors, with memory-efficient streaming (only one encoder on GPU at a time) and a micro-cache for steering vectors.

Downloads are lazy: open() only fetches config and the first encoder for dimension detection. Subsequent files are downloaded as needed by load_encoder(), decoder_vector(), and cache_steering_vectors().

§Example

use candle_mi::clt::CrossLayerTranscoder;
use candle_core::Device;

let mut clt = CrossLayerTranscoder::open("mntss/clt-gemma-2-2b-426k")?;
println!("CLT: {} layers, d_model={}", clt.config().n_layers, clt.config().d_model);

// Load encoder for layer 10
let device = Device::Cpu;
clt.load_encoder(10, &device)?;

Implementations§

Source§

impl CrossLayerTranscoder

Source

pub fn open(clt_repo: &str) -> Result<Self>

Open a CLT from HuggingFace and detect its configuration.

Only downloads config.yaml and W_enc_0.safetensors (~75 MB). All other encoder/decoder files are downloaded lazily on first use.

§Arguments
  • clt_repoHuggingFace repository ID (e.g., "mntss/clt-gemma-2-2b-426k")
§Errors

Returns MIError::Download if the repository is inaccessible or files cannot be fetched. Returns MIError::Config if the weight format is unexpected.

Source

pub const fn config(&self) -> &CltConfig

Access the auto-detected CLT configuration.

Source

pub fn loaded_encoder_layer(&self) -> Option<usize>

Check whether an encoder is currently loaded and for which layer.

Source

pub fn load_encoder(&mut self, layer: usize, device: &Device) -> Result<()>

Load a single encoder’s weights to the specified device.

Frees any previously loaded encoder first (stream-and-free pattern). Peak GPU overhead: ~75 MB for CLT-426K, ~450 MB for CLT-2.5M.

§Arguments
  • layer — Layer index (0..n_layers)
  • device — Target device (CPU or CUDA)
§Errors

Returns MIError::Config if the layer is out of range. Returns MIError::Download if the encoder file cannot be fetched. Returns MIError::Model on tensor deserialization failure.

Source

pub fn encode( &self, residual: &Tensor, layer: usize, ) -> Result<SparseActivations<CltFeatureId>>

Encode a residual stream activation into sparse CLT features.

The residual should be the “residual mid” activation at the given layer (after attention, before MLP).

Returns all features that pass the ReLU threshold, sorted by activation magnitude in descending order.

§Shapes
§Requires

load_encoder(layer) must have been called first.

§Errors

Returns MIError::Hook if no encoder is loaded or the wrong layer is loaded. Returns MIError::Model on tensor operation failure.

Source

pub fn top_k( &self, residual: &Tensor, layer: usize, k: usize, ) -> Result<SparseActivations<CltFeatureId>>

Encode and return only the top-k most active features.

§Shapes
§Requires

load_encoder(layer) must have been called first.

§Errors

Same as encode().

Source

pub fn decoder_vector( &mut self, feature: &CltFeatureId, target_layer: usize, device: &Device, ) -> Result<Tensor>

Extract a single feature’s decoder vector for a target downstream layer.

Loads from safetensors on demand. Checks the steering cache first to avoid redundant file reads.

§Shapes
  • returns: [d_model] — decoder vector on device
§Arguments
  • feature — The CLT feature to extract the decoder for
  • target_layer — The downstream layer to decode to (must be >= feature.layer)
  • device — Device to place the resulting tensor on
§Errors

Returns MIError::Config if layer indices are out of range. Returns MIError::Download if the decoder file cannot be fetched. Returns MIError::Model on tensor operation failure.

Source

pub fn cache_steering_vectors( &mut self, features: &[(CltFeatureId, usize)], device: &Device, ) -> Result<()>

Pre-load decoder vectors into the steering micro-cache.

Each entry is a (CltFeatureId, target_layer) pair. Vectors are loaded to the specified device and kept pinned for repeated injection.

Uses an OOM-safe pattern: loads each decoder file to CPU, extracts needed columns as independent F32 tensors, drops the large file, then moves small tensors to the target device.

Memory: 50 features × 2304 × 4 bytes = ~450 KB (negligible).

§Errors

Returns MIError::Download if decoder files cannot be fetched. Returns MIError::Model on tensor operation failure.

Source

pub fn cache_steering_vectors_all_downstream( &mut self, features: &[CltFeatureId], device: &Device, ) -> Result<()>

Cache steering vectors for ALL downstream layers of each feature.

For each feature at source layer l, caches decoder vectors for every downstream target layer l..n_layers. This enables multi-layer “clamping” injection where the steering signal propagates through all downstream transformer layers.

Same OOM-safe pattern as cache_steering_vectors().

§Arguments
  • features — Feature IDs to cache (all downstream layers are cached automatically)
  • device — Device to store cached tensors on (typically GPU)
§Errors

Returns MIError::Config if any feature layer is out of range. Returns MIError::Download if decoder files cannot be fetched. Returns MIError::Model on tensor operation failure.

Source

pub fn clear_steering_cache(&mut self)

Clear all cached steering vectors, freeing device memory.

Source

pub fn steering_cache_len(&self) -> usize

Number of vectors currently in the steering cache.

Source

pub fn prepare_hook_injection( &self, features: &[(CltFeatureId, usize)], position: usize, seq_len: usize, strength: f32, device: &Device, ) -> Result<HookSpec>

Build a crate::HookSpec that injects CLT decoder vectors into the residual stream.

Groups cached steering vectors by target layer, accumulates them per layer, scales by strength, and creates crate::Intervention::Add entries on crate::HookPoint::ResidPost for each target layer. The resulting HookSpec can be passed directly to MIModel::forward().

§Shapes
  • Internally constructs [1, seq_len, d_model] tensors with the steering vector placed at position and zeros elsewhere.
§Arguments
  • features — List of (feature_id, target_layer) pairs (must be cached)
  • position — Token position in the sequence to inject at
  • seq_len — Total sequence length (needed to construct position-specific tensors)
  • strength — Scalar multiplier for the accumulated steering vectors
  • device — Device to construct injection tensors on
§Errors

Returns MIError::Hook if any feature is not in the steering cache. Returns MIError::Model on tensor construction failure.

Source

pub fn inject( &self, residual: &Tensor, features: &[(CltFeatureId, usize)], position: usize, strength: f32, ) -> Result<Tensor>

Inject cached steering vectors directly into a residual stream tensor.

Convenience method for use outside the forward pass (e.g., in analysis scripts). Returns a new tensor with the injection applied: residual[:, position, :] += strength × Σ decoder_vectors

§Shapes
  • residual: [batch, seq_len, d_model] — hidden states
  • returns: [batch, seq_len, d_model] — modified hidden states
§Arguments
  • residual — Hidden states tensor
  • features — List of (feature, target_layer) pairs to inject (must be cached)
  • position — Token position in the sequence to inject at
  • strength — Scalar multiplier for the steering vectors
§Errors

Returns MIError::Hook if any feature is not in the steering cache. Returns MIError::Config if dimensions don’t match. Returns MIError::Model on tensor operation failure.

Source

pub fn score_features_by_decoder_projection( &mut self, direction: &Tensor, target_layer: usize, top_k: usize, cosine: bool, ) -> Result<Vec<(CltFeatureId, f32)>>

Score all CLT features by how strongly their decoder vector at target_layer projects along a given direction vector.

For each source layer 0..n_layers where source_layer <= target_layer: loads the decoder file to CPU, extracts the target layer slice [n_features, d_model], and computes scores = slice @ direction.

When cosine is true, scores are normalized by both the direction vector norm and each decoder row norm (cosine similarity).

§Shapes
  • direction: [d_model] — target direction vector (e.g., token embedding)
  • returns: top-k (CltFeatureId, f32) pairs, sorted by score descending
§Arguments
  • direction[d_model] direction vector to project decoders onto
  • target_layer — downstream layer to examine decoders at
  • top_k — number of top-scoring features to return
  • cosine — whether to use cosine similarity instead of dot product
§Errors

Returns MIError::Config if direction shape is wrong or target_layer is out of range. Returns MIError::Download if decoder files cannot be fetched. Returns MIError::Model on tensor operation failure.

§Memory

Processes one decoder file at a time on CPU (up to ~2 GB for layer 0). No GPU memory required.

Source

pub fn score_features_by_decoder_projection_batch( &mut self, directions: &[Tensor], target_layer: usize, top_k: usize, cosine: bool, ) -> Result<Vec<Vec<(CltFeatureId, f32)>>>

Batch version of score_features_by_decoder_projection.

Scores multiple direction vectors against all decoder files in a single pass. Each decoder file is loaded once for all directions, reducing I/O from n_words × n_layers file reads to just n_layers.

§Shapes
  • directions: slice of [d_model] tensors (one per word/direction)
  • returns: one Vec<(CltFeatureId, f32)> per direction (top-k per word)
§Arguments
  • directions — slice of [d_model] direction vectors
  • target_layer — downstream layer to examine decoders at
  • top_k — number of top-scoring features to return per direction
  • cosine — whether to use cosine similarity
§Errors

Returns MIError::Config if any direction has wrong shape, directions is empty, or target_layer is out of range. Returns MIError::Download if decoder files cannot be fetched. Returns MIError::Model on tensor operation failure.

§Memory

Stacks directions to [n_words, d_model] on CPU. Each decoder file loaded one at a time (up to ~2 GB for layer 0). No GPU memory required.

Source

pub fn extract_decoder_vectors( &mut self, features: &[CltFeatureId], target_layer: usize, ) -> Result<HashMap<CltFeatureId, Tensor>>

Extract decoder vectors for a set of features at a specific target layer.

Groups features by source layer, loads each decoder file once, and extracts the decoder vector at the target layer offset as an independent F32 CPU tensor. Uses the OOM-safe to_vec1 + from_vec pattern to ensure large decoder files are freed before processing the next layer.

§Shapes
  • returns: HashMap<CltFeatureId, Tensor> where each tensor is [d_model] (F32, CPU)
§Arguments
  • features — feature IDs to extract decoder vectors for
  • target_layer — downstream layer to extract decoders at
§Errors

Returns MIError::Config if any feature layer or target_layer is out of range, or if target_layer < feature.layer for any feature. Returns MIError::Download if decoder files cannot be fetched. Returns MIError::Model on tensor operation failure.

§Memory

Loads each decoder to CPU (up to ~2 GB), extracts independent F32 tensors, then drops the large file before processing the next layer.

Source

pub fn build_attribution_graph( &mut self, direction: &Tensor, target_layer: usize, top_k: usize, cosine: bool, ) -> Result<AttributionGraph>

Build an attribution graph by scoring features against a direction.

Convenience wrapper around score_features_by_decoder_projection that returns an AttributionGraph instead of a raw Vec.

§Shapes
  • direction: [d_model]
§Errors

Same as score_features_by_decoder_projection.

Source

pub fn build_attribution_graph_batch( &mut self, directions: &[Tensor], target_layer: usize, top_k: usize, cosine: bool, ) -> Result<Vec<AttributionGraph>>

Build attribution graphs for multiple directions in a single pass.

Convenience wrapper around score_features_by_decoder_projection_batch that returns Vec<AttributionGraph>.

§Shapes
  • directions: slice of [d_model] tensors
§Errors

Same as score_features_by_decoder_projection_batch.

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T> Instrument for T

Source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
Source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> PolicyExt for T
where T: ?Sized,

Source§

fn and<P, B, E>(self, other: P) -> And<T, P>
where T: Policy<B, E>, P: Policy<B, E>,

Create a new Policy that returns Action::Follow only if self and other return Action::Follow. Read more
Source§

fn or<P, B, E>(self, other: P) -> Or<T, P>
where T: Policy<B, E>, P: Policy<B, E>,

Create a new Policy that returns Action::Follow if either self or other returns Action::Follow. Read more
Source§

impl<T> Same for T

Source§

type Output = T

Should always be Self
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V

Source§

impl<T> WithSubscriber for T

Source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

impl<T> ErasedDestructor for T
where T: 'static,