use std::sync::Arc;
use datafusion::arrow::array::RecordBatch;
use datafusion::dataframe::DataFrame;
use tracing::{debug, info, instrument, warn};
use crate::core::error::{AnamError, Result};
use crate::execution::dispatcher::DevicePool;
use crate::model::fao::FaoRef;
use crate::model::registry::ModelRegistry;
#[derive(Debug, Clone)]
pub struct QueryConstraints {
pub max_latency_ms: Option<f64>,
pub min_accuracy: Option<f64>,
pub max_cost: Option<f64>,
}
#[derive(Debug, Clone)]
pub struct CandidatePlan {
pub fao_ref: FaoRef,
pub est_latency_ms: f64,
pub est_accuracy: f64,
pub est_cost: f64,
}
impl CandidatePlan {
pub fn satisfies(&self, constraints: &QueryConstraints) -> bool {
if constraints
.max_latency_ms
.is_some_and(|max_lat| self.est_latency_ms > max_lat)
{
return false;
}
if constraints
.min_accuracy
.is_some_and(|min_acc| self.est_accuracy < min_acc)
{
return false;
}
if constraints
.max_cost
.is_some_and(|max_cost| self.est_cost > max_cost)
{
return false;
}
true
}
pub fn dominates(&self, other: &CandidatePlan) -> bool {
let lat_ok = self.est_latency_ms <= other.est_latency_ms;
let acc_ok = self.est_accuracy >= other.est_accuracy;
let cost_ok = self.est_cost <= other.est_cost;
let strictly_better = self.est_latency_ms < other.est_latency_ms
|| self.est_accuracy > other.est_accuracy
|| self.est_cost < other.est_cost;
lat_ok && acc_ok && cost_ok && strictly_better
}
}
pub struct ParetoOptimizer {
registry: Arc<ModelRegistry>,
device_pool: Arc<DevicePool>,
}
impl std::fmt::Debug for ParetoOptimizer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ParetoOptimizer").finish()
}
}
impl ParetoOptimizer {
pub fn new(registry: Arc<ModelRegistry>, device_pool: Arc<DevicePool>) -> Self {
Self {
registry,
device_pool,
}
}
pub fn parse_constraints(&self, query: &str) -> Result<(String, Option<QueryConstraints>)> {
let query_trimmed = query.trim().trim_end_matches(';');
if let Some(with_start) = query_trimmed.to_uppercase().rfind("WITH (") {
let clean_sql = query_trimmed[..with_start].trim().to_string();
let with_clause = &query_trimmed[with_start + 6..];
let with_body = with_clause.trim_end_matches(')').trim();
let mut constraints = QueryConstraints {
max_latency_ms: None,
min_accuracy: None,
max_cost: None,
};
for part in with_body.split(',') {
let part = part.trim();
if let Some((key, val)) = part.split_once('=') {
let key = key.trim().to_lowercase();
let val = val.trim();
match key.as_str() {
"max_latency_ms" => {
constraints.max_latency_ms = val.parse().ok();
}
"min_accuracy" => {
constraints.min_accuracy = val.parse().ok();
}
"max_cost" => {
constraints.max_cost = val.parse().ok();
}
_ => {
warn!(key = %key, "unknown constraint in WITH clause");
}
}
}
}
Ok((clean_sql, Some(constraints)))
} else {
Ok((query_trimmed.to_string(), None))
}
}
#[instrument(skip(self, df))]
pub async fn execute_with_constraints(
&self,
df: DataFrame,
constraints: QueryConstraints,
) -> Result<Vec<RecordBatch>> {
info!(?constraints, "executing with Pareto optimization");
let batches = df.collect().await.map_err(AnamError::DataFusion)?;
let operators = self.registry.list_operators();
if !operators.is_empty() {
let candidates =
self.enumerate_candidates(&operators, batches.iter().map(|b| b.num_rows()).sum());
let frontier = self.compute_pareto_frontier(&candidates);
let feasible: Vec<_> = frontier
.iter()
.filter(|c| c.satisfies(&constraints))
.collect();
if feasible.is_empty() {
warn!("no feasible plan on Pareto frontier — using default execution");
} else {
let best = feasible
.iter()
.min_by(|a, b| {
a.est_latency_ms
.partial_cmp(&b.est_latency_ms)
.unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap();
info!(
fao = %best.fao_ref.function_id,
latency = best.est_latency_ms,
accuracy = best.est_accuracy,
"selected optimal plan from Pareto frontier"
);
}
}
Ok(batches)
}
fn enumerate_candidates(&self, operators: &[FaoRef], total_rows: usize) -> Vec<CandidatePlan> {
operators
.iter()
.map(|fao| {
let device_multiplier = self.device_pool.speed_multiplier();
CandidatePlan {
fao_ref: fao.clone(),
est_latency_ms: fao.est_latency_ms * (total_rows as f64 / 1000.0).max(1.0)
/ device_multiplier,
est_accuracy: fao.est_accuracy,
est_cost: fao.est_latency_ms * 0.001 / device_multiplier,
}
})
.collect()
}
pub fn compute_pareto_frontier(&self, candidates: &[CandidatePlan]) -> Vec<CandidatePlan> {
let mut frontier = Vec::new();
for (i, candidate) in candidates.iter().enumerate() {
let dominated = candidates
.iter()
.enumerate()
.any(|(j, other)| i != j && other.dominates(candidate));
if !dominated {
frontier.push(candidate.clone());
}
}
debug!(
candidates = candidates.len(),
frontier = frontier.len(),
"computed Pareto frontier"
);
frontier
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_with_clause() {
let registry = Arc::new(ModelRegistry::new());
let pool = Arc::new(DevicePool::cpu_only());
let optimizer = ParetoOptimizer::new(registry, pool);
let (sql, constraints) = optimizer
.parse_constraints(
"SELECT * FROM HighRisk WITH (max_latency_ms = 50, min_accuracy = 0.95)",
)
.unwrap();
assert_eq!(sql, "SELECT * FROM HighRisk");
let c = constraints.unwrap();
assert_eq!(c.max_latency_ms, Some(50.0));
assert_eq!(c.min_accuracy, Some(0.95));
}
#[test]
fn pareto_frontier_basic() {
let registry = Arc::new(ModelRegistry::new());
let pool = Arc::new(DevicePool::cpu_only());
let optimizer = ParetoOptimizer::new(registry, pool);
let candidates = vec![
CandidatePlan {
fao_ref: FaoRef {
function_id: "fast".into(),
version: "1".into(),
model_id: "m1".into(),
est_latency_ms: 10.0,
est_accuracy: 0.8,
},
est_latency_ms: 10.0,
est_accuracy: 0.8,
est_cost: 0.01,
},
CandidatePlan {
fao_ref: FaoRef {
function_id: "accurate".into(),
version: "1".into(),
model_id: "m2".into(),
est_latency_ms: 100.0,
est_accuracy: 0.99,
},
est_latency_ms: 100.0,
est_accuracy: 0.99,
est_cost: 0.1,
},
CandidatePlan {
fao_ref: FaoRef {
function_id: "dominated".into(),
version: "1".into(),
model_id: "m3".into(),
est_latency_ms: 100.0,
est_accuracy: 0.8,
},
est_latency_ms: 100.0,
est_accuracy: 0.8,
est_cost: 0.1,
},
];
let frontier = optimizer.compute_pareto_frontier(&candidates);
assert_eq!(frontier.len(), 2);
let ids: Vec<_> = frontier
.iter()
.map(|c| c.fao_ref.function_id.as_str())
.collect();
assert!(ids.contains(&"fast"));
assert!(ids.contains(&"accurate"));
}
}