Skip to main content

candle_mi/sae/
mod.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Sparse Autoencoder (SAE) support.
4//!
5//! Loads pre-trained SAE weights from SAELens-format safetensors + `cfg.json`
6//! or from Gemma Scope NPZ archives, encodes model activations into sparse
7//! feature vectors, decodes back to activation space, and produces steering
8//! vectors for injection.
9//!
10//! Each SAE targets a single hook point in the model (e.g., `resid_post` at
11//! layer 5). Multiple SAEs can be loaded independently for different hook
12//! points.
13//!
14//! # SAE Architecture
15//!
16//! A Sparse Autoencoder implements:
17//! ```text
18//! Encode:  features = activation_fn(x @ W_enc + b_enc)
19//! Decode:  x_hat = features @ W_dec + b_dec
20//! ```
21//!
22//! Supported activation functions:
23//! - **`ReLU`**: `features = ReLU(pre_acts)`
24//! - **`JumpReLU`**: `features = pre_acts * (pre_acts > threshold)`
25//! - **`TopK`**: keep only the top-k pre-activations, zero the rest
26//!
27//! # Weight File Formats
28//!
29//! ## `SAELens` safetensors format
30//!
31//! Each SAE directory contains:
32//! - `cfg.json`: configuration (`d_in`, `d_sae`, architecture, `hook_name`, ...)
33//! - `sae_weights.safetensors` (or `model.safetensors`): weight tensors
34//!
35//! ## Gemma Scope NPZ format
36//!
37//! A single `params.npz` file (`NumPy` ZIP archive) containing:
38//! - `W_enc.npy`: shape `[d_in, d_sae]` — encoder weight matrix
39//! - `W_dec.npy`: shape `[d_sae, d_in]` — decoder weight matrix
40//! - `b_enc.npy`: shape `[d_sae]` — encoder bias
41//! - `b_dec.npy`: shape `[d_in]` — decoder bias
42//! - `threshold.npy`: shape `[d_sae]` — `JumpReLU` threshold (optional)
43//!
44//! Architecture is auto-detected: presence of `threshold` → `JumpReLU`,
45//! otherwise `ReLU`.
46//!
47//! ## Tensor names (both formats)
48//!
49//! - `W_enc`: shape `[d_in, d_sae]` — encoder weight matrix
50//! - `W_dec`: shape `[d_sae, d_in]` — decoder weight matrix
51//! - `b_enc`: shape `[d_sae]` — encoder bias
52//! - `b_dec`: shape `[d_in]` — decoder bias
53//! - `threshold`: shape `[d_sae]` — `JumpReLU` threshold (optional)
54
55mod npz;
56
57use std::path::Path;
58
59use candle_core::{DType, Device, Tensor};
60use safetensors::tensor::SafeTensors;
61use tracing::info;
62
63use crate::error::{MIError, Result};
64use crate::hooks::{HookPoint, HookSpec, Intervention};
65use crate::sparse::{FeatureId, SparseActivations};
66
67// ---------------------------------------------------------------------------
68// Public types
69// ---------------------------------------------------------------------------
70
71/// Identifies a single SAE feature by its index within the dictionary.
72#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
73pub struct SaeFeatureId {
74    /// Feature index within the SAE dictionary (`0..d_sae`).
75    pub index: usize,
76}
77
78impl std::fmt::Display for SaeFeatureId {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        write!(f, "SAE:{}", self.index)
81    }
82}
83
84impl FeatureId for SaeFeatureId {}
85
86/// Activation function architecture for the SAE encoder.
87#[non_exhaustive]
88#[derive(Debug, Clone, PartialEq, Eq)]
89pub enum SaeArchitecture {
90    /// Standard `ReLU`: `features = ReLU(W_enc @ x + b_enc)`.
91    ReLU,
92    /// `JumpReLU`: `features = pre_acts * (pre_acts > threshold)`.
93    /// Requires a learned `threshold` tensor of shape `[d_sae]`.
94    JumpReLU,
95    /// `TopK`: keep only the top-k pre-activations, zero the rest.
96    TopK {
97        /// Number of features to keep active.
98        k: usize,
99    },
100}
101
102/// Input normalization strategy for SAE encoding.
103#[non_exhaustive]
104#[derive(Debug, Clone, PartialEq, Eq)]
105pub enum NormalizeActivations {
106    /// No normalization.
107    None,
108    /// Normalize by expected average L2 norm (estimated during training).
109    ExpectedAverageOnlyIn,
110}
111
112/// Strategy for `TopK` activation computation.
113#[non_exhaustive]
114#[derive(Debug, Clone, PartialEq, Eq)]
115pub enum TopKStrategy {
116    /// Automatically select based on device (CPU → direct, GPU → sort-based).
117    Auto,
118    /// Force CPU-side computation (transfer to CPU, compute mask, transfer back).
119    Cpu,
120    /// Force GPU-side sort-based computation.
121    Gpu,
122}
123
124/// Configuration for a Sparse Autoencoder, parsed from `cfg.json`.
125#[derive(Debug, Clone)]
126pub struct SaeConfig {
127    /// Input dimension (must match model hidden size at the hook point).
128    pub d_in: usize,
129    /// Dictionary size (number of SAE features).
130    pub d_sae: usize,
131    /// Encoder architecture (activation function).
132    pub architecture: SaeArchitecture,
133    /// `SAELens` hook name string (e.g., `"blocks.5.hook_resid_post"`).
134    pub hook_name: String,
135    /// Parsed hook point from the hook name.
136    pub hook_point: HookPoint,
137    /// Whether to subtract `b_dec` from input before encoding.
138    pub apply_b_dec_to_input: bool,
139    /// Input normalization strategy.
140    pub normalize_activations: NormalizeActivations,
141}
142
143// ---------------------------------------------------------------------------
144// Internal config parsing
145// ---------------------------------------------------------------------------
146
147#[derive(serde::Deserialize)]
148#[allow(clippy::missing_docs_in_private_items)]
149struct RawSaeConfig {
150    d_in: usize,
151    d_sae: usize,
152    #[serde(default)]
153    architecture: Option<String>,
154    #[serde(default)]
155    activation_fn_str: Option<String>,
156    #[serde(default)]
157    activation_fn_kwargs: Option<serde_json::Value>,
158    #[serde(default)]
159    hook_name: Option<String>,
160    #[serde(default)]
161    hook_point: Option<String>,
162    #[serde(default)]
163    apply_b_dec_to_input: bool,
164    #[serde(default)]
165    normalize_activations: Option<String>,
166}
167
168/// Parse a `RawSaeConfig` into a validated `SaeConfig`.
169fn parse_sae_config(raw: RawSaeConfig) -> Result<SaeConfig> {
170    // Resolve architecture from `architecture` or `activation_fn_str`.
171    let architecture = resolve_architecture(
172        raw.architecture.as_deref(),
173        raw.activation_fn_str.as_deref(),
174        raw.activation_fn_kwargs.as_ref(),
175    )?;
176
177    // Resolve hook name from `hook_name` or `hook_point`.
178    let hook_name = raw
179        .hook_name
180        .or(raw.hook_point)
181        .unwrap_or_else(|| "unknown".to_owned());
182
183    // Parse hook name to HookPoint via FromStr.
184    let hook_point: HookPoint = hook_name
185        .parse()
186        .unwrap_or_else(|_: std::convert::Infallible| {
187            // EXHAUSTIVE: Infallible can never happen, parse always succeeds
188            unreachable!()
189        });
190
191    let normalize_activations = match raw.normalize_activations.as_deref() {
192        Some("expected_average_only_in") => NormalizeActivations::ExpectedAverageOnlyIn,
193        _ => NormalizeActivations::None,
194    };
195
196    Ok(SaeConfig {
197        d_in: raw.d_in,
198        d_sae: raw.d_sae,
199        architecture,
200        hook_name,
201        hook_point,
202        apply_b_dec_to_input: raw.apply_b_dec_to_input,
203        normalize_activations,
204    })
205}
206
207/// Resolve SAE architecture from config fields.
208fn resolve_architecture(
209    architecture: Option<&str>,
210    activation_fn_str: Option<&str>,
211    activation_fn_kwargs: Option<&serde_json::Value>,
212) -> Result<SaeArchitecture> {
213    // Check `architecture` field first.
214    match architecture {
215        Some("jumprelu") => return Ok(SaeArchitecture::JumpReLU),
216        Some("topk") => {
217            let k = extract_topk_k(activation_fn_kwargs)?;
218            return Ok(SaeArchitecture::TopK { k });
219        }
220        Some("standard") | None => {} // fall through to activation_fn_str
221        Some(other) => {
222            return Err(MIError::Config(format!(
223                "unsupported SAE architecture: {other:?}"
224            )));
225        }
226    }
227
228    // Fall back to `activation_fn_str`.
229    match activation_fn_str {
230        Some("relu") | None => Ok(SaeArchitecture::ReLU),
231        Some("jumprelu") => Ok(SaeArchitecture::JumpReLU),
232        Some("topk") => {
233            let k = extract_topk_k(activation_fn_kwargs)?;
234            Ok(SaeArchitecture::TopK { k })
235        }
236        Some(other) => Err(MIError::Config(format!(
237            "unsupported SAE activation function: {other:?}"
238        ))),
239    }
240}
241
242/// Extract `k` from `activation_fn_kwargs.k`.
243fn extract_topk_k(kwargs: Option<&serde_json::Value>) -> Result<usize> {
244    let k = kwargs
245        .and_then(|v| v.get("k"))
246        .and_then(serde_json::Value::as_u64)
247        .ok_or_else(|| {
248            MIError::Config("TopK SAE requires activation_fn_kwargs.k in cfg.json".into())
249        })?;
250    let k_usize = usize::try_from(k)
251        .map_err(|_| MIError::Config(format!("TopK k value {k} too large for usize")))?;
252    Ok(k_usize)
253}
254
255// ---------------------------------------------------------------------------
256// SparseAutoencoder
257// ---------------------------------------------------------------------------
258
259/// A Sparse Autoencoder for mechanistic interpretability.
260///
261/// Loads SAE weights from SAELens-format safetensors + `cfg.json`,
262/// encodes model activations into sparse feature vectors, decodes
263/// back to activation space, and produces steering vectors for injection.
264///
265/// Each SAE targets a single hook point in the model (e.g., `resid_post`
266/// at layer 5). Multiple SAEs can be loaded independently for different
267/// hook points.
268///
269/// # Example
270///
271/// ```no_run
272/// # fn main() -> candle_mi::Result<()> {
273/// use candle_mi::sae::SparseAutoencoder;
274/// use candle_core::Device;
275///
276/// let sae = SparseAutoencoder::from_pretrained(
277///     "jbloom/Gemma-2-2B-Residual-Stream-SAEs",
278///     "gemma-2-2b-res-jb/blocks.20.hook_resid_post",
279///     &Device::Cpu,
280/// )?;
281/// println!("SAE: d_in={}, d_sae={}", sae.d_in(), sae.d_sae());
282/// # Ok(())
283/// # }
284/// ```
285pub struct SparseAutoencoder {
286    /// SAE configuration parsed from `cfg.json`.
287    config: SaeConfig,
288    /// Encoder weight matrix.
289    ///
290    /// # Shapes
291    /// - `w_enc`: `[d_in, d_sae]`
292    w_enc: Tensor,
293    /// Decoder weight matrix.
294    ///
295    /// # Shapes
296    /// - `w_dec`: `[d_sae, d_in]`
297    w_dec: Tensor,
298    /// Encoder bias vector.
299    ///
300    /// # Shapes
301    /// - `b_enc`: `[d_sae]`
302    b_enc: Tensor,
303    /// Decoder bias vector.
304    ///
305    /// # Shapes
306    /// - `b_dec`: `[d_in]`
307    b_dec: Tensor,
308    /// `JumpReLU` threshold (only present for `JumpReLU` architecture).
309    ///
310    /// # Shapes
311    /// - `threshold`: `[d_sae]`
312    threshold: Option<Tensor>,
313}
314
315impl SparseAutoencoder {
316    // --- Loading ---
317
318    /// Load an SAE from a local directory containing safetensors + `cfg.json`.
319    ///
320    /// Expects either `sae_weights.safetensors` or `model.safetensors`
321    /// plus a `cfg.json` file.
322    ///
323    /// # Errors
324    ///
325    /// Returns [`MIError::Config`] if `cfg.json` is missing or malformed.
326    /// Returns [`MIError::Config`] if weight shapes don't match `cfg.json` dimensions.
327    /// Returns [`MIError::Model`] on tensor loading failure.
328    /// Returns [`MIError::Io`] if files cannot be read.
329    pub fn from_local(dir: &Path, device: &Device) -> Result<Self> {
330        // Parse cfg.json.
331        let cfg_path = dir.join("cfg.json");
332        if !cfg_path.exists() {
333            return Err(MIError::Config(format!(
334                "cfg.json not found in {}",
335                dir.display()
336            )));
337        }
338        let cfg_text = std::fs::read_to_string(&cfg_path)?;
339        let raw: RawSaeConfig = serde_json::from_str(&cfg_text)
340            .map_err(|e| MIError::Config(format!("failed to parse cfg.json: {e}")))?;
341        let config = parse_sae_config(raw)?;
342
343        info!(
344            "SAE config: d_in={}, d_sae={}, arch={:?}, hook={}",
345            config.d_in, config.d_sae, config.architecture, config.hook_name
346        );
347
348        // Find safetensors file.
349        let weights_path = if dir.join("sae_weights.safetensors").exists() {
350            dir.join("sae_weights.safetensors")
351        } else if dir.join("model.safetensors").exists() {
352            dir.join("model.safetensors")
353        } else {
354            return Err(MIError::Config(format!(
355                "no safetensors file found in {}",
356                dir.display()
357            )));
358        };
359
360        // Load weights.
361        let data = std::fs::read(&weights_path)?;
362        let st = SafeTensors::deserialize(&data)
363            .map_err(|e| MIError::Config(format!("failed to deserialize SAE weights: {e}")))?;
364
365        let w_enc = load_tensor(&st, "W_enc", device)?;
366        let w_dec = load_tensor(&st, "W_dec", device)?;
367        let b_enc = load_tensor(&st, "b_enc", device)?;
368        let b_dec = load_tensor(&st, "b_dec", device)?;
369        let threshold = st
370            .tensor("threshold")
371            .ok()
372            .map(|v| tensor_from_view(&v, device))
373            .transpose()?;
374
375        // PROMOTE: F32 for numerical stability in matmul and bias add
376        let w_enc = w_enc.to_dtype(DType::F32)?;
377        let w_dec = w_dec.to_dtype(DType::F32)?;
378        let b_enc = b_enc.to_dtype(DType::F32)?;
379        let b_dec = b_dec.to_dtype(DType::F32)?;
380        let threshold = threshold.map(|t| t.to_dtype(DType::F32)).transpose()?;
381
382        // Validate shapes.
383        validate_shape(&w_enc, &[config.d_in, config.d_sae], "W_enc")?;
384        validate_shape(&w_dec, &[config.d_sae, config.d_in], "W_dec")?;
385        validate_shape(&b_enc, &[config.d_sae], "b_enc")?;
386        validate_shape(&b_dec, &[config.d_in], "b_dec")?;
387        if let Some(ref t) = threshold {
388            validate_shape(t, &[config.d_sae], "threshold")?;
389        }
390
391        // Validate JumpReLU has threshold.
392        if config.architecture == SaeArchitecture::JumpReLU && threshold.is_none() {
393            return Err(MIError::Config(
394                "JumpReLU SAE requires 'threshold' tensor in weights file".into(),
395            ));
396        }
397
398        info!(
399            "SAE loaded: {} weights on {:?}",
400            weights_path.display(),
401            device
402        );
403
404        Ok(Self {
405            config,
406            w_enc,
407            w_dec,
408            b_enc,
409            b_dec,
410            threshold,
411        })
412    }
413
414    /// Load an SAE from a Gemma Scope NPZ file (`params.npz`).
415    ///
416    /// The NPZ file must contain `W_enc`, `W_dec`, `b_enc`, `b_dec` arrays,
417    /// and optionally `threshold` (for `JumpReLU`). Config is inferred from
418    /// tensor shapes since NPZ files have no `cfg.json`.
419    ///
420    /// # Arguments
421    /// * `npz_path` — Path to the `params.npz` file
422    /// * `hook_layer` — Which model layer this SAE hooks into
423    /// * `device` — Target device (CPU or CUDA)
424    ///
425    /// # Errors
426    ///
427    /// Returns [`MIError::Config`] if required tensors are missing or shapes
428    /// are inconsistent.
429    /// Returns [`MIError::Io`] if the file cannot be read.
430    pub fn from_npz(npz_path: &Path, hook_layer: usize, device: &Device) -> Result<Self> {
431        info!("Loading SAE from NPZ: {}", npz_path.display());
432        let tensors = npz::load_npz(npz_path, device)?;
433
434        let w_enc = tensors
435            .get("W_enc")
436            .ok_or_else(|| MIError::Config("NPZ missing W_enc".into()))?
437            .to_dtype(DType::F32)?;
438        let w_dec = tensors
439            .get("W_dec")
440            .ok_or_else(|| MIError::Config("NPZ missing W_dec".into()))?
441            .to_dtype(DType::F32)?;
442        let b_enc = tensors
443            .get("b_enc")
444            .ok_or_else(|| MIError::Config("NPZ missing b_enc".into()))?
445            .to_dtype(DType::F32)?;
446        let b_dec = tensors
447            .get("b_dec")
448            .ok_or_else(|| MIError::Config("NPZ missing b_dec".into()))?
449            .to_dtype(DType::F32)?;
450        let threshold = tensors
451            .get("threshold")
452            .map(|t| t.to_dtype(DType::F32))
453            .transpose()?;
454
455        // Infer dimensions from W_enc: [d_in, d_sae].
456        let w_enc_dims = w_enc.dims();
457        if w_enc_dims.len() != 2 {
458            return Err(MIError::Config(format!(
459                "W_enc expected 2 dims, got {}",
460                w_enc_dims.len()
461            )));
462        }
463        let d_in = *w_enc_dims
464            .first()
465            .ok_or_else(|| MIError::Config("W_enc has no dimensions".into()))?;
466        let d_sae = *w_enc_dims
467            .get(1)
468            .ok_or_else(|| MIError::Config("W_enc has no second dimension".into()))?;
469
470        // Validate shapes.
471        validate_shape(&w_enc, &[d_in, d_sae], "W_enc")?;
472        validate_shape(&w_dec, &[d_sae, d_in], "W_dec")?;
473        validate_shape(&b_enc, &[d_sae], "b_enc")?;
474        validate_shape(&b_dec, &[d_in], "b_dec")?;
475        if let Some(ref t) = threshold {
476            validate_shape(t, &[d_sae], "threshold")?;
477        }
478
479        // Auto-detect architecture: threshold present → JumpReLU, else ReLU.
480        let architecture = if threshold.is_some() {
481            SaeArchitecture::JumpReLU
482        } else {
483            SaeArchitecture::ReLU
484        };
485
486        let hook_name = format!("blocks.{hook_layer}.hook_resid_post");
487        let hook_point = hook_name
488            .parse::<HookPoint>()
489            .map_err(|e| MIError::Config(format!("failed to parse hook name: {e}")))?;
490
491        let config = SaeConfig {
492            d_in,
493            d_sae,
494            architecture,
495            hook_name,
496            hook_point,
497            apply_b_dec_to_input: false,
498            normalize_activations: NormalizeActivations::None,
499        };
500
501        info!(
502            "SAE from NPZ: d_in={d_in}, d_sae={d_sae}, arch={:?}, hook={}",
503            config.architecture, config.hook_name
504        );
505
506        Ok(Self {
507            config,
508            w_enc,
509            w_dec,
510            b_enc,
511            b_dec,
512            threshold,
513        })
514    }
515
516    /// Load an SAE from a `HuggingFace` repository containing an NPZ file.
517    ///
518    /// Downloads the NPZ file via `hf-fetch-model`, then delegates to
519    /// [`from_npz`](Self::from_npz).
520    ///
521    /// # Arguments
522    /// * `repo_id` — `HuggingFace` repository ID
523    ///   (e.g., `"google/gemma-scope-2b-pt-res"`)
524    /// * `npz_path` — Path within the repo to the NPZ file
525    ///   (e.g., `"layer_0/width_16k/average_l0_105/params.npz"`)
526    /// * `hook_layer` — Which model layer this SAE hooks into
527    /// * `device` — Target device (CPU or CUDA)
528    ///
529    /// # Errors
530    ///
531    /// Returns [`MIError::Download`] if the file cannot be fetched.
532    /// Returns [`MIError::Config`] if the NPZ format is invalid.
533    pub fn from_pretrained_npz(
534        repo_id: &str,
535        npz_path: &str,
536        hook_layer: usize,
537        device: &Device,
538    ) -> Result<Self> {
539        let fetch_config = hf_fetch_model::FetchConfig::builder()
540            .on_progress(|event| {
541                tracing::info!(
542                    filename = %event.filename,
543                    percent = event.percent,
544                    bytes_downloaded = event.bytes_downloaded,
545                    bytes_total = event.bytes_total,
546                    "SAE NPZ download progress",
547                );
548            })
549            .build()
550            .map_err(|e| MIError::Download(format!("failed to build fetch config: {e}")))?;
551
552        info!("Downloading {npz_path} from {repo_id}");
553        let local_path =
554            hf_fetch_model::download_file_blocking(repo_id.to_owned(), npz_path, &fetch_config)
555                .map_err(|e| MIError::Download(format!("failed to download NPZ: {e}")))?
556                .into_inner();
557
558        Self::from_npz(&local_path, hook_layer, device)
559    }
560
561    /// Load an SAE from a `HuggingFace` repository.
562    ///
563    /// Downloads safetensors + `cfg.json` via `hf-fetch-model`, then delegates
564    /// to [`from_local`](Self::from_local).
565    ///
566    /// # Arguments
567    /// * `repo_id` — `HuggingFace` repository ID
568    ///   (e.g., `"jbloom/Gemma-2-2B-Residual-Stream-SAEs"`)
569    /// * `sae_id` — Subdirectory within the repo
570    ///   (e.g., `"gemma-2-2b-res-jb/blocks.20.hook_resid_post"`)
571    /// * `device` — Target device (CPU or CUDA)
572    ///
573    /// # Errors
574    ///
575    /// Returns [`MIError::Download`] if files cannot be fetched.
576    /// Returns [`MIError::Config`] if the SAE format is invalid.
577    pub fn from_pretrained(repo_id: &str, sae_id: &str, device: &Device) -> Result<Self> {
578        let fetch_config = hf_fetch_model::FetchConfig::builder()
579            .on_progress(|event| {
580                tracing::info!(
581                    filename = %event.filename,
582                    percent = event.percent,
583                    bytes_downloaded = event.bytes_downloaded,
584                    bytes_total = event.bytes_total,
585                    "SAE download progress",
586                );
587            })
588            .build()
589            .map_err(|e| MIError::Download(format!("failed to build fetch config: {e}")))?;
590
591        // Download cfg.json.
592        let cfg_remote = format!("{sae_id}/cfg.json");
593        info!("Downloading {cfg_remote} from {repo_id}");
594        let cfg_path =
595            hf_fetch_model::download_file_blocking(repo_id.to_owned(), &cfg_remote, &fetch_config)
596                .map_err(|e| MIError::Download(format!("failed to download cfg.json: {e}")))?
597                .into_inner();
598
599        // Download weights: try sae_weights.safetensors, fall back to model.safetensors.
600        let weights_remote = format!("{sae_id}/sae_weights.safetensors");
601        info!("Downloading {weights_remote} from {repo_id}");
602        let weights_path = hf_fetch_model::download_file_blocking(
603            repo_id.to_owned(),
604            &weights_remote,
605            &fetch_config,
606        )
607        .or_else(|_| {
608            let alt_remote = format!("{sae_id}/model.safetensors");
609            info!("Trying {alt_remote} from {repo_id}");
610            hf_fetch_model::download_file_blocking(repo_id.to_owned(), &alt_remote, &fetch_config)
611        })
612        .map_err(|e| MIError::Download(format!("failed to download SAE weights: {e}")))?
613        .into_inner();
614
615        // Both files are in cache; determine the common directory.
616        let dir = cfg_path.parent().ok_or_else(|| {
617            MIError::Config("cannot determine SAE directory from cfg.json path".into())
618        })?;
619
620        // Verify the weights file is in the same directory (or just load by path).
621        // hf-fetch-model may place files in the same cache dir; if not, we need
622        // to construct the dir from the weights path instead.
623        if dir.join("sae_weights.safetensors").exists() || dir.join("model.safetensors").exists() {
624            Self::from_local(dir, device)
625        } else {
626            // Files might be in different cache locations; load manually.
627            let weights_dir = weights_path.parent().ok_or_else(|| {
628                MIError::Config("cannot determine SAE directory from weights path".into())
629            })?;
630            // Copy cfg.json to weights dir if needed.
631            let target_cfg = weights_dir.join("cfg.json");
632            if !target_cfg.exists() {
633                std::fs::copy(&cfg_path, &target_cfg)?;
634            }
635            Self::from_local(weights_dir, device)
636        }
637    }
638
639    // --- Accessors ---
640
641    /// Access the SAE configuration.
642    #[must_use]
643    pub const fn config(&self) -> &SaeConfig {
644        &self.config
645    }
646
647    /// The hook point this SAE targets.
648    #[must_use]
649    pub const fn hook_point(&self) -> &HookPoint {
650        &self.config.hook_point
651    }
652
653    /// Dictionary size (number of features).
654    #[must_use]
655    pub const fn d_sae(&self) -> usize {
656        self.config.d_sae
657    }
658
659    /// Input dimension.
660    #[must_use]
661    pub const fn d_in(&self) -> usize {
662        self.config.d_in
663    }
664
665    // --- Encoding ---
666
667    /// Encode activations into SAE feature space (dense output).
668    ///
669    /// Applies the full encoder: `pre_acts = x @ W_enc + b_enc`, then the
670    /// architecture-specific activation function (`ReLU`, `JumpReLU`, or `TopK`).
671    ///
672    /// Uses [`TopKStrategy::Auto`] for `TopK` SAEs.
673    ///
674    /// # Shapes
675    /// - `x`: `[..., d_in]` — activations with any leading dimensions
676    /// - returns: `[..., d_sae]` — encoded features (mostly sparse)
677    ///
678    /// # Errors
679    ///
680    /// Returns [`MIError::Config`] if the last dimension of `x` != `d_in`.
681    /// Returns [`MIError::Model`] on tensor operation failure.
682    pub fn encode(&self, x: &Tensor) -> Result<Tensor> {
683        self.encode_with_strategy(x, &TopKStrategy::Auto)
684    }
685
686    /// Encode activations with an explicit [`TopKStrategy`].
687    ///
688    /// Same as [`encode()`](Self::encode) but allows overriding the `TopK`
689    /// computation strategy.
690    ///
691    /// # Shapes
692    /// - `x`: `[..., d_in]` — activations with any leading dimensions
693    /// - returns: `[..., d_sae]` — encoded features (mostly sparse)
694    ///
695    /// # Errors
696    ///
697    /// Returns [`MIError::Config`] if the last dimension of `x` != `d_in`.
698    /// Returns [`MIError::Model`] on tensor operation failure.
699    pub fn encode_with_strategy(&self, x: &Tensor, strategy: &TopKStrategy) -> Result<Tensor> {
700        let dims = x.dims();
701        let last_dim = *dims
702            .last()
703            .ok_or_else(|| MIError::Config("cannot encode empty tensor".into()))?;
704        if last_dim != self.config.d_in {
705            return Err(MIError::Config(format!(
706                "input last dim {last_dim} != SAE d_in {}",
707                self.config.d_in
708            )));
709        }
710
711        // PROMOTE: F32 for matmul precision
712        let x_f32 = x.to_dtype(DType::F32)?;
713
714        // Optionally subtract b_dec from input (centering).
715        let x_centered = if self.config.apply_b_dec_to_input {
716            let b_dec = broadcast_bias(&self.b_dec, x_f32.dims())?;
717            (&x_f32 - &b_dec)?
718        } else {
719            x_f32
720        };
721
722        // pre_acts = x @ W_enc + b_enc
723        // [... , d_in] @ [d_in, d_sae] → [..., d_sae]
724        let pre_acts = x_centered.broadcast_matmul(&self.w_enc)?;
725        // Broadcast b_enc [d_sae] to match pre_acts leading dims.
726        let b_enc = broadcast_bias(&self.b_enc, pre_acts.dims())?;
727        let pre_acts = (&pre_acts + &b_enc)?;
728
729        // Apply activation function.
730        match &self.config.architecture {
731            SaeArchitecture::ReLU => Ok(pre_acts.relu()?),
732            SaeArchitecture::JumpReLU => {
733                let threshold = self
734                    .threshold
735                    .as_ref()
736                    .ok_or_else(|| MIError::Config("JumpReLU requires threshold tensor".into()))?;
737                // Broadcast threshold to match pre_acts leading dims.
738                let threshold = broadcast_bias(threshold, pre_acts.dims())?;
739                // mask = (pre_acts > threshold), features = pre_acts * mask
740                let mask = pre_acts.gt(&threshold)?;
741                let mask_f32 = mask.to_dtype(DType::F32)?;
742                Ok((&pre_acts * &mask_f32)?)
743            }
744            SaeArchitecture::TopK { k } => topk_activation(&pre_acts, *k, strategy),
745        }
746    }
747
748    /// Encode a single activation vector into sparse SAE features.
749    ///
750    /// Returns only non-zero features, sorted by magnitude descending.
751    ///
752    /// # Shapes
753    /// - `x`: `[d_in]` — single activation vector
754    /// - returns: [`SparseActivations<SaeFeatureId>`] with `(SaeFeatureId, f32)` pairs
755    ///
756    /// # Errors
757    ///
758    /// Returns [`MIError::Config`] if `x` has wrong dimension.
759    /// Returns [`MIError::Model`] on tensor operation failure.
760    pub fn encode_sparse(&self, x: &Tensor) -> Result<SparseActivations<SaeFeatureId>> {
761        let encoded = self.encode(&x.unsqueeze(0)?)?;
762        let encoded_1d = encoded.squeeze(0)?;
763
764        // Transfer to CPU for sparse extraction.
765        let values: Vec<f32> = encoded_1d.to_vec1()?;
766
767        let mut features: Vec<(SaeFeatureId, f32)> = values
768            .iter()
769            .enumerate()
770            .filter(|&(_, v)| *v > 0.0)
771            .map(|(i, v)| (SaeFeatureId { index: i }, *v))
772            .collect();
773
774        // Sort by activation magnitude (descending).
775        features.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
776
777        Ok(SparseActivations { features })
778    }
779
780    // --- Decoding ---
781
782    /// Decode SAE features back to activation space.
783    ///
784    /// # Shapes
785    /// - `features`: `[..., d_sae]` — encoded feature activations
786    /// - returns: `[..., d_in]` — reconstructed activations
787    ///
788    /// # Errors
789    ///
790    /// Returns [`MIError::Model`] on tensor operation failure.
791    pub fn decode(&self, features: &Tensor) -> Result<Tensor> {
792        // x_hat = features @ W_dec + b_dec
793        // [..., d_sae] @ [d_sae, d_in] → [..., d_in]
794        let features_f32 = features.to_dtype(DType::F32)?;
795        let decoded = features_f32.broadcast_matmul(&self.w_dec)?;
796        let b_dec = broadcast_bias(&self.b_dec, decoded.dims())?;
797        Ok((&decoded + &b_dec)?)
798    }
799
800    // --- Reconstruction ---
801
802    /// Reconstruct activations through the SAE (encode then decode).
803    ///
804    /// # Shapes
805    /// - `x`: `[..., d_in]` — original activations
806    /// - returns: `[..., d_in]` — reconstructed activations
807    ///
808    /// # Errors
809    ///
810    /// Returns [`MIError::Config`] if the last dimension of `x` != `d_in`.
811    /// Returns [`MIError::Model`] on tensor operation failure.
812    pub fn reconstruct(&self, x: &Tensor) -> Result<Tensor> {
813        let encoded = self.encode(x)?;
814        self.decode(&encoded)
815    }
816
817    /// Compute reconstruction MSE loss.
818    ///
819    /// # Shapes
820    /// - `x`: `[..., d_in]` — original activations
821    /// - returns: scalar `f64` mean squared error
822    ///
823    /// # Errors
824    ///
825    /// Returns [`MIError::Config`] if the last dimension of `x` != `d_in`.
826    /// Returns [`MIError::Model`] on tensor operation failure.
827    pub fn reconstruction_error(&self, x: &Tensor) -> Result<f64> {
828        let x_f32 = x.to_dtype(DType::F32)?;
829        let x_hat = self.reconstruct(&x_f32)?;
830        let diff = (&x_f32 - &x_hat)?;
831        let mse: f32 = diff.sqr()?.mean_all()?.to_scalar()?;
832        Ok(f64::from(mse))
833    }
834
835    // --- Steering ---
836
837    /// Extract a single feature's decoder vector (steering direction).
838    ///
839    /// # Shapes
840    /// - returns: `[d_in]` — decoder vector on the SAE's device
841    ///
842    /// # Errors
843    ///
844    /// Returns [`MIError::Config`] if `feature_idx` >= `d_sae`.
845    /// Returns [`MIError::Model`] on tensor operation failure.
846    pub fn decoder_vector(&self, feature_idx: usize) -> Result<Tensor> {
847        if feature_idx >= self.config.d_sae {
848            return Err(MIError::Config(format!(
849                "feature index {feature_idx} out of range (d_sae={})",
850                self.config.d_sae
851            )));
852        }
853        // W_dec: [d_sae, d_in] → row feature_idx → [d_in]
854        Ok(self.w_dec.get(feature_idx)?)
855    }
856
857    /// Build a [`HookSpec`] that injects SAE decoder vectors into the model.
858    ///
859    /// Creates an [`Intervention::Add`] at this SAE's hook point with the
860    /// accumulated (scaled) decoder vectors placed at the given position.
861    ///
862    /// # Shapes
863    /// - Internally constructs `[1, seq_len, d_in]` with the vector at `position`.
864    ///
865    /// # Arguments
866    /// * `features` — List of `(feature_index, strength)` pairs
867    /// * `position` — Token position in the sequence to inject at
868    /// * `seq_len` — Total sequence length
869    /// * `device` — Device to construct injection tensors on
870    ///
871    /// # Errors
872    ///
873    /// Returns [`MIError::Config`] if any `feature_index` >= `d_sae`.
874    /// Returns [`MIError::Model`] on tensor construction failure.
875    pub fn prepare_hook_injection(
876        &self,
877        features: &[(usize, f32)],
878        position: usize,
879        seq_len: usize,
880        device: &Device,
881    ) -> Result<HookSpec> {
882        let d_in = self.config.d_in;
883
884        // Accumulate weighted decoder vectors.
885        let mut accumulated = Tensor::zeros(d_in, DType::F32, device)?;
886        for &(feature_idx, strength) in features {
887            let dec_vec = self.decoder_vector(feature_idx)?;
888            let dec_vec = dec_vec.to_device(device)?;
889            let scaled = (&dec_vec * f64::from(strength))?;
890            accumulated = (&accumulated + &scaled)?;
891        }
892
893        // Build [1, seq_len, d_in] with vector at `position`.
894        let injection = Tensor::zeros((1, seq_len, d_in), DType::F32, device)?;
895        let scaled_3d = accumulated.unsqueeze(0)?.unsqueeze(0)?; // [1, 1, d_in]
896
897        let before = if position > 0 {
898            Some(injection.narrow(1, 0, position)?)
899        } else {
900            None
901        };
902        let after = if position + 1 < seq_len {
903            Some(injection.narrow(1, position + 1, seq_len - position - 1)?)
904        } else {
905            None
906        };
907
908        let mut parts: Vec<Tensor> = Vec::with_capacity(3);
909        if let Some(b) = before {
910            parts.push(b);
911        }
912        parts.push(scaled_3d);
913        if let Some(a) = after {
914            parts.push(a);
915        }
916
917        let injection = Tensor::cat(&parts, 1)?;
918
919        let mut hooks = HookSpec::new();
920        hooks.intervene(self.config.hook_point.clone(), Intervention::Add(injection));
921        Ok(hooks)
922    }
923}
924
925// ---------------------------------------------------------------------------
926// TopK activation
927// ---------------------------------------------------------------------------
928
929/// Apply top-k activation: keep only the k largest values, zero the rest.
930///
931/// # Shapes
932/// - `pre_acts`: `[..., d_sae]` — pre-activation values
933/// - returns: same shape, with all but top-k values zeroed per last-dim slice
934///
935/// # Strategy
936/// - [`TopKStrategy::Auto`]: CPU → direct iteration, GPU → sort-based
937/// - [`TopKStrategy::Cpu`]: force CPU path
938/// - [`TopKStrategy::Gpu`]: force GPU sort-based path
939fn topk_activation(pre_acts: &Tensor, k: usize, strategy: &TopKStrategy) -> Result<Tensor> {
940    let use_cpu = match strategy {
941        TopKStrategy::Cpu => true,
942        TopKStrategy::Gpu => false,
943        TopKStrategy::Auto => matches!(pre_acts.device(), Device::Cpu),
944    };
945
946    if use_cpu {
947        topk_cpu(pre_acts, k)
948    } else {
949        topk_gpu(pre_acts, k)
950    }
951}
952
953/// `TopK` via CPU-side partial sort.
954fn topk_cpu(pre_acts: &Tensor, k: usize) -> Result<Tensor> {
955    let device = pre_acts.device().clone();
956    let shape = pre_acts.dims().to_vec();
957    let d_sae = *shape
958        .last()
959        .ok_or_else(|| MIError::Config("cannot apply TopK to empty tensor".into()))?;
960
961    // Flatten to 2D: [n, d_sae]
962    let n: usize = shape.iter().take(shape.len() - 1).product();
963    let flat = pre_acts.reshape((n, d_sae))?.to_dtype(DType::F32)?;
964    let flat_cpu = flat.to_device(&Device::Cpu)?;
965
966    let mut result_data: Vec<f32> = Vec::with_capacity(n * d_sae);
967
968    for row_idx in 0..n {
969        let row = flat_cpu.get(row_idx)?;
970        let mut row_vec: Vec<f32> = row.to_vec1()?;
971
972        // Find the k-th largest value via partial sort.
973        let k_clamped = k.min(d_sae);
974        if k_clamped > 0 && k_clamped < d_sae {
975            // Partial sort: put the k largest elements at the front.
976            let mut indices: Vec<usize> = (0..d_sae).collect();
977            #[allow(clippy::indexing_slicing)]
978            // CONTIGUOUS: indices and row_vec are both exactly d_sae elements
979            indices.select_nth_unstable_by(k_clamped - 1, |&a, &b| {
980                let va = row_vec.get(b).copied().unwrap_or(f32::NEG_INFINITY);
981                let vb = row_vec.get(a).copied().unwrap_or(f32::NEG_INFINITY);
982                va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal)
983            });
984            let threshold_idx = indices.get(k_clamped - 1).copied().unwrap_or(0);
985            let threshold = row_vec.get(threshold_idx).copied().unwrap_or(0.0);
986
987            // Zero values below threshold.
988            for v in &mut row_vec {
989                if *v < threshold {
990                    *v = 0.0;
991                }
992            }
993
994            // If there are ties at threshold, we might keep more than k.
995            // Count how many are >= threshold and zero extras from the end.
996            let active: usize = row_vec.iter().filter(|&&v| v >= threshold).count();
997            if active > k_clamped {
998                let mut excess = active - k_clamped;
999                for v in row_vec.iter_mut().rev() {
1000                    if excess == 0 {
1001                        break;
1002                    }
1003                    if (*v - threshold).abs() < f32::EPSILON {
1004                        *v = 0.0;
1005                        excess -= 1;
1006                    }
1007                }
1008            }
1009        } else if k_clamped == 0 {
1010            row_vec.fill(0.0);
1011        }
1012        // k_clamped >= d_sae: keep all values
1013
1014        result_data.extend_from_slice(&row_vec);
1015    }
1016
1017    let result = Tensor::from_vec(result_data, (n, d_sae), &device)?;
1018    result.reshape(shape.as_slice()).map_err(Into::into)
1019}
1020
1021/// `TopK` via GPU sort-based masking.
1022fn topk_gpu(pre_acts: &Tensor, k: usize) -> Result<Tensor> {
1023    let shape = pre_acts.dims().to_vec();
1024    let d_sae = *shape
1025        .last()
1026        .ok_or_else(|| MIError::Config("cannot apply TopK to empty tensor".into()))?;
1027
1028    let k_clamped = k.min(d_sae);
1029    if k_clamped == 0 {
1030        return Ok(pre_acts.zeros_like()?);
1031    }
1032    if k_clamped >= d_sae {
1033        return Ok(pre_acts.clone());
1034    }
1035
1036    // Flatten to 2D for sort_last_dim.
1037    let n: usize = shape.iter().take(shape.len() - 1).product();
1038    let flat = pre_acts.reshape((n, d_sae))?.to_dtype(DType::F32)?;
1039
1040    // Sort descending along last dim.
1041    let (sorted_vals, _sorted_indices) = flat.sort_last_dim(false)?;
1042
1043    // Get the k-th largest value per row: [n, 1]
1044    let kth_vals = sorted_vals.narrow(1, k_clamped - 1, 1)?;
1045
1046    // Mask: keep values >= kth value.
1047    let mask = flat.ge(&kth_vals)?;
1048    let mask_f32 = mask.to_dtype(DType::F32)?;
1049
1050    let result = (&flat * &mask_f32)?;
1051    result.reshape(shape.as_slice()).map_err(Into::into)
1052}
1053
1054// ---------------------------------------------------------------------------
1055// Helper functions
1056// ---------------------------------------------------------------------------
1057
1058/// Broadcast a 1D bias `[dim]` to match an arbitrary target shape.
1059///
1060/// For example, bias `[d_sae]` with target shape `[batch, seq, d_sae]`
1061/// is reshaped to `[1, 1, d_sae]` then broadcast to `[batch, seq, d_sae]`.
1062///
1063/// # Shapes
1064/// - `bias`: `[dim]` — 1D bias vector
1065/// - `target_shape`: the shape to broadcast into (last dim must match)
1066/// - returns: bias with same shape as `target_shape`
1067fn broadcast_bias(bias: &Tensor, target_shape: &[usize]) -> Result<Tensor> {
1068    let ndim = target_shape.len();
1069    if ndim <= 1 {
1070        return Ok(bias.clone());
1071    }
1072    // Reshape [dim] → [1, 1, ..., dim] with (ndim - 1) leading 1s.
1073    let mut shape = vec![1_usize; ndim];
1074    let last_dim = *target_shape
1075        .last()
1076        .ok_or_else(|| MIError::Config("cannot broadcast bias to empty shape".into()))?;
1077    if let Some(slot) = shape.last_mut() {
1078        *slot = last_dim;
1079    }
1080    let reshaped = bias.reshape(shape.as_slice())?;
1081    Ok(reshaped.broadcast_as(target_shape)?)
1082}
1083
1084/// Convert a safetensors `TensorView` to a candle `Tensor`.
1085///
1086/// # Shapes
1087/// - Preserves the original tensor shape from safetensors.
1088///
1089/// # Errors
1090///
1091/// Returns [`MIError::Config`] if the tensor dtype is not supported (BF16, F16, F32).
1092/// Returns [`MIError::Model`] on tensor construction failure.
1093fn tensor_from_view(view: &safetensors::tensor::TensorView<'_>, device: &Device) -> Result<Tensor> {
1094    let shape: Vec<usize> = view.shape().to_vec();
1095    #[allow(clippy::wildcard_enum_match_arm)]
1096    // EXHAUSTIVE: safetensors exposes many dtypes; SAEs only use float types
1097    let dtype = match view.dtype() {
1098        safetensors::Dtype::BF16 => DType::BF16,
1099        safetensors::Dtype::F16 => DType::F16,
1100        safetensors::Dtype::F32 => DType::F32,
1101        other => {
1102            return Err(MIError::Config(format!(
1103                "unsupported SAE tensor dtype: {other:?}"
1104            )));
1105        }
1106    };
1107    let tensor = Tensor::from_raw_buffer(view.data(), dtype, &shape, device)?;
1108    Ok(tensor)
1109}
1110
1111/// Load a named tensor from safetensors.
1112fn load_tensor(st: &SafeTensors<'_>, name: &str, device: &Device) -> Result<Tensor> {
1113    let view = st
1114        .tensor(name)
1115        .map_err(|e| MIError::Config(format!("tensor '{name}' not found: {e}")))?;
1116    tensor_from_view(&view, device)
1117}
1118
1119/// Validate that a tensor has the expected shape.
1120fn validate_shape(tensor: &Tensor, expected: &[usize], name: &str) -> Result<()> {
1121    if tensor.dims() != expected {
1122        return Err(MIError::Config(format!(
1123            "SAE tensor '{name}' shape mismatch: expected {expected:?}, got {:?}",
1124            tensor.dims()
1125        )));
1126    }
1127    Ok(())
1128}
1129
1130// ---------------------------------------------------------------------------
1131// Tests
1132// ---------------------------------------------------------------------------
1133
1134#[cfg(test)]
1135mod tests {
1136    use super::*;
1137
1138    #[test]
1139    fn sae_feature_id_display() {
1140        let fid = SaeFeatureId { index: 42 };
1141        assert_eq!(fid.to_string(), "SAE:42");
1142    }
1143
1144    #[test]
1145    fn resolve_architecture_relu_default() {
1146        let arch = resolve_architecture(None, None, None).unwrap();
1147        assert_eq!(arch, SaeArchitecture::ReLU);
1148    }
1149
1150    #[test]
1151    fn resolve_architecture_relu_explicit() {
1152        let arch = resolve_architecture(Some("standard"), Some("relu"), None).unwrap();
1153        assert_eq!(arch, SaeArchitecture::ReLU);
1154    }
1155
1156    #[test]
1157    fn resolve_architecture_jumprelu() {
1158        let arch = resolve_architecture(Some("jumprelu"), None, None).unwrap();
1159        assert_eq!(arch, SaeArchitecture::JumpReLU);
1160    }
1161
1162    #[test]
1163    fn resolve_architecture_jumprelu_from_activation() {
1164        let arch = resolve_architecture(None, Some("jumprelu"), None).unwrap();
1165        assert_eq!(arch, SaeArchitecture::JumpReLU);
1166    }
1167
1168    #[test]
1169    fn resolve_architecture_topk() {
1170        let kwargs = serde_json::json!({"k": 32});
1171        let arch = resolve_architecture(Some("topk"), None, Some(&kwargs)).unwrap();
1172        assert_eq!(arch, SaeArchitecture::TopK { k: 32 });
1173    }
1174
1175    #[test]
1176    fn resolve_architecture_topk_from_activation() {
1177        let kwargs = serde_json::json!({"k": 64});
1178        let arch = resolve_architecture(None, Some("topk"), Some(&kwargs)).unwrap();
1179        assert_eq!(arch, SaeArchitecture::TopK { k: 64 });
1180    }
1181
1182    #[test]
1183    fn resolve_architecture_topk_missing_k() {
1184        let result = resolve_architecture(Some("topk"), None, None);
1185        assert!(result.is_err());
1186    }
1187
1188    #[test]
1189    fn resolve_architecture_unknown() {
1190        let result = resolve_architecture(Some("gated"), None, None);
1191        assert!(result.is_err());
1192    }
1193
1194    #[test]
1195    fn parse_config_minimal() {
1196        let json = r#"{
1197            "d_in": 2304,
1198            "d_sae": 16384,
1199            "hook_name": "blocks.5.hook_resid_post"
1200        }"#;
1201        let raw: RawSaeConfig = serde_json::from_str(json).unwrap();
1202        let config = parse_sae_config(raw).unwrap();
1203        assert_eq!(config.d_in, 2304);
1204        assert_eq!(config.d_sae, 16384);
1205        assert_eq!(config.architecture, SaeArchitecture::ReLU);
1206        assert_eq!(config.hook_point, HookPoint::ResidPost(5));
1207        assert!(!config.apply_b_dec_to_input);
1208    }
1209
1210    #[test]
1211    fn parse_config_jumprelu() {
1212        let json = r#"{
1213            "d_in": 2304,
1214            "d_sae": 16384,
1215            "architecture": "jumprelu",
1216            "hook_name": "blocks.20.hook_resid_post",
1217            "apply_b_dec_to_input": true,
1218            "normalize_activations": "expected_average_only_in"
1219        }"#;
1220        let raw: RawSaeConfig = serde_json::from_str(json).unwrap();
1221        let config = parse_sae_config(raw).unwrap();
1222        assert_eq!(config.architecture, SaeArchitecture::JumpReLU);
1223        assert_eq!(config.hook_point, HookPoint::ResidPost(20));
1224        assert!(config.apply_b_dec_to_input);
1225        assert_eq!(
1226            config.normalize_activations,
1227            NormalizeActivations::ExpectedAverageOnlyIn
1228        );
1229    }
1230
1231    #[test]
1232    fn parse_config_topk() {
1233        let json = r#"{
1234            "d_in": 2304,
1235            "d_sae": 65536,
1236            "activation_fn_str": "topk",
1237            "activation_fn_kwargs": {"k": 32},
1238            "hook_name": "blocks.10.hook_resid_post"
1239        }"#;
1240        let raw: RawSaeConfig = serde_json::from_str(json).unwrap();
1241        let config = parse_sae_config(raw).unwrap();
1242        assert_eq!(config.architecture, SaeArchitecture::TopK { k: 32 });
1243    }
1244
1245    #[test]
1246    fn topk_cpu_basic() {
1247        let data = Tensor::new(&[[5.0_f32, 3.0, 1.0, 4.0, 2.0]], &Device::Cpu).unwrap();
1248        let result = topk_cpu(&data, 2).unwrap();
1249        let vals: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
1250        assert_eq!(vals, vec![5.0, 0.0, 0.0, 4.0, 0.0]);
1251    }
1252
1253    #[test]
1254    fn topk_cpu_all_kept() {
1255        let data = Tensor::new(&[[1.0_f32, 2.0, 3.0]], &Device::Cpu).unwrap();
1256        let result = topk_cpu(&data, 5).unwrap();
1257        let vals: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
1258        assert_eq!(vals, vec![1.0, 2.0, 3.0]);
1259    }
1260
1261    #[test]
1262    fn topk_cpu_none_kept() {
1263        let data = Tensor::new(&[[1.0_f32, 2.0, 3.0]], &Device::Cpu).unwrap();
1264        let result = topk_cpu(&data, 0).unwrap();
1265        let vals: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
1266        assert_eq!(vals, vec![0.0, 0.0, 0.0]);
1267    }
1268
1269    #[test]
1270    fn topk_cpu_batched() {
1271        let data = Tensor::new(
1272            &[[5.0_f32, 3.0, 1.0, 4.0, 2.0], [1.0, 2.0, 3.0, 4.0, 5.0]],
1273            &Device::Cpu,
1274        )
1275        .unwrap();
1276        let result = topk_cpu(&data, 3).unwrap();
1277        let vals: Vec<Vec<f32>> = result.to_vec2().unwrap();
1278        assert_eq!(vals[0], vec![5.0, 3.0, 0.0, 4.0, 0.0]);
1279        assert_eq!(vals[1], vec![0.0, 0.0, 3.0, 4.0, 5.0]);
1280    }
1281
1282    #[test]
1283    fn sparse_activations_sae() {
1284        let features = vec![
1285            (SaeFeatureId { index: 5 }, 3.0),
1286            (SaeFeatureId { index: 2 }, 2.0),
1287            (SaeFeatureId { index: 8 }, 1.0),
1288        ];
1289        let sparse = SparseActivations { features };
1290        assert_eq!(sparse.len(), 3);
1291        assert!(!sparse.is_empty());
1292    }
1293
1294    #[test]
1295    fn sparse_activations_truncate_sae() {
1296        let features = vec![
1297            (SaeFeatureId { index: 5 }, 3.0),
1298            (SaeFeatureId { index: 2 }, 2.0),
1299            (SaeFeatureId { index: 8 }, 1.0),
1300        ];
1301        let mut sparse = SparseActivations { features };
1302        sparse.truncate(2);
1303        assert_eq!(sparse.len(), 2);
1304        assert_eq!(sparse.features[0].0.index, 5);
1305        assert_eq!(sparse.features[1].0.index, 2);
1306    }
1307
1308    #[test]
1309    fn encode_decode_roundtrip_shapes() {
1310        // Create a tiny SAE for shape testing.
1311        let d_in = 4;
1312        let d_sae = 8;
1313        let device = Device::Cpu;
1314
1315        let w_enc = Tensor::randn(0.0_f32, 1.0, (d_in, d_sae), &device).unwrap();
1316        let w_dec = Tensor::randn(0.0_f32, 1.0, (d_sae, d_in), &device).unwrap();
1317        let b_enc = Tensor::zeros(d_sae, DType::F32, &device).unwrap();
1318        let b_dec = Tensor::zeros(d_in, DType::F32, &device).unwrap();
1319
1320        let sae = SparseAutoencoder {
1321            config: SaeConfig {
1322                d_in,
1323                d_sae,
1324                architecture: SaeArchitecture::ReLU,
1325                hook_name: "blocks.0.hook_resid_post".into(),
1326                hook_point: HookPoint::ResidPost(0),
1327                apply_b_dec_to_input: false,
1328                normalize_activations: NormalizeActivations::None,
1329            },
1330            w_enc,
1331            w_dec,
1332            b_enc,
1333            b_dec,
1334            threshold: None,
1335        };
1336
1337        // Test 1D input.
1338        let x1 = Tensor::randn(0.0_f32, 1.0, (d_in,), &device).unwrap();
1339        let encoded = sae.encode(&x1.unsqueeze(0).unwrap()).unwrap();
1340        assert_eq!(encoded.dims(), &[1, d_sae]);
1341
1342        // Test 2D input.
1343        let x2 = Tensor::randn(0.0_f32, 1.0, (3, d_in), &device).unwrap();
1344        let encoded = sae.encode(&x2).unwrap();
1345        assert_eq!(encoded.dims(), &[3, d_sae]);
1346        let decoded = sae.decode(&encoded).unwrap();
1347        assert_eq!(decoded.dims(), &[3, d_in]);
1348
1349        // Test 3D input.
1350        let x3 = Tensor::randn(0.0_f32, 1.0, (2, 5, d_in), &device).unwrap();
1351        let encoded = sae.encode(&x3).unwrap();
1352        assert_eq!(encoded.dims(), &[2, 5, d_sae]);
1353        let decoded = sae.decode(&encoded).unwrap();
1354        assert_eq!(decoded.dims(), &[2, 5, d_in]);
1355
1356        // Test reconstruction.
1357        let x_hat = sae.reconstruct(&x2).unwrap();
1358        assert_eq!(x_hat.dims(), &[3, d_in]);
1359
1360        // Test reconstruction error.
1361        let mse = sae.reconstruction_error(&x2).unwrap();
1362        assert!(mse >= 0.0);
1363    }
1364
1365    #[test]
1366    fn encode_sparse_basic() {
1367        let d_in = 4;
1368        let d_sae = 8;
1369        let device = Device::Cpu;
1370
1371        // Use identity-like encoder to get predictable output.
1372        let mut w_enc_data = vec![0.0_f32; d_in * d_sae];
1373        // Map input dim 0 → feature 0, dim 1 → feature 1, etc.
1374        for i in 0..d_in {
1375            w_enc_data[i * d_sae + i] = 1.0;
1376        }
1377        let w_enc = Tensor::from_vec(w_enc_data, (d_in, d_sae), &device).unwrap();
1378        let w_dec = Tensor::randn(0.0_f32, 1.0, (d_sae, d_in), &device).unwrap();
1379        let b_enc = Tensor::zeros(d_sae, DType::F32, &device).unwrap();
1380        let b_dec = Tensor::zeros(d_in, DType::F32, &device).unwrap();
1381
1382        let sae = SparseAutoencoder {
1383            config: SaeConfig {
1384                d_in,
1385                d_sae,
1386                architecture: SaeArchitecture::ReLU,
1387                hook_name: "blocks.0.hook_resid_post".into(),
1388                hook_point: HookPoint::ResidPost(0),
1389                apply_b_dec_to_input: false,
1390                normalize_activations: NormalizeActivations::None,
1391            },
1392            w_enc,
1393            w_dec,
1394            b_enc,
1395            b_dec,
1396            threshold: None,
1397        };
1398
1399        let x = Tensor::new(&[2.0_f32, -1.0, 3.0, 0.5], &device).unwrap();
1400        let sparse = sae.encode_sparse(&x).unwrap();
1401
1402        // Only positive values should appear: 2.0, 3.0, 0.5
1403        assert_eq!(sparse.len(), 3);
1404        // Should be sorted descending.
1405        assert_eq!(sparse.features[0].0.index, 2); // 3.0
1406        assert_eq!(sparse.features[1].0.index, 0); // 2.0
1407        assert_eq!(sparse.features[2].0.index, 3); // 0.5
1408    }
1409
1410    #[test]
1411    fn decoder_vector_basic() {
1412        let d_in = 4;
1413        let d_sae = 8;
1414        let device = Device::Cpu;
1415
1416        let w_dec = Tensor::randn(0.0_f32, 1.0, (d_sae, d_in), &device).unwrap();
1417        let sae = SparseAutoencoder {
1418            config: SaeConfig {
1419                d_in,
1420                d_sae,
1421                architecture: SaeArchitecture::ReLU,
1422                hook_name: "blocks.0.hook_resid_post".into(),
1423                hook_point: HookPoint::ResidPost(0),
1424                apply_b_dec_to_input: false,
1425                normalize_activations: NormalizeActivations::None,
1426            },
1427            w_enc: Tensor::zeros((d_in, d_sae), DType::F32, &device).unwrap(),
1428            w_dec: w_dec.clone(),
1429            b_enc: Tensor::zeros(d_sae, DType::F32, &device).unwrap(),
1430            b_dec: Tensor::zeros(d_in, DType::F32, &device).unwrap(),
1431            threshold: None,
1432        };
1433
1434        let vec0 = sae.decoder_vector(0).unwrap();
1435        assert_eq!(vec0.dims(), &[d_in]);
1436
1437        // Out of range should error.
1438        assert!(sae.decoder_vector(d_sae).is_err());
1439    }
1440
1441    #[test]
1442    fn prepare_injection_basic() {
1443        let d_in = 4;
1444        let d_sae = 8;
1445        let device = Device::Cpu;
1446
1447        let sae = SparseAutoencoder {
1448            config: SaeConfig {
1449                d_in,
1450                d_sae,
1451                architecture: SaeArchitecture::ReLU,
1452                hook_name: "blocks.0.hook_resid_post".into(),
1453                hook_point: HookPoint::ResidPost(0),
1454                apply_b_dec_to_input: false,
1455                normalize_activations: NormalizeActivations::None,
1456            },
1457            w_enc: Tensor::zeros((d_in, d_sae), DType::F32, &device).unwrap(),
1458            w_dec: Tensor::ones((d_sae, d_in), DType::F32, &device).unwrap(),
1459            b_enc: Tensor::zeros(d_sae, DType::F32, &device).unwrap(),
1460            b_dec: Tensor::zeros(d_in, DType::F32, &device).unwrap(),
1461            threshold: None,
1462        };
1463
1464        let features = vec![(0_usize, 1.0_f32), (1, 0.5)];
1465        let hooks = sae
1466            .prepare_hook_injection(&features, 2, 5, &device)
1467            .unwrap();
1468        assert!(!hooks.is_empty());
1469    }
1470}