ruvector-dag 2.0.6

Directed Acyclic Graph (DAG) structures for query plan optimization with neural learning
Documentation
//! Attention mechanism selection example

use ruvector_dag::attention::{
    CausalConeAttention, CausalConeConfig, DagAttention, TopologicalAttention, TopologicalConfig,
};
use ruvector_dag::dag::{OperatorNode, OperatorType, QueryDag};

fn main() {
    println!("=== Attention Mechanism Selection ===\n");

    // Create a sample DAG
    let dag = create_vector_search_dag();

    println!("Created vector search DAG:");
    println!("  Nodes: {}", dag.node_count());
    println!("  Edges: {}", dag.edge_count());

    // Test Topological Attention
    println!("\n--- Topological Attention ---");
    println!("Emphasizes node depth in the DAG hierarchy");

    let topo = TopologicalAttention::new(TopologicalConfig {
        decay_factor: 0.9,
        max_depth: 10,
    });

    let scores = topo.forward(&dag).unwrap();
    println!("\nAttention scores:");
    for (node_id, score) in &scores {
        let node = dag.get_node(*node_id).unwrap();
        println!("  Node {}: {:.4} - {:?}", node_id, score, node.op_type);
    }

    let sum: f32 = scores.values().sum();
    println!("\nSum of scores: {:.4} (should be ~1.0)", sum);

    // Test Causal Cone Attention
    println!("\n--- Causal Cone Attention ---");
    println!("Focuses on downstream dependencies");

    let causal = CausalConeAttention::new(CausalConeConfig {
        time_window_ms: 1000,
        future_discount: 0.85,
        ancestor_weight: 0.5,
    });

    let causal_scores = causal.forward(&dag).unwrap();
    println!("\nCausal cone scores:");
    for (node_id, score) in &causal_scores {
        let node = dag.get_node(*node_id).unwrap();
        println!("  Node {}: {:.4} - {:?}", node_id, score, node.op_type);
    }

    // Compare mechanisms
    println!("\n--- Comparison ---");
    println!("Node | Topological | Causal Cone | Difference");
    println!("-----|-------------|-------------|------------");
    for node_id in 0..dag.node_count() {
        let topo_score = scores.get(&node_id).unwrap_or(&0.0);
        let causal_score = causal_scores.get(&node_id).unwrap_or(&0.0);
        let diff = (topo_score - causal_score).abs();
        println!(
            "{:4} | {:11.4} | {:11.4} | {:11.4}",
            node_id, topo_score, causal_score, diff
        );
    }

    println!("\n=== Example Complete ===");
}

fn create_vector_search_dag() -> QueryDag {
    let mut dag = QueryDag::new();

    // HNSW scan - the primary vector search
    let hnsw = dag.add_node(OperatorNode::hnsw_scan(0, "embeddings_idx", 64));

    // Metadata table scan
    let meta = dag.add_node(OperatorNode::seq_scan(1, "metadata"));

    // Join embeddings with metadata
    let join = dag.add_node(OperatorNode::new(2, OperatorType::NestedLoopJoin));

    dag.add_edge(hnsw, join).unwrap();
    dag.add_edge(meta, join).unwrap();

    // Filter by category
    let filter = dag.add_node(OperatorNode::filter(3, "category = 'tech'"));
    dag.add_edge(join, filter).unwrap();

    // Limit results
    let limit = dag.add_node(OperatorNode::limit(4, 10));
    dag.add_edge(filter, limit).unwrap();

    // Result node
    let result = dag.add_node(OperatorNode::new(5, OperatorType::Result));
    dag.add_edge(limit, result).unwrap();

    dag
}