use crate::error::Result;
use crate::graph::{DynamicGraph, VertexId, Weight};
use crate::jtree::hierarchy::{JTreeConfig, JTreeHierarchy};
use std::collections::VecDeque;
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub enum EscalationPolicy {
Never,
Always,
LowConfidence {
threshold: f64,
},
ValueChange {
relative_threshold: f64,
absolute_threshold: f64,
},
Periodic {
query_interval: usize,
},
LatencyBased {
tier1_max_latency: Duration,
},
Adaptive {
window_size: usize,
error_threshold: f64,
},
}
impl Default for EscalationPolicy {
fn default() -> Self {
EscalationPolicy::Adaptive {
window_size: 100,
error_threshold: 0.1,
}
}
}
#[derive(Debug, Clone)]
pub struct EscalationTrigger {
pub approximate_value: f64,
pub confidence: f64,
pub queries_since_exact: usize,
pub time_since_exact: Duration,
pub recent_errors: Vec<f64>,
}
impl EscalationTrigger {
pub fn should_escalate(&self, policy: &EscalationPolicy) -> bool {
match policy {
EscalationPolicy::Never => false,
EscalationPolicy::Always => true,
EscalationPolicy::LowConfidence { threshold } => self.confidence < *threshold,
EscalationPolicy::ValueChange {
relative_threshold,
absolute_threshold,
} => {
false
}
EscalationPolicy::Periodic { query_interval } => {
self.queries_since_exact >= *query_interval
}
EscalationPolicy::LatencyBased { tier1_max_latency } => {
false
}
EscalationPolicy::Adaptive {
window_size,
error_threshold,
} => {
if self.recent_errors.len() < *window_size / 2 {
return false;
}
let avg_error: f64 =
self.recent_errors.iter().sum::<f64>() / self.recent_errors.len() as f64;
avg_error > *error_threshold
}
}
}
}
#[derive(Debug, Clone)]
pub struct QueryResult {
pub value: f64,
pub is_exact: bool,
pub tier: u8,
pub confidence: f64,
pub latency: Duration,
pub escalated: bool,
}
#[derive(Debug, Clone, Default)]
pub struct TierMetrics {
pub tier1_queries: usize,
pub tier2_queries: usize,
pub escalations: usize,
pub tier1_total_latency: Duration,
pub tier2_total_latency: Duration,
pub recorded_errors: Vec<f64>,
}
impl TierMetrics {
pub fn tier1_avg_latency(&self) -> Duration {
if self.tier1_queries == 0 {
Duration::ZERO
} else {
self.tier1_total_latency / self.tier1_queries as u32
}
}
pub fn tier2_avg_latency(&self) -> Duration {
if self.tier2_queries == 0 {
Duration::ZERO
} else {
self.tier2_total_latency / self.tier2_queries as u32
}
}
pub fn avg_error(&self) -> f64 {
if self.recorded_errors.is_empty() {
0.0
} else {
self.recorded_errors.iter().sum::<f64>() / self.recorded_errors.len() as f64
}
}
pub fn escalation_rate(&self) -> f64 {
let total = self.tier1_queries + self.tier2_queries;
if total == 0 {
0.0
} else {
self.escalations as f64 / total as f64
}
}
}
pub struct TwoTierCoordinator {
graph: Arc<DynamicGraph>,
config: JTreeConfig,
tier1: Option<JTreeHierarchy>,
policy: EscalationPolicy,
metrics: TierMetrics,
error_window: VecDeque<f64>,
max_error_window: usize,
last_exact_value: Option<f64>,
queries_since_exact: usize,
last_exact_time: Instant,
cached_approx_value: Option<f64>,
}
impl std::fmt::Debug for TwoTierCoordinator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TwoTierCoordinator")
.field("num_levels", &self.tier1.as_ref().map(|h| h.num_levels()))
.field("policy", &self.policy)
.field("metrics", &self.metrics)
.field("queries_since_exact", &self.queries_since_exact)
.field("cached_approx_value", &self.cached_approx_value)
.finish()
}
}
impl TwoTierCoordinator {
pub fn new(graph: Arc<DynamicGraph>, policy: EscalationPolicy) -> Self {
Self {
graph,
config: JTreeConfig::default(),
tier1: None,
policy,
metrics: TierMetrics::default(),
error_window: VecDeque::new(),
max_error_window: 100,
last_exact_value: None,
queries_since_exact: 0,
last_exact_time: Instant::now(),
cached_approx_value: None,
}
}
pub fn with_defaults(graph: Arc<DynamicGraph>) -> Self {
Self::new(graph, EscalationPolicy::default())
}
pub fn with_jtree_config(
graph: Arc<DynamicGraph>,
jtree_config: JTreeConfig,
policy: EscalationPolicy,
) -> Self {
Self {
graph,
config: jtree_config,
tier1: None,
policy,
metrics: TierMetrics::default(),
error_window: VecDeque::new(),
max_error_window: 100,
last_exact_value: None,
queries_since_exact: 0,
last_exact_time: Instant::now(),
cached_approx_value: None,
}
}
pub fn build(&mut self) -> Result<()> {
let hierarchy = JTreeHierarchy::build(Arc::clone(&self.graph), self.config.clone())?;
self.tier1 = Some(hierarchy);
Ok(())
}
fn ensure_built(&mut self) -> Result<()> {
if self.tier1.is_none() {
self.build()?;
}
Ok(())
}
fn tier1_mut(&mut self) -> Result<&mut JTreeHierarchy> {
self.ensure_built()?;
self.tier1.as_mut().ok_or_else(|| {
crate::error::MinCutError::InternalError("Hierarchy not built".to_string())
})
}
pub fn min_cut(&mut self) -> QueryResult {
let start = Instant::now();
if let Err(e) = self.ensure_built() {
return QueryResult {
value: f64::INFINITY,
is_exact: false,
tier: 0,
confidence: 0.0,
latency: start.elapsed(),
escalated: false,
};
}
let trigger = self.build_trigger();
let use_exact = trigger.should_escalate(&self.policy);
let result = if use_exact {
self.query_tier2_global(start)
} else {
self.query_tier1_global(start)
};
result.unwrap_or_else(|_| QueryResult {
value: f64::INFINITY,
is_exact: false,
tier: 0,
confidence: 0.0,
latency: start.elapsed(),
escalated: false,
})
}
pub fn st_min_cut(&mut self, s: VertexId, t: VertexId) -> Result<QueryResult> {
let start = Instant::now();
self.ensure_built()?;
let trigger = self.build_trigger();
let use_exact = trigger.should_escalate(&self.policy);
if use_exact {
self.query_tier2_st(s, t, start)
} else {
self.query_tier1_st(s, t, start)
}
}
pub fn exact_min_cut(&mut self) -> QueryResult {
let start = Instant::now();
if let Err(_) = self.ensure_built() {
return QueryResult {
value: f64::INFINITY,
is_exact: false,
tier: 0,
confidence: 0.0,
latency: start.elapsed(),
escalated: false,
};
}
self.query_tier2_global(start)
.unwrap_or_else(|_| QueryResult {
value: f64::INFINITY,
is_exact: false,
tier: 0,
confidence: 0.0,
latency: start.elapsed(),
escalated: false,
})
}
pub fn approximate_min_cut(&mut self) -> QueryResult {
let start = Instant::now();
if let Err(_) = self.ensure_built() {
return QueryResult {
value: f64::INFINITY,
is_exact: false,
tier: 0,
confidence: 0.0,
latency: start.elapsed(),
escalated: false,
};
}
self.query_tier1_global(start)
.unwrap_or_else(|_| QueryResult {
value: f64::INFINITY,
is_exact: false,
tier: 0,
confidence: 0.0,
latency: start.elapsed(),
escalated: false,
})
}
fn query_tier1_global(&mut self, start: Instant) -> Result<QueryResult> {
let hierarchy = self.tier1_mut()?;
let approx = hierarchy.approximate_min_cut()?;
let value = approx.value;
let latency = start.elapsed();
self.cached_approx_value = Some(value);
self.metrics.tier1_queries += 1;
self.metrics.tier1_total_latency += latency;
self.queries_since_exact += 1;
let confidence = self.estimate_confidence();
Ok(QueryResult {
value,
is_exact: false,
tier: 1,
confidence,
latency,
escalated: false,
})
}
fn query_tier1_st(
&mut self,
_s: VertexId,
_t: VertexId,
start: Instant,
) -> Result<QueryResult> {
let hierarchy = self.tier1_mut()?;
let approx = hierarchy.approximate_min_cut()?;
let value = approx.value;
let latency = start.elapsed();
self.cached_approx_value = Some(value);
self.metrics.tier1_queries += 1;
self.metrics.tier1_total_latency += latency;
self.queries_since_exact += 1;
let confidence = self.estimate_confidence();
Ok(QueryResult {
value,
is_exact: false,
tier: 1,
confidence,
latency,
escalated: false,
})
}
fn query_tier2_global(&mut self, start: Instant) -> Result<QueryResult> {
let hierarchy = self.tier1_mut()?;
let cut_result = hierarchy.min_cut(true)?; let value = cut_result.value;
let latency = start.elapsed();
if let Some(last_approx) = self.cached_approx_value {
let error = if last_approx > 0.0 {
(value - last_approx).abs() / last_approx
} else {
0.0
};
self.record_error(error);
}
self.last_exact_value = Some(value);
self.queries_since_exact = 0;
self.last_exact_time = Instant::now();
self.metrics.tier2_queries += 1;
self.metrics.tier2_total_latency += latency;
self.metrics.escalations += 1;
Ok(QueryResult {
value,
is_exact: cut_result.is_exact,
tier: 2,
confidence: 1.0,
latency,
escalated: true,
})
}
fn query_tier2_st(
&mut self,
_s: VertexId,
_t: VertexId,
start: Instant,
) -> Result<QueryResult> {
let hierarchy = self.tier1_mut()?;
let cut_result = hierarchy.min_cut(true)?;
let value = cut_result.value;
let latency = start.elapsed();
self.last_exact_value = Some(value);
self.queries_since_exact = 0;
self.last_exact_time = Instant::now();
self.metrics.tier2_queries += 1;
self.metrics.tier2_total_latency += latency;
self.metrics.escalations += 1;
Ok(QueryResult {
value,
is_exact: cut_result.is_exact,
tier: 2,
confidence: 1.0,
latency,
escalated: true,
})
}
fn build_trigger(&self) -> EscalationTrigger {
let recent_errors: Vec<f64> = self.error_window.iter().copied().collect();
let approximate_value = self.cached_approx_value.unwrap_or(f64::INFINITY);
EscalationTrigger {
approximate_value,
confidence: self.estimate_confidence(),
queries_since_exact: self.queries_since_exact,
time_since_exact: self.last_exact_time.elapsed(),
recent_errors,
}
}
fn estimate_confidence(&self) -> f64 {
let level_factor = if let Some(ref hierarchy) = self.tier1 {
let num_levels = hierarchy.num_levels();
let approx_factor = hierarchy.approximation_factor();
if num_levels > 0 {
(1.0 / approx_factor.ln().max(1.0)).min(1.0)
} else {
0.5
}
} else {
0.5
};
let recency_factor = {
let elapsed = self.last_exact_time.elapsed().as_secs_f64();
(-elapsed / 60.0).exp() };
let error_factor = if self.error_window.is_empty() {
0.8
} else {
let avg_error: f64 =
self.error_window.iter().sum::<f64>() / self.error_window.len() as f64;
(1.0 - avg_error).max(0.0)
};
(level_factor * 0.4 + recency_factor * 0.3 + error_factor * 0.3).min(1.0)
}
fn record_error(&mut self, error: f64) {
self.error_window.push_back(error);
if self.error_window.len() > self.max_error_window {
self.error_window.pop_front();
}
self.metrics.recorded_errors.push(error);
}
pub fn insert_edge(&mut self, u: VertexId, v: VertexId, weight: Weight) -> Result<f64> {
self.ensure_built()?;
let hierarchy = self.tier1.as_mut().ok_or_else(|| {
crate::error::MinCutError::InternalError("Hierarchy not built".to_string())
})?;
let result = hierarchy.insert_edge(u, v, weight)?;
self.cached_approx_value = Some(result);
Ok(result)
}
pub fn delete_edge(&mut self, u: VertexId, v: VertexId) -> Result<f64> {
self.ensure_built()?;
let hierarchy = self.tier1.as_mut().ok_or_else(|| {
crate::error::MinCutError::InternalError("Hierarchy not built".to_string())
})?;
let result = hierarchy.delete_edge(u, v)?;
self.cached_approx_value = Some(result);
Ok(result)
}
pub fn multi_terminal_cut(&mut self, terminals: &[VertexId]) -> Result<f64> {
if terminals.len() < 2 {
return Ok(f64::INFINITY);
}
self.ensure_built()?;
let hierarchy = self.tier1.as_mut().ok_or_else(|| {
crate::error::MinCutError::InternalError("Hierarchy not built".to_string())
})?;
let approx = hierarchy.approximate_min_cut()?;
Ok(approx.value)
}
pub fn metrics(&self) -> &TierMetrics {
&self.metrics
}
pub fn reset_metrics(&mut self) {
self.metrics = TierMetrics::default();
self.error_window.clear();
}
pub fn policy(&self) -> &EscalationPolicy {
&self.policy
}
pub fn set_policy(&mut self, policy: EscalationPolicy) {
self.policy = policy;
}
pub fn graph(&self) -> &Arc<DynamicGraph> {
&self.graph
}
pub fn tier1(&self) -> Option<&JTreeHierarchy> {
self.tier1.as_ref()
}
pub fn num_levels(&self) -> usize {
self.tier1.as_ref().map(|h| h.num_levels()).unwrap_or(0)
}
pub fn rebuild(&mut self) -> Result<()> {
self.tier1 = None;
self.build()?;
self.last_exact_value = None;
self.queries_since_exact = 0;
self.cached_approx_value = None;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_graph() -> Arc<DynamicGraph> {
let g = Arc::new(DynamicGraph::new());
g.insert_edge(1, 2, 2.0).unwrap();
g.insert_edge(2, 3, 2.0).unwrap();
g.insert_edge(3, 1, 2.0).unwrap();
g.insert_edge(4, 5, 2.0).unwrap();
g.insert_edge(5, 6, 2.0).unwrap();
g.insert_edge(6, 4, 2.0).unwrap();
g.insert_edge(3, 4, 1.0).unwrap(); g
}
#[test]
fn test_coordinator_creation() {
let g = create_test_graph();
let mut coord = TwoTierCoordinator::with_defaults(g);
coord.build().unwrap();
assert_eq!(coord.metrics().tier1_queries, 0);
assert_eq!(coord.metrics().tier2_queries, 0);
assert!(coord.num_levels() > 0);
}
#[test]
fn test_approximate_query() {
let g = create_test_graph();
let mut coord = TwoTierCoordinator::with_defaults(g);
coord.build().unwrap();
let result = coord.approximate_min_cut();
assert!(!result.is_exact);
assert_eq!(result.tier, 1);
assert!(result.value.is_finite());
assert!(!result.escalated);
}
#[test]
fn test_exact_query() {
let g = create_test_graph();
let mut coord = TwoTierCoordinator::with_defaults(g);
coord.build().unwrap();
let result = coord.exact_min_cut();
assert_eq!(result.tier, 2);
assert_eq!(result.confidence, 1.0);
assert!(result.escalated);
assert!(result.value.is_finite());
}
#[test]
fn test_st_query() {
let g = create_test_graph();
let mut coord = TwoTierCoordinator::with_defaults(g);
coord.build().unwrap();
let result = coord.st_min_cut(1, 6).unwrap();
assert!(result.value.is_finite());
}
#[test]
fn test_escalation_never() {
let g = create_test_graph();
let mut coord = TwoTierCoordinator::new(g, EscalationPolicy::Never);
coord.build().unwrap();
for _ in 0..10 {
let result = coord.min_cut();
assert!(!result.escalated);
assert_eq!(result.tier, 1);
}
}
#[test]
fn test_escalation_always() {
let g = create_test_graph();
let mut coord = TwoTierCoordinator::new(g, EscalationPolicy::Always);
coord.build().unwrap();
let result = coord.min_cut();
assert!(result.escalated);
assert_eq!(result.tier, 2);
}
#[test]
fn test_escalation_periodic() {
let g = create_test_graph();
let mut coord =
TwoTierCoordinator::new(g, EscalationPolicy::Periodic { query_interval: 3 });
coord.build().unwrap();
let r1 = coord.min_cut();
assert!(!r1.escalated);
let r2 = coord.min_cut();
assert!(!r2.escalated);
let r3 = coord.min_cut();
assert!(!r3.escalated);
let r4 = coord.min_cut();
assert!(r4.escalated);
}
#[test]
fn test_metrics_tracking() {
let g = create_test_graph();
let mut coord = TwoTierCoordinator::new(g, EscalationPolicy::Never);
coord.build().unwrap();
coord.approximate_min_cut();
coord.approximate_min_cut();
coord.exact_min_cut();
let metrics = coord.metrics();
assert_eq!(metrics.tier1_queries, 2);
assert_eq!(metrics.tier2_queries, 1);
assert_eq!(metrics.escalations, 1);
}
#[test]
fn test_edge_update() {
let g = create_test_graph();
let mut coord = TwoTierCoordinator::with_defaults(g.clone());
coord.build().unwrap();
let initial = coord.approximate_min_cut().value;
g.insert_edge(1, 5, 10.0).unwrap();
let _ = coord.insert_edge(1, 5, 10.0);
let after = coord.approximate_min_cut().value;
assert!(initial.is_finite());
assert!(after.is_finite());
}
#[test]
fn test_multi_terminal() {
let g = create_test_graph();
let mut coord = TwoTierCoordinator::with_defaults(g);
coord.build().unwrap();
let result = coord.multi_terminal_cut(&[1, 4, 6]).unwrap();
assert!(result.is_finite());
}
#[test]
fn test_confidence_estimation() {
let g = create_test_graph();
let mut coord = TwoTierCoordinator::with_defaults(g);
coord.build().unwrap();
let result = coord.approximate_min_cut();
assert!(result.confidence > 0.0);
assert!(result.confidence <= 1.0);
}
#[test]
fn test_reset_metrics() {
let g = create_test_graph();
let mut coord = TwoTierCoordinator::with_defaults(g);
coord.build().unwrap();
coord.approximate_min_cut();
coord.exact_min_cut();
coord.reset_metrics();
let metrics = coord.metrics();
assert_eq!(metrics.tier1_queries, 0);
assert_eq!(metrics.tier2_queries, 0);
}
#[test]
fn test_rebuild() {
let g = create_test_graph();
let mut coord = TwoTierCoordinator::with_defaults(g);
coord.build().unwrap();
let initial = coord.approximate_min_cut().value;
coord.rebuild().unwrap();
let after = coord.approximate_min_cut().value;
assert!((initial - after).abs() < 1e-10 || (initial.is_finite() && after.is_finite()));
}
#[test]
fn test_policy_modification() {
let g = create_test_graph();
let mut coord = TwoTierCoordinator::new(g, EscalationPolicy::Never);
coord.build().unwrap();
let r1 = coord.min_cut();
assert!(!r1.escalated);
coord.set_policy(EscalationPolicy::Always);
let r2 = coord.min_cut();
assert!(r2.escalated);
}
}