use std::str::FromStr;
use std::sync::{Arc, RwLock};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum LogLevel {
Trace,
Debug,
Info,
Warn,
Error,
Off,
}
impl LogLevel {
pub const fn as_str(&self) -> &'static str {
match self {
Self::Trace => "trace",
Self::Debug => "debug",
Self::Info => "info",
Self::Warn => "warn",
Self::Error => "error",
Self::Off => "off",
}
}
}
impl FromStr for LogLevel {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"trace" => Ok(Self::Trace),
"debug" => Ok(Self::Debug),
"info" => Ok(Self::Info),
"warn" | "warning" => Ok(Self::Warn),
"error" => Ok(Self::Error),
"off" | "none" => Ok(Self::Off),
_ => Err(format!("Invalid log level: {s}")),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DefaultBackend {
Cpu,
Gpu,
TensorNetwork,
Stabilizer,
Auto,
}
impl DefaultBackend {
pub const fn as_str(&self) -> &'static str {
match self {
Self::Cpu => "cpu",
Self::Gpu => "gpu",
Self::TensorNetwork => "tensor_network",
Self::Stabilizer => "stabilizer",
Self::Auto => "auto",
}
}
}
impl FromStr for DefaultBackend {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"cpu" => Ok(Self::Cpu),
"gpu" => Ok(Self::Gpu),
"tensor_network" | "tensor-network" | "tn" => Ok(Self::TensorNetwork),
"stabilizer" => Ok(Self::Stabilizer),
"auto" => Ok(Self::Auto),
_ => Err(format!("Invalid backend: {s}")),
}
}
}
#[derive(Debug, Clone)]
pub struct ConfigData {
pub num_threads: Option<usize>,
pub log_level: LogLevel,
pub memory_limit_bytes: Option<usize>,
pub default_backend: DefaultBackend,
pub enable_gpu: bool,
pub enable_simd: bool,
pub enable_telemetry: bool,
pub cache_dir: Option<String>,
pub max_cache_size_bytes: Option<usize>,
}
impl Default for ConfigData {
fn default() -> Self {
Self {
num_threads: None, log_level: LogLevel::Warn,
memory_limit_bytes: None, default_backend: DefaultBackend::Auto,
enable_gpu: true,
enable_simd: true,
enable_telemetry: false,
cache_dir: None,
max_cache_size_bytes: Some(1024 * 1024 * 1024), }
}
}
pub struct Config {
data: Arc<RwLock<ConfigData>>,
}
impl Config {
fn new() -> Self {
Self {
data: Arc::new(RwLock::new(ConfigData::default())),
}
}
pub fn global() -> &'static Self {
use std::sync::OnceLock;
static CONFIG: OnceLock<Config> = OnceLock::new();
CONFIG.get_or_init(|| {
let config = Self::new();
config.load_from_env();
config
})
}
pub fn builder() -> ConfigBuilder {
ConfigBuilder {
data: ConfigData::default(),
}
}
pub fn num_threads(&self) -> Option<usize> {
self.data
.read()
.unwrap_or_else(|e| e.into_inner())
.num_threads
}
pub fn set_num_threads(&self, num_threads: usize) {
self.data
.write()
.unwrap_or_else(|e| e.into_inner())
.num_threads = Some(num_threads);
}
pub fn log_level(&self) -> LogLevel {
self.data
.read()
.unwrap_or_else(|e| e.into_inner())
.log_level
}
pub fn set_log_level(&self, level: LogLevel) {
self.data
.write()
.unwrap_or_else(|e| e.into_inner())
.log_level = level;
}
pub fn memory_limit_bytes(&self) -> Option<usize> {
self.data
.read()
.unwrap_or_else(|e| e.into_inner())
.memory_limit_bytes
}
pub fn set_memory_limit_bytes(&self, limit: usize) {
self.data
.write()
.unwrap_or_else(|e| e.into_inner())
.memory_limit_bytes = Some(limit);
}
pub fn set_memory_limit_gb(&self, limit_gb: usize) {
self.set_memory_limit_bytes(limit_gb * 1024 * 1024 * 1024);
}
pub fn default_backend(&self) -> DefaultBackend {
self.data
.read()
.unwrap_or_else(|e| e.into_inner())
.default_backend
}
pub fn set_default_backend(&self, backend: DefaultBackend) {
self.data
.write()
.unwrap_or_else(|e| e.into_inner())
.default_backend = backend;
}
pub fn is_gpu_enabled(&self) -> bool {
self.data
.read()
.unwrap_or_else(|e| e.into_inner())
.enable_gpu
}
pub fn set_gpu_enabled(&self, enabled: bool) {
self.data
.write()
.unwrap_or_else(|e| e.into_inner())
.enable_gpu = enabled;
}
pub fn is_simd_enabled(&self) -> bool {
self.data
.read()
.unwrap_or_else(|e| e.into_inner())
.enable_simd
}
pub fn set_simd_enabled(&self, enabled: bool) {
self.data
.write()
.unwrap_or_else(|e| e.into_inner())
.enable_simd = enabled;
}
pub fn is_telemetry_enabled(&self) -> bool {
self.data
.read()
.unwrap_or_else(|e| e.into_inner())
.enable_telemetry
}
pub fn set_telemetry_enabled(&self, enabled: bool) {
self.data
.write()
.unwrap_or_else(|e| e.into_inner())
.enable_telemetry = enabled;
}
pub fn cache_dir(&self) -> Option<String> {
self.data
.read()
.unwrap_or_else(|e| e.into_inner())
.cache_dir
.clone()
}
pub fn set_cache_dir(&self, dir: impl Into<String>) {
self.data
.write()
.unwrap_or_else(|e| e.into_inner())
.cache_dir = Some(dir.into());
}
pub fn max_cache_size_bytes(&self) -> Option<usize> {
self.data
.read()
.unwrap_or_else(|e| e.into_inner())
.max_cache_size_bytes
}
pub fn set_max_cache_size_bytes(&self, size: usize) {
self.data
.write()
.unwrap_or_else(|e| e.into_inner())
.max_cache_size_bytes = Some(size);
}
pub fn set_max_cache_size_mb(&self, size_mb: usize) {
self.set_max_cache_size_bytes(size_mb * 1024 * 1024);
}
fn load_from_env(&self) {
if let Ok(val) = std::env::var("QUANTRS2_NUM_THREADS") {
if let Ok(num) = val.parse::<usize>() {
self.set_num_threads(num);
}
}
if let Ok(val) = std::env::var("QUANTRS2_LOG_LEVEL") {
if let Ok(level) = val.parse::<LogLevel>() {
self.set_log_level(level);
}
}
if let Ok(val) = std::env::var("QUANTRS2_MEMORY_LIMIT_GB") {
if let Ok(limit) = val.parse::<usize>() {
self.set_memory_limit_gb(limit);
}
}
if let Ok(val) = std::env::var("QUANTRS2_BACKEND") {
if let Ok(backend) = val.parse::<DefaultBackend>() {
self.set_default_backend(backend);
}
}
if let Ok(val) = std::env::var("QUANTRS2_ENABLE_GPU") {
if let Ok(enabled) = val.parse::<bool>() {
self.set_gpu_enabled(enabled);
}
}
if let Ok(val) = std::env::var("QUANTRS2_ENABLE_SIMD") {
if let Ok(enabled) = val.parse::<bool>() {
self.set_simd_enabled(enabled);
}
}
if let Ok(val) = std::env::var("QUANTRS2_ENABLE_TELEMETRY") {
if let Ok(enabled) = val.parse::<bool>() {
self.set_telemetry_enabled(enabled);
}
}
if let Ok(val) = std::env::var("QUANTRS2_CACHE_DIR") {
self.set_cache_dir(val);
}
if let Ok(val) = std::env::var("QUANTRS2_MAX_CACHE_SIZE_MB") {
if let Ok(size) = val.parse::<usize>() {
self.set_max_cache_size_mb(size);
}
}
}
pub fn reset(&self) {
*self.data.write().unwrap_or_else(|e| e.into_inner()) = ConfigData::default();
}
pub fn snapshot(&self) -> ConfigData {
self.data.read().unwrap_or_else(|e| e.into_inner()).clone()
}
}
pub struct ConfigBuilder {
data: ConfigData,
}
impl ConfigBuilder {
#[must_use]
pub const fn num_threads(mut self, num_threads: usize) -> Self {
self.data.num_threads = Some(num_threads);
self
}
#[must_use]
pub const fn log_level(mut self, level: LogLevel) -> Self {
self.data.log_level = level;
self
}
#[must_use]
pub const fn memory_limit_bytes(mut self, limit: usize) -> Self {
self.data.memory_limit_bytes = Some(limit);
self
}
#[must_use]
pub const fn memory_limit_gb(self, limit_gb: usize) -> Self {
self.memory_limit_bytes(limit_gb * 1024 * 1024 * 1024)
}
#[must_use]
pub const fn default_backend(mut self, backend: DefaultBackend) -> Self {
self.data.default_backend = backend;
self
}
#[must_use]
pub const fn enable_gpu(mut self, enabled: bool) -> Self {
self.data.enable_gpu = enabled;
self
}
#[must_use]
pub const fn enable_simd(mut self, enabled: bool) -> Self {
self.data.enable_simd = enabled;
self
}
#[must_use]
pub const fn enable_telemetry(mut self, enabled: bool) -> Self {
self.data.enable_telemetry = enabled;
self
}
#[must_use]
pub fn cache_dir(mut self, dir: impl Into<String>) -> Self {
self.data.cache_dir = Some(dir.into());
self
}
#[must_use]
pub const fn max_cache_size_bytes(mut self, size: usize) -> Self {
self.data.max_cache_size_bytes = Some(size);
self
}
#[must_use]
pub const fn max_cache_size_mb(self, size_mb: usize) -> Self {
self.max_cache_size_bytes(size_mb * 1024 * 1024)
}
pub fn apply(self) {
let config = Config::global();
*config.data.write().unwrap_or_else(|e| e.into_inner()) = self.data;
}
pub fn build(self) -> ConfigData {
self.data
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_log_level_parsing() {
assert_eq!("trace".parse::<LogLevel>(), Ok(LogLevel::Trace));
assert_eq!("DEBUG".parse::<LogLevel>(), Ok(LogLevel::Debug));
assert_eq!("Info".parse::<LogLevel>(), Ok(LogLevel::Info));
assert_eq!("warn".parse::<LogLevel>(), Ok(LogLevel::Warn));
assert_eq!("error".parse::<LogLevel>(), Ok(LogLevel::Error));
assert_eq!("off".parse::<LogLevel>(), Ok(LogLevel::Off));
assert!("invalid".parse::<LogLevel>().is_err());
}
#[test]
fn test_backend_parsing() {
assert_eq!("cpu".parse::<DefaultBackend>(), Ok(DefaultBackend::Cpu));
assert_eq!("GPU".parse::<DefaultBackend>(), Ok(DefaultBackend::Gpu));
assert_eq!(
"tensor-network".parse::<DefaultBackend>(),
Ok(DefaultBackend::TensorNetwork)
);
assert_eq!(
"stabilizer".parse::<DefaultBackend>(),
Ok(DefaultBackend::Stabilizer)
);
assert_eq!("auto".parse::<DefaultBackend>(), Ok(DefaultBackend::Auto));
assert!("invalid".parse::<DefaultBackend>().is_err());
}
#[test]
fn test_config_builder() {
let config_data = Config::builder()
.num_threads(8)
.log_level(LogLevel::Debug)
.memory_limit_gb(16)
.default_backend(DefaultBackend::Gpu)
.enable_gpu(true)
.enable_simd(true)
.enable_telemetry(false)
.cache_dir("/tmp/quantrs2")
.max_cache_size_mb(512)
.build();
assert_eq!(config_data.num_threads, Some(8));
assert_eq!(config_data.log_level, LogLevel::Debug);
assert_eq!(
config_data.memory_limit_bytes,
Some(16 * 1024 * 1024 * 1024)
);
assert_eq!(config_data.default_backend, DefaultBackend::Gpu);
assert!(config_data.enable_gpu);
assert!(config_data.enable_simd);
assert!(!config_data.enable_telemetry);
assert_eq!(config_data.cache_dir, Some("/tmp/quantrs2".to_string()));
assert_eq!(config_data.max_cache_size_bytes, Some(512 * 1024 * 1024));
}
#[test]
fn test_global_config() {
let config = Config::global();
config.set_num_threads(4);
assert_eq!(config.num_threads(), Some(4));
config.set_log_level(LogLevel::Info);
assert_eq!(config.log_level(), LogLevel::Info);
config.set_memory_limit_gb(8);
assert_eq!(config.memory_limit_bytes(), Some(8 * 1024 * 1024 * 1024));
config.set_default_backend(DefaultBackend::TensorNetwork);
assert_eq!(config.default_backend(), DefaultBackend::TensorNetwork);
config.set_gpu_enabled(false);
assert!(!config.is_gpu_enabled());
config.set_simd_enabled(false);
assert!(!config.is_simd_enabled());
config.set_telemetry_enabled(true);
assert!(config.is_telemetry_enabled());
config.set_cache_dir("/test/cache");
assert_eq!(config.cache_dir(), Some("/test/cache".to_string()));
config.set_max_cache_size_mb(256);
assert_eq!(config.max_cache_size_bytes(), Some(256 * 1024 * 1024));
}
#[test]
fn test_config_snapshot() {
let data = ConfigData {
num_threads: Some(6),
log_level: LogLevel::Warn,
memory_limit_bytes: Some(1024),
default_backend: DefaultBackend::Cpu,
enable_gpu: true,
enable_simd: true,
enable_telemetry: false,
cache_dir: Some("/test".to_string()),
max_cache_size_bytes: Some(512),
};
let snapshot = data.clone();
assert_eq!(snapshot.num_threads, data.num_threads);
assert_eq!(snapshot.log_level, data.log_level);
assert_eq!(snapshot.memory_limit_bytes, data.memory_limit_bytes);
assert_eq!(snapshot.default_backend, data.default_backend);
assert_eq!(snapshot.enable_gpu, data.enable_gpu);
assert_eq!(snapshot.cache_dir, data.cache_dir);
}
}