use tensorlogic_ir::{EinsumGraph, EinsumNode};
use crate::config::AttentionConfig;
use crate::error::Result;
#[derive(Clone, Debug)]
pub struct SelfAttention {
pub config: AttentionConfig,
}
impl SelfAttention {
pub fn new(config: AttentionConfig) -> Result<Self> {
config.validate()?;
Ok(Self { config })
}
pub fn build_attention_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
let scores_tensor = graph.add_tensor("attn_scores");
let scores_node = EinsumNode::new("bqd,bkd->bqk", vec![0, 1], vec![scores_tensor]);
graph.add_node(scores_node)?;
let scale_factor = (self.config.d_k as f64).sqrt();
let scale_tensor = graph.add_tensor("scale_factor");
let scaled_tensor = graph.add_tensor("scaled_scores");
let scale_node = EinsumNode::elem_binary(
format!("div_scalar_{}", scale_factor),
scores_tensor,
scale_tensor,
scaled_tensor,
);
graph.add_node(scale_node)?;
let softmax_tensor = graph.add_tensor("attention_weights");
let softmax_node = EinsumNode::elem_unary("softmax_k", scaled_tensor, softmax_tensor);
graph.add_node(softmax_node)?;
let output_tensor = graph.add_tensor("attn_output");
let output_node =
EinsumNode::new("bqk,bkv->bqv", vec![softmax_tensor, 2], vec![output_tensor]);
graph.add_node(output_node)?;
Ok(vec![output_tensor])
}
pub fn get_scale_factor(&self) -> f64 {
(self.config.d_k as f64).sqrt()
}
}
#[derive(Clone, Debug)]
pub struct MultiHeadAttention {
pub config: AttentionConfig,
}
impl MultiHeadAttention {
pub fn new(config: AttentionConfig) -> Result<Self> {
config.validate()?;
Ok(Self { config })
}
pub fn build_mha_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
let _h = self.config.n_heads; let d_k = self.config.d_k;
let q_split = graph.add_tensor("q_split");
let k_split = graph.add_tensor("k_split");
let v_split = graph.add_tensor("v_split");
let reshape_spec = format!("bsd->bsh{}", d_k);
let q_reshape = EinsumNode::new(&reshape_spec, vec![0], vec![q_split]);
graph.add_node(q_reshape)?;
let k_reshape = EinsumNode::new(&reshape_spec, vec![1], vec![k_split]);
graph.add_node(k_reshape)?;
let v_reshape = EinsumNode::new(&reshape_spec, vec![2], vec![v_split]);
graph.add_node(v_reshape)?;
let q_transposed = graph.add_tensor("q_transposed");
let k_transposed = graph.add_tensor("k_transposed");
let v_transposed = graph.add_tensor("v_transposed");
let transpose_node_q = EinsumNode::new("bshd->bhsd", vec![q_split], vec![q_transposed]);
graph.add_node(transpose_node_q)?;
let transpose_node_k = EinsumNode::new("bshd->bhsd", vec![k_split], vec![k_transposed]);
graph.add_node(transpose_node_k)?;
let transpose_node_v = EinsumNode::new("bshd->bhsd", vec![v_split], vec![v_transposed]);
graph.add_node(transpose_node_v)?;
let scores_tensor = graph.add_tensor("mha_scores");
let scores_node = EinsumNode::new(
"bhqd,bhkd->bhqk",
vec![q_transposed, k_transposed],
vec![scores_tensor],
);
graph.add_node(scores_node)?;
let scale_factor = (d_k as f64).sqrt();
let scale_tensor = graph.add_tensor("mha_scale");
let scaled_tensor = graph.add_tensor("mha_scaled_scores");
let scale_node = EinsumNode::elem_binary(
format!("div_scalar_{}", scale_factor),
scores_tensor,
scale_tensor,
scaled_tensor,
);
graph.add_node(scale_node)?;
let softmax_tensor = graph.add_tensor("mha_attention_weights");
let softmax_node = EinsumNode::elem_unary("softmax_k", scaled_tensor, softmax_tensor);
graph.add_node(softmax_node)?;
let attn_output = graph.add_tensor("mha_attn_output");
let attn_node = EinsumNode::new(
"bhqk,bhkv->bhqv",
vec![softmax_tensor, v_transposed],
vec![attn_output],
);
graph.add_node(attn_node)?;
let transposed_back = graph.add_tensor("mha_transposed_back");
let transpose_back_node =
EinsumNode::new("bhsd->bshd", vec![attn_output], vec![transposed_back]);
graph.add_node(transpose_back_node)?;
let output_tensor = graph.add_tensor("mha_output");
let reshape_back_spec = format!("bsh{}-:bsd", d_k);
let reshape_back = EinsumNode::new(
&reshape_back_spec,
vec![transposed_back],
vec![output_tensor],
);
graph.add_node(reshape_back)?;
Ok(vec![output_tensor])
}
pub fn num_heads(&self) -> usize {
self.config.n_heads
}
pub fn head_dim(&self) -> usize {
self.config.d_k
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_self_attention_creation() {
let config = AttentionConfig::new(512, 8).expect("unwrap");
let attn = SelfAttention::new(config).expect("unwrap");
assert_eq!(attn.config.d_model, 512);
assert_eq!(attn.config.n_heads, 8);
assert_eq!(attn.config.d_k, 64);
}
#[test]
fn test_self_attention_scale_factor() {
let config = AttentionConfig::new(512, 8).expect("unwrap");
let attn = SelfAttention::new(config).expect("unwrap");
let scale = attn.get_scale_factor();
assert!((scale - 8.0).abs() < 1e-10); }
#[test]
fn test_self_attention_graph_building() {
let config = AttentionConfig::new(512, 8).expect("unwrap");
let attn = SelfAttention::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("Q");
graph.add_tensor("K");
graph.add_tensor("V");
let outputs = attn.build_attention_graph(&mut graph).expect("unwrap");
assert_eq!(outputs.len(), 1);
assert!(!graph.nodes.is_empty());
}
#[test]
fn test_multi_head_attention_creation() {
let config = AttentionConfig::new(512, 8).expect("unwrap");
let mha = MultiHeadAttention::new(config).expect("unwrap");
assert_eq!(mha.num_heads(), 8);
assert_eq!(mha.head_dim(), 64);
}
#[test]
fn test_multi_head_attention_graph_building() {
let config = AttentionConfig::new(512, 8).expect("unwrap");
let mha = MultiHeadAttention::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("Q");
graph.add_tensor("K");
graph.add_tensor("V");
let outputs = mha.build_mha_graph(&mut graph).expect("unwrap");
assert_eq!(outputs.len(), 1);
assert!(!graph.nodes.is_empty());
}
#[test]
fn test_attention_invalid_config() {
let result = AttentionConfig::new(512, 7); assert!(result.is_err());
}
}