cbtop/profile_persistence/
config.rs1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use super::{ProfileError, ProfileResult};
7
8#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
10pub struct ProfileConfig {
11 pub name: String,
13 #[serde(default)]
15 pub description: String,
16 #[serde(default = "default_version")]
18 pub version: String,
19 #[serde(default = "default_refresh_ms")]
21 pub refresh_ms: u64,
22 #[serde(default)]
24 pub device_index: u32,
25 #[serde(default)]
27 pub backend: BackendConfig,
28 #[serde(default = "default_load_intensity")]
30 pub load_intensity: f64,
31 #[serde(default)]
33 pub workload: WorkloadConfig,
34 #[serde(default = "default_problem_size")]
36 pub problem_size: usize,
37 #[serde(default = "default_threads")]
39 pub threads: usize,
40 #[serde(default)]
42 pub deterministic: bool,
43 #[serde(default)]
45 pub metadata: HashMap<String, String>,
46}
47
48pub(super) fn default_version() -> String {
49 "1.0".to_string()
50}
51
52pub(super) fn default_refresh_ms() -> u64 {
53 100
54}
55
56pub(super) fn default_load_intensity() -> f64 {
57 0.0
58}
59
60pub(super) fn default_problem_size() -> usize {
61 1_048_576
62}
63
64pub(super) fn default_threads() -> usize {
65 std::thread::available_parallelism()
66 .map(|n| n.get())
67 .unwrap_or(1)
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
72#[serde(rename_all = "lowercase")]
73pub enum BackendConfig {
74 Simd,
75 Wgpu,
76 Cuda,
77 #[default]
78 All,
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
83#[serde(rename_all = "lowercase")]
84pub enum WorkloadConfig {
85 #[default]
86 Gemm,
87 Conv2d,
88 Attention,
89 Bandwidth,
90 Elementwise,
91 Reduction,
92 All,
93}
94
95impl Default for ProfileConfig {
96 fn default() -> Self {
97 Self {
98 name: "default".to_string(),
99 description: String::new(),
100 version: default_version(),
101 refresh_ms: default_refresh_ms(),
102 device_index: 0,
103 backend: BackendConfig::default(),
104 load_intensity: default_load_intensity(),
105 workload: WorkloadConfig::default(),
106 problem_size: default_problem_size(),
107 threads: default_threads(),
108 deterministic: false,
109 metadata: HashMap::new(),
110 }
111 }
112}
113
114impl ProfileConfig {
115 pub fn new(name: &str) -> ProfileResult<Self> {
117 validate_profile_name(name)?;
118 let mut config = Self::default();
119 config.name = name.to_string();
120 Ok(config)
121 }
122
123 pub fn with_description(name: &str, description: &str) -> ProfileResult<Self> {
125 let mut config = Self::new(name)?;
126 config.description = description.to_string();
127 Ok(config)
128 }
129
130 pub fn backend(mut self, backend: BackendConfig) -> Self {
132 self.backend = backend;
133 self
134 }
135
136 pub fn workload(mut self, workload: WorkloadConfig) -> Self {
138 self.workload = workload;
139 self
140 }
141
142 pub fn problem_size(mut self, size: usize) -> Self {
144 self.problem_size = size;
145 self
146 }
147
148 pub fn load_intensity(mut self, intensity: f64) -> Self {
150 self.load_intensity = intensity.clamp(0.0, 1.0);
151 self
152 }
153
154 pub fn threads(mut self, threads: usize) -> Self {
156 self.threads = threads;
157 self
158 }
159
160 pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
162 self.metadata.insert(key.to_string(), value.to_string());
163 self
164 }
165
166 pub fn to_toml(&self) -> ProfileResult<String> {
168 toml::to_string_pretty(self).map_err(|e| ProfileError::ParseError(e.to_string()))
169 }
170
171 pub fn from_toml(toml_str: &str) -> ProfileResult<Self> {
173 toml::from_str(toml_str).map_err(|e| ProfileError::ParseError(e.to_string()))
174 }
175}
176
177pub(super) fn validate_profile_name(name: &str) -> ProfileResult<()> {
179 if name.is_empty() {
180 return Err(ProfileError::InvalidName(
181 "name cannot be empty".to_string(),
182 ));
183 }
184
185 if name.len() > 64 {
186 return Err(ProfileError::InvalidName(
187 "name cannot exceed 64 characters".to_string(),
188 ));
189 }
190
191 if !name
193 .chars()
194 .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
195 {
196 return Err(ProfileError::InvalidName(
197 "name can only contain alphanumeric, underscore, or hyphen".to_string(),
198 ));
199 }
200
201 if let Some(first) = name.chars().next() {
203 if first == '-' || first.is_numeric() {
204 return Err(ProfileError::InvalidName(
205 "name cannot start with hyphen or number".to_string(),
206 ));
207 }
208 }
209
210 Ok(())
211}