Skip to main content

entrenar_bench/
cost.rs

1//! Cost-Performance Analysis (ENT-029)
2//!
3//! Provides Pareto frontier analysis for balancing training cost vs model performance.
4
5use serde::{Deserialize, Serialize};
6
7/// A single configuration with cost and performance metrics
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct CostPerformancePoint {
10    /// Configuration name or description
11    pub name: String,
12    /// Training cost in GPU-hours
13    pub gpu_hours: f64,
14    /// Estimated cloud cost in USD
15    pub cost_usd: f64,
16    /// Model accuracy (0.0 - 1.0)
17    pub accuracy: f64,
18    /// Model loss
19    pub loss: f64,
20    /// Memory usage in GB
21    pub memory_gb: f64,
22    /// Whether this point is on the Pareto frontier
23    pub is_pareto_optimal: bool,
24    /// Configuration parameters
25    pub config: ConfigParams,
26}
27
28/// Configuration parameters for a training run
29#[derive(Debug, Clone, Default, Serialize, Deserialize)]
30pub struct ConfigParams {
31    /// LoRA rank
32    pub lora_rank: Option<u32>,
33    /// Quantization bits (4, 8, 16, 32)
34    pub quant_bits: Option<u8>,
35    /// Temperature for distillation
36    pub temperature: Option<f32>,
37    /// Alpha for distillation
38    pub alpha: Option<f32>,
39    /// Batch size
40    pub batch_size: Option<usize>,
41    /// Learning rate
42    pub learning_rate: Option<f64>,
43}
44
45/// Cost model for different GPU types
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct CostModel {
48    /// GPU type name
49    pub gpu_type: String,
50    /// Cost per hour in USD
51    pub cost_per_hour: f64,
52    /// Memory in GB
53    pub memory_gb: f64,
54    /// Relative performance factor (vs baseline)
55    pub performance_factor: f64,
56}
57
58impl CostModel {
59    /// Create an A100 80GB cost model
60    pub fn a100_80gb() -> Self {
61        Self {
62            gpu_type: "A100-80GB".to_string(),
63            cost_per_hour: 2.21,
64            memory_gb: 80.0,
65            performance_factor: 1.0,
66        }
67    }
68
69    /// Create an A100 40GB cost model
70    pub fn a100_40gb() -> Self {
71        Self {
72            gpu_type: "A100-40GB".to_string(),
73            cost_per_hour: 1.10,
74            memory_gb: 40.0,
75            performance_factor: 0.9,
76        }
77    }
78
79    /// Create a V100 cost model
80    pub fn v100() -> Self {
81        Self {
82            gpu_type: "V100".to_string(),
83            cost_per_hour: 0.90,
84            memory_gb: 16.0,
85            performance_factor: 0.5,
86        }
87    }
88
89    /// Create a T4 cost model
90    pub fn t4() -> Self {
91        Self {
92            gpu_type: "T4".to_string(),
93            cost_per_hour: 0.35,
94            memory_gb: 16.0,
95            performance_factor: 0.25,
96        }
97    }
98
99    /// Create a custom cost model
100    pub fn custom(gpu_type: &str, cost_per_hour: f64, memory_gb: f64) -> Self {
101        Self {
102            gpu_type: gpu_type.to_string(),
103            cost_per_hour,
104            memory_gb,
105            performance_factor: 1.0,
106        }
107    }
108}
109
110/// Constraints for recommendations
111#[derive(Debug, Clone, Default, Serialize, Deserialize)]
112pub struct Constraints {
113    /// Maximum GPU-hours
114    pub max_gpu_hours: Option<f64>,
115    /// Maximum cost in USD
116    pub max_cost_usd: Option<f64>,
117    /// Minimum accuracy required
118    pub min_accuracy: Option<f64>,
119    /// Maximum memory in GB
120    pub max_memory_gb: Option<f64>,
121    /// Maximum loss
122    pub max_loss: Option<f64>,
123}
124
125impl Constraints {
126    /// Create new empty constraints
127    pub fn new() -> Self {
128        Self::default()
129    }
130
131    /// Set maximum GPU-hours
132    pub fn with_max_gpu_hours(mut self, hours: f64) -> Self {
133        self.max_gpu_hours = Some(hours);
134        self
135    }
136
137    /// Set maximum cost
138    pub fn with_max_cost(mut self, cost: f64) -> Self {
139        self.max_cost_usd = Some(cost);
140        self
141    }
142
143    /// Set minimum accuracy
144    pub fn with_min_accuracy(mut self, accuracy: f64) -> Self {
145        self.min_accuracy = Some(accuracy);
146        self
147    }
148
149    /// Set maximum memory
150    pub fn with_max_memory(mut self, memory_gb: f64) -> Self {
151        self.max_memory_gb = Some(memory_gb);
152        self
153    }
154
155    /// Check if a point satisfies all constraints
156    pub fn is_satisfied(&self, point: &CostPerformancePoint) -> bool {
157        if let Some(max_hours) = self.max_gpu_hours {
158            if point.gpu_hours > max_hours {
159                return false;
160            }
161        }
162        if let Some(max_cost) = self.max_cost_usd {
163            if point.cost_usd > max_cost {
164                return false;
165            }
166        }
167        if let Some(min_acc) = self.min_accuracy {
168            if point.accuracy < min_acc {
169                return false;
170            }
171        }
172        if let Some(max_mem) = self.max_memory_gb {
173            if point.memory_gb > max_mem {
174                return false;
175            }
176        }
177        if let Some(max_loss) = self.max_loss {
178            if point.loss > max_loss {
179                return false;
180            }
181        }
182        true
183    }
184}
185
186/// Cost-Performance Analysis result
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct CostPerformanceAnalysis {
189    /// All data points
190    pub points: Vec<CostPerformancePoint>,
191    /// Pareto-optimal points
192    pub pareto_frontier: Vec<CostPerformancePoint>,
193    /// Best by accuracy
194    pub best_accuracy: Option<CostPerformancePoint>,
195    /// Best by cost efficiency (accuracy per dollar)
196    pub best_efficiency: Option<CostPerformancePoint>,
197    /// Best by cost (lowest)
198    pub lowest_cost: Option<CostPerformancePoint>,
199}
200
201impl CostPerformanceAnalysis {
202    /// Compute Pareto frontier from data points
203    pub fn from_points(mut points: Vec<CostPerformancePoint>) -> Self {
204        // Compute Pareto frontier (minimize cost, maximize accuracy)
205        let pareto = compute_pareto_frontier(&points);
206
207        // Mark points that are Pareto optimal
208        for point in &mut points {
209            point.is_pareto_optimal = pareto.iter().any(|p| {
210                (p.cost_usd - point.cost_usd).abs() < 1e-6
211                    && (p.accuracy - point.accuracy).abs() < 1e-6
212            });
213        }
214
215        let pareto_frontier = pareto;
216
217        let best_accuracy = points
218            .iter()
219            .max_by(|a, b| {
220                a.accuracy
221                    .partial_cmp(&b.accuracy)
222                    .unwrap_or(std::cmp::Ordering::Equal)
223            })
224            .cloned();
225
226        let best_efficiency = points
227            .iter()
228            .filter(|p| p.cost_usd > 0.0)
229            .max_by(|a, b| {
230                let eff_a = a.accuracy / a.cost_usd;
231                let eff_b = b.accuracy / b.cost_usd;
232                eff_a
233                    .partial_cmp(&eff_b)
234                    .unwrap_or(std::cmp::Ordering::Equal)
235            })
236            .cloned();
237
238        let lowest_cost = points
239            .iter()
240            .min_by(|a, b| {
241                a.cost_usd
242                    .partial_cmp(&b.cost_usd)
243                    .unwrap_or(std::cmp::Ordering::Equal)
244            })
245            .cloned();
246
247        Self {
248            points,
249            pareto_frontier,
250            best_accuracy,
251            best_efficiency,
252            lowest_cost,
253        }
254    }
255
256    /// Get recommendations based on constraints
257    pub fn recommend(&self, constraints: &Constraints) -> Vec<Recommendation> {
258        let mut recommendations = Vec::new();
259
260        // Filter points that satisfy constraints
261        let valid_points: Vec<_> = self
262            .points
263            .iter()
264            .filter(|p| constraints.is_satisfied(p))
265            .collect();
266
267        if valid_points.is_empty() {
268            return recommendations;
269        }
270
271        // Best accuracy within constraints
272        if let Some(best_acc) = valid_points.iter().max_by(|a, b| {
273            a.accuracy
274                .partial_cmp(&b.accuracy)
275                .unwrap_or(std::cmp::Ordering::Equal)
276        }) {
277            recommendations.push(Recommendation {
278                reason: "Best accuracy within constraints".to_string(),
279                point: (*best_acc).clone(),
280            });
281        }
282
283        // Best efficiency within constraints
284        if let Some(best_eff) = valid_points
285            .iter()
286            .filter(|p| p.cost_usd > 0.0)
287            .max_by(|a, b| {
288                let eff_a = a.accuracy / a.cost_usd;
289                let eff_b = b.accuracy / b.cost_usd;
290                eff_a
291                    .partial_cmp(&eff_b)
292                    .unwrap_or(std::cmp::Ordering::Equal)
293            })
294        {
295            if recommendations
296                .iter()
297                .all(|r| r.point.name != best_eff.name)
298            {
299                recommendations.push(Recommendation {
300                    reason: "Best accuracy per dollar within constraints".to_string(),
301                    point: (*best_eff).clone(),
302                });
303            }
304        }
305
306        // Pareto-optimal within constraints
307        for point in &self.pareto_frontier {
308            if constraints.is_satisfied(point)
309                && recommendations.iter().all(|r| r.point.name != point.name)
310            {
311                recommendations.push(Recommendation {
312                    reason: "Pareto-optimal configuration".to_string(),
313                    point: point.clone(),
314                });
315            }
316        }
317
318        recommendations
319    }
320
321    /// Generate ASCII table for display
322    pub fn to_table(&self) -> String {
323        let mut table = String::new();
324        table.push_str("Cost-Performance Analysis\n");
325        table.push_str(
326            "┌────────────────────────┬───────────┬───────────┬──────────┬─────────┬─────────┐\n",
327        );
328        table.push_str(
329            "│ Configuration          │ GPU Hours │ Cost (USD)│ Accuracy │   Loss  │ Pareto? │\n",
330        );
331        table.push_str(
332            "├────────────────────────┼───────────┼───────────┼──────────┼─────────┼─────────┤\n",
333        );
334
335        for point in &self.points {
336            let pareto_mark = if point.is_pareto_optimal { "★" } else { " " };
337            table.push_str(&format!(
338                "│ {:22} │ {:>9.2} │ {:>9.2} │ {:>7.1}% │ {:>7.4} │    {}    │\n",
339                truncate(&point.name, 22),
340                point.gpu_hours,
341                point.cost_usd,
342                point.accuracy * 100.0,
343                point.loss,
344                pareto_mark
345            ));
346        }
347
348        table.push_str(
349            "└────────────────────────┴───────────┴───────────┴──────────┴─────────┴─────────┘\n",
350        );
351        table.push_str(
352            "\n★ = Pareto-optimal (no configuration is both cheaper AND more accurate)\n",
353        );
354
355        table
356    }
357}
358
359/// A recommendation with reasoning
360#[derive(Debug, Clone, Serialize, Deserialize)]
361pub struct Recommendation {
362    /// Reason for the recommendation
363    pub reason: String,
364    /// The recommended configuration
365    pub point: CostPerformancePoint,
366}
367
368/// Compute Pareto frontier for minimizing cost and maximizing accuracy
369fn compute_pareto_frontier(points: &[CostPerformancePoint]) -> Vec<CostPerformancePoint> {
370    let mut frontier = Vec::new();
371
372    for point in points {
373        // Check if this point is dominated by any other
374        let is_dominated = points.iter().any(|other| {
375            // Other dominates point if:
376            // - Other has lower or equal cost AND higher or equal accuracy
377            // - AND at least one is strictly better
378            other.cost_usd <= point.cost_usd
379                && other.accuracy >= point.accuracy
380                && (other.cost_usd < point.cost_usd || other.accuracy > point.accuracy)
381        });
382
383        if !is_dominated {
384            frontier.push(point.clone());
385        }
386    }
387
388    // Sort by cost
389    frontier.sort_by(|a, b| {
390        a.cost_usd
391            .partial_cmp(&b.cost_usd)
392            .unwrap_or(std::cmp::Ordering::Equal)
393    });
394    frontier
395}
396
397/// Truncate string to max length
398fn truncate(s: &str, max_len: usize) -> String {
399    if s.len() <= max_len {
400        format!("{s:max_len$}")
401    } else {
402        format!("{}...", &s[..max_len - 3])
403    }
404}
405
406/// Generate sample data points for testing/demo
407pub fn generate_sample_points(cost_model: &CostModel) -> Vec<CostPerformancePoint> {
408    // Sample configurations representing different trade-offs
409    vec![
410        // Full fine-tuning (expensive, high accuracy)
411        CostPerformancePoint {
412            name: "Full Fine-Tuning (7B)".to_string(),
413            gpu_hours: 120.0,
414            cost_usd: 120.0 * cost_model.cost_per_hour,
415            accuracy: 0.92,
416            loss: 0.25,
417            memory_gb: 56.0,
418            is_pareto_optimal: false,
419            config: ConfigParams {
420                lora_rank: None,
421                quant_bits: Some(16),
422                batch_size: Some(8),
423                learning_rate: Some(5e-5),
424                ..Default::default()
425            },
426        },
427        // LoRA (moderate cost, good accuracy)
428        CostPerformancePoint {
429            name: "LoRA r=64".to_string(),
430            gpu_hours: 24.0,
431            cost_usd: 24.0 * cost_model.cost_per_hour,
432            accuracy: 0.89,
433            loss: 0.30,
434            memory_gb: 28.0,
435            is_pareto_optimal: false,
436            config: ConfigParams {
437                lora_rank: Some(64),
438                quant_bits: Some(16),
439                batch_size: Some(16),
440                learning_rate: Some(2e-4),
441                ..Default::default()
442            },
443        },
444        // LoRA r=32 (cheaper, slightly lower accuracy)
445        CostPerformancePoint {
446            name: "LoRA r=32".to_string(),
447            gpu_hours: 18.0,
448            cost_usd: 18.0 * cost_model.cost_per_hour,
449            accuracy: 0.87,
450            loss: 0.33,
451            memory_gb: 24.0,
452            is_pareto_optimal: false,
453            config: ConfigParams {
454                lora_rank: Some(32),
455                quant_bits: Some(16),
456                batch_size: Some(16),
457                learning_rate: Some(2e-4),
458                ..Default::default()
459            },
460        },
461        // QLoRA 4-bit (low cost, good accuracy)
462        CostPerformancePoint {
463            name: "QLoRA 4-bit r=64".to_string(),
464            gpu_hours: 20.0,
465            cost_usd: 20.0 * cost_model.cost_per_hour,
466            accuracy: 0.86,
467            loss: 0.35,
468            memory_gb: 12.0,
469            is_pareto_optimal: false,
470            config: ConfigParams {
471                lora_rank: Some(64),
472                quant_bits: Some(4),
473                batch_size: Some(32),
474                learning_rate: Some(3e-4),
475                ..Default::default()
476            },
477        },
478        // Distillation (moderate cost, moderate accuracy)
479        CostPerformancePoint {
480            name: "Distillation T=4".to_string(),
481            gpu_hours: 36.0,
482            cost_usd: 36.0 * cost_model.cost_per_hour,
483            accuracy: 0.84,
484            loss: 0.38,
485            memory_gb: 32.0,
486            is_pareto_optimal: false,
487            config: ConfigParams {
488                temperature: Some(4.0),
489                alpha: Some(0.7),
490                batch_size: Some(16),
491                learning_rate: Some(1e-4),
492                ..Default::default()
493            },
494        },
495        // LoRA + Distillation (balanced)
496        CostPerformancePoint {
497            name: "LoRA + Distillation".to_string(),
498            gpu_hours: 32.0,
499            cost_usd: 32.0 * cost_model.cost_per_hour,
500            accuracy: 0.88,
501            loss: 0.31,
502            memory_gb: 26.0,
503            is_pareto_optimal: false,
504            config: ConfigParams {
505                lora_rank: Some(32),
506                temperature: Some(4.0),
507                alpha: Some(0.5),
508                batch_size: Some(16),
509                learning_rate: Some(2e-4),
510                ..Default::default()
511            },
512        },
513        // QLoRA 8-bit (moderate everything)
514        CostPerformancePoint {
515            name: "QLoRA 8-bit r=32".to_string(),
516            gpu_hours: 16.0,
517            cost_usd: 16.0 * cost_model.cost_per_hour,
518            accuracy: 0.85,
519            loss: 0.36,
520            memory_gb: 16.0,
521            is_pareto_optimal: false,
522            config: ConfigParams {
523                lora_rank: Some(32),
524                quant_bits: Some(8),
525                batch_size: Some(32),
526                learning_rate: Some(2e-4),
527                ..Default::default()
528            },
529        },
530        // Minimal LoRA (very cheap, lower accuracy)
531        CostPerformancePoint {
532            name: "LoRA r=8".to_string(),
533            gpu_hours: 8.0,
534            cost_usd: 8.0 * cost_model.cost_per_hour,
535            accuracy: 0.81,
536            loss: 0.42,
537            memory_gb: 18.0,
538            is_pareto_optimal: false,
539            config: ConfigParams {
540                lora_rank: Some(8),
541                quant_bits: Some(16),
542                batch_size: Some(32),
543                learning_rate: Some(5e-4),
544                ..Default::default()
545            },
546        },
547    ]
548}
549
550#[cfg(test)]
551mod tests {
552    use super::*;
553
554    #[test]
555    fn test_pareto_frontier() {
556        let points = vec![
557            CostPerformancePoint {
558                name: "A".to_string(),
559                gpu_hours: 10.0,
560                cost_usd: 10.0,
561                accuracy: 0.8,
562                loss: 0.3,
563                memory_gb: 16.0,
564                is_pareto_optimal: false,
565                config: Default::default(),
566            },
567            CostPerformancePoint {
568                name: "B".to_string(),
569                gpu_hours: 20.0,
570                cost_usd: 20.0,
571                accuracy: 0.9,
572                loss: 0.2,
573                memory_gb: 24.0,
574                is_pareto_optimal: false,
575                config: Default::default(),
576            },
577            CostPerformancePoint {
578                name: "C".to_string(), // Dominated by B
579                gpu_hours: 25.0,
580                cost_usd: 25.0,
581                accuracy: 0.85,
582                loss: 0.25,
583                memory_gb: 24.0,
584                is_pareto_optimal: false,
585                config: Default::default(),
586            },
587        ];
588
589        let frontier = compute_pareto_frontier(&points);
590        assert_eq!(frontier.len(), 2); // A and B are Pareto optimal
591        assert!(frontier.iter().any(|p| p.name == "A"));
592        assert!(frontier.iter().any(|p| p.name == "B"));
593        assert!(!frontier.iter().any(|p| p.name == "C"));
594    }
595
596    #[test]
597    fn test_constraints() {
598        let constraints = Constraints::new()
599            .with_max_cost(50.0)
600            .with_min_accuracy(0.85);
601
602        let point_good = CostPerformancePoint {
603            name: "Good".to_string(),
604            gpu_hours: 20.0,
605            cost_usd: 40.0,
606            accuracy: 0.90,
607            loss: 0.25,
608            memory_gb: 16.0,
609            is_pareto_optimal: false,
610            config: Default::default(),
611        };
612
613        let point_expensive = CostPerformancePoint {
614            name: "Expensive".to_string(),
615            gpu_hours: 30.0,
616            cost_usd: 60.0,
617            accuracy: 0.95,
618            loss: 0.20,
619            memory_gb: 16.0,
620            is_pareto_optimal: false,
621            config: Default::default(),
622        };
623
624        let point_low_acc = CostPerformancePoint {
625            name: "LowAcc".to_string(),
626            gpu_hours: 10.0,
627            cost_usd: 20.0,
628            accuracy: 0.80,
629            loss: 0.35,
630            memory_gb: 16.0,
631            is_pareto_optimal: false,
632            config: Default::default(),
633        };
634
635        assert!(constraints.is_satisfied(&point_good));
636        assert!(!constraints.is_satisfied(&point_expensive)); // Too expensive
637        assert!(!constraints.is_satisfied(&point_low_acc)); // Too low accuracy
638    }
639
640    #[test]
641    fn test_analysis_recommendations() {
642        let cost_model = CostModel::a100_80gb();
643        let points = generate_sample_points(&cost_model);
644        let analysis = CostPerformanceAnalysis::from_points(points);
645
646        assert!(!analysis.pareto_frontier.is_empty());
647        assert!(analysis.best_accuracy.is_some());
648        assert!(analysis.best_efficiency.is_some());
649
650        let constraints = Constraints::new().with_max_cost(50.0);
651        let recommendations = analysis.recommend(&constraints);
652        assert!(!recommendations.is_empty());
653    }
654
655    #[test]
656    fn test_cost_models() {
657        let a100 = CostModel::a100_80gb();
658        assert_eq!(a100.gpu_type, "A100-80GB");
659        assert!(a100.cost_per_hour > 0.0);
660
661        let v100 = CostModel::v100();
662        assert!(v100.cost_per_hour < a100.cost_per_hour);
663    }
664}