1mod npz;
56
57use std::path::Path;
58
59use candle_core::{DType, Device, Tensor};
60use safetensors::tensor::SafeTensors;
61use tracing::info;
62
63use crate::error::{MIError, Result};
64use crate::hooks::{HookPoint, HookSpec, Intervention};
65use crate::sparse::{FeatureId, SparseActivations};
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
73pub struct SaeFeatureId {
74 pub index: usize,
76}
77
78impl std::fmt::Display for SaeFeatureId {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 write!(f, "SAE:{}", self.index)
81 }
82}
83
84impl FeatureId for SaeFeatureId {}
85
86#[non_exhaustive]
88#[derive(Debug, Clone, PartialEq, Eq)]
89pub enum SaeArchitecture {
90 ReLU,
92 JumpReLU,
95 TopK {
97 k: usize,
99 },
100}
101
102#[non_exhaustive]
104#[derive(Debug, Clone, PartialEq, Eq)]
105pub enum NormalizeActivations {
106 None,
108 ExpectedAverageOnlyIn,
110}
111
112#[non_exhaustive]
114#[derive(Debug, Clone, PartialEq, Eq)]
115pub enum TopKStrategy {
116 Auto,
118 Cpu,
120 Gpu,
122}
123
124#[derive(Debug, Clone)]
126pub struct SaeConfig {
127 pub d_in: usize,
129 pub d_sae: usize,
131 pub architecture: SaeArchitecture,
133 pub hook_name: String,
135 pub hook_point: HookPoint,
137 pub apply_b_dec_to_input: bool,
139 pub normalize_activations: NormalizeActivations,
141}
142
143#[derive(serde::Deserialize)]
148#[allow(clippy::missing_docs_in_private_items)]
149struct RawSaeConfig {
150 d_in: usize,
151 d_sae: usize,
152 #[serde(default)]
153 architecture: Option<String>,
154 #[serde(default)]
155 activation_fn_str: Option<String>,
156 #[serde(default)]
157 activation_fn_kwargs: Option<serde_json::Value>,
158 #[serde(default)]
159 hook_name: Option<String>,
160 #[serde(default)]
161 hook_point: Option<String>,
162 #[serde(default)]
163 apply_b_dec_to_input: bool,
164 #[serde(default)]
165 normalize_activations: Option<String>,
166}
167
168fn parse_sae_config(raw: RawSaeConfig) -> Result<SaeConfig> {
170 let architecture = resolve_architecture(
172 raw.architecture.as_deref(),
173 raw.activation_fn_str.as_deref(),
174 raw.activation_fn_kwargs.as_ref(),
175 )?;
176
177 let hook_name = raw
179 .hook_name
180 .or(raw.hook_point)
181 .unwrap_or_else(|| "unknown".to_owned());
182
183 let hook_point: HookPoint = hook_name
185 .parse()
186 .unwrap_or_else(|_: std::convert::Infallible| {
187 unreachable!()
189 });
190
191 let normalize_activations = match raw.normalize_activations.as_deref() {
192 Some("expected_average_only_in") => NormalizeActivations::ExpectedAverageOnlyIn,
193 _ => NormalizeActivations::None,
194 };
195
196 Ok(SaeConfig {
197 d_in: raw.d_in,
198 d_sae: raw.d_sae,
199 architecture,
200 hook_name,
201 hook_point,
202 apply_b_dec_to_input: raw.apply_b_dec_to_input,
203 normalize_activations,
204 })
205}
206
207fn resolve_architecture(
209 architecture: Option<&str>,
210 activation_fn_str: Option<&str>,
211 activation_fn_kwargs: Option<&serde_json::Value>,
212) -> Result<SaeArchitecture> {
213 match architecture {
215 Some("jumprelu") => return Ok(SaeArchitecture::JumpReLU),
216 Some("topk") => {
217 let k = extract_topk_k(activation_fn_kwargs)?;
218 return Ok(SaeArchitecture::TopK { k });
219 }
220 Some("standard") | None => {} Some(other) => {
222 return Err(MIError::Config(format!(
223 "unsupported SAE architecture: {other:?}"
224 )));
225 }
226 }
227
228 match activation_fn_str {
230 Some("relu") | None => Ok(SaeArchitecture::ReLU),
231 Some("jumprelu") => Ok(SaeArchitecture::JumpReLU),
232 Some("topk") => {
233 let k = extract_topk_k(activation_fn_kwargs)?;
234 Ok(SaeArchitecture::TopK { k })
235 }
236 Some(other) => Err(MIError::Config(format!(
237 "unsupported SAE activation function: {other:?}"
238 ))),
239 }
240}
241
242fn extract_topk_k(kwargs: Option<&serde_json::Value>) -> Result<usize> {
244 let k = kwargs
245 .and_then(|v| v.get("k"))
246 .and_then(serde_json::Value::as_u64)
247 .ok_or_else(|| {
248 MIError::Config("TopK SAE requires activation_fn_kwargs.k in cfg.json".into())
249 })?;
250 let k_usize = usize::try_from(k)
251 .map_err(|_| MIError::Config(format!("TopK k value {k} too large for usize")))?;
252 Ok(k_usize)
253}
254
255pub struct SparseAutoencoder {
286 config: SaeConfig,
288 w_enc: Tensor,
293 w_dec: Tensor,
298 b_enc: Tensor,
303 b_dec: Tensor,
308 threshold: Option<Tensor>,
313}
314
315impl SparseAutoencoder {
316 pub fn from_local(dir: &Path, device: &Device) -> Result<Self> {
330 let cfg_path = dir.join("cfg.json");
332 if !cfg_path.exists() {
333 return Err(MIError::Config(format!(
334 "cfg.json not found in {}",
335 dir.display()
336 )));
337 }
338 let cfg_text = std::fs::read_to_string(&cfg_path)?;
339 let raw: RawSaeConfig = serde_json::from_str(&cfg_text)
340 .map_err(|e| MIError::Config(format!("failed to parse cfg.json: {e}")))?;
341 let config = parse_sae_config(raw)?;
342
343 info!(
344 "SAE config: d_in={}, d_sae={}, arch={:?}, hook={}",
345 config.d_in, config.d_sae, config.architecture, config.hook_name
346 );
347
348 let weights_path = if dir.join("sae_weights.safetensors").exists() {
350 dir.join("sae_weights.safetensors")
351 } else if dir.join("model.safetensors").exists() {
352 dir.join("model.safetensors")
353 } else {
354 return Err(MIError::Config(format!(
355 "no safetensors file found in {}",
356 dir.display()
357 )));
358 };
359
360 let data = std::fs::read(&weights_path)?;
362 let st = SafeTensors::deserialize(&data)
363 .map_err(|e| MIError::Config(format!("failed to deserialize SAE weights: {e}")))?;
364
365 let w_enc = load_tensor(&st, "W_enc", device)?;
366 let w_dec = load_tensor(&st, "W_dec", device)?;
367 let b_enc = load_tensor(&st, "b_enc", device)?;
368 let b_dec = load_tensor(&st, "b_dec", device)?;
369 let threshold = st
370 .tensor("threshold")
371 .ok()
372 .map(|v| tensor_from_view(&v, device))
373 .transpose()?;
374
375 let w_enc = w_enc.to_dtype(DType::F32)?;
377 let w_dec = w_dec.to_dtype(DType::F32)?;
378 let b_enc = b_enc.to_dtype(DType::F32)?;
379 let b_dec = b_dec.to_dtype(DType::F32)?;
380 let threshold = threshold.map(|t| t.to_dtype(DType::F32)).transpose()?;
381
382 validate_shape(&w_enc, &[config.d_in, config.d_sae], "W_enc")?;
384 validate_shape(&w_dec, &[config.d_sae, config.d_in], "W_dec")?;
385 validate_shape(&b_enc, &[config.d_sae], "b_enc")?;
386 validate_shape(&b_dec, &[config.d_in], "b_dec")?;
387 if let Some(ref t) = threshold {
388 validate_shape(t, &[config.d_sae], "threshold")?;
389 }
390
391 if config.architecture == SaeArchitecture::JumpReLU && threshold.is_none() {
393 return Err(MIError::Config(
394 "JumpReLU SAE requires 'threshold' tensor in weights file".into(),
395 ));
396 }
397
398 info!(
399 "SAE loaded: {} weights on {:?}",
400 weights_path.display(),
401 device
402 );
403
404 Ok(Self {
405 config,
406 w_enc,
407 w_dec,
408 b_enc,
409 b_dec,
410 threshold,
411 })
412 }
413
414 pub fn from_npz(npz_path: &Path, hook_layer: usize, device: &Device) -> Result<Self> {
431 info!("Loading SAE from NPZ: {}", npz_path.display());
432 let tensors = npz::load_npz(npz_path, device)?;
433
434 let w_enc = tensors
435 .get("W_enc")
436 .ok_or_else(|| MIError::Config("NPZ missing W_enc".into()))?
437 .to_dtype(DType::F32)?;
438 let w_dec = tensors
439 .get("W_dec")
440 .ok_or_else(|| MIError::Config("NPZ missing W_dec".into()))?
441 .to_dtype(DType::F32)?;
442 let b_enc = tensors
443 .get("b_enc")
444 .ok_or_else(|| MIError::Config("NPZ missing b_enc".into()))?
445 .to_dtype(DType::F32)?;
446 let b_dec = tensors
447 .get("b_dec")
448 .ok_or_else(|| MIError::Config("NPZ missing b_dec".into()))?
449 .to_dtype(DType::F32)?;
450 let threshold = tensors
451 .get("threshold")
452 .map(|t| t.to_dtype(DType::F32))
453 .transpose()?;
454
455 let w_enc_dims = w_enc.dims();
457 if w_enc_dims.len() != 2 {
458 return Err(MIError::Config(format!(
459 "W_enc expected 2 dims, got {}",
460 w_enc_dims.len()
461 )));
462 }
463 let d_in = *w_enc_dims
464 .first()
465 .ok_or_else(|| MIError::Config("W_enc has no dimensions".into()))?;
466 let d_sae = *w_enc_dims
467 .get(1)
468 .ok_or_else(|| MIError::Config("W_enc has no second dimension".into()))?;
469
470 validate_shape(&w_enc, &[d_in, d_sae], "W_enc")?;
472 validate_shape(&w_dec, &[d_sae, d_in], "W_dec")?;
473 validate_shape(&b_enc, &[d_sae], "b_enc")?;
474 validate_shape(&b_dec, &[d_in], "b_dec")?;
475 if let Some(ref t) = threshold {
476 validate_shape(t, &[d_sae], "threshold")?;
477 }
478
479 let architecture = if threshold.is_some() {
481 SaeArchitecture::JumpReLU
482 } else {
483 SaeArchitecture::ReLU
484 };
485
486 let hook_name = format!("blocks.{hook_layer}.hook_resid_post");
487 let hook_point = hook_name
488 .parse::<HookPoint>()
489 .map_err(|e| MIError::Config(format!("failed to parse hook name: {e}")))?;
490
491 let config = SaeConfig {
492 d_in,
493 d_sae,
494 architecture,
495 hook_name,
496 hook_point,
497 apply_b_dec_to_input: false,
498 normalize_activations: NormalizeActivations::None,
499 };
500
501 info!(
502 "SAE from NPZ: d_in={d_in}, d_sae={d_sae}, arch={:?}, hook={}",
503 config.architecture, config.hook_name
504 );
505
506 Ok(Self {
507 config,
508 w_enc,
509 w_dec,
510 b_enc,
511 b_dec,
512 threshold,
513 })
514 }
515
516 pub fn from_pretrained_npz(
534 repo_id: &str,
535 npz_path: &str,
536 hook_layer: usize,
537 device: &Device,
538 ) -> Result<Self> {
539 let fetch_config = hf_fetch_model::FetchConfig::builder()
540 .on_progress(|event| {
541 tracing::info!(
542 filename = %event.filename,
543 percent = event.percent,
544 bytes_downloaded = event.bytes_downloaded,
545 bytes_total = event.bytes_total,
546 "SAE NPZ download progress",
547 );
548 })
549 .build()
550 .map_err(|e| MIError::Download(format!("failed to build fetch config: {e}")))?;
551
552 info!("Downloading {npz_path} from {repo_id}");
553 let local_path =
554 hf_fetch_model::download_file_blocking(repo_id.to_owned(), npz_path, &fetch_config)
555 .map_err(|e| MIError::Download(format!("failed to download NPZ: {e}")))?
556 .into_inner();
557
558 Self::from_npz(&local_path, hook_layer, device)
559 }
560
561 pub fn from_pretrained(repo_id: &str, sae_id: &str, device: &Device) -> Result<Self> {
578 let fetch_config = hf_fetch_model::FetchConfig::builder()
579 .on_progress(|event| {
580 tracing::info!(
581 filename = %event.filename,
582 percent = event.percent,
583 bytes_downloaded = event.bytes_downloaded,
584 bytes_total = event.bytes_total,
585 "SAE download progress",
586 );
587 })
588 .build()
589 .map_err(|e| MIError::Download(format!("failed to build fetch config: {e}")))?;
590
591 let cfg_remote = format!("{sae_id}/cfg.json");
593 info!("Downloading {cfg_remote} from {repo_id}");
594 let cfg_path =
595 hf_fetch_model::download_file_blocking(repo_id.to_owned(), &cfg_remote, &fetch_config)
596 .map_err(|e| MIError::Download(format!("failed to download cfg.json: {e}")))?
597 .into_inner();
598
599 let weights_remote = format!("{sae_id}/sae_weights.safetensors");
601 info!("Downloading {weights_remote} from {repo_id}");
602 let weights_path = hf_fetch_model::download_file_blocking(
603 repo_id.to_owned(),
604 &weights_remote,
605 &fetch_config,
606 )
607 .or_else(|_| {
608 let alt_remote = format!("{sae_id}/model.safetensors");
609 info!("Trying {alt_remote} from {repo_id}");
610 hf_fetch_model::download_file_blocking(repo_id.to_owned(), &alt_remote, &fetch_config)
611 })
612 .map_err(|e| MIError::Download(format!("failed to download SAE weights: {e}")))?
613 .into_inner();
614
615 let dir = cfg_path.parent().ok_or_else(|| {
617 MIError::Config("cannot determine SAE directory from cfg.json path".into())
618 })?;
619
620 if dir.join("sae_weights.safetensors").exists() || dir.join("model.safetensors").exists() {
624 Self::from_local(dir, device)
625 } else {
626 let weights_dir = weights_path.parent().ok_or_else(|| {
628 MIError::Config("cannot determine SAE directory from weights path".into())
629 })?;
630 let target_cfg = weights_dir.join("cfg.json");
632 if !target_cfg.exists() {
633 std::fs::copy(&cfg_path, &target_cfg)?;
634 }
635 Self::from_local(weights_dir, device)
636 }
637 }
638
639 #[must_use]
643 pub const fn config(&self) -> &SaeConfig {
644 &self.config
645 }
646
647 #[must_use]
649 pub const fn hook_point(&self) -> &HookPoint {
650 &self.config.hook_point
651 }
652
653 #[must_use]
655 pub const fn d_sae(&self) -> usize {
656 self.config.d_sae
657 }
658
659 #[must_use]
661 pub const fn d_in(&self) -> usize {
662 self.config.d_in
663 }
664
665 pub fn encode(&self, x: &Tensor) -> Result<Tensor> {
683 self.encode_with_strategy(x, &TopKStrategy::Auto)
684 }
685
686 pub fn encode_with_strategy(&self, x: &Tensor, strategy: &TopKStrategy) -> Result<Tensor> {
700 let dims = x.dims();
701 let last_dim = *dims
702 .last()
703 .ok_or_else(|| MIError::Config("cannot encode empty tensor".into()))?;
704 if last_dim != self.config.d_in {
705 return Err(MIError::Config(format!(
706 "input last dim {last_dim} != SAE d_in {}",
707 self.config.d_in
708 )));
709 }
710
711 let x_f32 = x.to_dtype(DType::F32)?;
713
714 let x_centered = if self.config.apply_b_dec_to_input {
716 let b_dec = broadcast_bias(&self.b_dec, x_f32.dims())?;
717 (&x_f32 - &b_dec)?
718 } else {
719 x_f32
720 };
721
722 let pre_acts = x_centered.broadcast_matmul(&self.w_enc)?;
725 let b_enc = broadcast_bias(&self.b_enc, pre_acts.dims())?;
727 let pre_acts = (&pre_acts + &b_enc)?;
728
729 match &self.config.architecture {
731 SaeArchitecture::ReLU => Ok(pre_acts.relu()?),
732 SaeArchitecture::JumpReLU => {
733 let threshold = self
734 .threshold
735 .as_ref()
736 .ok_or_else(|| MIError::Config("JumpReLU requires threshold tensor".into()))?;
737 let threshold = broadcast_bias(threshold, pre_acts.dims())?;
739 let mask = pre_acts.gt(&threshold)?;
741 let mask_f32 = mask.to_dtype(DType::F32)?;
742 Ok((&pre_acts * &mask_f32)?)
743 }
744 SaeArchitecture::TopK { k } => topk_activation(&pre_acts, *k, strategy),
745 }
746 }
747
748 pub fn encode_sparse(&self, x: &Tensor) -> Result<SparseActivations<SaeFeatureId>> {
761 let encoded = self.encode(&x.unsqueeze(0)?)?;
762 let encoded_1d = encoded.squeeze(0)?;
763
764 let values: Vec<f32> = encoded_1d.to_vec1()?;
766
767 let mut features: Vec<(SaeFeatureId, f32)> = values
768 .iter()
769 .enumerate()
770 .filter(|&(_, v)| *v > 0.0)
771 .map(|(i, v)| (SaeFeatureId { index: i }, *v))
772 .collect();
773
774 features.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
776
777 Ok(SparseActivations { features })
778 }
779
780 pub fn decode(&self, features: &Tensor) -> Result<Tensor> {
792 let features_f32 = features.to_dtype(DType::F32)?;
795 let decoded = features_f32.broadcast_matmul(&self.w_dec)?;
796 let b_dec = broadcast_bias(&self.b_dec, decoded.dims())?;
797 Ok((&decoded + &b_dec)?)
798 }
799
800 pub fn reconstruct(&self, x: &Tensor) -> Result<Tensor> {
813 let encoded = self.encode(x)?;
814 self.decode(&encoded)
815 }
816
817 pub fn reconstruction_error(&self, x: &Tensor) -> Result<f64> {
828 let x_f32 = x.to_dtype(DType::F32)?;
829 let x_hat = self.reconstruct(&x_f32)?;
830 let diff = (&x_f32 - &x_hat)?;
831 let mse: f32 = diff.sqr()?.mean_all()?.to_scalar()?;
832 Ok(f64::from(mse))
833 }
834
835 pub fn decoder_vector(&self, feature_idx: usize) -> Result<Tensor> {
847 if feature_idx >= self.config.d_sae {
848 return Err(MIError::Config(format!(
849 "feature index {feature_idx} out of range (d_sae={})",
850 self.config.d_sae
851 )));
852 }
853 Ok(self.w_dec.get(feature_idx)?)
855 }
856
857 pub fn prepare_hook_injection(
876 &self,
877 features: &[(usize, f32)],
878 position: usize,
879 seq_len: usize,
880 device: &Device,
881 ) -> Result<HookSpec> {
882 let d_in = self.config.d_in;
883
884 let mut accumulated = Tensor::zeros(d_in, DType::F32, device)?;
886 for &(feature_idx, strength) in features {
887 let dec_vec = self.decoder_vector(feature_idx)?;
888 let dec_vec = dec_vec.to_device(device)?;
889 let scaled = (&dec_vec * f64::from(strength))?;
890 accumulated = (&accumulated + &scaled)?;
891 }
892
893 let injection = Tensor::zeros((1, seq_len, d_in), DType::F32, device)?;
895 let scaled_3d = accumulated.unsqueeze(0)?.unsqueeze(0)?; let before = if position > 0 {
898 Some(injection.narrow(1, 0, position)?)
899 } else {
900 None
901 };
902 let after = if position + 1 < seq_len {
903 Some(injection.narrow(1, position + 1, seq_len - position - 1)?)
904 } else {
905 None
906 };
907
908 let mut parts: Vec<Tensor> = Vec::with_capacity(3);
909 if let Some(b) = before {
910 parts.push(b);
911 }
912 parts.push(scaled_3d);
913 if let Some(a) = after {
914 parts.push(a);
915 }
916
917 let injection = Tensor::cat(&parts, 1)?;
918
919 let mut hooks = HookSpec::new();
920 hooks.intervene(self.config.hook_point.clone(), Intervention::Add(injection));
921 Ok(hooks)
922 }
923}
924
925fn topk_activation(pre_acts: &Tensor, k: usize, strategy: &TopKStrategy) -> Result<Tensor> {
940 let use_cpu = match strategy {
941 TopKStrategy::Cpu => true,
942 TopKStrategy::Gpu => false,
943 TopKStrategy::Auto => matches!(pre_acts.device(), Device::Cpu),
944 };
945
946 if use_cpu {
947 topk_cpu(pre_acts, k)
948 } else {
949 topk_gpu(pre_acts, k)
950 }
951}
952
953fn topk_cpu(pre_acts: &Tensor, k: usize) -> Result<Tensor> {
955 let device = pre_acts.device().clone();
956 let shape = pre_acts.dims().to_vec();
957 let d_sae = *shape
958 .last()
959 .ok_or_else(|| MIError::Config("cannot apply TopK to empty tensor".into()))?;
960
961 let n: usize = shape.iter().take(shape.len() - 1).product();
963 let flat = pre_acts.reshape((n, d_sae))?.to_dtype(DType::F32)?;
964 let flat_cpu = flat.to_device(&Device::Cpu)?;
965
966 let mut result_data: Vec<f32> = Vec::with_capacity(n * d_sae);
967
968 for row_idx in 0..n {
969 let row = flat_cpu.get(row_idx)?;
970 let mut row_vec: Vec<f32> = row.to_vec1()?;
971
972 let k_clamped = k.min(d_sae);
974 if k_clamped > 0 && k_clamped < d_sae {
975 let mut indices: Vec<usize> = (0..d_sae).collect();
977 #[allow(clippy::indexing_slicing)]
978 indices.select_nth_unstable_by(k_clamped - 1, |&a, &b| {
980 let va = row_vec.get(b).copied().unwrap_or(f32::NEG_INFINITY);
981 let vb = row_vec.get(a).copied().unwrap_or(f32::NEG_INFINITY);
982 va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal)
983 });
984 let threshold_idx = indices.get(k_clamped - 1).copied().unwrap_or(0);
985 let threshold = row_vec.get(threshold_idx).copied().unwrap_or(0.0);
986
987 for v in &mut row_vec {
989 if *v < threshold {
990 *v = 0.0;
991 }
992 }
993
994 let active: usize = row_vec.iter().filter(|&&v| v >= threshold).count();
997 if active > k_clamped {
998 let mut excess = active - k_clamped;
999 for v in row_vec.iter_mut().rev() {
1000 if excess == 0 {
1001 break;
1002 }
1003 if (*v - threshold).abs() < f32::EPSILON {
1004 *v = 0.0;
1005 excess -= 1;
1006 }
1007 }
1008 }
1009 } else if k_clamped == 0 {
1010 row_vec.fill(0.0);
1011 }
1012 result_data.extend_from_slice(&row_vec);
1015 }
1016
1017 let result = Tensor::from_vec(result_data, (n, d_sae), &device)?;
1018 result.reshape(shape.as_slice()).map_err(Into::into)
1019}
1020
1021fn topk_gpu(pre_acts: &Tensor, k: usize) -> Result<Tensor> {
1023 let shape = pre_acts.dims().to_vec();
1024 let d_sae = *shape
1025 .last()
1026 .ok_or_else(|| MIError::Config("cannot apply TopK to empty tensor".into()))?;
1027
1028 let k_clamped = k.min(d_sae);
1029 if k_clamped == 0 {
1030 return Ok(pre_acts.zeros_like()?);
1031 }
1032 if k_clamped >= d_sae {
1033 return Ok(pre_acts.clone());
1034 }
1035
1036 let n: usize = shape.iter().take(shape.len() - 1).product();
1038 let flat = pre_acts.reshape((n, d_sae))?.to_dtype(DType::F32)?;
1039
1040 let (sorted_vals, _sorted_indices) = flat.sort_last_dim(false)?;
1042
1043 let kth_vals = sorted_vals.narrow(1, k_clamped - 1, 1)?;
1045
1046 let mask = flat.ge(&kth_vals)?;
1048 let mask_f32 = mask.to_dtype(DType::F32)?;
1049
1050 let result = (&flat * &mask_f32)?;
1051 result.reshape(shape.as_slice()).map_err(Into::into)
1052}
1053
1054fn broadcast_bias(bias: &Tensor, target_shape: &[usize]) -> Result<Tensor> {
1068 let ndim = target_shape.len();
1069 if ndim <= 1 {
1070 return Ok(bias.clone());
1071 }
1072 let mut shape = vec![1_usize; ndim];
1074 let last_dim = *target_shape
1075 .last()
1076 .ok_or_else(|| MIError::Config("cannot broadcast bias to empty shape".into()))?;
1077 if let Some(slot) = shape.last_mut() {
1078 *slot = last_dim;
1079 }
1080 let reshaped = bias.reshape(shape.as_slice())?;
1081 Ok(reshaped.broadcast_as(target_shape)?)
1082}
1083
1084fn tensor_from_view(view: &safetensors::tensor::TensorView<'_>, device: &Device) -> Result<Tensor> {
1094 let shape: Vec<usize> = view.shape().to_vec();
1095 #[allow(clippy::wildcard_enum_match_arm)]
1096 let dtype = match view.dtype() {
1098 safetensors::Dtype::BF16 => DType::BF16,
1099 safetensors::Dtype::F16 => DType::F16,
1100 safetensors::Dtype::F32 => DType::F32,
1101 other => {
1102 return Err(MIError::Config(format!(
1103 "unsupported SAE tensor dtype: {other:?}"
1104 )));
1105 }
1106 };
1107 let tensor = Tensor::from_raw_buffer(view.data(), dtype, &shape, device)?;
1108 Ok(tensor)
1109}
1110
1111fn load_tensor(st: &SafeTensors<'_>, name: &str, device: &Device) -> Result<Tensor> {
1113 let view = st
1114 .tensor(name)
1115 .map_err(|e| MIError::Config(format!("tensor '{name}' not found: {e}")))?;
1116 tensor_from_view(&view, device)
1117}
1118
1119fn validate_shape(tensor: &Tensor, expected: &[usize], name: &str) -> Result<()> {
1121 if tensor.dims() != expected {
1122 return Err(MIError::Config(format!(
1123 "SAE tensor '{name}' shape mismatch: expected {expected:?}, got {:?}",
1124 tensor.dims()
1125 )));
1126 }
1127 Ok(())
1128}
1129
1130#[cfg(test)]
1135mod tests {
1136 use super::*;
1137
1138 #[test]
1139 fn sae_feature_id_display() {
1140 let fid = SaeFeatureId { index: 42 };
1141 assert_eq!(fid.to_string(), "SAE:42");
1142 }
1143
1144 #[test]
1145 fn resolve_architecture_relu_default() {
1146 let arch = resolve_architecture(None, None, None).unwrap();
1147 assert_eq!(arch, SaeArchitecture::ReLU);
1148 }
1149
1150 #[test]
1151 fn resolve_architecture_relu_explicit() {
1152 let arch = resolve_architecture(Some("standard"), Some("relu"), None).unwrap();
1153 assert_eq!(arch, SaeArchitecture::ReLU);
1154 }
1155
1156 #[test]
1157 fn resolve_architecture_jumprelu() {
1158 let arch = resolve_architecture(Some("jumprelu"), None, None).unwrap();
1159 assert_eq!(arch, SaeArchitecture::JumpReLU);
1160 }
1161
1162 #[test]
1163 fn resolve_architecture_jumprelu_from_activation() {
1164 let arch = resolve_architecture(None, Some("jumprelu"), None).unwrap();
1165 assert_eq!(arch, SaeArchitecture::JumpReLU);
1166 }
1167
1168 #[test]
1169 fn resolve_architecture_topk() {
1170 let kwargs = serde_json::json!({"k": 32});
1171 let arch = resolve_architecture(Some("topk"), None, Some(&kwargs)).unwrap();
1172 assert_eq!(arch, SaeArchitecture::TopK { k: 32 });
1173 }
1174
1175 #[test]
1176 fn resolve_architecture_topk_from_activation() {
1177 let kwargs = serde_json::json!({"k": 64});
1178 let arch = resolve_architecture(None, Some("topk"), Some(&kwargs)).unwrap();
1179 assert_eq!(arch, SaeArchitecture::TopK { k: 64 });
1180 }
1181
1182 #[test]
1183 fn resolve_architecture_topk_missing_k() {
1184 let result = resolve_architecture(Some("topk"), None, None);
1185 assert!(result.is_err());
1186 }
1187
1188 #[test]
1189 fn resolve_architecture_unknown() {
1190 let result = resolve_architecture(Some("gated"), None, None);
1191 assert!(result.is_err());
1192 }
1193
1194 #[test]
1195 fn parse_config_minimal() {
1196 let json = r#"{
1197 "d_in": 2304,
1198 "d_sae": 16384,
1199 "hook_name": "blocks.5.hook_resid_post"
1200 }"#;
1201 let raw: RawSaeConfig = serde_json::from_str(json).unwrap();
1202 let config = parse_sae_config(raw).unwrap();
1203 assert_eq!(config.d_in, 2304);
1204 assert_eq!(config.d_sae, 16384);
1205 assert_eq!(config.architecture, SaeArchitecture::ReLU);
1206 assert_eq!(config.hook_point, HookPoint::ResidPost(5));
1207 assert!(!config.apply_b_dec_to_input);
1208 }
1209
1210 #[test]
1211 fn parse_config_jumprelu() {
1212 let json = r#"{
1213 "d_in": 2304,
1214 "d_sae": 16384,
1215 "architecture": "jumprelu",
1216 "hook_name": "blocks.20.hook_resid_post",
1217 "apply_b_dec_to_input": true,
1218 "normalize_activations": "expected_average_only_in"
1219 }"#;
1220 let raw: RawSaeConfig = serde_json::from_str(json).unwrap();
1221 let config = parse_sae_config(raw).unwrap();
1222 assert_eq!(config.architecture, SaeArchitecture::JumpReLU);
1223 assert_eq!(config.hook_point, HookPoint::ResidPost(20));
1224 assert!(config.apply_b_dec_to_input);
1225 assert_eq!(
1226 config.normalize_activations,
1227 NormalizeActivations::ExpectedAverageOnlyIn
1228 );
1229 }
1230
1231 #[test]
1232 fn parse_config_topk() {
1233 let json = r#"{
1234 "d_in": 2304,
1235 "d_sae": 65536,
1236 "activation_fn_str": "topk",
1237 "activation_fn_kwargs": {"k": 32},
1238 "hook_name": "blocks.10.hook_resid_post"
1239 }"#;
1240 let raw: RawSaeConfig = serde_json::from_str(json).unwrap();
1241 let config = parse_sae_config(raw).unwrap();
1242 assert_eq!(config.architecture, SaeArchitecture::TopK { k: 32 });
1243 }
1244
1245 #[test]
1246 fn topk_cpu_basic() {
1247 let data = Tensor::new(&[[5.0_f32, 3.0, 1.0, 4.0, 2.0]], &Device::Cpu).unwrap();
1248 let result = topk_cpu(&data, 2).unwrap();
1249 let vals: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
1250 assert_eq!(vals, vec![5.0, 0.0, 0.0, 4.0, 0.0]);
1251 }
1252
1253 #[test]
1254 fn topk_cpu_all_kept() {
1255 let data = Tensor::new(&[[1.0_f32, 2.0, 3.0]], &Device::Cpu).unwrap();
1256 let result = topk_cpu(&data, 5).unwrap();
1257 let vals: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
1258 assert_eq!(vals, vec![1.0, 2.0, 3.0]);
1259 }
1260
1261 #[test]
1262 fn topk_cpu_none_kept() {
1263 let data = Tensor::new(&[[1.0_f32, 2.0, 3.0]], &Device::Cpu).unwrap();
1264 let result = topk_cpu(&data, 0).unwrap();
1265 let vals: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
1266 assert_eq!(vals, vec![0.0, 0.0, 0.0]);
1267 }
1268
1269 #[test]
1270 fn topk_cpu_batched() {
1271 let data = Tensor::new(
1272 &[[5.0_f32, 3.0, 1.0, 4.0, 2.0], [1.0, 2.0, 3.0, 4.0, 5.0]],
1273 &Device::Cpu,
1274 )
1275 .unwrap();
1276 let result = topk_cpu(&data, 3).unwrap();
1277 let vals: Vec<Vec<f32>> = result.to_vec2().unwrap();
1278 assert_eq!(vals[0], vec![5.0, 3.0, 0.0, 4.0, 0.0]);
1279 assert_eq!(vals[1], vec![0.0, 0.0, 3.0, 4.0, 5.0]);
1280 }
1281
1282 #[test]
1283 fn sparse_activations_sae() {
1284 let features = vec![
1285 (SaeFeatureId { index: 5 }, 3.0),
1286 (SaeFeatureId { index: 2 }, 2.0),
1287 (SaeFeatureId { index: 8 }, 1.0),
1288 ];
1289 let sparse = SparseActivations { features };
1290 assert_eq!(sparse.len(), 3);
1291 assert!(!sparse.is_empty());
1292 }
1293
1294 #[test]
1295 fn sparse_activations_truncate_sae() {
1296 let features = vec![
1297 (SaeFeatureId { index: 5 }, 3.0),
1298 (SaeFeatureId { index: 2 }, 2.0),
1299 (SaeFeatureId { index: 8 }, 1.0),
1300 ];
1301 let mut sparse = SparseActivations { features };
1302 sparse.truncate(2);
1303 assert_eq!(sparse.len(), 2);
1304 assert_eq!(sparse.features[0].0.index, 5);
1305 assert_eq!(sparse.features[1].0.index, 2);
1306 }
1307
1308 #[test]
1309 fn encode_decode_roundtrip_shapes() {
1310 let d_in = 4;
1312 let d_sae = 8;
1313 let device = Device::Cpu;
1314
1315 let w_enc = Tensor::randn(0.0_f32, 1.0, (d_in, d_sae), &device).unwrap();
1316 let w_dec = Tensor::randn(0.0_f32, 1.0, (d_sae, d_in), &device).unwrap();
1317 let b_enc = Tensor::zeros(d_sae, DType::F32, &device).unwrap();
1318 let b_dec = Tensor::zeros(d_in, DType::F32, &device).unwrap();
1319
1320 let sae = SparseAutoencoder {
1321 config: SaeConfig {
1322 d_in,
1323 d_sae,
1324 architecture: SaeArchitecture::ReLU,
1325 hook_name: "blocks.0.hook_resid_post".into(),
1326 hook_point: HookPoint::ResidPost(0),
1327 apply_b_dec_to_input: false,
1328 normalize_activations: NormalizeActivations::None,
1329 },
1330 w_enc,
1331 w_dec,
1332 b_enc,
1333 b_dec,
1334 threshold: None,
1335 };
1336
1337 let x1 = Tensor::randn(0.0_f32, 1.0, (d_in,), &device).unwrap();
1339 let encoded = sae.encode(&x1.unsqueeze(0).unwrap()).unwrap();
1340 assert_eq!(encoded.dims(), &[1, d_sae]);
1341
1342 let x2 = Tensor::randn(0.0_f32, 1.0, (3, d_in), &device).unwrap();
1344 let encoded = sae.encode(&x2).unwrap();
1345 assert_eq!(encoded.dims(), &[3, d_sae]);
1346 let decoded = sae.decode(&encoded).unwrap();
1347 assert_eq!(decoded.dims(), &[3, d_in]);
1348
1349 let x3 = Tensor::randn(0.0_f32, 1.0, (2, 5, d_in), &device).unwrap();
1351 let encoded = sae.encode(&x3).unwrap();
1352 assert_eq!(encoded.dims(), &[2, 5, d_sae]);
1353 let decoded = sae.decode(&encoded).unwrap();
1354 assert_eq!(decoded.dims(), &[2, 5, d_in]);
1355
1356 let x_hat = sae.reconstruct(&x2).unwrap();
1358 assert_eq!(x_hat.dims(), &[3, d_in]);
1359
1360 let mse = sae.reconstruction_error(&x2).unwrap();
1362 assert!(mse >= 0.0);
1363 }
1364
1365 #[test]
1366 fn encode_sparse_basic() {
1367 let d_in = 4;
1368 let d_sae = 8;
1369 let device = Device::Cpu;
1370
1371 let mut w_enc_data = vec![0.0_f32; d_in * d_sae];
1373 for i in 0..d_in {
1375 w_enc_data[i * d_sae + i] = 1.0;
1376 }
1377 let w_enc = Tensor::from_vec(w_enc_data, (d_in, d_sae), &device).unwrap();
1378 let w_dec = Tensor::randn(0.0_f32, 1.0, (d_sae, d_in), &device).unwrap();
1379 let b_enc = Tensor::zeros(d_sae, DType::F32, &device).unwrap();
1380 let b_dec = Tensor::zeros(d_in, DType::F32, &device).unwrap();
1381
1382 let sae = SparseAutoencoder {
1383 config: SaeConfig {
1384 d_in,
1385 d_sae,
1386 architecture: SaeArchitecture::ReLU,
1387 hook_name: "blocks.0.hook_resid_post".into(),
1388 hook_point: HookPoint::ResidPost(0),
1389 apply_b_dec_to_input: false,
1390 normalize_activations: NormalizeActivations::None,
1391 },
1392 w_enc,
1393 w_dec,
1394 b_enc,
1395 b_dec,
1396 threshold: None,
1397 };
1398
1399 let x = Tensor::new(&[2.0_f32, -1.0, 3.0, 0.5], &device).unwrap();
1400 let sparse = sae.encode_sparse(&x).unwrap();
1401
1402 assert_eq!(sparse.len(), 3);
1404 assert_eq!(sparse.features[0].0.index, 2); assert_eq!(sparse.features[1].0.index, 0); assert_eq!(sparse.features[2].0.index, 3); }
1409
1410 #[test]
1411 fn decoder_vector_basic() {
1412 let d_in = 4;
1413 let d_sae = 8;
1414 let device = Device::Cpu;
1415
1416 let w_dec = Tensor::randn(0.0_f32, 1.0, (d_sae, d_in), &device).unwrap();
1417 let sae = SparseAutoencoder {
1418 config: SaeConfig {
1419 d_in,
1420 d_sae,
1421 architecture: SaeArchitecture::ReLU,
1422 hook_name: "blocks.0.hook_resid_post".into(),
1423 hook_point: HookPoint::ResidPost(0),
1424 apply_b_dec_to_input: false,
1425 normalize_activations: NormalizeActivations::None,
1426 },
1427 w_enc: Tensor::zeros((d_in, d_sae), DType::F32, &device).unwrap(),
1428 w_dec: w_dec.clone(),
1429 b_enc: Tensor::zeros(d_sae, DType::F32, &device).unwrap(),
1430 b_dec: Tensor::zeros(d_in, DType::F32, &device).unwrap(),
1431 threshold: None,
1432 };
1433
1434 let vec0 = sae.decoder_vector(0).unwrap();
1435 assert_eq!(vec0.dims(), &[d_in]);
1436
1437 assert!(sae.decoder_vector(d_sae).is_err());
1439 }
1440
1441 #[test]
1442 fn prepare_injection_basic() {
1443 let d_in = 4;
1444 let d_sae = 8;
1445 let device = Device::Cpu;
1446
1447 let sae = SparseAutoencoder {
1448 config: SaeConfig {
1449 d_in,
1450 d_sae,
1451 architecture: SaeArchitecture::ReLU,
1452 hook_name: "blocks.0.hook_resid_post".into(),
1453 hook_point: HookPoint::ResidPost(0),
1454 apply_b_dec_to_input: false,
1455 normalize_activations: NormalizeActivations::None,
1456 },
1457 w_enc: Tensor::zeros((d_in, d_sae), DType::F32, &device).unwrap(),
1458 w_dec: Tensor::ones((d_sae, d_in), DType::F32, &device).unwrap(),
1459 b_enc: Tensor::zeros(d_sae, DType::F32, &device).unwrap(),
1460 b_dec: Tensor::zeros(d_in, DType::F32, &device).unwrap(),
1461 threshold: None,
1462 };
1463
1464 let features = vec![(0_usize, 1.0_f32), (1, 0.5)];
1465 let hooks = sae
1466 .prepare_hook_injection(&features, 2, 5, &device)
1467 .unwrap();
1468 assert!(!hooks.is_empty());
1469 }
1470}