use ruvector_dag::attention::{
CausalConeAttention, CausalConeConfig, DagAttention, TopologicalAttention, TopologicalConfig,
};
use ruvector_dag::dag::{OperatorNode, OperatorType, QueryDag};
fn main() {
println!("=== Attention Mechanism Selection ===\n");
let dag = create_vector_search_dag();
println!("Created vector search DAG:");
println!(" Nodes: {}", dag.node_count());
println!(" Edges: {}", dag.edge_count());
println!("\n--- Topological Attention ---");
println!("Emphasizes node depth in the DAG hierarchy");
let topo = TopologicalAttention::new(TopologicalConfig {
decay_factor: 0.9,
max_depth: 10,
});
let scores = topo.forward(&dag).unwrap();
println!("\nAttention scores:");
for (node_id, score) in &scores {
let node = dag.get_node(*node_id).unwrap();
println!(" Node {}: {:.4} - {:?}", node_id, score, node.op_type);
}
let sum: f32 = scores.values().sum();
println!("\nSum of scores: {:.4} (should be ~1.0)", sum);
println!("\n--- Causal Cone Attention ---");
println!("Focuses on downstream dependencies");
let causal = CausalConeAttention::new(CausalConeConfig {
time_window_ms: 1000,
future_discount: 0.85,
ancestor_weight: 0.5,
});
let causal_scores = causal.forward(&dag).unwrap();
println!("\nCausal cone scores:");
for (node_id, score) in &causal_scores {
let node = dag.get_node(*node_id).unwrap();
println!(" Node {}: {:.4} - {:?}", node_id, score, node.op_type);
}
println!("\n--- Comparison ---");
println!("Node | Topological | Causal Cone | Difference");
println!("-----|-------------|-------------|------------");
for node_id in 0..dag.node_count() {
let topo_score = scores.get(&node_id).unwrap_or(&0.0);
let causal_score = causal_scores.get(&node_id).unwrap_or(&0.0);
let diff = (topo_score - causal_score).abs();
println!(
"{:4} | {:11.4} | {:11.4} | {:11.4}",
node_id, topo_score, causal_score, diff
);
}
println!("\n=== Example Complete ===");
}
fn create_vector_search_dag() -> QueryDag {
let mut dag = QueryDag::new();
let hnsw = dag.add_node(OperatorNode::hnsw_scan(0, "embeddings_idx", 64));
let meta = dag.add_node(OperatorNode::seq_scan(1, "metadata"));
let join = dag.add_node(OperatorNode::new(2, OperatorType::NestedLoopJoin));
dag.add_edge(hnsw, join).unwrap();
dag.add_edge(meta, join).unwrap();
let filter = dag.add_node(OperatorNode::filter(3, "category = 'tech'"));
dag.add_edge(join, filter).unwrap();
let limit = dag.add_node(OperatorNode::limit(4, 10));
dag.add_edge(filter, limit).unwrap();
let result = dag.add_node(OperatorNode::new(5, OperatorType::Result));
dag.add_edge(limit, result).unwrap();
dag
}