use anyhow::Result;
#[cfg(feature = "native")]
use tracing::info;
#[cfg(not(feature = "native"))]
macro_rules! info {
($($arg:tt)*) => {{}};
}
use crate::pipeline::types::{PipelineContext, PipelineStage};
pub struct OptimizationStage {
pub(crate) enable_gpu: bool,
pub(crate) enable_simd: bool,
pub(crate) gpu_threshold: usize,
pub(crate) backend_selector: crate::backend::BackendSelector,
}
impl OptimizationStage {
pub fn new(enable_gpu: bool, enable_simd: bool, gpu_threshold: usize) -> Self {
Self {
enable_gpu,
enable_simd,
gpu_threshold,
backend_selector: crate::backend::BackendSelector::new(),
}
}
pub fn analyze_optimizations(&self) -> Vec<String> {
contract_pre_analyze!(self);
use crate::backend::OpComplexity;
let mut recommendations = Vec::new();
let workloads = vec![
("Element-wise operations", OpComplexity::Low, 1_000_000),
("Vector reductions", OpComplexity::Medium, 50_000),
("Matrix multiplications", OpComplexity::High, 100_000),
];
for (name, complexity, size) in workloads {
let backend = self.backend_selector.select_with_moe(complexity, size);
recommendations
.push(format!("{}: {} backend recommended ({} elements)", name, backend, size));
}
recommendations
}
}
#[async_trait::async_trait]
impl PipelineStage for OptimizationStage {
fn name(&self) -> &'static str {
"Optimization"
}
async fn execute(&self, mut ctx: PipelineContext) -> Result<PipelineContext> {
info!(
"Applying optimizations using MoE routing (GPU: {}, SIMD: {})",
self.enable_gpu, self.enable_simd
);
let moe_recommendations = self.analyze_optimizations();
info!("MoE backend recommendations:");
for rec in &moe_recommendations {
info!(" - {}", rec);
}
if self.enable_simd {
ctx.optimizations.push("SIMD vectorization enabled".to_string());
}
if self.enable_gpu {
ctx.optimizations
.push(format!("GPU dispatch enabled (threshold: {})", self.gpu_threshold));
}
ctx.optimizations.extend(moe_recommendations);
ctx.metadata
.insert("optimizations_applied".to_string(), serde_json::json!(ctx.optimizations));
ctx.metadata.insert("moe_routing_enabled".to_string(), serde_json::json!(true));
Ok(ctx)
}
}