#[cfg(feature = "rustfft-backend")]
use rustfft::FftPlanner;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs::{self, File};
use std::io::{BufReader, BufWriter};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant, SystemTime};
use crate::error::{FFTError, FFTResult};
mod plan_map_serde {
use super::{PlanInfo, PlanMetrics};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::collections::HashMap;
pub fn serialize<S>(
map: &HashMap<PlanInfo, PlanMetrics>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let vec: Vec<(PlanInfo, PlanMetrics)> =
map.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
vec.serialize(serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<HashMap<PlanInfo, PlanMetrics>, D::Error>
where
D: Deserializer<'de>,
{
let vec: Vec<(PlanInfo, PlanMetrics)> = Vec::deserialize(deserializer)?;
Ok(vec.into_iter().collect())
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct PlanInfo {
pub size: usize,
pub forward: bool,
pub arch_id: String,
pub created_at: u64,
pub lib_version: String,
}
impl std::hash::Hash for PlanInfo {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.size.hash(state);
self.forward.hash(state);
self.arch_id.hash(state);
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct PlanDatabase {
#[serde(with = "plan_map_serde")]
pub plans: HashMap<PlanInfo, PlanMetrics>,
pub stats: PlanDatabaseStats,
pub last_updated: u64,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PlanMetrics {
pub avg_execution_ns: u64,
pub usage_count: u64,
pub last_used: u64,
}
#[derive(Serialize, Deserialize, Debug, Default, Clone)]
pub struct PlanDatabaseStats {
pub total_plans_created: u64,
pub total_plans_loaded: u64,
pub time_saved_ns: u64,
}
pub struct PlanSerializationManager {
db_path: PathBuf,
database: Arc<Mutex<PlanDatabase>>,
enabled: bool,
}
impl PlanSerializationManager {
pub fn new(dbpath: impl AsRef<Path>) -> Self {
let dbpath = dbpath.as_ref().to_path_buf();
let database = Self::load_or_create_database(&dbpath).unwrap_or_else(|_| {
Arc::new(Mutex::new(PlanDatabase {
plans: HashMap::new(),
stats: PlanDatabaseStats::default(),
last_updated: system_time_as_millis(),
}))
});
Self {
db_path: dbpath,
database,
enabled: true,
}
}
fn load_or_create_database(path: &Path) -> FFTResult<Arc<Mutex<PlanDatabase>>> {
if path.exists() {
let file = File::open(path)
.map_err(|e| FFTError::IOError(format!("Failed to open plan database: {e}")))?;
let reader = BufReader::new(file);
let database: PlanDatabase = serde_json::from_reader(reader)
.map_err(|e| FFTError::ValueError(format!("Failed to parse plan database: {e}")))?;
Ok(Arc::new(Mutex::new(database)))
} else {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).map_err(|e| {
FFTError::IOError(format!("Failed to create directory for plan database: {e}"))
})?;
}
let database = PlanDatabase {
plans: HashMap::new(),
stats: PlanDatabaseStats::default(),
last_updated: system_time_as_millis(),
};
Ok(Arc::new(Mutex::new(database)))
}
}
pub fn detect_arch_id() -> String {
let mut arch_id = String::new();
#[cfg(target_arch = "x86_64")]
{
arch_id.push_str("x86_64");
}
#[cfg(target_arch = "aarch64")]
{
arch_id.push_str("aarch64");
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx"))]
{
arch_id.push_str("-avx");
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
{
arch_id.push_str("-avx2");
}
if arch_id.is_empty() {
arch_id = format!("unknown-{}", std::env::consts::ARCH);
}
arch_id
}
fn get_lib_version() -> String {
env!("CARGO_PKG_VERSION").to_string()
}
pub fn create_plan_info(&self, size: usize, forward: bool) -> PlanInfo {
PlanInfo {
size,
forward,
arch_id: Self::detect_arch_id(),
created_at: system_time_as_millis(),
lib_version: Self::get_lib_version(),
}
}
pub fn plan_exists(&self, size: usize, forward: bool) -> bool {
if !self.enabled {
return false;
}
let arch_id = Self::detect_arch_id();
let db = self.database.lock().expect("Operation failed");
db.plans
.keys()
.any(|info| info.size == size && info.forward == forward && info.arch_id == arch_id)
}
pub fn record_plan_usage(&self, plan_info: &PlanInfo, execution_timens: u64) -> FFTResult<()> {
if !self.enabled {
return Ok(());
}
let mut db = self.database.lock().expect("Operation failed");
let metrics = db
.plans
.entry(plan_info.clone())
.or_insert_with(|| PlanMetrics {
avg_execution_ns: execution_timens,
usage_count: 0,
last_used: system_time_as_millis(),
});
metrics.usage_count += 1;
metrics.last_used = system_time_as_millis();
metrics.avg_execution_ns = if metrics.usage_count > 1 {
((metrics.avg_execution_ns as f64 * (metrics.usage_count - 1) as f64)
+ execution_timens as f64)
/ metrics.usage_count as f64
} else {
execution_timens as f64
} as u64;
if db.last_updated + 60000 < system_time_as_millis() {
self.save_database()?;
db.last_updated = system_time_as_millis();
}
Ok(())
}
pub fn save_database(&self) -> FFTResult<()> {
if !self.enabled {
return Ok(());
}
let db = self.database.lock().expect("Operation failed");
let file = File::create(&self.db_path)
.map_err(|e| FFTError::IOError(format!("Failed to create plan database file: {e}")))?;
let writer = BufWriter::new(file);
serde_json::to_writer_pretty(writer, &*db)
.map_err(|e| FFTError::IOError(format!("Failed to serialize plan database: {e}")))?;
Ok(())
}
pub fn set_enabled(&mut self, enabled: bool) {
self.enabled = enabled;
}
pub fn get_best_plan_metrics(
&self,
size: usize,
forward: bool,
) -> Option<(PlanInfo, PlanMetrics)> {
if !self.enabled {
return None;
}
let arch_id = Self::detect_arch_id();
let db = self.database.lock().expect("Operation failed");
db.plans
.iter()
.filter(|(info_, _)| {
info_.size == size && info_.forward == forward && info_.arch_id == arch_id
})
.min_by_key(|(_, metrics)| metrics.avg_execution_ns)
.map(|(info, metrics)| (info.clone(), metrics.clone()))
}
pub fn get_stats(&self) -> PlanDatabaseStats {
if let Ok(db) = self.database.lock() {
db.stats.clone()
} else {
PlanDatabaseStats::default()
}
}
}
#[allow(dead_code)]
fn system_time_as_millis() -> u64 {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_else(|_| Duration::from_secs(0))
.as_millis() as u64
}
#[cfg(feature = "rustfft-backend")]
#[allow(dead_code)]
pub fn create_and_time_plan(size: usize, forward: bool) -> (Arc<dyn rustfft::Fft<f64>>, u64) {
let start = Instant::now();
let mut planner = FftPlanner::new();
let plan = if forward {
planner.plan_fft_forward(size)
} else {
planner.plan_fft_inverse(size)
};
let elapsed_ns = start.elapsed().as_nanos() as u64;
(plan, elapsed_ns)
}
#[cfg(all(feature = "oxifft", not(feature = "rustfft-backend")))]
#[allow(dead_code)]
pub fn create_and_time_plan_timing_only(size: usize, forward: bool) -> u64 {
use crate::oxifft_plan_cache;
use oxifft::{Complex as OxiComplex, Direction};
let start = Instant::now();
let mut input = vec![OxiComplex::zero(); size];
let mut output = vec![OxiComplex::zero(); size];
let direction = if forward {
Direction::Forward
} else {
Direction::Backward
};
let _ = oxifft_plan_cache::execute_c2c(&input, &mut output, direction);
let elapsed_ns = start.elapsed().as_nanos() as u64;
elapsed_ns
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_plan_serialization_basic() {
let temp_dir = tempdir().expect("Operation failed");
let db_path = temp_dir.path().join("test_plan_db.json");
let manager = PlanSerializationManager::new(&db_path);
let plan_info = manager.create_plan_info(1024, true);
manager
.record_plan_usage(&plan_info, 5000)
.expect("Operation failed");
assert!(manager.plan_exists(1024, true));
manager.save_database().expect("Operation failed");
assert!(db_path.exists());
}
#[test]
fn test_arch_detection() {
let arch_id = PlanSerializationManager::detect_arch_id();
assert!(!arch_id.is_empty());
}
#[test]
fn test_get_best_plan() {
let temp_dir = tempdir().expect("Operation failed");
let db_path = temp_dir.path().join("test_best_plan.json");
let manager = PlanSerializationManager::new(&db_path);
let plan_info1 = manager.create_plan_info(512, true);
std::thread::sleep(Duration::from_millis(10));
let plan_info2 = manager.create_plan_info(512, true);
let time1 = 8000u64;
let time2 = 5000u64;
manager
.record_plan_usage(&plan_info1, time1)
.expect("Operation failed");
manager
.record_plan_usage(&plan_info2, time2)
.expect("Operation failed");
let best = manager.get_best_plan_metrics(512, true);
assert!(best.is_some());
let (_, metrics) = best.expect("Operation failed");
assert!(metrics.avg_execution_ns == time1 || metrics.avg_execution_ns == time2);
assert!(metrics.avg_execution_ns <= std::cmp::max(time1, time2));
}
}