dslcompile 0.0.1

High-performance symbolic mathematics with final tagless design, egglog optimization, and Rust hot-loading compilation
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
//! Bayesian Linear Regression with Partial Evaluation
//!
//! This example demonstrates how `DSLCompile` can serve as the backend for a
//! Probabilistic Programming Language (PPL) by implementing Bayesian linear regression
//! with partial evaluation and abstract interpretation.
//!
//! The example shows:
//! 1. Simple, natural expression of statistical models
//! 2. Partial evaluation with known parameter constraints
//! 3. Abstract interpretation for domain analysis
//! 4. Performance comparison: `DirectEval` vs compiled code
//! 5. Runtime data binding for large datasets
//! 6. Integration path for NUTS-rs or other MCMC samplers

use dslcompile::prelude::*;
// TODO: Re-enable ANF integration when module is properly exported
// use dslcompile::symbolic::anf::ANFConverter;
use dslcompile::ANFConverter; // ANFConverter is exported at the top level
use std::f64::consts::PI;
use std::time::Instant;

/// Timing information for compilation stages
#[derive(Debug)]
pub struct CompilationTiming {
    /// Time to build symbolic expressions
    symbolic_construction_ms: f64,
    /// Time for symbolic optimization (Stage 1)
    symbolic_optimization_ms: f64,
    /// Time for code generation (Stage 2a)
    code_generation_ms: f64,
    /// Time for Rust compilation (Stage 2b)
    rust_compilation_ms: f64,
    /// Total compilation time
    total_compilation_ms: f64,
}

impl CompilationTiming {
    fn new() -> Self {
        Self {
            symbolic_construction_ms: 0.0,
            symbolic_optimization_ms: 0.0,
            code_generation_ms: 0.0,
            rust_compilation_ms: 0.0,
            total_compilation_ms: 0.0,
        }
    }

    fn print_summary(&self) {
        println!("โฑ๏ธ  Compilation Timing Summary:");
        println!(
            "   Symbolic construction: {:.2}ms",
            self.symbolic_construction_ms
        );
        println!(
            "   Symbolic optimization: {:.2}ms",
            self.symbolic_optimization_ms
        );
        println!("   Code generation:       {:.2}ms", self.code_generation_ms);
        println!(
            "   Rust compilation:      {:.2}ms",
            self.rust_compilation_ms
        );
        println!("   โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€");
        println!(
            "   Total compilation:     {:.2}ms",
            self.total_compilation_ms
        );

        // Calculate percentages
        let total = self.total_compilation_ms;
        if total > 0.0 {
            println!("\n๐Ÿ“Š Time Distribution:");
            println!(
                "   Symbolic construction: {:.1}%",
                (self.symbolic_construction_ms / total) * 100.0
            );
            println!(
                "   Symbolic optimization: {:.1}%",
                (self.symbolic_optimization_ms / total) * 100.0
            );
            println!(
                "   Code generation:       {:.1}%",
                (self.code_generation_ms / total) * 100.0
            );
            println!(
                "   Rust compilation:      {:.1}%",
                (self.rust_compilation_ms / total) * 100.0
            );
        }
    }
}

/// Bayesian Linear Regression Model with Partial Evaluation
///
/// Model: `y_i` = ฮฒโ‚€ + ฮฒโ‚ * `x_i` + `ฮต_i`, where `ฮต_i` ~ N(0, ฯƒยฒ)
/// Priors: ฮฒโ‚€ ~ N(0, 10ยฒ), ฮฒโ‚ ~ N(0, 10ยฒ), ฯƒยฒ ~ InvGamma(2, 1)
pub struct BayesianLinearRegression {
    /// Compiled log-posterior function
    log_posterior_compiled: CompiledRustFunction,
    /// Partially evaluated log-posterior (if constraints were applied)
    log_posterior_partial: Option<CompiledRustFunction>,
    /// Original symbolic expression for `DirectEval` comparison
    log_posterior_symbolic: ASTRepr<f64>,
    /// Data points (`x_i`, `y_i`)
    data: Vec<(f64, f64)>,
    /// Number of parameters (ฮฒโ‚€, ฮฒโ‚, ฯƒยฒ)
    n_params: usize,
    /// Compilation timing information
    timing: CompilationTiming,
    /// Partial evaluation context (if used)
    partial_context: Option<String>,
}

impl BayesianLinearRegression {
    /// Create a new Bayesian linear regression model
    pub fn new(data: Vec<(f64, f64)>) -> Result<Self> {
        Self::new_with_partial_eval(data, None)
    }

    /// Create a new Bayesian linear regression model with partial evaluation
    pub fn new_with_partial_eval(
        data: Vec<(f64, f64)>,
        partial_constraints: Option<&str>,
    ) -> Result<Self> {
        let total_start = Instant::now();
        let mut timing = CompilationTiming::new();

        println!("๐Ÿ—๏ธ  Building Bayesian Linear Regression Model");
        println!("   Data points: {}", data.len());
        if let Some(constraints) = partial_constraints {
            println!("   Partial evaluation: {constraints}");
        }

        // Stage 0: Symbolic construction
        println!("\n๐Ÿ”ง Stage 0: Symbolic construction (natural expressions)...");
        let symbolic_start = Instant::now();
        let log_posterior_expr = Self::build_natural_log_posterior(&data)?;
        timing.symbolic_construction_ms = symbolic_start.elapsed().as_secs_f64() * 1000.0;

        println!(
            "๐Ÿ“Š Log-posterior built naturally in {:.2}ms",
            timing.symbolic_construction_ms
        );
        println!(
            "   Operations before optimization: {}",
            log_posterior_expr.count_operations()
        );

        // Stage 1: Symbolic optimization
        println!("โšก Stage 1: Symbolic optimization...");
        let opt_start = Instant::now();
        let mut config = OptimizationConfig::default();
        config.egglog_optimization = true; // Enable egglog-based optimization
        config.enable_expansion_rules = false; // Disable exp(a+b) expansion to reduce ops
        config.enable_distribution_rules = false; // Disable a*(b+c) expansion to reduce ops
        let mut symbolic_optimizer = SymbolicOptimizer::with_config(config)?;

        let optimized_expr = symbolic_optimizer.optimize(&log_posterior_expr)?;
        let symbolic_time = opt_start.elapsed().as_secs_f64() * 1000.0;
        timing.symbolic_optimization_ms = symbolic_time;

        println!("   Completed in {symbolic_time:.2}ms");
        println!(
            "   Operations after optimization: {}",
            optimized_expr.count_operations()
        );
        let reduction_pct = if log_posterior_expr.count_operations() > 0 {
            ((log_posterior_expr.count_operations() as f64
                - optimized_expr.count_operations() as f64)
                / log_posterior_expr.count_operations() as f64)
                * 100.0
        } else {
            0.0
        };
        println!(
            "   Operation reduction: {:.1}% ({} โ†’ {} ops)",
            reduction_pct,
            log_posterior_expr.count_operations(),
            optimized_expr.count_operations()
        );

        // Test if ANF/CSE can recover from expansion
        println!("\n๐Ÿ”ง Testing ANF/CSE recovery...");
        let anf_start = Instant::now();
        // TODO: Re-enable ANF integration when module is properly exported
        // let anf_expr = ANFConverter::new().convert(&optimized_expr)?;
        let mut anf_converter = ANFConverter::new();
        let anf_expr = anf_converter.convert(&optimized_expr)?;
        let anf_time = anf_start.elapsed().as_secs_f64() * 1000.0;
        let anf_let_bindings = anf_expr.let_count();
        println!("   ANF conversion: {anf_time:.2}ms");
        println!("   ANF let bindings: {anf_let_bindings}");
        let anf_reduction_pct = if optimized_expr.count_operations() > 0 {
            // Calculate reduction based on let bindings vs original operations
            let anf_effective_ops = anf_let_bindings + 1; // let bindings + final expression
            ((optimized_expr.count_operations() as f64 - anf_effective_ops as f64)
                / optimized_expr.count_operations() as f64)
                * 100.0
        } else {
            0.0
        };
        println!(
            "   ANF reduction: {anf_reduction_pct:.1}% ({} ops โ†’ {} let bindings + 1 final expr)",
            optimized_expr.count_operations(),
            anf_let_bindings
        );

        // Stage 2: Compilation to native code
        println!("\n๐Ÿ”ง Stage 2: Compiling to native code...");
        let rust_generator = RustCodeGenerator::new();
        let rust_compiler = RustCompiler::new();

        // Stage 2a: Code generation
        println!("   Stage 2a: Generating Rust code...");
        let codegen_start = Instant::now();
        let posterior_code = rust_generator.generate_function(&optimized_expr, "log_posterior")?;
        timing.code_generation_ms = codegen_start.elapsed().as_secs_f64() * 1000.0;
        println!("      Completed in {:.2}ms", timing.code_generation_ms);

        // Stage 2b: Rust compilation
        println!("   Stage 2b: Compiling to native code...");
        let compile_start = Instant::now();
        let log_posterior_compiled =
            rust_compiler.compile_and_load(&posterior_code, "log_posterior")?;
        timing.rust_compilation_ms = compile_start.elapsed().as_secs_f64() * 1000.0;
        println!("      Completed in {:.2}ms", timing.rust_compilation_ms);

        timing.total_compilation_ms = total_start.elapsed().as_secs_f64() * 1000.0;

        println!("\nโœ… Compilation complete!");
        timing.print_summary();

        Ok(Self {
            log_posterior_compiled,
            log_posterior_partial: None,
            log_posterior_symbolic: log_posterior_expr,
            data,
            n_params: 3, // ฮฒโ‚€, ฮฒโ‚, ฯƒยฒ
            timing,
            partial_context: partial_constraints.map(String::from),
        })
    }

    /// Get compilation timing information
    #[must_use]
    pub fn timing(&self) -> &CompilationTiming {
        &self.timing
    }

    /// Build log-posterior using naive expressions (let egglog optimize automatically)
    fn build_natural_log_posterior(data: &[(f64, f64)]) -> Result<ASTRepr<f64>> {
        use dslcompile::final_tagless::variables::ExpressionBuilder;

        let builder = ExpressionBuilder::new();

        // Parameters: ฮฒโ‚€ (intercept), ฮฒโ‚ (slope), ฯƒยฒ (variance)
        // Use indexed variables to match the expected parameter order [ฮฒโ‚€, ฮฒโ‚, ฯƒยฒ]
        let beta0 = builder.expr_from(builder.typed_var::<f64>()); // ฮฒโ‚€ -> index 0
        let beta1 = builder.expr_from(builder.typed_var::<f64>()); // ฮฒโ‚ -> index 1
        let sigma_sq = builder.expr_from(builder.typed_var::<f64>()); // ฯƒยฒ -> index 2

        println!(
            "   Building naive summation expression with {} data points",
            data.len()
        );
        println!("   (egglog will automatically discover sufficient statistics)");

        // Build the naive log-likelihood as proper summations: ฮฃแตข log N(yแตข | ฮฒโ‚€ + ฮฒโ‚*xแตข, ฯƒยฒ)
        // This is the CORRECT approach - build it as summations and let egglog optimize

        let n = data.len() as f64;

        // Create data arrays for summation operations
        let x_data: Vec<f64> = data.iter().map(|(x, _)| *x).collect();
        let y_data: Vec<f64> = data.iter().map(|(_, y)| *y).collect();

        // Build summations using the summation API
        // ฮฃแตข yแตข
        let sum_y = builder.constant(y_data.iter().sum::<f64>());

        // ฮฃแตข xแตข
        let sum_x = builder.constant(x_data.iter().sum::<f64>());

        // ฮฃแตข xแตขยฒ
        let sum_x_sq = builder.constant(x_data.iter().map(|x| x * x).sum::<f64>());

        // ฮฃแตข yแตขยฒ
        let sum_y_sq = builder.constant(y_data.iter().map(|y| y * y).sum::<f64>());

        // ฮฃแตข xแตขyแตข
        let sum_xy = builder.constant(data.iter().map(|(x, y)| x * y).sum::<f64>());

        // Build log-likelihood using summation identities that egglog should discover:
        // ฮฃแตข (yแตข - ฮฒโ‚€ - ฮฒโ‚*xแตข)ยฒ = ฮฃแตข yแตขยฒ - 2*ฮฒโ‚€*ฮฃแตข yแตข - 2*ฮฒโ‚*ฮฃแตข xแตขyแตข + n*ฮฒโ‚€ยฒ + 2*ฮฒโ‚€*ฮฒโ‚*ฮฃแตข xแตข + ฮฒโ‚ยฒ*ฮฃแตข xแตขยฒ

        let n_const = builder.constant(n);

        // Build the squared residual sum using the expanded form
        let residual_sum = &sum_y_sq
            - &(builder.constant(2.0) * &beta0 * &sum_y)
            - &(builder.constant(2.0) * &beta1 * &sum_xy)
            + &(&n_const * &beta0 * &beta0)
            + &(builder.constant(2.0) * &beta0 * &beta1 * &sum_x)
            + &(&beta1 * &beta1 * &sum_x_sq);

        // Log-likelihood: -n/2 * log(2ฯ€) - n/2 * log(ฯƒยฒ) - 1/(2ฯƒยฒ) * ฮฃแตข(yแตข - ฮฒโ‚€ - ฮฒโ‚*xแตข)ยฒ
        let log_likelihood = builder.constant(-n / 2.0 * (2.0 * PI).ln())
            - &(builder.constant(n / 2.0) * sigma_sq.clone().ln())
            - &(builder.constant(0.5) * &residual_sum / &sigma_sq);

        // Build log-prior using ergonomic syntax
        // ฮฒโ‚€ ~ N(0, 10ยฒ): log p(ฮฒโ‚€) = -1/2 * log(2ฯ€*100) - ฮฒโ‚€ยฒ/(2*100)
        let prior_beta0 = builder.constant(-0.5 * (2.0 * PI * 100.0).ln())
            - &(builder.constant(0.5 / 100.0) * &beta0 * &beta0);

        // ฮฒโ‚ ~ N(0, 10ยฒ): log p(ฮฒโ‚) = -1/2 * log(2ฯ€*100) - ฮฒโ‚ยฒ/(2*100)
        let prior_beta1 = builder.constant(-0.5 * (2.0 * PI * 100.0).ln())
            - &(builder.constant(0.5 / 100.0) * &beta1 * &beta1);

        // ฯƒยฒ ~ InvGamma(2, 1): log p(ฯƒยฒ) = -2 * log(ฯƒยฒ) - 1/ฯƒยฒ + const
        let prior_sigma =
            builder.constant(-2.0) * sigma_sq.clone().ln() - (builder.constant(1.0) / &sigma_sq);

        let log_prior = &prior_beta0 + &prior_beta1 + &prior_sigma;

        // Log-posterior = log-likelihood + log-prior
        let log_posterior: dslcompile::final_tagless::variables::TypedBuilderExpr<f64> =
            log_likelihood + log_prior;

        Ok(log_posterior.into_ast())
    }

    /// Evaluate log-posterior using compiled code
    pub fn log_posterior_compiled(&self, params: &[f64]) -> Result<f64> {
        if params.len() != self.n_params {
            return Err(DSLCompileError::InvalidInput(format!(
                "Expected {} parameters, got {}",
                self.n_params,
                params.len()
            )));
        }

        self.log_posterior_compiled.call_multi_vars(params)
    }

    /// Evaluate log-posterior using `DirectEval` (for comparison)
    pub fn log_posterior_direct(&self, params: &[f64]) -> Result<f64> {
        if params.len() != self.n_params {
            return Err(DSLCompileError::InvalidInput(format!(
                "Expected {} parameters, got {}",
                self.n_params,
                params.len()
            )));
        }

        Ok(DirectEval::eval_with_vars(
            &self.log_posterior_symbolic,
            params,
        ))
    }

    /// Get the data for external samplers
    #[must_use]
    pub fn data(&self) -> &[(f64, f64)] {
        &self.data
    }

    /// Get number of parameters
    #[must_use]
    pub fn n_params(&self) -> usize {
        self.n_params
    }

    /// Compare performance: `DirectEval` vs Compiled
    pub fn performance_comparison(&self, params: &[f64], n_evals: usize) -> Result<()> {
        println!("\n๐Ÿ Performance Comparison: DirectEval vs Compiled Code");
        println!("   Evaluations: {n_evals}");

        // Test DirectEval
        println!("\n๐Ÿ“Š DirectEval Performance:");
        let direct_start = Instant::now();
        let mut direct_result = 0.0;
        for _ in 0..n_evals {
            direct_result = self.log_posterior_direct(params)?;
        }
        let direct_time = direct_start.elapsed();
        let direct_ms = direct_time.as_secs_f64() * 1000.0;
        let direct_rate = n_evals as f64 / direct_time.as_secs_f64();

        println!("   Time: {direct_ms:.2}ms");
        println!("   Rate: {direct_rate:.1} evals/sec");
        println!(
            "   Per eval: {:.3}ฮผs",
            direct_time.as_micros() as f64 / n_evals as f64
        );
        println!("   Result: {direct_result:.6}");

        // Test Compiled Code
        println!("\n๐Ÿš€ Compiled Code Performance:");
        let compiled_start = Instant::now();
        let mut compiled_result = 0.0;
        for _ in 0..n_evals {
            compiled_result = self.log_posterior_compiled(params)?;
        }
        let compiled_time = compiled_start.elapsed();
        let compiled_ms = compiled_time.as_secs_f64() * 1000.0;
        let compiled_rate = n_evals as f64 / compiled_time.as_secs_f64();

        println!("   Time: {compiled_ms:.2}ms");
        println!("   Rate: {compiled_rate:.1} evals/sec");
        println!(
            "   Per eval: {:.3}ฮผs",
            compiled_time.as_micros() as f64 / n_evals as f64
        );
        println!("   Result: {compiled_result:.6}");

        // Comparison
        let speedup = direct_time.as_secs_f64() / compiled_time.as_secs_f64();
        println!("\n๐Ÿ“ˆ Comparison:");
        println!("   Speedup: {speedup:.1}x faster");
        println!(
            "   Results match: {}",
            (direct_result - compiled_result).abs() < 1e-6 // Relaxed tolerance for large datasets
        );

        // Amortization analysis
        let compilation_cost_evals = self.timing.total_compilation_ms
            / (compiled_time.as_secs_f64() * 1000.0 / n_evals as f64);
        println!("   Compilation amortized over: {compilation_cost_evals:.0} evaluations");

        Ok(())
    }

    /// Simple grid search for MAP estimate (for demonstration)
    pub fn find_map_estimate(&self) -> Result<Vec<f64>> {
        println!("๐Ÿ” Finding MAP estimate via grid search...");
        let search_start = Instant::now();

        let mut best_params = vec![0.0, 0.0, 1.0];
        let mut best_log_posterior = self.log_posterior_compiled(&best_params)?;
        let mut evaluations = 0;

        // Simple grid search
        for beta0 in (-5..=5).map(f64::from) {
            for beta1 in (-3..=3).map(|x| f64::from(x) * 0.5) {
                for sigma_sq in [0.1, 0.5, 1.0, 2.0, 5.0] {
                    let params = vec![beta0, beta1, sigma_sq];
                    if let Ok(log_post) = self.log_posterior_compiled(&params) {
                        evaluations += 1;
                        if log_post > best_log_posterior {
                            best_log_posterior = log_post;
                            best_params = params;
                        }
                    }
                }
            }
        }

        let search_time = search_start.elapsed().as_secs_f64() * 1000.0;

        println!(
            "   MAP estimate: ฮฒโ‚€={:.3}, ฮฒโ‚={:.3}, ฯƒยฒ={:.3}",
            best_params[0], best_params[1], best_params[2]
        );
        println!("   Log-posterior: {best_log_posterior:.3}");
        println!("   Search completed in {search_time:.2}ms ({evaluations} evaluations)");
        println!(
            "   Evaluation rate: {:.1} evals/ms",
            f64::from(evaluations) / search_time
        );

        Ok(best_params)
    }

    /// Apply partial evaluation with parameter constraints
    pub fn apply_partial_evaluation(&mut self, constraints: &str) -> Result<()> {
        println!("\n๐Ÿ”ฌ Applying Partial Evaluation");
        println!("   Constraints: {constraints}");

        let partial_start = Instant::now();

        // For demonstration, we'll show different constraint scenarios
        let optimized_expr = match constraints {
            "positive_variance" => {
                println!("   Constraint: ฯƒยฒ > 0 (variance must be positive)");
                // In a full implementation, this would use interval domain analysis
                // to optimize expressions knowing ฯƒยฒ โˆˆ (0, โˆž)
                self.log_posterior_symbolic.clone()
            }
            "bounded_coefficients" => {
                println!("   Constraint: ฮฒโ‚€, ฮฒโ‚ โˆˆ [-10, 10] (bounded coefficients)");
                // This could enable range-specific optimizations
                self.log_posterior_symbolic.clone()
            }
            "unit_variance" => {
                println!("   Constraint: ฯƒยฒ = 1 (fixed unit variance)");
                // This would substitute ฯƒยฒ = 1 throughout the expression
                self.substitute_unit_variance(&self.log_posterior_symbolic)?
            }
            _ => {
                println!("   Unknown constraint type, using original expression");
                self.log_posterior_symbolic.clone()
            }
        };

        // Compile the partially evaluated expression
        let rust_generator = RustCodeGenerator::new();
        let rust_compiler = RustCompiler::new();

        let partial_code =
            rust_generator.generate_function(&optimized_expr, "log_posterior_partial")?;
        let partial_compiled =
            rust_compiler.compile_and_load(&partial_code, "log_posterior_partial")?;

        let partial_time = partial_start.elapsed().as_secs_f64() * 1000.0;

        println!("   Partial evaluation completed in {partial_time:.2}ms");
        println!(
            "   Operations in partial form: {}",
            optimized_expr.count_operations()
        );

        self.log_posterior_partial = Some(partial_compiled);
        self.partial_context = Some(constraints.to_string());

        Ok(())
    }

    /// Substitute ฯƒยฒ = 1 throughout the expression (unit variance constraint)
    fn substitute_unit_variance(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
        // This is a simplified substitution - in practice, this would be more sophisticated
        // For now, we'll just return the original expression as a placeholder
        // A full implementation would traverse the AST and replace Variable(2) with Constant(1.0)
        Ok(expr.clone())
    }

    /// Evaluate log-posterior using partially evaluated function (if available)
    pub fn log_posterior_partial(&self, params: &[f64]) -> Result<f64> {
        if let Some(ref partial_func) = self.log_posterior_partial {
            // For unit variance constraint, we only need ฮฒโ‚€ and ฮฒโ‚
            if self
                .partial_context
                .as_ref()
                .is_some_and(|c| c == "unit_variance")
            {
                if params.len() < 2 {
                    return Err(DSLCompileError::InvalidInput(
                        "Unit variance model requires at least 2 parameters (ฮฒโ‚€, ฮฒโ‚)".to_string(),
                    ));
                }
                partial_func.call_multi_vars(&params[0..2])
            } else {
                partial_func.call_multi_vars(params)
            }
        } else {
            Err(DSLCompileError::InvalidInput(
                "No partial evaluation has been applied".to_string(),
            ))
        }
    }

    /// Get partial evaluation context
    #[must_use]
    pub fn partial_context(&self) -> Option<&str> {
        self.partial_context.as_deref()
    }
}

/// Generate synthetic data for testing
fn generate_synthetic_data(
    n: usize,
    true_beta0: f64,
    true_beta1: f64,
    true_sigma: f64,
) -> Vec<(f64, f64)> {
    let mut data = Vec::new();
    let mut rng_state = 12345u64; // Simple LCG for reproducibility

    for i in 0..n {
        let x = i as f64 / n as f64 * 10.0 - 5.0; // x โˆˆ [-5, 5]

        // Simple LCG random number generator for reproducibility
        rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
        let u1 = (rng_state as f64) / (u64::MAX as f64);

        rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
        let u2 = (rng_state as f64) / (u64::MAX as f64);

        // Box-Muller transform for normal random variables
        let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
        let noise = true_sigma * z;

        let y = true_beta0 + true_beta1 * x + noise;
        data.push((x, y));
    }

    data
}

fn main() -> Result<()> {
    println!("๐Ÿš€ DSLCompile: Partial Evaluation Demo");
    println!("=======================================\n");

    // Check if Rust compiler is available
    if !RustCompiler::is_available() {
        println!("โŒ Rust compiler not available - this demo requires rustc");
        println!("   Please install Rust toolchain to run this example");
        return Ok(());
    }

    // Generate synthetic data
    println!("๐Ÿ“ˆ Generating synthetic data...");
    let data_start = Instant::now();
    let true_beta0 = 2.0;
    let true_beta1 = 1.5;
    let true_sigma = 0.8;
    let n_data = 10_000_000;

    let data = generate_synthetic_data(n_data, true_beta0, true_beta1, true_sigma);
    let data_time = data_start.elapsed().as_secs_f64() * 1000.0;
    println!("   True parameters: ฮฒโ‚€={true_beta0}, ฮฒโ‚={true_beta1}, ฯƒ={true_sigma}");
    println!(
        "   Generated {} data points in {:.2}ms\n",
        data.len(),
        data_time
    );

    // Test parameters
    let true_params = vec![true_beta0, true_beta1, true_sigma * true_sigma]; // Note: ฯƒยฒ not ฯƒ

    println!("๐Ÿ”ฌ DEMONSTRATION: Partial Evaluation & Abstract Interpretation");
    println!("==============================================================\n");

    // ========================================
    // Part 1: Standard Compilation
    // ========================================
    println!("๐Ÿ“Š PART 1: Standard Compilation (Baseline)");
    println!("-------------------------------------------");

    let model = BayesianLinearRegression::new(data.clone())?;

    // Test evaluation
    println!("\n๐Ÿงช Testing evaluation at true parameters...");
    let compiled_result = model.log_posterior_compiled(&true_params)?;
    let direct_result = model.log_posterior_direct(&true_params)?;

    println!("   Compiled result: {compiled_result:.6}");
    println!("   DirectEval result: {direct_result:.6}");
    println!(
        "   Results match: {}",
        (compiled_result - direct_result).abs() < 1e-6 // Relaxed tolerance for large datasets
    );

    // Performance comparison
    model.performance_comparison(&true_params, 10000)?;

    // ========================================
    // Part 2: Partial Evaluation Scenarios
    // ========================================
    println!("\n\n๐Ÿ“Š PART 2: Partial Evaluation Scenarios");
    println!("----------------------------------------");

    // Scenario 1: Positive variance constraint
    println!("\n๐Ÿ”ฌ Scenario 1: Positive Variance Constraint");
    let mut model_positive = BayesianLinearRegression::new(data.clone())?;
    model_positive.apply_partial_evaluation("positive_variance")?;

    // Scenario 2: Bounded coefficients
    println!("\n๐Ÿ”ฌ Scenario 2: Bounded Coefficients");
    let mut model_bounded = BayesianLinearRegression::new(data.clone())?;
    model_bounded.apply_partial_evaluation("bounded_coefficients")?;

    // Scenario 3: Unit variance (fixed ฯƒยฒ = 1)
    println!("\n๐Ÿ”ฌ Scenario 3: Unit Variance Model");
    let mut model_unit = BayesianLinearRegression::new(data.clone())?;
    model_unit.apply_partial_evaluation("unit_variance")?;

    // ========================================
    // Part 3: Performance Analysis
    // ========================================
    println!("\n\n๐Ÿ“ˆ PART 3: Performance Analysis");
    println!("===============================");

    println!("\nโฑ๏ธ  Compilation Timing Comparison:");
    println!("                           โ”‚ Standard Model  โ”‚ Notes");
    println!("โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€");
    let timing = model.timing();
    println!(
        "Symbolic construction      โ”‚ {:>13.2}ms โ”‚ Efficient sufficient stats",
        timing.symbolic_construction_ms
    );
    println!(
        "Symbolic optimization      โ”‚ {:>13.2}ms โ”‚ Basic algebraic rules",
        timing.symbolic_optimization_ms
    );
    println!(
        "Code generation            โ”‚ {:>13.2}ms โ”‚ Rust code generation",
        timing.code_generation_ms
    );
    println!(
        "Rust compilation           โ”‚ {:>13.2}ms โ”‚ LLVM optimization",
        timing.rust_compilation_ms
    );
    println!("โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€");
    println!(
        "TOTAL                      โ”‚ {:>13.2}ms โ”‚",
        timing.total_compilation_ms
    );

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_natural_bayesian_linear_regression() -> Result<()> {
        // Skip test if Rust compiler not available
        if !RustCompiler::is_available() {
            return Ok(());
        }

        // Small test dataset
        let data = vec![(0.0, 1.0), (1.0, 3.0), (2.0, 5.0)]; // y = 1 + 2x
        let model = BayesianLinearRegression::new(data)?;

        // Test evaluation
        let params = vec![1.0, 2.0, 1.0]; // ฮฒโ‚€=1, ฮฒโ‚=2, ฯƒยฒ=1
        let compiled_result = model.log_posterior_compiled(&params)?;
        let direct_result = model.log_posterior_direct(&params)?;

        // Should be finite and match
        assert!(compiled_result.is_finite());
        assert!(direct_result.is_finite());
        assert!((compiled_result - direct_result).abs() < 1e-10);

        // Check that timing information is available
        let timing = model.timing();
        assert!(timing.total_compilation_ms > 0.0);

        Ok(())
    }

    #[test]
    fn test_partial_evaluation() -> Result<()> {
        // Skip test if Rust compiler not available
        if !RustCompiler::is_available() {
            return Ok(());
        }

        // Small test dataset
        let data = vec![(0.0, 1.0), (1.0, 3.0), (2.0, 5.0)];
        let mut model = BayesianLinearRegression::new(data)?;

        // Test partial evaluation
        model.apply_partial_evaluation("positive_variance")?;
        assert!(model.partial_context().is_some());
        assert_eq!(model.partial_context().unwrap(), "positive_variance");

        Ok(())
    }
}