use std::any::Any;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::{Arc, Mutex, Weak};
use once_cell::sync::Lazy;
use crate::algebra::scalar::KrystScalar;
use crate::matrix::backend::SparseBackend;
use crate::matrix::op::{StructureId, ValuesId};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OpFormat {
Any,
Dense,
Csr,
Csc,
BlockCsr,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct BackendFormatSupport {
pub dense: bool,
pub csr: bool,
pub csc: bool,
pub block_csr: bool,
}
impl BackendFormatSupport {
pub const fn new(dense: bool, csr: bool, csc: bool, block_csr: bool) -> Self {
Self {
dense,
csr,
csc,
block_csr,
}
}
pub const fn supports(self, format: OpFormat) -> bool {
match format {
OpFormat::Dense => self.dense,
OpFormat::Csr => self.csr,
OpFormat::Csc => self.csc,
OpFormat::BlockCsr => self.block_csr,
OpFormat::Any => true,
}
}
}
impl OpFormat {
#[inline]
pub fn is_any(self) -> bool {
matches!(self, OpFormat::Any)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FormatHint {
Csr,
Dense,
Csc,
}
impl From<FormatHint> for OpFormat {
fn from(hint: FormatHint) -> Self {
match hint {
FormatHint::Csr => OpFormat::Csr,
FormatHint::Dense => OpFormat::Dense,
FormatHint::Csc => OpFormat::Csc,
}
}
}
impl OpFormat {
pub fn to_format_hint(self) -> Option<FormatHint> {
match self {
OpFormat::Csr => Some(FormatHint::Csr),
OpFormat::Dense => Some(FormatHint::Dense),
OpFormat::Csc => Some(FormatHint::Csc),
OpFormat::Any | OpFormat::BlockCsr => None,
}
}
}
pub trait AsFormat<S: KrystScalar, B: SparseBackend<S>> {
fn as_csr(&self) -> Option<&B::Csr> {
None
}
fn to_csr_cached(&self, drop_tol: S::Real) -> Arc<B::Csr>;
fn as_csc(&self) -> Option<&B::Csc> {
None
}
fn to_csc_cached(&self, drop_tol: S::Real) -> Arc<B::Csc>;
fn structure_id_for_cache(&self) -> StructureId {
StructureId(0)
}
fn values_id_for_cache(&self) -> ValuesId {
ValuesId(0)
}
fn backend_format_support() -> BackendFormatSupport {
B::FORMAT_SUPPORT
}
}
#[cfg(feature = "backend-faer")]
pub type DefaultAsFormat<S> = dyn AsFormat<S, crate::matrix::backend::DefaultBackend>;
#[derive(Clone, Copy, Debug)]
pub(crate) struct FormatKey {
pub base_ptr: usize,
pub structure_id: u64,
pub values_id: u64,
pub drop_tol_bits: u64,
}
impl PartialEq for FormatKey {
fn eq(&self, other: &Self) -> bool {
self.base_ptr == other.base_ptr
&& self.structure_id == other.structure_id
&& self.values_id == other.values_id
&& self.drop_tol_bits == other.drop_tol_bits
}
}
impl Eq for FormatKey {}
impl Hash for FormatKey {
fn hash<H: Hasher>(&self, state: &mut H) {
self.base_ptr.hash(state);
self.structure_id.hash(state);
self.values_id.hash(state);
self.drop_tol_bits.hash(state);
}
}
#[inline]
pub(crate) fn format_key_from_ptr(
base_ptr: usize,
structure_id: StructureId,
values_id: ValuesId,
drop_tol: f64,
) -> FormatKey {
FormatKey {
base_ptr,
structure_id: structure_id.0,
values_id: values_id.0,
drop_tol_bits: drop_tol.to_bits(),
}
}
type CacheMap = HashMap<FormatKey, Weak<dyn Any + Send + Sync>>;
pub(crate) static CSR_CACHE: Lazy<Mutex<CacheMap>> = Lazy::new(|| Mutex::new(HashMap::new()));
pub(crate) static CSC_CACHE: Lazy<Mutex<CacheMap>> = Lazy::new(|| Mutex::new(HashMap::new()));
#[inline]
pub(crate) fn get_or_insert_csr<T: 'static + Send + Sync>(
key: FormatKey,
build: impl FnOnce() -> Arc<T>,
) -> Arc<T> {
if let Some(existing) = CSR_CACHE
.lock()
.unwrap()
.get(&key)
.and_then(|w| w.upgrade())
{
if let Ok(typed) = existing.downcast::<T>() {
return typed;
}
}
let arc: Arc<T> = build();
let erased: Arc<dyn Any + Send + Sync> = arc.clone();
let mut cache = CSR_CACHE.lock().unwrap();
cache.insert(key, Arc::downgrade(&erased));
arc
}
#[inline]
pub(crate) fn get_or_insert_csc<T: 'static + Send + Sync>(
key: FormatKey,
build: impl FnOnce() -> Arc<T>,
) -> Arc<T> {
if let Some(existing) = CSC_CACHE
.lock()
.unwrap()
.get(&key)
.and_then(|w| w.upgrade())
{
if let Ok(typed) = existing.downcast::<T>() {
return typed;
}
}
let arc: Arc<T> = build();
let erased: Arc<dyn Any + Send + Sync> = arc.clone();
let mut cache = CSC_CACHE.lock().unwrap();
cache.insert(key, Arc::downgrade(&erased));
arc
}