mod npz;
use std::path::Path;
use candle_core::{DType, Device, Tensor};
use safetensors::tensor::SafeTensors;
use tracing::info;
use crate::error::{MIError, Result};
use crate::hooks::{HookPoint, HookSpec, Intervention};
use crate::sparse::{FeatureId, SparseActivations};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct SaeFeatureId {
pub index: usize,
}
impl std::fmt::Display for SaeFeatureId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SAE:{}", self.index)
}
}
impl FeatureId for SaeFeatureId {}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SaeArchitecture {
ReLU,
JumpReLU,
TopK {
k: usize,
},
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum NormalizeActivations {
None,
ExpectedAverageOnlyIn,
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TopKStrategy {
Auto,
Cpu,
Gpu,
}
#[derive(Debug, Clone)]
pub struct SaeConfig {
pub d_in: usize,
pub d_sae: usize,
pub architecture: SaeArchitecture,
pub hook_name: String,
pub hook_point: HookPoint,
pub apply_b_dec_to_input: bool,
pub normalize_activations: NormalizeActivations,
}
#[derive(serde::Deserialize)]
#[allow(clippy::missing_docs_in_private_items)]
struct RawSaeConfig {
d_in: usize,
d_sae: usize,
#[serde(default)]
architecture: Option<String>,
#[serde(default)]
activation_fn_str: Option<String>,
#[serde(default)]
activation_fn_kwargs: Option<serde_json::Value>,
#[serde(default)]
hook_name: Option<String>,
#[serde(default)]
hook_point: Option<String>,
#[serde(default)]
apply_b_dec_to_input: bool,
#[serde(default)]
normalize_activations: Option<String>,
}
fn parse_sae_config(raw: RawSaeConfig) -> Result<SaeConfig> {
let architecture = resolve_architecture(
raw.architecture.as_deref(),
raw.activation_fn_str.as_deref(),
raw.activation_fn_kwargs.as_ref(),
)?;
let hook_name = raw
.hook_name
.or(raw.hook_point)
.unwrap_or_else(|| "unknown".to_owned());
let hook_point: HookPoint = hook_name
.parse()
.unwrap_or_else(|_: std::convert::Infallible| {
unreachable!()
});
let normalize_activations = match raw.normalize_activations.as_deref() {
Some("expected_average_only_in") => NormalizeActivations::ExpectedAverageOnlyIn,
_ => NormalizeActivations::None,
};
Ok(SaeConfig {
d_in: raw.d_in,
d_sae: raw.d_sae,
architecture,
hook_name,
hook_point,
apply_b_dec_to_input: raw.apply_b_dec_to_input,
normalize_activations,
})
}
fn resolve_architecture(
architecture: Option<&str>,
activation_fn_str: Option<&str>,
activation_fn_kwargs: Option<&serde_json::Value>,
) -> Result<SaeArchitecture> {
match architecture {
Some("jumprelu") => return Ok(SaeArchitecture::JumpReLU),
Some("topk") => {
let k = extract_topk_k(activation_fn_kwargs)?;
return Ok(SaeArchitecture::TopK { k });
}
Some("standard") | None => {} Some(other) => {
return Err(MIError::Config(format!(
"unsupported SAE architecture: {other:?}"
)));
}
}
match activation_fn_str {
Some("relu") | None => Ok(SaeArchitecture::ReLU),
Some("jumprelu") => Ok(SaeArchitecture::JumpReLU),
Some("topk") => {
let k = extract_topk_k(activation_fn_kwargs)?;
Ok(SaeArchitecture::TopK { k })
}
Some(other) => Err(MIError::Config(format!(
"unsupported SAE activation function: {other:?}"
))),
}
}
fn extract_topk_k(kwargs: Option<&serde_json::Value>) -> Result<usize> {
let k = kwargs
.and_then(|v| v.get("k"))
.and_then(serde_json::Value::as_u64)
.ok_or_else(|| {
MIError::Config("TopK SAE requires activation_fn_kwargs.k in cfg.json".into())
})?;
let k_usize = usize::try_from(k)
.map_err(|_| MIError::Config(format!("TopK k value {k} too large for usize")))?;
Ok(k_usize)
}
pub struct SparseAutoencoder {
config: SaeConfig,
w_enc: Tensor,
w_dec: Tensor,
b_enc: Tensor,
b_dec: Tensor,
threshold: Option<Tensor>,
}
impl SparseAutoencoder {
pub fn from_local(dir: &Path, device: &Device) -> Result<Self> {
let cfg_path = dir.join("cfg.json");
if !cfg_path.exists() {
return Err(MIError::Config(format!(
"cfg.json not found in {}",
dir.display()
)));
}
let cfg_text = std::fs::read_to_string(&cfg_path)?;
let raw: RawSaeConfig = serde_json::from_str(&cfg_text)
.map_err(|e| MIError::Config(format!("failed to parse cfg.json: {e}")))?;
let config = parse_sae_config(raw)?;
info!(
"SAE config: d_in={}, d_sae={}, arch={:?}, hook={}",
config.d_in, config.d_sae, config.architecture, config.hook_name
);
let weights_path = if dir.join("sae_weights.safetensors").exists() {
dir.join("sae_weights.safetensors")
} else if dir.join("model.safetensors").exists() {
dir.join("model.safetensors")
} else {
return Err(MIError::Config(format!(
"no safetensors file found in {}",
dir.display()
)));
};
let data = std::fs::read(&weights_path)?;
let st = SafeTensors::deserialize(&data)
.map_err(|e| MIError::Config(format!("failed to deserialize SAE weights: {e}")))?;
let w_enc = load_tensor(&st, "W_enc", device)?;
let w_dec = load_tensor(&st, "W_dec", device)?;
let b_enc = load_tensor(&st, "b_enc", device)?;
let b_dec = load_tensor(&st, "b_dec", device)?;
let threshold = st
.tensor("threshold")
.ok()
.map(|v| tensor_from_view(&v, device))
.transpose()?;
let w_enc = w_enc.to_dtype(DType::F32)?;
let w_dec = w_dec.to_dtype(DType::F32)?;
let b_enc = b_enc.to_dtype(DType::F32)?;
let b_dec = b_dec.to_dtype(DType::F32)?;
let threshold = threshold.map(|t| t.to_dtype(DType::F32)).transpose()?;
validate_shape(&w_enc, &[config.d_in, config.d_sae], "W_enc")?;
validate_shape(&w_dec, &[config.d_sae, config.d_in], "W_dec")?;
validate_shape(&b_enc, &[config.d_sae], "b_enc")?;
validate_shape(&b_dec, &[config.d_in], "b_dec")?;
if let Some(ref t) = threshold {
validate_shape(t, &[config.d_sae], "threshold")?;
}
if config.architecture == SaeArchitecture::JumpReLU && threshold.is_none() {
return Err(MIError::Config(
"JumpReLU SAE requires 'threshold' tensor in weights file".into(),
));
}
info!(
"SAE loaded: {} weights on {:?}",
weights_path.display(),
device
);
Ok(Self {
config,
w_enc,
w_dec,
b_enc,
b_dec,
threshold,
})
}
pub fn from_npz(npz_path: &Path, hook_layer: usize, device: &Device) -> Result<Self> {
info!("Loading SAE from NPZ: {}", npz_path.display());
let tensors = npz::load_npz(npz_path, device)?;
let w_enc = tensors
.get("W_enc")
.ok_or_else(|| MIError::Config("NPZ missing W_enc".into()))?
.to_dtype(DType::F32)?;
let w_dec = tensors
.get("W_dec")
.ok_or_else(|| MIError::Config("NPZ missing W_dec".into()))?
.to_dtype(DType::F32)?;
let b_enc = tensors
.get("b_enc")
.ok_or_else(|| MIError::Config("NPZ missing b_enc".into()))?
.to_dtype(DType::F32)?;
let b_dec = tensors
.get("b_dec")
.ok_or_else(|| MIError::Config("NPZ missing b_dec".into()))?
.to_dtype(DType::F32)?;
let threshold = tensors
.get("threshold")
.map(|t| t.to_dtype(DType::F32))
.transpose()?;
let w_enc_dims = w_enc.dims();
if w_enc_dims.len() != 2 {
return Err(MIError::Config(format!(
"W_enc expected 2 dims, got {}",
w_enc_dims.len()
)));
}
let d_in = *w_enc_dims
.first()
.ok_or_else(|| MIError::Config("W_enc has no dimensions".into()))?;
let d_sae = *w_enc_dims
.get(1)
.ok_or_else(|| MIError::Config("W_enc has no second dimension".into()))?;
validate_shape(&w_enc, &[d_in, d_sae], "W_enc")?;
validate_shape(&w_dec, &[d_sae, d_in], "W_dec")?;
validate_shape(&b_enc, &[d_sae], "b_enc")?;
validate_shape(&b_dec, &[d_in], "b_dec")?;
if let Some(ref t) = threshold {
validate_shape(t, &[d_sae], "threshold")?;
}
let architecture = if threshold.is_some() {
SaeArchitecture::JumpReLU
} else {
SaeArchitecture::ReLU
};
let hook_name = format!("blocks.{hook_layer}.hook_resid_post");
let hook_point = hook_name
.parse::<HookPoint>()
.map_err(|e| MIError::Config(format!("failed to parse hook name: {e}")))?;
let config = SaeConfig {
d_in,
d_sae,
architecture,
hook_name,
hook_point,
apply_b_dec_to_input: false,
normalize_activations: NormalizeActivations::None,
};
info!(
"SAE from NPZ: d_in={d_in}, d_sae={d_sae}, arch={:?}, hook={}",
config.architecture, config.hook_name
);
Ok(Self {
config,
w_enc,
w_dec,
b_enc,
b_dec,
threshold,
})
}
pub fn from_pretrained_npz(
repo_id: &str,
npz_path: &str,
hook_layer: usize,
device: &Device,
) -> Result<Self> {
let fetch_config = crate::download::fetch_config_builder()
.on_progress(|event| {
tracing::info!(
filename = %event.filename,
percent = event.percent,
bytes_downloaded = event.bytes_downloaded,
bytes_total = event.bytes_total,
"SAE NPZ download progress",
);
})
.build()
.map_err(|e| MIError::Download(format!("failed to build fetch config: {e}")))?;
info!("Downloading {npz_path} from {repo_id}");
let local_path =
hf_fetch_model::download_file_blocking(repo_id.to_owned(), npz_path, &fetch_config)
.map_err(|e| MIError::Download(format!("failed to download NPZ: {e}")))?
.into_inner();
Self::from_npz(&local_path, hook_layer, device)
}
pub fn from_pretrained(repo_id: &str, sae_id: &str, device: &Device) -> Result<Self> {
let fetch_config = crate::download::fetch_config_builder()
.on_progress(|event| {
tracing::info!(
filename = %event.filename,
percent = event.percent,
bytes_downloaded = event.bytes_downloaded,
bytes_total = event.bytes_total,
"SAE download progress",
);
})
.build()
.map_err(|e| MIError::Download(format!("failed to build fetch config: {e}")))?;
let cfg_remote = format!("{sae_id}/cfg.json");
info!("Downloading {cfg_remote} from {repo_id}");
let cfg_path =
hf_fetch_model::download_file_blocking(repo_id.to_owned(), &cfg_remote, &fetch_config)
.map_err(|e| MIError::Download(format!("failed to download cfg.json: {e}")))?
.into_inner();
let weights_remote = format!("{sae_id}/sae_weights.safetensors");
info!("Downloading {weights_remote} from {repo_id}");
let weights_path = hf_fetch_model::download_file_blocking(
repo_id.to_owned(),
&weights_remote,
&fetch_config,
)
.or_else(|_| {
let alt_remote = format!("{sae_id}/model.safetensors");
info!("Trying {alt_remote} from {repo_id}");
hf_fetch_model::download_file_blocking(repo_id.to_owned(), &alt_remote, &fetch_config)
})
.map_err(|e| MIError::Download(format!("failed to download SAE weights: {e}")))?
.into_inner();
let dir = cfg_path.parent().ok_or_else(|| {
MIError::Config("cannot determine SAE directory from cfg.json path".into())
})?;
if dir.join("sae_weights.safetensors").exists() || dir.join("model.safetensors").exists() {
Self::from_local(dir, device)
} else {
let weights_dir = weights_path.parent().ok_or_else(|| {
MIError::Config("cannot determine SAE directory from weights path".into())
})?;
let target_cfg = weights_dir.join("cfg.json");
if !target_cfg.exists() {
std::fs::copy(&cfg_path, &target_cfg)?;
}
Self::from_local(weights_dir, device)
}
}
#[must_use]
pub const fn config(&self) -> &SaeConfig {
&self.config
}
#[must_use]
pub const fn hook_point(&self) -> &HookPoint {
&self.config.hook_point
}
#[must_use]
pub const fn d_sae(&self) -> usize {
self.config.d_sae
}
#[must_use]
pub const fn d_in(&self) -> usize {
self.config.d_in
}
pub fn encode(&self, x: &Tensor) -> Result<Tensor> {
self.encode_with_strategy(x, &TopKStrategy::Auto)
}
pub fn encode_with_strategy(&self, x: &Tensor, strategy: &TopKStrategy) -> Result<Tensor> {
let dims = x.dims();
let last_dim = *dims
.last()
.ok_or_else(|| MIError::Config("cannot encode empty tensor".into()))?;
if last_dim != self.config.d_in {
return Err(MIError::Config(format!(
"input last dim {last_dim} != SAE d_in {}",
self.config.d_in
)));
}
let x_f32 = x.to_dtype(DType::F32)?;
let x_centered = if self.config.apply_b_dec_to_input {
let b_dec = broadcast_bias(&self.b_dec, x_f32.dims())?;
(&x_f32 - &b_dec)?
} else {
x_f32
};
let pre_acts = x_centered.broadcast_matmul(&self.w_enc)?;
let b_enc = broadcast_bias(&self.b_enc, pre_acts.dims())?;
let pre_acts = (&pre_acts + &b_enc)?;
match &self.config.architecture {
SaeArchitecture::ReLU => Ok(pre_acts.relu()?),
SaeArchitecture::JumpReLU => {
let threshold = self
.threshold
.as_ref()
.ok_or_else(|| MIError::Config("JumpReLU requires threshold tensor".into()))?;
let threshold = broadcast_bias(threshold, pre_acts.dims())?;
let mask = pre_acts.gt(&threshold)?;
let mask_f32 = mask.to_dtype(DType::F32)?;
Ok((&pre_acts * &mask_f32)?)
}
SaeArchitecture::TopK { k } => topk_activation(&pre_acts, *k, strategy),
}
}
pub fn encode_sparse(&self, x: &Tensor) -> Result<SparseActivations<SaeFeatureId>> {
let encoded = self.encode(&x.unsqueeze(0)?)?;
let encoded_1d = encoded.squeeze(0)?;
let values: Vec<f32> = encoded_1d.to_vec1()?;
let mut features: Vec<(SaeFeatureId, f32)> = values
.iter()
.enumerate()
.filter(|&(_, v)| *v > 0.0)
.map(|(i, v)| (SaeFeatureId { index: i }, *v))
.collect();
features.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(SparseActivations { features })
}
pub fn decode(&self, features: &Tensor) -> Result<Tensor> {
let features_f32 = features.to_dtype(DType::F32)?;
let decoded = features_f32.broadcast_matmul(&self.w_dec)?;
let b_dec = broadcast_bias(&self.b_dec, decoded.dims())?;
Ok((&decoded + &b_dec)?)
}
pub fn reconstruct(&self, x: &Tensor) -> Result<Tensor> {
let encoded = self.encode(x)?;
self.decode(&encoded)
}
pub fn reconstruction_error(&self, x: &Tensor) -> Result<f64> {
let x_f32 = x.to_dtype(DType::F32)?;
let x_hat = self.reconstruct(&x_f32)?;
let diff = (&x_f32 - &x_hat)?;
let mse: f32 = diff.sqr()?.mean_all()?.to_scalar()?;
Ok(f64::from(mse))
}
pub fn decoder_vector(&self, feature_idx: usize) -> Result<Tensor> {
if feature_idx >= self.config.d_sae {
return Err(MIError::Config(format!(
"feature index {feature_idx} out of range (d_sae={})",
self.config.d_sae
)));
}
Ok(self.w_dec.get(feature_idx)?)
}
pub fn prepare_hook_injection(
&self,
features: &[(usize, f32)],
position: usize,
seq_len: usize,
device: &Device,
) -> Result<HookSpec> {
let d_in = self.config.d_in;
let mut accumulated = Tensor::zeros(d_in, DType::F32, device)?;
for &(feature_idx, strength) in features {
let dec_vec = self.decoder_vector(feature_idx)?;
let dec_vec = dec_vec.to_device(device)?;
let scaled = (&dec_vec * f64::from(strength))?;
accumulated = (&accumulated + &scaled)?;
}
let injection = Tensor::zeros((1, seq_len, d_in), DType::F32, device)?;
let scaled_3d = accumulated.unsqueeze(0)?.unsqueeze(0)?;
let before = if position > 0 {
Some(injection.narrow(1, 0, position)?)
} else {
None
};
let after = if position + 1 < seq_len {
Some(injection.narrow(1, position + 1, seq_len - position - 1)?)
} else {
None
};
let mut parts: Vec<Tensor> = Vec::with_capacity(3);
if let Some(b) = before {
parts.push(b);
}
parts.push(scaled_3d);
if let Some(a) = after {
parts.push(a);
}
let injection = Tensor::cat(&parts, 1)?;
let mut hooks = HookSpec::new();
hooks.intervene(self.config.hook_point.clone(), Intervention::Add(injection));
Ok(hooks)
}
}
fn topk_activation(pre_acts: &Tensor, k: usize, strategy: &TopKStrategy) -> Result<Tensor> {
let use_cpu = match strategy {
TopKStrategy::Cpu => true,
TopKStrategy::Gpu => false,
TopKStrategy::Auto => matches!(pre_acts.device(), Device::Cpu),
};
if use_cpu {
topk_cpu(pre_acts, k)
} else {
topk_gpu(pre_acts, k)
}
}
fn topk_cpu(pre_acts: &Tensor, k: usize) -> Result<Tensor> {
let device = pre_acts.device().clone();
let shape = pre_acts.dims().to_vec();
let d_sae = *shape
.last()
.ok_or_else(|| MIError::Config("cannot apply TopK to empty tensor".into()))?;
let n: usize = shape.iter().take(shape.len() - 1).product();
let flat = pre_acts.reshape((n, d_sae))?.to_dtype(DType::F32)?;
let flat_cpu = flat.to_device(&Device::Cpu)?;
let mut result_data: Vec<f32> = Vec::with_capacity(n * d_sae);
for row_idx in 0..n {
let row = flat_cpu.get(row_idx)?;
let mut row_vec: Vec<f32> = row.to_vec1()?;
let k_clamped = k.min(d_sae);
if k_clamped > 0 && k_clamped < d_sae {
let mut indices: Vec<usize> = (0..d_sae).collect();
#[allow(clippy::indexing_slicing)]
indices.select_nth_unstable_by(k_clamped - 1, |&a, &b| {
let va = row_vec.get(b).copied().unwrap_or(f32::NEG_INFINITY);
let vb = row_vec.get(a).copied().unwrap_or(f32::NEG_INFINITY);
va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal)
});
let threshold_idx = indices.get(k_clamped - 1).copied().unwrap_or(0);
let threshold = row_vec.get(threshold_idx).copied().unwrap_or(0.0);
for v in &mut row_vec {
if *v < threshold {
*v = 0.0;
}
}
let active: usize = row_vec.iter().filter(|&&v| v >= threshold).count();
if active > k_clamped {
let mut excess = active - k_clamped;
for v in row_vec.iter_mut().rev() {
if excess == 0 {
break;
}
if (*v - threshold).abs() < f32::EPSILON {
*v = 0.0;
excess -= 1;
}
}
}
} else if k_clamped == 0 {
row_vec.fill(0.0);
}
result_data.extend_from_slice(&row_vec);
}
let result = Tensor::from_vec(result_data, (n, d_sae), &device)?;
result.reshape(shape.as_slice()).map_err(Into::into)
}
fn topk_gpu(pre_acts: &Tensor, k: usize) -> Result<Tensor> {
let shape = pre_acts.dims().to_vec();
let d_sae = *shape
.last()
.ok_or_else(|| MIError::Config("cannot apply TopK to empty tensor".into()))?;
let k_clamped = k.min(d_sae);
if k_clamped == 0 {
return Ok(pre_acts.zeros_like()?);
}
if k_clamped >= d_sae {
return Ok(pre_acts.clone());
}
let n: usize = shape.iter().take(shape.len() - 1).product();
let flat = pre_acts.reshape((n, d_sae))?.to_dtype(DType::F32)?;
let (sorted_vals, _sorted_indices) = flat.sort_last_dim(false)?;
let kth_vals = sorted_vals.narrow(1, k_clamped - 1, 1)?;
let mask = flat.ge(&kth_vals)?;
let mask_f32 = mask.to_dtype(DType::F32)?;
let result = (&flat * &mask_f32)?;
result.reshape(shape.as_slice()).map_err(Into::into)
}
fn broadcast_bias(bias: &Tensor, target_shape: &[usize]) -> Result<Tensor> {
let ndim = target_shape.len();
if ndim <= 1 {
return Ok(bias.clone());
}
let mut shape = vec![1_usize; ndim];
let last_dim = *target_shape
.last()
.ok_or_else(|| MIError::Config("cannot broadcast bias to empty shape".into()))?;
if let Some(slot) = shape.last_mut() {
*slot = last_dim;
}
let reshaped = bias.reshape(shape.as_slice())?;
Ok(reshaped.broadcast_as(target_shape)?)
}
fn tensor_from_view(view: &safetensors::tensor::TensorView<'_>, device: &Device) -> Result<Tensor> {
let shape: Vec<usize> = view.shape().to_vec();
#[allow(clippy::wildcard_enum_match_arm)]
let dtype = match view.dtype() {
safetensors::Dtype::BF16 => DType::BF16,
safetensors::Dtype::F16 => DType::F16,
safetensors::Dtype::F32 => DType::F32,
other => {
return Err(MIError::Config(format!(
"unsupported SAE tensor dtype: {other:?}"
)));
}
};
let tensor = Tensor::from_raw_buffer(view.data(), dtype, &shape, device)?;
Ok(tensor)
}
fn load_tensor(st: &SafeTensors<'_>, name: &str, device: &Device) -> Result<Tensor> {
let view = st
.tensor(name)
.map_err(|e| MIError::Config(format!("tensor '{name}' not found: {e}")))?;
tensor_from_view(&view, device)
}
fn validate_shape(tensor: &Tensor, expected: &[usize], name: &str) -> Result<()> {
if tensor.dims() != expected {
return Err(MIError::Config(format!(
"SAE tensor '{name}' shape mismatch: expected {expected:?}, got {:?}",
tensor.dims()
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sae_feature_id_display() {
let fid = SaeFeatureId { index: 42 };
assert_eq!(fid.to_string(), "SAE:42");
}
#[test]
fn resolve_architecture_relu_default() {
let arch = resolve_architecture(None, None, None).unwrap();
assert_eq!(arch, SaeArchitecture::ReLU);
}
#[test]
fn resolve_architecture_relu_explicit() {
let arch = resolve_architecture(Some("standard"), Some("relu"), None).unwrap();
assert_eq!(arch, SaeArchitecture::ReLU);
}
#[test]
fn resolve_architecture_jumprelu() {
let arch = resolve_architecture(Some("jumprelu"), None, None).unwrap();
assert_eq!(arch, SaeArchitecture::JumpReLU);
}
#[test]
fn resolve_architecture_jumprelu_from_activation() {
let arch = resolve_architecture(None, Some("jumprelu"), None).unwrap();
assert_eq!(arch, SaeArchitecture::JumpReLU);
}
#[test]
fn resolve_architecture_topk() {
let kwargs = serde_json::json!({"k": 32});
let arch = resolve_architecture(Some("topk"), None, Some(&kwargs)).unwrap();
assert_eq!(arch, SaeArchitecture::TopK { k: 32 });
}
#[test]
fn resolve_architecture_topk_from_activation() {
let kwargs = serde_json::json!({"k": 64});
let arch = resolve_architecture(None, Some("topk"), Some(&kwargs)).unwrap();
assert_eq!(arch, SaeArchitecture::TopK { k: 64 });
}
#[test]
fn resolve_architecture_topk_missing_k() {
let result = resolve_architecture(Some("topk"), None, None);
assert!(result.is_err());
}
#[test]
fn resolve_architecture_unknown() {
let result = resolve_architecture(Some("gated"), None, None);
assert!(result.is_err());
}
#[test]
fn parse_config_minimal() {
let json = r#"{
"d_in": 2304,
"d_sae": 16384,
"hook_name": "blocks.5.hook_resid_post"
}"#;
let raw: RawSaeConfig = serde_json::from_str(json).unwrap();
let config = parse_sae_config(raw).unwrap();
assert_eq!(config.d_in, 2304);
assert_eq!(config.d_sae, 16384);
assert_eq!(config.architecture, SaeArchitecture::ReLU);
assert_eq!(config.hook_point, HookPoint::ResidPost(5));
assert!(!config.apply_b_dec_to_input);
}
#[test]
fn parse_config_jumprelu() {
let json = r#"{
"d_in": 2304,
"d_sae": 16384,
"architecture": "jumprelu",
"hook_name": "blocks.20.hook_resid_post",
"apply_b_dec_to_input": true,
"normalize_activations": "expected_average_only_in"
}"#;
let raw: RawSaeConfig = serde_json::from_str(json).unwrap();
let config = parse_sae_config(raw).unwrap();
assert_eq!(config.architecture, SaeArchitecture::JumpReLU);
assert_eq!(config.hook_point, HookPoint::ResidPost(20));
assert!(config.apply_b_dec_to_input);
assert_eq!(
config.normalize_activations,
NormalizeActivations::ExpectedAverageOnlyIn
);
}
#[test]
fn parse_config_topk() {
let json = r#"{
"d_in": 2304,
"d_sae": 65536,
"activation_fn_str": "topk",
"activation_fn_kwargs": {"k": 32},
"hook_name": "blocks.10.hook_resid_post"
}"#;
let raw: RawSaeConfig = serde_json::from_str(json).unwrap();
let config = parse_sae_config(raw).unwrap();
assert_eq!(config.architecture, SaeArchitecture::TopK { k: 32 });
}
#[test]
fn topk_cpu_basic() {
let data = Tensor::new(&[[5.0_f32, 3.0, 1.0, 4.0, 2.0]], &Device::Cpu).unwrap();
let result = topk_cpu(&data, 2).unwrap();
let vals: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
assert_eq!(vals, vec![5.0, 0.0, 0.0, 4.0, 0.0]);
}
#[test]
fn topk_cpu_all_kept() {
let data = Tensor::new(&[[1.0_f32, 2.0, 3.0]], &Device::Cpu).unwrap();
let result = topk_cpu(&data, 5).unwrap();
let vals: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
assert_eq!(vals, vec![1.0, 2.0, 3.0]);
}
#[test]
fn topk_cpu_none_kept() {
let data = Tensor::new(&[[1.0_f32, 2.0, 3.0]], &Device::Cpu).unwrap();
let result = topk_cpu(&data, 0).unwrap();
let vals: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
assert_eq!(vals, vec![0.0, 0.0, 0.0]);
}
#[test]
fn topk_cpu_batched() {
let data = Tensor::new(
&[[5.0_f32, 3.0, 1.0, 4.0, 2.0], [1.0, 2.0, 3.0, 4.0, 5.0]],
&Device::Cpu,
)
.unwrap();
let result = topk_cpu(&data, 3).unwrap();
let vals: Vec<Vec<f32>> = result.to_vec2().unwrap();
assert_eq!(vals[0], vec![5.0, 3.0, 0.0, 4.0, 0.0]);
assert_eq!(vals[1], vec![0.0, 0.0, 3.0, 4.0, 5.0]);
}
#[test]
fn sparse_activations_sae() {
let features = vec![
(SaeFeatureId { index: 5 }, 3.0),
(SaeFeatureId { index: 2 }, 2.0),
(SaeFeatureId { index: 8 }, 1.0),
];
let sparse = SparseActivations { features };
assert_eq!(sparse.len(), 3);
assert!(!sparse.is_empty());
}
#[test]
fn sparse_activations_truncate_sae() {
let features = vec![
(SaeFeatureId { index: 5 }, 3.0),
(SaeFeatureId { index: 2 }, 2.0),
(SaeFeatureId { index: 8 }, 1.0),
];
let mut sparse = SparseActivations { features };
sparse.truncate(2);
assert_eq!(sparse.len(), 2);
assert_eq!(sparse.features[0].0.index, 5);
assert_eq!(sparse.features[1].0.index, 2);
}
#[test]
fn encode_decode_roundtrip_shapes() {
let d_in = 4;
let d_sae = 8;
let device = Device::Cpu;
let w_enc = Tensor::randn(0.0_f32, 1.0, (d_in, d_sae), &device).unwrap();
let w_dec = Tensor::randn(0.0_f32, 1.0, (d_sae, d_in), &device).unwrap();
let b_enc = Tensor::zeros(d_sae, DType::F32, &device).unwrap();
let b_dec = Tensor::zeros(d_in, DType::F32, &device).unwrap();
let sae = SparseAutoencoder {
config: SaeConfig {
d_in,
d_sae,
architecture: SaeArchitecture::ReLU,
hook_name: "blocks.0.hook_resid_post".into(),
hook_point: HookPoint::ResidPost(0),
apply_b_dec_to_input: false,
normalize_activations: NormalizeActivations::None,
},
w_enc,
w_dec,
b_enc,
b_dec,
threshold: None,
};
let x1 = Tensor::randn(0.0_f32, 1.0, (d_in,), &device).unwrap();
let encoded = sae.encode(&x1.unsqueeze(0).unwrap()).unwrap();
assert_eq!(encoded.dims(), &[1, d_sae]);
let x2 = Tensor::randn(0.0_f32, 1.0, (3, d_in), &device).unwrap();
let encoded = sae.encode(&x2).unwrap();
assert_eq!(encoded.dims(), &[3, d_sae]);
let decoded = sae.decode(&encoded).unwrap();
assert_eq!(decoded.dims(), &[3, d_in]);
let x3 = Tensor::randn(0.0_f32, 1.0, (2, 5, d_in), &device).unwrap();
let encoded = sae.encode(&x3).unwrap();
assert_eq!(encoded.dims(), &[2, 5, d_sae]);
let decoded = sae.decode(&encoded).unwrap();
assert_eq!(decoded.dims(), &[2, 5, d_in]);
let x_hat = sae.reconstruct(&x2).unwrap();
assert_eq!(x_hat.dims(), &[3, d_in]);
let mse = sae.reconstruction_error(&x2).unwrap();
assert!(mse >= 0.0);
}
#[test]
fn encode_sparse_basic() {
let d_in = 4;
let d_sae = 8;
let device = Device::Cpu;
let mut w_enc_data = vec![0.0_f32; d_in * d_sae];
for i in 0..d_in {
w_enc_data[i * d_sae + i] = 1.0;
}
let w_enc = Tensor::from_vec(w_enc_data, (d_in, d_sae), &device).unwrap();
let w_dec = Tensor::randn(0.0_f32, 1.0, (d_sae, d_in), &device).unwrap();
let b_enc = Tensor::zeros(d_sae, DType::F32, &device).unwrap();
let b_dec = Tensor::zeros(d_in, DType::F32, &device).unwrap();
let sae = SparseAutoencoder {
config: SaeConfig {
d_in,
d_sae,
architecture: SaeArchitecture::ReLU,
hook_name: "blocks.0.hook_resid_post".into(),
hook_point: HookPoint::ResidPost(0),
apply_b_dec_to_input: false,
normalize_activations: NormalizeActivations::None,
},
w_enc,
w_dec,
b_enc,
b_dec,
threshold: None,
};
let x = Tensor::new(&[2.0_f32, -1.0, 3.0, 0.5], &device).unwrap();
let sparse = sae.encode_sparse(&x).unwrap();
assert_eq!(sparse.len(), 3);
assert_eq!(sparse.features[0].0.index, 2); assert_eq!(sparse.features[1].0.index, 0); assert_eq!(sparse.features[2].0.index, 3); }
#[test]
fn decoder_vector_basic() {
let d_in = 4;
let d_sae = 8;
let device = Device::Cpu;
let w_dec = Tensor::randn(0.0_f32, 1.0, (d_sae, d_in), &device).unwrap();
let sae = SparseAutoencoder {
config: SaeConfig {
d_in,
d_sae,
architecture: SaeArchitecture::ReLU,
hook_name: "blocks.0.hook_resid_post".into(),
hook_point: HookPoint::ResidPost(0),
apply_b_dec_to_input: false,
normalize_activations: NormalizeActivations::None,
},
w_enc: Tensor::zeros((d_in, d_sae), DType::F32, &device).unwrap(),
w_dec: w_dec.clone(),
b_enc: Tensor::zeros(d_sae, DType::F32, &device).unwrap(),
b_dec: Tensor::zeros(d_in, DType::F32, &device).unwrap(),
threshold: None,
};
let vec0 = sae.decoder_vector(0).unwrap();
assert_eq!(vec0.dims(), &[d_in]);
assert!(sae.decoder_vector(d_sae).is_err());
}
#[test]
fn prepare_injection_basic() {
let d_in = 4;
let d_sae = 8;
let device = Device::Cpu;
let sae = SparseAutoencoder {
config: SaeConfig {
d_in,
d_sae,
architecture: SaeArchitecture::ReLU,
hook_name: "blocks.0.hook_resid_post".into(),
hook_point: HookPoint::ResidPost(0),
apply_b_dec_to_input: false,
normalize_activations: NormalizeActivations::None,
},
w_enc: Tensor::zeros((d_in, d_sae), DType::F32, &device).unwrap(),
w_dec: Tensor::ones((d_sae, d_in), DType::F32, &device).unwrap(),
b_enc: Tensor::zeros(d_sae, DType::F32, &device).unwrap(),
b_dec: Tensor::zeros(d_in, DType::F32, &device).unwrap(),
threshold: None,
};
let features = vec![(0_usize, 1.0_f32), (1, 0.5)];
let hooks = sae
.prepare_hook_injection(&features, 2, 5, &device)
.unwrap();
assert!(!hooks.is_empty());
}
}