use tensorlogic_ir::EinsumGraph;
use crate::{
attention::MultiHeadAttention,
config::{AttentionConfig, FeedForwardConfig},
error::Result,
ffn::FeedForward,
normalization::{LayerNorm, LayerNormConfig},
};
#[derive(Clone, Debug)]
pub struct EncoderConfig {
pub attention: AttentionConfig,
pub feed_forward: FeedForwardConfig,
pub layer_norm: LayerNormConfig,
pub pre_norm: bool,
}
impl EncoderConfig {
pub fn new(d_model: usize, n_heads: usize, d_ff: usize) -> Result<Self> {
Ok(Self {
attention: AttentionConfig::new(d_model, n_heads)?,
feed_forward: FeedForwardConfig::new(d_model, d_ff),
layer_norm: LayerNormConfig::new(d_model),
pre_norm: true,
})
}
pub fn with_pre_norm(mut self, pre_norm: bool) -> Self {
self.pre_norm = pre_norm;
self
}
pub fn with_causal(mut self, causal: bool) -> Self {
self.attention = self.attention.with_causal(causal);
self
}
pub fn with_dropout(mut self, dropout: f64) -> Self {
self.attention = self.attention.with_dropout(dropout);
self.feed_forward = self.feed_forward.with_dropout(dropout);
self
}
pub fn validate(&self) -> Result<()> {
self.attention.validate()?;
self.feed_forward.validate()?;
self.layer_norm.validate()?;
if self.attention.d_model != self.feed_forward.d_model {
return Err(crate::error::TrustformerError::InvalidDimension {
expected: self.attention.d_model,
got: self.feed_forward.d_model,
context: "d_model mismatch between attention and FFN".to_string(),
});
}
if self.attention.d_model != self.layer_norm.normalized_shape {
return Err(crate::error::TrustformerError::InvalidDimension {
expected: self.attention.d_model,
got: self.layer_norm.normalized_shape,
context: "d_model mismatch with layer norm".to_string(),
});
}
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct Encoder {
pub config: EncoderConfig,
pub attention: MultiHeadAttention,
pub ffn: FeedForward,
pub norm1: LayerNorm,
pub norm2: LayerNorm,
}
impl Encoder {
pub fn new(config: EncoderConfig) -> Result<Self> {
config.validate()?;
let attention = MultiHeadAttention::new(config.attention.clone())?;
let ffn = FeedForward::new(config.feed_forward.clone())?;
let norm1 = LayerNorm::new(config.layer_norm.clone())?;
let norm2 = LayerNorm::new(config.layer_norm.clone())?;
Ok(Self {
config,
attention,
ffn,
norm1,
norm2,
})
}
pub fn build_encoder_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
let input_tensor = 0;
if self.config.pre_norm {
self.build_pre_norm_encoder(graph, input_tensor)
} else {
self.build_post_norm_encoder(graph, input_tensor)
}
}
fn build_pre_norm_encoder(
&self,
graph: &mut EinsumGraph,
input_tensor: usize,
) -> Result<Vec<usize>> {
let normed1_outputs = self.norm1.build_layernorm_graph(graph)?;
let normed1 = normed1_outputs[0];
let q_tensor = graph.add_tensor("encoder_Q");
let k_tensor = graph.add_tensor("encoder_K");
let v_tensor = graph.add_tensor("encoder_V");
let _q_node = tensorlogic_ir::EinsumNode::elem_unary("identity", normed1, q_tensor);
let _k_node = tensorlogic_ir::EinsumNode::elem_unary("identity", normed1, k_tensor);
let _v_node = tensorlogic_ir::EinsumNode::elem_unary("identity", normed1, v_tensor);
let attn_outputs = self.attention.build_mha_graph(graph)?;
let attn_output = attn_outputs[0];
let residual1 = graph.add_tensor("encoder_residual1");
let res1_node =
tensorlogic_ir::EinsumNode::elem_binary("add", input_tensor, attn_output, residual1);
graph.add_node(res1_node)?;
let _normed2_outputs = self.norm2.build_layernorm_graph(graph)?;
let ffn_outputs = self.ffn.build_ffn_graph(graph)?;
let ffn_output = ffn_outputs[0];
let output = graph.add_tensor("encoder_output");
let res2_node =
tensorlogic_ir::EinsumNode::elem_binary("add", residual1, ffn_output, output);
graph.add_node(res2_node)?;
Ok(vec![output])
}
fn build_post_norm_encoder(
&self,
graph: &mut EinsumGraph,
input_tensor: usize,
) -> Result<Vec<usize>> {
let attn_outputs = self.attention.build_mha_graph(graph)?;
let attn_output = attn_outputs[0];
let residual1 = graph.add_tensor("encoder_residual1");
let res1_node =
tensorlogic_ir::EinsumNode::elem_binary("add", input_tensor, attn_output, residual1);
graph.add_node(res1_node)?;
let normed1_outputs = self.norm1.build_layernorm_graph(graph)?;
let normed1 = normed1_outputs[0];
let ffn_outputs = self.ffn.build_ffn_graph(graph)?;
let ffn_output = ffn_outputs[0];
let residual2 = graph.add_tensor("encoder_residual2");
let res2_node =
tensorlogic_ir::EinsumNode::elem_binary("add", normed1, ffn_output, residual2);
graph.add_node(res2_node)?;
let normed2_outputs = self.norm2.build_layernorm_graph(graph)?;
let output = normed2_outputs[0];
Ok(vec![output])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encoder_config_creation() {
let config = EncoderConfig::new(512, 8, 2048).expect("unwrap");
assert_eq!(config.attention.d_model, 512);
assert_eq!(config.attention.n_heads, 8);
assert_eq!(config.feed_forward.d_ff, 2048);
assert!(config.pre_norm);
assert!(config.validate().is_ok());
}
#[test]
fn test_encoder_config_with_dropout() {
let config = EncoderConfig::new(512, 8, 2048)
.expect("unwrap")
.with_dropout(0.1);
assert!((config.attention.dropout - 0.1).abs() < 1e-10);
assert!((config.feed_forward.dropout - 0.1).abs() < 1e-10);
}
#[test]
fn test_encoder_config_pre_norm() {
let config = EncoderConfig::new(512, 8, 2048)
.expect("unwrap")
.with_pre_norm(false);
assert!(!config.pre_norm);
}
#[test]
fn test_encoder_creation() {
let config = EncoderConfig::new(512, 8, 2048).expect("unwrap");
let encoder = Encoder::new(config).expect("unwrap");
assert_eq!(encoder.config.attention.d_model, 512);
}
#[test]
fn test_encoder_graph_building_pre_norm() {
let config = EncoderConfig::new(512, 8, 2048).expect("unwrap");
let encoder = Encoder::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("x");
let outputs = encoder.build_encoder_graph(&mut graph).expect("unwrap");
assert_eq!(outputs.len(), 1);
assert!(!graph.nodes.is_empty());
}
#[test]
fn test_encoder_graph_building_post_norm() {
let config = EncoderConfig::new(512, 8, 2048)
.expect("unwrap")
.with_pre_norm(false);
let encoder = Encoder::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("x");
let outputs = encoder.build_encoder_graph(&mut graph).expect("unwrap");
assert_eq!(outputs.len(), 1);
assert!(!graph.nodes.is_empty());
}
#[test]
fn test_encoder_config_validation() {
let config = EncoderConfig::new(512, 8, 2048).expect("unwrap");
assert!(config.validate().is_ok());
let result = EncoderConfig::new(512, 7, 2048);
assert!(result.is_err());
}
#[test]
fn test_encoder_with_causal() {
let config = EncoderConfig::new(512, 8, 2048)
.expect("unwrap")
.with_causal(true);
assert!(config.attention.causal);
}
}