use std::sync::Arc;
use datafusion_common::Result;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_physical_optimizer::PhysicalOptimizerRule;
use datafusion_physical_plan::ExecutionPlan;
use crate::operators::{
CausalChainExec, Complexity, GraphActivationExec, IterativeRetrievalExec, QualityGateExec,
};
#[derive(Debug, Default)]
pub struct DepthSchedulingRule {
forced_complexity: Option<Complexity>,
}
impl DepthSchedulingRule {
pub fn new() -> Self {
Self::default()
}
pub fn with_complexity(complexity: Complexity) -> Self {
Self {
forced_complexity: Some(complexity),
}
}
fn should_prune(plan: &dyn ExecutionPlan, complexity: Complexity) -> bool {
match complexity {
Complexity::Simple => {
plan.as_any()
.downcast_ref::<GraphActivationExec>()
.is_some()
|| plan.as_any().downcast_ref::<CausalChainExec>().is_some()
|| plan
.as_any()
.downcast_ref::<IterativeRetrievalExec>()
.is_some()
|| plan.as_any().downcast_ref::<QualityGateExec>().is_some()
}
Complexity::Medium => {
plan.as_any().downcast_ref::<CausalChainExec>().is_some()
|| plan
.as_any()
.downcast_ref::<IterativeRetrievalExec>()
.is_some()
}
Complexity::Complex => false, }
}
}
impl PhysicalOptimizerRule for DepthSchedulingRule {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
_config: &datafusion_common::config::ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let complexity = self.forced_complexity.unwrap_or(Complexity::Complex);
if complexity == Complexity::Complex {
return Ok(plan);
}
plan.transform_down(|node| {
if Self::should_prune(node.as_ref(), complexity) {
let children = node.children();
if let Some(child) = children.first() {
Ok(Transformed::yes(Arc::clone(child)))
} else {
let schema = node.schema();
Ok(Transformed::yes(
Arc::new(datafusion_physical_plan::empty::EmptyExec::new(schema))
as Arc<dyn ExecutionPlan>,
))
}
} else {
Ok(Transformed::no(node))
}
})
.map(|t| t.data)
}
fn name(&self) -> &str {
"DepthSchedulingRule"
}
fn schema_check(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::operators::{ActivationMode, IterativeConfig};
use arrow_array::{RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use datafusion_common::config::ConfigOptions;
use datafusion_datasource::memory::MemorySourceConfig;
fn leaf_plan() -> Arc<dyn ExecutionPlan> {
let schema = Arc::new(Schema::new(vec![Field::new(
"content",
DataType::Utf8,
false,
)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(StringArray::from(vec!["test"]))],
)
.unwrap();
MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap()
}
#[test]
fn simple_prunes_graph_activation() {
let leaf = leaf_plan();
let graph = Arc::new(
GraphActivationExec::new(leaf, 10, ActivationMode::Spreading, 2, 0.01, 0.1).unwrap(),
) as Arc<dyn ExecutionPlan>;
let rule = DepthSchedulingRule::with_complexity(Complexity::Simple);
let config = ConfigOptions::new();
let optimized = rule.optimize(graph, &config).unwrap();
assert!(
optimized
.as_any()
.downcast_ref::<GraphActivationExec>()
.is_none(),
"Simple should prune GraphActivationExec"
);
}
#[test]
fn simple_prunes_causal_chain() {
let leaf = leaf_plan();
let causal = Arc::new(CausalChainExec::new(leaf, 3, 0.3)) as Arc<dyn ExecutionPlan>;
let rule = DepthSchedulingRule::with_complexity(Complexity::Simple);
let config = ConfigOptions::new();
let optimized = rule.optimize(causal, &config).unwrap();
assert!(
optimized
.as_any()
.downcast_ref::<CausalChainExec>()
.is_none(),
"Simple should prune CausalChainExec"
);
}
#[test]
fn simple_prunes_iterative_retrieval() {
let leaf = leaf_plan();
let iterative = Arc::new(IterativeRetrievalExec::new(
leaf,
IterativeConfig::default(),
)) as Arc<dyn ExecutionPlan>;
let rule = DepthSchedulingRule::with_complexity(Complexity::Simple);
let config = ConfigOptions::new();
let optimized = rule.optimize(iterative, &config).unwrap();
assert!(
optimized
.as_any()
.downcast_ref::<IterativeRetrievalExec>()
.is_none(),
"Simple should prune IterativeRetrievalExec"
);
}
#[test]
fn medium_keeps_graph_activation() {
let leaf = leaf_plan();
let graph = Arc::new(
GraphActivationExec::new(leaf, 10, ActivationMode::Spreading, 2, 0.01, 0.1).unwrap(),
) as Arc<dyn ExecutionPlan>;
let rule = DepthSchedulingRule::with_complexity(Complexity::Medium);
let config = ConfigOptions::new();
let optimized = rule.optimize(graph, &config).unwrap();
assert!(
optimized
.as_any()
.downcast_ref::<GraphActivationExec>()
.is_some(),
"Medium should keep GraphActivationExec"
);
}
#[test]
fn medium_prunes_causal_chain() {
let leaf = leaf_plan();
let causal = Arc::new(CausalChainExec::new(leaf, 3, 0.3)) as Arc<dyn ExecutionPlan>;
let rule = DepthSchedulingRule::with_complexity(Complexity::Medium);
let config = ConfigOptions::new();
let optimized = rule.optimize(causal, &config).unwrap();
assert!(
optimized
.as_any()
.downcast_ref::<CausalChainExec>()
.is_none(),
"Medium should prune CausalChainExec"
);
}
#[test]
fn complex_keeps_all() {
let leaf = leaf_plan();
let graph = Arc::new(
GraphActivationExec::new(leaf, 10, ActivationMode::Spreading, 2, 0.01, 0.1).unwrap(),
) as Arc<dyn ExecutionPlan>;
let rule = DepthSchedulingRule::with_complexity(Complexity::Complex);
let config = ConfigOptions::new();
let optimized = rule.optimize(graph, &config).unwrap();
assert!(
optimized
.as_any()
.downcast_ref::<GraphActivationExec>()
.is_some(),
"Complex should keep GraphActivationExec"
);
}
#[test]
fn default_rule_is_complex() {
let leaf = leaf_plan();
let graph = Arc::new(
GraphActivationExec::new(leaf, 10, ActivationMode::Spreading, 2, 0.01, 0.1).unwrap(),
) as Arc<dyn ExecutionPlan>;
let rule = DepthSchedulingRule::new();
let config = ConfigOptions::new();
let optimized = rule.optimize(graph, &config).unwrap();
assert!(
optimized
.as_any()
.downcast_ref::<GraphActivationExec>()
.is_some(),
"Default (Complex) should keep all operators"
);
}
}