pub mod ablation;
pub mod config;
pub mod fast;
pub mod piecewise;
pub mod probing;
pub mod standardize;
pub mod surprise;
pub mod tasks;
use std::path::Path;
use candle_core::{DType, Device, IndexOp, Module, Tensor};
use candle_nn::{Embedding, VarBuilder};
use crate::backend::MIBackend;
use crate::error::{MIError, Result};
use crate::hooks::{HookCache, HookPoint, HookSpec};
pub use config::{StoicheiaArch, StoicheiaConfig, StoicheiaOutput, StoicheiaTask};
fn load_weight_bytes(path: &Path) -> Result<Vec<u8>> {
match path.extension().and_then(|e| e.to_str()) {
Some("safetensors") => Ok(std::fs::read(path)?),
Some("pth" | "pkl") => {
let parsed = anamnesis::parse_pth(path).map_err(|e| {
MIError::Model(candle_core::Error::Msg(format!(
"failed to parse .pth file: {e}"
)))
})?;
parsed.to_safetensors_bytes().map_err(|e| {
MIError::Model(candle_core::Error::Msg(format!(
"failed to convert .pth to safetensors: {e}"
)))
})
}
Some(ext) => Err(MIError::Config(format!(
"unsupported weight file extension: .{ext} \
(expected .safetensors, .pth, or .pkl)"
))),
None => Err(MIError::Config(
"weight file has no extension \
(expected .safetensors, .pth, or .pkl)"
.into(),
)),
}
}
pub struct StoicheiaRnn {
weight_ih: Tensor,
weight_hh: Tensor,
weight_oh: Tensor,
config: StoicheiaConfig,
}
impl StoicheiaRnn {
#[allow(clippy::similar_names)]
pub fn load(config: StoicheiaConfig, path: impl AsRef<Path>, device: &Device) -> Result<Self> {
let buffer = load_weight_bytes(path.as_ref())?;
let vb = VarBuilder::from_buffered_safetensors(buffer, DType::F32, device)?;
let weight_ih = vb.get((config.hidden_size, 1), "rnn.weight_ih_l0")?;
let weight_hh = vb.get((config.hidden_size, config.hidden_size), "rnn.weight_hh_l0")?;
let weight_oh = vb.get((config.output_size(), config.hidden_size), "linear.weight")?;
Ok(Self {
weight_ih,
weight_hh,
weight_oh,
config,
})
}
}
impl StoicheiaRnn {
#[must_use]
pub const fn weight_ih(&self) -> &Tensor {
&self.weight_ih
}
#[must_use]
pub const fn weight_hh(&self) -> &Tensor {
&self.weight_hh
}
#[must_use]
pub const fn weight_oh(&self) -> &Tensor {
&self.weight_oh
}
#[must_use]
pub const fn config(&self) -> &StoicheiaConfig {
&self.config
}
}
impl MIBackend for StoicheiaRnn {
fn num_layers(&self) -> usize {
1
}
fn hidden_size(&self) -> usize {
self.config.hidden_size
}
fn vocab_size(&self) -> usize {
self.config.output_size()
}
fn num_heads(&self) -> usize {
0
}
fn forward(&self, input: &Tensor, hooks: &HookSpec) -> Result<HookCache> {
let device = input.device();
let (batch_size, seq_len) = input.dims2()?;
let h = self.config.hidden_size;
let mut hidden = Tensor::zeros((batch_size, h), DType::F32, device)?;
let mut cache = HookCache::new(Tensor::zeros(1, DType::F32, device)?);
let has_hooks = !hooks.is_empty();
let (captured_pre_act, captured_hidden) = if has_hooks {
let pre_act: std::collections::HashSet<usize> = (0..seq_len)
.filter(|t| {
hooks.is_captured(&HookPoint::Custom(format!("rnn.hook_pre_activation.{t}")))
})
.collect();
let hid: std::collections::HashSet<usize> = (0..seq_len)
.filter(|t| hooks.is_captured(&HookPoint::Custom(format!("rnn.hook_hidden.{t}"))))
.collect();
(pre_act, hid)
} else {
(
std::collections::HashSet::new(),
std::collections::HashSet::new(),
)
};
for t in 0..seq_len {
let x_t = input.i((.., t..=t))?;
let ih = x_t.matmul(&self.weight_ih.t()?)?;
let hh = hidden.matmul(&self.weight_hh.t()?)?;
let pre_act = (ih + hh)?;
if captured_pre_act.contains(&t) {
cache.store(
HookPoint::Custom(format!("rnn.hook_pre_activation.{t}")),
pre_act.clone(),
);
}
hidden = pre_act.relu()?;
if captured_hidden.contains(&t) {
cache.store(
HookPoint::Custom(format!("rnn.hook_hidden.{t}")),
hidden.clone(),
);
}
}
if has_hooks {
let final_hook = HookPoint::Custom("rnn.hook_final_state".into());
if hooks.is_captured(&final_hook) {
cache.store(final_hook, hidden.clone());
}
}
let output = hidden.matmul(&self.weight_oh.t()?)?;
if has_hooks {
let output_hook = HookPoint::Custom("rnn.hook_output".into());
if hooks.is_captured(&output_hook) {
cache.store(output_hook, output.clone());
}
}
let output_3d = output.unsqueeze(1)?;
cache.set_output(output_3d);
Ok(cache)
}
fn project_to_vocab(&self, hidden: &Tensor) -> Result<Tensor> {
Ok(hidden.matmul(&self.weight_oh.t()?)?)
}
}
struct AttentionLayer {
in_proj_weight: Tensor,
out_proj_weight: Tensor,
hidden_size: usize,
}
impl AttentionLayer {
fn forward(&self, hidden: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
let dim = self.hidden_size;
let qkv = hidden.broadcast_matmul(&self.in_proj_weight.t()?)?;
let query = qkv.narrow(2, 0, dim)?;
let key = qkv.narrow(2, dim, dim)?;
let value = qkv.narrow(2, 2 * dim, dim)?;
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let scale = (dim as f64).sqrt();
let scores_2d = (query.matmul(&key.t()?)? / scale)?;
let scores = scores_2d.unsqueeze(1)?;
let pattern = candle_nn::ops::softmax_last_dim(&scores)?;
let pattern_2d = pattern.squeeze(1)?;
let attn_out = pattern_2d.matmul(&value)?;
let projected = attn_out.broadcast_matmul(&self.out_proj_weight.t()?)?;
Ok((projected, scores, pattern))
}
}
pub struct StoicheiaTransformer {
embed: Embedding,
pos_embed: Embedding,
attns: Vec<AttentionLayer>,
unembed_weight: Tensor,
config: StoicheiaConfig,
}
impl StoicheiaTransformer {
pub fn load(config: StoicheiaConfig, path: impl AsRef<Path>, device: &Device) -> Result<Self> {
let buffer = load_weight_bytes(path.as_ref())?;
let vb = VarBuilder::from_buffered_safetensors(buffer, DType::F32, device)?;
let h = config.hidden_size;
let embed = Embedding::new(vb.get((config.input_range, h), "embed.weight")?, h);
let pos_embed = Embedding::new(vb.get((config.seq_len, h), "pos_embed.weight")?, h);
let mut attns = Vec::with_capacity(config.num_layers);
for i in 0..config.num_layers {
let in_proj_weight = vb.get((3 * h, h), &format!("attns.{i}.in_proj_weight"))?;
let out_proj_weight = vb.get((h, h), &format!("attns.{i}.out_proj.weight"))?;
attns.push(AttentionLayer {
in_proj_weight,
out_proj_weight,
hidden_size: h,
});
}
let unembed_weight = vb.get((config.output_size(), h), "unembed.weight")?;
Ok(Self {
embed,
pos_embed,
attns,
unembed_weight,
config,
})
}
}
impl MIBackend for StoicheiaTransformer {
fn num_layers(&self) -> usize {
self.config.num_layers
}
fn hidden_size(&self) -> usize {
self.config.hidden_size
}
fn vocab_size(&self) -> usize {
self.config.output_size()
}
fn num_heads(&self) -> usize {
self.config.num_heads
}
fn forward(&self, input_ids: &Tensor, hooks: &HookSpec) -> Result<HookCache> {
let device = input_ids.device();
let (batch, seq_len) = input_ids.dims2()?;
let mut cache = HookCache::new(Tensor::zeros(1, DType::F32, device)?);
let token_emb = self.embed.forward(input_ids)?;
#[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
let positions: Vec<u32> = (0..seq_len as u32).collect();
let pos_ids = Tensor::new(&positions[..], device)?
.unsqueeze(0)?
.expand((batch, seq_len))?;
let pos_emb = self.pos_embed.forward(&pos_ids)?;
let mut hidden = (token_emb + pos_emb)?;
let has_hooks = !hooks.is_empty();
if has_hooks && hooks.is_captured(&HookPoint::Embed) {
cache.store(HookPoint::Embed, hidden.clone());
}
for (i, attn) in self.attns.iter().enumerate() {
if has_hooks && hooks.is_captured(&HookPoint::ResidPre(i)) {
cache.store(HookPoint::ResidPre(i), hidden.clone());
}
let (attn_out, scores, pattern) = attn.forward(&hidden)?;
if has_hooks && hooks.is_captured(&HookPoint::AttnScores(i)) {
cache.store(HookPoint::AttnScores(i), scores);
}
if has_hooks && hooks.is_captured(&HookPoint::AttnPattern(i)) {
cache.store(HookPoint::AttnPattern(i), pattern);
}
if has_hooks && hooks.is_captured(&HookPoint::AttnOut(i)) {
cache.store(HookPoint::AttnOut(i), attn_out.clone());
}
hidden = (hidden + attn_out)?;
if has_hooks && hooks.is_captured(&HookPoint::ResidPost(i)) {
cache.store(HookPoint::ResidPost(i), hidden.clone());
}
}
let last_hidden = hidden.i((.., seq_len - 1, ..))?;
let output = last_hidden.matmul(&self.unembed_weight.t()?)?;
let output_3d = output.unsqueeze(1)?;
cache.set_output(output_3d);
Ok(cache)
}
fn project_to_vocab(&self, hidden: &Tensor) -> Result<Tensor> {
Ok(hidden.matmul(&self.unembed_weight.t()?)?)
}
}