use std::collections::HashMap;
use std::path::PathBuf;
use candle_core::{DType, Device, IndexOp, Tensor};
use safetensors::tensor::SafeTensors;
use tracing::info;
use crate::error::{MIError, Result};
#[derive(
Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize,
)]
pub struct CltFeatureId {
pub layer: usize,
pub index: usize,
}
impl std::fmt::Display for CltFeatureId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "L{}:{}", self.layer, self.index)
}
}
use crate::sparse::{FeatureId, SparseActivations};
impl FeatureId for CltFeatureId {}
#[derive(Debug, Clone)]
pub struct AttributionEdge {
pub feature: CltFeatureId,
pub score: f32,
}
#[derive(Debug, Clone)]
pub struct AttributionGraph {
target_layer: usize,
edges: Vec<AttributionEdge>,
}
impl AttributionGraph {
#[must_use]
pub const fn target_layer(&self) -> usize {
self.target_layer
}
#[must_use]
pub fn edges(&self) -> &[AttributionEdge] {
&self.edges
}
#[must_use]
pub const fn len(&self) -> usize {
self.edges.len()
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.edges.is_empty()
}
#[must_use]
pub fn top_k(&self, k: usize) -> Self {
Self {
target_layer: self.target_layer,
edges: self.edges.iter().take(k).cloned().collect(),
}
}
#[must_use]
pub fn threshold(&self, min_score: f32) -> Self {
Self {
target_layer: self.target_layer,
edges: self
.edges
.iter()
.filter(|e| e.score.abs() >= min_score)
.cloned()
.collect(),
}
}
#[must_use]
pub fn features(&self) -> Vec<CltFeatureId> {
self.edges.iter().map(|e| e.feature).collect()
}
#[must_use]
pub fn into_edges(self) -> Vec<AttributionEdge> {
self.edges
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum TranscoderSchema {
CltSplit,
PltBundle,
GemmaScopeNpz,
}
impl TranscoderSchema {
#[must_use]
pub const fn is_cross_layer(self) -> bool {
matches!(self, Self::CltSplit)
}
#[must_use]
pub const fn is_jump_relu(self) -> bool {
matches!(self, Self::GemmaScopeNpz)
}
}
const GEMMASCOPE_DEFERRAL_ERR: &str = "GemmaScope loader lands in v0.1.10 — see roadmap Step 1.6";
#[derive(Debug, Clone)]
pub struct CltConfig {
pub n_layers: usize,
pub d_model: usize,
pub n_features_per_layer: usize,
pub n_features_total: usize,
pub model_name: String,
pub schema: TranscoderSchema,
pub gemmascope_npz_paths: Vec<String>,
}
struct LoadedEncoder {
layer: usize,
w_enc: Tensor,
b_enc: Tensor,
}
pub struct CrossLayerTranscoder {
repo_id: String,
fetch_config: hf_fetch_model::FetchConfig,
encoder_paths: Vec<Option<PathBuf>>,
decoder_paths: Vec<Option<PathBuf>>,
config: CltConfig,
loaded_encoder: Option<LoadedEncoder>,
steering_cache: HashMap<(CltFeatureId, usize), Tensor>,
}
impl CrossLayerTranscoder {
#[allow(clippy::too_many_lines)]
pub fn open(clt_repo: &str) -> Result<Self> {
let fetch_config = crate::download::fetch_config_builder()
.on_progress(|event| {
tracing::info!(
filename = %event.filename,
percent = event.percent,
bytes_downloaded = event.bytes_downloaded,
bytes_total = event.bytes_total,
"CLT download progress",
);
})
.build()
.map_err(|e| MIError::Download(format!("failed to build fetch config: {e}")))?;
let rt = tokio::runtime::Runtime::new()
.map_err(|e| MIError::Download(format!("failed to create tokio runtime: {e}")))?;
let hf_token = std::env::var("HF_TOKEN").ok();
let http_client = hf_fetch_model::build_client(hf_token.as_deref())
.map_err(|e| MIError::Download(format!("failed to build HTTP client: {e}")))?;
let repo_files = rt
.block_on(hf_fetch_model::repo::list_repo_files_with_metadata(
clt_repo,
None,
None,
&http_client,
))
.map_err(|e| MIError::Download(format!("failed to list repo files: {e}")))?;
let filenames: Vec<&str> = repo_files.iter().map(|f| f.filename.as_str()).collect();
let schema = classify_transcoder_schema(&filenames).map_err(|_| {
MIError::Config(format!(
"unrecognised transcoder repo layout for {clt_repo}"
))
})?;
info!("Transcoder schema detected for {clt_repo}: {schema:?}");
if matches!(schema, TranscoderSchema::GemmaScopeNpz) {
return Err(MIError::Config(GEMMASCOPE_DEFERRAL_ERR.into()));
}
let n_layers = match schema {
TranscoderSchema::CltSplit => repo_files
.iter()
.filter(|f| {
f.filename.starts_with("W_enc_") && f.filename.ends_with(".safetensors")
})
.count(),
TranscoderSchema::PltBundle => repo_files
.iter()
.filter(|f| {
f.filename.starts_with("layer_") && f.filename.ends_with(".safetensors")
})
.count(),
TranscoderSchema::GemmaScopeNpz => {
return Err(MIError::Config(GEMMASCOPE_DEFERRAL_ERR.into()));
}
};
if n_layers == 0 {
return Err(MIError::Config(format!(
"no encoder files found in {clt_repo} (schema={schema:?})"
)));
}
let model_name = match hf_fetch_model::download_file_blocking(
clt_repo.to_owned(),
"config.yaml",
&fetch_config,
) {
Ok(outcome) => {
let path = outcome.into_inner();
let text = std::fs::read_to_string(&path)?;
parse_yaml_value(&text, "model_name").unwrap_or_else(|| "unknown".to_owned())
}
Err(_) => "unknown".to_owned(),
};
let (enc0_filename, enc_tensor_name, _) = encoder_file_and_tensor_names(schema, 0, &[])?;
let enc0_path = hf_fetch_model::download_file_blocking(
clt_repo.to_owned(),
&enc0_filename,
&fetch_config,
)
.map_err(|e| MIError::Download(format!("failed to download {enc0_filename}: {e}")))?
.into_inner();
let data = std::fs::read(&enc0_path)?;
let tensors = SafeTensors::deserialize(&data).map_err(|e| {
MIError::Config(format!("failed to deserialize {enc_tensor_name}: {e}"))
})?;
let w_enc_view = tensors
.tensor(&enc_tensor_name)
.map_err(|e| MIError::Config(format!("tensor '{enc_tensor_name}' not found: {e}")))?;
let shape = w_enc_view.shape();
if shape.len() != 2 {
return Err(MIError::Config(format!(
"expected 2D encoder weight, got shape {shape:?}"
)));
}
let n_features_per_layer = *shape
.first()
.ok_or_else(|| MIError::Config("encoder weight shape is empty".into()))?;
let d_model = *shape.get(1).ok_or_else(|| {
MIError::Config("encoder weight shape has fewer than 2 dimensions".into())
})?;
let mut encoder_paths: Vec<Option<PathBuf>> = vec![None; n_layers];
if let Some(slot) = encoder_paths.first_mut() {
*slot = Some(enc0_path);
}
let decoder_paths: Vec<Option<PathBuf>> = vec![None; n_layers];
let config = CltConfig {
n_layers,
d_model,
n_features_per_layer,
n_features_total: n_layers * n_features_per_layer,
model_name,
schema,
gemmascope_npz_paths: Vec::new(),
};
info!(
"CLT config: {} layers, d_model={}, features_per_layer={}, total={}, schema={:?}",
config.n_layers,
config.d_model,
config.n_features_per_layer,
config.n_features_total,
config.schema,
);
Ok(Self {
repo_id: clt_repo.to_owned(),
fetch_config,
encoder_paths,
decoder_paths,
config,
loaded_encoder: None,
steering_cache: HashMap::new(),
})
}
#[must_use]
pub const fn config(&self) -> &CltConfig {
&self.config
}
#[must_use]
pub fn loaded_encoder_layer(&self) -> Option<usize> {
self.loaded_encoder.as_ref().map(|e| e.layer)
}
fn ensure_encoder_path(&mut self, layer: usize) -> Result<PathBuf> {
if let Some(path) = self
.encoder_paths
.get(layer)
.and_then(std::option::Option::as_ref)
{
return Ok(path.clone());
}
let (filename, _, _) = encoder_file_and_tensor_names(
self.config.schema,
layer,
&self.config.gemmascope_npz_paths,
)?;
info!("Downloading {filename} from {}", self.repo_id);
let path = hf_fetch_model::download_file_blocking(
self.repo_id.clone(),
&filename,
&self.fetch_config,
)
.map_err(|e| MIError::Download(format!("failed to download {filename}: {e}")))?
.into_inner();
if let Some(slot) = self.encoder_paths.get_mut(layer) {
*slot = Some(path.clone());
}
Ok(path)
}
fn ensure_decoder_path(&mut self, layer: usize) -> Result<PathBuf> {
if !matches!(self.config.schema, TranscoderSchema::CltSplit) {
return self.ensure_encoder_path(layer);
}
if let Some(path) = self
.decoder_paths
.get(layer)
.and_then(std::option::Option::as_ref)
{
return Ok(path.clone());
}
let (filename, _) = decoder_file_and_tensor_name(
self.config.schema,
layer,
&self.config.gemmascope_npz_paths,
)?;
info!("Downloading {filename} from {}", self.repo_id);
let path = hf_fetch_model::download_file_blocking(
self.repo_id.clone(),
&filename,
&self.fetch_config,
)
.map_err(|e| MIError::Download(format!("failed to download {filename}: {e}")))?
.into_inner();
if let Some(slot) = self.decoder_paths.get_mut(layer) {
*slot = Some(path.clone());
}
Ok(path)
}
pub fn load_encoder(&mut self, layer: usize, device: &Device) -> Result<()> {
if layer >= self.config.n_layers {
return Err(MIError::Config(format!(
"layer {layer} out of range (CLT has {} layers)",
self.config.n_layers
)));
}
if let Some(ref enc) = self.loaded_encoder
&& enc.layer == layer
{
return Ok(());
}
self.loaded_encoder = None;
info!("Loading CLT encoder for layer {layer}");
let enc_path = self.ensure_encoder_path(layer)?;
let data = std::fs::read(&enc_path)?;
let st = SafeTensors::deserialize(&data).map_err(|e| {
MIError::Config(format!("failed to deserialize encoder layer {layer}: {e}"))
})?;
let (_, w_enc_name, b_enc_name) = encoder_file_and_tensor_names(
self.config.schema,
layer,
&self.config.gemmascope_npz_paths,
)?;
let w_enc = tensor_from_view(
&st.tensor(&w_enc_name)
.map_err(|e| MIError::Config(format!("tensor '{w_enc_name}' not found: {e}")))?,
device,
)?;
let b_enc = tensor_from_view(
&st.tensor(&b_enc_name)
.map_err(|e| MIError::Config(format!("tensor '{b_enc_name}' not found: {e}")))?,
device,
)?;
self.loaded_encoder = Some(LoadedEncoder {
layer,
w_enc,
b_enc,
});
Ok(())
}
pub fn load_skip_matrix(&mut self, layer: usize, device: &Device) -> Result<Tensor> {
if !matches!(self.config.schema, TranscoderSchema::PltBundle) {
return Err(MIError::Config(format!(
"load_skip_matrix: W_skip is only present in PltBundle schema \
(current schema: {:?})",
self.config.schema,
)));
}
if layer >= self.config.n_layers {
return Err(MIError::Config(format!(
"layer {layer} out of range (transcoder has {} layers)",
self.config.n_layers
)));
}
info!("Loading W_skip for PltBundle layer {layer}");
let path = self.ensure_encoder_path(layer)?;
let data = std::fs::read(&path)?;
let st = SafeTensors::deserialize(&data).map_err(|e| {
MIError::Config(format!("failed to deserialize bundle layer {layer}: {e}"))
})?;
let view = st.tensor("W_skip").map_err(|e| {
MIError::Config(format!("tensor 'W_skip' not found in layer {layer}: {e}"))
})?;
let w_skip = tensor_from_view(&view, device)?;
let w_skip_f32 = w_skip.to_dtype(DType::F32)?;
Ok(w_skip_f32)
}
pub fn encode(
&self,
residual: &Tensor,
layer: usize,
) -> Result<SparseActivations<CltFeatureId>> {
let pre_acts = self.encode_pre_activation_impl(residual, layer)?;
let acts = pre_acts.relu()?;
let acts_vec: Vec<f32> = acts.to_vec1()?;
let mut features: Vec<(CltFeatureId, f32)> = acts_vec
.iter()
.enumerate()
.filter(|&(_, v)| *v > 0.0)
.map(|(i, v)| (CltFeatureId { layer, index: i }, *v))
.collect();
features.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(SparseActivations { features })
}
pub fn encode_pre_activation(&self, residual: &Tensor, layer: usize) -> Result<Tensor> {
self.encode_pre_activation_impl(residual, layer)
}
fn encode_pre_activation_impl(&self, residual: &Tensor, layer: usize) -> Result<Tensor> {
let enc = self.loaded_encoder.as_ref().ok_or_else(|| {
MIError::Hook(format!(
"no encoder loaded — call load_encoder({layer}) first"
))
})?;
if enc.layer != layer {
return Err(MIError::Hook(format!(
"loaded encoder is for layer {}, but layer {layer} was requested",
enc.layer
)));
}
let residual_f32 = residual.flatten_all()?;
let residual_f32 = residual_f32.to_dtype(DType::F32)?;
let w_enc_f32 = enc.w_enc.to_dtype(DType::F32)?;
let b_enc_f32 = enc.b_enc.to_dtype(DType::F32)?;
let pre_acts = w_enc_f32.matmul(&residual_f32.unsqueeze(1)?)?.squeeze(1)?;
let pre_acts = (&pre_acts + &b_enc_f32)?;
Ok(pre_acts)
}
pub fn top_k(
&self,
residual: &Tensor,
layer: usize,
k: usize,
) -> Result<SparseActivations<CltFeatureId>> {
let mut sparse = self.encode(residual, layer)?;
sparse.truncate(k);
Ok(sparse)
}
pub fn decoder_vector(
&mut self,
feature: &CltFeatureId,
target_layer: usize,
device: &Device,
) -> Result<Tensor> {
if feature.layer >= self.config.n_layers {
return Err(MIError::Config(format!(
"feature source layer {} out of range (CLT has {} layers)",
feature.layer, self.config.n_layers
)));
}
if target_layer < feature.layer || target_layer >= self.config.n_layers {
return Err(MIError::Config(format!(
"target layer {target_layer} must be >= source layer {} and < {}",
feature.layer, self.config.n_layers
)));
}
if feature.index >= self.config.n_features_per_layer {
return Err(MIError::Config(format!(
"feature index {} out of range (max {})",
feature.index, self.config.n_features_per_layer
)));
}
let cache_key = (*feature, target_layer);
if let Some(cached) = self.steering_cache.get(&cache_key) {
return Ok(cached.clone());
}
let target_offset = target_layer - feature.layer;
let dec_path = self.ensure_decoder_path(feature.layer)?;
let data = std::fs::read(&dec_path)?;
let st = SafeTensors::deserialize(&data).map_err(|e| {
MIError::Config(format!(
"failed to deserialize decoder layer {}: {e}",
feature.layer
))
})?;
let (_, dec_tensor_name) = decoder_file_and_tensor_name(
self.config.schema,
feature.layer,
&self.config.gemmascope_npz_paths,
)?;
let w_dec = tensor_from_view(
&st.tensor(&dec_tensor_name).map_err(|e| {
MIError::Config(format!("tensor '{dec_tensor_name}' not found: {e}"))
})?,
&Device::Cpu,
)?;
let column = decoder_row(&w_dec, feature.index, target_offset, self.config.schema)?;
let column = column.to_device(device)?;
Ok(column)
}
pub fn cache_steering_vectors(
&mut self,
features: &[(CltFeatureId, usize)],
device: &Device,
) -> Result<()> {
let mut by_source: HashMap<usize, Vec<(usize, usize)>> = HashMap::new();
for (fid, target_layer) in features {
by_source
.entry(fid.layer)
.or_default()
.push((fid.index, *target_layer));
}
let mut loaded = 0_usize;
let n_source_layers = by_source.len();
for (layer_idx, (source_layer, entries)) in by_source.iter().enumerate() {
info!(
"cache_steering_vectors: loading decoder for source layer {} ({}/{})",
source_layer,
layer_idx + 1,
n_source_layers
);
let mut by_target: HashMap<usize, Vec<usize>> = HashMap::new();
for &(index, target_layer) in entries {
by_target.entry(target_layer).or_default().push(index);
}
let mut cpu_columns: Vec<(CltFeatureId, usize, Tensor)> = Vec::new();
{
let dec_path = self.ensure_decoder_path(*source_layer)?;
let data = std::fs::read(&dec_path)?;
info!(
"cache_steering_vectors: loaded {} MB for layer {}",
data.len() / (1024 * 1024),
source_layer
);
let st = SafeTensors::deserialize(&data).map_err(|e| {
MIError::Config(format!(
"failed to deserialize decoder layer {source_layer}: {e}"
))
})?;
let (_, dec_tensor_name) = decoder_file_and_tensor_name(
self.config.schema,
*source_layer,
&self.config.gemmascope_npz_paths,
)?;
let w_dec = tensor_from_view(
&st.tensor(&dec_tensor_name).map_err(|e| {
MIError::Config(format!("tensor '{dec_tensor_name}' not found: {e}"))
})?,
&Device::Cpu,
)?;
for (target_layer, indices) in &by_target {
let target_offset = target_layer - source_layer;
for &index in indices {
let fid = CltFeatureId {
layer: *source_layer,
index,
};
let cache_key = (fid, *target_layer);
if !self.steering_cache.contains_key(&cache_key) {
let view =
decoder_row(&w_dec, index, target_offset, self.config.schema)?;
let dims = view.dims().to_vec();
let values = view.to_dtype(DType::F32)?.to_vec1::<f32>()?;
let independent =
Tensor::from_vec(values, dims.as_slice(), &Device::Cpu)?;
cpu_columns.push((fid, *target_layer, independent));
}
}
}
}
for (fid, target_layer, cpu_tensor) in cpu_columns {
let cache_key = (fid, target_layer);
if let std::collections::hash_map::Entry::Vacant(e) =
self.steering_cache.entry(cache_key)
{
let device_tensor = cpu_tensor.to_device(device)?;
e.insert(device_tensor);
loaded += 1;
}
}
}
info!(
"Cached {loaded} new steering vectors ({} total in cache)",
self.steering_cache.len()
);
Ok(())
}
pub fn cache_steering_vectors_all_downstream(
&mut self,
features: &[CltFeatureId],
device: &Device,
) -> Result<()> {
let n_layers = self.config.n_layers;
let mut by_source: HashMap<usize, Vec<usize>> = HashMap::new();
for fid in features {
if fid.layer >= n_layers {
return Err(MIError::Config(format!(
"feature source layer {} out of range (max {})",
fid.layer,
n_layers - 1
)));
}
by_source.entry(fid.layer).or_default().push(fid.index);
}
let mut loaded = 0_usize;
let n_source_layers = by_source.len();
for (layer_idx, (source_layer, indices)) in by_source.iter().enumerate() {
let n_target_layers = if matches!(self.config.schema, TranscoderSchema::CltSplit) {
n_layers - source_layer
} else {
1
};
info!(
"cache_steering_vectors_all_downstream: loading decoder for source layer {} \
({}/{}, {} downstream layers)",
source_layer,
layer_idx + 1,
n_source_layers,
n_target_layers
);
let mut cpu_columns: Vec<(CltFeatureId, usize, Tensor)> = Vec::new();
{
let dec_path = self.ensure_decoder_path(*source_layer)?;
let data = std::fs::read(&dec_path)?;
info!(
"cache_steering_vectors_all_downstream: loaded {} MB for layer {}",
data.len() / (1024 * 1024),
source_layer
);
let st = SafeTensors::deserialize(&data).map_err(|e| {
MIError::Config(format!(
"failed to deserialize decoder layer {source_layer}: {e}"
))
})?;
let (_, dec_tensor_name) = decoder_file_and_tensor_name(
self.config.schema,
*source_layer,
&self.config.gemmascope_npz_paths,
)?;
let w_dec = tensor_from_view(
&st.tensor(&dec_tensor_name).map_err(|e| {
MIError::Config(format!("tensor '{dec_tensor_name}' not found: {e}"))
})?,
&Device::Cpu,
)?;
for &index in indices {
let fid = CltFeatureId {
layer: *source_layer,
index,
};
for target_offset in 0..n_target_layers {
let target_layer = source_layer + target_offset;
let cache_key = (fid, target_layer);
if !self.steering_cache.contains_key(&cache_key) {
let view =
decoder_row(&w_dec, index, target_offset, self.config.schema)?;
let dims = view.dims().to_vec();
let values = view.to_dtype(DType::F32)?.to_vec1::<f32>()?;
let independent =
Tensor::from_vec(values, dims.as_slice(), &Device::Cpu)?;
cpu_columns.push((fid, target_layer, independent));
}
}
}
}
for (fid, target_layer, cpu_tensor) in cpu_columns {
let cache_key = (fid, target_layer);
if let std::collections::hash_map::Entry::Vacant(e) =
self.steering_cache.entry(cache_key)
{
let device_tensor = cpu_tensor.to_device(device)?;
e.insert(device_tensor);
loaded += 1;
}
}
}
info!(
"Cached {loaded} new steering vectors across all downstream layers ({} total in cache)",
self.steering_cache.len()
);
Ok(())
}
pub fn clear_steering_cache(&mut self) {
let count = self.steering_cache.len();
self.steering_cache.clear();
if count > 0 {
info!("Cleared {count} steering vectors from cache");
}
}
#[must_use]
pub fn steering_cache_len(&self) -> usize {
self.steering_cache.len()
}
pub fn prepare_hook_injection(
&self,
features: &[(CltFeatureId, usize)],
position: usize,
seq_len: usize,
strength: f32,
device: &Device,
) -> Result<crate::hooks::HookSpec> {
use crate::hooks::{HookPoint, HookSpec, Intervention};
let mut per_layer: HashMap<usize, Tensor> = HashMap::new();
for (feature, target_layer) in features {
let cache_key = (*feature, *target_layer);
let cached = self.steering_cache.get(&cache_key).ok_or_else(|| {
MIError::Hook(format!(
"feature {feature} for target layer {target_layer} not in steering cache \
— call cache_steering_vectors() first"
))
})?;
let cached_f32 = cached.to_dtype(DType::F32)?;
if let Some(acc) = per_layer.get_mut(target_layer) {
let acc_ref: &Tensor = acc;
*acc = (acc_ref + &cached_f32)?;
} else {
per_layer.insert(*target_layer, cached_f32);
}
}
let mut hooks = HookSpec::new();
let d_model = self.config.d_model;
for (target_layer, accumulated) in &per_layer {
let scaled = (accumulated * f64::from(strength))?;
let mut injection = Tensor::zeros((1, seq_len, d_model), DType::F32, device)?;
let scaled_3d = scaled.unsqueeze(0)?.unsqueeze(0)?; let before = if position > 0 {
Some(injection.narrow(1, 0, position)?)
} else {
None
};
let after = if position + 1 < seq_len {
Some(injection.narrow(1, position + 1, seq_len - position - 1)?)
} else {
None
};
let mut parts: Vec<Tensor> = Vec::with_capacity(3);
if let Some(b) = before {
parts.push(b);
}
parts.push(scaled_3d);
if let Some(a) = after {
parts.push(a);
}
injection = Tensor::cat(&parts, 1)?;
hooks.intervene(
HookPoint::ResidPost(*target_layer),
Intervention::Add(injection),
);
}
Ok(hooks)
}
pub fn inject(
&self,
residual: &Tensor,
features: &[(CltFeatureId, usize)],
position: usize,
strength: f32,
) -> Result<Tensor> {
let (batch, seq_len, d_model) = residual.dims3()?;
if position >= seq_len {
return Err(MIError::Config(format!(
"injection position {position} out of range (seq_len={seq_len})"
)));
}
if d_model != self.config.d_model {
return Err(MIError::Config(format!(
"residual d_model={d_model} doesn't match CLT d_model={}",
self.config.d_model
)));
}
let mut accumulated = Tensor::zeros((d_model,), DType::F32, residual.device())?;
for (feature, target_layer) in features {
let cache_key = (*feature, *target_layer);
let cached = self.steering_cache.get(&cache_key).ok_or_else(|| {
MIError::Hook(format!(
"feature {feature} for target layer {target_layer} not in steering cache"
))
})?;
let cached_f32 = cached.to_dtype(DType::F32)?;
accumulated = (&accumulated + &cached_f32)?;
}
let accumulated = (accumulated * f64::from(strength))?;
let accumulated = accumulated.to_dtype(residual.dtype())?;
let pos_slice = residual.narrow(1, position, 1)?; let steering_expanded = accumulated
.unsqueeze(0)?
.unsqueeze(0)?
.expand((batch, 1, d_model))?; let pos_updated = (&pos_slice + &steering_expanded)?;
let mut parts: Vec<Tensor> = Vec::with_capacity(3);
if position > 0 {
parts.push(residual.narrow(1, 0, position)?);
}
parts.push(pos_updated);
if position + 1 < seq_len {
parts.push(residual.narrow(1, position + 1, seq_len - position - 1)?);
}
let result = Tensor::cat(&parts, 1)?;
Ok(result)
}
pub fn score_features_by_decoder_projection(
&mut self,
direction: &Tensor,
target_layer: usize,
top_k: usize,
cosine: bool,
) -> Result<Vec<(CltFeatureId, f32)>> {
let d_model = self.config.d_model;
if direction.dims() != [d_model] {
return Err(MIError::Config(format!(
"direction must have shape [{d_model}], got {:?}",
direction.dims()
)));
}
if target_layer >= self.config.n_layers {
return Err(MIError::Config(format!(
"target layer {target_layer} out of range (max {})",
self.config.n_layers - 1
)));
}
let direction_f32 = direction.to_dtype(DType::F32)?.to_device(&Device::Cpu)?;
let direction_norm = if cosine {
let norm: f32 = direction_f32.sqr()?.sum_all()?.sqrt()?.to_scalar()?;
if norm > 1e-10 {
direction_f32.broadcast_div(&Tensor::new(norm, &Device::Cpu)?)?
} else {
direction_f32
}
} else {
direction_f32
};
let mut all_scores: Vec<(CltFeatureId, f32)> = Vec::new();
for source_layer in 0..self.config.n_layers {
if target_layer < source_layer {
continue; }
if !self.config.schema.is_cross_layer() && source_layer != target_layer {
continue;
}
let target_offset = target_layer - source_layer;
let dec_path = self.ensure_decoder_path(source_layer)?;
let data = std::fs::read(&dec_path)?;
info!(
"score_features_by_decoder_projection: loaded {} MB for layer {}",
data.len() / (1024 * 1024),
source_layer
);
let st = SafeTensors::deserialize(&data).map_err(|e| {
MIError::Config(format!(
"failed to deserialize decoder layer {source_layer}: {e}"
))
})?;
let (_, dec_tensor_name) = decoder_file_and_tensor_name(
self.config.schema,
source_layer,
&self.config.gemmascope_npz_paths,
)?;
let w_dec = tensor_from_view(
&st.tensor(&dec_tensor_name).map_err(|e| {
MIError::Config(format!("tensor '{dec_tensor_name}' not found: {e}"))
})?,
&Device::Cpu,
)?;
let w_dec_f32 = w_dec.to_dtype(DType::F32)?;
let dec_slice = decoder_layer_slice(&w_dec_f32, target_offset, self.config.schema)?;
let raw_scores = dec_slice
.matmul(&direction_norm.unsqueeze(1)?)?
.squeeze(1)?;
let scores_vec: Vec<f32> = if cosine {
let dec_norms = dec_slice.sqr()?.sum(1)?.sqrt()?;
let cosine_scores = raw_scores.broadcast_div(&dec_norms)?;
cosine_scores.to_vec1()?
} else {
raw_scores.to_vec1()?
};
for (idx, &score) in scores_vec.iter().enumerate() {
if score.is_finite() {
all_scores.push((
CltFeatureId {
layer: source_layer,
index: idx,
},
score,
));
}
}
info!(
"Scored {} features at source layer {source_layer} (target layer {target_layer})",
scores_vec.len()
);
}
all_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
all_scores.truncate(top_k);
Ok(all_scores)
}
#[allow(clippy::too_many_lines)]
pub fn score_features_by_decoder_projection_batch(
&mut self,
directions: &[Tensor],
target_layer: usize,
top_k: usize,
cosine: bool,
) -> Result<Vec<Vec<(CltFeatureId, f32)>>> {
let d_model = self.config.d_model;
let n_words = directions.len();
if n_words == 0 {
return Err(MIError::Config(
"at least one direction vector required".into(),
));
}
for (i, dir) in directions.iter().enumerate() {
if dir.dims() != [d_model] {
return Err(MIError::Config(format!(
"direction vector {i} must have shape [{d_model}], got {:?}",
dir.dims()
)));
}
}
if target_layer >= self.config.n_layers {
return Err(MIError::Config(format!(
"target layer {target_layer} out of range (max {})",
self.config.n_layers - 1
)));
}
let dirs_f32: Vec<Tensor> = directions
.iter()
.map(|d| d.to_dtype(DType::F32)?.to_device(&Device::Cpu))
.collect::<std::result::Result<_, _>>()?;
let stacked = Tensor::stack(&dirs_f32, 0)?;
let stacked_norm = if cosine {
let norms = stacked.sqr()?.sum(1)?.sqrt()?; let ones = Tensor::ones_like(&norms)?;
let safe_norms = norms.maximum(&(&ones * 1e-10f64)?)?; stacked.broadcast_div(&safe_norms.unsqueeze(1)?)?
} else {
stacked
};
let directions_t = stacked_norm.t()?;
let mut all_scores: Vec<Vec<(CltFeatureId, f32)>> =
(0..n_words).map(|_| Vec::new()).collect();
for source_layer in 0..self.config.n_layers {
if target_layer < source_layer {
continue;
}
if !self.config.schema.is_cross_layer() && source_layer != target_layer {
continue;
}
let target_offset = target_layer - source_layer;
let dec_path = self.ensure_decoder_path(source_layer)?;
let data = std::fs::read(&dec_path)?;
info!(
"score_features_batch: loaded {} MB for layer {}",
data.len() / (1024 * 1024),
source_layer
);
let st = SafeTensors::deserialize(&data).map_err(|e| {
MIError::Config(format!(
"failed to deserialize decoder layer {source_layer}: {e}"
))
})?;
let (_, dec_tensor_name) = decoder_file_and_tensor_name(
self.config.schema,
source_layer,
&self.config.gemmascope_npz_paths,
)?;
let w_dec = tensor_from_view(
&st.tensor(&dec_tensor_name).map_err(|e| {
MIError::Config(format!("tensor '{dec_tensor_name}' not found: {e}"))
})?,
&Device::Cpu,
)?;
let w_dec_f32 = w_dec.to_dtype(DType::F32)?;
let dec_slice = decoder_layer_slice(&w_dec_f32, target_offset, self.config.schema)?;
let raw_scores = dec_slice.matmul(&directions_t)?;
let scores_2d: Vec<Vec<f32>> = if cosine {
let dec_norms = dec_slice.sqr()?.sum(1)?.sqrt()?; let cosine_scores = raw_scores.broadcast_div(&dec_norms.unsqueeze(1)?)?;
cosine_scores.t()?.to_vec2()?
} else {
raw_scores.t()?.to_vec2()?
};
for (w, word_scores) in scores_2d.iter().enumerate() {
for (idx, &score) in word_scores.iter().enumerate() {
if score.is_finite()
&& let Some(word_vec) = all_scores.get_mut(w)
{
word_vec.push((
CltFeatureId {
layer: source_layer,
index: idx,
},
score,
));
}
}
}
info!(
"Batch scored {} words × {} features at source layer {} (target layer {})",
n_words,
scores_2d.first().map_or(0, Vec::len),
source_layer,
target_layer
);
}
for word_scores in &mut all_scores {
word_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
word_scores.truncate(top_k);
}
Ok(all_scores)
}
pub fn extract_decoder_vectors(
&mut self,
features: &[CltFeatureId],
target_layer: usize,
) -> Result<HashMap<CltFeatureId, Tensor>> {
if target_layer >= self.config.n_layers {
return Err(MIError::Config(format!(
"target layer {target_layer} out of range (max {})",
self.config.n_layers - 1
)));
}
let mut by_source: HashMap<usize, Vec<usize>> = HashMap::new();
for fid in features {
if fid.layer >= self.config.n_layers {
return Err(MIError::Config(format!(
"feature source layer {} out of range (max {})",
fid.layer,
self.config.n_layers - 1
)));
}
if target_layer < fid.layer {
return Err(MIError::Config(format!(
"target layer {target_layer} must be >= source layer {}",
fid.layer
)));
}
by_source.entry(fid.layer).or_default().push(fid.index);
}
let mut result: HashMap<CltFeatureId, Tensor> = HashMap::new();
let n_source_layers = by_source.len();
for (layer_idx, (source_layer, indices)) in by_source.iter().enumerate() {
info!(
"extract_decoder_vectors: loading decoder for source layer {} ({}/{})",
source_layer,
layer_idx + 1,
n_source_layers
);
let target_offset = target_layer - source_layer;
let dec_path = self.ensure_decoder_path(*source_layer)?;
let data = std::fs::read(&dec_path)?;
let st = SafeTensors::deserialize(&data).map_err(|e| {
MIError::Config(format!(
"failed to deserialize decoder layer {source_layer}: {e}"
))
})?;
let (_, dec_tensor_name) = decoder_file_and_tensor_name(
self.config.schema,
*source_layer,
&self.config.gemmascope_npz_paths,
)?;
let w_dec = tensor_from_view(
&st.tensor(&dec_tensor_name).map_err(|e| {
MIError::Config(format!("tensor '{dec_tensor_name}' not found: {e}"))
})?,
&Device::Cpu,
)?;
for &index in indices {
let fid = CltFeatureId {
layer: *source_layer,
index,
};
if let std::collections::hash_map::Entry::Vacant(e) = result.entry(fid) {
let view = decoder_row(&w_dec, index, target_offset, self.config.schema)?;
let dims = view.dims().to_vec();
let values = view.to_dtype(DType::F32)?.to_vec1::<f32>()?;
let independent = Tensor::from_vec(values, dims.as_slice(), &Device::Cpu)?;
e.insert(independent);
}
}
}
info!(
"Extracted {} decoder vectors across {} source layers",
result.len(),
n_source_layers
);
Ok(result)
}
pub fn build_attribution_graph(
&mut self,
direction: &Tensor,
target_layer: usize,
top_k: usize,
cosine: bool,
) -> Result<AttributionGraph> {
let scored =
self.score_features_by_decoder_projection(direction, target_layer, top_k, cosine)?;
Ok(AttributionGraph {
target_layer,
edges: scored
.into_iter()
.map(|(feature, score)| AttributionEdge { feature, score })
.collect(),
})
}
pub fn build_attribution_graph_batch(
&mut self,
directions: &[Tensor],
target_layer: usize,
top_k: usize,
cosine: bool,
) -> Result<Vec<AttributionGraph>> {
let batch = self.score_features_by_decoder_projection_batch(
directions,
target_layer,
top_k,
cosine,
)?;
Ok(batch
.into_iter()
.map(|scored| AttributionGraph {
target_layer,
edges: scored
.into_iter()
.map(|(feature, score)| AttributionEdge { feature, score })
.collect(),
})
.collect())
}
}
fn classify_transcoder_schema(filenames: &[&str]) -> Result<TranscoderSchema> {
let has_clt_split = filenames
.iter()
.any(|f| f.starts_with("W_enc_") && f.ends_with(".safetensors"));
let has_plt_bundle = filenames
.iter()
.any(|f| f.starts_with("layer_") && f.ends_with(".safetensors"));
let has_gemmascope_npz_direct = filenames.iter().any(|f| {
f.starts_with("layer_")
&& f.contains("/width_")
&& f.contains("/average_l0_")
&& f.ends_with("/params.npz")
});
let has_config_yaml = filenames.contains(&"config.yaml");
let has_gemmascope_bin_metadata = filenames.iter().any(|f| {
f.starts_with("features/layer_")
&& std::path::Path::new(f)
.extension()
.is_some_and(|ext| ext.eq_ignore_ascii_case("bin"))
});
let has_gemmascope_metadata_repo =
has_config_yaml && has_gemmascope_bin_metadata && !has_clt_split && !has_plt_bundle;
if has_clt_split {
Ok(TranscoderSchema::CltSplit)
} else if has_plt_bundle {
Ok(TranscoderSchema::PltBundle)
} else if has_gemmascope_npz_direct || has_gemmascope_metadata_repo {
Ok(TranscoderSchema::GemmaScopeNpz)
} else {
Err(MIError::Config(
"unrecognised transcoder repo layout".into(),
))
}
}
fn encoder_file_and_tensor_names(
schema: TranscoderSchema,
layer: usize,
gemmascope_npz_paths: &[String],
) -> Result<(String, String, String)> {
match schema {
TranscoderSchema::CltSplit => Ok((
format!("W_enc_{layer}.safetensors"),
format!("W_enc_{layer}"),
format!("b_enc_{layer}"),
)),
TranscoderSchema::PltBundle => Ok((
format!("layer_{layer}.safetensors"),
"W_enc".to_owned(),
"b_enc".to_owned(),
)),
TranscoderSchema::GemmaScopeNpz => {
let path = gemmascope_npz_paths.get(layer).ok_or_else(|| {
MIError::Config(format!(
"GemmaScope NPZ path for layer {layer} missing (gemmascope_npz_paths has {} entries)",
gemmascope_npz_paths.len()
))
})?;
Ok((path.clone(), "W_enc".to_owned(), "b_enc".to_owned()))
}
}
}
fn decoder_file_and_tensor_name(
schema: TranscoderSchema,
layer: usize,
gemmascope_npz_paths: &[String],
) -> Result<(String, String)> {
match schema {
TranscoderSchema::CltSplit => Ok((
format!("W_dec_{layer}.safetensors"),
format!("W_dec_{layer}"),
)),
TranscoderSchema::PltBundle => {
Ok((format!("layer_{layer}.safetensors"), "W_dec".to_owned()))
}
TranscoderSchema::GemmaScopeNpz => {
let path = gemmascope_npz_paths.get(layer).ok_or_else(|| {
MIError::Config(format!(
"GemmaScope NPZ path for layer {layer} missing (gemmascope_npz_paths has {} entries)",
gemmascope_npz_paths.len()
))
})?;
Ok((path.clone(), "W_dec".to_owned()))
}
}
}
fn decoder_row(
w_dec: &Tensor,
feature_index: usize,
target_offset: usize,
schema: TranscoderSchema,
) -> Result<Tensor> {
match schema {
TranscoderSchema::CltSplit => Ok(w_dec.i((feature_index, target_offset))?),
TranscoderSchema::PltBundle | TranscoderSchema::GemmaScopeNpz => {
if target_offset != 0 {
return Err(MIError::Config(format!(
"per-layer schema {schema:?} only writes to its own layer \
(target_offset must be 0, got {target_offset})"
)));
}
Ok(w_dec.i(feature_index)?)
}
}
}
fn decoder_layer_slice(
w_dec: &Tensor,
target_offset: usize,
schema: TranscoderSchema,
) -> Result<Tensor> {
match schema {
TranscoderSchema::CltSplit => Ok(w_dec.i((.., target_offset, ..))?),
TranscoderSchema::PltBundle | TranscoderSchema::GemmaScopeNpz => {
if target_offset != 0 {
return Err(MIError::Config(format!(
"per-layer schema {schema:?} only writes to its own layer \
(target_offset must be 0, got {target_offset})"
)));
}
Ok(w_dec.clone())
}
}
}
fn tensor_from_view(view: &safetensors::tensor::TensorView<'_>, device: &Device) -> Result<Tensor> {
let shape: Vec<usize> = view.shape().to_vec();
#[allow(clippy::wildcard_enum_match_arm)]
let dtype = match view.dtype() {
safetensors::Dtype::BF16 => DType::BF16,
safetensors::Dtype::F16 => DType::F16,
safetensors::Dtype::F32 => DType::F32,
other => {
return Err(MIError::Config(format!(
"unsupported CLT tensor dtype: {other:?}"
)));
}
};
let tensor = Tensor::from_raw_buffer(view.data(), dtype, &shape, device)?;
Ok(tensor)
}
fn parse_yaml_value(yaml_text: &str, key: &str) -> Option<String> {
for line in yaml_text.lines() {
let line = line.trim();
if let Some(rest) = line.strip_prefix(key)
&& let Some(rest) = rest.strip_prefix(':')
{
let value = rest.trim().trim_matches('"');
return Some(value.to_owned());
}
}
None
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn clt_feature_id_display() {
let fid = CltFeatureId {
layer: 5,
index: 42,
};
assert_eq!(fid.to_string(), "L5:42");
}
#[test]
fn clt_feature_id_ordering() {
let a = CltFeatureId {
layer: 0,
index: 10,
};
let b = CltFeatureId {
layer: 0,
index: 20,
};
let c = CltFeatureId { layer: 1, index: 0 };
assert!(a < b);
assert!(b < c);
}
#[test]
fn sparse_activations_basics() {
let features = vec![
(CltFeatureId { layer: 0, index: 5 }, 3.0),
(CltFeatureId { layer: 0, index: 2 }, 2.0),
(CltFeatureId { layer: 0, index: 8 }, 1.0),
];
let sparse = SparseActivations { features };
assert_eq!(sparse.len(), 3);
assert!(!sparse.is_empty());
}
#[test]
fn sparse_activations_truncate() {
let features = vec![
(CltFeatureId { layer: 0, index: 5 }, 3.0),
(CltFeatureId { layer: 0, index: 2 }, 2.0),
(CltFeatureId { layer: 0, index: 8 }, 1.0),
];
let mut sparse = SparseActivations { features };
sparse.truncate(2);
assert_eq!(sparse.len(), 2);
assert_eq!(sparse.features[0].0.index, 5);
assert_eq!(sparse.features[1].0.index, 2);
}
#[test]
fn parse_yaml_value_basic() {
let yaml = "model_name: \"google/gemma-2-2b\"\nmodel_kind: cross_layer_transcoder\n";
assert_eq!(
parse_yaml_value(yaml, "model_name"),
Some("google/gemma-2-2b".to_owned())
);
assert_eq!(
parse_yaml_value(yaml, "model_kind"),
Some("cross_layer_transcoder".to_owned())
);
assert_eq!(parse_yaml_value(yaml, "missing_key"), None);
}
#[test]
fn encode_synthetic() {
let device = Device::Cpu;
let d_model = 8;
let n_features = 4;
#[rustfmt::skip]
let w_enc_data: Vec<f32> = vec![
1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ];
let w_enc = Tensor::from_vec(w_enc_data, (n_features, d_model), &device).unwrap();
let b_enc_data: Vec<f32> = vec![0.0, -0.5, 0.0, -2.0]; let b_enc = Tensor::from_vec(b_enc_data, (n_features,), &device).unwrap();
let residual_data: Vec<f32> = vec![1.5, 0.3, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0];
let residual = Tensor::from_vec(residual_data, (d_model,), &device).unwrap();
let clt = CrossLayerTranscoder {
repo_id: "test".to_owned(),
fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
encoder_paths: vec![None],
decoder_paths: vec![None],
config: CltConfig {
n_layers: 1,
d_model,
n_features_per_layer: n_features,
n_features_total: n_features,
model_name: "test".to_owned(),
schema: TranscoderSchema::CltSplit,
gemmascope_npz_paths: Vec::new(),
},
loaded_encoder: Some(LoadedEncoder {
layer: 0,
w_enc,
b_enc,
}),
steering_cache: HashMap::new(),
};
let sparse = clt.encode(&residual, 0).unwrap();
assert_eq!(sparse.len(), 1, "only feature 0 should be active");
assert_eq!(sparse.features[0].0.index, 0);
assert!((sparse.features[0].1 - 1.5).abs() < 1e-5);
}
#[test]
fn encode_wrong_layer_errors() {
let device = Device::Cpu;
let w_enc = Tensor::zeros((4, 8), DType::F32, &device).unwrap();
let b_enc = Tensor::zeros((4,), DType::F32, &device).unwrap();
let residual = Tensor::zeros((8,), DType::F32, &device).unwrap();
let clt = CrossLayerTranscoder {
repo_id: "test".to_owned(),
fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
encoder_paths: vec![None; 2],
decoder_paths: vec![None; 2],
config: CltConfig {
n_layers: 2,
d_model: 8,
n_features_per_layer: 4,
n_features_total: 8,
model_name: "test".to_owned(),
schema: TranscoderSchema::CltSplit,
gemmascope_npz_paths: Vec::new(),
},
loaded_encoder: Some(LoadedEncoder {
layer: 0,
w_enc,
b_enc,
}),
steering_cache: HashMap::new(),
};
let result = clt.encode(&residual, 1);
assert!(result.is_err());
}
#[test]
fn inject_position() {
let device = Device::Cpu;
let d_model = 4;
let residual = Tensor::ones((1, 3, d_model), DType::F32, &device).unwrap();
let fid = CltFeatureId { layer: 0, index: 0 };
let target_layer = 1;
let steering_vec =
Tensor::from_vec(vec![10.0_f32, 20.0, 30.0, 40.0], (d_model,), &device).unwrap();
let mut steering_cache = HashMap::new();
steering_cache.insert((fid, target_layer), steering_vec);
let clt = CrossLayerTranscoder {
repo_id: "test".to_owned(),
fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
encoder_paths: vec![None; 2],
decoder_paths: vec![None; 2],
config: CltConfig {
n_layers: 2,
d_model,
n_features_per_layer: 1,
n_features_total: 2,
model_name: "test".to_owned(),
schema: TranscoderSchema::CltSplit,
gemmascope_npz_paths: Vec::new(),
},
loaded_encoder: None,
steering_cache,
};
let result = clt
.inject(&residual, &[(fid, target_layer)], 1, 1.0)
.unwrap();
let pos0: Vec<f32> = result.i((0, 0)).unwrap().to_vec1().unwrap();
assert_eq!(pos0, vec![1.0, 1.0, 1.0, 1.0]);
let pos1: Vec<f32> = result.i((0, 1)).unwrap().to_vec1().unwrap();
assert_eq!(pos1, vec![11.0, 21.0, 31.0, 41.0]);
let pos2: Vec<f32> = result.i((0, 2)).unwrap().to_vec1().unwrap();
assert_eq!(pos2, vec![1.0, 1.0, 1.0, 1.0]);
}
#[test]
fn prepare_hook_injection_creates_correct_hooks() {
use crate::hooks::HookPoint;
let device = Device::Cpu;
let d_model = 4;
let fid = CltFeatureId { layer: 0, index: 0 };
let target_layer = 5;
let steering_vec =
Tensor::from_vec(vec![1.0_f32, 2.0, 3.0, 4.0], (d_model,), &device).unwrap();
let mut steering_cache = HashMap::new();
steering_cache.insert((fid, target_layer), steering_vec);
let clt = CrossLayerTranscoder {
repo_id: "test".to_owned(),
fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
encoder_paths: vec![None; 10],
decoder_paths: vec![None; 10],
config: CltConfig {
n_layers: 10,
d_model,
n_features_per_layer: 1,
n_features_total: 10,
model_name: "test".to_owned(),
schema: TranscoderSchema::CltSplit,
gemmascope_npz_paths: Vec::new(),
},
loaded_encoder: None,
steering_cache,
};
let hooks = clt
.prepare_hook_injection(&[(fid, target_layer)], 2, 5, 1.0, &device)
.unwrap();
assert!(hooks.has_intervention_at(&HookPoint::ResidPost(target_layer)));
assert!(!hooks.has_intervention_at(&HookPoint::ResidPost(0)));
assert!(!hooks.has_intervention_at(&HookPoint::ResidPost(4)));
}
#[test]
fn attribution_edge_basics() {
let edge = AttributionEdge {
feature: CltFeatureId {
layer: 3,
index: 42,
},
score: 0.75,
};
assert_eq!(edge.feature.layer, 3);
assert_eq!(edge.feature.index, 42);
assert!((edge.score - 0.75).abs() < f32::EPSILON);
}
#[test]
fn attribution_graph_empty() {
let graph = AttributionGraph {
target_layer: 5,
edges: Vec::new(),
};
assert_eq!(graph.target_layer(), 5);
assert!(graph.is_empty());
assert_eq!(graph.len(), 0);
assert!(graph.features().is_empty());
assert!(graph.into_edges().is_empty());
}
#[test]
fn attribution_graph_top_k() {
let edges = vec![
AttributionEdge {
feature: CltFeatureId { layer: 0, index: 0 },
score: 5.0,
},
AttributionEdge {
feature: CltFeatureId { layer: 0, index: 1 },
score: 3.0,
},
AttributionEdge {
feature: CltFeatureId { layer: 1, index: 0 },
score: 1.0,
},
AttributionEdge {
feature: CltFeatureId { layer: 1, index: 1 },
score: -1.0,
},
AttributionEdge {
feature: CltFeatureId { layer: 2, index: 0 },
score: -4.0,
},
];
let graph = AttributionGraph {
target_layer: 3,
edges,
};
assert_eq!(graph.len(), 5);
let top3 = graph.top_k(3);
assert_eq!(top3.len(), 3);
assert_eq!(top3.target_layer(), 3);
assert!((top3.edges()[0].score - 5.0).abs() < f32::EPSILON);
assert!((top3.edges()[1].score - 3.0).abs() < f32::EPSILON);
assert!((top3.edges()[2].score - 1.0).abs() < f32::EPSILON);
let top10 = graph.top_k(10);
assert_eq!(top10.len(), 5);
}
#[test]
fn attribution_graph_threshold() {
let edges = vec![
AttributionEdge {
feature: CltFeatureId { layer: 0, index: 0 },
score: 5.0,
},
AttributionEdge {
feature: CltFeatureId { layer: 0, index: 1 },
score: 3.0,
},
AttributionEdge {
feature: CltFeatureId { layer: 1, index: 0 },
score: 1.0,
},
AttributionEdge {
feature: CltFeatureId { layer: 1, index: 1 },
score: -1.0,
},
AttributionEdge {
feature: CltFeatureId { layer: 2, index: 0 },
score: -4.0,
},
];
let graph = AttributionGraph {
target_layer: 3,
edges,
};
let pruned = graph.threshold(2.0);
assert_eq!(pruned.len(), 3);
assert!((pruned.edges()[0].score - 5.0).abs() < f32::EPSILON);
assert!((pruned.edges()[1].score - 3.0).abs() < f32::EPSILON);
assert!((pruned.edges()[2].score - -4.0).abs() < f32::EPSILON);
}
#[test]
fn attribution_graph_features() {
let edges = vec![
AttributionEdge {
feature: CltFeatureId { layer: 2, index: 7 },
score: 1.0,
},
AttributionEdge {
feature: CltFeatureId { layer: 0, index: 3 },
score: 0.5,
},
];
let graph = AttributionGraph {
target_layer: 5,
edges,
};
let features = graph.features();
assert_eq!(features.len(), 2);
assert_eq!(features[0], CltFeatureId { layer: 2, index: 7 });
assert_eq!(features[1], CltFeatureId { layer: 0, index: 3 });
}
fn create_synthetic_decoder(
dir: &std::path::Path,
layer: usize,
n_features: usize,
n_target_layers: usize,
d_model: usize,
values: &[f32],
) -> PathBuf {
assert_eq!(values.len(), n_features * n_target_layers * d_model);
let bytes: Vec<u8> = values.iter().flat_map(|v| v.to_le_bytes()).collect();
let name = format!("W_dec_{layer}");
let shape = vec![n_features, n_target_layers, d_model];
let view =
safetensors::tensor::TensorView::new(safetensors::Dtype::F32, shape, &bytes).unwrap();
let mut tensors = HashMap::new();
tensors.insert(name, view);
let serialized = safetensors::serialize(&tensors, &None).unwrap();
let path = dir.join(format!("W_dec_{layer}.safetensors"));
std::fs::write(&path, serialized).unwrap();
path
}
#[test]
fn score_decoder_projection_synthetic() {
let dir = tempfile::tempdir().unwrap();
let d_model = 4;
let n_features = 4;
#[rustfmt::skip]
let dec0_values: Vec<f32> = vec![
0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
];
let path0 = create_synthetic_decoder(dir.path(), 0, n_features, 2, d_model, &dec0_values);
#[rustfmt::skip]
let dec1_values: Vec<f32> = vec![
2.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
0.0, 3.0, 0.0, 0.0,
];
let path1 = create_synthetic_decoder(dir.path(), 1, n_features, 1, d_model, &dec1_values);
let mut clt = CrossLayerTranscoder {
repo_id: "test".to_owned(),
fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
encoder_paths: vec![None; 2],
decoder_paths: vec![Some(path0), Some(path1)],
config: CltConfig {
n_layers: 2,
d_model,
n_features_per_layer: n_features,
n_features_total: n_features * 2,
model_name: "test".to_owned(),
schema: TranscoderSchema::CltSplit,
gemmascope_npz_paths: Vec::new(),
},
loaded_encoder: None,
steering_cache: HashMap::new(),
};
let direction =
Tensor::from_vec(vec![1.0_f32, 0.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
let scores = clt
.score_features_by_decoder_projection(&direction, 1, 10, false)
.unwrap();
assert!(scores.len() >= 2, "expected at least 2 non-zero scores");
assert_eq!(scores[0].0, CltFeatureId { layer: 1, index: 0 });
assert!((scores[0].1 - 2.0).abs() < 1e-5);
assert_eq!(scores[1].0, CltFeatureId { layer: 0, index: 0 });
assert!((scores[1].1 - 1.0).abs() < 1e-5);
let direction2 =
Tensor::from_vec(vec![0.0_f32, 1.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
let scores2 = clt
.score_features_by_decoder_projection(&direction2, 1, 10, false)
.unwrap();
assert_eq!(scores2[0].0, CltFeatureId { layer: 1, index: 3 });
assert!((scores2[0].1 - 3.0).abs() < 1e-5);
assert_eq!(scores2[1].0, CltFeatureId { layer: 0, index: 1 });
assert!((scores2[1].1 - 1.0).abs() < 1e-5);
}
#[test]
fn score_decoder_projection_cosine_synthetic() {
let dir = tempfile::tempdir().unwrap();
let d_model = 4;
let n_features = 2;
#[rustfmt::skip]
let dec0_values: Vec<f32> = vec![
3.0, 0.0, 0.0, 0.0,
1.0, 1.0, 0.0, 0.0,
];
let path0 = create_synthetic_decoder(dir.path(), 0, n_features, 1, d_model, &dec0_values);
let mut clt = CrossLayerTranscoder {
repo_id: "test".to_owned(),
fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
encoder_paths: vec![None],
decoder_paths: vec![Some(path0)],
config: CltConfig {
n_layers: 1,
d_model,
n_features_per_layer: n_features,
n_features_total: n_features,
model_name: "test".to_owned(),
schema: TranscoderSchema::CltSplit,
gemmascope_npz_paths: Vec::new(),
},
loaded_encoder: None,
steering_cache: HashMap::new(),
};
let direction =
Tensor::from_vec(vec![1.0_f32, 0.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
let dot_scores = clt
.score_features_by_decoder_projection(&direction, 0, 10, false)
.unwrap();
assert!((dot_scores[0].1 - 3.0).abs() < 1e-5);
assert!((dot_scores[1].1 - 1.0).abs() < 1e-5);
let cos_scores = clt
.score_features_by_decoder_projection(&direction, 0, 10, true)
.unwrap();
assert!(
(cos_scores[0].1 - 1.0).abs() < 1e-4,
"expected ~1.0, got {}",
cos_scores[0].1
);
let expected_cos = 1.0 / 2.0_f32.sqrt();
assert!(
(cos_scores[1].1 - expected_cos).abs() < 1e-4,
"expected ~{expected_cos}, got {}",
cos_scores[1].1
);
}
#[test]
fn score_decoder_projection_batch_synthetic() {
let dir = tempfile::tempdir().unwrap();
let d_model = 4;
let n_features = 2;
#[rustfmt::skip]
let dec0_values: Vec<f32> = vec![
1.0, 0.0, 0.0, 0.0,
0.0, 1.0, 0.0, 0.0,
];
let path0 = create_synthetic_decoder(dir.path(), 0, n_features, 1, d_model, &dec0_values);
let mut clt = CrossLayerTranscoder {
repo_id: "test".to_owned(),
fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
encoder_paths: vec![None],
decoder_paths: vec![Some(path0)],
config: CltConfig {
n_layers: 1,
d_model,
n_features_per_layer: n_features,
n_features_total: n_features,
model_name: "test".to_owned(),
schema: TranscoderSchema::CltSplit,
gemmascope_npz_paths: Vec::new(),
},
loaded_encoder: None,
steering_cache: HashMap::new(),
};
let dir0 =
Tensor::from_vec(vec![1.0_f32, 0.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
let dir1 =
Tensor::from_vec(vec![0.0_f32, 1.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
let batch = clt
.score_features_by_decoder_projection_batch(&[dir0, dir1], 0, 10, false)
.unwrap();
assert_eq!(batch.len(), 2);
assert_eq!(batch[0][0].0, CltFeatureId { layer: 0, index: 0 });
assert!((batch[0][0].1 - 1.0).abs() < 1e-5);
assert_eq!(batch[1][0].0, CltFeatureId { layer: 0, index: 1 });
assert!((batch[1][0].1 - 1.0).abs() < 1e-5);
}
#[test]
fn extract_decoder_vectors_synthetic() {
let dir = tempfile::tempdir().unwrap();
let d_model = 4;
let n_features = 3;
#[rustfmt::skip]
let dec0_values: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0,
];
let path0 = create_synthetic_decoder(dir.path(), 0, n_features, 2, d_model, &dec0_values);
let mut clt = CrossLayerTranscoder {
repo_id: "test".to_owned(),
fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
encoder_paths: vec![None; 2],
decoder_paths: vec![Some(path0), None],
config: CltConfig {
n_layers: 2,
d_model,
n_features_per_layer: n_features,
n_features_total: n_features * 2,
model_name: "test".to_owned(),
schema: TranscoderSchema::CltSplit,
gemmascope_npz_paths: Vec::new(),
},
loaded_encoder: None,
steering_cache: HashMap::new(),
};
let features = vec![
CltFeatureId { layer: 0, index: 0 },
CltFeatureId { layer: 0, index: 2 },
];
let vectors = clt.extract_decoder_vectors(&features, 1).unwrap();
assert_eq!(vectors.len(), 2);
let v0: Vec<f32> = vectors[&CltFeatureId { layer: 0, index: 0 }]
.to_vec1()
.unwrap();
assert_eq!(v0, vec![5.0, 6.0, 7.0, 8.0]);
let v2: Vec<f32> = vectors[&CltFeatureId { layer: 0, index: 2 }]
.to_vec1()
.unwrap();
assert_eq!(v2, vec![21.0, 22.0, 23.0, 24.0]);
}
#[test]
fn build_attribution_graph_synthetic() {
let dir = tempfile::tempdir().unwrap();
let d_model = 4;
let n_features = 2;
#[rustfmt::skip]
let dec0_values: Vec<f32> = vec![
1.0, 0.0, 0.0, 0.0,
0.0, 2.0, 0.0, 0.0,
];
let path0 = create_synthetic_decoder(dir.path(), 0, n_features, 1, d_model, &dec0_values);
let mut clt = CrossLayerTranscoder {
repo_id: "test".to_owned(),
fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
encoder_paths: vec![None],
decoder_paths: vec![Some(path0)],
config: CltConfig {
n_layers: 1,
d_model,
n_features_per_layer: n_features,
n_features_total: n_features,
model_name: "test".to_owned(),
schema: TranscoderSchema::CltSplit,
gemmascope_npz_paths: Vec::new(),
},
loaded_encoder: None,
steering_cache: HashMap::new(),
};
let direction =
Tensor::from_vec(vec![0.0_f32, 1.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
let graph = clt
.build_attribution_graph(&direction, 0, 10, false)
.unwrap();
assert_eq!(graph.target_layer(), 0);
assert!(!graph.is_empty());
assert_eq!(
graph.edges()[0].feature,
CltFeatureId { layer: 0, index: 1 }
);
assert!((graph.edges()[0].score - 2.0).abs() < 1e-5);
let pruned = graph.threshold(1.0);
assert_eq!(pruned.len(), 1);
assert_eq!(pruned.features()[0], CltFeatureId { layer: 0, index: 1 });
}
#[test]
fn classify_clt_split_layout() {
let files = [
"W_enc_0.safetensors",
"W_enc_1.safetensors",
"W_dec_0.safetensors",
"W_dec_1.safetensors",
"config.yaml",
];
let schema = classify_transcoder_schema(&files).unwrap();
assert_eq!(schema, TranscoderSchema::CltSplit);
assert!(schema.is_cross_layer());
assert!(!schema.is_jump_relu());
}
#[test]
fn classify_plt_bundle_layout() {
let files = [
"layer_0.safetensors",
"layer_1.safetensors",
"features/layer_0.bin", "features/layer_1.bin",
"README.md",
];
let schema = classify_transcoder_schema(&files).unwrap();
assert_eq!(schema, TranscoderSchema::PltBundle);
assert!(!schema.is_cross_layer());
assert!(!schema.is_jump_relu());
}
#[test]
fn classify_gemmascope_metadata_repo() {
let files = [
"config.yaml",
"features/index.json.gz",
"features/layer_0.bin",
"features/layer_1.bin",
];
let schema = classify_transcoder_schema(&files).unwrap();
assert_eq!(schema, TranscoderSchema::GemmaScopeNpz);
assert!(!schema.is_cross_layer());
assert!(schema.is_jump_relu());
}
#[test]
fn classify_gemmascope_npz_direct() {
let files = [
"layer_0/width_16k/average_l0_100/params.npz",
"layer_1/width_16k/average_l0_105/params.npz",
"layer_2/width_16k/average_l0_108/params.npz",
];
let schema = classify_transcoder_schema(&files).unwrap();
assert_eq!(schema, TranscoderSchema::GemmaScopeNpz);
}
#[test]
fn classify_unrecognised_layout_errors() {
let files = ["random_file.txt", "README.md"];
let err = classify_transcoder_schema(&files).unwrap_err();
match err {
MIError::Config(msg) => assert!(
msg.contains("unrecognised transcoder repo layout"),
"unexpected error message: {msg}"
),
other => panic!("expected MIError::Config, got {other:?}"),
}
}
#[test]
fn classify_empty_listing_errors() {
let files: [&str; 0] = [];
let err = classify_transcoder_schema(&files).unwrap_err();
assert!(matches!(err, MIError::Config(_)));
}
#[test]
fn classify_prefers_clt_split_over_plt_bundle() {
let files = ["W_enc_0.safetensors", "layer_0.safetensors"];
let schema = classify_transcoder_schema(&files).unwrap();
assert_eq!(schema, TranscoderSchema::CltSplit);
}
#[test]
fn classify_bin_files_alone_are_not_enough() {
let files = ["features/layer_0.bin", "features/layer_1.bin"];
assert!(classify_transcoder_schema(&files).is_err());
}
#[test]
fn gemmascope_deferral_error_message_is_informative() {
assert!(
GEMMASCOPE_DEFERRAL_ERR.contains("v0.1.10"),
"deferral error must mention the follow-up release"
);
assert!(
GEMMASCOPE_DEFERRAL_ERR.contains("Step 1.6"),
"deferral error must point at the roadmap step"
);
}
#[test]
fn encoder_file_and_tensor_names_clt_split() {
let (filename, w_enc_name, b_enc_name) =
encoder_file_and_tensor_names(TranscoderSchema::CltSplit, 7, &[]).unwrap();
assert_eq!(filename, "W_enc_7.safetensors");
assert_eq!(w_enc_name, "W_enc_7");
assert_eq!(b_enc_name, "b_enc_7");
}
#[test]
fn encoder_file_and_tensor_names_plt_bundle() {
let (filename, w_enc_name, b_enc_name) =
encoder_file_and_tensor_names(TranscoderSchema::PltBundle, 3, &[]).unwrap();
assert_eq!(filename, "layer_3.safetensors");
assert_eq!(w_enc_name, "W_enc");
assert_eq!(b_enc_name, "b_enc");
}
#[test]
fn encoder_file_and_tensor_names_gemmascope_needs_paths() {
assert!(encoder_file_and_tensor_names(TranscoderSchema::GemmaScopeNpz, 0, &[]).is_err());
let paths = vec![
"layer_0/width_16k/average_l0_100/params.npz".to_owned(),
"layer_1/width_16k/average_l0_105/params.npz".to_owned(),
];
let (filename, w_enc_name, b_enc_name) =
encoder_file_and_tensor_names(TranscoderSchema::GemmaScopeNpz, 1, &paths).unwrap();
assert_eq!(filename, "layer_1/width_16k/average_l0_105/params.npz");
assert_eq!(w_enc_name, "W_enc");
assert_eq!(b_enc_name, "b_enc");
}
#[test]
fn decoder_file_and_tensor_name_all_schemas() {
let (filename, tname) =
decoder_file_and_tensor_name(TranscoderSchema::CltSplit, 5, &[]).unwrap();
assert_eq!(filename, "W_dec_5.safetensors");
assert_eq!(tname, "W_dec_5");
let (filename, tname) =
decoder_file_and_tensor_name(TranscoderSchema::PltBundle, 5, &[]).unwrap();
assert_eq!(filename, "layer_5.safetensors");
assert_eq!(tname, "W_dec");
assert!(decoder_file_and_tensor_name(TranscoderSchema::GemmaScopeNpz, 0, &[]).is_err());
}
#[test]
fn decoder_row_clt_split_indexes_rank3() {
#[rustfmt::skip]
let values: Vec<f32> = vec![
0.0, 0.0, 0.0, 0.0, 5.0, 6.0, 7.0, 8.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 9.0, 10.0, 11.0, 12.0, 0.0, 0.0, 0.0, 0.0,
];
let w_dec = Tensor::from_vec(values, (2, 3, 4), &Device::Cpu).unwrap();
let row = decoder_row(&w_dec, 0, 1, TranscoderSchema::CltSplit).unwrap();
let got: Vec<f32> = row.to_vec1().unwrap();
assert_eq!(got, vec![5.0, 6.0, 7.0, 8.0]);
let row = decoder_row(&w_dec, 1, 1, TranscoderSchema::CltSplit).unwrap();
let got: Vec<f32> = row.to_vec1().unwrap();
assert_eq!(got, vec![9.0, 10.0, 11.0, 12.0]);
}
#[test]
fn decoder_row_plt_bundle_indexes_rank2_at_offset_zero() {
#[rustfmt::skip]
let values: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ];
let w_dec = Tensor::from_vec(values, (2, 4), &Device::Cpu).unwrap();
let row = decoder_row(&w_dec, 1, 0, TranscoderSchema::PltBundle).unwrap();
let got: Vec<f32> = row.to_vec1().unwrap();
assert_eq!(got, vec![5.0, 6.0, 7.0, 8.0]);
}
#[test]
fn decoder_row_plt_bundle_rejects_nonzero_offset() {
let w_dec = Tensor::zeros((2, 4), DType::F32, &Device::Cpu).unwrap();
let err = decoder_row(&w_dec, 0, 1, TranscoderSchema::PltBundle).unwrap_err();
match err {
MIError::Config(msg) => {
assert!(msg.contains("target_offset must be 0"), "got: {msg}");
}
other => panic!("expected MIError::Config, got {other:?}"),
}
}
#[test]
fn decoder_layer_slice_plt_bundle_rejects_nonzero_offset() {
let w_dec = Tensor::zeros((2, 4), DType::F32, &Device::Cpu).unwrap();
assert!(decoder_layer_slice(&w_dec, 2, TranscoderSchema::PltBundle).is_err());
}
#[test]
fn decoder_layer_slice_plt_bundle_returns_full_rank2_at_offset_zero() {
#[rustfmt::skip]
let values: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0,
];
let w_dec = Tensor::from_vec(values.clone(), (2, 4), &Device::Cpu).unwrap();
let slice = decoder_layer_slice(&w_dec, 0, TranscoderSchema::PltBundle).unwrap();
assert_eq!(slice.dims(), &[2, 4]);
let got: Vec<Vec<f32>> = slice.to_vec2().unwrap();
assert_eq!(got[0], vec![1.0, 2.0, 3.0, 4.0]);
assert_eq!(got[1], vec![5.0, 6.0, 7.0, 8.0]);
}
fn create_synthetic_plt_bundle(
dir: &std::path::Path,
layer: usize,
n_features: usize,
d_model: usize,
w_enc: &[f32],
w_dec: &[f32],
w_skip: &[f32],
b_enc: &[f32],
b_dec: &[f32],
) -> PathBuf {
assert_eq!(w_enc.len(), n_features * d_model);
assert_eq!(
w_dec.len(),
n_features * d_model,
"PltBundle W_dec is rank-2"
);
assert_eq!(w_skip.len(), d_model * d_model);
assert_eq!(b_enc.len(), n_features);
assert_eq!(b_dec.len(), d_model);
let w_enc_bytes: Vec<u8> = w_enc.iter().flat_map(|v| v.to_le_bytes()).collect();
let w_dec_bytes: Vec<u8> = w_dec.iter().flat_map(|v| v.to_le_bytes()).collect();
let w_skip_bytes: Vec<u8> = w_skip.iter().flat_map(|v| v.to_le_bytes()).collect();
let b_enc_bytes: Vec<u8> = b_enc.iter().flat_map(|v| v.to_le_bytes()).collect();
let b_dec_bytes: Vec<u8> = b_dec.iter().flat_map(|v| v.to_le_bytes()).collect();
let mut tensors = HashMap::new();
tensors.insert(
"W_enc".to_owned(),
safetensors::tensor::TensorView::new(
safetensors::Dtype::F32,
vec![n_features, d_model],
&w_enc_bytes,
)
.unwrap(),
);
tensors.insert(
"W_dec".to_owned(),
safetensors::tensor::TensorView::new(
safetensors::Dtype::F32,
vec![n_features, d_model],
&w_dec_bytes,
)
.unwrap(),
);
tensors.insert(
"W_skip".to_owned(),
safetensors::tensor::TensorView::new(
safetensors::Dtype::F32,
vec![d_model, d_model],
&w_skip_bytes,
)
.unwrap(),
);
tensors.insert(
"b_enc".to_owned(),
safetensors::tensor::TensorView::new(
safetensors::Dtype::F32,
vec![n_features],
&b_enc_bytes,
)
.unwrap(),
);
tensors.insert(
"b_dec".to_owned(),
safetensors::tensor::TensorView::new(
safetensors::Dtype::F32,
vec![d_model],
&b_dec_bytes,
)
.unwrap(),
);
let serialized = safetensors::serialize(&tensors, &None).unwrap();
let path = dir.join(format!("layer_{layer}.safetensors"));
std::fs::write(&path, serialized).unwrap();
path
}
#[test]
fn plt_bundle_cache_steering_all_downstream_is_single_entry_per_feature() {
let dir = tempfile::tempdir().unwrap();
let d_model = 4;
let n_features = 2;
#[rustfmt::skip]
let w_dec_0: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0,
];
let path0 = create_synthetic_plt_bundle(
dir.path(),
0,
n_features,
d_model,
&vec![0.0; n_features * d_model],
&w_dec_0,
&vec![0.0; d_model * d_model],
&vec![0.0; n_features],
&vec![0.0; d_model],
);
let path1 = create_synthetic_plt_bundle(
dir.path(),
1,
n_features,
d_model,
&vec![0.0; n_features * d_model],
&vec![99.0; n_features * d_model],
&vec![0.0; d_model * d_model],
&vec![0.0; n_features],
&vec![0.0; d_model],
);
let mut clt = CrossLayerTranscoder {
repo_id: "test".to_owned(),
fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
encoder_paths: vec![Some(path0), Some(path1)],
decoder_paths: vec![None, None],
config: CltConfig {
n_layers: 2,
d_model,
n_features_per_layer: n_features,
n_features_total: n_features * 2,
model_name: "test".to_owned(),
schema: TranscoderSchema::PltBundle,
gemmascope_npz_paths: Vec::new(),
},
loaded_encoder: None,
steering_cache: HashMap::new(),
};
let features = vec![CltFeatureId { layer: 0, index: 0 }];
clt.cache_steering_vectors_all_downstream(&features, &Device::Cpu)
.unwrap();
assert_eq!(
clt.steering_cache_len(),
1,
"PltBundle must cache exactly 1 entry per feature (not n_layers)"
);
let cached = clt
.steering_cache
.get(&(CltFeatureId { layer: 0, index: 0 }, 0))
.expect("entry at (feature, source_layer)");
let values: Vec<f32> = cached.to_vec1().unwrap();
assert_eq!(values, vec![1.0, 2.0, 3.0, 4.0]);
assert!(
!clt.steering_cache
.contains_key(&(CltFeatureId { layer: 0, index: 0 }, 1)),
"PltBundle must not cache downstream entries"
);
}
#[test]
fn clt_split_cache_steering_all_downstream_caches_all_targets() {
let dir = tempfile::tempdir().unwrap();
let d_model = 4;
let n_features = 2;
#[rustfmt::skip]
let dec0_values: Vec<f32> = vec![
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
];
let path0 = create_synthetic_decoder(dir.path(), 0, n_features, 2, d_model, &dec0_values);
#[rustfmt::skip]
let dec1_values: Vec<f32> = vec![
2.0, 0.0, 0.0, 0.0,
0.0, 2.0, 0.0, 0.0,
];
let path1 = create_synthetic_decoder(dir.path(), 1, n_features, 1, d_model, &dec1_values);
let mut clt = CrossLayerTranscoder {
repo_id: "test".to_owned(),
fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
encoder_paths: vec![None, None],
decoder_paths: vec![Some(path0), Some(path1)],
config: CltConfig {
n_layers: 2,
d_model,
n_features_per_layer: n_features,
n_features_total: n_features * 2,
model_name: "test".to_owned(),
schema: TranscoderSchema::CltSplit,
gemmascope_npz_paths: Vec::new(),
},
loaded_encoder: None,
steering_cache: HashMap::new(),
};
let features = vec![CltFeatureId { layer: 0, index: 0 }];
clt.cache_steering_vectors_all_downstream(&features, &Device::Cpu)
.unwrap();
assert_eq!(
clt.steering_cache_len(),
2,
"CltSplit: layer 0 writes to 2 downstream layers"
);
assert!(
clt.steering_cache
.contains_key(&(CltFeatureId { layer: 0, index: 0 }, 0))
);
assert!(
clt.steering_cache
.contains_key(&(CltFeatureId { layer: 0, index: 0 }, 1))
);
}
#[test]
fn encode_pre_activation_matches_encode_postrelu() {
let dir = tempfile::tempdir().unwrap();
let d_model = 4;
let n_features = 5;
#[rustfmt::skip]
let w_enc: Vec<f32> = vec![
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.5, 0.5, 0.5, 0.5, -1.0, -1.0, -1.0, -1.0, ];
let b_enc: Vec<f32> = vec![0.0, 0.0, 0.0, -1.0, 2.0];
let path0 = create_synthetic_plt_bundle(
dir.path(),
0,
n_features,
d_model,
&w_enc,
&vec![0.0; n_features * d_model],
&vec![0.0; d_model * d_model],
&b_enc,
&vec![0.0; d_model],
);
let mut clt = CrossLayerTranscoder {
repo_id: "test".to_owned(),
fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
encoder_paths: vec![Some(path0)],
decoder_paths: vec![None],
config: CltConfig {
n_layers: 1,
d_model,
n_features_per_layer: n_features,
n_features_total: n_features,
model_name: "test".to_owned(),
schema: TranscoderSchema::PltBundle,
gemmascope_npz_paths: Vec::new(),
},
loaded_encoder: None,
steering_cache: HashMap::new(),
};
clt.load_encoder(0, &Device::Cpu).unwrap();
let residual =
Tensor::from_vec(vec![1.0_f32, 2.0, -0.5, 1.5], (d_model,), &Device::Cpu).unwrap();
let pre_acts_tensor = clt.encode_pre_activation(&residual, 0).unwrap();
let pre_acts: Vec<f32> = pre_acts_tensor.to_vec1().unwrap();
assert_eq!(pre_acts.len(), n_features);
let sparse = clt.encode(&residual, 0).unwrap();
for (fid, act) in &sparse.features {
let pre = pre_acts[fid.index];
assert!(
pre > 0.0,
"sparse feature {fid:?} must have positive pre-act"
);
assert!(
(pre - act).abs() < 1e-6,
"feature {fid:?}: sparse={act}, pre={pre}"
);
}
let sparse_indices: std::collections::HashSet<usize> =
sparse.features.iter().map(|(f, _)| f.index).collect();
for (i, &pre) in pre_acts.iter().enumerate() {
if pre <= 0.0 {
assert!(
!sparse_indices.contains(&i),
"feature {i} pre-act {pre} <= 0 but appears in sparse output"
);
}
}
}
#[test]
fn load_skip_matrix_round_trip_plt_bundle() {
let dir = tempfile::tempdir().unwrap();
let d_model = 3;
let n_features = 2;
#[rustfmt::skip]
let w_skip: Vec<f32> = vec![
1.0, 2.0, 3.0,
4.0, 5.0, 6.0,
7.0, 8.0, 9.0,
];
let path0 = create_synthetic_plt_bundle(
dir.path(),
0,
n_features,
d_model,
&vec![0.0; n_features * d_model],
&vec![0.0; n_features * d_model],
&w_skip,
&vec![0.0; n_features],
&vec![0.0; d_model],
);
let mut clt = CrossLayerTranscoder {
repo_id: "test".to_owned(),
fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
encoder_paths: vec![Some(path0)],
decoder_paths: vec![None],
config: CltConfig {
n_layers: 1,
d_model,
n_features_per_layer: n_features,
n_features_total: n_features,
model_name: "test".to_owned(),
schema: TranscoderSchema::PltBundle,
gemmascope_npz_paths: Vec::new(),
},
loaded_encoder: None,
steering_cache: HashMap::new(),
};
let loaded = clt.load_skip_matrix(0, &Device::Cpu).unwrap();
assert_eq!(loaded.dims(), &[d_model, d_model]);
let values: Vec<Vec<f32>> = loaded.to_vec2().unwrap();
assert_eq!(values[0], vec![1.0, 2.0, 3.0]);
assert_eq!(values[1], vec![4.0, 5.0, 6.0]);
assert_eq!(values[2], vec![7.0, 8.0, 9.0]);
}
#[test]
fn load_skip_matrix_rejects_clt_split_schema() {
let mut clt = CrossLayerTranscoder {
repo_id: "test".to_owned(),
fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
encoder_paths: vec![None],
decoder_paths: vec![None],
config: CltConfig {
n_layers: 1,
d_model: 4,
n_features_per_layer: 2,
n_features_total: 2,
model_name: "test".to_owned(),
schema: TranscoderSchema::CltSplit,
gemmascope_npz_paths: Vec::new(),
},
loaded_encoder: None,
steering_cache: HashMap::new(),
};
let err = clt.load_skip_matrix(0, &Device::Cpu).unwrap_err();
match err {
MIError::Config(msg) => {
assert!(
msg.contains("PltBundle"),
"error message should mention PltBundle schema: {msg}"
);
}
other => panic!("expected MIError::Config, got {other:?}"),
}
}
}