1use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct CostPerformancePoint {
10 pub name: String,
12 pub gpu_hours: f64,
14 pub cost_usd: f64,
16 pub accuracy: f64,
18 pub loss: f64,
20 pub memory_gb: f64,
22 pub is_pareto_optimal: bool,
24 pub config: ConfigParams,
26}
27
28#[derive(Debug, Clone, Default, Serialize, Deserialize)]
30pub struct ConfigParams {
31 pub lora_rank: Option<u32>,
33 pub quant_bits: Option<u8>,
35 pub temperature: Option<f32>,
37 pub alpha: Option<f32>,
39 pub batch_size: Option<usize>,
41 pub learning_rate: Option<f64>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct CostModel {
48 pub gpu_type: String,
50 pub cost_per_hour: f64,
52 pub memory_gb: f64,
54 pub performance_factor: f64,
56}
57
58impl CostModel {
59 pub fn a100_80gb() -> Self {
61 Self {
62 gpu_type: "A100-80GB".to_string(),
63 cost_per_hour: 2.21,
64 memory_gb: 80.0,
65 performance_factor: 1.0,
66 }
67 }
68
69 pub fn a100_40gb() -> Self {
71 Self {
72 gpu_type: "A100-40GB".to_string(),
73 cost_per_hour: 1.10,
74 memory_gb: 40.0,
75 performance_factor: 0.9,
76 }
77 }
78
79 pub fn v100() -> Self {
81 Self {
82 gpu_type: "V100".to_string(),
83 cost_per_hour: 0.90,
84 memory_gb: 16.0,
85 performance_factor: 0.5,
86 }
87 }
88
89 pub fn t4() -> Self {
91 Self {
92 gpu_type: "T4".to_string(),
93 cost_per_hour: 0.35,
94 memory_gb: 16.0,
95 performance_factor: 0.25,
96 }
97 }
98
99 pub fn custom(gpu_type: &str, cost_per_hour: f64, memory_gb: f64) -> Self {
101 Self {
102 gpu_type: gpu_type.to_string(),
103 cost_per_hour,
104 memory_gb,
105 performance_factor: 1.0,
106 }
107 }
108}
109
110#[derive(Debug, Clone, Default, Serialize, Deserialize)]
112pub struct Constraints {
113 pub max_gpu_hours: Option<f64>,
115 pub max_cost_usd: Option<f64>,
117 pub min_accuracy: Option<f64>,
119 pub max_memory_gb: Option<f64>,
121 pub max_loss: Option<f64>,
123}
124
125impl Constraints {
126 pub fn new() -> Self {
128 Self::default()
129 }
130
131 pub fn with_max_gpu_hours(mut self, hours: f64) -> Self {
133 self.max_gpu_hours = Some(hours);
134 self
135 }
136
137 pub fn with_max_cost(mut self, cost: f64) -> Self {
139 self.max_cost_usd = Some(cost);
140 self
141 }
142
143 pub fn with_min_accuracy(mut self, accuracy: f64) -> Self {
145 self.min_accuracy = Some(accuracy);
146 self
147 }
148
149 pub fn with_max_memory(mut self, memory_gb: f64) -> Self {
151 self.max_memory_gb = Some(memory_gb);
152 self
153 }
154
155 pub fn is_satisfied(&self, point: &CostPerformancePoint) -> bool {
157 if let Some(max_hours) = self.max_gpu_hours {
158 if point.gpu_hours > max_hours {
159 return false;
160 }
161 }
162 if let Some(max_cost) = self.max_cost_usd {
163 if point.cost_usd > max_cost {
164 return false;
165 }
166 }
167 if let Some(min_acc) = self.min_accuracy {
168 if point.accuracy < min_acc {
169 return false;
170 }
171 }
172 if let Some(max_mem) = self.max_memory_gb {
173 if point.memory_gb > max_mem {
174 return false;
175 }
176 }
177 if let Some(max_loss) = self.max_loss {
178 if point.loss > max_loss {
179 return false;
180 }
181 }
182 true
183 }
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct CostPerformanceAnalysis {
189 pub points: Vec<CostPerformancePoint>,
191 pub pareto_frontier: Vec<CostPerformancePoint>,
193 pub best_accuracy: Option<CostPerformancePoint>,
195 pub best_efficiency: Option<CostPerformancePoint>,
197 pub lowest_cost: Option<CostPerformancePoint>,
199}
200
201impl CostPerformanceAnalysis {
202 pub fn from_points(mut points: Vec<CostPerformancePoint>) -> Self {
204 let pareto = compute_pareto_frontier(&points);
206
207 for point in &mut points {
209 point.is_pareto_optimal = pareto.iter().any(|p| {
210 (p.cost_usd - point.cost_usd).abs() < 1e-6
211 && (p.accuracy - point.accuracy).abs() < 1e-6
212 });
213 }
214
215 let pareto_frontier = pareto;
216
217 let best_accuracy = points
218 .iter()
219 .max_by(|a, b| {
220 a.accuracy
221 .partial_cmp(&b.accuracy)
222 .unwrap_or(std::cmp::Ordering::Equal)
223 })
224 .cloned();
225
226 let best_efficiency = points
227 .iter()
228 .filter(|p| p.cost_usd > 0.0)
229 .max_by(|a, b| {
230 let eff_a = a.accuracy / a.cost_usd;
231 let eff_b = b.accuracy / b.cost_usd;
232 eff_a
233 .partial_cmp(&eff_b)
234 .unwrap_or(std::cmp::Ordering::Equal)
235 })
236 .cloned();
237
238 let lowest_cost = points
239 .iter()
240 .min_by(|a, b| {
241 a.cost_usd
242 .partial_cmp(&b.cost_usd)
243 .unwrap_or(std::cmp::Ordering::Equal)
244 })
245 .cloned();
246
247 Self {
248 points,
249 pareto_frontier,
250 best_accuracy,
251 best_efficiency,
252 lowest_cost,
253 }
254 }
255
256 pub fn recommend(&self, constraints: &Constraints) -> Vec<Recommendation> {
258 let mut recommendations = Vec::new();
259
260 let valid_points: Vec<_> = self
262 .points
263 .iter()
264 .filter(|p| constraints.is_satisfied(p))
265 .collect();
266
267 if valid_points.is_empty() {
268 return recommendations;
269 }
270
271 if let Some(best_acc) = valid_points.iter().max_by(|a, b| {
273 a.accuracy
274 .partial_cmp(&b.accuracy)
275 .unwrap_or(std::cmp::Ordering::Equal)
276 }) {
277 recommendations.push(Recommendation {
278 reason: "Best accuracy within constraints".to_string(),
279 point: (*best_acc).clone(),
280 });
281 }
282
283 if let Some(best_eff) = valid_points
285 .iter()
286 .filter(|p| p.cost_usd > 0.0)
287 .max_by(|a, b| {
288 let eff_a = a.accuracy / a.cost_usd;
289 let eff_b = b.accuracy / b.cost_usd;
290 eff_a
291 .partial_cmp(&eff_b)
292 .unwrap_or(std::cmp::Ordering::Equal)
293 })
294 {
295 if recommendations
296 .iter()
297 .all(|r| r.point.name != best_eff.name)
298 {
299 recommendations.push(Recommendation {
300 reason: "Best accuracy per dollar within constraints".to_string(),
301 point: (*best_eff).clone(),
302 });
303 }
304 }
305
306 for point in &self.pareto_frontier {
308 if constraints.is_satisfied(point)
309 && recommendations.iter().all(|r| r.point.name != point.name)
310 {
311 recommendations.push(Recommendation {
312 reason: "Pareto-optimal configuration".to_string(),
313 point: point.clone(),
314 });
315 }
316 }
317
318 recommendations
319 }
320
321 pub fn to_table(&self) -> String {
323 let mut table = String::new();
324 table.push_str("Cost-Performance Analysis\n");
325 table.push_str(
326 "┌────────────────────────┬───────────┬───────────┬──────────┬─────────┬─────────┐\n",
327 );
328 table.push_str(
329 "│ Configuration │ GPU Hours │ Cost (USD)│ Accuracy │ Loss │ Pareto? │\n",
330 );
331 table.push_str(
332 "├────────────────────────┼───────────┼───────────┼──────────┼─────────┼─────────┤\n",
333 );
334
335 for point in &self.points {
336 let pareto_mark = if point.is_pareto_optimal { "★" } else { " " };
337 table.push_str(&format!(
338 "│ {:22} │ {:>9.2} │ {:>9.2} │ {:>7.1}% │ {:>7.4} │ {} │\n",
339 truncate(&point.name, 22),
340 point.gpu_hours,
341 point.cost_usd,
342 point.accuracy * 100.0,
343 point.loss,
344 pareto_mark
345 ));
346 }
347
348 table.push_str(
349 "└────────────────────────┴───────────┴───────────┴──────────┴─────────┴─────────┘\n",
350 );
351 table.push_str(
352 "\n★ = Pareto-optimal (no configuration is both cheaper AND more accurate)\n",
353 );
354
355 table
356 }
357}
358
359#[derive(Debug, Clone, Serialize, Deserialize)]
361pub struct Recommendation {
362 pub reason: String,
364 pub point: CostPerformancePoint,
366}
367
368fn compute_pareto_frontier(points: &[CostPerformancePoint]) -> Vec<CostPerformancePoint> {
370 let mut frontier = Vec::new();
371
372 for point in points {
373 let is_dominated = points.iter().any(|other| {
375 other.cost_usd <= point.cost_usd
379 && other.accuracy >= point.accuracy
380 && (other.cost_usd < point.cost_usd || other.accuracy > point.accuracy)
381 });
382
383 if !is_dominated {
384 frontier.push(point.clone());
385 }
386 }
387
388 frontier.sort_by(|a, b| {
390 a.cost_usd
391 .partial_cmp(&b.cost_usd)
392 .unwrap_or(std::cmp::Ordering::Equal)
393 });
394 frontier
395}
396
397fn truncate(s: &str, max_len: usize) -> String {
399 if s.len() <= max_len {
400 format!("{s:max_len$}")
401 } else {
402 format!("{}...", &s[..max_len - 3])
403 }
404}
405
406pub fn generate_sample_points(cost_model: &CostModel) -> Vec<CostPerformancePoint> {
408 vec![
410 CostPerformancePoint {
412 name: "Full Fine-Tuning (7B)".to_string(),
413 gpu_hours: 120.0,
414 cost_usd: 120.0 * cost_model.cost_per_hour,
415 accuracy: 0.92,
416 loss: 0.25,
417 memory_gb: 56.0,
418 is_pareto_optimal: false,
419 config: ConfigParams {
420 lora_rank: None,
421 quant_bits: Some(16),
422 batch_size: Some(8),
423 learning_rate: Some(5e-5),
424 ..Default::default()
425 },
426 },
427 CostPerformancePoint {
429 name: "LoRA r=64".to_string(),
430 gpu_hours: 24.0,
431 cost_usd: 24.0 * cost_model.cost_per_hour,
432 accuracy: 0.89,
433 loss: 0.30,
434 memory_gb: 28.0,
435 is_pareto_optimal: false,
436 config: ConfigParams {
437 lora_rank: Some(64),
438 quant_bits: Some(16),
439 batch_size: Some(16),
440 learning_rate: Some(2e-4),
441 ..Default::default()
442 },
443 },
444 CostPerformancePoint {
446 name: "LoRA r=32".to_string(),
447 gpu_hours: 18.0,
448 cost_usd: 18.0 * cost_model.cost_per_hour,
449 accuracy: 0.87,
450 loss: 0.33,
451 memory_gb: 24.0,
452 is_pareto_optimal: false,
453 config: ConfigParams {
454 lora_rank: Some(32),
455 quant_bits: Some(16),
456 batch_size: Some(16),
457 learning_rate: Some(2e-4),
458 ..Default::default()
459 },
460 },
461 CostPerformancePoint {
463 name: "QLoRA 4-bit r=64".to_string(),
464 gpu_hours: 20.0,
465 cost_usd: 20.0 * cost_model.cost_per_hour,
466 accuracy: 0.86,
467 loss: 0.35,
468 memory_gb: 12.0,
469 is_pareto_optimal: false,
470 config: ConfigParams {
471 lora_rank: Some(64),
472 quant_bits: Some(4),
473 batch_size: Some(32),
474 learning_rate: Some(3e-4),
475 ..Default::default()
476 },
477 },
478 CostPerformancePoint {
480 name: "Distillation T=4".to_string(),
481 gpu_hours: 36.0,
482 cost_usd: 36.0 * cost_model.cost_per_hour,
483 accuracy: 0.84,
484 loss: 0.38,
485 memory_gb: 32.0,
486 is_pareto_optimal: false,
487 config: ConfigParams {
488 temperature: Some(4.0),
489 alpha: Some(0.7),
490 batch_size: Some(16),
491 learning_rate: Some(1e-4),
492 ..Default::default()
493 },
494 },
495 CostPerformancePoint {
497 name: "LoRA + Distillation".to_string(),
498 gpu_hours: 32.0,
499 cost_usd: 32.0 * cost_model.cost_per_hour,
500 accuracy: 0.88,
501 loss: 0.31,
502 memory_gb: 26.0,
503 is_pareto_optimal: false,
504 config: ConfigParams {
505 lora_rank: Some(32),
506 temperature: Some(4.0),
507 alpha: Some(0.5),
508 batch_size: Some(16),
509 learning_rate: Some(2e-4),
510 ..Default::default()
511 },
512 },
513 CostPerformancePoint {
515 name: "QLoRA 8-bit r=32".to_string(),
516 gpu_hours: 16.0,
517 cost_usd: 16.0 * cost_model.cost_per_hour,
518 accuracy: 0.85,
519 loss: 0.36,
520 memory_gb: 16.0,
521 is_pareto_optimal: false,
522 config: ConfigParams {
523 lora_rank: Some(32),
524 quant_bits: Some(8),
525 batch_size: Some(32),
526 learning_rate: Some(2e-4),
527 ..Default::default()
528 },
529 },
530 CostPerformancePoint {
532 name: "LoRA r=8".to_string(),
533 gpu_hours: 8.0,
534 cost_usd: 8.0 * cost_model.cost_per_hour,
535 accuracy: 0.81,
536 loss: 0.42,
537 memory_gb: 18.0,
538 is_pareto_optimal: false,
539 config: ConfigParams {
540 lora_rank: Some(8),
541 quant_bits: Some(16),
542 batch_size: Some(32),
543 learning_rate: Some(5e-4),
544 ..Default::default()
545 },
546 },
547 ]
548}
549
550#[cfg(test)]
551mod tests {
552 use super::*;
553
554 #[test]
555 fn test_pareto_frontier() {
556 let points = vec![
557 CostPerformancePoint {
558 name: "A".to_string(),
559 gpu_hours: 10.0,
560 cost_usd: 10.0,
561 accuracy: 0.8,
562 loss: 0.3,
563 memory_gb: 16.0,
564 is_pareto_optimal: false,
565 config: Default::default(),
566 },
567 CostPerformancePoint {
568 name: "B".to_string(),
569 gpu_hours: 20.0,
570 cost_usd: 20.0,
571 accuracy: 0.9,
572 loss: 0.2,
573 memory_gb: 24.0,
574 is_pareto_optimal: false,
575 config: Default::default(),
576 },
577 CostPerformancePoint {
578 name: "C".to_string(), gpu_hours: 25.0,
580 cost_usd: 25.0,
581 accuracy: 0.85,
582 loss: 0.25,
583 memory_gb: 24.0,
584 is_pareto_optimal: false,
585 config: Default::default(),
586 },
587 ];
588
589 let frontier = compute_pareto_frontier(&points);
590 assert_eq!(frontier.len(), 2); assert!(frontier.iter().any(|p| p.name == "A"));
592 assert!(frontier.iter().any(|p| p.name == "B"));
593 assert!(!frontier.iter().any(|p| p.name == "C"));
594 }
595
596 #[test]
597 fn test_constraints() {
598 let constraints = Constraints::new()
599 .with_max_cost(50.0)
600 .with_min_accuracy(0.85);
601
602 let point_good = CostPerformancePoint {
603 name: "Good".to_string(),
604 gpu_hours: 20.0,
605 cost_usd: 40.0,
606 accuracy: 0.90,
607 loss: 0.25,
608 memory_gb: 16.0,
609 is_pareto_optimal: false,
610 config: Default::default(),
611 };
612
613 let point_expensive = CostPerformancePoint {
614 name: "Expensive".to_string(),
615 gpu_hours: 30.0,
616 cost_usd: 60.0,
617 accuracy: 0.95,
618 loss: 0.20,
619 memory_gb: 16.0,
620 is_pareto_optimal: false,
621 config: Default::default(),
622 };
623
624 let point_low_acc = CostPerformancePoint {
625 name: "LowAcc".to_string(),
626 gpu_hours: 10.0,
627 cost_usd: 20.0,
628 accuracy: 0.80,
629 loss: 0.35,
630 memory_gb: 16.0,
631 is_pareto_optimal: false,
632 config: Default::default(),
633 };
634
635 assert!(constraints.is_satisfied(&point_good));
636 assert!(!constraints.is_satisfied(&point_expensive)); assert!(!constraints.is_satisfied(&point_low_acc)); }
639
640 #[test]
641 fn test_analysis_recommendations() {
642 let cost_model = CostModel::a100_80gb();
643 let points = generate_sample_points(&cost_model);
644 let analysis = CostPerformanceAnalysis::from_points(points);
645
646 assert!(!analysis.pareto_frontier.is_empty());
647 assert!(analysis.best_accuracy.is_some());
648 assert!(analysis.best_efficiency.is_some());
649
650 let constraints = Constraints::new().with_max_cost(50.0);
651 let recommendations = analysis.recommend(&constraints);
652 assert!(!recommendations.is_empty());
653 }
654
655 #[test]
656 fn test_cost_models() {
657 let a100 = CostModel::a100_80gb();
658 assert_eq!(a100.gpu_type, "A100-80GB");
659 assert!(a100.cost_per_hour > 0.0);
660
661 let v100 = CostModel::v100();
662 assert!(v100.cost_per_hour < a100.cost_per_hour);
663 }
664}