use anyhow::Result;
use rlx_ir::HirGraphExt;
use rlx_ir::hir::HirMut;
use std::sync::{Arc, Mutex};
use super::BlockStage;
use crate::context::FlowCtx;
use crate::value::FlowValue;
#[derive(Debug, Clone)]
pub struct GemmaKvTapStage {
pub layer_prefix: String,
pub head_dim: usize,
pub n_rot: usize,
pub rope_table: Option<String>,
pub k_eq_v: bool,
pub outputs: Arc<Mutex<Vec<rlx_ir::HirNodeId>>>,
}
impl GemmaKvTapStage {
pub fn layer(
layer_idx: usize,
head_dim: usize,
_eps: f32,
sink: Arc<Mutex<Vec<rlx_ir::HirNodeId>>>,
) -> Self {
Self {
layer_prefix: format!("model.layers.{layer_idx}"),
head_dim,
n_rot: head_dim,
rope_table: None,
k_eq_v: false,
outputs: sink,
}
}
pub fn with_n_rot(mut self, n_rot: usize) -> Self {
self.n_rot = n_rot;
self
}
pub fn with_rope_table(mut self, name: impl Into<String>) -> Self {
self.rope_table = Some(name.into());
self
}
pub fn with_k_eq_v(mut self) -> Self {
self.k_eq_v = true;
self
}
}
impl BlockStage for GemmaKvTapStage {
fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
let lp = &self.layer_prefix;
let (cos, sin) = super::self_attn::resolve_rope_handles(ctx, self.rope_table.as_deref())?;
let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
let v_w = if self.k_eq_v {
None
} else {
Some(ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?)
};
let mut gb = HirMut::new(ctx.hir());
let k = gb.mm(input.id, k_w);
let v = match v_w {
Some(w) => gb.mm(input.id, w),
None => k,
};
let k_rope = gb.rope_n(k, cos, sin, self.head_dim, self.n_rot);
self.outputs.lock().expect("kv tap sink").push(k_rope);
self.outputs.lock().expect("kv tap sink").push(v);
Ok(Some(input))
}
}