1use serde::{Deserialize, Serialize};
29use std::time::{Duration, Instant};
30
31#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
33pub enum ComputeDevice {
34 Cpu {
36 cores: u32,
38 threads_per_core: u32,
40 architecture: CpuArchitecture,
42 },
43 Gpu {
45 name: String,
47 memory_gb: f32,
49 compute_capability: Option<String>,
51 vendor: GpuVendor,
53 },
54 Tpu {
56 version: TpuVersion,
58 cores: u32,
60 },
61 AppleSilicon {
63 chip: AppleChip,
65 neural_engine_cores: u32,
67 gpu_cores: u32,
69 memory_gb: u32,
71 },
72 Edge {
74 name: String,
76 power_budget_watts: f32,
78 },
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
83pub enum CpuArchitecture {
84 X86_64,
86 Aarch64,
88 Riscv64,
90 Wasm32,
92}
93
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
96pub enum GpuVendor {
97 Nvidia,
99 Amd,
101 Intel,
103 Apple,
105}
106
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
109pub enum TpuVersion {
110 V2,
112 V3,
114 V4,
116 V5e,
118 V5p,
120}
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
124pub enum AppleChip {
125 M1,
127 M1Pro,
129 M1Max,
131 M1Ultra,
133 M2,
135 M2Pro,
137 M2Max,
139 M2Ultra,
141 M3,
143 M3Pro,
145 M3Max,
147 M4,
149 M4Pro,
151 M4Max,
153}
154
155impl ComputeDevice {
156 #[must_use]
158 pub fn theoretical_flops(&self) -> f64 {
159 match self {
160 ComputeDevice::Cpu {
161 cores,
162 threads_per_core,
163 architecture,
164 } => {
165 let base_flops = match architecture {
166 CpuArchitecture::X86_64 => 32.0, CpuArchitecture::Aarch64 => 16.0, CpuArchitecture::Riscv64 => 8.0,
169 CpuArchitecture::Wasm32 => 4.0,
170 };
171 f64::from(*cores) * f64::from(*threads_per_core) * base_flops * 1e9
172 }
173 ComputeDevice::Gpu {
174 memory_gb, vendor, ..
175 } => {
176 let bandwidth_factor = match vendor {
178 GpuVendor::Nvidia => 15.0,
179 GpuVendor::Amd => 12.0,
180 GpuVendor::Intel => 8.0,
181 GpuVendor::Apple => 10.0,
182 };
183 f64::from(*memory_gb) * bandwidth_factor * 1e12
184 }
185 ComputeDevice::Tpu { version, cores } => {
186 let flops_per_core = match version {
187 TpuVersion::V2 => 45e12,
188 TpuVersion::V3 => 90e12,
189 TpuVersion::V4 => 275e12,
190 TpuVersion::V5e => 197e12,
191 TpuVersion::V5p => 459e12,
192 };
193 f64::from(*cores) * flops_per_core
194 }
195 ComputeDevice::AppleSilicon {
196 chip, gpu_cores, ..
197 } => {
198 let flops_per_gpu_core = match chip {
199 AppleChip::M1 | AppleChip::M1Pro | AppleChip::M1Max | AppleChip::M1Ultra => {
200 128e9
201 }
202 AppleChip::M2 | AppleChip::M2Pro | AppleChip::M2Max | AppleChip::M2Ultra => {
203 150e9
204 }
205 AppleChip::M3 | AppleChip::M3Pro | AppleChip::M3Max => 180e9,
206 AppleChip::M4 | AppleChip::M4Pro | AppleChip::M4Max => 200e9,
207 };
208 f64::from(*gpu_cores) * flops_per_gpu_core
209 }
210 ComputeDevice::Edge {
211 power_budget_watts, ..
212 } => {
213 f64::from(*power_budget_watts) * 10e9
215 }
216 }
217 }
218
219 #[must_use]
221 pub fn estimated_power_watts(&self) -> f32 {
222 match self {
223 ComputeDevice::Cpu { cores, .. } => (*cores as f32) * 15.0,
224 ComputeDevice::Gpu {
225 memory_gb, vendor, ..
226 } => {
227 let base = match vendor {
228 GpuVendor::Nvidia => 30.0,
229 GpuVendor::Amd => 35.0,
230 GpuVendor::Intel => 25.0,
231 GpuVendor::Apple => 20.0,
232 };
233 *memory_gb * base
234 }
235 ComputeDevice::Tpu { version, cores } => {
236 let per_core = match version {
237 TpuVersion::V2 => 40.0,
238 TpuVersion::V3 => 50.0,
239 TpuVersion::V4 => 60.0,
240 TpuVersion::V5e => 45.0,
241 TpuVersion::V5p => 70.0,
242 };
243 (*cores as f32) * per_core
244 }
245 ComputeDevice::AppleSilicon { chip, .. } => match chip {
246 AppleChip::M1 => 20.0,
247 AppleChip::M1Pro => 30.0,
248 AppleChip::M1Max => 40.0,
249 AppleChip::M1Ultra => 60.0,
250 AppleChip::M2 => 22.0,
251 AppleChip::M2Pro => 32.0,
252 AppleChip::M2Max => 45.0,
253 AppleChip::M2Ultra => 65.0,
254 AppleChip::M3 => 24.0,
255 AppleChip::M3Pro => 35.0,
256 AppleChip::M3Max => 50.0,
257 AppleChip::M4 => 25.0,
258 AppleChip::M4Pro => 38.0,
259 AppleChip::M4Max => 55.0,
260 },
261 ComputeDevice::Edge {
262 power_budget_watts, ..
263 } => *power_budget_watts,
264 }
265 }
266
267 #[must_use]
269 pub fn default_cpu() -> Self {
270 ComputeDevice::Cpu {
271 cores: 8,
272 threads_per_core: 2,
273 #[cfg(target_arch = "x86_64")]
274 architecture: CpuArchitecture::X86_64,
275 #[cfg(target_arch = "aarch64")]
276 architecture: CpuArchitecture::Aarch64,
277 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
278 architecture: CpuArchitecture::X86_64,
279 }
280 }
281}
282
283impl Default for ComputeDevice {
284 fn default() -> Self {
285 Self::default_cpu()
286 }
287}
288
289#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
291pub struct EnergyMetrics {
292 pub total_joules: f64,
294 pub average_power_watts: f64,
296 pub peak_power_watts: f64,
298 pub duration_seconds: f64,
300 pub co2_grams: Option<f64>,
302 pub pue: f64,
304}
305
306impl EnergyMetrics {
307 #[must_use]
309 pub fn new(
310 total_joules: f64,
311 average_power_watts: f64,
312 peak_power_watts: f64,
313 duration_seconds: f64,
314 ) -> Self {
315 Self {
316 total_joules,
317 average_power_watts,
318 peak_power_watts,
319 duration_seconds,
320 co2_grams: None,
321 pue: 1.0,
322 }
323 }
324
325 #[must_use]
332 pub fn with_carbon_intensity(mut self, carbon_intensity_g_per_kwh: f64) -> Self {
333 let kwh = self.total_joules / 3_600_000.0;
334 self.co2_grams = Some(kwh * carbon_intensity_g_per_kwh * self.pue);
335 self
336 }
337
338 #[must_use]
346 pub fn with_pue(mut self, pue: f64) -> Self {
347 let old_pue = self.pue;
348 self.pue = pue;
349 if let Some(co2) = self.co2_grams {
351 self.co2_grams = Some(co2 / old_pue * pue);
352 }
353 self
354 }
355
356 #[must_use]
358 pub fn flops_per_watt(&self, total_flops: f64) -> f64 {
359 if self.average_power_watts > 0.0 {
360 total_flops / self.average_power_watts
361 } else {
362 0.0
363 }
364 }
365}
366
367impl Default for EnergyMetrics {
368 fn default() -> Self {
369 Self::new(0.0, 0.0, 0.0, 0.0)
370 }
371}
372
373#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
375pub struct CostMetrics {
376 pub compute_cost_usd: f64,
378 pub storage_cost_usd: f64,
380 pub network_cost_usd: f64,
382 pub total_cost_usd: f64,
384 pub cost_per_sample: Option<f64>,
386 pub currency: String,
388}
389
390impl CostMetrics {
391 #[must_use]
393 pub fn new(compute_cost: f64, storage_cost: f64, network_cost: f64) -> Self {
394 Self {
395 compute_cost_usd: compute_cost,
396 storage_cost_usd: storage_cost,
397 network_cost_usd: network_cost,
398 total_cost_usd: compute_cost + storage_cost + network_cost,
399 cost_per_sample: None,
400 currency: "USD".to_string(),
401 }
402 }
403
404 #[must_use]
406 pub fn with_samples(mut self, total_samples: u64) -> Self {
407 if total_samples > 0 {
408 self.cost_per_sample = Some(self.total_cost_usd / total_samples as f64);
409 }
410 self
411 }
412
413 #[must_use]
415 pub fn cost_per_sample(&self) -> f64 {
416 self.cost_per_sample.unwrap_or(0.0)
417 }
418}
419
420impl Default for CostMetrics {
421 fn default() -> Self {
422 Self::new(0.0, 0.0, 0.0)
423 }
424}
425
426#[derive(Debug, Clone)]
430pub struct GenerationExperiment {
431 pub name: String,
433 pub device: ComputeDevice,
435 start_time: Option<Instant>,
437 pub samples_generated: u64,
439 pub total_duration: Duration,
441 pub hourly_rate_usd: f64,
443 pub carbon_intensity: f64,
445}
446
447impl GenerationExperiment {
448 #[must_use]
450 pub fn new(name: &str, device: ComputeDevice) -> Self {
451 Self {
452 name: name.to_string(),
453 device,
454 start_time: None,
455 samples_generated: 0,
456 total_duration: Duration::ZERO,
457 hourly_rate_usd: 0.10, carbon_intensity: 386.0, }
460 }
461
462 #[must_use]
464 pub fn with_hourly_rate(mut self, rate_usd: f64) -> Self {
465 self.hourly_rate_usd = rate_usd;
466 self
467 }
468
469 #[must_use]
471 pub fn with_carbon_intensity(mut self, g_per_kwh: f64) -> Self {
472 self.carbon_intensity = g_per_kwh;
473 self
474 }
475
476 pub fn start(&mut self) {
478 self.start_time = Some(Instant::now());
479 }
480
481 pub fn record_samples(&mut self, count: u64, duration: Duration) {
483 self.samples_generated += count;
484 self.total_duration += duration;
485 }
486
487 pub fn stop(&mut self) {
489 if let Some(start) = self.start_time.take() {
490 self.total_duration += start.elapsed();
491 }
492 }
493
494 #[must_use]
496 pub fn finalize(&self) -> ExperimentMetrics {
497 let duration_secs = self.total_duration.as_secs_f64();
498 let power_watts = f64::from(self.device.estimated_power_watts());
499
500 let total_joules = power_watts * duration_secs;
502
503 let energy =
504 EnergyMetrics::new(total_joules, power_watts, power_watts * 1.2, duration_secs)
505 .with_carbon_intensity(self.carbon_intensity);
506
507 let hours = duration_secs / 3600.0;
509 let compute_cost = self.hourly_rate_usd * hours;
510 let cost = CostMetrics::new(compute_cost, 0.0, 0.0).with_samples(self.samples_generated);
511
512 ExperimentMetrics {
513 name: self.name.clone(),
514 samples_generated: self.samples_generated,
515 duration: self.total_duration,
516 energy,
517 cost,
518 samples_per_second: if duration_secs > 0.0 {
519 self.samples_generated as f64 / duration_secs
520 } else {
521 0.0
522 },
523 }
524 }
525}
526
527#[derive(Debug, Clone, Serialize, Deserialize)]
529pub struct ExperimentMetrics {
530 pub name: String,
532 pub samples_generated: u64,
534 #[serde(with = "duration_serde")]
536 pub duration: Duration,
537 pub energy: EnergyMetrics,
539 pub cost: CostMetrics,
541 pub samples_per_second: f64,
543}
544
545impl ExperimentMetrics {
546 #[must_use]
548 pub fn cost_per_sample(&self) -> f64 {
549 self.cost.cost_per_sample()
550 }
551
552 #[must_use]
554 pub fn co2_per_sample(&self) -> f64 {
555 if self.samples_generated > 0 {
556 self.energy.co2_grams.unwrap_or(0.0) / self.samples_generated as f64
557 } else {
558 0.0
559 }
560 }
561}
562
563mod duration_serde {
565 use serde::{Deserialize, Deserializer, Serialize, Serializer};
566 use std::time::Duration;
567
568 pub(super) fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
569 where
570 S: Serializer,
571 {
572 duration.as_secs_f64().serialize(serializer)
573 }
574
575 pub(super) fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
576 where
577 D: Deserializer<'de>,
578 {
579 let secs = f64::deserialize(deserializer)?;
580 Ok(Duration::from_secs_f64(secs))
581 }
582}
583
584#[cfg(test)]
585mod tests {
586 use super::*;
587
588 #[test]
589 fn test_compute_device_cpu() {
590 let device = ComputeDevice::Cpu {
591 cores: 8,
592 threads_per_core: 2,
593 architecture: CpuArchitecture::X86_64,
594 };
595 assert!(device.theoretical_flops() > 0.0);
596 assert!(device.estimated_power_watts() > 0.0);
597 }
598
599 #[test]
600 fn test_compute_device_gpu() {
601 let device = ComputeDevice::Gpu {
602 name: "RTX 4090".to_string(),
603 memory_gb: 24.0,
604 compute_capability: Some("8.9".to_string()),
605 vendor: GpuVendor::Nvidia,
606 };
607 assert!(device.theoretical_flops() > 1e12);
608 assert!(device.estimated_power_watts() > 100.0);
609 }
610
611 #[test]
612 fn test_compute_device_apple_silicon() {
613 let device = ComputeDevice::AppleSilicon {
614 chip: AppleChip::M3Max,
615 neural_engine_cores: 16,
616 gpu_cores: 40,
617 memory_gb: 64,
618 };
619 assert!(device.theoretical_flops() > 1e12);
620 assert_eq!(device.estimated_power_watts(), 50.0);
621 }
622
623 #[test]
624 fn test_energy_metrics() {
625 let energy = EnergyMetrics::new(3600.0, 100.0, 120.0, 36.0)
626 .with_carbon_intensity(386.0)
627 .with_pue(1.2);
628
629 assert!(energy.co2_grams.is_some());
630 assert!(energy.pue > 1.0);
631 }
632
633 #[test]
634 fn test_cost_metrics() {
635 let cost = CostMetrics::new(1.0, 0.1, 0.05).with_samples(1000);
636 assert!((cost.total_cost_usd - 1.15).abs() < 0.0001);
638 assert!((cost.cost_per_sample() - 0.00115).abs() < 0.0001);
639 }
640
641 #[test]
642 fn test_generation_experiment() {
643 let device = ComputeDevice::default_cpu();
644 let mut experiment = GenerationExperiment::new("test-run", device)
645 .with_hourly_rate(0.50)
646 .with_carbon_intensity(200.0);
647
648 experiment.record_samples(1000, Duration::from_secs(60));
649 let metrics = experiment.finalize();
650
651 assert_eq!(metrics.samples_generated, 1000);
652 assert!(metrics.samples_per_second > 10.0);
653 assert!(metrics.cost_per_sample() > 0.0);
654 }
655
656 #[test]
657 fn test_experiment_start_stop() {
658 let device = ComputeDevice::default();
659 let mut experiment = GenerationExperiment::new("timed-run", device);
660
661 experiment.start();
662 std::thread::sleep(Duration::from_millis(10));
663 experiment.stop();
664
665 assert!(experiment.total_duration.as_millis() >= 10);
666 }
667
668 #[test]
669 fn test_compute_device_default() {
670 let device = ComputeDevice::default();
671 match device {
672 ComputeDevice::Cpu { cores, .. } => assert!(cores > 0),
673 _ => panic!("Expected CPU device"),
674 }
675 }
676
677 #[test]
678 fn test_energy_metrics_default() {
679 let energy = EnergyMetrics::default();
680 assert_eq!(energy.total_joules, 0.0);
681 assert_eq!(energy.pue, 1.0);
682 }
683
684 #[test]
685 fn test_cost_metrics_default() {
686 let cost = CostMetrics::default();
687 assert_eq!(cost.total_cost_usd, 0.0);
688 assert_eq!(cost.currency, "USD");
689 }
690
691 #[test]
692 fn test_tpu_device() {
693 let device = ComputeDevice::Tpu {
694 version: TpuVersion::V4,
695 cores: 4,
696 };
697 assert!(device.theoretical_flops() > 1e15);
698 }
699
700 #[test]
701 fn test_edge_device() {
702 let device = ComputeDevice::Edge {
703 name: "Jetson Nano".to_string(),
704 power_budget_watts: 10.0,
705 };
706 assert_eq!(device.estimated_power_watts(), 10.0);
707 }
708
709 #[test]
710 fn test_experiment_metrics_co2_per_sample() {
711 let device = ComputeDevice::default_cpu();
712 let mut experiment = GenerationExperiment::new("co2-test", device);
713 experiment.record_samples(100, Duration::from_secs(10));
714 let metrics = experiment.finalize();
715
716 assert!(metrics.co2_per_sample() >= 0.0);
717 }
718
719 #[test]
720 fn test_experiment_metrics_serialization() {
721 let device = ComputeDevice::default_cpu();
722 let mut experiment = GenerationExperiment::new("serial-test", device);
723 experiment.record_samples(50, Duration::from_secs(5));
724 let metrics = experiment.finalize();
725
726 let json = serde_json::to_string(&metrics).expect("serialization");
727 assert!(json.contains("serial-test"));
728
729 let parsed: ExperimentMetrics = serde_json::from_str(&json).expect("deserialization");
730 assert_eq!(parsed.samples_generated, 50);
731 }
732}