helix/dna/mds/
optimizer.rs

1use crate::dna::mds::codegen::{HelixIR, Instruction};
2use std::collections::{HashMap, HashSet};
3pub use crate::mds::codegen::{StringPool, SymbolTable, Metadata, ConstantPool, ConstantValue};
4use std::path::PathBuf;
5use anyhow::Result;
6
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum OptimizationLevel {
10    Zero,
11    One,
12    Two,
13    Three,
14}
15impl Default for OptimizationLevel {
16    fn default() -> Self {
17        Self::Two
18    }
19}
20impl From<u8> for OptimizationLevel {
21    fn from(level: u8) -> Self {
22        match level {
23            0 => Self::Zero,
24            1 => Self::One,
25            2 => Self::Two,
26            3 | _ => Self::Three,
27        }
28    }
29}
30pub struct Optimizer {
31    level: OptimizationLevel,
32    stats: OptimizationStats,
33}
34impl Optimizer {
35    pub fn new(level: OptimizationLevel) -> Self {
36        Self {
37            level,
38            stats: OptimizationStats::default(),
39        }
40    }
41    pub fn optimize(&mut self, ir: &mut HelixIR) {
42        match self.level {
43            OptimizationLevel::Zero => {}
44            OptimizationLevel::One => {
45                self.apply_basic_optimizations(ir);
46            }
47            OptimizationLevel::Two => {
48                self.apply_basic_optimizations(ir);
49                self.apply_standard_optimizations(ir);
50            }
51            OptimizationLevel::Three => {
52                self.apply_basic_optimizations(ir);
53                self.apply_standard_optimizations(ir);
54                self.apply_aggressive_optimizations(ir);
55            }
56        }
57    }
58    fn apply_basic_optimizations(&mut self, ir: &mut HelixIR) {
59        self.deduplicate_strings(ir);
60        self.remove_dead_code(ir);
61        self.optimize_string_pool(ir);
62    }
63    fn apply_standard_optimizations(&mut self, ir: &mut HelixIR) {
64        self.fold_constants(ir);
65        self.inline_small_functions(ir);
66        self.optimize_instruction_sequence(ir);
67        self.merge_duplicate_sections(ir);
68    }
69    fn apply_aggressive_optimizations(&mut self, ir: &mut HelixIR) {
70        self.eliminate_cross_references(ir);
71        self.optimize_pipelines(ir);
72        self.compress_data_sections(ir);
73        self.reorder_for_cache_locality(ir);
74    }
75    fn deduplicate_strings(&mut self, ir: &mut HelixIR) {
76        let mut seen = HashMap::new();
77        let mut new_strings = Vec::new();
78        let mut remap = HashMap::new();
79        for (idx, string) in ir.string_pool.strings.iter().enumerate() {
80            if let Some(&existing_idx) = seen.get(string) {
81                remap.insert(idx as u32, existing_idx);
82                self.stats.strings_deduplicated += 1;
83            } else {
84                let new_idx = new_strings.len() as u32;
85                seen.insert(string.clone(), new_idx);
86                new_strings.push(string.clone());
87                remap.insert(idx as u32, new_idx);
88            }
89        }
90        let original_size = ir.string_pool.strings.len();
91        ir.string_pool.strings = new_strings;
92        self.stats.strings_removed = original_size - ir.string_pool.strings.len();
93        for instruction in &mut ir.instructions {
94            self.remap_instruction_strings(instruction, &remap);
95        }
96    }
97    fn remove_dead_code(&mut self, ir: &mut HelixIR) {
98        let mut reachable = HashSet::new();
99        let mut work_list = vec![0];
100        while let Some(idx) = work_list.pop() {
101            if idx >= ir.instructions.len() || !reachable.insert(idx) {
102                continue;
103            }
104            work_list.push(idx + 1);
105        }
106        let mut new_instructions = Vec::new();
107        let mut remap = HashMap::new();
108        for (idx, instruction) in ir.instructions.iter().enumerate() {
109            if reachable.contains(&idx) {
110                remap.insert(idx, new_instructions.len());
111                new_instructions.push(instruction.clone());
112            } else {
113                self.stats.instructions_removed += 1;
114            }
115        }
116        for instruction in &mut new_instructions {
117            self.remap_jump_targets(instruction, &remap);
118        }
119        ir.instructions = new_instructions;
120    }
121    fn optimize_string_pool(&mut self, ir: &mut HelixIR) {
122        let mut frequency = HashMap::new();
123        for instruction in &ir.instructions {
124            self.count_string_usage(instruction, &mut frequency);
125        }
126        let mut indexed_strings: Vec<(u32, String)> = ir
127            .string_pool
128            .strings
129            .iter()
130            .enumerate()
131            .map(|(i, s)| (i as u32, s.clone()))
132            .collect();
133        indexed_strings
134            .sort_by_key(|(idx, _)| {
135                std::cmp::Reverse(frequency.get(idx).copied().unwrap_or(0))
136            });
137        let mut remap = HashMap::new();
138        let mut new_strings = Vec::new();
139        for (old_idx, string) in indexed_strings {
140            let new_idx = new_strings.len() as u32;
141            remap.insert(old_idx, new_idx);
142            new_strings.push(string);
143        }
144        ir.string_pool.strings = new_strings;
145        for instruction in &mut ir.instructions {
146            self.remap_instruction_strings(instruction, &remap);
147        }
148    }
149    fn fold_constants(&mut self, _ir: &mut HelixIR) {}
150    fn inline_small_functions(&mut self, ir: &mut HelixIR) {
151        let mut reference_count = std::collections::HashMap::new();
152        for instruction in &ir.instructions {
153            if let Instruction::ResolveReference { index, .. } = instruction {
154                *reference_count.entry(*index).or_insert(0) += 1;
155            }
156        }
157        self.stats.functions_inlined = 0;
158    }
159    fn optimize_instruction_sequence(&mut self, ir: &mut HelixIR) {
160        let mut seen_properties: std::collections::HashSet<(u32, u32)> = std::collections::HashSet::new();
161        let mut i = 0;
162        while i < ir.instructions.len() {
163            match &ir.instructions[i] {
164                Instruction::SetProperty { target, key, .. } => {
165                    let prop_key = (*target, *key);
166                    if seen_properties.contains(&prop_key) {
167                        ir.instructions.remove(i);
168                        self.stats.instructions_removed += 1;
169                        continue;
170                    } else {
171                        seen_properties.insert(prop_key);
172                    }
173                }
174                Instruction::SetCapability { agent, capability } => {
175                    let cap_key = (*agent, *capability);
176                    if seen_properties.contains(&cap_key) {
177                        ir.instructions.remove(i);
178                        self.stats.instructions_removed += 1;
179                        continue;
180                    } else {
181                        seen_properties.insert(cap_key);
182                    }
183                }
184                _ => {}
185            }
186            i += 1;
187        }
188    }
189    fn merge_duplicate_sections(&mut self, ir: &mut HelixIR) {
190        use std::collections::HashMap;
191        let mut agent_signatures: HashMap<String, Vec<u32>> = HashMap::new();
192        for (id, agent) in &ir.symbol_table.agents {
193            let signature = format!(
194                "{}-{}-{:?}-{:?}", agent.model_idx, agent.role_idx, agent.temperature,
195                agent.max_tokens
196            );
197            agent_signatures.entry(signature).or_insert_with(Vec::new).push(*id);
198        }
199        for (_, agents) in &agent_signatures {
200            if agents.len() > 1 {
201                self.stats.sections_merged += agents.len() - 1;
202            }
203        }
204        let mut workflow_signatures: HashMap<String, Vec<u32>> = HashMap::new();
205        for (id, workflow) in &ir.symbol_table.workflows {
206            let signature = format!("{:?}", workflow.trigger_type);
207            workflow_signatures.entry(signature).or_insert_with(Vec::new).push(*id);
208        }
209    }
210    fn eliminate_cross_references(&mut self, ir: &mut HelixIR) {
211        use std::collections::HashSet;
212        let mut referenced_agents: HashSet<u32> = HashSet::new();
213        let mut referenced_workflows: HashSet<u32> = HashSet::new();
214        for crew in ir.symbol_table.crews.values() {
215            for agent_id in &crew.agent_ids {
216                referenced_agents.insert(*agent_id);
217            }
218        }
219        for workflow in ir.symbol_table.workflows.values() {
220            if let Some(pipeline) = &workflow.pipeline {
221                for node_id in pipeline {
222                    referenced_workflows.insert(*node_id);
223                }
224            }
225        }
226        for instruction in &ir.instructions {
227            match instruction {
228                Instruction::ResolveReference { ref_type: _, index: _ } => {}
229                _ => {}
230            }
231        }
232        let unreferenced_agents: Vec<u32> = ir
233            .symbol_table
234            .agents
235            .keys()
236            .filter(|id| !referenced_agents.contains(id))
237            .cloned()
238            .collect();
239        for agent_id in unreferenced_agents {
240            ir.symbol_table.agents.remove(&agent_id);
241            self.stats.instructions_removed += 1;
242        }
243        let unreferenced_workflows: Vec<u32> = ir
244            .symbol_table
245            .workflows
246            .keys()
247            .filter(|id| !referenced_workflows.contains(id))
248            .cloned()
249            .collect();
250        for workflow_id in unreferenced_workflows {
251            ir.symbol_table.workflows.remove(&workflow_id);
252            self.stats.instructions_removed += 1;
253        }
254    }
255    fn optimize_pipelines(&mut self, ir: &mut HelixIR) {
256        for i in 0..ir.instructions.len() {
257            if let Instruction::DefinePipeline { .. } = &ir.instructions[i] {
258                self.stats.pipelines_optimized += 1;
259            }
260        }
261    }
262    fn compress_data_sections(&mut self, ir: &mut HelixIR) {
263        let mut string_frequency: HashMap<u32, usize> = HashMap::new();
264        for instruction in &ir.instructions {
265            match instruction {
266                Instruction::SetProperty { key, .. } => {
267                    *string_frequency.entry(*key).or_insert(0) += 1;
268                }
269                Instruction::SetCapability { capability, .. } => {
270                    *string_frequency.entry(*capability).or_insert(0) += 1;
271                }
272                Instruction::SetMetadata { key, value } => {
273                    *string_frequency.entry(*key).or_insert(0) += 1;
274                    *string_frequency.entry(*value).or_insert(0) += 1;
275                }
276                _ => {}
277            }
278        }
279        for agent in ir.symbol_table.agents.values() {
280            *string_frequency.entry(agent.name_idx).or_insert(0) += 1;
281            *string_frequency.entry(agent.model_idx).or_insert(0) += 1;
282            *string_frequency.entry(agent.role_idx).or_insert(0) += 1;
283        }
284        let _total_strings = ir.string_pool.strings.len();
285        let frequently_used = string_frequency
286            .iter()
287            .filter(|(_, count)| **count > 1)
288            .count();
289        if frequently_used > 0 {
290            self.stats.bytes_saved += frequently_used * 8;
291        }
292    }
293    fn reorder_for_cache_locality(&mut self, ir: &mut HelixIR) {
294        let mut reordered = Vec::new();
295        let mut agent_instructions = Vec::new();
296        let mut workflow_instructions = Vec::new();
297        let mut other_instructions = Vec::new();
298        for instruction in ir.instructions.drain(..) {
299            match &instruction {
300                Instruction::DeclareAgent(_) => agent_instructions.push(instruction),
301                Instruction::DeclareWorkflow(_) => {
302                    workflow_instructions.push(instruction)
303                }
304                Instruction::DefinePipeline { .. } => {
305                    workflow_instructions.push(instruction)
306                }
307                _ => other_instructions.push(instruction),
308            }
309        }
310        reordered.extend(agent_instructions);
311        reordered.extend(workflow_instructions);
312        reordered.extend(other_instructions);
313        ir.instructions = reordered;
314    }
315    fn remap_instruction_strings(
316        &self,
317        instruction: &mut Instruction,
318        remap: &HashMap<u32, u32>,
319    ) {
320        match instruction {
321            Instruction::SetProperty { key, value, .. } => {
322                if let Some(&new_idx) = remap.get(key) {
323                    *key = new_idx;
324                }
325                if let ConstantValue::String(idx) = value {
326                    if let Some(&new_idx) = remap.get(idx) {
327                        *idx = new_idx;
328                    }
329                }
330            }
331            Instruction::SetCapability { capability, .. } => {
332                if let Some(&new_idx) = remap.get(capability) {
333                    *capability = new_idx;
334                }
335            }
336            Instruction::SetMetadata { key, value } => {
337                if let Some(&new_idx) = remap.get(key) {
338                    *key = new_idx;
339                }
340                if let Some(&new_idx) = remap.get(value) {
341                    *value = new_idx;
342                }
343            }
344            _ => {}
345        }
346    }
347    fn remap_jump_targets(
348        &self,
349        _instruction: &mut Instruction,
350        _remap: &HashMap<usize, usize>,
351    ) {}
352    fn count_string_usage(
353        &self,
354        instruction: &Instruction,
355        frequency: &mut HashMap<u32, usize>,
356    ) {
357        match instruction {
358            Instruction::SetProperty { key, value, .. } => {
359                *frequency.entry(*key).or_insert(0) += 1;
360                if let ConstantValue::String(idx) = value {
361                    *frequency.entry(*idx).or_insert(0) += 1;
362                }
363            }
364            Instruction::SetCapability { capability, .. } => {
365                *frequency.entry(*capability).or_insert(0) += 1;
366            }
367            Instruction::SetMetadata { key, value } => {
368                *frequency.entry(*key).or_insert(0) += 1;
369                *frequency.entry(*value).or_insert(0) += 1;
370            }
371            _ => {}
372        }
373    }
374    pub fn stats(&self) -> &OptimizationStats {
375        &self.stats
376    }
377}
378#[derive(Debug, Default)]
379pub struct OptimizationStats {
380    pub strings_deduplicated: usize,
381    pub strings_removed: usize,
382    pub instructions_removed: usize,
383    pub constants_folded: usize,
384    pub functions_inlined: usize,
385    pub pipelines_optimized: usize,
386    pub sections_merged: usize,
387    pub bytes_saved: usize,
388}
389impl OptimizationStats {
390    pub fn report(&self) -> String {
391        format!(
392            "Optimization Results:\n\
393             - Strings deduplicated: {}\n\
394             - Strings removed: {}\n\
395             - Instructions removed: {}\n\
396             - Constants folded: {}\n\
397             - Functions inlined: {}\n\
398             - Pipelines optimized: {}\n\
399             - Sections merged: {}\n\
400             - Total bytes saved: {}",
401            self.strings_deduplicated, self.strings_removed, self.instructions_removed,
402            self.constants_folded, self.functions_inlined, self.pipelines_optimized, self
403            .sections_merged, self.bytes_saved
404        )
405    }
406}
407#[cfg(test)]
408mod tests {
409    use crate::mds::codegen::{StringPool, SymbolTable, Metadata, ConstantPool, ConstantValue};
410    #[test]
411    fn test_optimization_levels() {
412        assert_eq!(OptimizationLevel::from(0), OptimizationLevel::Zero);
413        assert_eq!(OptimizationLevel::from(1), OptimizationLevel::One);
414        assert_eq!(OptimizationLevel::from(2), OptimizationLevel::Two);
415        assert_eq!(OptimizationLevel::from(3), OptimizationLevel::Three);
416        assert_eq!(OptimizationLevel::from(99), OptimizationLevel::Three);
417    }
418    #[test]
419    fn test_string_deduplication() {
420        let mut ir = HelixIR {
421            version: 1,
422            metadata: Metadata::default(),
423            symbol_table: SymbolTable::default(),
424            instructions: vec![
425                Instruction::DeclareAgent(0), Instruction::DeclareWorkflow(1),
426                Instruction::DeclareContext(2),
427            ],
428            string_pool: StringPool {
429                strings: vec![
430                    "hello".to_string(), "world".to_string(), "hello".to_string(),
431                ],
432                index: std::collections::HashMap::new(),
433            },
434            constants: ConstantPool::default(),
435        };
436        let mut optimizer = Optimizer::new(OptimizationLevel::One);
437        optimizer.deduplicate_strings(&mut ir);
438        assert_eq!(ir.string_pool.strings.len(), 2);
439        assert_eq!(optimizer.stats.strings_deduplicated, 1);
440    }
441    #[test]
442    fn test_constant_folding() {
443        let mut ir = HelixIR {
444            version: 1,
445            metadata: Metadata::default(),
446            symbol_table: SymbolTable::default(),
447            instructions: vec![Instruction::DeclareAgent(0),],
448            string_pool: StringPool::default(),
449            constants: ConstantPool::default(),
450        };
451        let mut optimizer = Optimizer::new(OptimizationLevel::Two);
452        optimizer.fold_constants(&mut ir);
453        assert_eq!(ir.instructions.len(), 1);
454        match &ir.instructions[0] {
455            Instruction::DeclareAgent(0) => {}
456            _ => panic!("Expected DeclareAgent(0)"),
457        }
458    }
459}
460
461pub fn optimize_command(
462    input: PathBuf,
463    output: Option<PathBuf>,
464    level: u8,
465    verbose: bool,
466) -> Result<(), Box<dyn std::error::Error>> {
467    let output_path = output.unwrap_or_else(|| input.clone());
468    if verbose {
469        println!("⚡ Optimizing: {}", input.display());
470        println!("  Level: {}", level);
471    }
472    let loader = crate::mds::loader::BinaryLoader::new();
473    let binary = loader.load_file(&input)?;
474    let serializer = crate::mds::serializer::BinarySerializer::new(false);
475    let mut ir = serializer.deserialize_to_ir(&binary)?;
476    let mut optimizer = crate::mds::optimizer::Optimizer::new(
477        OptimizationLevel::from(level),
478    );
479    optimizer.optimize(&mut ir);
480    let optimized_binary = serializer.serialize(ir, None)?;
481    serializer.write_to_file(&optimized_binary, &output_path)?;
482    println!("✅ Optimized successfully: {}", output_path.display());
483    if verbose {
484        let stats = optimizer.stats();
485        println!("\nOptimization Results:");
486        println!("{}", stats.report());
487    }
488    Ok(())
489}