use crate::backend::BackendContext;
use crate::error::FFTResult;
use crate::plan_cache::get_global_cache;
use crate::worker_pool::get_global_pool;
pub struct FftContext {
backend_context: Option<BackendContext>,
previous_workers: Option<usize>,
previous_cache_enabled: Option<bool>,
worker_pool: &'static crate::worker_pool::WorkerPool,
plan_cache: &'static crate::plan_cache::PlanCache,
}
impl FftContext {
pub fn new() -> Self {
Self {
backend_context: None,
previous_workers: None,
previous_cache_enabled: None,
worker_pool: get_global_pool(),
plan_cache: get_global_cache(),
}
}
}
impl Default for FftContext {
fn default() -> Self {
Self::new()
}
}
impl FftContext {
pub fn with_backend(mut self, backendname: &str) -> FFTResult<Self> {
self.backend_context = Some(BackendContext::new(backendname)?);
Ok(self)
}
pub fn with_workers(mut self, _numworkers: usize) -> Self {
self.previous_workers = Some(self.worker_pool.get_workers());
self
}
pub fn with_cache(mut self, enabled: bool) -> Self {
self.previous_cache_enabled = Some(self.plan_cache.is_enabled());
self.plan_cache.set_enabled(enabled);
self
}
pub fn __enter__(self) -> Self {
self
}
pub fn __exit__(self) {
}
}
impl Drop for FftContext {
fn drop(&mut self) {
if let Some(enabled) = self.previous_cache_enabled {
self.plan_cache.set_enabled(enabled);
}
}
}
pub struct FftContextBuilder {
backend: Option<String>,
workers: Option<usize>,
cache_enabled: Option<bool>,
cache_size: Option<usize>,
cache_ttl: Option<std::time::Duration>,
}
impl FftContextBuilder {
pub fn new() -> Self {
Self {
backend: None,
workers: None,
cache_enabled: None,
cache_size: None,
cache_ttl: None,
}
}
pub fn backend(mut self, name: &str) -> Self {
self.backend = Some(name.to_string());
self
}
pub fn workers(mut self, count: usize) -> Self {
self.workers = Some(count);
self
}
pub fn cache_enabled(mut self, enabled: bool) -> Self {
self.cache_enabled = Some(enabled);
self
}
pub fn cache_size(mut self, size: usize) -> Self {
self.cache_size = Some(size);
self
}
pub fn cache_ttl(mut self, ttl: std::time::Duration) -> Self {
self.cache_ttl = Some(ttl);
self
}
pub fn build(self) -> FFTResult<FftContext> {
let mut context = FftContext::new();
if let Some(backend) = self.backend {
context = context.with_backend(&backend)?;
}
if let Some(workers) = self.workers {
context = context.with_workers(workers);
}
if let Some(enabled) = self.cache_enabled {
context = context.with_cache(enabled);
}
Ok(context)
}
}
impl Default for FftContextBuilder {
fn default() -> Self {
Self::new()
}
}
#[allow(dead_code)]
pub fn fft_context() -> FftContextBuilder {
FftContextBuilder::new()
}
pub struct FftSettingsGuard {
_context: FftContext,
}
impl FftSettingsGuard {
pub fn new(context: FftContext) -> Self {
Self { _context: context }
}
}
#[allow(dead_code)]
pub fn with_fft_settings<F, R>(builder: FftContextBuilder, f: F) -> FFTResult<R>
where
F: FnOnce() -> R,
{
let context = builder.build()?;
let _guard = FftSettingsGuard::new(context);
Ok(f())
}
#[allow(dead_code)]
pub fn with_backend<F, R>(backend: &str, f: F) -> FFTResult<R>
where
F: FnOnce() -> R,
{
with_fft_settings(fft_context().backend(backend), f)
}
#[allow(dead_code)]
pub fn with_workers<F, R>(workers: usize, f: F) -> FFTResult<R>
where
F: FnOnce() -> R,
{
with_fft_settings(fft_context().workers(workers), f)
}
#[allow(dead_code)]
pub fn without_cache<F, R>(f: F) -> FFTResult<R>
where
F: FnOnce() -> R,
{
with_fft_settings(fft_context().cache_enabled(false), f)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_builder() {
let builder = fft_context()
.backend("rustfft")
.workers(4)
.cache_enabled(true);
let context = builder.build().expect("Operation failed");
drop(context);
}
#[test]
fn test_with_backend() {
let result = with_backend("rustfft", || {
42
});
assert_eq!(result.expect("Operation failed"), 42);
}
#[test]
fn test_with_workers() {
let result = with_workers(2, || {
84
});
assert_eq!(result.expect("Operation failed"), 84);
}
#[test]
fn test_without_cache() {
let result = without_cache(|| {
168
});
assert_eq!(result.expect("Operation failed"), 168);
}
#[test]
fn test_combined_settings() {
let result = with_fft_settings(
fft_context()
.backend("rustfft")
.workers(4)
.cache_enabled(false),
|| {
336
},
);
assert_eq!(result.expect("Operation failed"), 336);
}
}