1use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fs::{self, File};
10use std::io::{BufReader, BufWriter};
11use std::path::{Path, PathBuf};
12use std::sync::{Arc, Mutex};
13use std::time::{Duration, Instant, SystemTime};
14
15use crate::error::{FFTError, FFTResult};
16
17mod plan_map_serde {
19 use super::{PlanInfo, PlanMetrics};
20 use serde::{Deserialize, Deserializer, Serialize, Serializer};
21 use std::collections::HashMap;
22
23 pub fn serialize<S>(
24 map: &HashMap<PlanInfo, PlanMetrics>,
25 serializer: S,
26 ) -> Result<S::Ok, S::Error>
27 where
28 S: Serializer,
29 {
30 let vec: Vec<(PlanInfo, PlanMetrics)> =
32 map.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
33 vec.serialize(serializer)
34 }
35
36 pub fn deserialize<'de, D>(deserializer: D) -> Result<HashMap<PlanInfo, PlanMetrics>, D::Error>
37 where
38 D: Deserializer<'de>,
39 {
40 let vec: Vec<(PlanInfo, PlanMetrics)> = Vec::deserialize(deserializer)?;
42 Ok(vec.into_iter().collect())
43 }
44}
45
46#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
48pub struct PlanInfo {
49 pub size: usize,
51 pub forward: bool,
53 pub arch_id: String,
55 pub created_at: u64,
57 pub lib_version: String,
59}
60
61impl std::hash::Hash for PlanInfo {
63 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
64 self.size.hash(state);
65 self.forward.hash(state);
66 self.arch_id.hash(state);
67 }
69}
70
71#[derive(Serialize, Deserialize, Debug)]
73pub struct PlanDatabase {
74 #[serde(with = "plan_map_serde")]
76 pub plans: HashMap<PlanInfo, PlanMetrics>,
77 pub stats: PlanDatabaseStats,
79 pub last_updated: u64,
81}
82
83#[derive(Serialize, Deserialize, Debug, Clone)]
85pub struct PlanMetrics {
86 pub avg_execution_ns: u64,
88 pub usage_count: u64,
90 pub last_used: u64,
92}
93
94#[derive(Serialize, Deserialize, Debug, Default, Clone)]
96pub struct PlanDatabaseStats {
97 pub total_plans_created: u64,
99 pub total_plans_loaded: u64,
101 pub time_saved_ns: u64,
103}
104
105pub struct PlanSerializationManager {
107 db_path: PathBuf,
109 database: Arc<Mutex<PlanDatabase>>,
111 enabled: bool,
113}
114
115impl PlanSerializationManager {
116 pub fn new(dbpath: impl AsRef<Path>) -> Self {
118 let dbpath = dbpath.as_ref().to_path_buf();
119 let database = Self::load_or_create_database(&dbpath).unwrap_or_else(|_| {
120 Arc::new(Mutex::new(PlanDatabase {
121 plans: HashMap::new(),
122 stats: PlanDatabaseStats::default(),
123 last_updated: system_time_as_millis(),
124 }))
125 });
126
127 Self {
128 db_path: dbpath,
129 database,
130 enabled: true,
131 }
132 }
133
134 fn load_or_create_database(path: &Path) -> FFTResult<Arc<Mutex<PlanDatabase>>> {
136 if path.exists() {
137 let file = File::open(path)
138 .map_err(|e| FFTError::IOError(format!("Failed to open plan database: {e}")))?;
139 let reader = BufReader::new(file);
140 let database: PlanDatabase = serde_json::from_reader(reader)
141 .map_err(|e| FFTError::ValueError(format!("Failed to parse plan database: {e}")))?;
142 Ok(Arc::new(Mutex::new(database)))
143 } else {
144 if let Some(parent) = path.parent() {
146 fs::create_dir_all(parent).map_err(|e| {
147 FFTError::IOError(format!("Failed to create directory for plan database: {e}"))
148 })?;
149 }
150
151 let database = PlanDatabase {
153 plans: HashMap::new(),
154 stats: PlanDatabaseStats::default(),
155 last_updated: system_time_as_millis(),
156 };
157 Ok(Arc::new(Mutex::new(database)))
158 }
159 }
160
161 pub fn detect_arch_id() -> String {
163 let mut arch_id = String::new();
166
167 #[cfg(target_arch = "x86_64")]
168 {
169 arch_id.push_str("x86_64");
170 }
171
172 #[cfg(target_arch = "aarch64")]
173 {
174 arch_id.push_str("aarch64");
175 }
176
177 #[cfg(all(target_arch = "x86_64", target_feature = "avx"))]
179 {
180 arch_id.push_str("-avx");
181 }
182
183 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
184 {
185 arch_id.push_str("-avx2");
186 }
187
188 if arch_id.is_empty() {
189 arch_id = format!("unknown-{}", std::env::consts::ARCH);
190 }
191
192 arch_id
193 }
194
195 fn get_lib_version() -> String {
197 env!("CARGO_PKG_VERSION").to_string()
198 }
199
200 pub fn create_plan_info(&self, size: usize, forward: bool) -> PlanInfo {
202 PlanInfo {
203 size,
204 forward,
205 arch_id: Self::detect_arch_id(),
206 created_at: system_time_as_millis(),
207 lib_version: Self::get_lib_version(),
208 }
209 }
210
211 pub fn plan_exists(&self, size: usize, forward: bool) -> bool {
213 if !self.enabled {
214 return false;
215 }
216
217 let arch_id = Self::detect_arch_id();
218 let db = self.database.lock().expect("Operation failed");
219
220 db.plans
221 .keys()
222 .any(|info| info.size == size && info.forward == forward && info.arch_id == arch_id)
223 }
224
225 pub fn record_plan_usage(&self, plan_info: &PlanInfo, execution_timens: u64) -> FFTResult<()> {
227 if !self.enabled {
228 return Ok(());
229 }
230
231 let mut db = self.database.lock().expect("Operation failed");
232
233 let metrics = db
235 .plans
236 .entry(plan_info.clone())
237 .or_insert_with(|| PlanMetrics {
238 avg_execution_ns: execution_timens,
239 usage_count: 0,
240 last_used: system_time_as_millis(),
241 });
242
243 metrics.usage_count += 1;
245 metrics.last_used = system_time_as_millis();
246
247 metrics.avg_execution_ns = if metrics.usage_count > 1 {
249 ((metrics.avg_execution_ns as f64 * (metrics.usage_count - 1) as f64)
250 + execution_timens as f64)
251 / metrics.usage_count as f64
252 } else {
253 execution_timens as f64
254 } as u64;
255
256 if db.last_updated + 60000 < system_time_as_millis() {
258 self.save_database()?;
260 db.last_updated = system_time_as_millis();
261 }
262
263 Ok(())
264 }
265
266 pub fn save_database(&self) -> FFTResult<()> {
268 if !self.enabled {
269 return Ok(());
270 }
271
272 let db = self.database.lock().expect("Operation failed");
273 let file = File::create(&self.db_path)
274 .map_err(|e| FFTError::IOError(format!("Failed to create plan database file: {e}")))?;
275
276 let writer = BufWriter::new(file);
277 serde_json::to_writer_pretty(writer, &*db)
278 .map_err(|e| FFTError::IOError(format!("Failed to serialize plan database: {e}")))?;
279
280 Ok(())
281 }
282
283 pub fn set_enabled(&mut self, enabled: bool) {
285 self.enabled = enabled;
286 }
287
288 pub fn get_best_plan_metrics(
290 &self,
291 size: usize,
292 forward: bool,
293 ) -> Option<(PlanInfo, PlanMetrics)> {
294 if !self.enabled {
295 return None;
296 }
297
298 let arch_id = Self::detect_arch_id();
299 let db = self.database.lock().expect("Operation failed");
300
301 db.plans
302 .iter()
303 .filter(|(info_, _)| {
304 info_.size == size && info_.forward == forward && info_.arch_id == arch_id
305 })
306 .min_by_key(|(_, metrics)| metrics.avg_execution_ns)
307 .map(|(info, metrics)| (info.clone(), metrics.clone()))
308 }
309
310 pub fn get_stats(&self) -> PlanDatabaseStats {
312 if let Ok(db) = self.database.lock() {
313 db.stats.clone()
314 } else {
315 PlanDatabaseStats::default()
316 }
317 }
318}
319
320#[allow(dead_code)]
322fn system_time_as_millis() -> u64 {
323 SystemTime::now()
324 .duration_since(SystemTime::UNIX_EPOCH)
325 .unwrap_or_else(|_| Duration::from_secs(0))
326 .as_millis() as u64
327}
328
329#[allow(dead_code)]
334pub fn create_and_time_plan(size: usize, forward: bool) -> u64 {
335 #[cfg(feature = "oxifft")]
336 {
337 use crate::oxifft_plan_cache;
338 use oxifft::{Complex as OxiComplex, Direction};
339
340 let start = Instant::now();
341
342 let input = vec![OxiComplex::new(0.0, 0.0); size];
343 let mut output = vec![OxiComplex::new(0.0, 0.0); size];
344
345 let direction = if forward {
346 Direction::Forward
347 } else {
348 Direction::Backward
349 };
350 let _ = oxifft_plan_cache::execute_c2c(&input, &mut output, direction);
351
352 start.elapsed().as_nanos() as u64
353 }
354
355 #[cfg(not(feature = "oxifft"))]
356 {
357 let _ = (size, forward);
358 0u64
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use tempfile::tempdir;
366
367 #[test]
368 fn test_plan_serialization_basic() {
369 let temp_dir = tempdir().expect("Operation failed");
371 let db_path = temp_dir.path().join("test_plan_db.json");
372
373 let manager = PlanSerializationManager::new(&db_path);
375
376 let plan_info = manager.create_plan_info(1024, true);
378
379 manager
381 .record_plan_usage(&plan_info, 5000)
382 .expect("Operation failed");
383
384 assert!(manager.plan_exists(1024, true));
386
387 manager.save_database().expect("Operation failed");
389
390 assert!(db_path.exists());
392 }
393
394 #[test]
395 fn test_arch_detection() {
396 let arch_id = PlanSerializationManager::detect_arch_id();
397 assert!(!arch_id.is_empty());
398 }
399
400 #[test]
401 fn test_get_best_plan() {
402 let temp_dir = tempdir().expect("Operation failed");
404 let db_path = temp_dir.path().join("test_best_plan.json");
405
406 let manager = PlanSerializationManager::new(&db_path);
408
409 let plan_info1 = manager.create_plan_info(512, true);
411
412 std::thread::sleep(Duration::from_millis(10));
414 let plan_info2 = manager.create_plan_info(512, true);
415
416 let time1 = 8000u64;
418 let time2 = 5000u64;
419 manager
420 .record_plan_usage(&plan_info1, time1)
421 .expect("Operation failed");
422 manager
423 .record_plan_usage(&plan_info2, time2)
424 .expect("Operation failed");
425
426 let best = manager.get_best_plan_metrics(512, true);
428 assert!(best.is_some());
429
430 let (_, metrics) = best.expect("Operation failed");
431 assert!(metrics.avg_execution_ns == time1 || metrics.avg_execution_ns == time2);
433 assert!(metrics.avg_execution_ns <= std::cmp::max(time1, time2));
434 }
435}