use crate::tensor::DenseTensor;
use crate::tensor::traits::TensorBase;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum GraphNodeType {
TokenEmbedding,
HiddenState,
AttentionOutput,
FFNOutput,
}
#[derive(Debug, Clone)]
pub struct TokenEmbeddingNode {
pub token_id: usize,
pub position: usize,
pub embedding: DenseTensor,
}
impl TokenEmbeddingNode {
pub fn new(token_id: usize, position: usize, embedding: DenseTensor) -> Self {
Self {
token_id,
position,
embedding,
}
}
pub fn hidden_dim(&self) -> usize {
self.embedding.shape()[1]
}
}
#[derive(Debug, Clone)]
pub struct HiddenStateNode {
pub layer: usize,
pub position: usize,
pub state: DenseTensor,
}
impl HiddenStateNode {
pub fn new(layer: usize, position: usize, state: DenseTensor) -> Self {
Self {
layer,
position,
state,
}
}
pub fn hidden_dim(&self) -> usize {
self.state.shape()[1]
}
}
#[derive(Debug, Clone)]
pub struct AttentionOutputNode {
pub layer: usize,
pub head: usize,
pub query_pos: usize,
pub attended_positions: Vec<usize>,
pub weights: Vec<f64>,
pub output: DenseTensor,
}
impl AttentionOutputNode {
pub fn new(
layer: usize,
head: usize,
query_pos: usize,
attended_positions: Vec<usize>,
weights: Vec<f64>,
output: DenseTensor,
) -> Self {
Self {
layer,
head,
query_pos,
attended_positions,
weights,
output,
}
}
pub fn head_dim(&self) -> usize {
self.output.shape()[1]
}
pub fn num_attended(&self) -> usize {
self.attended_positions.len()
}
}
#[derive(Debug, Clone)]
pub struct FFNOutputNode {
pub layer: usize,
pub position: usize,
pub output: DenseTensor,
}
impl FFNOutputNode {
pub fn new(layer: usize, position: usize, output: DenseTensor) -> Self {
Self {
layer,
position,
output,
}
}
pub fn hidden_dim(&self) -> usize {
self.output.shape()[1]
}
}
#[derive(Debug, Clone)]
pub struct GraphNode {
pub node_type: GraphNodeType,
pub id: usize,
pub layer: usize,
pub position: usize,
pub token_embedding: Option<TokenEmbeddingNode>,
pub hidden_state: Option<HiddenStateNode>,
pub attention_output: Option<AttentionOutputNode>,
pub ffn_output: Option<FFNOutputNode>,
}
impl GraphNode {
pub fn token_embedding(id: usize, token_id: usize, position: usize, embedding: DenseTensor) -> Self {
Self {
node_type: GraphNodeType::TokenEmbedding,
id,
layer: 0,
position,
token_embedding: Some(TokenEmbeddingNode::new(token_id, position, embedding)),
hidden_state: None,
attention_output: None,
ffn_output: None,
}
}
pub fn hidden_state(id: usize, layer: usize, position: usize, state: DenseTensor) -> Self {
Self {
node_type: GraphNodeType::HiddenState,
id,
layer,
position,
token_embedding: None,
hidden_state: Some(HiddenStateNode::new(layer, position, state)),
attention_output: None,
ffn_output: None,
}
}
pub fn attention_output(
id: usize,
layer: usize,
head: usize,
query_pos: usize,
attended_positions: Vec<usize>,
weights: Vec<f64>,
output: DenseTensor,
) -> Self {
Self {
node_type: GraphNodeType::AttentionOutput,
id,
layer,
position: query_pos,
token_embedding: None,
hidden_state: None,
attention_output: Some(AttentionOutputNode::new(
layer,
head,
query_pos,
attended_positions,
weights,
output,
)),
ffn_output: None,
}
}
pub fn ffn_output(id: usize, layer: usize, position: usize, output: DenseTensor) -> Self {
Self {
node_type: GraphNodeType::FFNOutput,
id,
layer,
position,
token_embedding: None,
hidden_state: None,
attention_output: None,
ffn_output: Some(FFNOutputNode::new(layer, position, output)),
}
}
pub fn get_embedding(&self) -> Option<&TokenEmbeddingNode> {
self.token_embedding.as_ref()
}
pub fn get_hidden_state(&self) -> Option<&HiddenStateNode> {
self.hidden_state.as_ref()
}
pub fn get_attention_output(&self) -> Option<&AttentionOutputNode> {
self.attention_output.as_ref()
}
pub fn get_ffn_output(&self) -> Option<&FFNOutputNode> {
self.ffn_output.as_ref()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_embedding_node() {
let embedding = DenseTensor::new(vec![0.1, 0.2, 0.3, 0.4], vec![1, 4]);
let node = GraphNode::token_embedding(0, 10, 0, embedding);
assert_eq!(node.node_type, GraphNodeType::TokenEmbedding);
assert_eq!(node.id, 0);
assert_eq!(node.layer, 0);
assert_eq!(node.position, 0);
let emb = node.get_embedding().unwrap();
assert_eq!(emb.token_id, 10);
assert_eq!(emb.position, 0);
assert_eq!(emb.hidden_dim(), 4);
}
#[test]
fn test_hidden_state_node() {
let state = DenseTensor::new(vec![0.1, 0.2, 0.3], vec![1, 3]);
let node = GraphNode::hidden_state(1, 5, 2, state);
assert_eq!(node.node_type, GraphNodeType::HiddenState);
assert_eq!(node.layer, 5);
assert_eq!(node.position, 2);
let hidden = node.get_hidden_state().unwrap();
assert_eq!(hidden.layer, 5);
assert_eq!(hidden.position, 2);
assert_eq!(hidden.hidden_dim(), 3);
}
#[test]
fn test_attention_output_node() {
let output = DenseTensor::new(vec![0.1, 0.2], vec![1, 2]);
let node = GraphNode::attention_output(
10,
3,
2,
5,
vec![3, 4, 5],
vec![0.3, 0.5, 0.2],
output,
);
assert_eq!(node.node_type, GraphNodeType::AttentionOutput);
assert_eq!(node.layer, 3);
let attn = node.get_attention_output().unwrap();
assert_eq!(attn.layer, 3);
assert_eq!(attn.head, 2);
assert_eq!(attn.query_pos, 5);
assert_eq!(attn.num_attended(), 3);
assert_eq!(attn.head_dim(), 2);
}
#[test]
fn test_ffn_output_node() {
let output = DenseTensor::new(vec![0.1, 0.2, 0.3], vec![1, 3]);
let node = GraphNode::ffn_output(20, 7, 4, output);
assert_eq!(node.node_type, GraphNodeType::FFNOutput);
assert_eq!(node.layer, 7);
assert_eq!(node.position, 4);
let ffn = node.get_ffn_output().unwrap();
assert_eq!(ffn.layer, 7);
assert_eq!(ffn.position, 4);
assert_eq!(ffn.hidden_dim(), 3);
}
}