use super::trait_def::{AttentionError, AttentionScores, DagAttentionMechanism};
use crate::dag::QueryDag;
use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Debug, Clone)]
pub enum FlowCapacity {
UnitCapacity,
CostBased,
RowBased,
}
#[derive(Debug, Clone)]
pub struct MinCutConfig {
pub gate_threshold: f32,
pub flow_capacity: FlowCapacity,
}
impl Default for MinCutConfig {
fn default() -> Self {
Self {
gate_threshold: 0.5,
flow_capacity: FlowCapacity::UnitCapacity,
}
}
}
pub struct MinCutGatedAttention {
config: MinCutConfig,
}
impl MinCutGatedAttention {
pub fn new(config: MinCutConfig) -> Self {
Self { config }
}
pub fn with_defaults() -> Self {
Self::new(MinCutConfig::default())
}
fn compute_min_cut(&self, dag: &QueryDag) -> HashSet<usize> {
let mut cut_nodes = HashSet::new();
let mut capacity: HashMap<(usize, usize), f64> = HashMap::new();
for node_id in 0..dag.node_count() {
if dag.get_node(node_id).is_none() {
continue;
}
for &child in dag.children(node_id) {
let cap = match self.config.flow_capacity {
FlowCapacity::UnitCapacity => 1.0,
FlowCapacity::CostBased => dag
.get_node(node_id)
.map(|n| n.estimated_cost)
.unwrap_or(1.0),
FlowCapacity::RowBased => dag
.get_node(node_id)
.map(|n| n.estimated_rows)
.unwrap_or(1.0),
};
capacity.insert((node_id, child), cap);
}
}
let source = match dag.root() {
Some(root) => root,
None => return cut_nodes,
};
let leaves = dag.leaves();
if leaves.is_empty() {
return cut_nodes;
}
let sink = leaves[0];
let mut residual = capacity.clone();
#[allow(unused_variables, unused_assignments)]
let mut total_flow = 0.0;
loop {
let mut parent: HashMap<usize, usize> = HashMap::new();
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(source);
visited.insert(source);
while let Some(u) = queue.pop_front() {
if u == sink {
break;
}
for v in dag.children(u) {
if !visited.contains(v) && residual.get(&(u, *v)).copied().unwrap_or(0.0) > 0.0
{
visited.insert(*v);
parent.insert(*v, u);
queue.push_back(*v);
}
}
}
if !parent.contains_key(&sink) {
break;
}
let mut path_flow = f64::INFINITY;
let mut v = sink;
while v != source {
let u = parent[&v];
path_flow = path_flow.min(residual.get(&(u, v)).copied().unwrap_or(0.0));
v = u;
}
v = sink;
while v != source {
let u = parent[&v];
*residual.entry((u, v)).or_insert(0.0) -= path_flow;
*residual.entry((v, u)).or_insert(0.0) += path_flow;
v = u;
}
total_flow += path_flow;
}
let mut reachable = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(source);
reachable.insert(source);
while let Some(u) = queue.pop_front() {
for &v in dag.children(u) {
if !reachable.contains(&v) && residual.get(&(u, v)).copied().unwrap_or(0.0) > 0.0 {
reachable.insert(v);
queue.push_back(v);
}
}
}
for node_id in 0..dag.node_count() {
if dag.get_node(node_id).is_none() {
continue;
}
for &child in dag.children(node_id) {
if reachable.contains(&node_id) && !reachable.contains(&child) {
cut_nodes.insert(node_id);
cut_nodes.insert(child);
}
}
}
cut_nodes
}
}
impl DagAttentionMechanism for MinCutGatedAttention {
fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError> {
if dag.node_count() == 0 {
return Err(AttentionError::InvalidDag("Empty DAG".to_string()));
}
let cut_nodes = self.compute_min_cut(dag);
let n = dag.node_count();
let mut score_vec = vec![0.0; n];
let mut total = 0.0f32;
for node_id in 0..n {
if dag.get_node(node_id).is_none() {
continue;
}
let is_in_cut = cut_nodes.contains(&node_id);
let score = if is_in_cut {
1.0
} else {
self.config.gate_threshold
};
score_vec[node_id] = score;
total += score;
}
if total > 0.0 {
for score in score_vec.iter_mut() {
*score /= total;
}
}
Ok(AttentionScores::new(score_vec))
}
fn name(&self) -> &'static str {
"mincut_gated"
}
fn complexity(&self) -> &'static str {
"O(n * e^2)"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dag::{OperatorNode, OperatorType};
#[test]
fn test_mincut_gated_attention() {
let mut dag = QueryDag::new();
let id0 = dag.add_node(OperatorNode::seq_scan(0, "table1"));
let id1 = dag.add_node(OperatorNode::seq_scan(0, "table2"));
let id2 = dag.add_node(OperatorNode::hash_join(0, "id"));
let id3 = dag.add_node(OperatorNode::filter(0, "status = 'active'"));
let id4 = dag.add_node(OperatorNode::project(0, vec!["name".to_string()]));
dag.add_edge(id0, id2).unwrap();
dag.add_edge(id1, id2).unwrap();
dag.add_edge(id2, id3).unwrap();
dag.add_edge(id2, id4).unwrap();
let attention = MinCutGatedAttention::with_defaults();
let scores = attention.forward(&dag).unwrap();
let sum: f32 = scores.scores.iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
for &score in &scores.scores {
assert!(score >= 0.0 && score <= 1.0);
}
}
}