candle-mi 0.1.12

Mechanistic interpretability for language models in Rust, built on candle
Documentation
(* ::Package:: *)

(* ================================================================== *)
(* Causal Trace Heatmap: Meng et al. (2022) Figure 1e                 *)
(* candle-mi example: activation_patching                              *)
(*                                                                     *)
(* Reads JSON output from --output and plots a layer x position        *)
(* heatmap of recovery percentages. The grid is transposed to match    *)
(* the paper: tokens on Y-axis, layers on X-axis.                     *)
(* ================================================================== *)

(* --- Import JSON output from the Rust example --- *)
(* Change this to the JSON file you want to plot. *)
jsonFile = "gemma-2-2b.json";

raw = Import[
  FileNameJoin[{NotebookDirectory[], jsonFile}],
  "RawJSON"
];

modelId    = raw["model_id"];
tokens     = raw["tokens"];
nLayers    = raw["n_layers"];
seqLen     = raw["seq_len"];
subjectPos = raw["subject_pos"];
grid       = raw["grid"];

Print["Model: ", modelId];
Print["Tokens: ", tokens];
Print["Grid: ", nLayers, " layers x ", seqLen, " positions"];
Print["Subject position: ", subjectPos, " (\"", tokens[[subjectPos + 1]], "\")"];

(* ================================================================== *)
(* PLOT 1: Causal Trace Heatmap (Figure 1e)                           *)
(*                                                                     *)
(* Paper orientation: X = layer, Y = token position (top to bottom).  *)
(* grid is [layer][position], so Transpose gives [position][layer].   *)
(* ================================================================== *)

gridT = Transpose[grid];

(* Token labels for Y-axis (reversed so first token is at top) *)
tokenLabels = tokens;

heatmap = MatrixPlot[
  gridT,
  ColorFunction -> "TemperatureMap",
  PlotLegends -> BarLegend[Automatic, LegendLabel -> "Recovery (%)"],
  FrameLabel -> {"Token", "Layer"},
  FrameTicks -> {
    {Table[{i, tokenLabels[[i]]}, {i, 1, seqLen}], None},
    {Table[{i, i - 1}, {i, 1, nLayers}], None}
  },
  PlotLabel -> Style[
    "Causal Trace (Meng et al. Figure 1e)\n" <> modelId,
    14, Bold
  ],
  ImageSize -> 700,
  AspectRatio -> seqLen / nLayers
];

(* ================================================================== *)
(* PLOT 2: Subject-position recovery curve (layer sweep)               *)
(* ================================================================== *)

subjectRecovery = raw["subject_recovery"];

recoveryCurve = ListLinePlot[
  subjectRecovery,
  PlotRange -> {Automatic, {-5, 105}},
  AxesLabel -> {"Layer", "Recovery (%)"},
  PlotLabel -> Style[
    "Subject-Position Recovery\n" <> modelId <>
    " (pos " <> ToString[subjectPos] <> " = \"" <>
    tokens[[subjectPos + 1]] <> "\")",
    12, Bold
  ],
  PlotStyle -> {Thick, Blue},
  GridLines -> Automatic,
  ImageSize -> 500,
  Epilog -> {Red, PointSize[Large],
    Point[{First@Ordering[subjectRecovery, -1],
           Max[subjectRecovery]}]}
];

(* ================================================================== *)
(* Export PNGs                                                         *)
(* ================================================================== *)

plotDir = FileNameJoin[{NotebookDirectory[], "plots"}];
If[!DirectoryQ[plotDir], CreateDirectory[plotDir]];

prefix = StringReplace[modelId, "/" -> "_"] <> "_";

Export[
  FileNameJoin[{plotDir, prefix <> "causal_trace_heatmap.png"}],
  heatmap, ImageResolution -> 200
];
Export[
  FileNameJoin[{plotDir, prefix <> "subject_recovery.png"}],
  recoveryCurve, ImageResolution -> 200
];

Print["Exported: ", prefix, "causal_trace_heatmap.png"];
Print["Exported: ", prefix, "subject_recovery.png"];

(* Display *)
Column[{
  Style["Causal Trace Analysis", 20, Bold],
  Style[modelId, 14],
  Spacer[20],
  heatmap,
  Spacer[20],
  recoveryCurve
}]