1pub(crate) mod npz;
59
60use std::path::Path;
61
62use candle_core::{DType, Device, Tensor};
63use safetensors::tensor::SafeTensors;
64use tracing::info;
65
66use crate::error::{MIError, Result};
67use crate::hooks::{HookPoint, HookSpec, Intervention};
68use crate::sparse::{FeatureId, SparseActivations};
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
76pub struct SaeFeatureId {
77 pub index: usize,
79}
80
81impl std::fmt::Display for SaeFeatureId {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 write!(f, "SAE:{}", self.index)
84 }
85}
86
87impl FeatureId for SaeFeatureId {}
88
89#[non_exhaustive]
91#[derive(Debug, Clone, PartialEq, Eq)]
92pub enum SaeArchitecture {
93 ReLU,
95 JumpReLU,
98 TopK {
100 k: usize,
102 },
103}
104
105#[non_exhaustive]
107#[derive(Debug, Clone, PartialEq, Eq)]
108pub enum NormalizeActivations {
109 None,
111 ExpectedAverageOnlyIn,
113}
114
115#[non_exhaustive]
117#[derive(Debug, Clone, PartialEq, Eq)]
118pub enum TopKStrategy {
119 Auto,
121 Cpu,
123 Gpu,
125}
126
127#[derive(Debug, Clone)]
129pub struct SaeConfig {
130 pub d_in: usize,
132 pub d_sae: usize,
134 pub architecture: SaeArchitecture,
136 pub hook_name: String,
138 pub hook_point: HookPoint,
140 pub apply_b_dec_to_input: bool,
142 pub normalize_activations: NormalizeActivations,
144}
145
146#[derive(serde::Deserialize)]
151#[allow(clippy::missing_docs_in_private_items)]
152struct RawSaeConfig {
153 d_in: usize,
154 d_sae: usize,
155 #[serde(default)]
156 architecture: Option<String>,
157 #[serde(default)]
158 activation_fn_str: Option<String>,
159 #[serde(default)]
160 activation_fn_kwargs: Option<serde_json::Value>,
161 #[serde(default)]
162 hook_name: Option<String>,
163 #[serde(default)]
164 hook_point: Option<String>,
165 #[serde(default)]
166 apply_b_dec_to_input: bool,
167 #[serde(default)]
168 normalize_activations: Option<String>,
169}
170
171fn parse_sae_config(raw: RawSaeConfig) -> Result<SaeConfig> {
173 let architecture = resolve_architecture(
175 raw.architecture.as_deref(),
176 raw.activation_fn_str.as_deref(),
177 raw.activation_fn_kwargs.as_ref(),
178 )?;
179
180 let hook_name = raw
182 .hook_name
183 .or(raw.hook_point)
184 .unwrap_or_else(|| "unknown".to_owned());
185
186 let hook_point: HookPoint = hook_name
188 .parse()
189 .unwrap_or_else(|_: std::convert::Infallible| {
190 unreachable!()
192 });
193
194 let normalize_activations = match raw.normalize_activations.as_deref() {
195 Some("expected_average_only_in") => NormalizeActivations::ExpectedAverageOnlyIn,
196 _ => NormalizeActivations::None,
197 };
198
199 Ok(SaeConfig {
200 d_in: raw.d_in,
201 d_sae: raw.d_sae,
202 architecture,
203 hook_name,
204 hook_point,
205 apply_b_dec_to_input: raw.apply_b_dec_to_input,
206 normalize_activations,
207 })
208}
209
210fn resolve_architecture(
212 architecture: Option<&str>,
213 activation_fn_str: Option<&str>,
214 activation_fn_kwargs: Option<&serde_json::Value>,
215) -> Result<SaeArchitecture> {
216 match architecture {
218 Some("jumprelu") => return Ok(SaeArchitecture::JumpReLU),
219 Some("topk") => {
220 let k = extract_topk_k(activation_fn_kwargs)?;
221 return Ok(SaeArchitecture::TopK { k });
222 }
223 Some("standard") | None => {} Some(other) => {
225 return Err(MIError::Config(format!(
226 "unsupported SAE architecture: {other:?}"
227 )));
228 }
229 }
230
231 match activation_fn_str {
233 Some("relu") | None => Ok(SaeArchitecture::ReLU),
234 Some("jumprelu") => Ok(SaeArchitecture::JumpReLU),
235 Some("topk") => {
236 let k = extract_topk_k(activation_fn_kwargs)?;
237 Ok(SaeArchitecture::TopK { k })
238 }
239 Some(other) => Err(MIError::Config(format!(
240 "unsupported SAE activation function: {other:?}"
241 ))),
242 }
243}
244
245fn extract_topk_k(kwargs: Option<&serde_json::Value>) -> Result<usize> {
247 let k = kwargs
248 .and_then(|v| v.get("k"))
249 .and_then(serde_json::Value::as_u64)
250 .ok_or_else(|| {
251 MIError::Config("TopK SAE requires activation_fn_kwargs.k in cfg.json".into())
252 })?;
253 let k_usize = usize::try_from(k)
254 .map_err(|_| MIError::Config(format!("TopK k value {k} too large for usize")))?;
255 Ok(k_usize)
256}
257
258pub struct SparseAutoencoder {
289 config: SaeConfig,
291 w_enc: Tensor,
296 w_dec: Tensor,
301 b_enc: Tensor,
306 b_dec: Tensor,
311 threshold: Option<Tensor>,
316}
317
318impl SparseAutoencoder {
319 pub fn from_local(dir: &Path, device: &Device) -> Result<Self> {
333 let cfg_path = dir.join("cfg.json");
335 if !cfg_path.exists() {
336 return Err(MIError::Config(format!(
337 "cfg.json not found in {}",
338 dir.display()
339 )));
340 }
341 let cfg_text = std::fs::read_to_string(&cfg_path)?;
342 let raw: RawSaeConfig = serde_json::from_str(&cfg_text)
343 .map_err(|e| MIError::Config(format!("failed to parse cfg.json: {e}")))?;
344 let config = parse_sae_config(raw)?;
345
346 info!(
347 "SAE config: d_in={}, d_sae={}, arch={:?}, hook={}",
348 config.d_in, config.d_sae, config.architecture, config.hook_name
349 );
350
351 let weights_path = if dir.join("sae_weights.safetensors").exists() {
353 dir.join("sae_weights.safetensors")
354 } else if dir.join("model.safetensors").exists() {
355 dir.join("model.safetensors")
356 } else {
357 return Err(MIError::Config(format!(
358 "no safetensors file found in {}",
359 dir.display()
360 )));
361 };
362
363 let data = std::fs::read(&weights_path)?;
365 let st = SafeTensors::deserialize(&data)
366 .map_err(|e| MIError::Config(format!("failed to deserialize SAE weights: {e}")))?;
367
368 let w_enc = load_tensor(&st, "W_enc", device)?;
369 let w_dec = load_tensor(&st, "W_dec", device)?;
370 let b_enc = load_tensor(&st, "b_enc", device)?;
371 let b_dec = load_tensor(&st, "b_dec", device)?;
372 let threshold = st
373 .tensor("threshold")
374 .ok()
375 .map(|v| tensor_from_view(&v, device))
376 .transpose()?;
377
378 let w_enc = w_enc.to_dtype(DType::F32)?;
380 let w_dec = w_dec.to_dtype(DType::F32)?;
381 let b_enc = b_enc.to_dtype(DType::F32)?;
382 let b_dec = b_dec.to_dtype(DType::F32)?;
383 let threshold = threshold.map(|t| t.to_dtype(DType::F32)).transpose()?;
384
385 validate_shape(&w_enc, &[config.d_in, config.d_sae], "W_enc")?;
387 validate_shape(&w_dec, &[config.d_sae, config.d_in], "W_dec")?;
388 validate_shape(&b_enc, &[config.d_sae], "b_enc")?;
389 validate_shape(&b_dec, &[config.d_in], "b_dec")?;
390 if let Some(ref t) = threshold {
391 validate_shape(t, &[config.d_sae], "threshold")?;
392 }
393
394 if config.architecture == SaeArchitecture::JumpReLU && threshold.is_none() {
396 return Err(MIError::Config(
397 "JumpReLU SAE requires 'threshold' tensor in weights file".into(),
398 ));
399 }
400
401 info!(
402 "SAE loaded: {} weights on {:?}",
403 weights_path.display(),
404 device
405 );
406
407 Ok(Self {
408 config,
409 w_enc,
410 w_dec,
411 b_enc,
412 b_dec,
413 threshold,
414 })
415 }
416
417 pub fn from_npz(npz_path: &Path, hook_layer: usize, device: &Device) -> Result<Self> {
434 info!("Loading SAE from NPZ: {}", npz_path.display());
435 let tensors = npz::load_npz(npz_path, device)?;
436
437 let w_enc = tensors
438 .get("W_enc")
439 .ok_or_else(|| MIError::Config("NPZ missing W_enc".into()))?
440 .to_dtype(DType::F32)?;
441 let w_dec = tensors
442 .get("W_dec")
443 .ok_or_else(|| MIError::Config("NPZ missing W_dec".into()))?
444 .to_dtype(DType::F32)?;
445 let b_enc = tensors
446 .get("b_enc")
447 .ok_or_else(|| MIError::Config("NPZ missing b_enc".into()))?
448 .to_dtype(DType::F32)?;
449 let b_dec = tensors
450 .get("b_dec")
451 .ok_or_else(|| MIError::Config("NPZ missing b_dec".into()))?
452 .to_dtype(DType::F32)?;
453 let threshold = tensors
454 .get("threshold")
455 .map(|t| t.to_dtype(DType::F32))
456 .transpose()?;
457
458 let w_enc_dims = w_enc.dims();
460 if w_enc_dims.len() != 2 {
461 return Err(MIError::Config(format!(
462 "W_enc expected 2 dims, got {}",
463 w_enc_dims.len()
464 )));
465 }
466 let d_in = *w_enc_dims
467 .first()
468 .ok_or_else(|| MIError::Config("W_enc has no dimensions".into()))?;
469 let d_sae = *w_enc_dims
470 .get(1)
471 .ok_or_else(|| MIError::Config("W_enc has no second dimension".into()))?;
472
473 validate_shape(&w_enc, &[d_in, d_sae], "W_enc")?;
475 validate_shape(&w_dec, &[d_sae, d_in], "W_dec")?;
476 validate_shape(&b_enc, &[d_sae], "b_enc")?;
477 validate_shape(&b_dec, &[d_in], "b_dec")?;
478 if let Some(ref t) = threshold {
479 validate_shape(t, &[d_sae], "threshold")?;
480 }
481
482 let architecture = if threshold.is_some() {
484 SaeArchitecture::JumpReLU
485 } else {
486 SaeArchitecture::ReLU
487 };
488
489 let hook_name = format!("blocks.{hook_layer}.hook_resid_post");
490 let hook_point = hook_name
491 .parse::<HookPoint>()
492 .map_err(|e| MIError::Config(format!("failed to parse hook name: {e}")))?;
493
494 let config = SaeConfig {
495 d_in,
496 d_sae,
497 architecture,
498 hook_name,
499 hook_point,
500 apply_b_dec_to_input: false,
501 normalize_activations: NormalizeActivations::None,
502 };
503
504 info!(
505 "SAE from NPZ: d_in={d_in}, d_sae={d_sae}, arch={:?}, hook={}",
506 config.architecture, config.hook_name
507 );
508
509 Ok(Self {
510 config,
511 w_enc,
512 w_dec,
513 b_enc,
514 b_dec,
515 threshold,
516 })
517 }
518
519 pub fn from_pretrained_npz(
537 repo_id: &str,
538 npz_path: &str,
539 hook_layer: usize,
540 device: &Device,
541 ) -> Result<Self> {
542 let fetch_config = crate::download::fetch_config_builder()
543 .on_progress(|event| {
544 tracing::info!(
545 filename = %event.filename,
546 percent = event.percent,
547 bytes_downloaded = event.bytes_downloaded,
548 bytes_total = event.bytes_total,
549 "SAE NPZ download progress",
550 );
551 })
552 .build()
553 .map_err(|e| MIError::Download(format!("failed to build fetch config: {e}")))?;
554
555 info!("Downloading {npz_path} from {repo_id}");
556 let local_path =
557 hf_fetch_model::download_file_blocking(repo_id.to_owned(), npz_path, &fetch_config)
558 .map_err(|e| MIError::Download(format!("failed to download NPZ: {e}")))?
559 .into_inner();
560
561 Self::from_npz(&local_path, hook_layer, device)
562 }
563
564 pub fn from_pretrained(repo_id: &str, sae_id: &str, device: &Device) -> Result<Self> {
581 let fetch_config = crate::download::fetch_config_builder()
582 .on_progress(|event| {
583 tracing::info!(
584 filename = %event.filename,
585 percent = event.percent,
586 bytes_downloaded = event.bytes_downloaded,
587 bytes_total = event.bytes_total,
588 "SAE download progress",
589 );
590 })
591 .build()
592 .map_err(|e| MIError::Download(format!("failed to build fetch config: {e}")))?;
593
594 let cfg_remote = format!("{sae_id}/cfg.json");
596 info!("Downloading {cfg_remote} from {repo_id}");
597 let cfg_path =
598 hf_fetch_model::download_file_blocking(repo_id.to_owned(), &cfg_remote, &fetch_config)
599 .map_err(|e| MIError::Download(format!("failed to download cfg.json: {e}")))?
600 .into_inner();
601
602 let weights_remote = format!("{sae_id}/sae_weights.safetensors");
604 info!("Downloading {weights_remote} from {repo_id}");
605 let weights_path = hf_fetch_model::download_file_blocking(
606 repo_id.to_owned(),
607 &weights_remote,
608 &fetch_config,
609 )
610 .or_else(|_| {
611 let alt_remote = format!("{sae_id}/model.safetensors");
612 info!("Trying {alt_remote} from {repo_id}");
613 hf_fetch_model::download_file_blocking(repo_id.to_owned(), &alt_remote, &fetch_config)
614 })
615 .map_err(|e| MIError::Download(format!("failed to download SAE weights: {e}")))?
616 .into_inner();
617
618 let dir = cfg_path.parent().ok_or_else(|| {
620 MIError::Config("cannot determine SAE directory from cfg.json path".into())
621 })?;
622
623 if dir.join("sae_weights.safetensors").exists() || dir.join("model.safetensors").exists() {
627 Self::from_local(dir, device)
628 } else {
629 let weights_dir = weights_path.parent().ok_or_else(|| {
631 MIError::Config("cannot determine SAE directory from weights path".into())
632 })?;
633 let target_cfg = weights_dir.join("cfg.json");
635 if !target_cfg.exists() {
636 std::fs::copy(&cfg_path, &target_cfg)?;
637 }
638 Self::from_local(weights_dir, device)
639 }
640 }
641
642 #[must_use]
646 pub const fn config(&self) -> &SaeConfig {
647 &self.config
648 }
649
650 #[must_use]
652 pub const fn hook_point(&self) -> &HookPoint {
653 &self.config.hook_point
654 }
655
656 #[must_use]
658 pub const fn d_sae(&self) -> usize {
659 self.config.d_sae
660 }
661
662 #[must_use]
664 pub const fn d_in(&self) -> usize {
665 self.config.d_in
666 }
667
668 pub fn encode(&self, x: &Tensor) -> Result<Tensor> {
686 self.encode_with_strategy(x, &TopKStrategy::Auto)
687 }
688
689 pub fn encode_with_strategy(&self, x: &Tensor, strategy: &TopKStrategy) -> Result<Tensor> {
703 let dims = x.dims();
704 let last_dim = *dims
705 .last()
706 .ok_or_else(|| MIError::Config("cannot encode empty tensor".into()))?;
707 if last_dim != self.config.d_in {
708 return Err(MIError::Config(format!(
709 "input last dim {last_dim} != SAE d_in {}",
710 self.config.d_in
711 )));
712 }
713
714 let x_f32 = x.to_dtype(DType::F32)?;
716
717 let x_centered = if self.config.apply_b_dec_to_input {
719 let b_dec = broadcast_bias(&self.b_dec, x_f32.dims())?;
720 (&x_f32 - &b_dec)?
721 } else {
722 x_f32
723 };
724
725 let pre_acts = x_centered.broadcast_matmul(&self.w_enc)?;
728 let b_enc = broadcast_bias(&self.b_enc, pre_acts.dims())?;
730 let pre_acts = (&pre_acts + &b_enc)?;
731
732 match &self.config.architecture {
734 SaeArchitecture::ReLU => Ok(pre_acts.relu()?),
735 SaeArchitecture::JumpReLU => {
736 let threshold = self
737 .threshold
738 .as_ref()
739 .ok_or_else(|| MIError::Config("JumpReLU requires threshold tensor".into()))?;
740 let threshold = broadcast_bias(threshold, pre_acts.dims())?;
742 let mask = pre_acts.gt(&threshold)?;
744 let mask_f32 = mask.to_dtype(DType::F32)?;
745 Ok((&pre_acts * &mask_f32)?)
746 }
747 SaeArchitecture::TopK { k } => topk_activation(&pre_acts, *k, strategy),
748 }
749 }
750
751 pub fn encode_sparse(&self, x: &Tensor) -> Result<SparseActivations<SaeFeatureId>> {
764 let encoded = self.encode(&x.unsqueeze(0)?)?;
765 let encoded_1d = encoded.squeeze(0)?;
766
767 let values: Vec<f32> = encoded_1d.to_vec1()?;
769
770 let mut features: Vec<(SaeFeatureId, f32)> = values
771 .iter()
772 .enumerate()
773 .filter(|&(_, v)| *v > 0.0)
774 .map(|(i, v)| (SaeFeatureId { index: i }, *v))
775 .collect();
776
777 features.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
779
780 Ok(SparseActivations { features })
781 }
782
783 pub fn decode(&self, features: &Tensor) -> Result<Tensor> {
795 let features_f32 = features.to_dtype(DType::F32)?;
798 let decoded = features_f32.broadcast_matmul(&self.w_dec)?;
799 let b_dec = broadcast_bias(&self.b_dec, decoded.dims())?;
800 Ok((&decoded + &b_dec)?)
801 }
802
803 pub fn reconstruct(&self, x: &Tensor) -> Result<Tensor> {
816 let encoded = self.encode(x)?;
817 self.decode(&encoded)
818 }
819
820 pub fn reconstruction_error(&self, x: &Tensor) -> Result<f64> {
831 let x_f32 = x.to_dtype(DType::F32)?;
832 let x_hat = self.reconstruct(&x_f32)?;
833 let diff = (&x_f32 - &x_hat)?;
834 let mse: f32 = diff.sqr()?.mean_all()?.to_scalar()?;
835 Ok(f64::from(mse))
836 }
837
838 pub fn decoder_vector(&self, feature_idx: usize) -> Result<Tensor> {
850 if feature_idx >= self.config.d_sae {
851 return Err(MIError::Config(format!(
852 "feature index {feature_idx} out of range (d_sae={})",
853 self.config.d_sae
854 )));
855 }
856 Ok(self.w_dec.get(feature_idx)?)
858 }
859
860 pub fn prepare_hook_injection(
879 &self,
880 features: &[(usize, f32)],
881 position: usize,
882 seq_len: usize,
883 device: &Device,
884 ) -> Result<HookSpec> {
885 let d_in = self.config.d_in;
886
887 let mut accumulated = Tensor::zeros(d_in, DType::F32, device)?;
889 for &(feature_idx, strength) in features {
890 let dec_vec = self.decoder_vector(feature_idx)?;
891 let dec_vec = dec_vec.to_device(device)?;
892 let scaled = (&dec_vec * f64::from(strength))?;
893 accumulated = (&accumulated + &scaled)?;
894 }
895
896 let injection = Tensor::zeros((1, seq_len, d_in), DType::F32, device)?;
898 let scaled_3d = accumulated.unsqueeze(0)?.unsqueeze(0)?; let before = if position > 0 {
901 Some(injection.narrow(1, 0, position)?)
902 } else {
903 None
904 };
905 let after = if position + 1 < seq_len {
906 Some(injection.narrow(1, position + 1, seq_len - position - 1)?)
907 } else {
908 None
909 };
910
911 let mut parts: Vec<Tensor> = Vec::with_capacity(3);
912 if let Some(b) = before {
913 parts.push(b);
914 }
915 parts.push(scaled_3d);
916 if let Some(a) = after {
917 parts.push(a);
918 }
919
920 let injection = Tensor::cat(&parts, 1)?;
921
922 let mut hooks = HookSpec::new();
923 hooks.intervene(self.config.hook_point.clone(), Intervention::Add(injection));
924 Ok(hooks)
925 }
926}
927
928fn topk_activation(pre_acts: &Tensor, k: usize, strategy: &TopKStrategy) -> Result<Tensor> {
943 let use_cpu = match strategy {
944 TopKStrategy::Cpu => true,
945 TopKStrategy::Gpu => false,
946 TopKStrategy::Auto => matches!(pre_acts.device(), Device::Cpu),
947 };
948
949 if use_cpu {
950 topk_cpu(pre_acts, k)
951 } else {
952 topk_gpu(pre_acts, k)
953 }
954}
955
956fn topk_cpu(pre_acts: &Tensor, k: usize) -> Result<Tensor> {
958 let device = pre_acts.device().clone();
959 let shape = pre_acts.dims().to_vec();
960 let d_sae = *shape
961 .last()
962 .ok_or_else(|| MIError::Config("cannot apply TopK to empty tensor".into()))?;
963
964 let n: usize = shape.iter().take(shape.len() - 1).product();
966 let flat = pre_acts.reshape((n, d_sae))?.to_dtype(DType::F32)?;
967 let flat_cpu = flat.to_device(&Device::Cpu)?;
968
969 let mut result_data: Vec<f32> = Vec::with_capacity(n * d_sae);
970
971 for row_idx in 0..n {
972 let row = flat_cpu.get(row_idx)?;
973 let mut row_vec: Vec<f32> = row.to_vec1()?;
974
975 let k_clamped = k.min(d_sae);
977 if k_clamped > 0 && k_clamped < d_sae {
978 let mut indices: Vec<usize> = (0..d_sae).collect();
980 #[allow(clippy::indexing_slicing)]
981 indices.select_nth_unstable_by(k_clamped - 1, |&a, &b| {
983 let va = row_vec.get(b).copied().unwrap_or(f32::NEG_INFINITY);
984 let vb = row_vec.get(a).copied().unwrap_or(f32::NEG_INFINITY);
985 va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal)
986 });
987 let threshold_idx = indices.get(k_clamped - 1).copied().unwrap_or(0);
988 let threshold = row_vec.get(threshold_idx).copied().unwrap_or(0.0);
989
990 for v in &mut row_vec {
992 if *v < threshold {
993 *v = 0.0;
994 }
995 }
996
997 let active: usize = row_vec.iter().filter(|&&v| v >= threshold).count();
1000 if active > k_clamped {
1001 let mut excess = active - k_clamped;
1002 for v in row_vec.iter_mut().rev() {
1003 if excess == 0 {
1004 break;
1005 }
1006 if (*v - threshold).abs() < f32::EPSILON {
1007 *v = 0.0;
1008 excess -= 1;
1009 }
1010 }
1011 }
1012 } else if k_clamped == 0 {
1013 row_vec.fill(0.0);
1014 }
1015 result_data.extend_from_slice(&row_vec);
1018 }
1019
1020 let result = Tensor::from_vec(result_data, (n, d_sae), &device)?;
1021 result.reshape(shape.as_slice()).map_err(Into::into)
1022}
1023
1024fn topk_gpu(pre_acts: &Tensor, k: usize) -> Result<Tensor> {
1026 let shape = pre_acts.dims().to_vec();
1027 let d_sae = *shape
1028 .last()
1029 .ok_or_else(|| MIError::Config("cannot apply TopK to empty tensor".into()))?;
1030
1031 let k_clamped = k.min(d_sae);
1032 if k_clamped == 0 {
1033 return Ok(pre_acts.zeros_like()?);
1034 }
1035 if k_clamped >= d_sae {
1036 return Ok(pre_acts.clone());
1037 }
1038
1039 let n: usize = shape.iter().take(shape.len() - 1).product();
1041 let flat = pre_acts.reshape((n, d_sae))?.to_dtype(DType::F32)?;
1042
1043 let (sorted_vals, _sorted_indices) = flat.sort_last_dim(false)?;
1045
1046 let kth_vals = sorted_vals.narrow(1, k_clamped - 1, 1)?;
1048
1049 let mask = flat.ge(&kth_vals)?;
1051 let mask_f32 = mask.to_dtype(DType::F32)?;
1052
1053 let result = (&flat * &mask_f32)?;
1054 result.reshape(shape.as_slice()).map_err(Into::into)
1055}
1056
1057fn broadcast_bias(bias: &Tensor, target_shape: &[usize]) -> Result<Tensor> {
1071 let ndim = target_shape.len();
1072 if ndim <= 1 {
1073 return Ok(bias.clone());
1074 }
1075 let mut shape = vec![1_usize; ndim];
1077 let last_dim = *target_shape
1078 .last()
1079 .ok_or_else(|| MIError::Config("cannot broadcast bias to empty shape".into()))?;
1080 if let Some(slot) = shape.last_mut() {
1081 *slot = last_dim;
1082 }
1083 let reshaped = bias.reshape(shape.as_slice())?;
1084 Ok(reshaped.broadcast_as(target_shape)?)
1085}
1086
1087fn tensor_from_view(view: &safetensors::tensor::TensorView<'_>, device: &Device) -> Result<Tensor> {
1097 let shape: Vec<usize> = view.shape().to_vec();
1098 #[allow(clippy::wildcard_enum_match_arm)]
1099 let dtype = match view.dtype() {
1101 safetensors::Dtype::BF16 => DType::BF16,
1102 safetensors::Dtype::F16 => DType::F16,
1103 safetensors::Dtype::F32 => DType::F32,
1104 other => {
1105 return Err(MIError::Config(format!(
1106 "unsupported SAE tensor dtype: {other:?}"
1107 )));
1108 }
1109 };
1110 let tensor = Tensor::from_raw_buffer(view.data(), dtype, &shape, device)?;
1111 Ok(tensor)
1112}
1113
1114fn load_tensor(st: &SafeTensors<'_>, name: &str, device: &Device) -> Result<Tensor> {
1116 let view = st
1117 .tensor(name)
1118 .map_err(|e| MIError::Config(format!("tensor '{name}' not found: {e}")))?;
1119 tensor_from_view(&view, device)
1120}
1121
1122fn validate_shape(tensor: &Tensor, expected: &[usize], name: &str) -> Result<()> {
1124 if tensor.dims() != expected {
1125 return Err(MIError::Config(format!(
1126 "SAE tensor '{name}' shape mismatch: expected {expected:?}, got {:?}",
1127 tensor.dims()
1128 )));
1129 }
1130 Ok(())
1131}
1132
1133#[cfg(test)]
1138mod tests {
1139 use super::*;
1140
1141 #[test]
1142 fn sae_feature_id_display() {
1143 let fid = SaeFeatureId { index: 42 };
1144 assert_eq!(fid.to_string(), "SAE:42");
1145 }
1146
1147 #[test]
1148 fn resolve_architecture_relu_default() {
1149 let arch = resolve_architecture(None, None, None).unwrap();
1150 assert_eq!(arch, SaeArchitecture::ReLU);
1151 }
1152
1153 #[test]
1154 fn resolve_architecture_relu_explicit() {
1155 let arch = resolve_architecture(Some("standard"), Some("relu"), None).unwrap();
1156 assert_eq!(arch, SaeArchitecture::ReLU);
1157 }
1158
1159 #[test]
1160 fn resolve_architecture_jumprelu() {
1161 let arch = resolve_architecture(Some("jumprelu"), None, None).unwrap();
1162 assert_eq!(arch, SaeArchitecture::JumpReLU);
1163 }
1164
1165 #[test]
1166 fn resolve_architecture_jumprelu_from_activation() {
1167 let arch = resolve_architecture(None, Some("jumprelu"), None).unwrap();
1168 assert_eq!(arch, SaeArchitecture::JumpReLU);
1169 }
1170
1171 #[test]
1172 fn resolve_architecture_topk() {
1173 let kwargs = serde_json::json!({"k": 32});
1174 let arch = resolve_architecture(Some("topk"), None, Some(&kwargs)).unwrap();
1175 assert_eq!(arch, SaeArchitecture::TopK { k: 32 });
1176 }
1177
1178 #[test]
1179 fn resolve_architecture_topk_from_activation() {
1180 let kwargs = serde_json::json!({"k": 64});
1181 let arch = resolve_architecture(None, Some("topk"), Some(&kwargs)).unwrap();
1182 assert_eq!(arch, SaeArchitecture::TopK { k: 64 });
1183 }
1184
1185 #[test]
1186 fn resolve_architecture_topk_missing_k() {
1187 let result = resolve_architecture(Some("topk"), None, None);
1188 assert!(result.is_err());
1189 }
1190
1191 #[test]
1192 fn resolve_architecture_unknown() {
1193 let result = resolve_architecture(Some("gated"), None, None);
1194 assert!(result.is_err());
1195 }
1196
1197 #[test]
1198 fn parse_config_minimal() {
1199 let json = r#"{
1200 "d_in": 2304,
1201 "d_sae": 16384,
1202 "hook_name": "blocks.5.hook_resid_post"
1203 }"#;
1204 let raw: RawSaeConfig = serde_json::from_str(json).unwrap();
1205 let config = parse_sae_config(raw).unwrap();
1206 assert_eq!(config.d_in, 2304);
1207 assert_eq!(config.d_sae, 16384);
1208 assert_eq!(config.architecture, SaeArchitecture::ReLU);
1209 assert_eq!(config.hook_point, HookPoint::ResidPost(5));
1210 assert!(!config.apply_b_dec_to_input);
1211 }
1212
1213 #[test]
1214 fn parse_config_jumprelu() {
1215 let json = r#"{
1216 "d_in": 2304,
1217 "d_sae": 16384,
1218 "architecture": "jumprelu",
1219 "hook_name": "blocks.20.hook_resid_post",
1220 "apply_b_dec_to_input": true,
1221 "normalize_activations": "expected_average_only_in"
1222 }"#;
1223 let raw: RawSaeConfig = serde_json::from_str(json).unwrap();
1224 let config = parse_sae_config(raw).unwrap();
1225 assert_eq!(config.architecture, SaeArchitecture::JumpReLU);
1226 assert_eq!(config.hook_point, HookPoint::ResidPost(20));
1227 assert!(config.apply_b_dec_to_input);
1228 assert_eq!(
1229 config.normalize_activations,
1230 NormalizeActivations::ExpectedAverageOnlyIn
1231 );
1232 }
1233
1234 #[test]
1235 fn parse_config_topk() {
1236 let json = r#"{
1237 "d_in": 2304,
1238 "d_sae": 65536,
1239 "activation_fn_str": "topk",
1240 "activation_fn_kwargs": {"k": 32},
1241 "hook_name": "blocks.10.hook_resid_post"
1242 }"#;
1243 let raw: RawSaeConfig = serde_json::from_str(json).unwrap();
1244 let config = parse_sae_config(raw).unwrap();
1245 assert_eq!(config.architecture, SaeArchitecture::TopK { k: 32 });
1246 }
1247
1248 #[test]
1249 fn topk_cpu_basic() {
1250 let data = Tensor::new(&[[5.0_f32, 3.0, 1.0, 4.0, 2.0]], &Device::Cpu).unwrap();
1251 let result = topk_cpu(&data, 2).unwrap();
1252 let vals: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
1253 assert_eq!(vals, vec![5.0, 0.0, 0.0, 4.0, 0.0]);
1254 }
1255
1256 #[test]
1257 fn topk_cpu_all_kept() {
1258 let data = Tensor::new(&[[1.0_f32, 2.0, 3.0]], &Device::Cpu).unwrap();
1259 let result = topk_cpu(&data, 5).unwrap();
1260 let vals: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
1261 assert_eq!(vals, vec![1.0, 2.0, 3.0]);
1262 }
1263
1264 #[test]
1265 fn topk_cpu_none_kept() {
1266 let data = Tensor::new(&[[1.0_f32, 2.0, 3.0]], &Device::Cpu).unwrap();
1267 let result = topk_cpu(&data, 0).unwrap();
1268 let vals: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
1269 assert_eq!(vals, vec![0.0, 0.0, 0.0]);
1270 }
1271
1272 #[test]
1273 fn topk_cpu_batched() {
1274 let data = Tensor::new(
1275 &[[5.0_f32, 3.0, 1.0, 4.0, 2.0], [1.0, 2.0, 3.0, 4.0, 5.0]],
1276 &Device::Cpu,
1277 )
1278 .unwrap();
1279 let result = topk_cpu(&data, 3).unwrap();
1280 let vals: Vec<Vec<f32>> = result.to_vec2().unwrap();
1281 assert_eq!(vals[0], vec![5.0, 3.0, 0.0, 4.0, 0.0]);
1282 assert_eq!(vals[1], vec![0.0, 0.0, 3.0, 4.0, 5.0]);
1283 }
1284
1285 #[test]
1286 fn sparse_activations_sae() {
1287 let features = vec![
1288 (SaeFeatureId { index: 5 }, 3.0),
1289 (SaeFeatureId { index: 2 }, 2.0),
1290 (SaeFeatureId { index: 8 }, 1.0),
1291 ];
1292 let sparse = SparseActivations { features };
1293 assert_eq!(sparse.len(), 3);
1294 assert!(!sparse.is_empty());
1295 }
1296
1297 #[test]
1298 fn sparse_activations_truncate_sae() {
1299 let features = vec![
1300 (SaeFeatureId { index: 5 }, 3.0),
1301 (SaeFeatureId { index: 2 }, 2.0),
1302 (SaeFeatureId { index: 8 }, 1.0),
1303 ];
1304 let mut sparse = SparseActivations { features };
1305 sparse.truncate(2);
1306 assert_eq!(sparse.len(), 2);
1307 assert_eq!(sparse.features[0].0.index, 5);
1308 assert_eq!(sparse.features[1].0.index, 2);
1309 }
1310
1311 #[test]
1312 fn encode_decode_roundtrip_shapes() {
1313 let d_in = 4;
1315 let d_sae = 8;
1316 let device = Device::Cpu;
1317
1318 let w_enc = Tensor::randn(0.0_f32, 1.0, (d_in, d_sae), &device).unwrap();
1319 let w_dec = Tensor::randn(0.0_f32, 1.0, (d_sae, d_in), &device).unwrap();
1320 let b_enc = Tensor::zeros(d_sae, DType::F32, &device).unwrap();
1321 let b_dec = Tensor::zeros(d_in, DType::F32, &device).unwrap();
1322
1323 let sae = SparseAutoencoder {
1324 config: SaeConfig {
1325 d_in,
1326 d_sae,
1327 architecture: SaeArchitecture::ReLU,
1328 hook_name: "blocks.0.hook_resid_post".into(),
1329 hook_point: HookPoint::ResidPost(0),
1330 apply_b_dec_to_input: false,
1331 normalize_activations: NormalizeActivations::None,
1332 },
1333 w_enc,
1334 w_dec,
1335 b_enc,
1336 b_dec,
1337 threshold: None,
1338 };
1339
1340 let x1 = Tensor::randn(0.0_f32, 1.0, (d_in,), &device).unwrap();
1342 let encoded = sae.encode(&x1.unsqueeze(0).unwrap()).unwrap();
1343 assert_eq!(encoded.dims(), &[1, d_sae]);
1344
1345 let x2 = Tensor::randn(0.0_f32, 1.0, (3, d_in), &device).unwrap();
1347 let encoded = sae.encode(&x2).unwrap();
1348 assert_eq!(encoded.dims(), &[3, d_sae]);
1349 let decoded = sae.decode(&encoded).unwrap();
1350 assert_eq!(decoded.dims(), &[3, d_in]);
1351
1352 let x3 = Tensor::randn(0.0_f32, 1.0, (2, 5, d_in), &device).unwrap();
1354 let encoded = sae.encode(&x3).unwrap();
1355 assert_eq!(encoded.dims(), &[2, 5, d_sae]);
1356 let decoded = sae.decode(&encoded).unwrap();
1357 assert_eq!(decoded.dims(), &[2, 5, d_in]);
1358
1359 let x_hat = sae.reconstruct(&x2).unwrap();
1361 assert_eq!(x_hat.dims(), &[3, d_in]);
1362
1363 let mse = sae.reconstruction_error(&x2).unwrap();
1365 assert!(mse >= 0.0);
1366 }
1367
1368 #[test]
1369 fn encode_sparse_basic() {
1370 let d_in = 4;
1371 let d_sae = 8;
1372 let device = Device::Cpu;
1373
1374 let mut w_enc_data = vec![0.0_f32; d_in * d_sae];
1376 for i in 0..d_in {
1378 w_enc_data[i * d_sae + i] = 1.0;
1379 }
1380 let w_enc = Tensor::from_vec(w_enc_data, (d_in, d_sae), &device).unwrap();
1381 let w_dec = Tensor::randn(0.0_f32, 1.0, (d_sae, d_in), &device).unwrap();
1382 let b_enc = Tensor::zeros(d_sae, DType::F32, &device).unwrap();
1383 let b_dec = Tensor::zeros(d_in, DType::F32, &device).unwrap();
1384
1385 let sae = SparseAutoencoder {
1386 config: SaeConfig {
1387 d_in,
1388 d_sae,
1389 architecture: SaeArchitecture::ReLU,
1390 hook_name: "blocks.0.hook_resid_post".into(),
1391 hook_point: HookPoint::ResidPost(0),
1392 apply_b_dec_to_input: false,
1393 normalize_activations: NormalizeActivations::None,
1394 },
1395 w_enc,
1396 w_dec,
1397 b_enc,
1398 b_dec,
1399 threshold: None,
1400 };
1401
1402 let x = Tensor::new(&[2.0_f32, -1.0, 3.0, 0.5], &device).unwrap();
1403 let sparse = sae.encode_sparse(&x).unwrap();
1404
1405 assert_eq!(sparse.len(), 3);
1407 assert_eq!(sparse.features[0].0.index, 2); assert_eq!(sparse.features[1].0.index, 0); assert_eq!(sparse.features[2].0.index, 3); }
1412
1413 #[test]
1414 fn decoder_vector_basic() {
1415 let d_in = 4;
1416 let d_sae = 8;
1417 let device = Device::Cpu;
1418
1419 let w_dec = Tensor::randn(0.0_f32, 1.0, (d_sae, d_in), &device).unwrap();
1420 let sae = SparseAutoencoder {
1421 config: SaeConfig {
1422 d_in,
1423 d_sae,
1424 architecture: SaeArchitecture::ReLU,
1425 hook_name: "blocks.0.hook_resid_post".into(),
1426 hook_point: HookPoint::ResidPost(0),
1427 apply_b_dec_to_input: false,
1428 normalize_activations: NormalizeActivations::None,
1429 },
1430 w_enc: Tensor::zeros((d_in, d_sae), DType::F32, &device).unwrap(),
1431 w_dec: w_dec.clone(),
1432 b_enc: Tensor::zeros(d_sae, DType::F32, &device).unwrap(),
1433 b_dec: Tensor::zeros(d_in, DType::F32, &device).unwrap(),
1434 threshold: None,
1435 };
1436
1437 let vec0 = sae.decoder_vector(0).unwrap();
1438 assert_eq!(vec0.dims(), &[d_in]);
1439
1440 assert!(sae.decoder_vector(d_sae).is_err());
1442 }
1443
1444 #[test]
1445 fn prepare_injection_basic() {
1446 let d_in = 4;
1447 let d_sae = 8;
1448 let device = Device::Cpu;
1449
1450 let sae = SparseAutoencoder {
1451 config: SaeConfig {
1452 d_in,
1453 d_sae,
1454 architecture: SaeArchitecture::ReLU,
1455 hook_name: "blocks.0.hook_resid_post".into(),
1456 hook_point: HookPoint::ResidPost(0),
1457 apply_b_dec_to_input: false,
1458 normalize_activations: NormalizeActivations::None,
1459 },
1460 w_enc: Tensor::zeros((d_in, d_sae), DType::F32, &device).unwrap(),
1461 w_dec: Tensor::ones((d_sae, d_in), DType::F32, &device).unwrap(),
1462 b_enc: Tensor::zeros(d_sae, DType::F32, &device).unwrap(),
1463 b_dec: Tensor::zeros(d_in, DType::F32, &device).unwrap(),
1464 threshold: None,
1465 };
1466
1467 let features = vec![(0_usize, 1.0_f32), (1, 0.5)];
1468 let hooks = sae
1469 .prepare_hook_injection(&features, 2, 5, &device)
1470 .unwrap();
1471 assert!(!hooks.is_empty());
1472 }
1473}