Skip to main content

candle_mi/clt/
mod.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Cross-Layer Transcoder (CLT) support.
4//!
5//! Loads pre-trained CLT weights from `HuggingFace` (circuit-tracer format),
6//! encodes residual stream activations into sparse feature activations,
7//! injects decoder vectors into the residual stream for steering, and
8//! scores features by decoder projection for attribution graph construction.
9//!
10//! Memory-efficient: uses stream-and-free for encoders (~75 MB/layer on GPU)
11//! and a micro-cache for steering vectors (~450 KB for 50 features).
12//! Decoder scoring operates entirely on CPU (one file at a time, up to ~2 GB).
13//!
14//! # CLT Architecture
15//!
16//! A cross-layer transcoder at layer `l` implements:
17//! ```text
18//! Encode:  features = ReLU(W_enc[l] @ residual_mid[l] + b_enc[l])
19//! Decode:  For each downstream layer l' >= l:
20//!            mlp_out_hat[l'] += W_dec[l, l'] @ features + b_dec[l']
21//! Inject:  residual[pos] += strength × W_dec[l, target_layer, feature_idx, :]
22//! ```
23//!
24//! # Weight File Layout (circuit-tracer format)
25//!
26//! Each encoder file `W_enc_{l}.safetensors` contains:
27//! - `W_enc_{l}`: shape `[n_features, d_model]` (BF16) — encoder weight matrix
28//! - `b_enc_{l}`: shape `[n_features]` (BF16) — encoder bias
29//! - `b_dec_{l}`: shape `[d_model]` (BF16) — decoder bias for target layer l
30//!
31//! Each decoder file `W_dec_{l}.safetensors` contains:
32//! - `W_dec_{l}`: shape `[n_features, n_target_layers, d_model]` (BF16)
33//!   where `n_target_layers = n_layers - l` (layer l writes to layers l..n_layers-1)
34
35use std::collections::HashMap;
36use std::path::PathBuf;
37
38use candle_core::{DType, Device, IndexOp, Tensor};
39use safetensors::tensor::SafeTensors;
40use tracing::info;
41
42use crate::error::{MIError, Result};
43
44// ---------------------------------------------------------------------------
45// Public types
46// ---------------------------------------------------------------------------
47
48/// Identifies a single CLT feature by its source layer and index within that layer.
49#[derive(
50    Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize,
51)]
52pub struct CltFeatureId {
53    /// Source layer where this feature's encoder lives (`0..n_layers`).
54    pub layer: usize,
55    /// Feature index within the layer (`0..n_features_per_layer`).
56    pub index: usize,
57}
58
59impl std::fmt::Display for CltFeatureId {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        write!(f, "L{}:{}", self.layer, self.index)
62    }
63}
64
65use crate::sparse::{FeatureId, SparseActivations};
66
67impl FeatureId for CltFeatureId {}
68
69/// A single edge in a CLT attribution graph.
70///
71/// Represents a feature's decoder projection score onto a target direction
72/// at a specific downstream layer. Positive scores indicate alignment,
73/// negative scores indicate opposition.
74#[derive(Debug, Clone)]
75pub struct AttributionEdge {
76    /// The CLT feature contributing this edge.
77    pub feature: CltFeatureId,
78    /// Decoder projection score (dot product or cosine similarity).
79    pub score: f32,
80}
81
82/// Attribution graph for CLT circuit analysis.
83///
84/// Represents a set of CLT features scored by how strongly their decoder
85/// vectors project along a target direction at a specific layer. Built by
86/// [`CrossLayerTranscoder::build_attribution_graph()`] or
87/// [`CrossLayerTranscoder::build_attribution_graph_batch()`].
88///
89/// Edges are always sorted by score in descending order.
90///
91/// # Pruning
92///
93/// - [`top_k()`](Self::top_k): keep only the k highest-scoring features
94/// - [`threshold()`](Self::threshold): keep features with |score| above a minimum
95#[derive(Debug, Clone)]
96pub struct AttributionGraph {
97    /// Target layer these scores were computed for.
98    target_layer: usize,
99    /// Edges sorted by score descending.
100    edges: Vec<AttributionEdge>,
101}
102
103impl AttributionGraph {
104    /// Target layer this graph was scored against.
105    #[must_use]
106    pub const fn target_layer(&self) -> usize {
107        self.target_layer
108    }
109
110    /// All edges, sorted by score descending.
111    #[must_use]
112    pub fn edges(&self) -> &[AttributionEdge] {
113        &self.edges
114    }
115
116    /// Number of edges in the graph.
117    #[must_use]
118    pub const fn len(&self) -> usize {
119        self.edges.len()
120    }
121
122    /// Whether the graph has no edges.
123    #[must_use]
124    pub const fn is_empty(&self) -> bool {
125        self.edges.is_empty()
126    }
127
128    /// Return a new graph with only the top-k highest-scoring edges.
129    #[must_use]
130    pub fn top_k(&self, k: usize) -> Self {
131        Self {
132            target_layer: self.target_layer,
133            edges: self.edges.iter().take(k).cloned().collect(),
134        }
135    }
136
137    /// Return a new graph keeping only edges whose absolute score meets
138    /// or exceeds `min_score`.
139    #[must_use]
140    pub fn threshold(&self, min_score: f32) -> Self {
141        Self {
142            target_layer: self.target_layer,
143            edges: self
144                .edges
145                .iter()
146                .filter(|e| e.score.abs() >= min_score)
147                .cloned()
148                .collect(),
149        }
150    }
151
152    /// Extract the feature IDs from all edges in score order.
153    #[must_use]
154    pub fn features(&self) -> Vec<CltFeatureId> {
155        self.edges.iter().map(|e| e.feature).collect()
156    }
157
158    /// Consume the graph and return its edges.
159    #[must_use]
160    pub fn into_edges(self) -> Vec<AttributionEdge> {
161        self.edges
162    }
163}
164
165/// CLT configuration auto-detected from tensor shapes.
166#[derive(Debug, Clone)]
167pub struct CltConfig {
168    /// Number of layers in the base model (26 for Gemma 2 2B).
169    pub n_layers: usize,
170    /// Hidden dimension of the base model (2304 for Gemma 2 2B).
171    pub d_model: usize,
172    /// Number of features per encoder layer (16384 for CLT-426K).
173    pub n_features_per_layer: usize,
174    /// Total feature count across all layers.
175    pub n_features_total: usize,
176    /// Base model name from config.yaml.
177    pub model_name: String,
178}
179
180// ---------------------------------------------------------------------------
181// Internal types
182// ---------------------------------------------------------------------------
183
184/// Currently loaded encoder weights on GPU.
185struct LoadedEncoder {
186    /// Layer index this encoder corresponds to.
187    layer: usize,
188    /// Encoder weight matrix.
189    ///
190    /// # Shapes
191    /// - `w_enc`: `[n_features, d_model]`
192    w_enc: Tensor,
193    /// Encoder bias vector.
194    ///
195    /// # Shapes
196    /// - `b_enc`: `[n_features]`
197    b_enc: Tensor,
198}
199
200// ---------------------------------------------------------------------------
201// CrossLayerTranscoder
202// ---------------------------------------------------------------------------
203
204/// Cross-Layer Transcoder.
205///
206/// Loads CLT encoder/decoder weights on-demand from `HuggingFace` safetensors,
207/// with memory-efficient streaming (only one encoder on GPU at a time)
208/// and a micro-cache for steering vectors.
209///
210/// Downloads are lazy: [`open()`](Self::open) only fetches config and the first
211/// encoder for dimension detection. Subsequent files are downloaded as needed by
212/// [`load_encoder()`](Self::load_encoder), [`decoder_vector()`](Self::decoder_vector),
213/// and [`cache_steering_vectors()`](Self::cache_steering_vectors).
214///
215/// # Example
216///
217/// ```no_run
218/// # fn main() -> candle_mi::Result<()> {
219/// use candle_mi::clt::CrossLayerTranscoder;
220/// use candle_core::Device;
221///
222/// let mut clt = CrossLayerTranscoder::open("mntss/clt-gemma-2-2b-426k")?;
223/// println!("CLT: {} layers, d_model={}", clt.config().n_layers, clt.config().d_model);
224///
225/// // Load encoder for layer 10
226/// let device = Device::Cpu;
227/// clt.load_encoder(10, &device)?;
228/// # Ok(())
229/// # }
230/// ```
231pub struct CrossLayerTranscoder {
232    /// `HuggingFace` repository ID for on-demand downloads.
233    repo_id: String,
234    /// Fetch configuration for `hf-fetch-model` downloads.
235    fetch_config: hf_fetch_model::FetchConfig,
236    /// Local paths to already-downloaded encoder files (None = not yet downloaded).
237    encoder_paths: Vec<Option<PathBuf>>,
238    /// Local paths to already-downloaded decoder files (None = not yet downloaded).
239    decoder_paths: Vec<Option<PathBuf>>,
240    /// Auto-detected configuration.
241    config: CltConfig,
242    /// Currently loaded encoder (stream-and-free: only one at a time).
243    loaded_encoder: Option<LoadedEncoder>,
244    /// Micro-cache: pre-extracted steering vectors pinned on device.
245    /// Key: (`feature_id`, `target_layer`), Value: decoder vector `[d_model]` on device.
246    steering_cache: HashMap<(CltFeatureId, usize), Tensor>,
247}
248
249impl CrossLayerTranscoder {
250    /// Open a CLT from `HuggingFace` and detect its configuration.
251    ///
252    /// Only downloads `config.yaml` and `W_enc_0.safetensors` (~75 MB).
253    /// All other encoder/decoder files are downloaded lazily on first use.
254    ///
255    /// # Arguments
256    /// * `clt_repo` — `HuggingFace` repository ID (e.g., `"mntss/clt-gemma-2-2b-426k"`)
257    ///
258    /// # Errors
259    ///
260    /// Returns [`MIError::Download`] if the repository is inaccessible or files
261    /// cannot be fetched. Returns [`MIError::Config`] if the weight format is
262    /// unexpected.
263    pub fn open(clt_repo: &str) -> Result<Self> {
264        let fetch_config = hf_fetch_model::FetchConfig::builder()
265            .on_progress(|event| {
266                tracing::info!(
267                    filename = %event.filename,
268                    percent = event.percent,
269                    bytes_downloaded = event.bytes_downloaded,
270                    bytes_total = event.bytes_total,
271                    "CLT download progress",
272                );
273            })
274            .build()
275            .map_err(|e| MIError::Download(format!("failed to build fetch config: {e}")))?;
276
277        // Detect n_layers by listing repo files (no downloads needed).
278        let rt = tokio::runtime::Runtime::new()
279            .map_err(|e| MIError::Download(format!("failed to create tokio runtime: {e}")))?;
280        let repo_files = rt
281            .block_on(hf_fetch_model::repo::list_repo_files_with_metadata(
282                clt_repo, None, None,
283            ))
284            .map_err(|e| MIError::Download(format!("failed to list repo files: {e}")))?;
285        let n_layers = repo_files
286            .iter()
287            .filter(|f| f.filename.starts_with("W_enc_") && f.filename.ends_with(".safetensors"))
288            .count();
289        if n_layers == 0 {
290            return Err(MIError::Config(format!(
291                "no CLT encoder files found in {clt_repo}"
292            )));
293        }
294
295        // Parse config.yaml for model_name (simple line-by-line, no serde_yaml dep).
296        let model_name = match hf_fetch_model::download_file_blocking(
297            clt_repo.to_owned(),
298            "config.yaml",
299            &fetch_config,
300        ) {
301            Ok(outcome) => {
302                let path = outcome.into_inner();
303                let text = std::fs::read_to_string(&path)?;
304                parse_yaml_value(&text, "model_name").unwrap_or_else(|| "unknown".to_owned())
305            }
306            Err(_) => "unknown".to_owned(),
307        };
308
309        // Download W_enc_0 for dimension detection (~75 MB).
310        let enc0_path = hf_fetch_model::download_file_blocking(
311            clt_repo.to_owned(),
312            "W_enc_0.safetensors",
313            &fetch_config,
314        )
315        .map_err(|e| MIError::Download(format!("failed to download W_enc_0: {e}")))?
316        .into_inner();
317
318        let data = std::fs::read(&enc0_path)?;
319        let tensors = SafeTensors::deserialize(&data)
320            .map_err(|e| MIError::Config(format!("failed to deserialize W_enc_0: {e}")))?;
321        let w_enc_view = tensors
322            .tensor("W_enc_0")
323            .map_err(|e| MIError::Config(format!("tensor 'W_enc_0' not found: {e}")))?;
324        let shape = w_enc_view.shape();
325        if shape.len() != 2 {
326            return Err(MIError::Config(format!(
327                "expected 2D encoder weight, got shape {shape:?}"
328            )));
329        }
330        let n_features_per_layer = *shape
331            .first()
332            .ok_or_else(|| MIError::Config("encoder weight shape is empty".into()))?;
333        let d_model = *shape.get(1).ok_or_else(|| {
334            MIError::Config("encoder weight shape has fewer than 2 dimensions".into())
335        })?;
336
337        // Initialise paths: only first encoder known, rest downloaded lazily.
338        let mut encoder_paths: Vec<Option<PathBuf>> = vec![None; n_layers];
339        if let Some(slot) = encoder_paths.first_mut() {
340            *slot = Some(enc0_path);
341        }
342        let decoder_paths: Vec<Option<PathBuf>> = vec![None; n_layers];
343
344        let config = CltConfig {
345            n_layers,
346            d_model,
347            n_features_per_layer,
348            n_features_total: n_layers * n_features_per_layer,
349            model_name,
350        };
351        info!(
352            "CLT config: {} layers, d_model={}, features_per_layer={}, total={}",
353            config.n_layers, config.d_model, config.n_features_per_layer, config.n_features_total
354        );
355
356        Ok(Self {
357            repo_id: clt_repo.to_owned(),
358            fetch_config,
359            encoder_paths,
360            decoder_paths,
361            config,
362            loaded_encoder: None,
363            steering_cache: HashMap::new(),
364        })
365    }
366
367    /// Access the auto-detected CLT configuration.
368    #[must_use]
369    pub const fn config(&self) -> &CltConfig {
370        &self.config
371    }
372
373    /// Check whether an encoder is currently loaded and for which layer.
374    #[must_use]
375    pub fn loaded_encoder_layer(&self) -> Option<usize> {
376        self.loaded_encoder.as_ref().map(|e| e.layer)
377    }
378
379    // --- Lazy download helpers ---
380
381    /// Ensure the encoder file for a given layer is downloaded. Returns the path.
382    fn ensure_encoder_path(&mut self, layer: usize) -> Result<PathBuf> {
383        if let Some(path) = self
384            .encoder_paths
385            .get(layer)
386            .and_then(std::option::Option::as_ref)
387        {
388            // BORROW: explicit .clone() — PathBuf from Vec
389            return Ok(path.clone());
390        }
391        let filename = format!("W_enc_{layer}.safetensors");
392        info!("Downloading {filename} from {}", self.repo_id);
393        let path = hf_fetch_model::download_file_blocking(
394            self.repo_id.clone(),
395            &filename,
396            &self.fetch_config,
397        )
398        .map_err(|e| MIError::Download(format!("failed to download {filename}: {e}")))?
399        .into_inner();
400        if let Some(slot) = self.encoder_paths.get_mut(layer) {
401            // BORROW: explicit .clone() — store PathBuf in cache
402            *slot = Some(path.clone());
403        }
404        Ok(path)
405    }
406
407    /// Ensure the decoder file for a given layer is downloaded. Returns the path.
408    fn ensure_decoder_path(&mut self, layer: usize) -> Result<PathBuf> {
409        if let Some(path) = self
410            .decoder_paths
411            .get(layer)
412            .and_then(std::option::Option::as_ref)
413        {
414            // BORROW: explicit .clone() — PathBuf from Vec
415            return Ok(path.clone());
416        }
417        let filename = format!("W_dec_{layer}.safetensors");
418        info!("Downloading {filename} from {}", self.repo_id);
419        let path = hf_fetch_model::download_file_blocking(
420            self.repo_id.clone(),
421            &filename,
422            &self.fetch_config,
423        )
424        .map_err(|e| MIError::Download(format!("failed to download {filename}: {e}")))?
425        .into_inner();
426        if let Some(slot) = self.decoder_paths.get_mut(layer) {
427            // BORROW: explicit .clone() — store PathBuf in cache
428            *slot = Some(path.clone());
429        }
430        Ok(path)
431    }
432
433    // --- Encoder loading (stream-and-free) ---
434
435    /// Load a single encoder's weights to the specified device.
436    ///
437    /// Frees any previously loaded encoder first (stream-and-free pattern).
438    /// Peak GPU overhead: ~75 MB for CLT-426K, ~450 MB for CLT-2.5M.
439    ///
440    /// # Arguments
441    /// * `layer` — Layer index (`0..n_layers`)
442    /// * `device` — Target device (CPU or CUDA)
443    ///
444    /// # Errors
445    ///
446    /// Returns [`MIError::Config`] if the layer is out of range.
447    /// Returns [`MIError::Download`] if the encoder file cannot be fetched.
448    /// Returns [`MIError::Model`] on tensor deserialization failure.
449    pub fn load_encoder(&mut self, layer: usize, device: &Device) -> Result<()> {
450        if layer >= self.config.n_layers {
451            return Err(MIError::Config(format!(
452                "layer {layer} out of range (CLT has {} layers)",
453                self.config.n_layers
454            )));
455        }
456
457        // Skip if already loaded.
458        if let Some(ref enc) = self.loaded_encoder {
459            if enc.layer == layer {
460                return Ok(());
461            }
462        }
463
464        // Drop previous encoder (frees GPU memory).
465        self.loaded_encoder = None;
466
467        info!("Loading CLT encoder for layer {layer}");
468
469        let enc_path = self.ensure_encoder_path(layer)?;
470        let data = std::fs::read(&enc_path)?;
471        let st = SafeTensors::deserialize(&data).map_err(|e| {
472            MIError::Config(format!("failed to deserialize encoder layer {layer}: {e}"))
473        })?;
474
475        let w_enc_name = format!("W_enc_{layer}");
476        let b_enc_name = format!("b_enc_{layer}");
477
478        let w_enc = tensor_from_view(
479            &st.tensor(&w_enc_name)
480                .map_err(|e| MIError::Config(format!("tensor '{w_enc_name}' not found: {e}")))?,
481            device,
482        )?;
483        let b_enc = tensor_from_view(
484            &st.tensor(&b_enc_name)
485                .map_err(|e| MIError::Config(format!("tensor '{b_enc_name}' not found: {e}")))?,
486            device,
487        )?;
488
489        self.loaded_encoder = Some(LoadedEncoder {
490            layer,
491            w_enc,
492            b_enc,
493        });
494
495        Ok(())
496    }
497
498    // --- Encoding ---
499
500    /// Encode a residual stream activation into sparse CLT features.
501    ///
502    /// The residual should be the "residual mid" activation at the given layer
503    /// (after attention, before MLP).
504    ///
505    /// Returns all features that pass the `ReLU` threshold, sorted by
506    /// activation magnitude in descending order.
507    ///
508    /// # Shapes
509    /// - `residual`: `[d_model]` — residual stream activation at one position
510    /// - returns: [`SparseActivations<CltFeatureId>`] with `(CltFeatureId, f32)` pairs
511    ///
512    /// # Requires
513    /// [`load_encoder(layer)`](Self::load_encoder) must have been called first.
514    ///
515    /// # Errors
516    ///
517    /// Returns [`MIError::Hook`] if no encoder is loaded or the wrong layer is loaded.
518    /// Returns [`MIError::Model`] on tensor operation failure.
519    pub fn encode(
520        &self,
521        residual: &Tensor,
522        layer: usize,
523    ) -> Result<SparseActivations<CltFeatureId>> {
524        let enc = self.loaded_encoder.as_ref().ok_or_else(|| {
525            MIError::Hook(format!(
526                "no encoder loaded — call load_encoder({layer}) first"
527            ))
528        })?;
529        if enc.layer != layer {
530            return Err(MIError::Hook(format!(
531                "loaded encoder is for layer {}, but layer {layer} was requested",
532                enc.layer
533            )));
534        }
535
536        // Compute pre-activations in F32 for numerical stability.
537        // W_enc: [n_features, d_model], residual: [d_model]
538        // pre_acts = W_enc @ residual + b_enc → [n_features]
539        let residual_f32 = residual.flatten_all()?;
540        // PROMOTE: matmul and bias add require F32 for numerical stability
541        let residual_f32 = residual_f32.to_dtype(DType::F32)?;
542        let w_enc_f32 = enc.w_enc.to_dtype(DType::F32)?;
543        let b_enc_f32 = enc.b_enc.to_dtype(DType::F32)?;
544
545        let pre_acts = w_enc_f32.matmul(&residual_f32.unsqueeze(1)?)?.squeeze(1)?;
546        let pre_acts = (&pre_acts + &b_enc_f32)?;
547
548        // ReLU activation.
549        let acts = pre_acts.relu()?;
550
551        // Transfer to CPU for sparse extraction.
552        let acts_vec: Vec<f32> = acts.to_vec1()?;
553
554        let mut features: Vec<(CltFeatureId, f32)> = acts_vec
555            .iter()
556            .enumerate()
557            .filter(|&(_, v)| *v > 0.0)
558            .map(|(i, v)| (CltFeatureId { layer, index: i }, *v))
559            .collect();
560
561        // Sort by activation magnitude (descending).
562        features.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
563
564        Ok(SparseActivations { features })
565    }
566
567    /// Encode and return only the top-k most active features.
568    ///
569    /// # Shapes
570    /// - `residual`: `[d_model]` — residual stream activation at one position
571    /// - returns: [`SparseActivations<CltFeatureId>`] truncated to at most `k` entries
572    ///
573    /// # Requires
574    /// [`load_encoder(layer)`](Self::load_encoder) must have been called first.
575    ///
576    /// # Errors
577    ///
578    /// Same as [`encode()`](Self::encode).
579    pub fn top_k(
580        &self,
581        residual: &Tensor,
582        layer: usize,
583        k: usize,
584    ) -> Result<SparseActivations<CltFeatureId>> {
585        let mut sparse = self.encode(residual, layer)?;
586        sparse.truncate(k);
587        Ok(sparse)
588    }
589
590    // --- Decoder access ---
591
592    /// Extract a single feature's decoder vector for a target downstream layer.
593    ///
594    /// Loads from safetensors on demand. Checks the steering cache first
595    /// to avoid redundant file reads.
596    ///
597    /// # Shapes
598    /// - returns: `[d_model]` — decoder vector on `device`
599    ///
600    /// # Arguments
601    /// * `feature` — The CLT feature to extract the decoder for
602    /// * `target_layer` — The downstream layer to decode to (must be >= feature.layer)
603    /// * `device` — Device to place the resulting tensor on
604    ///
605    /// # Errors
606    ///
607    /// Returns [`MIError::Config`] if layer indices are out of range.
608    /// Returns [`MIError::Download`] if the decoder file cannot be fetched.
609    /// Returns [`MIError::Model`] on tensor operation failure.
610    pub fn decoder_vector(
611        &mut self,
612        feature: &CltFeatureId,
613        target_layer: usize,
614        device: &Device,
615    ) -> Result<Tensor> {
616        if feature.layer >= self.config.n_layers {
617            return Err(MIError::Config(format!(
618                "feature source layer {} out of range (CLT has {} layers)",
619                feature.layer, self.config.n_layers
620            )));
621        }
622        if target_layer < feature.layer || target_layer >= self.config.n_layers {
623            return Err(MIError::Config(format!(
624                "target layer {target_layer} must be >= source layer {} and < {}",
625                feature.layer, self.config.n_layers
626            )));
627        }
628        if feature.index >= self.config.n_features_per_layer {
629            return Err(MIError::Config(format!(
630                "feature index {} out of range (max {})",
631                feature.index, self.config.n_features_per_layer
632            )));
633        }
634
635        // Check steering cache first.
636        let cache_key = (*feature, target_layer);
637        if let Some(cached) = self.steering_cache.get(&cache_key) {
638            return Ok(cached.clone());
639        }
640
641        // W_dec_l has shape [n_features, n_layers - l, d_model]
642        // target_offset = target_layer - feature.layer
643        let target_offset = target_layer - feature.layer;
644
645        let dec_path = self.ensure_decoder_path(feature.layer)?;
646        let data = std::fs::read(&dec_path)?;
647        let st = SafeTensors::deserialize(&data).map_err(|e| {
648            MIError::Config(format!(
649                "failed to deserialize decoder layer {}: {e}",
650                feature.layer
651            ))
652        })?;
653
654        let dec_name = format!("W_dec_{}", feature.layer);
655        let w_dec = tensor_from_view(
656            &st.tensor(&dec_name)
657                .map_err(|e| MIError::Config(format!("tensor '{dec_name}' not found: {e}")))?,
658            &Device::Cpu,
659        )?;
660
661        // w_dec[feature.index, target_offset, :] → [d_model]
662        let column = w_dec.i((feature.index, target_offset))?;
663
664        // Transfer to target device.
665        let column = column.to_device(device)?;
666
667        Ok(column)
668    }
669
670    // --- Micro-cache ---
671
672    /// Pre-load decoder vectors into the steering micro-cache.
673    ///
674    /// Each entry is a `(CltFeatureId, target_layer)` pair. Vectors are
675    /// loaded to the specified device and kept pinned for repeated injection.
676    ///
677    /// Uses an OOM-safe pattern: loads each decoder file to CPU, extracts needed
678    /// columns as independent F32 tensors, drops the large file, then moves
679    /// small tensors to the target device.
680    ///
681    /// Memory: 50 features × 2304 × 4 bytes = ~450 KB (negligible).
682    ///
683    /// # Errors
684    ///
685    /// Returns [`MIError::Download`] if decoder files cannot be fetched.
686    /// Returns [`MIError::Model`] on tensor operation failure.
687    pub fn cache_steering_vectors(
688        &mut self,
689        features: &[(CltFeatureId, usize)],
690        device: &Device,
691    ) -> Result<()> {
692        // Group by source layer to batch decoder file reads.
693        let mut by_source: HashMap<usize, Vec<(usize, usize)>> = HashMap::new();
694        for (fid, target_layer) in features {
695            by_source
696                .entry(fid.layer)
697                .or_default()
698                .push((fid.index, *target_layer));
699        }
700
701        let mut loaded = 0_usize;
702        let n_source_layers = by_source.len();
703        for (layer_idx, (source_layer, entries)) in by_source.iter().enumerate() {
704            info!(
705                "cache_steering_vectors: loading decoder for source layer {} ({}/{})",
706                source_layer,
707                layer_idx + 1,
708                n_source_layers
709            );
710
711            // Group by target_layer to identify needed offsets.
712            let mut by_target: HashMap<usize, Vec<usize>> = HashMap::new();
713            for &(index, target_layer) in entries {
714                by_target.entry(target_layer).or_default().push(index);
715            }
716
717            // Load decoder file, extract needed columns as independent CPU
718            // tensors, then drop the large file data BEFORE any GPU transfer.
719            // This prevents OOM when early-layer decoders can be >1.6 GB each.
720            let mut cpu_columns: Vec<(CltFeatureId, usize, Tensor)> = Vec::new();
721            {
722                let dec_path = self.ensure_decoder_path(*source_layer)?;
723                let data = std::fs::read(&dec_path)?;
724                info!(
725                    "cache_steering_vectors: loaded {} MB for layer {}",
726                    data.len() / (1024 * 1024),
727                    source_layer
728                );
729                let st = SafeTensors::deserialize(&data).map_err(|e| {
730                    MIError::Config(format!(
731                        "failed to deserialize decoder layer {source_layer}: {e}"
732                    ))
733                })?;
734                let dec_name = format!("W_dec_{source_layer}");
735                let w_dec = tensor_from_view(
736                    &st.tensor(&dec_name).map_err(|e| {
737                        MIError::Config(format!("tensor '{dec_name}' not found: {e}"))
738                    })?,
739                    &Device::Cpu,
740                )?;
741
742                for (target_layer, indices) in &by_target {
743                    let target_offset = target_layer - source_layer;
744                    for &index in indices {
745                        let fid = CltFeatureId {
746                            layer: *source_layer,
747                            index,
748                        };
749                        let cache_key = (fid, *target_layer);
750                        if !self.steering_cache.contains_key(&cache_key) {
751                            // Extract as independent F32 tensor: to_dtype +
752                            // to_vec1 copies data OUT of candle's Arc storage,
753                            // so dropping w_dec truly frees the ~1.6 GB decoder.
754                            let view = w_dec.i((index, target_offset))?;
755                            let dims = view.dims().to_vec();
756                            // PROMOTE: F32 for numerical stability in accumulation
757                            let values = view.to_dtype(DType::F32)?.to_vec1::<f32>()?;
758                            let independent =
759                                Tensor::from_vec(values, dims.as_slice(), &Device::Cpu)?;
760                            cpu_columns.push((fid, *target_layer, independent));
761                        }
762                    }
763                }
764                // data, st, w_dec all drop here — freeing the large decoder file
765            }
766
767            // Now move the small independent columns to the target device.
768            for (fid, target_layer, cpu_tensor) in cpu_columns {
769                let cache_key = (fid, target_layer);
770                if let std::collections::hash_map::Entry::Vacant(e) =
771                    self.steering_cache.entry(cache_key)
772                {
773                    let device_tensor = cpu_tensor.to_device(device)?;
774                    e.insert(device_tensor);
775                    loaded += 1;
776                }
777            }
778        }
779
780        info!(
781            "Cached {loaded} new steering vectors ({} total in cache)",
782            self.steering_cache.len()
783        );
784        Ok(())
785    }
786
787    /// Cache steering vectors for ALL downstream layers of each feature.
788    ///
789    /// For each feature at source layer `l`, caches decoder vectors for every
790    /// downstream target layer `l..n_layers`. This enables multi-layer
791    /// "clamping" injection where the steering signal propagates through all
792    /// downstream transformer layers.
793    ///
794    /// Same OOM-safe pattern as [`cache_steering_vectors()`](Self::cache_steering_vectors).
795    ///
796    /// # Arguments
797    /// * `features` — Feature IDs to cache (all downstream layers are cached automatically)
798    /// * `device` — Device to store cached tensors on (typically GPU)
799    ///
800    /// # Errors
801    ///
802    /// Returns [`MIError::Config`] if any feature layer is out of range.
803    /// Returns [`MIError::Download`] if decoder files cannot be fetched.
804    /// Returns [`MIError::Model`] on tensor operation failure.
805    pub fn cache_steering_vectors_all_downstream(
806        &mut self,
807        features: &[CltFeatureId],
808        device: &Device,
809    ) -> Result<()> {
810        let n_layers = self.config.n_layers;
811
812        // Group by source layer to batch decoder file reads.
813        let mut by_source: HashMap<usize, Vec<usize>> = HashMap::new();
814        for fid in features {
815            if fid.layer >= n_layers {
816                return Err(MIError::Config(format!(
817                    "feature source layer {} out of range (max {})",
818                    fid.layer,
819                    n_layers - 1
820                )));
821            }
822            by_source.entry(fid.layer).or_default().push(fid.index);
823        }
824
825        let mut loaded = 0_usize;
826        let n_source_layers = by_source.len();
827        for (layer_idx, (source_layer, indices)) in by_source.iter().enumerate() {
828            let n_target_layers = n_layers - source_layer;
829            info!(
830                "cache_steering_vectors_all_downstream: loading decoder for source layer {} \
831                 ({}/{}, {} downstream layers)",
832                source_layer,
833                layer_idx + 1,
834                n_source_layers,
835                n_target_layers
836            );
837
838            // Load decoder file, extract ALL offsets as independent CPU tensors, then drop.
839            let mut cpu_columns: Vec<(CltFeatureId, usize, Tensor)> = Vec::new();
840            {
841                let dec_path = self.ensure_decoder_path(*source_layer)?;
842                let data = std::fs::read(&dec_path)?;
843                info!(
844                    "cache_steering_vectors_all_downstream: loaded {} MB for layer {}",
845                    data.len() / (1024 * 1024),
846                    source_layer
847                );
848                let st = SafeTensors::deserialize(&data).map_err(|e| {
849                    MIError::Config(format!(
850                        "failed to deserialize decoder layer {source_layer}: {e}"
851                    ))
852                })?;
853                let dec_name = format!("W_dec_{source_layer}");
854                let w_dec = tensor_from_view(
855                    &st.tensor(&dec_name).map_err(|e| {
856                        MIError::Config(format!("tensor '{dec_name}' not found: {e}"))
857                    })?,
858                    &Device::Cpu,
859                )?;
860
861                for &index in indices {
862                    let fid = CltFeatureId {
863                        layer: *source_layer,
864                        index,
865                    };
866                    for target_offset in 0..n_target_layers {
867                        let target_layer = source_layer + target_offset;
868                        let cache_key = (fid, target_layer);
869                        if !self.steering_cache.contains_key(&cache_key) {
870                            let view = w_dec.i((index, target_offset))?;
871                            let dims = view.dims().to_vec();
872                            // PROMOTE: F32 for numerical stability in accumulation
873                            let values = view.to_dtype(DType::F32)?.to_vec1::<f32>()?;
874                            let independent =
875                                Tensor::from_vec(values, dims.as_slice(), &Device::Cpu)?;
876                            cpu_columns.push((fid, target_layer, independent));
877                        }
878                    }
879                }
880                // data, st, w_dec all drop here — freeing the large decoder file
881            }
882
883            // Move small independent columns to the target device.
884            for (fid, target_layer, cpu_tensor) in cpu_columns {
885                let cache_key = (fid, target_layer);
886                if let std::collections::hash_map::Entry::Vacant(e) =
887                    self.steering_cache.entry(cache_key)
888                {
889                    let device_tensor = cpu_tensor.to_device(device)?;
890                    e.insert(device_tensor);
891                    loaded += 1;
892                }
893            }
894        }
895
896        info!(
897            "Cached {loaded} new steering vectors across all downstream layers ({} total in cache)",
898            self.steering_cache.len()
899        );
900        Ok(())
901    }
902
903    /// Clear all cached steering vectors, freeing device memory.
904    pub fn clear_steering_cache(&mut self) {
905        let count = self.steering_cache.len();
906        self.steering_cache.clear();
907        if count > 0 {
908            info!("Cleared {count} steering vectors from cache");
909        }
910    }
911
912    /// Number of vectors currently in the steering cache.
913    #[must_use]
914    pub fn steering_cache_len(&self) -> usize {
915        self.steering_cache.len()
916    }
917
918    // --- Injection ---
919
920    /// Build a [`crate::HookSpec`] that injects CLT decoder vectors into the residual stream.
921    ///
922    /// Groups cached steering vectors by target layer, accumulates them per layer,
923    /// scales by `strength`, and creates [`crate::Intervention::Add`] entries on
924    /// [`crate::HookPoint::ResidPost`] for each target layer. The resulting `HookSpec`
925    /// can be passed directly to [`MIModel::forward()`](crate::MIModel::forward).
926    ///
927    /// # Shapes
928    /// - Internally constructs `[1, seq_len, d_model]` tensors with the steering
929    ///   vector placed at `position` and zeros elsewhere.
930    ///
931    /// # Arguments
932    /// * `features` — List of `(feature_id, target_layer)` pairs (must be cached)
933    /// * `position` — Token position in the sequence to inject at
934    /// * `seq_len` — Total sequence length (needed to construct position-specific tensors)
935    /// * `strength` — Scalar multiplier for the accumulated steering vectors
936    /// * `device` — Device to construct injection tensors on
937    ///
938    /// # Errors
939    ///
940    /// Returns [`MIError::Hook`] if any feature is not in the steering cache.
941    /// Returns [`MIError::Model`] on tensor construction failure.
942    pub fn prepare_hook_injection(
943        &self,
944        features: &[(CltFeatureId, usize)],
945        position: usize,
946        seq_len: usize,
947        strength: f32,
948        device: &Device,
949    ) -> Result<crate::hooks::HookSpec> {
950        use crate::hooks::{HookPoint, HookSpec, Intervention};
951
952        // Group features by target layer and accumulate their decoder vectors.
953        let mut per_layer: HashMap<usize, Tensor> = HashMap::new();
954        for (feature, target_layer) in features {
955            let cache_key = (*feature, *target_layer);
956            let cached = self.steering_cache.get(&cache_key).ok_or_else(|| {
957                MIError::Hook(format!(
958                    "feature {feature} for target layer {target_layer} not in steering cache \
959                     — call cache_steering_vectors() first"
960                ))
961            })?;
962            // PROMOTE: accumulate in F32 for numerical stability
963            let cached_f32 = cached.to_dtype(DType::F32)?;
964            if let Some(acc) = per_layer.get_mut(target_layer) {
965                let acc_ref: &Tensor = acc;
966                *acc = (acc_ref + &cached_f32)?;
967            } else {
968                per_layer.insert(*target_layer, cached_f32);
969            }
970        }
971
972        // Build HookSpec with Intervention::Add at each target layer.
973        let mut hooks = HookSpec::new();
974        let d_model = self.config.d_model;
975
976        for (target_layer, accumulated) in &per_layer {
977            // Scale by strength.
978            let scaled = (accumulated * f64::from(strength))?;
979
980            // Build a [1, seq_len, d_model] tensor with the vector at `position`.
981            let mut injection = Tensor::zeros((1, seq_len, d_model), DType::F32, device)?;
982
983            // Place the scaled vector at the target position.
984            let scaled_3d = scaled.unsqueeze(0)?.unsqueeze(0)?; // [1, 1, d_model]
985            let before = if position > 0 {
986                Some(injection.narrow(1, 0, position)?)
987            } else {
988                None
989            };
990            let after = if position + 1 < seq_len {
991                Some(injection.narrow(1, position + 1, seq_len - position - 1)?)
992            } else {
993                None
994            };
995
996            let mut parts: Vec<Tensor> = Vec::with_capacity(3);
997            if let Some(b) = before {
998                parts.push(b);
999            }
1000            parts.push(scaled_3d);
1001            if let Some(a) = after {
1002                parts.push(a);
1003            }
1004
1005            injection = Tensor::cat(&parts, 1)?;
1006
1007            hooks.intervene(
1008                HookPoint::ResidPost(*target_layer),
1009                Intervention::Add(injection),
1010            );
1011        }
1012
1013        Ok(hooks)
1014    }
1015
1016    /// Inject cached steering vectors directly into a residual stream tensor.
1017    ///
1018    /// Convenience method for use outside the forward pass (e.g., in analysis
1019    /// scripts). Returns a new tensor with the injection applied:
1020    /// `residual[:, position, :] += strength × Σ decoder_vectors`
1021    ///
1022    /// # Shapes
1023    /// - `residual`: `[batch, seq_len, d_model]` — hidden states
1024    /// - returns: `[batch, seq_len, d_model]` — modified hidden states
1025    ///
1026    /// # Arguments
1027    /// * `residual` — Hidden states tensor
1028    /// * `features` — List of `(feature, target_layer)` pairs to inject (must be cached)
1029    /// * `position` — Token position in the sequence to inject at
1030    /// * `strength` — Scalar multiplier for the steering vectors
1031    ///
1032    /// # Errors
1033    ///
1034    /// Returns [`MIError::Hook`] if any feature is not in the steering cache.
1035    /// Returns [`MIError::Config`] if dimensions don't match.
1036    /// Returns [`MIError::Model`] on tensor operation failure.
1037    pub fn inject(
1038        &self,
1039        residual: &Tensor,
1040        features: &[(CltFeatureId, usize)],
1041        position: usize,
1042        strength: f32,
1043    ) -> Result<Tensor> {
1044        let (batch, seq_len, d_model) = residual.dims3()?;
1045        if position >= seq_len {
1046            return Err(MIError::Config(format!(
1047                "injection position {position} out of range (seq_len={seq_len})"
1048            )));
1049        }
1050        if d_model != self.config.d_model {
1051            return Err(MIError::Config(format!(
1052                "residual d_model={d_model} doesn't match CLT d_model={}",
1053                self.config.d_model
1054            )));
1055        }
1056
1057        // Accumulate all steering vectors into one vector (F32 for stability).
1058        let mut accumulated = Tensor::zeros((d_model,), DType::F32, residual.device())?;
1059        for (feature, target_layer) in features {
1060            let cache_key = (*feature, *target_layer);
1061            let cached = self.steering_cache.get(&cache_key).ok_or_else(|| {
1062                MIError::Hook(format!(
1063                    "feature {feature} for target layer {target_layer} not in steering cache"
1064                ))
1065            })?;
1066            // PROMOTE: accumulate in F32 for numerical stability
1067            let cached_f32 = cached.to_dtype(DType::F32)?;
1068            accumulated = (&accumulated + &cached_f32)?;
1069        }
1070
1071        // Scale by strength.
1072        let accumulated = (accumulated * f64::from(strength))?;
1073
1074        // Convert to residual dtype.
1075        let accumulated = accumulated.to_dtype(residual.dtype())?;
1076
1077        // Build steering tensor and inject at position.
1078        let pos_slice = residual.narrow(1, position, 1)?; // [batch, 1, d_model]
1079        let steering_expanded = accumulated
1080            .unsqueeze(0)?
1081            .unsqueeze(0)?
1082            .expand((batch, 1, d_model))?; // [batch, 1, d_model]
1083        let pos_updated = (&pos_slice + &steering_expanded)?;
1084
1085        // Reassemble: before + updated_position + after.
1086        let mut parts: Vec<Tensor> = Vec::with_capacity(3);
1087        if position > 0 {
1088            parts.push(residual.narrow(1, 0, position)?);
1089        }
1090        parts.push(pos_updated);
1091        if position + 1 < seq_len {
1092            parts.push(residual.narrow(1, position + 1, seq_len - position - 1)?);
1093        }
1094
1095        let result = Tensor::cat(&parts, 1)?;
1096        Ok(result)
1097    }
1098
1099    // --- Attribution / decoder scoring ---
1100
1101    /// Score all CLT features by how strongly their decoder vector at
1102    /// `target_layer` projects along a given direction vector.
1103    ///
1104    /// For each source layer `0..n_layers` where `source_layer <= target_layer`:
1105    /// loads the decoder file to CPU, extracts the target layer slice
1106    /// `[n_features, d_model]`, and computes `scores = slice @ direction`.
1107    ///
1108    /// When `cosine` is true, scores are normalized by both the direction
1109    /// vector norm and each decoder row norm (cosine similarity).
1110    ///
1111    /// # Shapes
1112    /// - `direction`: `[d_model]` — target direction vector (e.g., token embedding)
1113    /// - returns: top-k `(CltFeatureId, f32)` pairs, sorted by score descending
1114    ///
1115    /// # Arguments
1116    /// * `direction` — `[d_model]` direction vector to project decoders onto
1117    /// * `target_layer` — downstream layer to examine decoders at
1118    /// * `top_k` — number of top-scoring features to return
1119    /// * `cosine` — whether to use cosine similarity instead of dot product
1120    ///
1121    /// # Errors
1122    ///
1123    /// Returns [`MIError::Config`] if `direction` shape is wrong or `target_layer`
1124    /// is out of range.
1125    /// Returns [`MIError::Download`] if decoder files cannot be fetched.
1126    /// Returns [`MIError::Model`] on tensor operation failure.
1127    ///
1128    /// # Memory
1129    ///
1130    /// Processes one decoder file at a time on CPU (up to ~2 GB for layer 0).
1131    /// No GPU memory required.
1132    pub fn score_features_by_decoder_projection(
1133        &mut self,
1134        direction: &Tensor,
1135        target_layer: usize,
1136        top_k: usize,
1137        cosine: bool,
1138    ) -> Result<Vec<(CltFeatureId, f32)>> {
1139        let d_model = self.config.d_model;
1140        if direction.dims() != [d_model] {
1141            return Err(MIError::Config(format!(
1142                "direction must have shape [{d_model}], got {:?}",
1143                direction.dims()
1144            )));
1145        }
1146        if target_layer >= self.config.n_layers {
1147            return Err(MIError::Config(format!(
1148                "target layer {target_layer} out of range (max {})",
1149                self.config.n_layers - 1
1150            )));
1151        }
1152
1153        // PROMOTE: F32 for dot-product precision matching Python reference
1154        let direction_f32 = direction.to_dtype(DType::F32)?.to_device(&Device::Cpu)?;
1155
1156        // Optionally normalize direction to unit length for cosine similarity.
1157        let direction_norm = if cosine {
1158            let norm: f32 = direction_f32.sqr()?.sum_all()?.sqrt()?.to_scalar()?;
1159            if norm > 1e-10 {
1160                direction_f32.broadcast_div(&Tensor::new(norm, &Device::Cpu)?)?
1161            } else {
1162                direction_f32
1163            }
1164        } else {
1165            direction_f32
1166        };
1167
1168        let mut all_scores: Vec<(CltFeatureId, f32)> = Vec::new();
1169
1170        for source_layer in 0..self.config.n_layers {
1171            if target_layer < source_layer {
1172                continue; // This source layer cannot decode to target_layer.
1173            }
1174            let target_offset = target_layer - source_layer;
1175
1176            // Load decoder file to CPU.
1177            let dec_path = self.ensure_decoder_path(source_layer)?;
1178            let data = std::fs::read(&dec_path)?;
1179            info!(
1180                "score_features_by_decoder_projection: loaded {} MB for layer {}",
1181                data.len() / (1024 * 1024),
1182                source_layer
1183            );
1184            let st = SafeTensors::deserialize(&data).map_err(|e| {
1185                MIError::Config(format!(
1186                    "failed to deserialize decoder layer {source_layer}: {e}"
1187                ))
1188            })?;
1189
1190            let dec_name = format!("W_dec_{source_layer}");
1191            let w_dec = tensor_from_view(
1192                &st.tensor(&dec_name)
1193                    .map_err(|e| MIError::Config(format!("tensor '{dec_name}' not found: {e}")))?,
1194                &Device::Cpu,
1195            )?;
1196            // PROMOTE: decoder weights are BF16 on disk; F32 for matmul precision
1197            let w_dec_f32 = w_dec.to_dtype(DType::F32)?;
1198
1199            // Extract target layer slice: [n_features, d_model]
1200            let dec_slice = w_dec_f32.i((.., target_offset, ..))?;
1201
1202            // raw_scores = dec_slice @ direction_norm → [n_features]
1203            let raw_scores = dec_slice
1204                .matmul(&direction_norm.unsqueeze(1)?)?
1205                .squeeze(1)?;
1206
1207            let scores_vec: Vec<f32> = if cosine {
1208                // Divide by each decoder row's L2 norm → cosine similarity.
1209                let dec_norms = dec_slice.sqr()?.sum(1)?.sqrt()?;
1210                let cosine_scores = raw_scores.broadcast_div(&dec_norms)?;
1211                cosine_scores.to_vec1()?
1212            } else {
1213                raw_scores.to_vec1()?
1214            };
1215
1216            for (idx, &score) in scores_vec.iter().enumerate() {
1217                if score.is_finite() {
1218                    all_scores.push((
1219                        CltFeatureId {
1220                            layer: source_layer,
1221                            index: idx,
1222                        },
1223                        score,
1224                    ));
1225                }
1226            }
1227
1228            info!(
1229                "Scored {} features at source layer {source_layer} (target layer {target_layer})",
1230                scores_vec.len()
1231            );
1232        }
1233
1234        // Sort by score descending, take top-k.
1235        all_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1236        all_scores.truncate(top_k);
1237
1238        Ok(all_scores)
1239    }
1240
1241    /// Batch version of [`score_features_by_decoder_projection`](Self::score_features_by_decoder_projection).
1242    ///
1243    /// Scores multiple direction vectors against all decoder files in a single
1244    /// pass. Each decoder file is loaded **once** for all directions, reducing
1245    /// I/O from `n_words × n_layers` file reads to just `n_layers`.
1246    ///
1247    /// # Shapes
1248    /// - `directions`: slice of `[d_model]` tensors (one per word/direction)
1249    /// - returns: one `Vec<(CltFeatureId, f32)>` per direction (top-k per word)
1250    ///
1251    /// # Arguments
1252    /// * `directions` — slice of `[d_model]` direction vectors
1253    /// * `target_layer` — downstream layer to examine decoders at
1254    /// * `top_k` — number of top-scoring features to return per direction
1255    /// * `cosine` — whether to use cosine similarity
1256    ///
1257    /// # Errors
1258    ///
1259    /// Returns [`MIError::Config`] if any direction has wrong shape, directions is
1260    /// empty, or `target_layer` is out of range.
1261    /// Returns [`MIError::Download`] if decoder files cannot be fetched.
1262    /// Returns [`MIError::Model`] on tensor operation failure.
1263    ///
1264    /// # Memory
1265    ///
1266    /// Stacks directions to `[n_words, d_model]` on CPU. Each decoder file
1267    /// loaded one at a time (up to ~2 GB for layer 0). No GPU memory required.
1268    pub fn score_features_by_decoder_projection_batch(
1269        &mut self,
1270        directions: &[Tensor],
1271        target_layer: usize,
1272        top_k: usize,
1273        cosine: bool,
1274    ) -> Result<Vec<Vec<(CltFeatureId, f32)>>> {
1275        let d_model = self.config.d_model;
1276        let n_words = directions.len();
1277        if n_words == 0 {
1278            return Err(MIError::Config(
1279                "at least one direction vector required".into(),
1280            ));
1281        }
1282        for (i, dir) in directions.iter().enumerate() {
1283            if dir.dims() != [d_model] {
1284                return Err(MIError::Config(format!(
1285                    "direction vector {i} must have shape [{d_model}], got {:?}",
1286                    dir.dims()
1287                )));
1288            }
1289        }
1290        if target_layer >= self.config.n_layers {
1291            return Err(MIError::Config(format!(
1292                "target layer {target_layer} out of range (max {})",
1293                self.config.n_layers - 1
1294            )));
1295        }
1296
1297        // PROMOTE: directions may arrive as BF16; F32 for matmul precision
1298        let dirs_f32: Vec<Tensor> = directions
1299            .iter()
1300            .map(|d| d.to_dtype(DType::F32)?.to_device(&Device::Cpu))
1301            .collect::<std::result::Result<_, _>>()?;
1302        let stacked = Tensor::stack(&dirs_f32, 0)?; // [n_words, d_model]
1303
1304        // For cosine: row-normalize direction vectors to unit length.
1305        let stacked_norm = if cosine {
1306            let norms = stacked.sqr()?.sum(1)?.sqrt()?; // [n_words]
1307            let ones = Tensor::ones_like(&norms)?;
1308            let safe_norms = norms.maximum(&(&ones * 1e-10f64)?)?; // [n_words]
1309            stacked.broadcast_div(&safe_norms.unsqueeze(1)?)?
1310        } else {
1311            stacked
1312        };
1313        let directions_t = stacked_norm.t()?; // [d_model, n_words]
1314
1315        // Per-word score accumulators.
1316        let mut all_scores: Vec<Vec<(CltFeatureId, f32)>> =
1317            (0..n_words).map(|_| Vec::new()).collect();
1318
1319        for source_layer in 0..self.config.n_layers {
1320            if target_layer < source_layer {
1321                continue;
1322            }
1323            let target_offset = target_layer - source_layer;
1324
1325            // Load decoder file ONCE for all words.
1326            let dec_path = self.ensure_decoder_path(source_layer)?;
1327            let data = std::fs::read(&dec_path)?;
1328            info!(
1329                "score_features_batch: loaded {} MB for layer {}",
1330                data.len() / (1024 * 1024),
1331                source_layer
1332            );
1333            let st = SafeTensors::deserialize(&data).map_err(|e| {
1334                MIError::Config(format!(
1335                    "failed to deserialize decoder layer {source_layer}: {e}"
1336                ))
1337            })?;
1338            let dec_name = format!("W_dec_{source_layer}");
1339            let w_dec = tensor_from_view(
1340                &st.tensor(&dec_name)
1341                    .map_err(|e| MIError::Config(format!("tensor '{dec_name}' not found: {e}")))?,
1342                &Device::Cpu,
1343            )?;
1344            // PROMOTE: decoder weights are BF16 on disk; F32 for matmul precision
1345            let w_dec_f32 = w_dec.to_dtype(DType::F32)?;
1346            let dec_slice = w_dec_f32.i((.., target_offset, ..))?; // [n_features, d_model]
1347
1348            // Batch matmul: [n_features, d_model] × [d_model, n_words] = [n_features, n_words]
1349            let raw_scores = dec_slice.matmul(&directions_t)?;
1350
1351            // Transpose to [n_words, n_features] for easy extraction.
1352            let scores_2d: Vec<Vec<f32>> = if cosine {
1353                let dec_norms = dec_slice.sqr()?.sum(1)?.sqrt()?; // [n_features]
1354                let cosine_scores = raw_scores.broadcast_div(&dec_norms.unsqueeze(1)?)?;
1355                cosine_scores.t()?.to_vec2()?
1356            } else {
1357                raw_scores.t()?.to_vec2()?
1358            };
1359
1360            for (w, word_scores) in scores_2d.iter().enumerate() {
1361                for (idx, &score) in word_scores.iter().enumerate() {
1362                    if score.is_finite() {
1363                        if let Some(word_vec) = all_scores.get_mut(w) {
1364                            word_vec.push((
1365                                CltFeatureId {
1366                                    layer: source_layer,
1367                                    index: idx,
1368                                },
1369                                score,
1370                            ));
1371                        }
1372                    }
1373                }
1374            }
1375
1376            info!(
1377                "Batch scored {} words × {} features at source layer {} (target layer {})",
1378                n_words,
1379                scores_2d.first().map_or(0, Vec::len),
1380                source_layer,
1381                target_layer
1382            );
1383        }
1384
1385        // Sort and truncate per word.
1386        for word_scores in &mut all_scores {
1387            word_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1388            word_scores.truncate(top_k);
1389        }
1390
1391        Ok(all_scores)
1392    }
1393
1394    /// Extract decoder vectors for a set of features at a specific target layer.
1395    ///
1396    /// Groups features by source layer, loads each decoder file once, and
1397    /// extracts the decoder vector at the target layer offset as an independent
1398    /// F32 CPU tensor. Uses the OOM-safe `to_vec1` + `from_vec` pattern to
1399    /// ensure large decoder files are freed before processing the next layer.
1400    ///
1401    /// # Shapes
1402    /// - returns: `HashMap<CltFeatureId, Tensor>` where each tensor is `[d_model]` (F32, CPU)
1403    ///
1404    /// # Arguments
1405    /// * `features` — feature IDs to extract decoder vectors for
1406    /// * `target_layer` — downstream layer to extract decoders at
1407    ///
1408    /// # Errors
1409    ///
1410    /// Returns [`MIError::Config`] if any feature layer or `target_layer` is out
1411    /// of range, or if `target_layer < feature.layer` for any feature.
1412    /// Returns [`MIError::Download`] if decoder files cannot be fetched.
1413    /// Returns [`MIError::Model`] on tensor operation failure.
1414    ///
1415    /// # Memory
1416    ///
1417    /// Loads each decoder to CPU (up to ~2 GB), extracts independent F32
1418    /// tensors, then drops the large file before processing the next layer.
1419    pub fn extract_decoder_vectors(
1420        &mut self,
1421        features: &[CltFeatureId],
1422        target_layer: usize,
1423    ) -> Result<HashMap<CltFeatureId, Tensor>> {
1424        if target_layer >= self.config.n_layers {
1425            return Err(MIError::Config(format!(
1426                "target layer {target_layer} out of range (max {})",
1427                self.config.n_layers - 1
1428            )));
1429        }
1430
1431        // Group by source layer.
1432        let mut by_source: HashMap<usize, Vec<usize>> = HashMap::new();
1433        for fid in features {
1434            if fid.layer >= self.config.n_layers {
1435                return Err(MIError::Config(format!(
1436                    "feature source layer {} out of range (max {})",
1437                    fid.layer,
1438                    self.config.n_layers - 1
1439                )));
1440            }
1441            if target_layer < fid.layer {
1442                return Err(MIError::Config(format!(
1443                    "target layer {target_layer} must be >= source layer {}",
1444                    fid.layer
1445                )));
1446            }
1447            by_source.entry(fid.layer).or_default().push(fid.index);
1448        }
1449
1450        let mut result: HashMap<CltFeatureId, Tensor> = HashMap::new();
1451        let n_source_layers = by_source.len();
1452
1453        for (layer_idx, (source_layer, indices)) in by_source.iter().enumerate() {
1454            info!(
1455                "extract_decoder_vectors: loading decoder for source layer {} ({}/{})",
1456                source_layer,
1457                layer_idx + 1,
1458                n_source_layers
1459            );
1460            let target_offset = target_layer - source_layer;
1461
1462            // Load decoder file to CPU, extract needed rows as independent tensors.
1463            let dec_path = self.ensure_decoder_path(*source_layer)?;
1464            let data = std::fs::read(&dec_path)?;
1465            let st = SafeTensors::deserialize(&data).map_err(|e| {
1466                MIError::Config(format!(
1467                    "failed to deserialize decoder layer {source_layer}: {e}"
1468                ))
1469            })?;
1470            let dec_name = format!("W_dec_{source_layer}");
1471            let w_dec = tensor_from_view(
1472                &st.tensor(&dec_name)
1473                    .map_err(|e| MIError::Config(format!("tensor '{dec_name}' not found: {e}")))?,
1474                &Device::Cpu,
1475            )?;
1476
1477            for &index in indices {
1478                let fid = CltFeatureId {
1479                    layer: *source_layer,
1480                    index,
1481                };
1482                if let std::collections::hash_map::Entry::Vacant(e) = result.entry(fid) {
1483                    // Extract as independent F32 tensor (OOM-safe copy).
1484                    let view = w_dec.i((index, target_offset))?;
1485                    let dims = view.dims().to_vec();
1486                    // PROMOTE: decoder weights are BF16 on disk; extract as F32
1487                    let values = view.to_dtype(DType::F32)?.to_vec1::<f32>()?;
1488                    let independent = Tensor::from_vec(values, dims.as_slice(), &Device::Cpu)?;
1489                    e.insert(independent);
1490                }
1491            }
1492            // data, st, w_dec drop here — freeing the large decoder file.
1493        }
1494
1495        info!(
1496            "Extracted {} decoder vectors across {} source layers",
1497            result.len(),
1498            n_source_layers
1499        );
1500
1501        Ok(result)
1502    }
1503
1504    /// Build an attribution graph by scoring features against a direction.
1505    ///
1506    /// Convenience wrapper around
1507    /// [`score_features_by_decoder_projection`](Self::score_features_by_decoder_projection)
1508    /// that returns an [`AttributionGraph`] instead of a raw Vec.
1509    ///
1510    /// # Shapes
1511    /// - `direction`: `[d_model]`
1512    ///
1513    /// # Errors
1514    ///
1515    /// Same as [`score_features_by_decoder_projection`](Self::score_features_by_decoder_projection).
1516    pub fn build_attribution_graph(
1517        &mut self,
1518        direction: &Tensor,
1519        target_layer: usize,
1520        top_k: usize,
1521        cosine: bool,
1522    ) -> Result<AttributionGraph> {
1523        let scored =
1524            self.score_features_by_decoder_projection(direction, target_layer, top_k, cosine)?;
1525        Ok(AttributionGraph {
1526            target_layer,
1527            edges: scored
1528                .into_iter()
1529                .map(|(feature, score)| AttributionEdge { feature, score })
1530                .collect(),
1531        })
1532    }
1533
1534    /// Build attribution graphs for multiple directions in a single pass.
1535    ///
1536    /// Convenience wrapper around
1537    /// [`score_features_by_decoder_projection_batch`](Self::score_features_by_decoder_projection_batch)
1538    /// that returns `Vec<AttributionGraph>`.
1539    ///
1540    /// # Shapes
1541    /// - `directions`: slice of `[d_model]` tensors
1542    ///
1543    /// # Errors
1544    ///
1545    /// Same as [`score_features_by_decoder_projection_batch`](Self::score_features_by_decoder_projection_batch).
1546    pub fn build_attribution_graph_batch(
1547        &mut self,
1548        directions: &[Tensor],
1549        target_layer: usize,
1550        top_k: usize,
1551        cosine: bool,
1552    ) -> Result<Vec<AttributionGraph>> {
1553        let batch = self.score_features_by_decoder_projection_batch(
1554            directions,
1555            target_layer,
1556            top_k,
1557            cosine,
1558        )?;
1559        Ok(batch
1560            .into_iter()
1561            .map(|scored| AttributionGraph {
1562                target_layer,
1563                edges: scored
1564                    .into_iter()
1565                    .map(|(feature, score)| AttributionEdge { feature, score })
1566                    .collect(),
1567            })
1568            .collect())
1569    }
1570}
1571
1572// ---------------------------------------------------------------------------
1573// Helper functions
1574// ---------------------------------------------------------------------------
1575
1576/// Convert a safetensors `TensorView` to a candle `Tensor`.
1577///
1578/// # Shapes
1579/// - Preserves the original tensor shape from safetensors.
1580///
1581/// # Errors
1582///
1583/// Returns [`MIError::Config`] if the tensor dtype is not supported (BF16, F16, F32).
1584/// Returns [`MIError::Model`] on tensor construction failure.
1585fn tensor_from_view(view: &safetensors::tensor::TensorView<'_>, device: &Device) -> Result<Tensor> {
1586    let shape: Vec<usize> = view.shape().to_vec();
1587    #[allow(clippy::wildcard_enum_match_arm)]
1588    // EXHAUSTIVE: safetensors exposes many dtypes; CLTs only use float types
1589    let dtype = match view.dtype() {
1590        safetensors::Dtype::BF16 => DType::BF16,
1591        safetensors::Dtype::F16 => DType::F16,
1592        safetensors::Dtype::F32 => DType::F32,
1593        other => {
1594            return Err(MIError::Config(format!(
1595                "unsupported CLT tensor dtype: {other:?}"
1596            )));
1597        }
1598    };
1599    let tensor = Tensor::from_raw_buffer(view.data(), dtype, &shape, device)?;
1600    Ok(tensor)
1601}
1602
1603/// Parse a value from a simple YAML file by key.
1604///
1605/// No `serde_yaml` dependency — uses line-by-line matching.
1606fn parse_yaml_value(yaml_text: &str, key: &str) -> Option<String> {
1607    for line in yaml_text.lines() {
1608        let line = line.trim();
1609        if let Some(rest) = line.strip_prefix(key) {
1610            if let Some(rest) = rest.strip_prefix(':') {
1611                let value = rest.trim().trim_matches('"');
1612                return Some(value.to_owned());
1613            }
1614        }
1615    }
1616    None
1617}
1618
1619// ---------------------------------------------------------------------------
1620// Tests
1621// ---------------------------------------------------------------------------
1622
1623#[cfg(test)]
1624#[allow(clippy::unwrap_used, clippy::expect_used)]
1625mod tests {
1626    use super::*;
1627
1628    #[test]
1629    fn clt_feature_id_display() {
1630        let fid = CltFeatureId {
1631            layer: 5,
1632            index: 42,
1633        };
1634        assert_eq!(fid.to_string(), "L5:42");
1635    }
1636
1637    #[test]
1638    fn clt_feature_id_ordering() {
1639        let a = CltFeatureId {
1640            layer: 0,
1641            index: 10,
1642        };
1643        let b = CltFeatureId {
1644            layer: 0,
1645            index: 20,
1646        };
1647        let c = CltFeatureId { layer: 1, index: 0 };
1648        assert!(a < b);
1649        assert!(b < c);
1650    }
1651
1652    #[test]
1653    fn sparse_activations_basics() {
1654        let features = vec![
1655            (CltFeatureId { layer: 0, index: 5 }, 3.0),
1656            (CltFeatureId { layer: 0, index: 2 }, 2.0),
1657            (CltFeatureId { layer: 0, index: 8 }, 1.0),
1658        ];
1659        let sparse = SparseActivations { features };
1660        assert_eq!(sparse.len(), 3);
1661        assert!(!sparse.is_empty());
1662    }
1663
1664    #[test]
1665    fn sparse_activations_truncate() {
1666        let features = vec![
1667            (CltFeatureId { layer: 0, index: 5 }, 3.0),
1668            (CltFeatureId { layer: 0, index: 2 }, 2.0),
1669            (CltFeatureId { layer: 0, index: 8 }, 1.0),
1670        ];
1671        let mut sparse = SparseActivations { features };
1672        sparse.truncate(2);
1673        assert_eq!(sparse.len(), 2);
1674        assert_eq!(sparse.features[0].0.index, 5);
1675        assert_eq!(sparse.features[1].0.index, 2);
1676    }
1677
1678    #[test]
1679    fn parse_yaml_value_basic() {
1680        let yaml = "model_name: \"google/gemma-2-2b\"\nmodel_kind: cross_layer_transcoder\n";
1681        assert_eq!(
1682            parse_yaml_value(yaml, "model_name"),
1683            Some("google/gemma-2-2b".to_owned())
1684        );
1685        assert_eq!(
1686            parse_yaml_value(yaml, "model_kind"),
1687            Some("cross_layer_transcoder".to_owned())
1688        );
1689        assert_eq!(parse_yaml_value(yaml, "missing_key"), None);
1690    }
1691
1692    #[test]
1693    fn encode_synthetic() {
1694        // Create a small synthetic encoder: 4 features, d_model=8
1695        let device = Device::Cpu;
1696        let d_model = 8;
1697        let n_features = 4;
1698
1699        // W_enc: [4, 8] — identity-like rows so we can predict output
1700        #[rustfmt::skip]
1701        let w_enc_data: Vec<f32> = vec![
1702            1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, // feature 0: picks up residual[0]
1703            0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, // feature 1: picks up residual[1]
1704            0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, // feature 2: picks up residual[2]
1705            0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, // feature 3: picks up residual[3]
1706        ];
1707        let w_enc = Tensor::from_vec(w_enc_data, (n_features, d_model), &device).unwrap();
1708
1709        // b_enc: [4] — bias shifts to test ReLU
1710        let b_enc_data: Vec<f32> = vec![0.0, -0.5, 0.0, -2.0]; // feature 3 will need residual[3] > 2.0
1711        let b_enc = Tensor::from_vec(b_enc_data, (n_features,), &device).unwrap();
1712
1713        // Residual: [8] — values: [1.5, 0.3, 0.0, 1.0, ...]
1714        let residual_data: Vec<f32> = vec![1.5, 0.3, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0];
1715        let residual = Tensor::from_vec(residual_data, (d_model,), &device).unwrap();
1716
1717        // Expected pre_acts = W_enc @ residual + b_enc
1718        // = [1.5, 0.3, 0.0, 1.0] + [0.0, -0.5, 0.0, -2.0]
1719        // = [1.5, -0.2, 0.0, -1.0]
1720        // After ReLU: [1.5, 0.0, 0.0, 0.0]
1721        // Only feature 0 is active with activation 1.5
1722
1723        // Create a fake loaded encoder
1724        let clt = CrossLayerTranscoder {
1725            repo_id: "test".to_owned(),
1726            fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
1727            encoder_paths: vec![None],
1728            decoder_paths: vec![None],
1729            config: CltConfig {
1730                n_layers: 1,
1731                d_model,
1732                n_features_per_layer: n_features,
1733                n_features_total: n_features,
1734                model_name: "test".to_owned(),
1735            },
1736            loaded_encoder: Some(LoadedEncoder {
1737                layer: 0,
1738                w_enc,
1739                b_enc,
1740            }),
1741            steering_cache: HashMap::new(),
1742        };
1743
1744        let sparse = clt.encode(&residual, 0).unwrap();
1745        assert_eq!(sparse.len(), 1, "only feature 0 should be active");
1746        assert_eq!(sparse.features[0].0.index, 0);
1747        assert!((sparse.features[0].1 - 1.5).abs() < 1e-5);
1748    }
1749
1750    #[test]
1751    fn encode_wrong_layer_errors() {
1752        let device = Device::Cpu;
1753        let w_enc = Tensor::zeros((4, 8), DType::F32, &device).unwrap();
1754        let b_enc = Tensor::zeros((4,), DType::F32, &device).unwrap();
1755        let residual = Tensor::zeros((8,), DType::F32, &device).unwrap();
1756
1757        let clt = CrossLayerTranscoder {
1758            repo_id: "test".to_owned(),
1759            fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
1760            encoder_paths: vec![None; 2],
1761            decoder_paths: vec![None; 2],
1762            config: CltConfig {
1763                n_layers: 2,
1764                d_model: 8,
1765                n_features_per_layer: 4,
1766                n_features_total: 8,
1767                model_name: "test".to_owned(),
1768            },
1769            loaded_encoder: Some(LoadedEncoder {
1770                layer: 0,
1771                w_enc,
1772                b_enc,
1773            }),
1774            steering_cache: HashMap::new(),
1775        };
1776
1777        // Requesting layer 1 when layer 0 is loaded should error.
1778        let result = clt.encode(&residual, 1);
1779        assert!(result.is_err());
1780    }
1781
1782    #[test]
1783    fn inject_position() {
1784        let device = Device::Cpu;
1785        let d_model = 4;
1786
1787        // Residual: [1, 3, 4] — batch=1, seq_len=3, d_model=4
1788        let residual = Tensor::ones((1, 3, d_model), DType::F32, &device).unwrap();
1789
1790        // Create a CLT with a pre-cached steering vector.
1791        let fid = CltFeatureId { layer: 0, index: 0 };
1792        let target_layer = 1;
1793        let steering_vec =
1794            Tensor::from_vec(vec![10.0_f32, 20.0, 30.0, 40.0], (d_model,), &device).unwrap();
1795
1796        let mut steering_cache = HashMap::new();
1797        steering_cache.insert((fid, target_layer), steering_vec);
1798
1799        let clt = CrossLayerTranscoder {
1800            repo_id: "test".to_owned(),
1801            fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
1802            encoder_paths: vec![None; 2],
1803            decoder_paths: vec![None; 2],
1804            config: CltConfig {
1805                n_layers: 2,
1806                d_model,
1807                n_features_per_layer: 1,
1808                n_features_total: 2,
1809                model_name: "test".to_owned(),
1810            },
1811            loaded_encoder: None,
1812            steering_cache,
1813        };
1814
1815        // Inject at position 1 with strength 1.0
1816        let result = clt
1817            .inject(&residual, &[(fid, target_layer)], 1, 1.0)
1818            .unwrap();
1819
1820        // Position 0 should be unchanged (all 1.0)
1821        let pos0: Vec<f32> = result.i((0, 0)).unwrap().to_vec1().unwrap();
1822        assert_eq!(pos0, vec![1.0, 1.0, 1.0, 1.0]);
1823
1824        // Position 1 should have the steering vector added (1 + [10, 20, 30, 40])
1825        let pos1: Vec<f32> = result.i((0, 1)).unwrap().to_vec1().unwrap();
1826        assert_eq!(pos1, vec![11.0, 21.0, 31.0, 41.0]);
1827
1828        // Position 2 should be unchanged
1829        let pos2: Vec<f32> = result.i((0, 2)).unwrap().to_vec1().unwrap();
1830        assert_eq!(pos2, vec![1.0, 1.0, 1.0, 1.0]);
1831    }
1832
1833    #[test]
1834    fn prepare_hook_injection_creates_correct_hooks() {
1835        use crate::hooks::HookPoint;
1836
1837        let device = Device::Cpu;
1838        let d_model = 4;
1839
1840        let fid = CltFeatureId { layer: 0, index: 0 };
1841        let target_layer = 5;
1842        let steering_vec =
1843            Tensor::from_vec(vec![1.0_f32, 2.0, 3.0, 4.0], (d_model,), &device).unwrap();
1844
1845        let mut steering_cache = HashMap::new();
1846        steering_cache.insert((fid, target_layer), steering_vec);
1847
1848        let clt = CrossLayerTranscoder {
1849            repo_id: "test".to_owned(),
1850            fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
1851            encoder_paths: vec![None; 10],
1852            decoder_paths: vec![None; 10],
1853            config: CltConfig {
1854                n_layers: 10,
1855                d_model,
1856                n_features_per_layer: 1,
1857                n_features_total: 10,
1858                model_name: "test".to_owned(),
1859            },
1860            loaded_encoder: None,
1861            steering_cache,
1862        };
1863
1864        let hooks = clt
1865            .prepare_hook_injection(&[(fid, target_layer)], 2, 5, 1.0, &device)
1866            .unwrap();
1867
1868        // Should have an intervention at ResidPost(5).
1869        assert!(hooks.has_intervention_at(&HookPoint::ResidPost(target_layer)));
1870        // Should NOT have interventions at other layers.
1871        assert!(!hooks.has_intervention_at(&HookPoint::ResidPost(0)));
1872        assert!(!hooks.has_intervention_at(&HookPoint::ResidPost(4)));
1873    }
1874
1875    // ====================================================================
1876    // Attribution graph — pure type tests
1877    // ====================================================================
1878
1879    #[test]
1880    fn attribution_edge_basics() {
1881        let edge = AttributionEdge {
1882            feature: CltFeatureId {
1883                layer: 3,
1884                index: 42,
1885            },
1886            score: 0.75,
1887        };
1888        assert_eq!(edge.feature.layer, 3);
1889        assert_eq!(edge.feature.index, 42);
1890        assert!((edge.score - 0.75).abs() < f32::EPSILON);
1891    }
1892
1893    #[test]
1894    fn attribution_graph_empty() {
1895        let graph = AttributionGraph {
1896            target_layer: 5,
1897            edges: Vec::new(),
1898        };
1899        assert_eq!(graph.target_layer(), 5);
1900        assert!(graph.is_empty());
1901        assert_eq!(graph.len(), 0);
1902        assert!(graph.features().is_empty());
1903        assert!(graph.into_edges().is_empty());
1904    }
1905
1906    #[test]
1907    fn attribution_graph_top_k() {
1908        let edges = vec![
1909            AttributionEdge {
1910                feature: CltFeatureId { layer: 0, index: 0 },
1911                score: 5.0,
1912            },
1913            AttributionEdge {
1914                feature: CltFeatureId { layer: 0, index: 1 },
1915                score: 3.0,
1916            },
1917            AttributionEdge {
1918                feature: CltFeatureId { layer: 1, index: 0 },
1919                score: 1.0,
1920            },
1921            AttributionEdge {
1922                feature: CltFeatureId { layer: 1, index: 1 },
1923                score: -1.0,
1924            },
1925            AttributionEdge {
1926                feature: CltFeatureId { layer: 2, index: 0 },
1927                score: -4.0,
1928            },
1929        ];
1930        let graph = AttributionGraph {
1931            target_layer: 3,
1932            edges,
1933        };
1934
1935        assert_eq!(graph.len(), 5);
1936
1937        let top3 = graph.top_k(3);
1938        assert_eq!(top3.len(), 3);
1939        assert_eq!(top3.target_layer(), 3);
1940        assert!((top3.edges()[0].score - 5.0).abs() < f32::EPSILON);
1941        assert!((top3.edges()[1].score - 3.0).abs() < f32::EPSILON);
1942        assert!((top3.edges()[2].score - 1.0).abs() < f32::EPSILON);
1943
1944        // top_k larger than graph size returns all edges.
1945        let top10 = graph.top_k(10);
1946        assert_eq!(top10.len(), 5);
1947    }
1948
1949    #[test]
1950    fn attribution_graph_threshold() {
1951        let edges = vec![
1952            AttributionEdge {
1953                feature: CltFeatureId { layer: 0, index: 0 },
1954                score: 5.0,
1955            },
1956            AttributionEdge {
1957                feature: CltFeatureId { layer: 0, index: 1 },
1958                score: 3.0,
1959            },
1960            AttributionEdge {
1961                feature: CltFeatureId { layer: 1, index: 0 },
1962                score: 1.0,
1963            },
1964            AttributionEdge {
1965                feature: CltFeatureId { layer: 1, index: 1 },
1966                score: -1.0,
1967            },
1968            AttributionEdge {
1969                feature: CltFeatureId { layer: 2, index: 0 },
1970                score: -4.0,
1971            },
1972        ];
1973        let graph = AttributionGraph {
1974            target_layer: 3,
1975            edges,
1976        };
1977
1978        // Threshold at 2.0 keeps |score| >= 2.0: 5.0, 3.0, -4.0
1979        let pruned = graph.threshold(2.0);
1980        assert_eq!(pruned.len(), 3);
1981        assert!((pruned.edges()[0].score - 5.0).abs() < f32::EPSILON);
1982        assert!((pruned.edges()[1].score - 3.0).abs() < f32::EPSILON);
1983        assert!((pruned.edges()[2].score - -4.0).abs() < f32::EPSILON);
1984    }
1985
1986    #[test]
1987    fn attribution_graph_features() {
1988        let edges = vec![
1989            AttributionEdge {
1990                feature: CltFeatureId { layer: 2, index: 7 },
1991                score: 1.0,
1992            },
1993            AttributionEdge {
1994                feature: CltFeatureId { layer: 0, index: 3 },
1995                score: 0.5,
1996            },
1997        ];
1998        let graph = AttributionGraph {
1999            target_layer: 5,
2000            edges,
2001        };
2002
2003        let features = graph.features();
2004        assert_eq!(features.len(), 2);
2005        assert_eq!(features[0], CltFeatureId { layer: 2, index: 7 });
2006        assert_eq!(features[1], CltFeatureId { layer: 0, index: 3 });
2007    }
2008
2009    // ====================================================================
2010    // Attribution graph — synthetic decoder file tests
2011    // ====================================================================
2012
2013    /// Create a synthetic decoder safetensors file and return its path.
2014    fn create_synthetic_decoder(
2015        dir: &std::path::Path,
2016        layer: usize,
2017        n_features: usize,
2018        n_target_layers: usize,
2019        d_model: usize,
2020        values: &[f32],
2021    ) -> PathBuf {
2022        assert_eq!(values.len(), n_features * n_target_layers * d_model);
2023        let bytes: Vec<u8> = values.iter().flat_map(|v| v.to_le_bytes()).collect();
2024        let name = format!("W_dec_{layer}");
2025        let shape = vec![n_features, n_target_layers, d_model];
2026        let view =
2027            safetensors::tensor::TensorView::new(safetensors::Dtype::F32, shape, &bytes).unwrap();
2028        let mut tensors = HashMap::new();
2029        tensors.insert(name, view);
2030        let serialized = safetensors::serialize(&tensors, &None).unwrap();
2031        let path = dir.join(format!("W_dec_{layer}.safetensors"));
2032        std::fs::write(&path, serialized).unwrap();
2033        path
2034    }
2035
2036    #[test]
2037    fn score_decoder_projection_synthetic() {
2038        // 2 layers, 4 features/layer, d_model=4.
2039        // Layer 0 can decode to layers 0 and 1. Layer 1 can decode to layer 1.
2040        // Target layer = 1.
2041        let dir = tempfile::tempdir().unwrap();
2042        let d_model = 4;
2043        let n_features = 4;
2044
2045        // W_dec_0: [4 features, 2 target_layers, 4 d_model]
2046        // Feature 0, offset 1 (target layer 1): [1, 0, 0, 0]
2047        // Feature 1, offset 1: [0, 1, 0, 0]
2048        // Feature 2, offset 1: [0, 0, 1, 0]
2049        // Feature 3, offset 1: [0, 0, 0, 1]
2050        #[rustfmt::skip]
2051        let dec0_values: Vec<f32> = vec![
2052            // feature 0: offset 0, offset 1
2053            0.0, 0.0, 0.0, 0.0,  1.0, 0.0, 0.0, 0.0,
2054            // feature 1
2055            0.0, 0.0, 0.0, 0.0,  0.0, 1.0, 0.0, 0.0,
2056            // feature 2
2057            0.0, 0.0, 0.0, 0.0,  0.0, 0.0, 1.0, 0.0,
2058            // feature 3
2059            0.0, 0.0, 0.0, 0.0,  0.0, 0.0, 0.0, 1.0,
2060        ];
2061        let path0 = create_synthetic_decoder(dir.path(), 0, n_features, 2, d_model, &dec0_values);
2062
2063        // W_dec_1: [4 features, 1 target_layer, 4 d_model]
2064        // Feature 0, offset 0 (target layer 1): [2, 0, 0, 0]  (strong on dim 0)
2065        // Feature 1: [0, 0, 0, 0]
2066        // Feature 2: [0, 0, 0, 0]
2067        // Feature 3: [0, 3, 0, 0]  (strong on dim 1)
2068        #[rustfmt::skip]
2069        let dec1_values: Vec<f32> = vec![
2070            2.0, 0.0, 0.0, 0.0,
2071            0.0, 0.0, 0.0, 0.0,
2072            0.0, 0.0, 0.0, 0.0,
2073            0.0, 3.0, 0.0, 0.0,
2074        ];
2075        let path1 = create_synthetic_decoder(dir.path(), 1, n_features, 1, d_model, &dec1_values);
2076
2077        let mut clt = CrossLayerTranscoder {
2078            repo_id: "test".to_owned(),
2079            fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
2080            encoder_paths: vec![None; 2],
2081            decoder_paths: vec![Some(path0), Some(path1)],
2082            config: CltConfig {
2083                n_layers: 2,
2084                d_model,
2085                n_features_per_layer: n_features,
2086                n_features_total: n_features * 2,
2087                model_name: "test".to_owned(),
2088            },
2089            loaded_encoder: None,
2090            steering_cache: HashMap::new(),
2091        };
2092
2093        // Direction: [1, 0, 0, 0] — should pick up L0:0 (score=1) and L1:0 (score=2).
2094        let direction =
2095            Tensor::from_vec(vec![1.0_f32, 0.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
2096
2097        let scores = clt
2098            .score_features_by_decoder_projection(&direction, 1, 10, false)
2099            .unwrap();
2100
2101        // Top scorer should be L1:0 (score=2), then L0:0 (score=1).
2102        assert!(scores.len() >= 2, "expected at least 2 non-zero scores");
2103        assert_eq!(scores[0].0, CltFeatureId { layer: 1, index: 0 });
2104        assert!((scores[0].1 - 2.0).abs() < 1e-5);
2105        assert_eq!(scores[1].0, CltFeatureId { layer: 0, index: 0 });
2106        assert!((scores[1].1 - 1.0).abs() < 1e-5);
2107
2108        // Direction: [0, 1, 0, 0] — should pick up L1:3 (score=3) and L0:1 (score=1).
2109        let direction2 =
2110            Tensor::from_vec(vec![0.0_f32, 1.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
2111
2112        let scores2 = clt
2113            .score_features_by_decoder_projection(&direction2, 1, 10, false)
2114            .unwrap();
2115
2116        assert_eq!(scores2[0].0, CltFeatureId { layer: 1, index: 3 });
2117        assert!((scores2[0].1 - 3.0).abs() < 1e-5);
2118        assert_eq!(scores2[1].0, CltFeatureId { layer: 0, index: 1 });
2119        assert!((scores2[1].1 - 1.0).abs() < 1e-5);
2120    }
2121
2122    #[test]
2123    fn score_decoder_projection_cosine_synthetic() {
2124        // Same setup: verify cosine normalization.
2125        let dir = tempfile::tempdir().unwrap();
2126        let d_model = 4;
2127        let n_features = 2;
2128
2129        // W_dec_0: [2 features, 1 target_layer, 4 d_model]
2130        // Feature 0: [3, 0, 0, 0]  (length 3, aligned with [1,0,0,0])
2131        // Feature 1: [1, 1, 0, 0]  (length sqrt(2), partially aligned)
2132        #[rustfmt::skip]
2133        let dec0_values: Vec<f32> = vec![
2134            3.0, 0.0, 0.0, 0.0,
2135            1.0, 1.0, 0.0, 0.0,
2136        ];
2137        let path0 = create_synthetic_decoder(dir.path(), 0, n_features, 1, d_model, &dec0_values);
2138
2139        let mut clt = CrossLayerTranscoder {
2140            repo_id: "test".to_owned(),
2141            fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
2142            encoder_paths: vec![None],
2143            decoder_paths: vec![Some(path0)],
2144            config: CltConfig {
2145                n_layers: 1,
2146                d_model,
2147                n_features_per_layer: n_features,
2148                n_features_total: n_features,
2149                model_name: "test".to_owned(),
2150            },
2151            loaded_encoder: None,
2152            steering_cache: HashMap::new(),
2153        };
2154
2155        let direction =
2156            Tensor::from_vec(vec![1.0_f32, 0.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
2157
2158        // Dot product: feature 0 = 3.0, feature 1 = 1.0.
2159        let dot_scores = clt
2160            .score_features_by_decoder_projection(&direction, 0, 10, false)
2161            .unwrap();
2162        assert!((dot_scores[0].1 - 3.0).abs() < 1e-5);
2163        assert!((dot_scores[1].1 - 1.0).abs() < 1e-5);
2164
2165        // Cosine: feature 0 = 1.0 (perfectly aligned), feature 1 = 1/sqrt(2) ≈ 0.707.
2166        let cos_scores = clt
2167            .score_features_by_decoder_projection(&direction, 0, 10, true)
2168            .unwrap();
2169        assert!(
2170            (cos_scores[0].1 - 1.0).abs() < 1e-4,
2171            "expected ~1.0, got {}",
2172            cos_scores[0].1
2173        );
2174        let expected_cos = 1.0 / 2.0_f32.sqrt();
2175        assert!(
2176            (cos_scores[1].1 - expected_cos).abs() < 1e-4,
2177            "expected ~{expected_cos}, got {}",
2178            cos_scores[1].1
2179        );
2180    }
2181
2182    #[test]
2183    fn score_decoder_projection_batch_synthetic() {
2184        let dir = tempfile::tempdir().unwrap();
2185        let d_model = 4;
2186        let n_features = 2;
2187
2188        // W_dec_0: feature 0 = [1,0,0,0], feature 1 = [0,1,0,0]
2189        #[rustfmt::skip]
2190        let dec0_values: Vec<f32> = vec![
2191            1.0, 0.0, 0.0, 0.0,
2192            0.0, 1.0, 0.0, 0.0,
2193        ];
2194        let path0 = create_synthetic_decoder(dir.path(), 0, n_features, 1, d_model, &dec0_values);
2195
2196        let mut clt = CrossLayerTranscoder {
2197            repo_id: "test".to_owned(),
2198            fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
2199            encoder_paths: vec![None],
2200            decoder_paths: vec![Some(path0)],
2201            config: CltConfig {
2202                n_layers: 1,
2203                d_model,
2204                n_features_per_layer: n_features,
2205                n_features_total: n_features,
2206                model_name: "test".to_owned(),
2207            },
2208            loaded_encoder: None,
2209            steering_cache: HashMap::new(),
2210        };
2211
2212        // Two directions: [1,0,0,0] and [0,1,0,0].
2213        let dir0 =
2214            Tensor::from_vec(vec![1.0_f32, 0.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
2215        let dir1 =
2216            Tensor::from_vec(vec![0.0_f32, 1.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
2217
2218        let batch = clt
2219            .score_features_by_decoder_projection_batch(&[dir0, dir1], 0, 10, false)
2220            .unwrap();
2221
2222        assert_eq!(batch.len(), 2);
2223
2224        // Direction 0 should score feature 0 highest.
2225        assert_eq!(batch[0][0].0, CltFeatureId { layer: 0, index: 0 });
2226        assert!((batch[0][0].1 - 1.0).abs() < 1e-5);
2227
2228        // Direction 1 should score feature 1 highest.
2229        assert_eq!(batch[1][0].0, CltFeatureId { layer: 0, index: 1 });
2230        assert!((batch[1][0].1 - 1.0).abs() < 1e-5);
2231    }
2232
2233    #[test]
2234    fn extract_decoder_vectors_synthetic() {
2235        let dir = tempfile::tempdir().unwrap();
2236        let d_model = 4;
2237        let n_features = 3;
2238
2239        // W_dec_0: [3 features, 2 target_layers, 4 d_model]
2240        #[rustfmt::skip]
2241        let dec0_values: Vec<f32> = vec![
2242            // feature 0: offset 0, offset 1
2243            1.0, 2.0, 3.0, 4.0,  5.0, 6.0, 7.0, 8.0,
2244            // feature 1
2245            9.0, 10.0, 11.0, 12.0,  13.0, 14.0, 15.0, 16.0,
2246            // feature 2
2247            17.0, 18.0, 19.0, 20.0,  21.0, 22.0, 23.0, 24.0,
2248        ];
2249        let path0 = create_synthetic_decoder(dir.path(), 0, n_features, 2, d_model, &dec0_values);
2250
2251        let mut clt = CrossLayerTranscoder {
2252            repo_id: "test".to_owned(),
2253            fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
2254            encoder_paths: vec![None; 2],
2255            decoder_paths: vec![Some(path0), None],
2256            config: CltConfig {
2257                n_layers: 2,
2258                d_model,
2259                n_features_per_layer: n_features,
2260                n_features_total: n_features * 2,
2261                model_name: "test".to_owned(),
2262            },
2263            loaded_encoder: None,
2264            steering_cache: HashMap::new(),
2265        };
2266
2267        let features = vec![
2268            CltFeatureId { layer: 0, index: 0 },
2269            CltFeatureId { layer: 0, index: 2 },
2270        ];
2271
2272        // Extract at target_layer=1 (offset 1 for source layer 0).
2273        let vectors = clt.extract_decoder_vectors(&features, 1).unwrap();
2274        assert_eq!(vectors.len(), 2);
2275
2276        // Feature 0, offset 1: [5, 6, 7, 8]
2277        let v0: Vec<f32> = vectors[&CltFeatureId { layer: 0, index: 0 }]
2278            .to_vec1()
2279            .unwrap();
2280        assert_eq!(v0, vec![5.0, 6.0, 7.0, 8.0]);
2281
2282        // Feature 2, offset 1: [21, 22, 23, 24]
2283        let v2: Vec<f32> = vectors[&CltFeatureId { layer: 0, index: 2 }]
2284            .to_vec1()
2285            .unwrap();
2286        assert_eq!(v2, vec![21.0, 22.0, 23.0, 24.0]);
2287    }
2288
2289    #[test]
2290    fn build_attribution_graph_synthetic() {
2291        let dir = tempfile::tempdir().unwrap();
2292        let d_model = 4;
2293        let n_features = 2;
2294
2295        #[rustfmt::skip]
2296        let dec0_values: Vec<f32> = vec![
2297            1.0, 0.0, 0.0, 0.0,
2298            0.0, 2.0, 0.0, 0.0,
2299        ];
2300        let path0 = create_synthetic_decoder(dir.path(), 0, n_features, 1, d_model, &dec0_values);
2301
2302        let mut clt = CrossLayerTranscoder {
2303            repo_id: "test".to_owned(),
2304            fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
2305            encoder_paths: vec![None],
2306            decoder_paths: vec![Some(path0)],
2307            config: CltConfig {
2308                n_layers: 1,
2309                d_model,
2310                n_features_per_layer: n_features,
2311                n_features_total: n_features,
2312                model_name: "test".to_owned(),
2313            },
2314            loaded_encoder: None,
2315            steering_cache: HashMap::new(),
2316        };
2317
2318        let direction =
2319            Tensor::from_vec(vec![0.0_f32, 1.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
2320
2321        let graph = clt
2322            .build_attribution_graph(&direction, 0, 10, false)
2323            .unwrap();
2324
2325        assert_eq!(graph.target_layer(), 0);
2326        assert!(!graph.is_empty());
2327        // Feature 1 has score 2.0, feature 0 has score 0.0.
2328        assert_eq!(
2329            graph.edges()[0].feature,
2330            CltFeatureId { layer: 0, index: 1 }
2331        );
2332        assert!((graph.edges()[0].score - 2.0).abs() < 1e-5);
2333
2334        // Pruning: threshold at 1.0 should keep only feature 1.
2335        let pruned = graph.threshold(1.0);
2336        assert_eq!(pruned.len(), 1);
2337        assert_eq!(pruned.features()[0], CltFeatureId { layer: 0, index: 1 });
2338    }
2339}