1use std::collections::HashMap;
36use std::path::PathBuf;
37
38use candle_core::{DType, Device, IndexOp, Tensor};
39use safetensors::tensor::SafeTensors;
40use tracing::info;
41
42use crate::error::{MIError, Result};
43
44#[derive(
50 Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize,
51)]
52pub struct CltFeatureId {
53 pub layer: usize,
55 pub index: usize,
57}
58
59impl std::fmt::Display for CltFeatureId {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 write!(f, "L{}:{}", self.layer, self.index)
62 }
63}
64
65use crate::sparse::{FeatureId, SparseActivations};
66
67impl FeatureId for CltFeatureId {}
68
69#[derive(Debug, Clone)]
75pub struct AttributionEdge {
76 pub feature: CltFeatureId,
78 pub score: f32,
80}
81
82#[derive(Debug, Clone)]
96pub struct AttributionGraph {
97 target_layer: usize,
99 edges: Vec<AttributionEdge>,
101}
102
103impl AttributionGraph {
104 #[must_use]
106 pub const fn target_layer(&self) -> usize {
107 self.target_layer
108 }
109
110 #[must_use]
112 pub fn edges(&self) -> &[AttributionEdge] {
113 &self.edges
114 }
115
116 #[must_use]
118 pub const fn len(&self) -> usize {
119 self.edges.len()
120 }
121
122 #[must_use]
124 pub const fn is_empty(&self) -> bool {
125 self.edges.is_empty()
126 }
127
128 #[must_use]
130 pub fn top_k(&self, k: usize) -> Self {
131 Self {
132 target_layer: self.target_layer,
133 edges: self.edges.iter().take(k).cloned().collect(),
134 }
135 }
136
137 #[must_use]
140 pub fn threshold(&self, min_score: f32) -> Self {
141 Self {
142 target_layer: self.target_layer,
143 edges: self
144 .edges
145 .iter()
146 .filter(|e| e.score.abs() >= min_score)
147 .cloned()
148 .collect(),
149 }
150 }
151
152 #[must_use]
154 pub fn features(&self) -> Vec<CltFeatureId> {
155 self.edges.iter().map(|e| e.feature).collect()
156 }
157
158 #[must_use]
160 pub fn into_edges(self) -> Vec<AttributionEdge> {
161 self.edges
162 }
163}
164
165#[derive(Debug, Clone)]
167pub struct CltConfig {
168 pub n_layers: usize,
170 pub d_model: usize,
172 pub n_features_per_layer: usize,
174 pub n_features_total: usize,
176 pub model_name: String,
178}
179
180struct LoadedEncoder {
186 layer: usize,
188 w_enc: Tensor,
193 b_enc: Tensor,
198}
199
200pub struct CrossLayerTranscoder {
232 repo_id: String,
234 fetch_config: hf_fetch_model::FetchConfig,
236 encoder_paths: Vec<Option<PathBuf>>,
238 decoder_paths: Vec<Option<PathBuf>>,
240 config: CltConfig,
242 loaded_encoder: Option<LoadedEncoder>,
244 steering_cache: HashMap<(CltFeatureId, usize), Tensor>,
247}
248
249impl CrossLayerTranscoder {
250 pub fn open(clt_repo: &str) -> Result<Self> {
264 let fetch_config = hf_fetch_model::FetchConfig::builder()
265 .on_progress(|event| {
266 tracing::info!(
267 filename = %event.filename,
268 percent = event.percent,
269 bytes_downloaded = event.bytes_downloaded,
270 bytes_total = event.bytes_total,
271 "CLT download progress",
272 );
273 })
274 .build()
275 .map_err(|e| MIError::Download(format!("failed to build fetch config: {e}")))?;
276
277 let rt = tokio::runtime::Runtime::new()
279 .map_err(|e| MIError::Download(format!("failed to create tokio runtime: {e}")))?;
280 let repo_files = rt
281 .block_on(hf_fetch_model::repo::list_repo_files_with_metadata(
282 clt_repo, None, None,
283 ))
284 .map_err(|e| MIError::Download(format!("failed to list repo files: {e}")))?;
285 let n_layers = repo_files
286 .iter()
287 .filter(|f| f.filename.starts_with("W_enc_") && f.filename.ends_with(".safetensors"))
288 .count();
289 if n_layers == 0 {
290 return Err(MIError::Config(format!(
291 "no CLT encoder files found in {clt_repo}"
292 )));
293 }
294
295 let model_name = match hf_fetch_model::download_file_blocking(
297 clt_repo.to_owned(),
298 "config.yaml",
299 &fetch_config,
300 ) {
301 Ok(outcome) => {
302 let path = outcome.into_inner();
303 let text = std::fs::read_to_string(&path)?;
304 parse_yaml_value(&text, "model_name").unwrap_or_else(|| "unknown".to_owned())
305 }
306 Err(_) => "unknown".to_owned(),
307 };
308
309 let enc0_path = hf_fetch_model::download_file_blocking(
311 clt_repo.to_owned(),
312 "W_enc_0.safetensors",
313 &fetch_config,
314 )
315 .map_err(|e| MIError::Download(format!("failed to download W_enc_0: {e}")))?
316 .into_inner();
317
318 let data = std::fs::read(&enc0_path)?;
319 let tensors = SafeTensors::deserialize(&data)
320 .map_err(|e| MIError::Config(format!("failed to deserialize W_enc_0: {e}")))?;
321 let w_enc_view = tensors
322 .tensor("W_enc_0")
323 .map_err(|e| MIError::Config(format!("tensor 'W_enc_0' not found: {e}")))?;
324 let shape = w_enc_view.shape();
325 if shape.len() != 2 {
326 return Err(MIError::Config(format!(
327 "expected 2D encoder weight, got shape {shape:?}"
328 )));
329 }
330 let n_features_per_layer = *shape
331 .first()
332 .ok_or_else(|| MIError::Config("encoder weight shape is empty".into()))?;
333 let d_model = *shape.get(1).ok_or_else(|| {
334 MIError::Config("encoder weight shape has fewer than 2 dimensions".into())
335 })?;
336
337 let mut encoder_paths: Vec<Option<PathBuf>> = vec![None; n_layers];
339 if let Some(slot) = encoder_paths.first_mut() {
340 *slot = Some(enc0_path);
341 }
342 let decoder_paths: Vec<Option<PathBuf>> = vec![None; n_layers];
343
344 let config = CltConfig {
345 n_layers,
346 d_model,
347 n_features_per_layer,
348 n_features_total: n_layers * n_features_per_layer,
349 model_name,
350 };
351 info!(
352 "CLT config: {} layers, d_model={}, features_per_layer={}, total={}",
353 config.n_layers, config.d_model, config.n_features_per_layer, config.n_features_total
354 );
355
356 Ok(Self {
357 repo_id: clt_repo.to_owned(),
358 fetch_config,
359 encoder_paths,
360 decoder_paths,
361 config,
362 loaded_encoder: None,
363 steering_cache: HashMap::new(),
364 })
365 }
366
367 #[must_use]
369 pub const fn config(&self) -> &CltConfig {
370 &self.config
371 }
372
373 #[must_use]
375 pub fn loaded_encoder_layer(&self) -> Option<usize> {
376 self.loaded_encoder.as_ref().map(|e| e.layer)
377 }
378
379 fn ensure_encoder_path(&mut self, layer: usize) -> Result<PathBuf> {
383 if let Some(path) = self
384 .encoder_paths
385 .get(layer)
386 .and_then(std::option::Option::as_ref)
387 {
388 return Ok(path.clone());
390 }
391 let filename = format!("W_enc_{layer}.safetensors");
392 info!("Downloading {filename} from {}", self.repo_id);
393 let path = hf_fetch_model::download_file_blocking(
394 self.repo_id.clone(),
395 &filename,
396 &self.fetch_config,
397 )
398 .map_err(|e| MIError::Download(format!("failed to download {filename}: {e}")))?
399 .into_inner();
400 if let Some(slot) = self.encoder_paths.get_mut(layer) {
401 *slot = Some(path.clone());
403 }
404 Ok(path)
405 }
406
407 fn ensure_decoder_path(&mut self, layer: usize) -> Result<PathBuf> {
409 if let Some(path) = self
410 .decoder_paths
411 .get(layer)
412 .and_then(std::option::Option::as_ref)
413 {
414 return Ok(path.clone());
416 }
417 let filename = format!("W_dec_{layer}.safetensors");
418 info!("Downloading {filename} from {}", self.repo_id);
419 let path = hf_fetch_model::download_file_blocking(
420 self.repo_id.clone(),
421 &filename,
422 &self.fetch_config,
423 )
424 .map_err(|e| MIError::Download(format!("failed to download {filename}: {e}")))?
425 .into_inner();
426 if let Some(slot) = self.decoder_paths.get_mut(layer) {
427 *slot = Some(path.clone());
429 }
430 Ok(path)
431 }
432
433 pub fn load_encoder(&mut self, layer: usize, device: &Device) -> Result<()> {
450 if layer >= self.config.n_layers {
451 return Err(MIError::Config(format!(
452 "layer {layer} out of range (CLT has {} layers)",
453 self.config.n_layers
454 )));
455 }
456
457 if let Some(ref enc) = self.loaded_encoder {
459 if enc.layer == layer {
460 return Ok(());
461 }
462 }
463
464 self.loaded_encoder = None;
466
467 info!("Loading CLT encoder for layer {layer}");
468
469 let enc_path = self.ensure_encoder_path(layer)?;
470 let data = std::fs::read(&enc_path)?;
471 let st = SafeTensors::deserialize(&data).map_err(|e| {
472 MIError::Config(format!("failed to deserialize encoder layer {layer}: {e}"))
473 })?;
474
475 let w_enc_name = format!("W_enc_{layer}");
476 let b_enc_name = format!("b_enc_{layer}");
477
478 let w_enc = tensor_from_view(
479 &st.tensor(&w_enc_name)
480 .map_err(|e| MIError::Config(format!("tensor '{w_enc_name}' not found: {e}")))?,
481 device,
482 )?;
483 let b_enc = tensor_from_view(
484 &st.tensor(&b_enc_name)
485 .map_err(|e| MIError::Config(format!("tensor '{b_enc_name}' not found: {e}")))?,
486 device,
487 )?;
488
489 self.loaded_encoder = Some(LoadedEncoder {
490 layer,
491 w_enc,
492 b_enc,
493 });
494
495 Ok(())
496 }
497
498 pub fn encode(
520 &self,
521 residual: &Tensor,
522 layer: usize,
523 ) -> Result<SparseActivations<CltFeatureId>> {
524 let enc = self.loaded_encoder.as_ref().ok_or_else(|| {
525 MIError::Hook(format!(
526 "no encoder loaded — call load_encoder({layer}) first"
527 ))
528 })?;
529 if enc.layer != layer {
530 return Err(MIError::Hook(format!(
531 "loaded encoder is for layer {}, but layer {layer} was requested",
532 enc.layer
533 )));
534 }
535
536 let residual_f32 = residual.flatten_all()?;
540 let residual_f32 = residual_f32.to_dtype(DType::F32)?;
542 let w_enc_f32 = enc.w_enc.to_dtype(DType::F32)?;
543 let b_enc_f32 = enc.b_enc.to_dtype(DType::F32)?;
544
545 let pre_acts = w_enc_f32.matmul(&residual_f32.unsqueeze(1)?)?.squeeze(1)?;
546 let pre_acts = (&pre_acts + &b_enc_f32)?;
547
548 let acts = pre_acts.relu()?;
550
551 let acts_vec: Vec<f32> = acts.to_vec1()?;
553
554 let mut features: Vec<(CltFeatureId, f32)> = acts_vec
555 .iter()
556 .enumerate()
557 .filter(|&(_, v)| *v > 0.0)
558 .map(|(i, v)| (CltFeatureId { layer, index: i }, *v))
559 .collect();
560
561 features.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
563
564 Ok(SparseActivations { features })
565 }
566
567 pub fn top_k(
580 &self,
581 residual: &Tensor,
582 layer: usize,
583 k: usize,
584 ) -> Result<SparseActivations<CltFeatureId>> {
585 let mut sparse = self.encode(residual, layer)?;
586 sparse.truncate(k);
587 Ok(sparse)
588 }
589
590 pub fn decoder_vector(
611 &mut self,
612 feature: &CltFeatureId,
613 target_layer: usize,
614 device: &Device,
615 ) -> Result<Tensor> {
616 if feature.layer >= self.config.n_layers {
617 return Err(MIError::Config(format!(
618 "feature source layer {} out of range (CLT has {} layers)",
619 feature.layer, self.config.n_layers
620 )));
621 }
622 if target_layer < feature.layer || target_layer >= self.config.n_layers {
623 return Err(MIError::Config(format!(
624 "target layer {target_layer} must be >= source layer {} and < {}",
625 feature.layer, self.config.n_layers
626 )));
627 }
628 if feature.index >= self.config.n_features_per_layer {
629 return Err(MIError::Config(format!(
630 "feature index {} out of range (max {})",
631 feature.index, self.config.n_features_per_layer
632 )));
633 }
634
635 let cache_key = (*feature, target_layer);
637 if let Some(cached) = self.steering_cache.get(&cache_key) {
638 return Ok(cached.clone());
639 }
640
641 let target_offset = target_layer - feature.layer;
644
645 let dec_path = self.ensure_decoder_path(feature.layer)?;
646 let data = std::fs::read(&dec_path)?;
647 let st = SafeTensors::deserialize(&data).map_err(|e| {
648 MIError::Config(format!(
649 "failed to deserialize decoder layer {}: {e}",
650 feature.layer
651 ))
652 })?;
653
654 let dec_name = format!("W_dec_{}", feature.layer);
655 let w_dec = tensor_from_view(
656 &st.tensor(&dec_name)
657 .map_err(|e| MIError::Config(format!("tensor '{dec_name}' not found: {e}")))?,
658 &Device::Cpu,
659 )?;
660
661 let column = w_dec.i((feature.index, target_offset))?;
663
664 let column = column.to_device(device)?;
666
667 Ok(column)
668 }
669
670 pub fn cache_steering_vectors(
688 &mut self,
689 features: &[(CltFeatureId, usize)],
690 device: &Device,
691 ) -> Result<()> {
692 let mut by_source: HashMap<usize, Vec<(usize, usize)>> = HashMap::new();
694 for (fid, target_layer) in features {
695 by_source
696 .entry(fid.layer)
697 .or_default()
698 .push((fid.index, *target_layer));
699 }
700
701 let mut loaded = 0_usize;
702 let n_source_layers = by_source.len();
703 for (layer_idx, (source_layer, entries)) in by_source.iter().enumerate() {
704 info!(
705 "cache_steering_vectors: loading decoder for source layer {} ({}/{})",
706 source_layer,
707 layer_idx + 1,
708 n_source_layers
709 );
710
711 let mut by_target: HashMap<usize, Vec<usize>> = HashMap::new();
713 for &(index, target_layer) in entries {
714 by_target.entry(target_layer).or_default().push(index);
715 }
716
717 let mut cpu_columns: Vec<(CltFeatureId, usize, Tensor)> = Vec::new();
721 {
722 let dec_path = self.ensure_decoder_path(*source_layer)?;
723 let data = std::fs::read(&dec_path)?;
724 info!(
725 "cache_steering_vectors: loaded {} MB for layer {}",
726 data.len() / (1024 * 1024),
727 source_layer
728 );
729 let st = SafeTensors::deserialize(&data).map_err(|e| {
730 MIError::Config(format!(
731 "failed to deserialize decoder layer {source_layer}: {e}"
732 ))
733 })?;
734 let dec_name = format!("W_dec_{source_layer}");
735 let w_dec = tensor_from_view(
736 &st.tensor(&dec_name).map_err(|e| {
737 MIError::Config(format!("tensor '{dec_name}' not found: {e}"))
738 })?,
739 &Device::Cpu,
740 )?;
741
742 for (target_layer, indices) in &by_target {
743 let target_offset = target_layer - source_layer;
744 for &index in indices {
745 let fid = CltFeatureId {
746 layer: *source_layer,
747 index,
748 };
749 let cache_key = (fid, *target_layer);
750 if !self.steering_cache.contains_key(&cache_key) {
751 let view = w_dec.i((index, target_offset))?;
755 let dims = view.dims().to_vec();
756 let values = view.to_dtype(DType::F32)?.to_vec1::<f32>()?;
758 let independent =
759 Tensor::from_vec(values, dims.as_slice(), &Device::Cpu)?;
760 cpu_columns.push((fid, *target_layer, independent));
761 }
762 }
763 }
764 }
766
767 for (fid, target_layer, cpu_tensor) in cpu_columns {
769 let cache_key = (fid, target_layer);
770 if let std::collections::hash_map::Entry::Vacant(e) =
771 self.steering_cache.entry(cache_key)
772 {
773 let device_tensor = cpu_tensor.to_device(device)?;
774 e.insert(device_tensor);
775 loaded += 1;
776 }
777 }
778 }
779
780 info!(
781 "Cached {loaded} new steering vectors ({} total in cache)",
782 self.steering_cache.len()
783 );
784 Ok(())
785 }
786
787 pub fn cache_steering_vectors_all_downstream(
806 &mut self,
807 features: &[CltFeatureId],
808 device: &Device,
809 ) -> Result<()> {
810 let n_layers = self.config.n_layers;
811
812 let mut by_source: HashMap<usize, Vec<usize>> = HashMap::new();
814 for fid in features {
815 if fid.layer >= n_layers {
816 return Err(MIError::Config(format!(
817 "feature source layer {} out of range (max {})",
818 fid.layer,
819 n_layers - 1
820 )));
821 }
822 by_source.entry(fid.layer).or_default().push(fid.index);
823 }
824
825 let mut loaded = 0_usize;
826 let n_source_layers = by_source.len();
827 for (layer_idx, (source_layer, indices)) in by_source.iter().enumerate() {
828 let n_target_layers = n_layers - source_layer;
829 info!(
830 "cache_steering_vectors_all_downstream: loading decoder for source layer {} \
831 ({}/{}, {} downstream layers)",
832 source_layer,
833 layer_idx + 1,
834 n_source_layers,
835 n_target_layers
836 );
837
838 let mut cpu_columns: Vec<(CltFeatureId, usize, Tensor)> = Vec::new();
840 {
841 let dec_path = self.ensure_decoder_path(*source_layer)?;
842 let data = std::fs::read(&dec_path)?;
843 info!(
844 "cache_steering_vectors_all_downstream: loaded {} MB for layer {}",
845 data.len() / (1024 * 1024),
846 source_layer
847 );
848 let st = SafeTensors::deserialize(&data).map_err(|e| {
849 MIError::Config(format!(
850 "failed to deserialize decoder layer {source_layer}: {e}"
851 ))
852 })?;
853 let dec_name = format!("W_dec_{source_layer}");
854 let w_dec = tensor_from_view(
855 &st.tensor(&dec_name).map_err(|e| {
856 MIError::Config(format!("tensor '{dec_name}' not found: {e}"))
857 })?,
858 &Device::Cpu,
859 )?;
860
861 for &index in indices {
862 let fid = CltFeatureId {
863 layer: *source_layer,
864 index,
865 };
866 for target_offset in 0..n_target_layers {
867 let target_layer = source_layer + target_offset;
868 let cache_key = (fid, target_layer);
869 if !self.steering_cache.contains_key(&cache_key) {
870 let view = w_dec.i((index, target_offset))?;
871 let dims = view.dims().to_vec();
872 let values = view.to_dtype(DType::F32)?.to_vec1::<f32>()?;
874 let independent =
875 Tensor::from_vec(values, dims.as_slice(), &Device::Cpu)?;
876 cpu_columns.push((fid, target_layer, independent));
877 }
878 }
879 }
880 }
882
883 for (fid, target_layer, cpu_tensor) in cpu_columns {
885 let cache_key = (fid, target_layer);
886 if let std::collections::hash_map::Entry::Vacant(e) =
887 self.steering_cache.entry(cache_key)
888 {
889 let device_tensor = cpu_tensor.to_device(device)?;
890 e.insert(device_tensor);
891 loaded += 1;
892 }
893 }
894 }
895
896 info!(
897 "Cached {loaded} new steering vectors across all downstream layers ({} total in cache)",
898 self.steering_cache.len()
899 );
900 Ok(())
901 }
902
903 pub fn clear_steering_cache(&mut self) {
905 let count = self.steering_cache.len();
906 self.steering_cache.clear();
907 if count > 0 {
908 info!("Cleared {count} steering vectors from cache");
909 }
910 }
911
912 #[must_use]
914 pub fn steering_cache_len(&self) -> usize {
915 self.steering_cache.len()
916 }
917
918 pub fn prepare_hook_injection(
943 &self,
944 features: &[(CltFeatureId, usize)],
945 position: usize,
946 seq_len: usize,
947 strength: f32,
948 device: &Device,
949 ) -> Result<crate::hooks::HookSpec> {
950 use crate::hooks::{HookPoint, HookSpec, Intervention};
951
952 let mut per_layer: HashMap<usize, Tensor> = HashMap::new();
954 for (feature, target_layer) in features {
955 let cache_key = (*feature, *target_layer);
956 let cached = self.steering_cache.get(&cache_key).ok_or_else(|| {
957 MIError::Hook(format!(
958 "feature {feature} for target layer {target_layer} not in steering cache \
959 — call cache_steering_vectors() first"
960 ))
961 })?;
962 let cached_f32 = cached.to_dtype(DType::F32)?;
964 if let Some(acc) = per_layer.get_mut(target_layer) {
965 let acc_ref: &Tensor = acc;
966 *acc = (acc_ref + &cached_f32)?;
967 } else {
968 per_layer.insert(*target_layer, cached_f32);
969 }
970 }
971
972 let mut hooks = HookSpec::new();
974 let d_model = self.config.d_model;
975
976 for (target_layer, accumulated) in &per_layer {
977 let scaled = (accumulated * f64::from(strength))?;
979
980 let mut injection = Tensor::zeros((1, seq_len, d_model), DType::F32, device)?;
982
983 let scaled_3d = scaled.unsqueeze(0)?.unsqueeze(0)?; let before = if position > 0 {
986 Some(injection.narrow(1, 0, position)?)
987 } else {
988 None
989 };
990 let after = if position + 1 < seq_len {
991 Some(injection.narrow(1, position + 1, seq_len - position - 1)?)
992 } else {
993 None
994 };
995
996 let mut parts: Vec<Tensor> = Vec::with_capacity(3);
997 if let Some(b) = before {
998 parts.push(b);
999 }
1000 parts.push(scaled_3d);
1001 if let Some(a) = after {
1002 parts.push(a);
1003 }
1004
1005 injection = Tensor::cat(&parts, 1)?;
1006
1007 hooks.intervene(
1008 HookPoint::ResidPost(*target_layer),
1009 Intervention::Add(injection),
1010 );
1011 }
1012
1013 Ok(hooks)
1014 }
1015
1016 pub fn inject(
1038 &self,
1039 residual: &Tensor,
1040 features: &[(CltFeatureId, usize)],
1041 position: usize,
1042 strength: f32,
1043 ) -> Result<Tensor> {
1044 let (batch, seq_len, d_model) = residual.dims3()?;
1045 if position >= seq_len {
1046 return Err(MIError::Config(format!(
1047 "injection position {position} out of range (seq_len={seq_len})"
1048 )));
1049 }
1050 if d_model != self.config.d_model {
1051 return Err(MIError::Config(format!(
1052 "residual d_model={d_model} doesn't match CLT d_model={}",
1053 self.config.d_model
1054 )));
1055 }
1056
1057 let mut accumulated = Tensor::zeros((d_model,), DType::F32, residual.device())?;
1059 for (feature, target_layer) in features {
1060 let cache_key = (*feature, *target_layer);
1061 let cached = self.steering_cache.get(&cache_key).ok_or_else(|| {
1062 MIError::Hook(format!(
1063 "feature {feature} for target layer {target_layer} not in steering cache"
1064 ))
1065 })?;
1066 let cached_f32 = cached.to_dtype(DType::F32)?;
1068 accumulated = (&accumulated + &cached_f32)?;
1069 }
1070
1071 let accumulated = (accumulated * f64::from(strength))?;
1073
1074 let accumulated = accumulated.to_dtype(residual.dtype())?;
1076
1077 let pos_slice = residual.narrow(1, position, 1)?; let steering_expanded = accumulated
1080 .unsqueeze(0)?
1081 .unsqueeze(0)?
1082 .expand((batch, 1, d_model))?; let pos_updated = (&pos_slice + &steering_expanded)?;
1084
1085 let mut parts: Vec<Tensor> = Vec::with_capacity(3);
1087 if position > 0 {
1088 parts.push(residual.narrow(1, 0, position)?);
1089 }
1090 parts.push(pos_updated);
1091 if position + 1 < seq_len {
1092 parts.push(residual.narrow(1, position + 1, seq_len - position - 1)?);
1093 }
1094
1095 let result = Tensor::cat(&parts, 1)?;
1096 Ok(result)
1097 }
1098
1099 pub fn score_features_by_decoder_projection(
1133 &mut self,
1134 direction: &Tensor,
1135 target_layer: usize,
1136 top_k: usize,
1137 cosine: bool,
1138 ) -> Result<Vec<(CltFeatureId, f32)>> {
1139 let d_model = self.config.d_model;
1140 if direction.dims() != [d_model] {
1141 return Err(MIError::Config(format!(
1142 "direction must have shape [{d_model}], got {:?}",
1143 direction.dims()
1144 )));
1145 }
1146 if target_layer >= self.config.n_layers {
1147 return Err(MIError::Config(format!(
1148 "target layer {target_layer} out of range (max {})",
1149 self.config.n_layers - 1
1150 )));
1151 }
1152
1153 let direction_f32 = direction.to_dtype(DType::F32)?.to_device(&Device::Cpu)?;
1155
1156 let direction_norm = if cosine {
1158 let norm: f32 = direction_f32.sqr()?.sum_all()?.sqrt()?.to_scalar()?;
1159 if norm > 1e-10 {
1160 direction_f32.broadcast_div(&Tensor::new(norm, &Device::Cpu)?)?
1161 } else {
1162 direction_f32
1163 }
1164 } else {
1165 direction_f32
1166 };
1167
1168 let mut all_scores: Vec<(CltFeatureId, f32)> = Vec::new();
1169
1170 for source_layer in 0..self.config.n_layers {
1171 if target_layer < source_layer {
1172 continue; }
1174 let target_offset = target_layer - source_layer;
1175
1176 let dec_path = self.ensure_decoder_path(source_layer)?;
1178 let data = std::fs::read(&dec_path)?;
1179 info!(
1180 "score_features_by_decoder_projection: loaded {} MB for layer {}",
1181 data.len() / (1024 * 1024),
1182 source_layer
1183 );
1184 let st = SafeTensors::deserialize(&data).map_err(|e| {
1185 MIError::Config(format!(
1186 "failed to deserialize decoder layer {source_layer}: {e}"
1187 ))
1188 })?;
1189
1190 let dec_name = format!("W_dec_{source_layer}");
1191 let w_dec = tensor_from_view(
1192 &st.tensor(&dec_name)
1193 .map_err(|e| MIError::Config(format!("tensor '{dec_name}' not found: {e}")))?,
1194 &Device::Cpu,
1195 )?;
1196 let w_dec_f32 = w_dec.to_dtype(DType::F32)?;
1198
1199 let dec_slice = w_dec_f32.i((.., target_offset, ..))?;
1201
1202 let raw_scores = dec_slice
1204 .matmul(&direction_norm.unsqueeze(1)?)?
1205 .squeeze(1)?;
1206
1207 let scores_vec: Vec<f32> = if cosine {
1208 let dec_norms = dec_slice.sqr()?.sum(1)?.sqrt()?;
1210 let cosine_scores = raw_scores.broadcast_div(&dec_norms)?;
1211 cosine_scores.to_vec1()?
1212 } else {
1213 raw_scores.to_vec1()?
1214 };
1215
1216 for (idx, &score) in scores_vec.iter().enumerate() {
1217 if score.is_finite() {
1218 all_scores.push((
1219 CltFeatureId {
1220 layer: source_layer,
1221 index: idx,
1222 },
1223 score,
1224 ));
1225 }
1226 }
1227
1228 info!(
1229 "Scored {} features at source layer {source_layer} (target layer {target_layer})",
1230 scores_vec.len()
1231 );
1232 }
1233
1234 all_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1236 all_scores.truncate(top_k);
1237
1238 Ok(all_scores)
1239 }
1240
1241 pub fn score_features_by_decoder_projection_batch(
1269 &mut self,
1270 directions: &[Tensor],
1271 target_layer: usize,
1272 top_k: usize,
1273 cosine: bool,
1274 ) -> Result<Vec<Vec<(CltFeatureId, f32)>>> {
1275 let d_model = self.config.d_model;
1276 let n_words = directions.len();
1277 if n_words == 0 {
1278 return Err(MIError::Config(
1279 "at least one direction vector required".into(),
1280 ));
1281 }
1282 for (i, dir) in directions.iter().enumerate() {
1283 if dir.dims() != [d_model] {
1284 return Err(MIError::Config(format!(
1285 "direction vector {i} must have shape [{d_model}], got {:?}",
1286 dir.dims()
1287 )));
1288 }
1289 }
1290 if target_layer >= self.config.n_layers {
1291 return Err(MIError::Config(format!(
1292 "target layer {target_layer} out of range (max {})",
1293 self.config.n_layers - 1
1294 )));
1295 }
1296
1297 let dirs_f32: Vec<Tensor> = directions
1299 .iter()
1300 .map(|d| d.to_dtype(DType::F32)?.to_device(&Device::Cpu))
1301 .collect::<std::result::Result<_, _>>()?;
1302 let stacked = Tensor::stack(&dirs_f32, 0)?; let stacked_norm = if cosine {
1306 let norms = stacked.sqr()?.sum(1)?.sqrt()?; let ones = Tensor::ones_like(&norms)?;
1308 let safe_norms = norms.maximum(&(&ones * 1e-10f64)?)?; stacked.broadcast_div(&safe_norms.unsqueeze(1)?)?
1310 } else {
1311 stacked
1312 };
1313 let directions_t = stacked_norm.t()?; let mut all_scores: Vec<Vec<(CltFeatureId, f32)>> =
1317 (0..n_words).map(|_| Vec::new()).collect();
1318
1319 for source_layer in 0..self.config.n_layers {
1320 if target_layer < source_layer {
1321 continue;
1322 }
1323 let target_offset = target_layer - source_layer;
1324
1325 let dec_path = self.ensure_decoder_path(source_layer)?;
1327 let data = std::fs::read(&dec_path)?;
1328 info!(
1329 "score_features_batch: loaded {} MB for layer {}",
1330 data.len() / (1024 * 1024),
1331 source_layer
1332 );
1333 let st = SafeTensors::deserialize(&data).map_err(|e| {
1334 MIError::Config(format!(
1335 "failed to deserialize decoder layer {source_layer}: {e}"
1336 ))
1337 })?;
1338 let dec_name = format!("W_dec_{source_layer}");
1339 let w_dec = tensor_from_view(
1340 &st.tensor(&dec_name)
1341 .map_err(|e| MIError::Config(format!("tensor '{dec_name}' not found: {e}")))?,
1342 &Device::Cpu,
1343 )?;
1344 let w_dec_f32 = w_dec.to_dtype(DType::F32)?;
1346 let dec_slice = w_dec_f32.i((.., target_offset, ..))?; let raw_scores = dec_slice.matmul(&directions_t)?;
1350
1351 let scores_2d: Vec<Vec<f32>> = if cosine {
1353 let dec_norms = dec_slice.sqr()?.sum(1)?.sqrt()?; let cosine_scores = raw_scores.broadcast_div(&dec_norms.unsqueeze(1)?)?;
1355 cosine_scores.t()?.to_vec2()?
1356 } else {
1357 raw_scores.t()?.to_vec2()?
1358 };
1359
1360 for (w, word_scores) in scores_2d.iter().enumerate() {
1361 for (idx, &score) in word_scores.iter().enumerate() {
1362 if score.is_finite() {
1363 if let Some(word_vec) = all_scores.get_mut(w) {
1364 word_vec.push((
1365 CltFeatureId {
1366 layer: source_layer,
1367 index: idx,
1368 },
1369 score,
1370 ));
1371 }
1372 }
1373 }
1374 }
1375
1376 info!(
1377 "Batch scored {} words × {} features at source layer {} (target layer {})",
1378 n_words,
1379 scores_2d.first().map_or(0, Vec::len),
1380 source_layer,
1381 target_layer
1382 );
1383 }
1384
1385 for word_scores in &mut all_scores {
1387 word_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1388 word_scores.truncate(top_k);
1389 }
1390
1391 Ok(all_scores)
1392 }
1393
1394 pub fn extract_decoder_vectors(
1420 &mut self,
1421 features: &[CltFeatureId],
1422 target_layer: usize,
1423 ) -> Result<HashMap<CltFeatureId, Tensor>> {
1424 if target_layer >= self.config.n_layers {
1425 return Err(MIError::Config(format!(
1426 "target layer {target_layer} out of range (max {})",
1427 self.config.n_layers - 1
1428 )));
1429 }
1430
1431 let mut by_source: HashMap<usize, Vec<usize>> = HashMap::new();
1433 for fid in features {
1434 if fid.layer >= self.config.n_layers {
1435 return Err(MIError::Config(format!(
1436 "feature source layer {} out of range (max {})",
1437 fid.layer,
1438 self.config.n_layers - 1
1439 )));
1440 }
1441 if target_layer < fid.layer {
1442 return Err(MIError::Config(format!(
1443 "target layer {target_layer} must be >= source layer {}",
1444 fid.layer
1445 )));
1446 }
1447 by_source.entry(fid.layer).or_default().push(fid.index);
1448 }
1449
1450 let mut result: HashMap<CltFeatureId, Tensor> = HashMap::new();
1451 let n_source_layers = by_source.len();
1452
1453 for (layer_idx, (source_layer, indices)) in by_source.iter().enumerate() {
1454 info!(
1455 "extract_decoder_vectors: loading decoder for source layer {} ({}/{})",
1456 source_layer,
1457 layer_idx + 1,
1458 n_source_layers
1459 );
1460 let target_offset = target_layer - source_layer;
1461
1462 let dec_path = self.ensure_decoder_path(*source_layer)?;
1464 let data = std::fs::read(&dec_path)?;
1465 let st = SafeTensors::deserialize(&data).map_err(|e| {
1466 MIError::Config(format!(
1467 "failed to deserialize decoder layer {source_layer}: {e}"
1468 ))
1469 })?;
1470 let dec_name = format!("W_dec_{source_layer}");
1471 let w_dec = tensor_from_view(
1472 &st.tensor(&dec_name)
1473 .map_err(|e| MIError::Config(format!("tensor '{dec_name}' not found: {e}")))?,
1474 &Device::Cpu,
1475 )?;
1476
1477 for &index in indices {
1478 let fid = CltFeatureId {
1479 layer: *source_layer,
1480 index,
1481 };
1482 if let std::collections::hash_map::Entry::Vacant(e) = result.entry(fid) {
1483 let view = w_dec.i((index, target_offset))?;
1485 let dims = view.dims().to_vec();
1486 let values = view.to_dtype(DType::F32)?.to_vec1::<f32>()?;
1488 let independent = Tensor::from_vec(values, dims.as_slice(), &Device::Cpu)?;
1489 e.insert(independent);
1490 }
1491 }
1492 }
1494
1495 info!(
1496 "Extracted {} decoder vectors across {} source layers",
1497 result.len(),
1498 n_source_layers
1499 );
1500
1501 Ok(result)
1502 }
1503
1504 pub fn build_attribution_graph(
1517 &mut self,
1518 direction: &Tensor,
1519 target_layer: usize,
1520 top_k: usize,
1521 cosine: bool,
1522 ) -> Result<AttributionGraph> {
1523 let scored =
1524 self.score_features_by_decoder_projection(direction, target_layer, top_k, cosine)?;
1525 Ok(AttributionGraph {
1526 target_layer,
1527 edges: scored
1528 .into_iter()
1529 .map(|(feature, score)| AttributionEdge { feature, score })
1530 .collect(),
1531 })
1532 }
1533
1534 pub fn build_attribution_graph_batch(
1547 &mut self,
1548 directions: &[Tensor],
1549 target_layer: usize,
1550 top_k: usize,
1551 cosine: bool,
1552 ) -> Result<Vec<AttributionGraph>> {
1553 let batch = self.score_features_by_decoder_projection_batch(
1554 directions,
1555 target_layer,
1556 top_k,
1557 cosine,
1558 )?;
1559 Ok(batch
1560 .into_iter()
1561 .map(|scored| AttributionGraph {
1562 target_layer,
1563 edges: scored
1564 .into_iter()
1565 .map(|(feature, score)| AttributionEdge { feature, score })
1566 .collect(),
1567 })
1568 .collect())
1569 }
1570}
1571
1572fn tensor_from_view(view: &safetensors::tensor::TensorView<'_>, device: &Device) -> Result<Tensor> {
1586 let shape: Vec<usize> = view.shape().to_vec();
1587 #[allow(clippy::wildcard_enum_match_arm)]
1588 let dtype = match view.dtype() {
1590 safetensors::Dtype::BF16 => DType::BF16,
1591 safetensors::Dtype::F16 => DType::F16,
1592 safetensors::Dtype::F32 => DType::F32,
1593 other => {
1594 return Err(MIError::Config(format!(
1595 "unsupported CLT tensor dtype: {other:?}"
1596 )));
1597 }
1598 };
1599 let tensor = Tensor::from_raw_buffer(view.data(), dtype, &shape, device)?;
1600 Ok(tensor)
1601}
1602
1603fn parse_yaml_value(yaml_text: &str, key: &str) -> Option<String> {
1607 for line in yaml_text.lines() {
1608 let line = line.trim();
1609 if let Some(rest) = line.strip_prefix(key) {
1610 if let Some(rest) = rest.strip_prefix(':') {
1611 let value = rest.trim().trim_matches('"');
1612 return Some(value.to_owned());
1613 }
1614 }
1615 }
1616 None
1617}
1618
1619#[cfg(test)]
1624#[allow(clippy::unwrap_used, clippy::expect_used)]
1625mod tests {
1626 use super::*;
1627
1628 #[test]
1629 fn clt_feature_id_display() {
1630 let fid = CltFeatureId {
1631 layer: 5,
1632 index: 42,
1633 };
1634 assert_eq!(fid.to_string(), "L5:42");
1635 }
1636
1637 #[test]
1638 fn clt_feature_id_ordering() {
1639 let a = CltFeatureId {
1640 layer: 0,
1641 index: 10,
1642 };
1643 let b = CltFeatureId {
1644 layer: 0,
1645 index: 20,
1646 };
1647 let c = CltFeatureId { layer: 1, index: 0 };
1648 assert!(a < b);
1649 assert!(b < c);
1650 }
1651
1652 #[test]
1653 fn sparse_activations_basics() {
1654 let features = vec![
1655 (CltFeatureId { layer: 0, index: 5 }, 3.0),
1656 (CltFeatureId { layer: 0, index: 2 }, 2.0),
1657 (CltFeatureId { layer: 0, index: 8 }, 1.0),
1658 ];
1659 let sparse = SparseActivations { features };
1660 assert_eq!(sparse.len(), 3);
1661 assert!(!sparse.is_empty());
1662 }
1663
1664 #[test]
1665 fn sparse_activations_truncate() {
1666 let features = vec![
1667 (CltFeatureId { layer: 0, index: 5 }, 3.0),
1668 (CltFeatureId { layer: 0, index: 2 }, 2.0),
1669 (CltFeatureId { layer: 0, index: 8 }, 1.0),
1670 ];
1671 let mut sparse = SparseActivations { features };
1672 sparse.truncate(2);
1673 assert_eq!(sparse.len(), 2);
1674 assert_eq!(sparse.features[0].0.index, 5);
1675 assert_eq!(sparse.features[1].0.index, 2);
1676 }
1677
1678 #[test]
1679 fn parse_yaml_value_basic() {
1680 let yaml = "model_name: \"google/gemma-2-2b\"\nmodel_kind: cross_layer_transcoder\n";
1681 assert_eq!(
1682 parse_yaml_value(yaml, "model_name"),
1683 Some("google/gemma-2-2b".to_owned())
1684 );
1685 assert_eq!(
1686 parse_yaml_value(yaml, "model_kind"),
1687 Some("cross_layer_transcoder".to_owned())
1688 );
1689 assert_eq!(parse_yaml_value(yaml, "missing_key"), None);
1690 }
1691
1692 #[test]
1693 fn encode_synthetic() {
1694 let device = Device::Cpu;
1696 let d_model = 8;
1697 let n_features = 4;
1698
1699 #[rustfmt::skip]
1701 let w_enc_data: Vec<f32> = vec![
1702 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ];
1707 let w_enc = Tensor::from_vec(w_enc_data, (n_features, d_model), &device).unwrap();
1708
1709 let b_enc_data: Vec<f32> = vec![0.0, -0.5, 0.0, -2.0]; let b_enc = Tensor::from_vec(b_enc_data, (n_features,), &device).unwrap();
1712
1713 let residual_data: Vec<f32> = vec![1.5, 0.3, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0];
1715 let residual = Tensor::from_vec(residual_data, (d_model,), &device).unwrap();
1716
1717 let clt = CrossLayerTranscoder {
1725 repo_id: "test".to_owned(),
1726 fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
1727 encoder_paths: vec![None],
1728 decoder_paths: vec![None],
1729 config: CltConfig {
1730 n_layers: 1,
1731 d_model,
1732 n_features_per_layer: n_features,
1733 n_features_total: n_features,
1734 model_name: "test".to_owned(),
1735 },
1736 loaded_encoder: Some(LoadedEncoder {
1737 layer: 0,
1738 w_enc,
1739 b_enc,
1740 }),
1741 steering_cache: HashMap::new(),
1742 };
1743
1744 let sparse = clt.encode(&residual, 0).unwrap();
1745 assert_eq!(sparse.len(), 1, "only feature 0 should be active");
1746 assert_eq!(sparse.features[0].0.index, 0);
1747 assert!((sparse.features[0].1 - 1.5).abs() < 1e-5);
1748 }
1749
1750 #[test]
1751 fn encode_wrong_layer_errors() {
1752 let device = Device::Cpu;
1753 let w_enc = Tensor::zeros((4, 8), DType::F32, &device).unwrap();
1754 let b_enc = Tensor::zeros((4,), DType::F32, &device).unwrap();
1755 let residual = Tensor::zeros((8,), DType::F32, &device).unwrap();
1756
1757 let clt = CrossLayerTranscoder {
1758 repo_id: "test".to_owned(),
1759 fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
1760 encoder_paths: vec![None; 2],
1761 decoder_paths: vec![None; 2],
1762 config: CltConfig {
1763 n_layers: 2,
1764 d_model: 8,
1765 n_features_per_layer: 4,
1766 n_features_total: 8,
1767 model_name: "test".to_owned(),
1768 },
1769 loaded_encoder: Some(LoadedEncoder {
1770 layer: 0,
1771 w_enc,
1772 b_enc,
1773 }),
1774 steering_cache: HashMap::new(),
1775 };
1776
1777 let result = clt.encode(&residual, 1);
1779 assert!(result.is_err());
1780 }
1781
1782 #[test]
1783 fn inject_position() {
1784 let device = Device::Cpu;
1785 let d_model = 4;
1786
1787 let residual = Tensor::ones((1, 3, d_model), DType::F32, &device).unwrap();
1789
1790 let fid = CltFeatureId { layer: 0, index: 0 };
1792 let target_layer = 1;
1793 let steering_vec =
1794 Tensor::from_vec(vec![10.0_f32, 20.0, 30.0, 40.0], (d_model,), &device).unwrap();
1795
1796 let mut steering_cache = HashMap::new();
1797 steering_cache.insert((fid, target_layer), steering_vec);
1798
1799 let clt = CrossLayerTranscoder {
1800 repo_id: "test".to_owned(),
1801 fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
1802 encoder_paths: vec![None; 2],
1803 decoder_paths: vec![None; 2],
1804 config: CltConfig {
1805 n_layers: 2,
1806 d_model,
1807 n_features_per_layer: 1,
1808 n_features_total: 2,
1809 model_name: "test".to_owned(),
1810 },
1811 loaded_encoder: None,
1812 steering_cache,
1813 };
1814
1815 let result = clt
1817 .inject(&residual, &[(fid, target_layer)], 1, 1.0)
1818 .unwrap();
1819
1820 let pos0: Vec<f32> = result.i((0, 0)).unwrap().to_vec1().unwrap();
1822 assert_eq!(pos0, vec![1.0, 1.0, 1.0, 1.0]);
1823
1824 let pos1: Vec<f32> = result.i((0, 1)).unwrap().to_vec1().unwrap();
1826 assert_eq!(pos1, vec![11.0, 21.0, 31.0, 41.0]);
1827
1828 let pos2: Vec<f32> = result.i((0, 2)).unwrap().to_vec1().unwrap();
1830 assert_eq!(pos2, vec![1.0, 1.0, 1.0, 1.0]);
1831 }
1832
1833 #[test]
1834 fn prepare_hook_injection_creates_correct_hooks() {
1835 use crate::hooks::HookPoint;
1836
1837 let device = Device::Cpu;
1838 let d_model = 4;
1839
1840 let fid = CltFeatureId { layer: 0, index: 0 };
1841 let target_layer = 5;
1842 let steering_vec =
1843 Tensor::from_vec(vec![1.0_f32, 2.0, 3.0, 4.0], (d_model,), &device).unwrap();
1844
1845 let mut steering_cache = HashMap::new();
1846 steering_cache.insert((fid, target_layer), steering_vec);
1847
1848 let clt = CrossLayerTranscoder {
1849 repo_id: "test".to_owned(),
1850 fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
1851 encoder_paths: vec![None; 10],
1852 decoder_paths: vec![None; 10],
1853 config: CltConfig {
1854 n_layers: 10,
1855 d_model,
1856 n_features_per_layer: 1,
1857 n_features_total: 10,
1858 model_name: "test".to_owned(),
1859 },
1860 loaded_encoder: None,
1861 steering_cache,
1862 };
1863
1864 let hooks = clt
1865 .prepare_hook_injection(&[(fid, target_layer)], 2, 5, 1.0, &device)
1866 .unwrap();
1867
1868 assert!(hooks.has_intervention_at(&HookPoint::ResidPost(target_layer)));
1870 assert!(!hooks.has_intervention_at(&HookPoint::ResidPost(0)));
1872 assert!(!hooks.has_intervention_at(&HookPoint::ResidPost(4)));
1873 }
1874
1875 #[test]
1880 fn attribution_edge_basics() {
1881 let edge = AttributionEdge {
1882 feature: CltFeatureId {
1883 layer: 3,
1884 index: 42,
1885 },
1886 score: 0.75,
1887 };
1888 assert_eq!(edge.feature.layer, 3);
1889 assert_eq!(edge.feature.index, 42);
1890 assert!((edge.score - 0.75).abs() < f32::EPSILON);
1891 }
1892
1893 #[test]
1894 fn attribution_graph_empty() {
1895 let graph = AttributionGraph {
1896 target_layer: 5,
1897 edges: Vec::new(),
1898 };
1899 assert_eq!(graph.target_layer(), 5);
1900 assert!(graph.is_empty());
1901 assert_eq!(graph.len(), 0);
1902 assert!(graph.features().is_empty());
1903 assert!(graph.into_edges().is_empty());
1904 }
1905
1906 #[test]
1907 fn attribution_graph_top_k() {
1908 let edges = vec![
1909 AttributionEdge {
1910 feature: CltFeatureId { layer: 0, index: 0 },
1911 score: 5.0,
1912 },
1913 AttributionEdge {
1914 feature: CltFeatureId { layer: 0, index: 1 },
1915 score: 3.0,
1916 },
1917 AttributionEdge {
1918 feature: CltFeatureId { layer: 1, index: 0 },
1919 score: 1.0,
1920 },
1921 AttributionEdge {
1922 feature: CltFeatureId { layer: 1, index: 1 },
1923 score: -1.0,
1924 },
1925 AttributionEdge {
1926 feature: CltFeatureId { layer: 2, index: 0 },
1927 score: -4.0,
1928 },
1929 ];
1930 let graph = AttributionGraph {
1931 target_layer: 3,
1932 edges,
1933 };
1934
1935 assert_eq!(graph.len(), 5);
1936
1937 let top3 = graph.top_k(3);
1938 assert_eq!(top3.len(), 3);
1939 assert_eq!(top3.target_layer(), 3);
1940 assert!((top3.edges()[0].score - 5.0).abs() < f32::EPSILON);
1941 assert!((top3.edges()[1].score - 3.0).abs() < f32::EPSILON);
1942 assert!((top3.edges()[2].score - 1.0).abs() < f32::EPSILON);
1943
1944 let top10 = graph.top_k(10);
1946 assert_eq!(top10.len(), 5);
1947 }
1948
1949 #[test]
1950 fn attribution_graph_threshold() {
1951 let edges = vec![
1952 AttributionEdge {
1953 feature: CltFeatureId { layer: 0, index: 0 },
1954 score: 5.0,
1955 },
1956 AttributionEdge {
1957 feature: CltFeatureId { layer: 0, index: 1 },
1958 score: 3.0,
1959 },
1960 AttributionEdge {
1961 feature: CltFeatureId { layer: 1, index: 0 },
1962 score: 1.0,
1963 },
1964 AttributionEdge {
1965 feature: CltFeatureId { layer: 1, index: 1 },
1966 score: -1.0,
1967 },
1968 AttributionEdge {
1969 feature: CltFeatureId { layer: 2, index: 0 },
1970 score: -4.0,
1971 },
1972 ];
1973 let graph = AttributionGraph {
1974 target_layer: 3,
1975 edges,
1976 };
1977
1978 let pruned = graph.threshold(2.0);
1980 assert_eq!(pruned.len(), 3);
1981 assert!((pruned.edges()[0].score - 5.0).abs() < f32::EPSILON);
1982 assert!((pruned.edges()[1].score - 3.0).abs() < f32::EPSILON);
1983 assert!((pruned.edges()[2].score - -4.0).abs() < f32::EPSILON);
1984 }
1985
1986 #[test]
1987 fn attribution_graph_features() {
1988 let edges = vec![
1989 AttributionEdge {
1990 feature: CltFeatureId { layer: 2, index: 7 },
1991 score: 1.0,
1992 },
1993 AttributionEdge {
1994 feature: CltFeatureId { layer: 0, index: 3 },
1995 score: 0.5,
1996 },
1997 ];
1998 let graph = AttributionGraph {
1999 target_layer: 5,
2000 edges,
2001 };
2002
2003 let features = graph.features();
2004 assert_eq!(features.len(), 2);
2005 assert_eq!(features[0], CltFeatureId { layer: 2, index: 7 });
2006 assert_eq!(features[1], CltFeatureId { layer: 0, index: 3 });
2007 }
2008
2009 fn create_synthetic_decoder(
2015 dir: &std::path::Path,
2016 layer: usize,
2017 n_features: usize,
2018 n_target_layers: usize,
2019 d_model: usize,
2020 values: &[f32],
2021 ) -> PathBuf {
2022 assert_eq!(values.len(), n_features * n_target_layers * d_model);
2023 let bytes: Vec<u8> = values.iter().flat_map(|v| v.to_le_bytes()).collect();
2024 let name = format!("W_dec_{layer}");
2025 let shape = vec![n_features, n_target_layers, d_model];
2026 let view =
2027 safetensors::tensor::TensorView::new(safetensors::Dtype::F32, shape, &bytes).unwrap();
2028 let mut tensors = HashMap::new();
2029 tensors.insert(name, view);
2030 let serialized = safetensors::serialize(&tensors, &None).unwrap();
2031 let path = dir.join(format!("W_dec_{layer}.safetensors"));
2032 std::fs::write(&path, serialized).unwrap();
2033 path
2034 }
2035
2036 #[test]
2037 fn score_decoder_projection_synthetic() {
2038 let dir = tempfile::tempdir().unwrap();
2042 let d_model = 4;
2043 let n_features = 4;
2044
2045 #[rustfmt::skip]
2051 let dec0_values: Vec<f32> = vec![
2052 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,
2054 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
2056 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
2058 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
2060 ];
2061 let path0 = create_synthetic_decoder(dir.path(), 0, n_features, 2, d_model, &dec0_values);
2062
2063 #[rustfmt::skip]
2069 let dec1_values: Vec<f32> = vec![
2070 2.0, 0.0, 0.0, 0.0,
2071 0.0, 0.0, 0.0, 0.0,
2072 0.0, 0.0, 0.0, 0.0,
2073 0.0, 3.0, 0.0, 0.0,
2074 ];
2075 let path1 = create_synthetic_decoder(dir.path(), 1, n_features, 1, d_model, &dec1_values);
2076
2077 let mut clt = CrossLayerTranscoder {
2078 repo_id: "test".to_owned(),
2079 fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
2080 encoder_paths: vec![None; 2],
2081 decoder_paths: vec![Some(path0), Some(path1)],
2082 config: CltConfig {
2083 n_layers: 2,
2084 d_model,
2085 n_features_per_layer: n_features,
2086 n_features_total: n_features * 2,
2087 model_name: "test".to_owned(),
2088 },
2089 loaded_encoder: None,
2090 steering_cache: HashMap::new(),
2091 };
2092
2093 let direction =
2095 Tensor::from_vec(vec![1.0_f32, 0.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
2096
2097 let scores = clt
2098 .score_features_by_decoder_projection(&direction, 1, 10, false)
2099 .unwrap();
2100
2101 assert!(scores.len() >= 2, "expected at least 2 non-zero scores");
2103 assert_eq!(scores[0].0, CltFeatureId { layer: 1, index: 0 });
2104 assert!((scores[0].1 - 2.0).abs() < 1e-5);
2105 assert_eq!(scores[1].0, CltFeatureId { layer: 0, index: 0 });
2106 assert!((scores[1].1 - 1.0).abs() < 1e-5);
2107
2108 let direction2 =
2110 Tensor::from_vec(vec![0.0_f32, 1.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
2111
2112 let scores2 = clt
2113 .score_features_by_decoder_projection(&direction2, 1, 10, false)
2114 .unwrap();
2115
2116 assert_eq!(scores2[0].0, CltFeatureId { layer: 1, index: 3 });
2117 assert!((scores2[0].1 - 3.0).abs() < 1e-5);
2118 assert_eq!(scores2[1].0, CltFeatureId { layer: 0, index: 1 });
2119 assert!((scores2[1].1 - 1.0).abs() < 1e-5);
2120 }
2121
2122 #[test]
2123 fn score_decoder_projection_cosine_synthetic() {
2124 let dir = tempfile::tempdir().unwrap();
2126 let d_model = 4;
2127 let n_features = 2;
2128
2129 #[rustfmt::skip]
2133 let dec0_values: Vec<f32> = vec![
2134 3.0, 0.0, 0.0, 0.0,
2135 1.0, 1.0, 0.0, 0.0,
2136 ];
2137 let path0 = create_synthetic_decoder(dir.path(), 0, n_features, 1, d_model, &dec0_values);
2138
2139 let mut clt = CrossLayerTranscoder {
2140 repo_id: "test".to_owned(),
2141 fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
2142 encoder_paths: vec![None],
2143 decoder_paths: vec![Some(path0)],
2144 config: CltConfig {
2145 n_layers: 1,
2146 d_model,
2147 n_features_per_layer: n_features,
2148 n_features_total: n_features,
2149 model_name: "test".to_owned(),
2150 },
2151 loaded_encoder: None,
2152 steering_cache: HashMap::new(),
2153 };
2154
2155 let direction =
2156 Tensor::from_vec(vec![1.0_f32, 0.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
2157
2158 let dot_scores = clt
2160 .score_features_by_decoder_projection(&direction, 0, 10, false)
2161 .unwrap();
2162 assert!((dot_scores[0].1 - 3.0).abs() < 1e-5);
2163 assert!((dot_scores[1].1 - 1.0).abs() < 1e-5);
2164
2165 let cos_scores = clt
2167 .score_features_by_decoder_projection(&direction, 0, 10, true)
2168 .unwrap();
2169 assert!(
2170 (cos_scores[0].1 - 1.0).abs() < 1e-4,
2171 "expected ~1.0, got {}",
2172 cos_scores[0].1
2173 );
2174 let expected_cos = 1.0 / 2.0_f32.sqrt();
2175 assert!(
2176 (cos_scores[1].1 - expected_cos).abs() < 1e-4,
2177 "expected ~{expected_cos}, got {}",
2178 cos_scores[1].1
2179 );
2180 }
2181
2182 #[test]
2183 fn score_decoder_projection_batch_synthetic() {
2184 let dir = tempfile::tempdir().unwrap();
2185 let d_model = 4;
2186 let n_features = 2;
2187
2188 #[rustfmt::skip]
2190 let dec0_values: Vec<f32> = vec![
2191 1.0, 0.0, 0.0, 0.0,
2192 0.0, 1.0, 0.0, 0.0,
2193 ];
2194 let path0 = create_synthetic_decoder(dir.path(), 0, n_features, 1, d_model, &dec0_values);
2195
2196 let mut clt = CrossLayerTranscoder {
2197 repo_id: "test".to_owned(),
2198 fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
2199 encoder_paths: vec![None],
2200 decoder_paths: vec![Some(path0)],
2201 config: CltConfig {
2202 n_layers: 1,
2203 d_model,
2204 n_features_per_layer: n_features,
2205 n_features_total: n_features,
2206 model_name: "test".to_owned(),
2207 },
2208 loaded_encoder: None,
2209 steering_cache: HashMap::new(),
2210 };
2211
2212 let dir0 =
2214 Tensor::from_vec(vec![1.0_f32, 0.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
2215 let dir1 =
2216 Tensor::from_vec(vec![0.0_f32, 1.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
2217
2218 let batch = clt
2219 .score_features_by_decoder_projection_batch(&[dir0, dir1], 0, 10, false)
2220 .unwrap();
2221
2222 assert_eq!(batch.len(), 2);
2223
2224 assert_eq!(batch[0][0].0, CltFeatureId { layer: 0, index: 0 });
2226 assert!((batch[0][0].1 - 1.0).abs() < 1e-5);
2227
2228 assert_eq!(batch[1][0].0, CltFeatureId { layer: 0, index: 1 });
2230 assert!((batch[1][0].1 - 1.0).abs() < 1e-5);
2231 }
2232
2233 #[test]
2234 fn extract_decoder_vectors_synthetic() {
2235 let dir = tempfile::tempdir().unwrap();
2236 let d_model = 4;
2237 let n_features = 3;
2238
2239 #[rustfmt::skip]
2241 let dec0_values: Vec<f32> = vec![
2242 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
2244 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
2246 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0,
2248 ];
2249 let path0 = create_synthetic_decoder(dir.path(), 0, n_features, 2, d_model, &dec0_values);
2250
2251 let mut clt = CrossLayerTranscoder {
2252 repo_id: "test".to_owned(),
2253 fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
2254 encoder_paths: vec![None; 2],
2255 decoder_paths: vec![Some(path0), None],
2256 config: CltConfig {
2257 n_layers: 2,
2258 d_model,
2259 n_features_per_layer: n_features,
2260 n_features_total: n_features * 2,
2261 model_name: "test".to_owned(),
2262 },
2263 loaded_encoder: None,
2264 steering_cache: HashMap::new(),
2265 };
2266
2267 let features = vec![
2268 CltFeatureId { layer: 0, index: 0 },
2269 CltFeatureId { layer: 0, index: 2 },
2270 ];
2271
2272 let vectors = clt.extract_decoder_vectors(&features, 1).unwrap();
2274 assert_eq!(vectors.len(), 2);
2275
2276 let v0: Vec<f32> = vectors[&CltFeatureId { layer: 0, index: 0 }]
2278 .to_vec1()
2279 .unwrap();
2280 assert_eq!(v0, vec![5.0, 6.0, 7.0, 8.0]);
2281
2282 let v2: Vec<f32> = vectors[&CltFeatureId { layer: 0, index: 2 }]
2284 .to_vec1()
2285 .unwrap();
2286 assert_eq!(v2, vec![21.0, 22.0, 23.0, 24.0]);
2287 }
2288
2289 #[test]
2290 fn build_attribution_graph_synthetic() {
2291 let dir = tempfile::tempdir().unwrap();
2292 let d_model = 4;
2293 let n_features = 2;
2294
2295 #[rustfmt::skip]
2296 let dec0_values: Vec<f32> = vec![
2297 1.0, 0.0, 0.0, 0.0,
2298 0.0, 2.0, 0.0, 0.0,
2299 ];
2300 let path0 = create_synthetic_decoder(dir.path(), 0, n_features, 1, d_model, &dec0_values);
2301
2302 let mut clt = CrossLayerTranscoder {
2303 repo_id: "test".to_owned(),
2304 fetch_config: hf_fetch_model::FetchConfig::builder().build().unwrap(),
2305 encoder_paths: vec![None],
2306 decoder_paths: vec![Some(path0)],
2307 config: CltConfig {
2308 n_layers: 1,
2309 d_model,
2310 n_features_per_layer: n_features,
2311 n_features_total: n_features,
2312 model_name: "test".to_owned(),
2313 },
2314 loaded_encoder: None,
2315 steering_cache: HashMap::new(),
2316 };
2317
2318 let direction =
2319 Tensor::from_vec(vec![0.0_f32, 1.0, 0.0, 0.0], (d_model,), &Device::Cpu).unwrap();
2320
2321 let graph = clt
2322 .build_attribution_graph(&direction, 0, 10, false)
2323 .unwrap();
2324
2325 assert_eq!(graph.target_layer(), 0);
2326 assert!(!graph.is_empty());
2327 assert_eq!(
2329 graph.edges()[0].feature,
2330 CltFeatureId { layer: 0, index: 1 }
2331 );
2332 assert!((graph.edges()[0].score - 2.0).abs() < 1e-5);
2333
2334 let pruned = graph.threshold(1.0);
2336 assert_eq!(pruned.len(), 1);
2337 assert_eq!(pruned.features()[0], CltFeatureId { layer: 0, index: 1 });
2338 }
2339}