use tensorlogic_ir::{EinsumGraph, EinsumNode};
use crate::config::FeedForwardConfig;
use crate::error::Result;
#[derive(Clone, Debug)]
pub struct FeedForward {
pub config: FeedForwardConfig,
}
impl FeedForward {
pub fn new(config: FeedForwardConfig) -> Result<Self> {
config.validate()?;
Ok(Self { config })
}
pub fn build_ffn_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
let linear1_tensor = graph.add_tensor("ffn_linear1");
let linear1_node = EinsumNode::new("bsd,df->bsf", vec![0, 1], vec![linear1_tensor]);
graph.add_node(linear1_node)?;
let bias1_tensor = graph.add_tensor("ffn_bias1");
let bias1_node = EinsumNode::elem_binary("add", linear1_tensor, 2, bias1_tensor);
graph.add_node(bias1_node)?;
let activation_tensor = graph.add_tensor("ffn_activation");
let activation_node =
EinsumNode::elem_unary(&self.config.activation, bias1_tensor, activation_tensor);
graph.add_node(activation_node)?;
let linear2_tensor = graph.add_tensor("ffn_linear2");
let linear2_node = EinsumNode::new(
"bsf,fd->bsd",
vec![activation_tensor, 3],
vec![linear2_tensor],
);
graph.add_node(linear2_node)?;
let output_tensor = graph.add_tensor("ffn_output");
let bias2_node = EinsumNode::elem_binary("add", linear2_tensor, 4, output_tensor);
graph.add_node(bias2_node)?;
Ok(vec![output_tensor])
}
pub fn expansion_ratio(&self) -> f64 {
self.config.d_ff as f64 / self.config.d_model as f64
}
pub fn activation(&self) -> &str {
&self.config.activation
}
}
#[derive(Clone, Debug)]
pub struct GatedFeedForward {
pub config: FeedForwardConfig,
}
impl GatedFeedForward {
pub fn new(config: FeedForwardConfig) -> Result<Self> {
config.validate()?;
Ok(Self { config })
}
pub fn build_glu_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
let gate_proj = graph.add_tensor("glu_gate_proj");
let gate_node = EinsumNode::new("bsd,df->bsf", vec![0, 1], vec![gate_proj]);
graph.add_node(gate_node)?;
let gate_activated = graph.add_tensor("glu_gate_activated");
let gate_act_node = EinsumNode::elem_unary("sigmoid", gate_proj, gate_activated);
graph.add_node(gate_act_node)?;
let value_proj = graph.add_tensor("glu_value_proj");
let value_node = EinsumNode::new("bsd,df->bsf", vec![0, 2], vec![value_proj]);
graph.add_node(value_node)?;
let value_activated = graph.add_tensor("glu_value_activated");
let value_act_node =
EinsumNode::elem_unary(&self.config.activation, value_proj, value_activated);
graph.add_node(value_act_node)?;
let gated = graph.add_tensor("glu_gated");
let gate_mul_node = EinsumNode::elem_binary("mul", gate_activated, value_activated, gated);
graph.add_node(gate_mul_node)?;
let output_tensor = graph.add_tensor("glu_output");
let output_node = EinsumNode::new("bsf,fd->bsd", vec![gated, 3], vec![output_tensor]);
graph.add_node(output_node)?;
Ok(vec![output_tensor])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ffn_creation() {
let config = FeedForwardConfig::new(512, 2048);
let ffn = FeedForward::new(config).expect("unwrap");
assert_eq!(ffn.config.d_model, 512);
assert_eq!(ffn.config.d_ff, 2048);
assert_eq!(ffn.activation(), "gelu");
}
#[test]
fn test_ffn_expansion_ratio() {
let config = FeedForwardConfig::new(512, 2048);
let ffn = FeedForward::new(config).expect("unwrap");
let ratio = ffn.expansion_ratio();
assert!((ratio - 4.0).abs() < 1e-10);
}
#[test]
fn test_ffn_with_custom_activation() {
let config = FeedForwardConfig::new(512, 2048).with_activation("relu");
let ffn = FeedForward::new(config).expect("unwrap");
assert_eq!(ffn.activation(), "relu");
}
#[test]
fn test_ffn_graph_building() {
let config = FeedForwardConfig::new(512, 2048);
let ffn = FeedForward::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("x");
graph.add_tensor("W1");
graph.add_tensor("b1");
graph.add_tensor("W2");
graph.add_tensor("b2");
let outputs = ffn.build_ffn_graph(&mut graph).expect("unwrap");
assert_eq!(outputs.len(), 1);
assert!(!graph.nodes.is_empty());
}
#[test]
fn test_gated_ffn_creation() {
let config = FeedForwardConfig::new(512, 2048);
let glu = GatedFeedForward::new(config).expect("unwrap");
assert_eq!(glu.config.d_model, 512);
assert_eq!(glu.config.d_ff, 2048);
}
#[test]
fn test_gated_ffn_graph_building() {
let config = FeedForwardConfig::new(512, 2048);
let glu = GatedFeedForward::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("x");
graph.add_tensor("W_gate");
graph.add_tensor("W_value");
graph.add_tensor("W_out");
let outputs = glu.build_glu_graph(&mut graph).expect("unwrap");
assert_eq!(outputs.len(), 1);
assert!(!graph.nodes.is_empty());
}
}