use serde::{Deserialize, Serialize};
use tensorlogic_ir::{EinsumGraph, EinsumNode, TLExpr, Term};
use crate::{
config::AttentionConfig,
error::{Result, TrustformerError},
};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum RuleAttentionType {
Hard,
Soft {
strength: f64,
},
Gated {
rule_weight: f64,
},
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct RuleAttentionConfig {
pub base_attention: AttentionConfig,
pub rule_type: RuleAttentionType,
pub normalize_after_rules: bool,
}
impl RuleAttentionConfig {
pub fn hard(base_attention: AttentionConfig) -> Self {
Self {
base_attention,
rule_type: RuleAttentionType::Hard,
normalize_after_rules: true,
}
}
pub fn soft(base_attention: AttentionConfig, strength: f64) -> Self {
Self {
base_attention,
rule_type: RuleAttentionType::Soft { strength },
normalize_after_rules: true,
}
}
pub fn gated(base_attention: AttentionConfig, rule_weight: f64) -> Self {
Self {
base_attention,
rule_type: RuleAttentionType::Gated { rule_weight },
normalize_after_rules: true,
}
}
pub fn with_normalize_after_rules(mut self, normalize: bool) -> Self {
self.normalize_after_rules = normalize;
self
}
pub fn validate(&self) -> Result<()> {
self.base_attention.validate()?;
match &self.rule_type {
RuleAttentionType::Soft { strength } => {
if !(0.0..=1.0).contains(strength) {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: format!("strength must be in [0,1], got {}", strength),
});
}
}
RuleAttentionType::Gated { rule_weight } => {
if !(0.0..=1.0).contains(rule_weight) {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: format!("rule_weight must be in [0,1], got {}", rule_weight),
});
}
}
RuleAttentionType::Hard => {}
}
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct RuleBasedAttention {
pub config: RuleAttentionConfig,
pub attention_rule: Option<TLExpr>,
}
impl RuleBasedAttention {
pub fn new(config: RuleAttentionConfig) -> Result<Self> {
config.validate()?;
Ok(Self {
config,
attention_rule: None,
})
}
pub fn with_rule(mut self, rule: TLExpr) -> Self {
self.attention_rule = Some(rule);
self
}
pub fn build_rule_attention_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
let scores_tensor = graph.add_tensor("rule_attn_scores");
let scores_node = EinsumNode::new("bqd,bkd->bqk", vec![0, 1], vec![scores_tensor]);
graph.add_node(scores_node)?;
let scale_factor = (self.config.base_attention.d_k as f64).sqrt();
let scale_tensor = graph.add_tensor("rule_attn_scale");
let scaled_tensor = graph.add_tensor("rule_attn_scaled_scores");
let scale_node = EinsumNode::elem_binary(
format!("div_scalar_{}", scale_factor),
scores_tensor,
scale_tensor,
scaled_tensor,
);
graph.add_node(scale_node)?;
let modified_scores = match &self.config.rule_type {
RuleAttentionType::Hard => {
let masked_tensor = graph.add_tensor("rule_hard_masked");
let mask_node = EinsumNode::elem_binary("mul", scaled_tensor, 3, masked_tensor);
graph.add_node(mask_node)?;
masked_tensor
}
RuleAttentionType::Soft { strength } => {
let strength_const = graph.add_tensor("strength_const");
let scaled_bias = graph.add_tensor("scaled_bias");
let scale_bias_node = EinsumNode::elem_binary(
format!("mul_scalar_{}", strength),
3,
strength_const,
scaled_bias,
);
graph.add_node(scale_bias_node)?;
let biased_tensor = graph.add_tensor("rule_soft_biased");
let add_node =
EinsumNode::elem_binary("add", scaled_tensor, scaled_bias, biased_tensor);
graph.add_node(add_node)?;
biased_tensor
}
RuleAttentionType::Gated { rule_weight } => {
let content_weight = 1.0 - rule_weight;
let content_weighted = graph.add_tensor("rule_content_weighted");
let content_const = graph.add_tensor("content_weight_const");
let content_node = EinsumNode::elem_binary(
format!("mul_scalar_{}", content_weight),
scaled_tensor,
content_const,
content_weighted,
);
graph.add_node(content_node)?;
let rule_weighted = graph.add_tensor("rule_weighted");
let rule_const = graph.add_tensor("rule_weight_const");
let rule_node = EinsumNode::elem_binary(
format!("mul_scalar_{}", rule_weight),
3,
rule_const,
rule_weighted,
);
graph.add_node(rule_node)?;
let gated_tensor = graph.add_tensor("rule_gated");
let gate_node =
EinsumNode::elem_binary("add", content_weighted, rule_weighted, gated_tensor);
graph.add_node(gate_node)?;
gated_tensor
}
};
let softmax_tensor = graph.add_tensor("rule_attention_weights");
let softmax_node = EinsumNode::elem_unary("softmax_k", modified_scores, softmax_tensor);
graph.add_node(softmax_node)?;
let output_tensor = graph.add_tensor("rule_attn_output");
let output_node =
EinsumNode::new("bqk,bkv->bqv", vec![softmax_tensor, 2], vec![output_tensor]);
graph.add_node(output_node)?;
Ok(vec![output_tensor])
}
pub fn get_rule(&self) -> Option<&TLExpr> {
self.attention_rule.as_ref()
}
}
#[derive(Clone, Debug)]
pub struct StructuredAttention {
pub config: AttentionConfig,
pub structure_predicate: Option<TLExpr>,
}
impl StructuredAttention {
pub fn new(config: AttentionConfig) -> Result<Self> {
config.validate()?;
Ok(Self {
config,
structure_predicate: None,
})
}
pub fn with_predicate(mut self, predicate: TLExpr) -> Self {
self.structure_predicate = Some(predicate);
self
}
pub fn build_structured_attention_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
let normalized_tensor = graph.add_tensor("struct_attn_normalized");
let norm_node = EinsumNode::elem_unary("softmax_k", 1, normalized_tensor);
graph.add_node(norm_node)?;
let output_tensor = graph.add_tensor("struct_attn_output");
let output_node = EinsumNode::new(
"bqk,bkv->bqv",
vec![normalized_tensor, 0],
vec![output_tensor],
);
graph.add_node(output_node)?;
Ok(vec![output_tensor])
}
pub fn get_predicate(&self) -> Option<&TLExpr> {
self.structure_predicate.as_ref()
}
}
pub mod patterns {
use super::*;
pub fn syntactic_dependency(head_idx: &str, dep_idx: &str) -> TLExpr {
TLExpr::Pred {
name: "SyntacticDep".to_string(),
args: vec![
Term::Var(head_idx.to_string()),
Term::Var(dep_idx.to_string()),
],
}
}
pub fn coreference(mention1: &str, mention2: &str) -> TLExpr {
TLExpr::Pred {
name: "Coref".to_string(),
args: vec![
Term::Var(mention1.to_string()),
Term::Var(mention2.to_string()),
],
}
}
pub fn semantic_similarity(token1: &str, token2: &str, threshold: f64) -> TLExpr {
let sim_pred = TLExpr::Pred {
name: "Similarity".to_string(),
args: vec![Term::Var(token1.to_string()), Term::Var(token2.to_string())],
};
let threshold_term = Term::Const(format!("{}", threshold));
TLExpr::Pred {
name: "GreaterThan".to_string(),
args: vec![Term::Const(format!("{:?}", sim_pred)), threshold_term],
}
}
pub fn hierarchical(child: &str, parent: &str) -> TLExpr {
TLExpr::Pred {
name: "ContainedIn".to_string(),
args: vec![Term::Var(child.to_string()), Term::Var(parent.to_string())],
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hard_rule_attention_config() {
let base = AttentionConfig::new(512, 8).expect("unwrap");
let config = RuleAttentionConfig::hard(base);
assert!(matches!(config.rule_type, RuleAttentionType::Hard));
assert!(config.validate().is_ok());
}
#[test]
fn test_soft_rule_attention_config() {
let base = AttentionConfig::new(512, 8).expect("unwrap");
let config = RuleAttentionConfig::soft(base, 0.5);
assert!(matches!(
config.rule_type,
RuleAttentionType::Soft { strength: 0.5 }
));
assert!(config.validate().is_ok());
}
#[test]
fn test_gated_rule_attention_config() {
let base = AttentionConfig::new(512, 8).expect("unwrap");
let config = RuleAttentionConfig::gated(base, 0.7);
assert!(matches!(
config.rule_type,
RuleAttentionType::Gated { rule_weight: 0.7 }
));
assert!(config.validate().is_ok());
}
#[test]
fn test_rule_based_attention_creation() {
let base = AttentionConfig::new(512, 8).expect("unwrap");
let config = RuleAttentionConfig::hard(base);
let attn = RuleBasedAttention::new(config).expect("unwrap");
assert!(attn.get_rule().is_none());
}
#[test]
fn test_rule_based_attention_with_rule() {
let base = AttentionConfig::new(512, 8).expect("unwrap");
let config = RuleAttentionConfig::hard(base);
let rule = patterns::syntactic_dependency("i", "j");
let attn = RuleBasedAttention::new(config)
.expect("unwrap")
.with_rule(rule);
assert!(attn.get_rule().is_some());
}
#[test]
fn test_rule_based_attention_graph_building() {
let base = AttentionConfig::new(512, 8).expect("unwrap");
let config = RuleAttentionConfig::soft(base, 0.5);
let attn = RuleBasedAttention::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("Q");
graph.add_tensor("K");
graph.add_tensor("V");
graph.add_tensor("rule_mask");
let outputs = attn.build_rule_attention_graph(&mut graph).expect("unwrap");
assert_eq!(outputs.len(), 1);
assert!(!graph.nodes.is_empty());
}
#[test]
fn test_structured_attention_creation() {
let config = AttentionConfig::new(512, 8).expect("unwrap");
let attn = StructuredAttention::new(config).expect("unwrap");
assert!(attn.get_predicate().is_none());
}
#[test]
fn test_structured_attention_with_predicate() {
let config = AttentionConfig::new(512, 8).expect("unwrap");
let predicate = patterns::coreference("m1", "m2");
let attn = StructuredAttention::new(config)
.expect("unwrap")
.with_predicate(predicate);
assert!(attn.get_predicate().is_some());
}
#[test]
fn test_structured_attention_graph_building() {
let config = AttentionConfig::new(512, 8).expect("unwrap");
let attn = StructuredAttention::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("V");
graph.add_tensor("structure_matrix");
let outputs = attn
.build_structured_attention_graph(&mut graph)
.expect("unwrap");
assert_eq!(outputs.len(), 1);
assert!(!graph.nodes.is_empty());
}
#[test]
fn test_pattern_syntactic_dependency() {
let rule = patterns::syntactic_dependency("head", "dep");
match rule {
TLExpr::Pred { name, args } => {
assert_eq!(name, "SyntacticDep");
assert_eq!(args.len(), 2);
}
_ => panic!("Expected Pred"),
}
}
#[test]
fn test_pattern_coreference() {
let rule = patterns::coreference("m1", "m2");
match rule {
TLExpr::Pred { name, args } => {
assert_eq!(name, "Coref");
assert_eq!(args.len(), 2);
}
_ => panic!("Expected Pred"),
}
}
#[test]
fn test_invalid_soft_strength() {
let base = AttentionConfig::new(512, 8).expect("unwrap");
let config = RuleAttentionConfig::soft(base, 1.5);
assert!(config.validate().is_err());
}
#[test]
fn test_invalid_gated_weight() {
let base = AttentionConfig::new(512, 8).expect("unwrap");
let config = RuleAttentionConfig::gated(base, -0.1);
assert!(config.validate().is_err());
}
}