cbtop/profile_persistence/
overlay.rs1use super::config::{default_threads, BackendConfig, ProfileConfig, WorkloadConfig};
4
5#[derive(Debug, Clone, Default)]
7pub struct ProfileOverlay {
8 pub refresh_ms: Option<u64>,
10 pub device_index: Option<u32>,
12 pub backend: Option<BackendConfig>,
14 pub load_intensity: Option<f64>,
16 pub workload: Option<WorkloadConfig>,
18 pub problem_size: Option<usize>,
20 pub threads: Option<usize>,
22 pub deterministic: Option<bool>,
24}
25
26impl ProfileOverlay {
27 pub fn new() -> Self {
29 Self::default()
30 }
31
32 pub fn refresh_ms(mut self, ms: u64) -> Self {
34 self.refresh_ms = Some(ms);
35 self
36 }
37
38 pub fn backend(mut self, backend: BackendConfig) -> Self {
40 self.backend = Some(backend);
41 self
42 }
43
44 pub fn workload(mut self, workload: WorkloadConfig) -> Self {
46 self.workload = Some(workload);
47 self
48 }
49
50 pub fn problem_size(mut self, size: usize) -> Self {
52 self.problem_size = Some(size);
53 self
54 }
55
56 pub fn apply(&self, mut profile: ProfileConfig) -> ProfileConfig {
58 if let Some(v) = self.refresh_ms {
59 profile.refresh_ms = v;
60 }
61 if let Some(v) = self.device_index {
62 profile.device_index = v;
63 }
64 if let Some(v) = self.backend {
65 profile.backend = v;
66 }
67 if let Some(v) = self.load_intensity {
68 profile.load_intensity = v;
69 }
70 if let Some(v) = self.workload {
71 profile.workload = v;
72 }
73 if let Some(v) = self.problem_size {
74 profile.problem_size = v;
75 }
76 if let Some(v) = self.threads {
77 profile.threads = v;
78 }
79 if let Some(v) = self.deterministic {
80 profile.deterministic = v;
81 }
82 profile
83 }
84
85 pub fn has_overrides(&self) -> bool {
87 self.refresh_ms.is_some()
88 || self.device_index.is_some()
89 || self.backend.is_some()
90 || self.load_intensity.is_some()
91 || self.workload.is_some()
92 || self.problem_size.is_some()
93 || self.threads.is_some()
94 || self.deterministic.is_some()
95 }
96}
97
98pub mod templates {
100 use super::*;
101
102 pub fn ml_training() -> ProfileConfig {
104 ProfileConfig {
105 name: "ml_training".to_string(),
106 description: "Optimized for ML training workloads".to_string(),
107 version: "1.0".to_string(),
108 refresh_ms: 200,
109 device_index: 0,
110 backend: BackendConfig::Cuda,
111 load_intensity: 0.75,
112 workload: WorkloadConfig::Gemm,
113 problem_size: 4_194_304,
114 threads: default_threads(),
115 deterministic: false,
116 metadata: [("use_case".to_string(), "training".to_string())]
117 .into_iter()
118 .collect(),
119 }
120 }
121
122 pub fn inference() -> ProfileConfig {
124 ProfileConfig {
125 name: "inference".to_string(),
126 description: "Optimized for inference workloads".to_string(),
127 version: "1.0".to_string(),
128 refresh_ms: 50,
129 device_index: 0,
130 backend: BackendConfig::Cuda,
131 load_intensity: 0.5,
132 workload: WorkloadConfig::Attention,
133 problem_size: 1_048_576,
134 threads: default_threads(),
135 deterministic: true,
136 metadata: [("use_case".to_string(), "inference".to_string())]
137 .into_iter()
138 .collect(),
139 }
140 }
141
142 pub fn stress_test() -> ProfileConfig {
144 ProfileConfig {
145 name: "stress_test".to_string(),
146 description: "Maximum stress for stability testing".to_string(),
147 version: "1.0".to_string(),
148 refresh_ms: 100,
149 device_index: 0,
150 backend: BackendConfig::All,
151 load_intensity: 1.0,
152 workload: WorkloadConfig::All,
153 problem_size: 16_777_216,
154 threads: default_threads(),
155 deterministic: false,
156 metadata: [("use_case".to_string(), "stress".to_string())]
157 .into_iter()
158 .collect(),
159 }
160 }
161
162 pub fn simd_only() -> ProfileConfig {
164 ProfileConfig {
165 name: "simd_only".to_string(),
166 description: "CPU SIMD operations only".to_string(),
167 version: "1.0".to_string(),
168 refresh_ms: 100,
169 device_index: 0,
170 backend: BackendConfig::Simd,
171 load_intensity: 0.5,
172 workload: WorkloadConfig::Elementwise,
173 problem_size: 1_048_576,
174 threads: default_threads(),
175 deterministic: false,
176 metadata: [("use_case".to_string(), "cpu".to_string())]
177 .into_iter()
178 .collect(),
179 }
180 }
181}