batuta/pipeline/stages/
optimization.rs1use anyhow::Result;
4
5#[cfg(feature = "native")]
6use tracing::info;
7
8#[cfg(not(feature = "native"))]
10macro_rules! info {
11 ($($arg:tt)*) => {{}};
12}
13
14use crate::pipeline::types::{PipelineContext, PipelineStage};
15
16pub struct OptimizationStage {
18 pub(crate) enable_gpu: bool,
19 pub(crate) enable_simd: bool,
20 pub(crate) gpu_threshold: usize,
21 pub(crate) backend_selector: crate::backend::BackendSelector,
22}
23
24impl OptimizationStage {
25 pub fn new(enable_gpu: bool, enable_simd: bool, gpu_threshold: usize) -> Self {
26 Self {
27 enable_gpu,
28 enable_simd,
29 gpu_threshold,
30 backend_selector: crate::backend::BackendSelector::new(),
31 }
32 }
33
34 pub fn analyze_optimizations(&self) -> Vec<String> {
36 contract_pre_analyze!(self);
37 use crate::backend::OpComplexity;
38
39 let mut recommendations = Vec::new();
40
41 let workloads = vec![
43 ("Element-wise operations", OpComplexity::Low, 1_000_000),
44 ("Vector reductions", OpComplexity::Medium, 50_000),
45 ("Matrix multiplications", OpComplexity::High, 100_000),
46 ];
47
48 for (name, complexity, size) in workloads {
49 let backend = self.backend_selector.select_with_moe(complexity, size);
50 recommendations
51 .push(format!("{}: {} backend recommended ({} elements)", name, backend, size));
52 }
53
54 recommendations
55 }
56}
57
58#[async_trait::async_trait]
59impl PipelineStage for OptimizationStage {
60 fn name(&self) -> &'static str {
61 "Optimization"
62 }
63
64 async fn execute(&self, mut ctx: PipelineContext) -> Result<PipelineContext> {
65 info!(
66 "Applying optimizations using MoE routing (GPU: {}, SIMD: {})",
67 self.enable_gpu, self.enable_simd
68 );
69
70 let moe_recommendations = self.analyze_optimizations();
72
73 info!("MoE backend recommendations:");
74 for rec in &moe_recommendations {
75 info!(" - {}", rec);
76 }
77
78 if self.enable_simd {
80 ctx.optimizations.push("SIMD vectorization enabled".to_string());
81 }
82
83 if self.enable_gpu {
84 ctx.optimizations
85 .push(format!("GPU dispatch enabled (threshold: {})", self.gpu_threshold));
86 }
87
88 ctx.optimizations.extend(moe_recommendations);
90
91 ctx.metadata
93 .insert("optimizations_applied".to_string(), serde_json::json!(ctx.optimizations));
94
95 ctx.metadata.insert("moe_routing_enabled".to_string(), serde_json::json!(true));
96
97 Ok(ctx)
98 }
99}