use crate::dag::QueryDag;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use thiserror::Error;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AttentionScores {
pub scores: Vec<f32>,
pub edge_weights: Option<Vec<Vec<f32>>>,
pub metadata: HashMap<String, String>,
}
impl AttentionScores {
pub fn new(scores: Vec<f32>) -> Self {
Self {
scores,
edge_weights: None,
metadata: HashMap::new(),
}
}
pub fn with_edge_weights(mut self, weights: Vec<Vec<f32>>) -> Self {
self.edge_weights = Some(weights);
self
}
pub fn with_metadata(mut self, key: String, value: String) -> Self {
self.metadata.insert(key, value);
self
}
}
#[derive(Debug, Error)]
pub enum AttentionError {
#[error("Invalid DAG structure: {0}")]
InvalidDag(String),
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch { expected: usize, actual: usize },
#[error("Computation failed: {0}")]
ComputationFailed(String),
#[error("Configuration error: {0}")]
ConfigError(String),
}
pub trait DagAttentionMechanism: Send + Sync {
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError>;
fn name(&self) -> &'static str;
fn complexity(&self) -> &'static str;
fn update(&mut self, _dag: &QueryDag, _execution_times: &HashMap<usize, f64>) {
}
fn reset(&mut self) {
}
}