use crate::dag::QueryDag;
use std::collections::HashMap;
pub type AttentionScores = HashMap<usize, f32>;
#[derive(Debug, Clone)]
pub struct AttentionConfig {
pub normalize: bool,
pub temperature: f32,
pub dropout: f32,
}
impl Default for AttentionConfig {
fn default() -> Self {
Self {
normalize: true,
temperature: 1.0,
dropout: 0.0,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum AttentionError {
#[error("Empty DAG")]
EmptyDag,
#[error("Cycle detected in DAG")]
CycleDetected,
#[error("Node {0} not found")]
NodeNotFound(usize),
#[error("Computation failed: {0}")]
ComputationFailed(String),
}
pub trait DagAttention: Send + Sync {
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError>;
fn update(&mut self, dag: &QueryDag, execution_times: &HashMap<usize, f64>);
fn name(&self) -> &'static str;
fn complexity(&self) -> &'static str;
}