use std::{default::Default, fmt, fmt::Write, iter::IntoIterator, marker::PhantomData};
use foldhash::fast::RandomState;
use indexmap::{IndexMap, map::Entry};
use pad_adapter::PadAdapter;
use super::{
Key, Symbol, TypedSymbol,
symbol::{DefaultSymbolHandler, KeyFormatter},
};
use crate::{
linear::LinearValues,
variables::{VariableDtype, VariableSafe},
};
#[derive(Default, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Values {
values: IndexMap<Key, Box<dyn VariableSafe>, RandomState>,
}
impl Values {
pub fn new() -> Self {
Values::default()
}
pub fn len(&self) -> usize {
self.values.len()
}
pub fn is_empty(&self) -> bool {
self.values.is_empty()
}
pub fn entry(&mut self, key: impl Symbol) -> Entry<'_, Key, Box<dyn VariableSafe>> {
self.values.entry(key.into())
}
pub fn insert<S, V>(&mut self, symbol: S, value: V) -> Option<Box<dyn VariableSafe>>
where
S: TypedSymbol<V>,
V: VariableDtype,
{
self.values.insert(symbol.into(), Box::new(value))
}
pub fn insert_unchecked<S, V>(&mut self, symbol: S, value: V) -> Option<Box<dyn VariableSafe>>
where
S: Symbol,
V: VariableDtype,
{
self.values.insert(symbol.into(), Box::new(value))
}
pub(crate) fn get_raw<S>(&self, symbol: S) -> Option<&dyn VariableSafe>
where
S: Symbol,
{
self.values.get(&symbol.into()).map(|f| f.as_ref())
}
pub fn get<S, V>(&self, symbol: S) -> Option<&V>
where
S: TypedSymbol<V>,
V: VariableDtype,
{
self.values
.get(&symbol.into())
.and_then(|value| value.downcast_ref::<V>())
}
pub fn get_unchecked<S, V>(&self, symbol: S) -> Option<&V>
where
S: Symbol,
V: VariableDtype,
{
self.values
.get(&symbol.into())
.and_then(|value| value.downcast_ref::<V>())
}
pub fn get_mut<S, V>(&mut self, symbol: S) -> Option<&mut V>
where
S: TypedSymbol<V>,
V: VariableDtype,
{
self.values
.get_mut(&symbol.into())
.and_then(|value| value.downcast_mut::<V>())
}
pub fn get_unchecked_mut<S, V>(&mut self, symbol: S) -> Option<&mut V>
where
S: Symbol,
V: VariableDtype,
{
self.values
.get_mut(&symbol.into())
.and_then(|value| value.downcast_mut::<V>())
}
pub fn remove<S, V>(&mut self, symbol: S) -> Option<V>
where
S: TypedSymbol<V>,
V: VariableDtype,
{
self.values
.shift_remove(&symbol.into())
.and_then(|value| value.downcast::<V>().ok())
.map(|value| *value)
}
pub fn iter(&self) -> impl Iterator<Item = (&Key, &Box<dyn VariableSafe>)> {
self.values.iter()
}
pub fn filter<'a, T: 'a + VariableSafe>(&'a self) -> impl Iterator<Item = &'a T> {
self.values
.iter()
.filter_map(|(_, value)| value.downcast_ref::<T>())
}
pub fn oplus_mut(&mut self, delta: &LinearValues) {
for (key, value) in delta.iter() {
if let Some(v) = self.values.get_mut(key) {
assert!(v.dim() == value.len(), "Dimension mismatch in values oplus",);
v.oplus_mut(value);
}
}
}
}
impl fmt::Debug for Values {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&ValuesFormatter::<DefaultSymbolHandler>::new(self), f)
}
}
impl fmt::Display for Values {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&ValuesFormatter::<DefaultSymbolHandler>::new(self), f)
}
}
pub struct ValuesFormatter<'v, KF> {
values: &'v Values,
kf: PhantomData<KF>,
}
impl<'v, KF> ValuesFormatter<'v, KF> {
pub fn new(values: &'v Values) -> Self {
Self {
values,
kf: Default::default(),
}
}
}
impl<KF: KeyFormatter> fmt::Display for ValuesFormatter<'_, KF> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let precision = f.precision().unwrap_or(3);
if f.alternate() {
f.write_str("Values {\n")?;
let mut pad = PadAdapter::new(f);
for (key, value) in self.values.iter() {
KF::fmt(&mut pad, *key)?;
#[allow(clippy::uninlined_format_args)]
writeln!(pad, ": {:#.p$},", value, p = precision)?;
}
} else {
f.write_str("Values { ")?;
for (key, value) in self.values.iter() {
KF::fmt(f, *key)?;
#[allow(clippy::uninlined_format_args)]
write!(f, ": {:.p$}, ", value, p = precision)?;
}
}
f.write_str("}")
}
}
impl<KF: KeyFormatter> fmt::Debug for ValuesFormatter<'_, KF> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let precision = f.precision().unwrap_or(3);
if f.alternate() {
f.write_str("Values {\n")?;
let mut pad = PadAdapter::new(f);
for (key, value) in self.values.iter() {
KF::fmt(&mut pad, *key)?;
#[allow(clippy::uninlined_format_args)]
writeln!(pad, ": {:#.p$?},", value, p = precision)?;
}
} else {
f.write_str("Values { ")?;
for (key, value) in self.values.iter() {
KF::fmt(f, *key)?;
#[allow(clippy::uninlined_format_args)]
write!(f, ": {:.p$?}, ", value, p = precision)?;
}
}
f.write_str("}")
}
}
impl IntoIterator for Values {
type Item = (Key, Box<dyn VariableSafe>);
type IntoIter = indexmap::map::IntoIter<Key, Box<dyn VariableSafe>>;
fn into_iter(self) -> Self::IntoIter {
self.values.into_iter()
}
}