use super::trait_def::{AttentionError, AttentionScores, DagAttentionMechanism};
use crate::dag::QueryDag;
use std::collections::HashMap;
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct TemporalBTSPConfig {
pub plateau_duration_ms: u64,
pub eligibility_decay: f32,
pub learning_rate: f32,
pub temperature: f32,
pub baseline_attention: f32,
}
impl Default for TemporalBTSPConfig {
fn default() -> Self {
Self {
plateau_duration_ms: 500,
eligibility_decay: 0.95,
learning_rate: 0.1,
temperature: 0.1,
baseline_attention: 0.5,
}
}
}
pub struct TemporalBTSPAttention {
config: TemporalBTSPConfig,
eligibility_traces: HashMap<usize, f32>,
last_plateau: HashMap<usize, Instant>,
update_count: usize,
}
impl TemporalBTSPAttention {
pub fn new(config: TemporalBTSPConfig) -> Self {
Self {
config,
eligibility_traces: HashMap::new(),
last_plateau: HashMap::new(),
update_count: 0,
}
}
fn update_eligibility(&mut self, node_id: usize, signal: f32) {
let trace = self.eligibility_traces.entry(node_id).or_insert(0.0);
*trace = *trace * self.config.eligibility_decay + signal * self.config.learning_rate;
*trace = trace.max(0.0).min(1.0);
}
fn is_plateau(&self, node_id: usize) -> bool {
self.last_plateau
.get(&node_id)
.map(|t| t.elapsed().as_millis() < self.config.plateau_duration_ms as u128)
.unwrap_or(false)
}
fn trigger_plateau(&mut self, node_id: usize) {
self.last_plateau.insert(node_id, Instant::now());
}
fn compute_topology_attention(&self, dag: &QueryDag) -> Vec<f32> {
let n = dag.node_count();
let mut scores = vec![self.config.baseline_attention; n];
for node in dag.nodes() {
if node.id < n {
let cost_factor = (node.estimated_cost as f32 / 100.0).min(1.0);
let rows_factor = (node.estimated_rows as f32 / 1000.0).min(1.0);
scores[node.id] = 0.5 * cost_factor + 0.5 * rows_factor;
}
}
scores
}
fn apply_eligibility_modulation(&self, base_scores: &mut [f32]) {
for (node_id, &trace) in &self.eligibility_traces {
if *node_id < base_scores.len() {
base_scores[*node_id] *= 1.0 + trace;
}
}
}
fn apply_plateau_boost(&self, scores: &mut [f32]) {
for (node_id, _) in &self.last_plateau {
if *node_id < scores.len() && self.is_plateau(*node_id) {
scores[*node_id] *= 1.5;
}
}
}
fn normalize_scores(&self, scores: &mut [f32]) {
if scores.is_empty() {
return;
}
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = scores
.iter()
.map(|&s| ((s - max_score) / self.config.temperature).exp())
.sum();
if exp_sum > 0.0 {
for score in scores.iter_mut() {
*score = ((*score - max_score) / self.config.temperature).exp() / exp_sum;
}
} else {
let uniform = 1.0 / scores.len() as f32;
scores.fill(uniform);
}
}
}
impl DagAttentionMechanism for TemporalBTSPAttention {
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
if dag.nodes.is_empty() {
return Err(AttentionError::InvalidDag("Empty DAG".to_string()));
}
let mut scores = self.compute_topology_attention(dag);
self.apply_eligibility_modulation(&mut scores);
self.apply_plateau_boost(&mut scores);
self.normalize_scores(&mut scores);
let mut result = AttentionScores::new(scores)
.with_metadata("mechanism".to_string(), "temporal_btsp".to_string())
.with_metadata("update_count".to_string(), self.update_count.to_string());
let active_traces = self
.eligibility_traces
.values()
.filter(|&&t| t > 0.01)
.count();
result
.metadata
.insert("active_traces".to_string(), active_traces.to_string());
let active_plateaus = self
.last_plateau
.keys()
.filter(|k| self.is_plateau(**k))
.count();
result
.metadata
.insert("active_plateaus".to_string(), active_plateaus.to_string());
Ok(result)
}
fn name(&self) -> &'static str {
"temporal_btsp"
}
fn complexity(&self) -> &'static str {
"O(n + t)"
}
fn update(&mut self, dag: &QueryDag, execution_times: &HashMap<usize, f64>) {
self.update_count += 1;
for (node_id, &exec_time) in execution_times {
let node = match dag.get_node(*node_id) {
Some(n) => n,
None => continue,
};
let expected_time = node.estimated_cost;
let time_ratio = exec_time / expected_time.max(0.001);
let reward = if time_ratio < 1.0 {
1.0 - time_ratio as f32
} else {
-(time_ratio as f32 - 1.0).min(1.0)
};
self.update_eligibility(*node_id, reward);
if reward > 0.3 {
self.trigger_plateau(*node_id);
}
}
let executed_nodes: std::collections::HashSet<_> = execution_times.keys().collect();
for node_id in 0..dag.node_count() {
if !executed_nodes.contains(&node_id) {
self.update_eligibility(node_id, 0.0);
}
}
}
fn reset(&mut self) {
self.eligibility_traces.clear();
self.last_plateau.clear();
self.update_count = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dag::{OperatorNode, OperatorType};
use std::thread::sleep;
use std::time::Duration;
#[test]
fn test_eligibility_update() {
let config = TemporalBTSPConfig::default();
let mut attention = TemporalBTSPAttention::new(config);
attention.update_eligibility(0, 0.5);
assert!(attention.eligibility_traces.get(&0).unwrap() > &0.0);
attention.update_eligibility(0, 0.5);
assert!(attention.eligibility_traces.get(&0).unwrap() > &0.0);
}
#[test]
fn test_plateau_state() {
let mut config = TemporalBTSPConfig::default();
config.plateau_duration_ms = 100;
let mut attention = TemporalBTSPAttention::new(config);
attention.trigger_plateau(0);
assert!(attention.is_plateau(0));
sleep(Duration::from_millis(150));
assert!(!attention.is_plateau(0));
}
#[test]
fn test_temporal_attention() {
let config = TemporalBTSPConfig::default();
let mut attention = TemporalBTSPAttention::new(config);
let mut dag = QueryDag::new();
for i in 0..3 {
let mut node = OperatorNode::new(i, OperatorType::Scan);
node.estimated_cost = 10.0;
dag.add_node(node);
}
let result1 = attention.forward(&dag).unwrap();
assert_eq!(result1.scores.len(), 3);
let mut exec_times = HashMap::new();
exec_times.insert(0, 5.0); exec_times.insert(1, 15.0);
attention.update(&dag, &exec_times);
let result2 = attention.forward(&dag).unwrap();
assert_eq!(result2.scores.len(), 3);
assert!(attention.eligibility_traces.get(&0).unwrap() > &0.0);
}
}