use std::collections::HashMap;
use std::path::PathBuf;
use candle_core::{DType, Device, IndexOp, Tensor};
use safetensors::tensor::SafeTensors;
use tracing::info;
use crate::error::{MIError, Result};
#[derive(
Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize,
)]
pub struct CltFeatureId {
pub layer: usize,
pub index: usize,
}
impl std::fmt::Display for CltFeatureId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "L{}:{}", self.layer, self.index)
}
}
use crate::sparse::{FeatureId, SparseActivations};
impl FeatureId for CltFeatureId {}
#[derive(Debug, Clone)]
pub struct AttributionEdge {
pub feature: CltFeatureId,
pub score: f32,
}
#[derive(Debug, Clone)]
pub struct AttributionGraph {
target_layer: usize,
edges: Vec<AttributionEdge>,
}
impl AttributionGraph {
#[must_use]
pub const fn target_layer(&self) -> usize {
self.target_layer
}
#[must_use]
pub fn edges(&self) -> &[AttributionEdge] {
&self.edges
}
#[must_use]
pub const fn len(&self) -> usize {
self.edges.len()
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.edges.is_empty()
}
#[must_use]
pub fn top_k(&self, k: usize) -> Self {
Self {
target_layer: self.target_layer,
edges: self.edges.iter().take(k).cloned().collect(),
}
}
#[must_use]
pub fn threshold(&self, min_score: f32) -> Self {
Self {
target_layer: self.target_layer,
edges: self
.edges
.iter()
.filter(|e| e.score.abs() >= min_score)
.cloned()
.collect(),
}
}
#[must_use]
pub fn features(&self) -> Vec<CltFeatureId> {
self.edges.iter().map(|e| e.feature).collect()
}
#[must_use]
pub fn into_edges(self) -> Vec<AttributionEdge> {
self.edges
}
}
#[derive(Debug, Clone)]
pub struct CltConfig {
pub n_layers: usize,
pub d_model: usize,
pub n_features_per_layer: usize,
pub n_features_total: usize,
pub model_name: String,
}
struct LoadedEncoder {
layer: usize,
w_enc: Tensor,
b_enc: Tensor,
}
pub struct CrossLayerTranscoder {
repo_id: String,
fetch_config: hf_fetch_model::FetchConfig,
encoder_paths: Vec<Option<PathBuf>>,
decoder_paths: Vec<Option<PathBuf>>,
config: CltConfig,
loaded_encoder: Option<LoadedEncoder>,
steering_cache: HashMap<(CltFeatureId, usize), Tensor>,
}
impl CrossLayerTranscoder {
pub fn open(clt_repo: &str) -> Result<Self> {
let fetch_config = hf_fetch_model::FetchConfig::builder()
.on_progress(|event| {
tracing::info!(
filename = %event.filename,
percent = event.percent,
bytes_downloaded = event.bytes_downloaded,
bytes_total = event.bytes_total,
"CLT download progress",
);
})
.build()
.map_err(|e| MIError::Download(format!("failed to build fetch config: {e}")))?;
let rt = tokio::runtime::Runtime::new()
.map_err(|e| MIError::Download(format!("failed to create tokio runtime: {e}")))?;
let repo_files = rt
.block_on(hf_fetch_model::repo::list_repo_files_with_metadata(
clt_repo, None, None,
))
.map_err(|e| MIError::Download(format!("failed to list repo files: {e}")))?;
let n_layers = repo_files
.iter()
.filter(|f| f.filename.starts_with("W_enc_") && f.filename.ends_with(".safetensors"))
.count();
if n_layers == 0 {
return Err(MIError::Config(format!(
"no CLT encoder files found in {clt_repo}"
)));
}
let model_name = match hf_fetch_model::download_file_blocking(
clt_repo.to_owned(),
"config.yaml",
&fetch_config,
) {
Ok(outcome) => {
let path = outcome.into_inner();
let text = std::fs::read_to_string(&path)?;
parse_yaml_value(&text, "model_name").unwrap_or_else(|| "unknown".to_owned())
}
Err(_) => "unknown".to_owned(),
};
let enc0_path = hf_fetch_model::download_file_blocking(
clt_repo.to_owned(),
"W_enc_0.safetensors",
&fetch_config,
)
.map_err(|e| MIError::Download(format!("failed to download W_enc_0: {e}")))?
.into_inner();
let data = std::fs::read(&enc0_path)?;
let tensors = SafeTensors::deserialize(&data)
.map_err(|e| MIError::Config(format!("failed to deserialize W_enc_0: {e}")))?;
let w_enc_view = tensors
.tensor("W_enc_0")
.map_err(|e| MIError::Config(format!("tensor 'W_enc_0' not found: {e}")))?;
let shape = w_enc_view.shape();
if shape.len() != 2 {
return Err(MIError::Config(format!(
"expected 2D encoder weight, got shape {shape:?}"
)));
}
let n_features_per_layer = *shape
.first()
.ok_or_else(|| MIError::Config("encoder weight shape is empty".into()))?;
let d_model = *shape.get(1).ok_or_else(|| {
MIError::Config("encoder weight shape has fewer than 2 dimensions".into())
})?;
let mut encoder_paths: Vec<Option<PathBuf>> = vec![None; n_layers];
if let Some(slot) = encoder_paths.first_mut() {
*slot = Some(enc0_path);
}
let decoder_paths: Vec<Option<PathBuf>> = vec![None; n_layers];
let config = CltConfig {
n_layers,
d_model,
n_features_per_layer,
n_features_total: n_layers * n_features_per_layer,
model_name,
};
info!(
"CLT config: {} layers, d_model={}, features_per_layer={}, total={}",
config.n_layers, config.d_model, config.n_features_per_layer, config.n_features_total
);
Ok(Self {
repo_id: clt_repo.to_owned(),
fetch_config,
encoder_paths,
decoder_paths,
config,
loaded_encoder: None,
steering_cache: HashMap::new(),
})
}
#[must_use]
pub const fn config(&self) -> &CltConfig {
&self.config
}
#[must_use]
pub fn loaded_encoder_layer(&self) -> Option<usize> {
self.loaded_encoder.as_ref().map(|e| e.layer)
}
fn ensure_encoder_path(&mut self, layer: usize) -> Result<PathBuf> {
if let Some(path) = self
.encoder_paths
.get(layer)
.and_then(std::option::Option::as_ref)
{
return Ok(path.clone());
}
let filename = format!("W_enc_{layer}.safetensors");
info!("Downloading {filename} from {}", self.repo_id);
let path = hf_fetch_model::download_file_blocking(
self.repo_id.clone(),
&filename,
&self.fetch_config,
)
.map_err(|e| MIError::Download(format!("failed to download {filename}: {e}")))?
.into_inner();
if let Some(slot) = self.encoder_paths.get_mut(layer) {
*slot = Some(path.clone());
}
Ok(path)
}
fn ensure_decoder_path(&mut self, layer: usize) -> Result<PathBuf> {
if let Some(path) = self
.decoder_paths
.get(layer)
.and_then(std::option::Option::as_ref)
{
return Ok(path.clone());
}
let filename = format!("W_dec_{layer}.safetensors");
info!("Downloading {filename} from {}", self.repo_id);
let path = hf_fetch_model::download_file_blocking(
self.repo_id.clone(),
&filename,
&self.fetch_config,
)
.map_err(|e| MIError::Download(format!("failed to download {filename}: {e}")))?
.into_inner();
if let Some(slot) = self.decoder_paths.get_mut(layer) {
*slot = Some(path.clone());
}
Ok(path)
}
pub fn load_encoder(&mut self, layer: usize, device: &Device) -> Result<()> {
if layer >= self.config.n_layers {
return Err(MIError::Config(format!(
"layer {layer} out of range (CLT has {} layers)",
self.config.n_layers
)));
}
if let Some(ref enc) = self.loaded_encoder {
if enc.layer == layer {
return Ok(());
}
}
self.loaded_encoder = None;
info!("Loading CLT encoder for layer {layer}");
let enc_path = self.ensure_encoder_path(layer)?;
let data = std::fs::read(&enc_path)?;
let st = SafeTensors::deserialize(&data).map_err(|e| {
MIError::Config(format!("failed to deserialize encoder layer {layer}: {e}"))
})?;
let w_enc_name = format!("W_enc_{layer}");
let b_enc_name = format!("b_enc_{layer}");
let w_enc = tensor_from_view(
&st.tensor(&w_enc_name)
.map_err(|e| MIError::Config(format!("tensor '{w_enc_name}' not found: {e}")))?,
device,
)?;
let b_enc = tensor_from_view(
&st.tensor(&b_enc_name)
.map_err(|e| MIError::Config(format!("tensor '{b_enc_name}' not found: {e}")))?,
device,
)?;
self.loaded_encoder = Some(LoadedEncoder {
layer,
w_enc,
b_enc,
});
Ok(())
}
pub fn encode(
&self,
residual: &Tensor,
layer: usize,
) -> Result<SparseActivations<CltFeatureId>> {
let enc = self.loaded_encoder.as_ref().ok_or_else(|| {
MIError::Hook(format!(
"no encoder loaded — call load_encoder({layer}) first"
))
})?;
if enc.layer != layer {
return Err(MIError::Hook(format!(
"loaded encoder is for layer {}, but layer {layer} was requested",
enc.layer
)));
}
let residual_f32 = residual.flatten_all()?;
let residual_f32 = residual_f32.to_dtype(DType::F32)?;
let w_enc_f32 = enc.w_enc.to_dtype(DType::F32)?;
let b_enc_f32 = enc.b_enc.to_dtype(DType::F32)?;
let pre_acts = w_enc_f32.matmul(&residual_f32.unsqueeze(1)?)?.squeeze(1)?;
let pre_acts = (&pre_acts + &b_enc_f32)?;
let acts = pre_acts.relu()?;
let acts_vec: Vec<f32> = acts.to_vec1()?;
let mut features: Vec<(CltFeatureId, f32)> = acts_vec
.iter()
.enumerate()
.filter(|&(_, v)| *v > 0.0)
.map(|(i, v)| (CltFeatureId { layer, index: i }, *v))
.collect();
features.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(SparseActivations { features })
}
pub fn top_k(
&self,
residual: &Tensor,
layer: usize,
k: usize,
) -> Result<SparseActivations<CltFeatureId>> {
let mut sparse = self.encode(residual, layer)?;
sparse.truncate(k);
Ok(sparse)
}
pub fn decoder_vector(
&mut self,
feature: &CltFeatureId,
target_layer: usize,
device: &Device,
) -> Result<Tensor> {
if feature.layer >= self.config.n_layers {
return Err(MIError::Config(format!(
"feature source layer {} out of range (CLT has {} layers)",
feature.layer, self.config.n_layers
)));
}
if target_layer < feature.layer || target_layer >= self.config.n_layers {
return Err(MIError::Config(format!(
"target layer {target_layer} must be >= source layer {} and < {}",
feature.layer, self.config.n_layers
)));
}
if feature.index >= self.config.n_features_per_layer {
return Err(MIError::Config(format!(
"feature index {} out of range (max {})",
feature.index, self.config.n_features_per_layer
)));
}
let cache_key = (*feature, target_layer);
if let Some(cached) = self.steering_cache.get(&cache_key) {
return Ok(cached.clone());
}
let target_offset = target_layer - feature.layer;
let dec_path = self.ensure_decoder_path(feature.layer)?;
let data = std::fs::read(&dec_path)?;
let st = SafeTensors::deserialize(&data).map_err(|e| {
MIError::Config(format!(
"failed to deserialize decoder layer {}: {e}",
feature.layer
))
})?;
let dec_name = format!("W_dec_{}", feature.layer);
let w_dec = tensor_from_view(
&st.tensor(&dec_name)
.map_err(|e| MIError::Config(format!("tensor '{dec_name}' not found: {e}")))?,
&Device::Cpu,
)?;
let column = w_dec.i((feature.index, target_offset))?;
let column = column.to_device(device)?;
Ok(column)
}
pub fn cache_steering_vectors(
&mut self,
features: &[(CltFeatureId, usize)],
device: &Device,
) -> Result<()> {
let mut by_source: HashMap<usize, Vec<(usize, usize)>> = HashMap::new();
for (fid, target_layer) in features {
by_source
.entry(fid.layer)
.or_default()
.push((fid.index, *target_layer));
}
let mut loaded = 0_usize;
let n_source_layers = by_source.len();
for (layer_idx, (source_layer, entries)) in by_source.iter().enumerate() {
info!(
"cache_steering_vectors: loading decoder for source layer {} ({}/{})",
source_layer,
layer_idx + 1,
n_source_layers
);
let mut by_target: HashMap<usize, Vec<usize>> = HashMap::new();
for &(index, target_layer) in entries {
by_target.entry(target_layer).or_default().push(index);
}
let mut cpu_columns: Vec<(CltFeatureId, usize, Tensor)> = Vec::new();
{
let dec_path = self.ensure_decoder_path(*source_layer)?;
let data = std::fs::read(&dec_path)?;
info!(
"cache_steering_vectors: loaded {} MB for layer {}",
data.len() / (1024 * 1024),
source_layer
);
let st = SafeTensors::deserialize(&data).map_err(|e| {
MIError::Config(format!(
"failed to deserialize decoder layer {source_layer}: {e}"
))
})?;
let dec_name = format!("W_dec_{source_layer}");
let w_dec = tensor_from_view(
&st.tensor(&dec_name).map_err(|e| {
MIError::Config(format!("tensor '{dec_name}' not found: {e}"))
})?,
&Device::Cpu,
)?;
for (target_layer, indices) in &by_target {
let target_offset = target_layer - source_layer;
for &index in indices {
let fid = CltFeatureId {
layer: *source_layer,
index,
};
let cache_key = (fid, *target_layer);
if !self.steering_cache.contains_key(&cache_key) {
let view = w_dec.i((index, target_offset))?;
let dims = view.dims().to_vec();
let values = view.to_dtype(DType::F32)?.to_vec1::<f32>()?;
let independent =
Tensor::from_vec(values, dims.as_slice(), &Device::Cpu)?;
cpu_columns.push((fid, *target_layer, independent));
}
}
}
}
for (fid, target_layer, cpu_tensor) in cpu_columns {
let cache_key = (fid, target_layer);
if let std::collections::hash_map::Entry::Vacant(e) =
self.steering_cache.entry(cache_key)
{
let device_tensor = cpu_tensor.to_device(device)?;
e.insert(device_tensor);
loaded += 1;
}
}
}
info!(
"Cached {loaded} new steering vectors ({} total in cache)",
self.steering_cache.len()
);
Ok(())
}
pub fn cache_steering_vectors_all_downstream(
&mut self,
features: &[CltFeatureId],
device: &Device,
) -> Result<()> {
let n_layers = self.config.n_layers;
let mut by_source: HashMap<usize, Vec<usize>> = HashMap::new();
for fid in features {
if fid.layer >= n_layers {
return Err(MIError::Config(format!(
"feature source layer {} out of range (max {})",
fid.layer,
n_layers - 1
)));
}
by_source.entry(fid.layer).or_default().push(fid.index);
}
let mut loaded = 0_usize;
let n_source_layers = by_source.len();
for (layer_idx, (source_layer, indices)) in by_source.iter().enumerate() {
let n_target_layers = n_layers - source_layer;
info!(
"cache_steering_vectors_all_downstream: loading decoder for source layer {} \
({}/{}, {} downstream layers)",
source_layer,
layer_idx + 1,
n_source_layers,
n_target_layers
);
let mut cpu_columns: Vec<(CltFeatureId, usize, Tensor)> = Vec::new();
{
let dec_path = self.ensure_decoder_path(*source_layer)?;
let data = std::fs::read(&dec_path)?;
info!(
"cache_steering_vectors_all_downstream: loaded {} MB for layer {}",
data.len() / (1024 * 1024),
source_layer
);
let st = SafeTensors::deserialize(&data).map_err(|e| {
MIError::Config(format!(
"failed to deserialize decoder layer {source_layer}: {e}"
))
})?;
let dec_name = format!("W_dec_{source_layer}");
let w_dec = tensor_from_view(
&st.tensor(&dec_name).map_err(|e| {
MIError::Config(format!("tensor '{dec_name}' not found: {e}"))
})?,
&Device::Cpu,
)?;
for &index in indices {
let fid = CltFeatureId {
layer: *source_layer,
index,
};
for target_offset in 0..n_target_layers {
let target_layer = source_layer + target_offset;
let cache_key = (fid, target_layer);
if !self.steering_cache.contains_key(&cache_key) {
let view = w_dec.i((index, target_offset))?;
let dims = view.dims().to_vec();
let values = view.to_dtype(DType::F32)?.to_vec1::<f32>()?;
let independent =
Tensor::from_vec(values, dims.as_slice(), &Device::Cpu)?;
cpu_columns.push((fid, target_layer, independent));
}
}
}
}
for (fid, target_layer, cpu_tensor) in cpu_columns {
let cache_key = (fid, target_layer);
if let std::collections::hash_map::Entry::Vacant(e) =
self.steering_cache.entry(cache_key)
{
let device_tensor = cpu_tensor.to_device(device)?;
e.insert(device_tensor);
loaded += 1;
}
}
}
info!(
"Cached {loaded} new steering vectors across all downstream layers ({} total in cache)",
self.steering_cache.len()
);
Ok(())
}
pub fn clear_steering_cache(&mut self) {
let count = self.steering_cache.len();
self.steering_cache.clear();
if count > 0 {
info!("Cleared {count} steering vectors from cache");
}
}
#[must_use]
pub fn steering_cache_len(&self) -> usize {
self.steering_cache.len()
}
pub fn prepare_hook_injection(
&self,
features: &[(CltFeatureId, usize)],
position: usize,
seq_len: usize,
strength: f32,
device: &Device,
) -> Result<crate::hooks::HookSpec> {
use crate::hooks::{HookPoint, HookSpec, Intervention};
let mut per_layer: HashMap<usize, Tensor> = HashMap::new();
for (feature, target_layer) in features {
let cache_key = (*feature, *target_layer);
let cached = self.steering_cache.get(&cache_key).ok_or_else(|| {
MIError::Hook(format!(
"feature {feature} for target layer {target_layer} not in steering cache \
— call cache_steering_vectors() first"
))
})?;
let cached_f32 = cached.to_dtype(DType::F32)?;
if let Some(acc) = per_layer.get_mut(target_layer) {
let acc_ref: &Tensor = acc;
*acc = (acc_ref + &cached_f32)?;
} else {
per_layer.insert(*target_layer, cached_f32);
}
}
let mut hooks = HookSpec::new();
let d_model = self.config.d_model;
for (target_layer, accumulated) in &per_layer {
let scaled = (accumulated * f64::from(strength))?;
let mut injection = Tensor::zeros((1, seq_len, d_model), DType::F32, device)?;
let scaled_3d = scaled.unsqueeze(0)?.unsqueeze(0)?; let before = if position > 0 {
Some(injection.narrow(1, 0, position)?)
} else {
None
};
let after = if position + 1 < seq_len {
Some(injection.narrow(1, position + 1, seq_len - position - 1)?)
} else {
None
};
let mut parts: Vec<Tensor> = Vec::with_capacity(3);
if let Some(b) = before {
parts.push(b);
}
parts.push(scaled_3d);
if let Some(a) = after {
parts.push(a);
}
injection = Tensor::cat(&parts, 1)?;
hooks.intervene(
HookPoint::ResidPost(*target_layer),
Intervention::Add(injection),
);
}
Ok(hooks)
}
pub fn inject(
&self,
residual: &Tensor,
features: &[(CltFeatureId, usize)],
position: usize,
strength: f32,
) -> Result<Tensor> {
let (batch, seq_len, d_model) = residual.dims3()?;
if position >= seq_len {
return Err(MIError::Config(format!(
"injection position {position} out of range (seq_len={seq_len})"
)));
}
if d_model != self.config.d_model {
return Err(MIError::Config(format!(
"residual d_model={d_model} doesn't match CLT d_model={}",
self.config.d_model
)));
}
let mut accumulated = Tensor::zeros((d_model,), DType::F32, residual.device())?;
for (feature, target_layer) in features {
let cache_key = (*feature, *target_layer);
let cached = self.steering_cache.get(&cache_key).ok_or_else(|| {
MIError::Hook(format!(
"feature {feature} for target layer {target_layer} not in steering cache"
))
})?;
let cached_f32 = cached.to_dtype(DType::F32)?;
accumulated = (&accumulated + &cached_f32)?;
}
let accumulated = (accumulated * f64::from(strength))?;
let accumulated = accumulated.to_dtype(residual.dtype())?;
let pos_slice = residual.narrow(1, position, 1)?; let steering_expanded = accumulated
.unsqueeze(0)?
.unsqueeze(0)?
.expand((batch, 1, d_model))?; let pos_updated = (&pos_slice + &steering_expanded)?;
let mut parts: Vec<Tensor> = Vec::with_capacity(3);
if position > 0 {
parts.push(residual.narrow(1, 0, position)?);
}
parts.push(pos_updated);
if position + 1 < seq_len {
parts.push(residual.narrow(1, position + 1, seq_len - position - 1)?);
}
let result = Tensor::cat(&parts, 1)?;
Ok(result)
}
pub fn score_features_by_decoder_projection(
&mut self,
direction: &Tensor,
target_layer: usize,
top_k: usize,
cosine: bool,
) -> Result<Vec<(CltFeatureId, f32)>> {
let d_model = self.config.d_model;
if direction.dims() != [d_model] {
return Err(MIError::Config(format!(
"direction must have shape [{d_model}], got {:?}",
direction.dims()
)));
}
if target_layer >= self.config.n_layers {
return Err(MIError::Config(format!(
"target layer {target_layer} out of range (max {})",
self.config.n_layers - 1
)));
}
let direction_f32 = direction.to_dtype(DType::F32)?.to_device(&Device::Cpu)?;
let direction_norm = if cosine {
let norm: f32 = direction_f32.sqr()?.sum_all()?.sqrt()?.to_scalar()?;
if norm > 1e-10 {
direction_f32.broadcast_div(&Tensor::new(norm, &Device::Cpu)?)?
} else {
direction_f32
}
} else {
direction_f32
};
let mut all_scores: Vec<(CltFeatureId, f32)> = Vec::new();
for source_layer in 0..self.config.n_layers {
if target_layer < source_layer {
continue; }
let target_offset = target_layer - source_layer;
let dec_path = self.ensure_decoder_path(source_layer)?;
let data = std::fs::read(&dec_path)?;
info!(
"score_features_by_decoder_projection: loaded {} MB for layer {}",
data.len() / (1024 * 1024),
source_layer
);
let st = SafeTensors::deserialize(&data).map_err(|e| {
MIError::Config(format!(
"failed to deserialize decoder layer {source_layer}: {e}"
))
})?;
let dec_name = format!("W_dec_{source_layer}");
let w_dec = tensor_from_view(
&st.tensor(&dec_name)
.map_err(|e| MIError::Config(format!("tensor '{dec_name}' not found: {e}")))?,
&Device::Cpu,
)?;
let w_dec_f32 = w_dec.to_dtype(DType::F32)?;
let dec_slice = w_dec_f32.i((.., target_offset, ..))?;
let raw_scores = dec_slice
.matmul(&direction_norm.unsqueeze(1)?)?
.squeeze(1)?;
let scores_vec: Vec<f32> = if cosine {
let dec_norms = dec_slice.sqr()?.sum(1)?.sqrt()?;
let cosine_scores = raw_scores.broadcast_div(&dec_norms)?;
cosine_scores.to_vec1()?
} else {
raw_scores.to_vec1()?
};
for (idx, &score) in scores_vec.iter().enumerate() {
if score.is_finite() {
all_scores.push((
CltFeatureId {
layer: source_layer,
index: idx,
},
score,
));
}
}
info!(
"Scored {} features at source layer {source_layer} (target layer {target_layer})",
scores_vec.len()
);
}
all_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
all_scores.truncate(top_k);
Ok(all_scores)
}
pub fn score_features_by_decoder_projection_batch(
&mut self,
directions: &[Tensor],
target_layer: usize,
top_k: usize,
cosine: bool,
) -> Result<Vec<Vec<(CltFeatureId, f32)>>> {
let d_model = self.config.d_model;
let n_words = directions.len();
if n_words == 0 {
return Err(MIError::Config(
"at least one direction vector required".into(),
));
}
for (i, dir) in directions.iter().enumerate() {
if dir.dims() != [d_model] {
return Err(MIError::Config(format!(
"direction vector {i} must have shape [{d_model}], got {:?}",
dir.dims()
)));
}
}
if target_layer >= self.config.n_layers {
return Err(MIError::Config(format!(
"target layer {target_layer} out of range (max {})",
self.config.n_layers - 1
)));
}
let dirs_f32: Vec<Tensor> = directions
.iter()
.map(|d| d.to_dtype(DType::F32)?.to_device(&Device::Cpu))
.collect::<std::result::Result<_, _>>()?;
let stacked = Tensor::stack(&dirs_f32, 0)?;
let stacked_norm = if cosine {
let norms = stacked.sqr()?.sum(1)?.sqrt()?; let ones = Tensor::ones_like(&norms)?;
let safe_norms = norms.maximum(&(&ones * 1e-10f64)?)?; stacked.broadcast_div(&safe_norms.unsqueeze(1)?)?
} else {
stacked
};
let directions_t = stacked_norm.t()?;
let mut all_scores: Vec<Vec<(CltFeatureId, f32)>> =
(0..n_words).map(|_| Vec::new()).collect();
for source_layer in 0..self.config.n_layers {
if target_layer < source_layer {
continue;
}
let target_offset = target_layer - source_layer;
let dec_path = self.ensure_decoder_path(source_layer)?;
let data = std::fs::read(&dec_path)?;
info!(
"score_features_batch: loaded {} MB for layer {}",
data.len() / (1024 * 1024),
source_layer
);
let st = SafeTensors::deserialize(&data).map_err(|e| {
MIError::Config(format!(
"failed to deserialize decoder layer {source_layer}: {e}"
))
})?;
let dec_name = format!("W_dec_{source_layer}");
let w_dec = tensor_from_view(
&st.tensor(&dec_name)
.map_err(|e| MIError::Config(format!("tensor '{dec_name}' not found: {e}")))?,
&Device::Cpu,
)?;
let w_dec_f32 = w_dec.to_dtype(DType::F32)?;
let dec_slice = w_dec_f32.i((.., target_offset, ..))?;
let raw_scores = dec_slice.matmul(&directions_t)?;
let scores_2d: Vec<Vec<f32>> = if cosine {
let dec_norms = dec_slice.sqr()?.sum(1)?.sqrt()?; let cosine_scores = raw_scores.broadcast_div(&dec_norms.unsqueeze(1)?)?;
cosine_scores.t()?.to_vec2()?
} else {
raw_scores.t()?.to_vec2()?
};
for (w, word_scores) in scores_2d.iter().enumerate() {
for (idx, &score) in word_scores.iter().enumerate() {
if score.is_finite() {
if let Some(word_vec) = all_scores.get_mut(w) {
word_vec.push((
CltFeatureId {
layer: source_layer,
index: idx,
},
score,
));
}
}
}
}
info!(
"Batch scored {} words × {} features at source layer {} (target layer {})",
n_words,
scores_2d.first().map_or(0, Vec::len),
source_layer,
target_layer
);
}
for word_scores in &mut all_scores {
word_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
word_scores.truncate(top_k);
}
Ok(all_scores)
}
pub fn extract_decoder_vectors(
&mut self,
features: &[CltFeatureId],
target_layer: usize,
) -> Result<HashMap<CltFeatureId, Tensor>> {
if target_layer >= self.config.n_layers {
return Err(MIError::Config(format!(
"target layer {target_layer} out of range (max {})",
self.config.n_layers - 1
)));
}
let mut by_source: HashMap<usize, Vec<usize>> = HashMap::new();
for fid in features {
if fid.layer >= self.config.n_layers {
return Err(MIError::Config(format!(
"feature source layer {} out of range (max {})",
fid.layer,
self.config.n_layers - 1
)));
}
if target_layer < fid.layer {
return Err(MIError::Config(format!(
"target layer {target_layer} must be >= source layer {}",
fid.layer
)));
}
by_source.entry(fid.layer).or_default().push(fid.index);
}
let mut result: HashMap<CltFeatureId, Tensor> = HashMap::new();
let n_source_layers = by_source.len();
for (layer_idx, (source_layer, indices)) in by_source.iter().enumerate() {
info!(
"extract_decoder_vectors: loading decoder for source layer {} ({}/{})",
source_layer,
layer_idx + 1,
n_source_layers
);
let target_offset = target_layer - source_layer;
let dec_path = self.ensure_decoder_path(*source_layer)?;
let data = std::fs::read(&dec_path)?;
let st = SafeTensors::deserialize(&data).map_err(|e| {
MIError::Config(format!(
"failed to deserialize decoder layer {source_layer}: {e}"
))
})?;
let dec_name = format!("W_dec_{source_layer}");
let w_dec = tensor_from_view(
&st.tensor(&dec_name)
.map_err(|e| MIError::Config(format!("tensor '{dec_name}' not found: {e}")))?,
&Device::Cpu,
)?;
for &index in indices {
let fid = CltFeatureId {
layer: *source_layer,
index,
};
if let std::collections::hash_map::Entry::Vacant(e) = result.entry(fid) {
let view = w_dec.i((index, target_offset))?;
let dims = view.dims().to_vec();
let values = view.to_dtype(DType::F32)?.to_vec1::<f32>()?;
let independent = Tensor::from_vec(values, dims.as_slice(), &Device::Cpu)?;
e.insert(independent);
}
}
}
info!(
"Extracted {} decoder vectors across {} source layers",
result.len(),
n_source_layers
);
Ok(result)
}
pub fn build_attribution_graph(
&mut self,
direction: &Tensor,
target_layer: usize,
top_k: usize,
cosine: bool,
) -> Result<AttributionGraph> {
let scored =
self.score_features_by_decoder_projection(direction, target_layer, top_k, cosine)?;
Ok(AttributionGraph {
target_layer,
edges: scored
.into_iter()
.map(|(feature, score)| AttributionEdge { feature, score })
.collect(),
})
}
pub fn build_attribution_graph_batch(
&mut self,
directions: &[Tensor],
target_layer: usize,
top_k: usize,
cosine: bool,
) -> Result<Vec<AttributionGraph>> {
let batch = self.score_features_by_decoder_projection_batch(
directions,
target_layer,
top_k,
cosine,
)?;
Ok(batch
.into_iter()
.map(|scored| AttributionGraph {
target_layer,
edges: scored
.into_iter()
.map(|(feature, score)| AttributionEdge { feature, score })
.collect(),
})
.collect())
}
}
fn tensor_from_view(view: &safetensors::tensor::TensorView<'_>, device: &Device) -> Result<Tensor> {
let shape: Vec<usize> = view.shape().to_vec();
#[allow(clippy::wildcard_enum_match_arm)]
let dtype = match view.dtype() {
safetensors::Dtype::BF16 => DType::BF16,
safetensors::Dtype::F16 => DType::F16,
safetensors::Dtype::F32 => DType::F32,
other => {
return Err(MIError::Config(format!(
"unsupported CLT tensor dtype: {other:?}"
)));
}
};
let tensor = Tensor::from_raw_buffer(view.data(), dtype, &shape, device)?;
Ok(tensor)
}
fn parse_yaml_value(yaml_text: &str, key: &str) -> Option<String> {
for line in yaml_text.lines() {
let line = line.trim();
if let Some(rest) = line.strip_prefix(key) {
if let Some(rest) = rest.strip_prefix(':') {
let value = rest.trim().trim_matches('"');
return Some(value.to_owned());
}
}
}
None
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn clt_feature_id_display() {
let fid = CltFeatureId {
layer: 5,
index: 42,
};
assert_eq!(fid.to_string(), "L5:42");
}
#[test]
fn clt_feature_id_ordering() {
let a = CltFeatureId {
layer: 0,
index: 10,
};
let b = CltFeatureId {
layer: 0,
index: 20,
};
let c = CltFeatureId { layer: 1, index: 0 };
assert!(a < b);
assert!(b < c);
}
#[test]
fn sparse_activations_basics() {
let features = vec![
(CltFeatureId { layer: 0, index: 5 }, 3.0),
(CltFeatureId { layer: 0, index: 2 }, 2.0),
(CltFeatureId { layer: 0, index: 8 }, 1.0),
];
let sparse = SparseActivations { features };
assert_eq!(sparse.len(), 3);
assert!(!sparse.is_empty());
}
#[test]
fn sparse_activations_truncate() {
let features = vec![
(CltFeatureId { layer: 0, index: 5 }, 3.0),
(CltFeatureId { layer: 0, index: 2 }, 2.0),
(CltFeatureId { layer: 0, index: 8 }, 1.0),
];
let mut sparse = SparseActivations { features };
sparse.truncate(2);
assert_eq!(sparse.len(), 2);
assert_eq!(sparse.features[0].0.index, 5);
assert_eq!(sparse.features[1].0.index, 2);
}
#[test]
fn parse_yaml_value_basic() {
let yaml = "model_name: \"google/gemma-2-2b\"\nmodel_kind: cross_layer_transcoder\n";
assert_eq!(
parse_yaml_value(yaml, "model_name"),
Some("google/gemma-2-2b".to_owned())
);
assert_eq!(
parse_yaml_value(yaml, "model_kind"),
Some("cross_layer_transcoder".to_owned())
);
assert_eq!(parse_yaml_value(yaml, "missing_key"), None);
}
#[test]
fn encode_synthetic() {
let device = Device::Cpu;
let d_model = 8;
let n_features = 4;
#[rustfmt::skip]
let w_enc_data: Vec<f32> = vec![
1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ];
let w_enc = Tensor::from_vec(w_enc_data, (n_features, d_model), &device).unwrap();
let b_enc_data: Vec<f32> = vec![0.0, -0.5, 0.0, -2.0]; let b_enc = Tensor::from_vec(b_enc_data, (n_features,), &device).unwrap();
let residual_data: Vec<f32> = vec![1.5, 0.3, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0];
let residual = Tensor::from_vec(residual_data, (d_model,), &device).unwrap();
let clt = CrossLayerTranscoder {
repo_id: "test".to_owned(),
fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
encoder_paths: vec![None],
decoder_paths: vec![None],
config: CltConfig {
n_layers: 1,
d_model,
n_features_per_layer: n_features,
n_features_total: n_features,
model_name: "test".to_owned(),
},
loaded_encoder: Some(LoadedEncoder {
layer: 0,
w_enc,
b_enc,
}),
steering_cache: HashMap::new(),
};
let sparse = clt.encode(&residual, 0).unwrap();
assert_eq!(sparse.len(), 1, "only feature 0 should be active");
assert_eq!(sparse.features[0].0.index, 0);
assert!((sparse.features[0].1 - 1.5).abs() < 1e-5);
}
#[test]
fn encode_wrong_layer_errors() {
let device = Device::Cpu;
let w_enc = Tensor::zeros((4, 8), DType::F32, &device).unwrap();
let b_enc = Tensor::zeros((4,), DType::F32, &device).unwrap();
let residual = Tensor::zeros((8,), DType::F32, &device).unwrap();
let clt = CrossLayerTranscoder {
repo_id: "test".to_owned(),
fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
encoder_paths: vec![None; 2],
decoder_paths: vec![None; 2],
config: CltConfig {
n_layers: 2,
d_model: 8,
n_features_per_layer: 4,
n_features_total: 8,
model_name: "test".to_owned(),
},
loaded_encoder: Some(LoadedEncoder {
layer: 0,
w_enc,
b_enc,
}),
steering_cache: HashMap::new(),
};
let result = clt.encode(&residual, 1);
assert!(result.is_err());
}
#[test]
fn inject_position() {
let device = Device::Cpu;
let d_model = 4;
let residual = Tensor::ones((1, 3, d_model), DType::F32, &device).unwrap();
let fid = CltFeatureId { layer: 0, index: 0 };
let target_layer = 1;
let steering_vec =
Tensor::from_vec(vec![10.0_f32, 20.0, 30.0, 40.0], (d_model,), &device).unwrap();
let mut steering_cache = HashMap::new();
steering_cache.insert((fid, target_layer), steering_vec);
let clt = CrossLayerTranscoder {
repo_id: "test".to_owned(),
fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
encoder_paths: vec![None; 2],
decoder_paths: vec![None; 2],
config: CltConfig {
n_layers: 2,
d_model,
n_features_per_layer: 1,
n_features_total: 2,
model_name: "test".to_owned(),
},
loaded_encoder: None,
steering_cache,
};
let result = clt
.inject(&residual, &[(fid, target_layer)], 1, 1.0)
.unwrap();
let pos0: Vec<f32> = result.i((0, 0)).unwrap().to_vec1().unwrap();
assert_eq!(pos0, vec![1.0, 1.0, 1.0, 1.0]);
let pos1: Vec<f32> = result.i((0, 1)).unwrap().to_vec1().unwrap();
assert_eq!(pos1, vec![11.0, 21.0, 31.0, 41.0]);
let pos2: Vec<f32> = result.i((0, 2)).unwrap().to_vec1().unwrap();
assert_eq!(pos2, vec![1.0, 1.0, 1.0, 1.0]);
}
#[test]
fn prepare_hook_injection_creates_correct_hooks() {
use crate::hooks::HookPoint;
let device = Device::Cpu;
let d_model = 4;
let fid = CltFeatureId { layer: 0, index: 0 };
let target_layer = 5;
let steering_vec =
Tensor::from_vec(vec![1.0_f32, 2.0, 3.0, 4.0], (d_model,), &device).unwrap();
let mut steering_cache = HashMap::new();
steering_cache.insert((fid, target_layer), steering_vec);
let clt = CrossLayerTranscoder {
repo_id: "test".to_owned(),
fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
encoder_paths: vec![None; 10],
decoder_paths: vec![None; 10],
config: CltConfig {
n_layers: 10,
d_model,
n_features_per_layer: 1,
n_features_total: 10,
model_name: "test".to_owned(),
},
loaded_encoder: None,
steering_cache,
};
let hooks = clt
.prepare_hook_injection(&[(fid, target_layer)], 2, 5, 1.0, &device)
.unwrap();
assert!(hooks.has_intervention_at(&HookPoint::ResidPost(target_layer)));
assert!(!hooks.has_intervention_at(&HookPoint::ResidPost(0)));
assert!(!hooks.has_intervention_at(&HookPoint::ResidPost(4)));
}
#[test]
fn attribution_edge_basics() {
let edge = AttributionEdge {
feature: CltFeatureId {
layer: 3,
index: 42,
},
score: 0.75,
};
assert_eq!(edge.feature.layer, 3);
assert_eq!(edge.feature.index, 42);
assert!((edge.score - 0.75).abs() < f32::EPSILON);
}
#[test]
fn attribution_graph_empty() {
let graph = AttributionGraph {
target_layer: 5,
edges: Vec::new(),
};
assert_eq!(graph.target_layer(), 5);
assert!(graph.is_empty());
assert_eq!(graph.len(), 0);
assert!(graph.features().is_empty());
assert!(graph.into_edges().is_empty());
}
#[test]
fn attribution_graph_top_k() {
let edges = vec![
AttributionEdge {
feature: CltFeatureId { layer: 0, index: 0 },
score: 5.0,
},
AttributionEdge {
feature: CltFeatureId { layer: 0, index: 1 },
score: 3.0,
},
AttributionEdge {
feature: CltFeatureId { layer: 1, index: 0 },
score: 1.0,
},
AttributionEdge {
feature: CltFeatureId { layer: 1, index: 1 },
score: -1.0,
},
AttributionEdge {
feature: CltFeatureId { layer: 2, index: 0 },
score: -4.0,
},
];
let graph = AttributionGraph {
target_layer: 3,
edges,
};
assert_eq!(graph.len(), 5);
let top3 = graph.top_k(3);
assert_eq!(top3.len(), 3);
assert_eq!(top3.target_layer(), 3);
assert!((top3.edges()[0].score - 5.0).abs() < f32::EPSILON);
assert!((top3.edges()[1].score - 3.0).abs() < f32::EPSILON);
assert!((top3.edges()[2].score - 1.0).abs() < f32::EPSILON);
let top10 = graph.top_k(10);
assert_eq!(top10.len(), 5);
}
#[test]
fn attribution_graph_threshold() {
let edges = vec![
AttributionEdge {
feature: CltFeatureId { layer: 0, index: 0 },
score: 5.0,
},
AttributionEdge {
feature: CltFeatureId { layer: 0, index: 1 },
score: 3.0,
},
AttributionEdge {
feature: CltFeatureId { layer: 1, index: 0 },
score: 1.0,
},
AttributionEdge {
feature: CltFeatureId { layer: 1, index: 1 },
score: -1.0,
},
AttributionEdge {
feature: CltFeatureId { layer: 2, index: 0 },
score: -4.0,
},
];
let graph = AttributionGraph {
target_layer: 3,
edges,
};
let pruned = graph.threshold(2.0);
assert_eq!(pruned.len(), 3);
assert!((pruned.edges()[0].score - 5.0).abs() < f32::EPSILON);
assert!((pruned.edges()[1].score - 3.0).abs() < f32::EPSILON);
assert!((pruned.edges()[2].score - -4.0).abs() < f32::EPSILON);
}
#[test]
fn attribution_graph_features() {
let edges = vec![
AttributionEdge {
feature: CltFeatureId { layer: 2, index: 7 },
score: 1.0,
},
AttributionEdge {
feature: CltFeatureId { layer: 0, index: 3 },
score: 0.5,
},
];
let graph = AttributionGraph {
target_layer: 5,
edges,
};
let features = graph.features();
assert_eq!(features.len(), 2);
assert_eq!(features[0], CltFeatureId { layer: 2, index: 7 });
assert_eq!(features[1], CltFeatureId { layer: 0, index: 3 });
}
fn create_synthetic_decoder(
dir: &std::path::Path,
layer: usize,
n_features: usize,
n_target_layers: usize,
d_model: usize,
values: &[f32],
) -> PathBuf {
assert_eq!(values.len(), n_features * n_target_layers * d_model);
let bytes: Vec<u8> = values.iter().flat_map(|v| v.to_le_bytes()).collect();
let name = format!("W_dec_{layer}");
let shape = vec![n_features, n_target_layers, d_model];
let view =
safetensors::tensor::TensorView::new(safetensors::Dtype::F32, shape, &bytes).unwrap();
let mut tensors = HashMap::new();
tensors.insert(name, view);
let serialized = safetensors::serialize(&tensors, &None).unwrap();
let path = dir.join(format!("W_dec_{layer}.safetensors"));
std::fs::write(&path, serialized).unwrap();
path
}
#[test]
fn score_decoder_projection_synthetic() {
let dir = tempfile::tempdir().unwrap();
let d_model = 4;
let n_features = 4;
#[rustfmt::skip]
let dec0_values: Vec<f32> = vec![
0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
];
let path0 = create_synthetic_decoder(dir.path(), 0, n_features, 2, d_model, &dec0_values);
#[rustfmt::skip]
let dec1_values: Vec<f32> = vec![
2.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
0.0, 3.0, 0.0, 0.0,
];
let path1 = create_synthetic_decoder(dir.path(), 1, n_features, 1, d_model, &dec1_values);
let mut clt = CrossLayerTranscoder {
repo_id: "test".to_owned(),
fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
encoder_paths: vec![None; 2],
decoder_paths: vec![Some(path0), Some(path1)],
config: CltConfig {
n_layers: 2,
d_model,
n_features_per_layer: n_features,
n_features_total: n_features * 2,
model_name: "test".to_owned(),
},
loaded_encoder: None,
steering_cache: HashMap::new(),
};
let direction =
Tensor::from_vec(vec![1.0_f32, 0.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
let scores = clt
.score_features_by_decoder_projection(&direction, 1, 10, false)
.unwrap();
assert!(scores.len() >= 2, "expected at least 2 non-zero scores");
assert_eq!(scores[0].0, CltFeatureId { layer: 1, index: 0 });
assert!((scores[0].1 - 2.0).abs() < 1e-5);
assert_eq!(scores[1].0, CltFeatureId { layer: 0, index: 0 });
assert!((scores[1].1 - 1.0).abs() < 1e-5);
let direction2 =
Tensor::from_vec(vec![0.0_f32, 1.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
let scores2 = clt
.score_features_by_decoder_projection(&direction2, 1, 10, false)
.unwrap();
assert_eq!(scores2[0].0, CltFeatureId { layer: 1, index: 3 });
assert!((scores2[0].1 - 3.0).abs() < 1e-5);
assert_eq!(scores2[1].0, CltFeatureId { layer: 0, index: 1 });
assert!((scores2[1].1 - 1.0).abs() < 1e-5);
}
#[test]
fn score_decoder_projection_cosine_synthetic() {
let dir = tempfile::tempdir().unwrap();
let d_model = 4;
let n_features = 2;
#[rustfmt::skip]
let dec0_values: Vec<f32> = vec![
3.0, 0.0, 0.0, 0.0,
1.0, 1.0, 0.0, 0.0,
];
let path0 = create_synthetic_decoder(dir.path(), 0, n_features, 1, d_model, &dec0_values);
let mut clt = CrossLayerTranscoder {
repo_id: "test".to_owned(),
fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
encoder_paths: vec![None],
decoder_paths: vec![Some(path0)],
config: CltConfig {
n_layers: 1,
d_model,
n_features_per_layer: n_features,
n_features_total: n_features,
model_name: "test".to_owned(),
},
loaded_encoder: None,
steering_cache: HashMap::new(),
};
let direction =
Tensor::from_vec(vec![1.0_f32, 0.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
let dot_scores = clt
.score_features_by_decoder_projection(&direction, 0, 10, false)
.unwrap();
assert!((dot_scores[0].1 - 3.0).abs() < 1e-5);
assert!((dot_scores[1].1 - 1.0).abs() < 1e-5);
let cos_scores = clt
.score_features_by_decoder_projection(&direction, 0, 10, true)
.unwrap();
assert!(
(cos_scores[0].1 - 1.0).abs() < 1e-4,
"expected ~1.0, got {}",
cos_scores[0].1
);
let expected_cos = 1.0 / 2.0_f32.sqrt();
assert!(
(cos_scores[1].1 - expected_cos).abs() < 1e-4,
"expected ~{expected_cos}, got {}",
cos_scores[1].1
);
}
#[test]
fn score_decoder_projection_batch_synthetic() {
let dir = tempfile::tempdir().unwrap();
let d_model = 4;
let n_features = 2;
#[rustfmt::skip]
let dec0_values: Vec<f32> = vec![
1.0, 0.0, 0.0, 0.0,
0.0, 1.0, 0.0, 0.0,
];
let path0 = create_synthetic_decoder(dir.path(), 0, n_features, 1, d_model, &dec0_values);
let mut clt = CrossLayerTranscoder {
repo_id: "test".to_owned(),
fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
encoder_paths: vec![None],
decoder_paths: vec![Some(path0)],
config: CltConfig {
n_layers: 1,
d_model,
n_features_per_layer: n_features,
n_features_total: n_features,
model_name: "test".to_owned(),
},
loaded_encoder: None,
steering_cache: HashMap::new(),
};
let dir0 =
Tensor::from_vec(vec![1.0_f32, 0.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
let dir1 =
Tensor::from_vec(vec![0.0_f32, 1.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
let batch = clt
.score_features_by_decoder_projection_batch(&[dir0, dir1], 0, 10, false)
.unwrap();
assert_eq!(batch.len(), 2);
assert_eq!(batch[0][0].0, CltFeatureId { layer: 0, index: 0 });
assert!((batch[0][0].1 - 1.0).abs() < 1e-5);
assert_eq!(batch[1][0].0, CltFeatureId { layer: 0, index: 1 });
assert!((batch[1][0].1 - 1.0).abs() < 1e-5);
}
#[test]
fn extract_decoder_vectors_synthetic() {
let dir = tempfile::tempdir().unwrap();
let d_model = 4;
let n_features = 3;
#[rustfmt::skip]
let dec0_values: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0,
];
let path0 = create_synthetic_decoder(dir.path(), 0, n_features, 2, d_model, &dec0_values);
let mut clt = CrossLayerTranscoder {
repo_id: "test".to_owned(),
fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
encoder_paths: vec![None; 2],
decoder_paths: vec![Some(path0), None],
config: CltConfig {
n_layers: 2,
d_model,
n_features_per_layer: n_features,
n_features_total: n_features * 2,
model_name: "test".to_owned(),
},
loaded_encoder: None,
steering_cache: HashMap::new(),
};
let features = vec![
CltFeatureId { layer: 0, index: 0 },
CltFeatureId { layer: 0, index: 2 },
];
let vectors = clt.extract_decoder_vectors(&features, 1).unwrap();
assert_eq!(vectors.len(), 2);
let v0: Vec<f32> = vectors[&CltFeatureId { layer: 0, index: 0 }]
.to_vec1()
.unwrap();
assert_eq!(v0, vec![5.0, 6.0, 7.0, 8.0]);
let v2: Vec<f32> = vectors[&CltFeatureId { layer: 0, index: 2 }]
.to_vec1()
.unwrap();
assert_eq!(v2, vec![21.0, 22.0, 23.0, 24.0]);
}
#[test]
fn build_attribution_graph_synthetic() {
let dir = tempfile::tempdir().unwrap();
let d_model = 4;
let n_features = 2;
#[rustfmt::skip]
let dec0_values: Vec<f32> = vec![
1.0, 0.0, 0.0, 0.0,
0.0, 2.0, 0.0, 0.0,
];
let path0 = create_synthetic_decoder(dir.path(), 0, n_features, 1, d_model, &dec0_values);
let mut clt = CrossLayerTranscoder {
repo_id: "test".to_owned(),
fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
encoder_paths: vec![None],
decoder_paths: vec![Some(path0)],
config: CltConfig {
n_layers: 1,
d_model,
n_features_per_layer: n_features,
n_features_total: n_features,
model_name: "test".to_owned(),
},
loaded_encoder: None,
steering_cache: HashMap::new(),
};
let direction =
Tensor::from_vec(vec![0.0_f32, 1.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
let graph = clt
.build_attribution_graph(&direction, 0, 10, false)
.unwrap();
assert_eq!(graph.target_layer(), 0);
assert!(!graph.is_empty());
assert_eq!(
graph.edges()[0].feature,
CltFeatureId { layer: 0, index: 1 }
);
assert!((graph.edges()[0].score - 2.0).abs() < 1e-5);
let pruned = graph.threshold(1.0);
assert_eq!(pruned.len(), 1);
assert_eq!(pruned.features()[0], CltFeatureId { layer: 0, index: 1 });
}
}