Skip to main content

candle_mi/sae/
mod.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Sparse Autoencoder (SAE) support.
4//!
5//! Loads pre-trained SAE weights from SAELens-format safetensors + `cfg.json`
6//! or from Gemma Scope NPZ archives, encodes model activations into sparse
7//! feature vectors, decodes back to activation space, and produces steering
8//! vectors for injection.
9//!
10//! Each SAE targets a single hook point in the model (e.g., `resid_post` at
11//! layer 5). Multiple SAEs can be loaded independently for different hook
12//! points.
13//!
14//! # SAE Architecture
15//!
16//! A Sparse Autoencoder implements:
17//! ```text
18//! Encode:  features = activation_fn(x @ W_enc + b_enc)
19//! Decode:  x_hat = features @ W_dec + b_dec
20//! ```
21//!
22//! Supported activation functions:
23//! - **`ReLU`**: `features = ReLU(pre_acts)`
24//! - **`JumpReLU`**: `features = pre_acts * (pre_acts > threshold)`
25//! - **`TopK`**: keep only the top-k pre-activations, zero the rest
26//!
27//! # Weight File Formats
28//!
29//! ## `SAELens` safetensors format
30//!
31//! Each SAE directory contains:
32//! - `cfg.json`: configuration (`d_in`, `d_sae`, architecture, `hook_name`, ...)
33//! - `sae_weights.safetensors` (or `model.safetensors`): weight tensors
34//!
35//! ## Gemma Scope NPZ format
36//!
37//! A single `params.npz` file (`NumPy` ZIP archive) containing:
38//! - `W_enc.npy`: shape `[d_in, d_sae]` — encoder weight matrix
39//! - `W_dec.npy`: shape `[d_sae, d_in]` — decoder weight matrix
40//! - `b_enc.npy`: shape `[d_sae]` — encoder bias
41//! - `b_dec.npy`: shape `[d_in]` — decoder bias
42//! - `threshold.npy`: shape `[d_sae]` — `JumpReLU` threshold (optional)
43//!
44//! Architecture is auto-detected: presence of `threshold` → `JumpReLU`,
45//! otherwise `ReLU`.
46//!
47//! ## Tensor names (both formats)
48//!
49//! - `W_enc`: shape `[d_in, d_sae]` — encoder weight matrix
50//! - `W_dec`: shape `[d_sae, d_in]` — decoder weight matrix
51//! - `b_enc`: shape `[d_sae]` — encoder bias
52//! - `b_dec`: shape `[d_in]` — decoder bias
53//! - `threshold`: shape `[d_sae]` — `JumpReLU` threshold (optional)
54
55// `npz` adapts anamnesis NPZ tensors to candle. Originally private to the
56// `sae` module; now `pub(crate)` so the CLT GemmaScope loader can share
57// the same NPZ → candle bridge without duplicating the F32/F64 conversion.
58pub(crate) mod npz;
59
60use std::path::Path;
61
62use candle_core::{DType, Device, Tensor};
63use safetensors::tensor::SafeTensors;
64use tracing::info;
65
66use crate::error::{MIError, Result};
67use crate::hooks::{HookPoint, HookSpec, Intervention};
68use crate::sparse::{FeatureId, SparseActivations};
69
70// ---------------------------------------------------------------------------
71// Public types
72// ---------------------------------------------------------------------------
73
74/// Identifies a single SAE feature by its index within the dictionary.
75#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
76pub struct SaeFeatureId {
77    /// Feature index within the SAE dictionary (`0..d_sae`).
78    pub index: usize,
79}
80
81impl std::fmt::Display for SaeFeatureId {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        write!(f, "SAE:{}", self.index)
84    }
85}
86
87impl FeatureId for SaeFeatureId {}
88
89/// Activation function architecture for the SAE encoder.
90#[non_exhaustive]
91#[derive(Debug, Clone, PartialEq, Eq)]
92pub enum SaeArchitecture {
93    /// Standard `ReLU`: `features = ReLU(W_enc @ x + b_enc)`.
94    ReLU,
95    /// `JumpReLU`: `features = pre_acts * (pre_acts > threshold)`.
96    /// Requires a learned `threshold` tensor of shape `[d_sae]`.
97    JumpReLU,
98    /// `TopK`: keep only the top-k pre-activations, zero the rest.
99    TopK {
100        /// Number of features to keep active.
101        k: usize,
102    },
103}
104
105/// Input normalization strategy for SAE encoding.
106#[non_exhaustive]
107#[derive(Debug, Clone, PartialEq, Eq)]
108pub enum NormalizeActivations {
109    /// No normalization.
110    None,
111    /// Normalize by expected average L2 norm (estimated during training).
112    ExpectedAverageOnlyIn,
113}
114
115/// Strategy for `TopK` activation computation.
116#[non_exhaustive]
117#[derive(Debug, Clone, PartialEq, Eq)]
118pub enum TopKStrategy {
119    /// Automatically select based on device (CPU → direct, GPU → sort-based).
120    Auto,
121    /// Force CPU-side computation (transfer to CPU, compute mask, transfer back).
122    Cpu,
123    /// Force GPU-side sort-based computation.
124    Gpu,
125}
126
127/// Configuration for a Sparse Autoencoder, parsed from `cfg.json`.
128#[derive(Debug, Clone)]
129pub struct SaeConfig {
130    /// Input dimension (must match model hidden size at the hook point).
131    pub d_in: usize,
132    /// Dictionary size (number of SAE features).
133    pub d_sae: usize,
134    /// Encoder architecture (activation function).
135    pub architecture: SaeArchitecture,
136    /// `SAELens` hook name string (e.g., `"blocks.5.hook_resid_post"`).
137    pub hook_name: String,
138    /// Parsed hook point from the hook name.
139    pub hook_point: HookPoint,
140    /// Whether to subtract `b_dec` from input before encoding.
141    pub apply_b_dec_to_input: bool,
142    /// Input normalization strategy.
143    pub normalize_activations: NormalizeActivations,
144}
145
146// ---------------------------------------------------------------------------
147// Internal config parsing
148// ---------------------------------------------------------------------------
149
150#[derive(serde::Deserialize)]
151#[allow(clippy::missing_docs_in_private_items)]
152struct RawSaeConfig {
153    d_in: usize,
154    d_sae: usize,
155    #[serde(default)]
156    architecture: Option<String>,
157    #[serde(default)]
158    activation_fn_str: Option<String>,
159    #[serde(default)]
160    activation_fn_kwargs: Option<serde_json::Value>,
161    #[serde(default)]
162    hook_name: Option<String>,
163    #[serde(default)]
164    hook_point: Option<String>,
165    #[serde(default)]
166    apply_b_dec_to_input: bool,
167    #[serde(default)]
168    normalize_activations: Option<String>,
169}
170
171/// Parse a `RawSaeConfig` into a validated `SaeConfig`.
172fn parse_sae_config(raw: RawSaeConfig) -> Result<SaeConfig> {
173    // Resolve architecture from `architecture` or `activation_fn_str`.
174    let architecture = resolve_architecture(
175        raw.architecture.as_deref(),
176        raw.activation_fn_str.as_deref(),
177        raw.activation_fn_kwargs.as_ref(),
178    )?;
179
180    // Resolve hook name from `hook_name` or `hook_point`.
181    let hook_name = raw
182        .hook_name
183        .or(raw.hook_point)
184        .unwrap_or_else(|| "unknown".to_owned());
185
186    // Parse hook name to HookPoint via FromStr.
187    let hook_point: HookPoint = hook_name
188        .parse()
189        .unwrap_or_else(|_: std::convert::Infallible| {
190            // EXHAUSTIVE: Infallible can never happen, parse always succeeds
191            unreachable!()
192        });
193
194    let normalize_activations = match raw.normalize_activations.as_deref() {
195        Some("expected_average_only_in") => NormalizeActivations::ExpectedAverageOnlyIn,
196        _ => NormalizeActivations::None,
197    };
198
199    Ok(SaeConfig {
200        d_in: raw.d_in,
201        d_sae: raw.d_sae,
202        architecture,
203        hook_name,
204        hook_point,
205        apply_b_dec_to_input: raw.apply_b_dec_to_input,
206        normalize_activations,
207    })
208}
209
210/// Resolve SAE architecture from config fields.
211fn resolve_architecture(
212    architecture: Option<&str>,
213    activation_fn_str: Option<&str>,
214    activation_fn_kwargs: Option<&serde_json::Value>,
215) -> Result<SaeArchitecture> {
216    // Check `architecture` field first.
217    match architecture {
218        Some("jumprelu") => return Ok(SaeArchitecture::JumpReLU),
219        Some("topk") => {
220            let k = extract_topk_k(activation_fn_kwargs)?;
221            return Ok(SaeArchitecture::TopK { k });
222        }
223        Some("standard") | None => {} // fall through to activation_fn_str
224        Some(other) => {
225            return Err(MIError::Config(format!(
226                "unsupported SAE architecture: {other:?}"
227            )));
228        }
229    }
230
231    // Fall back to `activation_fn_str`.
232    match activation_fn_str {
233        Some("relu") | None => Ok(SaeArchitecture::ReLU),
234        Some("jumprelu") => Ok(SaeArchitecture::JumpReLU),
235        Some("topk") => {
236            let k = extract_topk_k(activation_fn_kwargs)?;
237            Ok(SaeArchitecture::TopK { k })
238        }
239        Some(other) => Err(MIError::Config(format!(
240            "unsupported SAE activation function: {other:?}"
241        ))),
242    }
243}
244
245/// Extract `k` from `activation_fn_kwargs.k`.
246fn extract_topk_k(kwargs: Option<&serde_json::Value>) -> Result<usize> {
247    let k = kwargs
248        .and_then(|v| v.get("k"))
249        .and_then(serde_json::Value::as_u64)
250        .ok_or_else(|| {
251            MIError::Config("TopK SAE requires activation_fn_kwargs.k in cfg.json".into())
252        })?;
253    let k_usize = usize::try_from(k)
254        .map_err(|_| MIError::Config(format!("TopK k value {k} too large for usize")))?;
255    Ok(k_usize)
256}
257
258// ---------------------------------------------------------------------------
259// SparseAutoencoder
260// ---------------------------------------------------------------------------
261
262/// A Sparse Autoencoder for mechanistic interpretability.
263///
264/// Loads SAE weights from SAELens-format safetensors + `cfg.json`,
265/// encodes model activations into sparse feature vectors, decodes
266/// back to activation space, and produces steering vectors for injection.
267///
268/// Each SAE targets a single hook point in the model (e.g., `resid_post`
269/// at layer 5). Multiple SAEs can be loaded independently for different
270/// hook points.
271///
272/// # Example
273///
274/// ```no_run
275/// # fn main() -> candle_mi::Result<()> {
276/// use candle_mi::sae::SparseAutoencoder;
277/// use candle_core::Device;
278///
279/// let sae = SparseAutoencoder::from_pretrained(
280///     "jbloom/Gemma-2-2B-Residual-Stream-SAEs",
281///     "gemma-2-2b-res-jb/blocks.20.hook_resid_post",
282///     &Device::Cpu,
283/// )?;
284/// println!("SAE: d_in={}, d_sae={}", sae.d_in(), sae.d_sae());
285/// # Ok(())
286/// # }
287/// ```
288pub struct SparseAutoencoder {
289    /// SAE configuration parsed from `cfg.json`.
290    config: SaeConfig,
291    /// Encoder weight matrix.
292    ///
293    /// # Shapes
294    /// - `w_enc`: `[d_in, d_sae]`
295    w_enc: Tensor,
296    /// Decoder weight matrix.
297    ///
298    /// # Shapes
299    /// - `w_dec`: `[d_sae, d_in]`
300    w_dec: Tensor,
301    /// Encoder bias vector.
302    ///
303    /// # Shapes
304    /// - `b_enc`: `[d_sae]`
305    b_enc: Tensor,
306    /// Decoder bias vector.
307    ///
308    /// # Shapes
309    /// - `b_dec`: `[d_in]`
310    b_dec: Tensor,
311    /// `JumpReLU` threshold (only present for `JumpReLU` architecture).
312    ///
313    /// # Shapes
314    /// - `threshold`: `[d_sae]`
315    threshold: Option<Tensor>,
316}
317
318impl SparseAutoencoder {
319    // --- Loading ---
320
321    /// Load an SAE from a local directory containing safetensors + `cfg.json`.
322    ///
323    /// Expects either `sae_weights.safetensors` or `model.safetensors`
324    /// plus a `cfg.json` file.
325    ///
326    /// # Errors
327    ///
328    /// Returns [`MIError::Config`] if `cfg.json` is missing or malformed.
329    /// Returns [`MIError::Config`] if weight shapes don't match `cfg.json` dimensions.
330    /// Returns [`MIError::Model`] on tensor loading failure.
331    /// Returns [`MIError::Io`] if files cannot be read.
332    pub fn from_local(dir: &Path, device: &Device) -> Result<Self> {
333        // Parse cfg.json.
334        let cfg_path = dir.join("cfg.json");
335        if !cfg_path.exists() {
336            return Err(MIError::Config(format!(
337                "cfg.json not found in {}",
338                dir.display()
339            )));
340        }
341        let cfg_text = std::fs::read_to_string(&cfg_path)?;
342        let raw: RawSaeConfig = serde_json::from_str(&cfg_text)
343            .map_err(|e| MIError::Config(format!("failed to parse cfg.json: {e}")))?;
344        let config = parse_sae_config(raw)?;
345
346        info!(
347            "SAE config: d_in={}, d_sae={}, arch={:?}, hook={}",
348            config.d_in, config.d_sae, config.architecture, config.hook_name
349        );
350
351        // Find safetensors file.
352        let weights_path = if dir.join("sae_weights.safetensors").exists() {
353            dir.join("sae_weights.safetensors")
354        } else if dir.join("model.safetensors").exists() {
355            dir.join("model.safetensors")
356        } else {
357            return Err(MIError::Config(format!(
358                "no safetensors file found in {}",
359                dir.display()
360            )));
361        };
362
363        // Load weights.
364        let data = std::fs::read(&weights_path)?;
365        let st = SafeTensors::deserialize(&data)
366            .map_err(|e| MIError::Config(format!("failed to deserialize SAE weights: {e}")))?;
367
368        let w_enc = load_tensor(&st, "W_enc", device)?;
369        let w_dec = load_tensor(&st, "W_dec", device)?;
370        let b_enc = load_tensor(&st, "b_enc", device)?;
371        let b_dec = load_tensor(&st, "b_dec", device)?;
372        let threshold = st
373            .tensor("threshold")
374            .ok()
375            .map(|v| tensor_from_view(&v, device))
376            .transpose()?;
377
378        // PROMOTE: F32 for numerical stability in matmul and bias add
379        let w_enc = w_enc.to_dtype(DType::F32)?;
380        let w_dec = w_dec.to_dtype(DType::F32)?;
381        let b_enc = b_enc.to_dtype(DType::F32)?;
382        let b_dec = b_dec.to_dtype(DType::F32)?;
383        let threshold = threshold.map(|t| t.to_dtype(DType::F32)).transpose()?;
384
385        // Validate shapes.
386        validate_shape(&w_enc, &[config.d_in, config.d_sae], "W_enc")?;
387        validate_shape(&w_dec, &[config.d_sae, config.d_in], "W_dec")?;
388        validate_shape(&b_enc, &[config.d_sae], "b_enc")?;
389        validate_shape(&b_dec, &[config.d_in], "b_dec")?;
390        if let Some(ref t) = threshold {
391            validate_shape(t, &[config.d_sae], "threshold")?;
392        }
393
394        // Validate JumpReLU has threshold.
395        if config.architecture == SaeArchitecture::JumpReLU && threshold.is_none() {
396            return Err(MIError::Config(
397                "JumpReLU SAE requires 'threshold' tensor in weights file".into(),
398            ));
399        }
400
401        info!(
402            "SAE loaded: {} weights on {:?}",
403            weights_path.display(),
404            device
405        );
406
407        Ok(Self {
408            config,
409            w_enc,
410            w_dec,
411            b_enc,
412            b_dec,
413            threshold,
414        })
415    }
416
417    /// Load an SAE from a Gemma Scope NPZ file (`params.npz`).
418    ///
419    /// The NPZ file must contain `W_enc`, `W_dec`, `b_enc`, `b_dec` arrays,
420    /// and optionally `threshold` (for `JumpReLU`). Config is inferred from
421    /// tensor shapes since NPZ files have no `cfg.json`.
422    ///
423    /// # Arguments
424    /// * `npz_path` — Path to the `params.npz` file
425    /// * `hook_layer` — Which model layer this SAE hooks into
426    /// * `device` — Target device (CPU or CUDA)
427    ///
428    /// # Errors
429    ///
430    /// Returns [`MIError::Config`] if required tensors are missing or shapes
431    /// are inconsistent.
432    /// Returns [`MIError::Io`] if the file cannot be read.
433    pub fn from_npz(npz_path: &Path, hook_layer: usize, device: &Device) -> Result<Self> {
434        info!("Loading SAE from NPZ: {}", npz_path.display());
435        let tensors = npz::load_npz(npz_path, device)?;
436
437        let w_enc = tensors
438            .get("W_enc")
439            .ok_or_else(|| MIError::Config("NPZ missing W_enc".into()))?
440            .to_dtype(DType::F32)?;
441        let w_dec = tensors
442            .get("W_dec")
443            .ok_or_else(|| MIError::Config("NPZ missing W_dec".into()))?
444            .to_dtype(DType::F32)?;
445        let b_enc = tensors
446            .get("b_enc")
447            .ok_or_else(|| MIError::Config("NPZ missing b_enc".into()))?
448            .to_dtype(DType::F32)?;
449        let b_dec = tensors
450            .get("b_dec")
451            .ok_or_else(|| MIError::Config("NPZ missing b_dec".into()))?
452            .to_dtype(DType::F32)?;
453        let threshold = tensors
454            .get("threshold")
455            .map(|t| t.to_dtype(DType::F32))
456            .transpose()?;
457
458        // Infer dimensions from W_enc: [d_in, d_sae].
459        let w_enc_dims = w_enc.dims();
460        if w_enc_dims.len() != 2 {
461            return Err(MIError::Config(format!(
462                "W_enc expected 2 dims, got {}",
463                w_enc_dims.len()
464            )));
465        }
466        let d_in = *w_enc_dims
467            .first()
468            .ok_or_else(|| MIError::Config("W_enc has no dimensions".into()))?;
469        let d_sae = *w_enc_dims
470            .get(1)
471            .ok_or_else(|| MIError::Config("W_enc has no second dimension".into()))?;
472
473        // Validate shapes.
474        validate_shape(&w_enc, &[d_in, d_sae], "W_enc")?;
475        validate_shape(&w_dec, &[d_sae, d_in], "W_dec")?;
476        validate_shape(&b_enc, &[d_sae], "b_enc")?;
477        validate_shape(&b_dec, &[d_in], "b_dec")?;
478        if let Some(ref t) = threshold {
479            validate_shape(t, &[d_sae], "threshold")?;
480        }
481
482        // Auto-detect architecture: threshold present → JumpReLU, else ReLU.
483        let architecture = if threshold.is_some() {
484            SaeArchitecture::JumpReLU
485        } else {
486            SaeArchitecture::ReLU
487        };
488
489        let hook_name = format!("blocks.{hook_layer}.hook_resid_post");
490        let hook_point = hook_name
491            .parse::<HookPoint>()
492            .map_err(|e| MIError::Config(format!("failed to parse hook name: {e}")))?;
493
494        let config = SaeConfig {
495            d_in,
496            d_sae,
497            architecture,
498            hook_name,
499            hook_point,
500            apply_b_dec_to_input: false,
501            normalize_activations: NormalizeActivations::None,
502        };
503
504        info!(
505            "SAE from NPZ: d_in={d_in}, d_sae={d_sae}, arch={:?}, hook={}",
506            config.architecture, config.hook_name
507        );
508
509        Ok(Self {
510            config,
511            w_enc,
512            w_dec,
513            b_enc,
514            b_dec,
515            threshold,
516        })
517    }
518
519    /// Load an SAE from a `HuggingFace` repository containing an NPZ file.
520    ///
521    /// Downloads the NPZ file via `hf-fetch-model`, then delegates to
522    /// [`from_npz`](Self::from_npz).
523    ///
524    /// # Arguments
525    /// * `repo_id` — `HuggingFace` repository ID
526    ///   (e.g., `"google/gemma-scope-2b-pt-res"`)
527    /// * `npz_path` — Path within the repo to the NPZ file
528    ///   (e.g., `"layer_0/width_16k/average_l0_105/params.npz"`)
529    /// * `hook_layer` — Which model layer this SAE hooks into
530    /// * `device` — Target device (CPU or CUDA)
531    ///
532    /// # Errors
533    ///
534    /// Returns [`MIError::Download`] if the file cannot be fetched.
535    /// Returns [`MIError::Config`] if the NPZ format is invalid.
536    pub fn from_pretrained_npz(
537        repo_id: &str,
538        npz_path: &str,
539        hook_layer: usize,
540        device: &Device,
541    ) -> Result<Self> {
542        let fetch_config = crate::download::fetch_config_builder()
543            .on_progress(|event| {
544                tracing::info!(
545                    filename = %event.filename,
546                    percent = event.percent,
547                    bytes_downloaded = event.bytes_downloaded,
548                    bytes_total = event.bytes_total,
549                    "SAE NPZ download progress",
550                );
551            })
552            .build()
553            .map_err(|e| MIError::Download(format!("failed to build fetch config: {e}")))?;
554
555        info!("Downloading {npz_path} from {repo_id}");
556        let local_path =
557            hf_fetch_model::download_file_blocking(repo_id.to_owned(), npz_path, &fetch_config)
558                .map_err(|e| MIError::Download(format!("failed to download NPZ: {e}")))?
559                .into_inner();
560
561        Self::from_npz(&local_path, hook_layer, device)
562    }
563
564    /// Load an SAE from a `HuggingFace` repository.
565    ///
566    /// Downloads safetensors + `cfg.json` via `hf-fetch-model`, then delegates
567    /// to [`from_local`](Self::from_local).
568    ///
569    /// # Arguments
570    /// * `repo_id` — `HuggingFace` repository ID
571    ///   (e.g., `"jbloom/Gemma-2-2B-Residual-Stream-SAEs"`)
572    /// * `sae_id` — Subdirectory within the repo
573    ///   (e.g., `"gemma-2-2b-res-jb/blocks.20.hook_resid_post"`)
574    /// * `device` — Target device (CPU or CUDA)
575    ///
576    /// # Errors
577    ///
578    /// Returns [`MIError::Download`] if files cannot be fetched.
579    /// Returns [`MIError::Config`] if the SAE format is invalid.
580    pub fn from_pretrained(repo_id: &str, sae_id: &str, device: &Device) -> Result<Self> {
581        let fetch_config = crate::download::fetch_config_builder()
582            .on_progress(|event| {
583                tracing::info!(
584                    filename = %event.filename,
585                    percent = event.percent,
586                    bytes_downloaded = event.bytes_downloaded,
587                    bytes_total = event.bytes_total,
588                    "SAE download progress",
589                );
590            })
591            .build()
592            .map_err(|e| MIError::Download(format!("failed to build fetch config: {e}")))?;
593
594        // Download cfg.json.
595        let cfg_remote = format!("{sae_id}/cfg.json");
596        info!("Downloading {cfg_remote} from {repo_id}");
597        let cfg_path =
598            hf_fetch_model::download_file_blocking(repo_id.to_owned(), &cfg_remote, &fetch_config)
599                .map_err(|e| MIError::Download(format!("failed to download cfg.json: {e}")))?
600                .into_inner();
601
602        // Download weights: try sae_weights.safetensors, fall back to model.safetensors.
603        let weights_remote = format!("{sae_id}/sae_weights.safetensors");
604        info!("Downloading {weights_remote} from {repo_id}");
605        let weights_path = hf_fetch_model::download_file_blocking(
606            repo_id.to_owned(),
607            &weights_remote,
608            &fetch_config,
609        )
610        .or_else(|_| {
611            let alt_remote = format!("{sae_id}/model.safetensors");
612            info!("Trying {alt_remote} from {repo_id}");
613            hf_fetch_model::download_file_blocking(repo_id.to_owned(), &alt_remote, &fetch_config)
614        })
615        .map_err(|e| MIError::Download(format!("failed to download SAE weights: {e}")))?
616        .into_inner();
617
618        // Both files are in cache; determine the common directory.
619        let dir = cfg_path.parent().ok_or_else(|| {
620            MIError::Config("cannot determine SAE directory from cfg.json path".into())
621        })?;
622
623        // Verify the weights file is in the same directory (or just load by path).
624        // hf-fetch-model may place files in the same cache dir; if not, we need
625        // to construct the dir from the weights path instead.
626        if dir.join("sae_weights.safetensors").exists() || dir.join("model.safetensors").exists() {
627            Self::from_local(dir, device)
628        } else {
629            // Files might be in different cache locations; load manually.
630            let weights_dir = weights_path.parent().ok_or_else(|| {
631                MIError::Config("cannot determine SAE directory from weights path".into())
632            })?;
633            // Copy cfg.json to weights dir if needed.
634            let target_cfg = weights_dir.join("cfg.json");
635            if !target_cfg.exists() {
636                std::fs::copy(&cfg_path, &target_cfg)?;
637            }
638            Self::from_local(weights_dir, device)
639        }
640    }
641
642    // --- Accessors ---
643
644    /// Access the SAE configuration.
645    #[must_use]
646    pub const fn config(&self) -> &SaeConfig {
647        &self.config
648    }
649
650    /// The hook point this SAE targets.
651    #[must_use]
652    pub const fn hook_point(&self) -> &HookPoint {
653        &self.config.hook_point
654    }
655
656    /// Dictionary size (number of features).
657    #[must_use]
658    pub const fn d_sae(&self) -> usize {
659        self.config.d_sae
660    }
661
662    /// Input dimension.
663    #[must_use]
664    pub const fn d_in(&self) -> usize {
665        self.config.d_in
666    }
667
668    // --- Encoding ---
669
670    /// Encode activations into SAE feature space (dense output).
671    ///
672    /// Applies the full encoder: `pre_acts = x @ W_enc + b_enc`, then the
673    /// architecture-specific activation function (`ReLU`, `JumpReLU`, or `TopK`).
674    ///
675    /// Uses [`TopKStrategy::Auto`] for `TopK` SAEs.
676    ///
677    /// # Shapes
678    /// - `x`: `[..., d_in]` — activations with any leading dimensions
679    /// - returns: `[..., d_sae]` — encoded features (mostly sparse)
680    ///
681    /// # Errors
682    ///
683    /// Returns [`MIError::Config`] if the last dimension of `x` != `d_in`.
684    /// Returns [`MIError::Model`] on tensor operation failure.
685    pub fn encode(&self, x: &Tensor) -> Result<Tensor> {
686        self.encode_with_strategy(x, &TopKStrategy::Auto)
687    }
688
689    /// Encode activations with an explicit [`TopKStrategy`].
690    ///
691    /// Same as [`encode()`](Self::encode) but allows overriding the `TopK`
692    /// computation strategy.
693    ///
694    /// # Shapes
695    /// - `x`: `[..., d_in]` — activations with any leading dimensions
696    /// - returns: `[..., d_sae]` — encoded features (mostly sparse)
697    ///
698    /// # Errors
699    ///
700    /// Returns [`MIError::Config`] if the last dimension of `x` != `d_in`.
701    /// Returns [`MIError::Model`] on tensor operation failure.
702    pub fn encode_with_strategy(&self, x: &Tensor, strategy: &TopKStrategy) -> Result<Tensor> {
703        let dims = x.dims();
704        let last_dim = *dims
705            .last()
706            .ok_or_else(|| MIError::Config("cannot encode empty tensor".into()))?;
707        if last_dim != self.config.d_in {
708            return Err(MIError::Config(format!(
709                "input last dim {last_dim} != SAE d_in {}",
710                self.config.d_in
711            )));
712        }
713
714        // PROMOTE: F32 for matmul precision
715        let x_f32 = x.to_dtype(DType::F32)?;
716
717        // Optionally subtract b_dec from input (centering).
718        let x_centered = if self.config.apply_b_dec_to_input {
719            let b_dec = broadcast_bias(&self.b_dec, x_f32.dims())?;
720            (&x_f32 - &b_dec)?
721        } else {
722            x_f32
723        };
724
725        // pre_acts = x @ W_enc + b_enc
726        // [... , d_in] @ [d_in, d_sae] → [..., d_sae]
727        let pre_acts = x_centered.broadcast_matmul(&self.w_enc)?;
728        // Broadcast b_enc [d_sae] to match pre_acts leading dims.
729        let b_enc = broadcast_bias(&self.b_enc, pre_acts.dims())?;
730        let pre_acts = (&pre_acts + &b_enc)?;
731
732        // Apply activation function.
733        match &self.config.architecture {
734            SaeArchitecture::ReLU => Ok(pre_acts.relu()?),
735            SaeArchitecture::JumpReLU => {
736                let threshold = self
737                    .threshold
738                    .as_ref()
739                    .ok_or_else(|| MIError::Config("JumpReLU requires threshold tensor".into()))?;
740                // Broadcast threshold to match pre_acts leading dims.
741                let threshold = broadcast_bias(threshold, pre_acts.dims())?;
742                // mask = (pre_acts > threshold), features = pre_acts * mask
743                let mask = pre_acts.gt(&threshold)?;
744                let mask_f32 = mask.to_dtype(DType::F32)?;
745                Ok((&pre_acts * &mask_f32)?)
746            }
747            SaeArchitecture::TopK { k } => topk_activation(&pre_acts, *k, strategy),
748        }
749    }
750
751    /// Encode a single activation vector into sparse SAE features.
752    ///
753    /// Returns only non-zero features, sorted by magnitude descending.
754    ///
755    /// # Shapes
756    /// - `x`: `[d_in]` — single activation vector
757    /// - returns: [`SparseActivations<SaeFeatureId>`] with `(SaeFeatureId, f32)` pairs
758    ///
759    /// # Errors
760    ///
761    /// Returns [`MIError::Config`] if `x` has wrong dimension.
762    /// Returns [`MIError::Model`] on tensor operation failure.
763    pub fn encode_sparse(&self, x: &Tensor) -> Result<SparseActivations<SaeFeatureId>> {
764        let encoded = self.encode(&x.unsqueeze(0)?)?;
765        let encoded_1d = encoded.squeeze(0)?;
766
767        // Transfer to CPU for sparse extraction.
768        let values: Vec<f32> = encoded_1d.to_vec1()?;
769
770        let mut features: Vec<(SaeFeatureId, f32)> = values
771            .iter()
772            .enumerate()
773            .filter(|&(_, v)| *v > 0.0)
774            .map(|(i, v)| (SaeFeatureId { index: i }, *v))
775            .collect();
776
777        // Sort by activation magnitude (descending).
778        features.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
779
780        Ok(SparseActivations { features })
781    }
782
783    // --- Decoding ---
784
785    /// Decode SAE features back to activation space.
786    ///
787    /// # Shapes
788    /// - `features`: `[..., d_sae]` — encoded feature activations
789    /// - returns: `[..., d_in]` — reconstructed activations
790    ///
791    /// # Errors
792    ///
793    /// Returns [`MIError::Model`] on tensor operation failure.
794    pub fn decode(&self, features: &Tensor) -> Result<Tensor> {
795        // x_hat = features @ W_dec + b_dec
796        // [..., d_sae] @ [d_sae, d_in] → [..., d_in]
797        let features_f32 = features.to_dtype(DType::F32)?;
798        let decoded = features_f32.broadcast_matmul(&self.w_dec)?;
799        let b_dec = broadcast_bias(&self.b_dec, decoded.dims())?;
800        Ok((&decoded + &b_dec)?)
801    }
802
803    // --- Reconstruction ---
804
805    /// Reconstruct activations through the SAE (encode then decode).
806    ///
807    /// # Shapes
808    /// - `x`: `[..., d_in]` — original activations
809    /// - returns: `[..., d_in]` — reconstructed activations
810    ///
811    /// # Errors
812    ///
813    /// Returns [`MIError::Config`] if the last dimension of `x` != `d_in`.
814    /// Returns [`MIError::Model`] on tensor operation failure.
815    pub fn reconstruct(&self, x: &Tensor) -> Result<Tensor> {
816        let encoded = self.encode(x)?;
817        self.decode(&encoded)
818    }
819
820    /// Compute reconstruction MSE loss.
821    ///
822    /// # Shapes
823    /// - `x`: `[..., d_in]` — original activations
824    /// - returns: scalar `f64` mean squared error
825    ///
826    /// # Errors
827    ///
828    /// Returns [`MIError::Config`] if the last dimension of `x` != `d_in`.
829    /// Returns [`MIError::Model`] on tensor operation failure.
830    pub fn reconstruction_error(&self, x: &Tensor) -> Result<f64> {
831        let x_f32 = x.to_dtype(DType::F32)?;
832        let x_hat = self.reconstruct(&x_f32)?;
833        let diff = (&x_f32 - &x_hat)?;
834        let mse: f32 = diff.sqr()?.mean_all()?.to_scalar()?;
835        Ok(f64::from(mse))
836    }
837
838    // --- Steering ---
839
840    /// Extract a single feature's decoder vector (steering direction).
841    ///
842    /// # Shapes
843    /// - returns: `[d_in]` — decoder vector on the SAE's device
844    ///
845    /// # Errors
846    ///
847    /// Returns [`MIError::Config`] if `feature_idx` >= `d_sae`.
848    /// Returns [`MIError::Model`] on tensor operation failure.
849    pub fn decoder_vector(&self, feature_idx: usize) -> Result<Tensor> {
850        if feature_idx >= self.config.d_sae {
851            return Err(MIError::Config(format!(
852                "feature index {feature_idx} out of range (d_sae={})",
853                self.config.d_sae
854            )));
855        }
856        // W_dec: [d_sae, d_in] → row feature_idx → [d_in]
857        Ok(self.w_dec.get(feature_idx)?)
858    }
859
860    /// Build a [`HookSpec`] that injects SAE decoder vectors into the model.
861    ///
862    /// Creates an [`Intervention::Add`] at this SAE's hook point with the
863    /// accumulated (scaled) decoder vectors placed at the given position.
864    ///
865    /// # Shapes
866    /// - Internally constructs `[1, seq_len, d_in]` with the vector at `position`.
867    ///
868    /// # Arguments
869    /// * `features` — List of `(feature_index, strength)` pairs
870    /// * `position` — Token position in the sequence to inject at
871    /// * `seq_len` — Total sequence length
872    /// * `device` — Device to construct injection tensors on
873    ///
874    /// # Errors
875    ///
876    /// Returns [`MIError::Config`] if any `feature_index` >= `d_sae`.
877    /// Returns [`MIError::Model`] on tensor construction failure.
878    pub fn prepare_hook_injection(
879        &self,
880        features: &[(usize, f32)],
881        position: usize,
882        seq_len: usize,
883        device: &Device,
884    ) -> Result<HookSpec> {
885        let d_in = self.config.d_in;
886
887        // Accumulate weighted decoder vectors.
888        let mut accumulated = Tensor::zeros(d_in, DType::F32, device)?;
889        for &(feature_idx, strength) in features {
890            let dec_vec = self.decoder_vector(feature_idx)?;
891            let dec_vec = dec_vec.to_device(device)?;
892            let scaled = (&dec_vec * f64::from(strength))?;
893            accumulated = (&accumulated + &scaled)?;
894        }
895
896        // Build [1, seq_len, d_in] with vector at `position`.
897        let injection = Tensor::zeros((1, seq_len, d_in), DType::F32, device)?;
898        let scaled_3d = accumulated.unsqueeze(0)?.unsqueeze(0)?; // [1, 1, d_in]
899
900        let before = if position > 0 {
901            Some(injection.narrow(1, 0, position)?)
902        } else {
903            None
904        };
905        let after = if position + 1 < seq_len {
906            Some(injection.narrow(1, position + 1, seq_len - position - 1)?)
907        } else {
908            None
909        };
910
911        let mut parts: Vec<Tensor> = Vec::with_capacity(3);
912        if let Some(b) = before {
913            parts.push(b);
914        }
915        parts.push(scaled_3d);
916        if let Some(a) = after {
917            parts.push(a);
918        }
919
920        let injection = Tensor::cat(&parts, 1)?;
921
922        let mut hooks = HookSpec::new();
923        hooks.intervene(self.config.hook_point.clone(), Intervention::Add(injection));
924        Ok(hooks)
925    }
926}
927
928// ---------------------------------------------------------------------------
929// TopK activation
930// ---------------------------------------------------------------------------
931
932/// Apply top-k activation: keep only the k largest values, zero the rest.
933///
934/// # Shapes
935/// - `pre_acts`: `[..., d_sae]` — pre-activation values
936/// - returns: same shape, with all but top-k values zeroed per last-dim slice
937///
938/// # Strategy
939/// - [`TopKStrategy::Auto`]: CPU → direct iteration, GPU → sort-based
940/// - [`TopKStrategy::Cpu`]: force CPU path
941/// - [`TopKStrategy::Gpu`]: force GPU sort-based path
942fn topk_activation(pre_acts: &Tensor, k: usize, strategy: &TopKStrategy) -> Result<Tensor> {
943    let use_cpu = match strategy {
944        TopKStrategy::Cpu => true,
945        TopKStrategy::Gpu => false,
946        TopKStrategy::Auto => matches!(pre_acts.device(), Device::Cpu),
947    };
948
949    if use_cpu {
950        topk_cpu(pre_acts, k)
951    } else {
952        topk_gpu(pre_acts, k)
953    }
954}
955
956/// `TopK` via CPU-side partial sort.
957fn topk_cpu(pre_acts: &Tensor, k: usize) -> Result<Tensor> {
958    let device = pre_acts.device().clone();
959    let shape = pre_acts.dims().to_vec();
960    let d_sae = *shape
961        .last()
962        .ok_or_else(|| MIError::Config("cannot apply TopK to empty tensor".into()))?;
963
964    // Flatten to 2D: [n, d_sae]
965    let n: usize = shape.iter().take(shape.len() - 1).product();
966    let flat = pre_acts.reshape((n, d_sae))?.to_dtype(DType::F32)?;
967    let flat_cpu = flat.to_device(&Device::Cpu)?;
968
969    let mut result_data: Vec<f32> = Vec::with_capacity(n * d_sae);
970
971    for row_idx in 0..n {
972        let row = flat_cpu.get(row_idx)?;
973        let mut row_vec: Vec<f32> = row.to_vec1()?;
974
975        // Find the k-th largest value via partial sort.
976        let k_clamped = k.min(d_sae);
977        if k_clamped > 0 && k_clamped < d_sae {
978            // Partial sort: put the k largest elements at the front.
979            let mut indices: Vec<usize> = (0..d_sae).collect();
980            #[allow(clippy::indexing_slicing)]
981            // CONTIGUOUS: indices and row_vec are both exactly d_sae elements
982            indices.select_nth_unstable_by(k_clamped - 1, |&a, &b| {
983                let va = row_vec.get(b).copied().unwrap_or(f32::NEG_INFINITY);
984                let vb = row_vec.get(a).copied().unwrap_or(f32::NEG_INFINITY);
985                va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal)
986            });
987            let threshold_idx = indices.get(k_clamped - 1).copied().unwrap_or(0);
988            let threshold = row_vec.get(threshold_idx).copied().unwrap_or(0.0);
989
990            // Zero values below threshold.
991            for v in &mut row_vec {
992                if *v < threshold {
993                    *v = 0.0;
994                }
995            }
996
997            // If there are ties at threshold, we might keep more than k.
998            // Count how many are >= threshold and zero extras from the end.
999            let active: usize = row_vec.iter().filter(|&&v| v >= threshold).count();
1000            if active > k_clamped {
1001                let mut excess = active - k_clamped;
1002                for v in row_vec.iter_mut().rev() {
1003                    if excess == 0 {
1004                        break;
1005                    }
1006                    if (*v - threshold).abs() < f32::EPSILON {
1007                        *v = 0.0;
1008                        excess -= 1;
1009                    }
1010                }
1011            }
1012        } else if k_clamped == 0 {
1013            row_vec.fill(0.0);
1014        }
1015        // k_clamped >= d_sae: keep all values
1016
1017        result_data.extend_from_slice(&row_vec);
1018    }
1019
1020    let result = Tensor::from_vec(result_data, (n, d_sae), &device)?;
1021    result.reshape(shape.as_slice()).map_err(Into::into)
1022}
1023
1024/// `TopK` via GPU sort-based masking.
1025fn topk_gpu(pre_acts: &Tensor, k: usize) -> Result<Tensor> {
1026    let shape = pre_acts.dims().to_vec();
1027    let d_sae = *shape
1028        .last()
1029        .ok_or_else(|| MIError::Config("cannot apply TopK to empty tensor".into()))?;
1030
1031    let k_clamped = k.min(d_sae);
1032    if k_clamped == 0 {
1033        return Ok(pre_acts.zeros_like()?);
1034    }
1035    if k_clamped >= d_sae {
1036        return Ok(pre_acts.clone());
1037    }
1038
1039    // Flatten to 2D for sort_last_dim.
1040    let n: usize = shape.iter().take(shape.len() - 1).product();
1041    let flat = pre_acts.reshape((n, d_sae))?.to_dtype(DType::F32)?;
1042
1043    // Sort descending along last dim.
1044    let (sorted_vals, _sorted_indices) = flat.sort_last_dim(false)?;
1045
1046    // Get the k-th largest value per row: [n, 1]
1047    let kth_vals = sorted_vals.narrow(1, k_clamped - 1, 1)?;
1048
1049    // Mask: keep values >= kth value.
1050    let mask = flat.ge(&kth_vals)?;
1051    let mask_f32 = mask.to_dtype(DType::F32)?;
1052
1053    let result = (&flat * &mask_f32)?;
1054    result.reshape(shape.as_slice()).map_err(Into::into)
1055}
1056
1057// ---------------------------------------------------------------------------
1058// Helper functions
1059// ---------------------------------------------------------------------------
1060
1061/// Broadcast a 1D bias `[dim]` to match an arbitrary target shape.
1062///
1063/// For example, bias `[d_sae]` with target shape `[batch, seq, d_sae]`
1064/// is reshaped to `[1, 1, d_sae]` then broadcast to `[batch, seq, d_sae]`.
1065///
1066/// # Shapes
1067/// - `bias`: `[dim]` — 1D bias vector
1068/// - `target_shape`: the shape to broadcast into (last dim must match)
1069/// - returns: bias with same shape as `target_shape`
1070fn broadcast_bias(bias: &Tensor, target_shape: &[usize]) -> Result<Tensor> {
1071    let ndim = target_shape.len();
1072    if ndim <= 1 {
1073        return Ok(bias.clone());
1074    }
1075    // Reshape [dim] → [1, 1, ..., dim] with (ndim - 1) leading 1s.
1076    let mut shape = vec![1_usize; ndim];
1077    let last_dim = *target_shape
1078        .last()
1079        .ok_or_else(|| MIError::Config("cannot broadcast bias to empty shape".into()))?;
1080    if let Some(slot) = shape.last_mut() {
1081        *slot = last_dim;
1082    }
1083    let reshaped = bias.reshape(shape.as_slice())?;
1084    Ok(reshaped.broadcast_as(target_shape)?)
1085}
1086
1087/// Convert a safetensors `TensorView` to a candle `Tensor`.
1088///
1089/// # Shapes
1090/// - Preserves the original tensor shape from safetensors.
1091///
1092/// # Errors
1093///
1094/// Returns [`MIError::Config`] if the tensor dtype is not supported (BF16, F16, F32).
1095/// Returns [`MIError::Model`] on tensor construction failure.
1096fn tensor_from_view(view: &safetensors::tensor::TensorView<'_>, device: &Device) -> Result<Tensor> {
1097    let shape: Vec<usize> = view.shape().to_vec();
1098    #[allow(clippy::wildcard_enum_match_arm)]
1099    // EXHAUSTIVE: safetensors exposes many dtypes; SAEs only use float types
1100    let dtype = match view.dtype() {
1101        safetensors::Dtype::BF16 => DType::BF16,
1102        safetensors::Dtype::F16 => DType::F16,
1103        safetensors::Dtype::F32 => DType::F32,
1104        other => {
1105            return Err(MIError::Config(format!(
1106                "unsupported SAE tensor dtype: {other:?}"
1107            )));
1108        }
1109    };
1110    let tensor = Tensor::from_raw_buffer(view.data(), dtype, &shape, device)?;
1111    Ok(tensor)
1112}
1113
1114/// Load a named tensor from safetensors.
1115fn load_tensor(st: &SafeTensors<'_>, name: &str, device: &Device) -> Result<Tensor> {
1116    let view = st
1117        .tensor(name)
1118        .map_err(|e| MIError::Config(format!("tensor '{name}' not found: {e}")))?;
1119    tensor_from_view(&view, device)
1120}
1121
1122/// Validate that a tensor has the expected shape.
1123fn validate_shape(tensor: &Tensor, expected: &[usize], name: &str) -> Result<()> {
1124    if tensor.dims() != expected {
1125        return Err(MIError::Config(format!(
1126            "SAE tensor '{name}' shape mismatch: expected {expected:?}, got {:?}",
1127            tensor.dims()
1128        )));
1129    }
1130    Ok(())
1131}
1132
1133// ---------------------------------------------------------------------------
1134// Tests
1135// ---------------------------------------------------------------------------
1136
1137#[cfg(test)]
1138mod tests {
1139    use super::*;
1140
1141    #[test]
1142    fn sae_feature_id_display() {
1143        let fid = SaeFeatureId { index: 42 };
1144        assert_eq!(fid.to_string(), "SAE:42");
1145    }
1146
1147    #[test]
1148    fn resolve_architecture_relu_default() {
1149        let arch = resolve_architecture(None, None, None).unwrap();
1150        assert_eq!(arch, SaeArchitecture::ReLU);
1151    }
1152
1153    #[test]
1154    fn resolve_architecture_relu_explicit() {
1155        let arch = resolve_architecture(Some("standard"), Some("relu"), None).unwrap();
1156        assert_eq!(arch, SaeArchitecture::ReLU);
1157    }
1158
1159    #[test]
1160    fn resolve_architecture_jumprelu() {
1161        let arch = resolve_architecture(Some("jumprelu"), None, None).unwrap();
1162        assert_eq!(arch, SaeArchitecture::JumpReLU);
1163    }
1164
1165    #[test]
1166    fn resolve_architecture_jumprelu_from_activation() {
1167        let arch = resolve_architecture(None, Some("jumprelu"), None).unwrap();
1168        assert_eq!(arch, SaeArchitecture::JumpReLU);
1169    }
1170
1171    #[test]
1172    fn resolve_architecture_topk() {
1173        let kwargs = serde_json::json!({"k": 32});
1174        let arch = resolve_architecture(Some("topk"), None, Some(&kwargs)).unwrap();
1175        assert_eq!(arch, SaeArchitecture::TopK { k: 32 });
1176    }
1177
1178    #[test]
1179    fn resolve_architecture_topk_from_activation() {
1180        let kwargs = serde_json::json!({"k": 64});
1181        let arch = resolve_architecture(None, Some("topk"), Some(&kwargs)).unwrap();
1182        assert_eq!(arch, SaeArchitecture::TopK { k: 64 });
1183    }
1184
1185    #[test]
1186    fn resolve_architecture_topk_missing_k() {
1187        let result = resolve_architecture(Some("topk"), None, None);
1188        assert!(result.is_err());
1189    }
1190
1191    #[test]
1192    fn resolve_architecture_unknown() {
1193        let result = resolve_architecture(Some("gated"), None, None);
1194        assert!(result.is_err());
1195    }
1196
1197    #[test]
1198    fn parse_config_minimal() {
1199        let json = r#"{
1200            "d_in": 2304,
1201            "d_sae": 16384,
1202            "hook_name": "blocks.5.hook_resid_post"
1203        }"#;
1204        let raw: RawSaeConfig = serde_json::from_str(json).unwrap();
1205        let config = parse_sae_config(raw).unwrap();
1206        assert_eq!(config.d_in, 2304);
1207        assert_eq!(config.d_sae, 16384);
1208        assert_eq!(config.architecture, SaeArchitecture::ReLU);
1209        assert_eq!(config.hook_point, HookPoint::ResidPost(5));
1210        assert!(!config.apply_b_dec_to_input);
1211    }
1212
1213    #[test]
1214    fn parse_config_jumprelu() {
1215        let json = r#"{
1216            "d_in": 2304,
1217            "d_sae": 16384,
1218            "architecture": "jumprelu",
1219            "hook_name": "blocks.20.hook_resid_post",
1220            "apply_b_dec_to_input": true,
1221            "normalize_activations": "expected_average_only_in"
1222        }"#;
1223        let raw: RawSaeConfig = serde_json::from_str(json).unwrap();
1224        let config = parse_sae_config(raw).unwrap();
1225        assert_eq!(config.architecture, SaeArchitecture::JumpReLU);
1226        assert_eq!(config.hook_point, HookPoint::ResidPost(20));
1227        assert!(config.apply_b_dec_to_input);
1228        assert_eq!(
1229            config.normalize_activations,
1230            NormalizeActivations::ExpectedAverageOnlyIn
1231        );
1232    }
1233
1234    #[test]
1235    fn parse_config_topk() {
1236        let json = r#"{
1237            "d_in": 2304,
1238            "d_sae": 65536,
1239            "activation_fn_str": "topk",
1240            "activation_fn_kwargs": {"k": 32},
1241            "hook_name": "blocks.10.hook_resid_post"
1242        }"#;
1243        let raw: RawSaeConfig = serde_json::from_str(json).unwrap();
1244        let config = parse_sae_config(raw).unwrap();
1245        assert_eq!(config.architecture, SaeArchitecture::TopK { k: 32 });
1246    }
1247
1248    #[test]
1249    fn topk_cpu_basic() {
1250        let data = Tensor::new(&[[5.0_f32, 3.0, 1.0, 4.0, 2.0]], &Device::Cpu).unwrap();
1251        let result = topk_cpu(&data, 2).unwrap();
1252        let vals: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
1253        assert_eq!(vals, vec![5.0, 0.0, 0.0, 4.0, 0.0]);
1254    }
1255
1256    #[test]
1257    fn topk_cpu_all_kept() {
1258        let data = Tensor::new(&[[1.0_f32, 2.0, 3.0]], &Device::Cpu).unwrap();
1259        let result = topk_cpu(&data, 5).unwrap();
1260        let vals: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
1261        assert_eq!(vals, vec![1.0, 2.0, 3.0]);
1262    }
1263
1264    #[test]
1265    fn topk_cpu_none_kept() {
1266        let data = Tensor::new(&[[1.0_f32, 2.0, 3.0]], &Device::Cpu).unwrap();
1267        let result = topk_cpu(&data, 0).unwrap();
1268        let vals: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
1269        assert_eq!(vals, vec![0.0, 0.0, 0.0]);
1270    }
1271
1272    #[test]
1273    fn topk_cpu_batched() {
1274        let data = Tensor::new(
1275            &[[5.0_f32, 3.0, 1.0, 4.0, 2.0], [1.0, 2.0, 3.0, 4.0, 5.0]],
1276            &Device::Cpu,
1277        )
1278        .unwrap();
1279        let result = topk_cpu(&data, 3).unwrap();
1280        let vals: Vec<Vec<f32>> = result.to_vec2().unwrap();
1281        assert_eq!(vals[0], vec![5.0, 3.0, 0.0, 4.0, 0.0]);
1282        assert_eq!(vals[1], vec![0.0, 0.0, 3.0, 4.0, 5.0]);
1283    }
1284
1285    #[test]
1286    fn sparse_activations_sae() {
1287        let features = vec![
1288            (SaeFeatureId { index: 5 }, 3.0),
1289            (SaeFeatureId { index: 2 }, 2.0),
1290            (SaeFeatureId { index: 8 }, 1.0),
1291        ];
1292        let sparse = SparseActivations { features };
1293        assert_eq!(sparse.len(), 3);
1294        assert!(!sparse.is_empty());
1295    }
1296
1297    #[test]
1298    fn sparse_activations_truncate_sae() {
1299        let features = vec![
1300            (SaeFeatureId { index: 5 }, 3.0),
1301            (SaeFeatureId { index: 2 }, 2.0),
1302            (SaeFeatureId { index: 8 }, 1.0),
1303        ];
1304        let mut sparse = SparseActivations { features };
1305        sparse.truncate(2);
1306        assert_eq!(sparse.len(), 2);
1307        assert_eq!(sparse.features[0].0.index, 5);
1308        assert_eq!(sparse.features[1].0.index, 2);
1309    }
1310
1311    #[test]
1312    fn encode_decode_roundtrip_shapes() {
1313        // Create a tiny SAE for shape testing.
1314        let d_in = 4;
1315        let d_sae = 8;
1316        let device = Device::Cpu;
1317
1318        let w_enc = Tensor::randn(0.0_f32, 1.0, (d_in, d_sae), &device).unwrap();
1319        let w_dec = Tensor::randn(0.0_f32, 1.0, (d_sae, d_in), &device).unwrap();
1320        let b_enc = Tensor::zeros(d_sae, DType::F32, &device).unwrap();
1321        let b_dec = Tensor::zeros(d_in, DType::F32, &device).unwrap();
1322
1323        let sae = SparseAutoencoder {
1324            config: SaeConfig {
1325                d_in,
1326                d_sae,
1327                architecture: SaeArchitecture::ReLU,
1328                hook_name: "blocks.0.hook_resid_post".into(),
1329                hook_point: HookPoint::ResidPost(0),
1330                apply_b_dec_to_input: false,
1331                normalize_activations: NormalizeActivations::None,
1332            },
1333            w_enc,
1334            w_dec,
1335            b_enc,
1336            b_dec,
1337            threshold: None,
1338        };
1339
1340        // Test 1D input.
1341        let x1 = Tensor::randn(0.0_f32, 1.0, (d_in,), &device).unwrap();
1342        let encoded = sae.encode(&x1.unsqueeze(0).unwrap()).unwrap();
1343        assert_eq!(encoded.dims(), &[1, d_sae]);
1344
1345        // Test 2D input.
1346        let x2 = Tensor::randn(0.0_f32, 1.0, (3, d_in), &device).unwrap();
1347        let encoded = sae.encode(&x2).unwrap();
1348        assert_eq!(encoded.dims(), &[3, d_sae]);
1349        let decoded = sae.decode(&encoded).unwrap();
1350        assert_eq!(decoded.dims(), &[3, d_in]);
1351
1352        // Test 3D input.
1353        let x3 = Tensor::randn(0.0_f32, 1.0, (2, 5, d_in), &device).unwrap();
1354        let encoded = sae.encode(&x3).unwrap();
1355        assert_eq!(encoded.dims(), &[2, 5, d_sae]);
1356        let decoded = sae.decode(&encoded).unwrap();
1357        assert_eq!(decoded.dims(), &[2, 5, d_in]);
1358
1359        // Test reconstruction.
1360        let x_hat = sae.reconstruct(&x2).unwrap();
1361        assert_eq!(x_hat.dims(), &[3, d_in]);
1362
1363        // Test reconstruction error.
1364        let mse = sae.reconstruction_error(&x2).unwrap();
1365        assert!(mse >= 0.0);
1366    }
1367
1368    #[test]
1369    fn encode_sparse_basic() {
1370        let d_in = 4;
1371        let d_sae = 8;
1372        let device = Device::Cpu;
1373
1374        // Use identity-like encoder to get predictable output.
1375        let mut w_enc_data = vec![0.0_f32; d_in * d_sae];
1376        // Map input dim 0 → feature 0, dim 1 → feature 1, etc.
1377        for i in 0..d_in {
1378            w_enc_data[i * d_sae + i] = 1.0;
1379        }
1380        let w_enc = Tensor::from_vec(w_enc_data, (d_in, d_sae), &device).unwrap();
1381        let w_dec = Tensor::randn(0.0_f32, 1.0, (d_sae, d_in), &device).unwrap();
1382        let b_enc = Tensor::zeros(d_sae, DType::F32, &device).unwrap();
1383        let b_dec = Tensor::zeros(d_in, DType::F32, &device).unwrap();
1384
1385        let sae = SparseAutoencoder {
1386            config: SaeConfig {
1387                d_in,
1388                d_sae,
1389                architecture: SaeArchitecture::ReLU,
1390                hook_name: "blocks.0.hook_resid_post".into(),
1391                hook_point: HookPoint::ResidPost(0),
1392                apply_b_dec_to_input: false,
1393                normalize_activations: NormalizeActivations::None,
1394            },
1395            w_enc,
1396            w_dec,
1397            b_enc,
1398            b_dec,
1399            threshold: None,
1400        };
1401
1402        let x = Tensor::new(&[2.0_f32, -1.0, 3.0, 0.5], &device).unwrap();
1403        let sparse = sae.encode_sparse(&x).unwrap();
1404
1405        // Only positive values should appear: 2.0, 3.0, 0.5
1406        assert_eq!(sparse.len(), 3);
1407        // Should be sorted descending.
1408        assert_eq!(sparse.features[0].0.index, 2); // 3.0
1409        assert_eq!(sparse.features[1].0.index, 0); // 2.0
1410        assert_eq!(sparse.features[2].0.index, 3); // 0.5
1411    }
1412
1413    #[test]
1414    fn decoder_vector_basic() {
1415        let d_in = 4;
1416        let d_sae = 8;
1417        let device = Device::Cpu;
1418
1419        let w_dec = Tensor::randn(0.0_f32, 1.0, (d_sae, d_in), &device).unwrap();
1420        let sae = SparseAutoencoder {
1421            config: SaeConfig {
1422                d_in,
1423                d_sae,
1424                architecture: SaeArchitecture::ReLU,
1425                hook_name: "blocks.0.hook_resid_post".into(),
1426                hook_point: HookPoint::ResidPost(0),
1427                apply_b_dec_to_input: false,
1428                normalize_activations: NormalizeActivations::None,
1429            },
1430            w_enc: Tensor::zeros((d_in, d_sae), DType::F32, &device).unwrap(),
1431            w_dec: w_dec.clone(),
1432            b_enc: Tensor::zeros(d_sae, DType::F32, &device).unwrap(),
1433            b_dec: Tensor::zeros(d_in, DType::F32, &device).unwrap(),
1434            threshold: None,
1435        };
1436
1437        let vec0 = sae.decoder_vector(0).unwrap();
1438        assert_eq!(vec0.dims(), &[d_in]);
1439
1440        // Out of range should error.
1441        assert!(sae.decoder_vector(d_sae).is_err());
1442    }
1443
1444    #[test]
1445    fn prepare_injection_basic() {
1446        let d_in = 4;
1447        let d_sae = 8;
1448        let device = Device::Cpu;
1449
1450        let sae = SparseAutoencoder {
1451            config: SaeConfig {
1452                d_in,
1453                d_sae,
1454                architecture: SaeArchitecture::ReLU,
1455                hook_name: "blocks.0.hook_resid_post".into(),
1456                hook_point: HookPoint::ResidPost(0),
1457                apply_b_dec_to_input: false,
1458                normalize_activations: NormalizeActivations::None,
1459            },
1460            w_enc: Tensor::zeros((d_in, d_sae), DType::F32, &device).unwrap(),
1461            w_dec: Tensor::ones((d_sae, d_in), DType::F32, &device).unwrap(),
1462            b_enc: Tensor::zeros(d_sae, DType::F32, &device).unwrap(),
1463            b_dec: Tensor::zeros(d_in, DType::F32, &device).unwrap(),
1464            threshold: None,
1465        };
1466
1467        let features = vec![(0_usize, 1.0_f32), (1, 0.5)];
1468        let hooks = sae
1469            .prepare_hook_injection(&features, 2, 5, &device)
1470            .unwrap();
1471        assert!(!hooks.is_empty());
1472    }
1473}