#[cfg(feature = "ffi")]
pub(crate) mod ffi;
#[cfg(all(feature = "polars", feature = "contrib"))]
mod polars;
#[cfg(all(feature = "polars", feature = "contrib"))]
pub use polars::*;
#[cfg(feature = "contrib")]
use std::any::Any;
use std::collections::HashMap;
use std::hash::Hash;
use std::marker::PhantomData;
use std::ops::Bound;
use crate::core::Domain;
use crate::error::Fallible;
use crate::traits::{CheckAtom, CheckNull, HasNull, ProductOrd};
use std::fmt::{Debug, Formatter};
use bitvec::prelude::{BitVec, Lsb0};
#[cfg(feature = "contrib")]
mod poly;
#[derive(Clone, PartialEq)]
pub struct AtomDomain<T: CheckAtom> {
pub bounds: Option<Bounds<T>>,
nan: bool,
}
impl<T: CheckAtom> Debug for AtomDomain<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
let bounds = self
.bounds
.as_ref()
.map(|b| format!("bounds={:?}, ", b))
.unwrap_or_default();
let nan = self.nan.then(|| "nan=true, ").unwrap_or_default();
write!(f, "AtomDomain({}{}T={})", bounds, nan, type_name!(T))
}
}
impl<T: CheckAtom> Default for AtomDomain<T> {
fn default() -> Self {
AtomDomain {
bounds: None,
nan: T::NULLABLE,
}
}
}
impl<T: CheckAtom> AtomDomain<T> {
pub fn new(bounds: Option<Bounds<T>>, nan: Option<NaN<T>>) -> Self {
AtomDomain {
bounds,
nan: nan.is_some(),
}
}
pub fn nan(&self) -> bool {
self.nan
}
pub fn assert_non_nan(&self) -> Fallible<()> {
if self.nan() {
return fallible!(FailedFunction, "Domain contains nan");
}
Ok(())
}
}
impl<T: CheckAtom> AtomDomain<T> {
pub fn new_non_nan() -> Self {
AtomDomain {
bounds: None,
nan: false,
}
}
}
impl<T: CheckAtom + PartialOrd + Debug> AtomDomain<T> {
pub fn new_closed(bounds: (T, T)) -> Fallible<Self> {
Ok(AtomDomain {
bounds: Some(Bounds::new_closed(bounds)?),
nan: false,
})
}
pub fn get_closed_bounds(&self) -> Fallible<(T, T)> {
let bounds = self.bounds.as_ref().ok_or_else(|| {
err!(
MakeTransformation,
"input domain must consist of bounded data. Either specify bounds in the input domain or use make_clamp."
)
})?;
match (&bounds.lower, &bounds.upper) {
(Bound::Included(l), Bound::Included(u)) => Ok((l.clone(), u.clone())),
_ => fallible!(MakeTransformation, "bounds are not closed"),
}
}
}
impl<T: CheckAtom> Domain for AtomDomain<T> {
type Carrier = T;
fn member(&self, val: &Self::Carrier) -> Fallible<bool> {
val.check_member(self.bounds.clone(), self.nan)
}
}
#[derive(PartialEq)]
pub struct NaN<T> {
pub _marker: PhantomData<T>,
}
impl<T> Clone for NaN<T> {
fn clone(&self) -> Self {
Self {
_marker: self._marker.clone(),
}
}
}
impl<T: HasNull> Default for NaN<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: HasNull> NaN<T> {
pub fn new() -> Self {
NaN {
_marker: PhantomData,
}
}
}
#[derive(Clone, PartialEq)]
pub struct Bounds<T> {
lower: Bound<T>,
upper: Bound<T>,
}
impl<T: PartialOrd + Debug + CheckNull> Bounds<T> {
pub fn new_closed(bounds: (T, T)) -> Fallible<Self> {
Self::new((Bound::Included(bounds.0), Bound::Included(bounds.1)))
}
pub fn new(bounds: (Bound<T>, Bound<T>)) -> Fallible<Self> {
let (lower, upper) = bounds;
fn get<T>(value: &Bound<T>) -> Option<&T> {
match value {
Bound::Included(value) => Some(value),
Bound::Excluded(value) => Some(value),
Bound::Unbounded => None,
}
}
if let Some((v_lower, v_upper)) = get(&lower).zip(get(&upper)) {
if v_lower > v_upper {
return fallible!(
MakeDomain,
"lower bound ({:?}) may not be greater than upper bound ({:?})",
v_lower,
v_upper
);
}
if v_lower == v_upper {
match (&lower, &upper) {
(Bound::Included(l), Bound::Excluded(u)) => {
return fallible!(
MakeDomain,
"upper bound ({:?}) excludes inclusive lower bound ({:?})",
l,
u
);
}
(Bound::Excluded(l), Bound::Included(u)) => {
return fallible!(
MakeDomain,
"lower bound ({:?}) excludes inclusive upper bound ({:?})",
l,
u
);
}
_ => (),
}
}
}
if let Some(lower) = get(&lower) {
if lower.is_null() {
return fallible!(FailedFunction, "lower must not be null");
}
}
if let Some(upper) = get(&upper) {
if upper.is_null() {
return fallible!(FailedFunction, "upper must not be null");
}
}
Ok(Bounds { lower, upper })
}
pub fn lower(&self) -> Option<&T> {
match &self.lower {
Bound::Included(v) => Some(v),
Bound::Excluded(v) => Some(v),
Bound::Unbounded => None,
}
}
pub fn upper(&self) -> Option<&T> {
match &self.upper {
Bound::Included(v) => Some(v),
Bound::Excluded(v) => Some(v),
Bound::Unbounded => None,
}
}
}
impl<T: Clone> Bounds<T> {
pub fn get_closed(&self) -> Fallible<(T, T)> {
match (&self.lower, &self.upper) {
(Bound::Included(lower), Bound::Included(upper)) => Ok((lower.clone(), upper.clone())),
_ => fallible!(MakeDomain, "Bounds are not closed"),
}
}
}
impl<T: Debug> Debug for Bounds<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
let lower = match &self.lower {
Bound::Included(v) => format!("[{:?}", v),
Bound::Excluded(v) => format!("({:?}", v),
Bound::Unbounded => "(-∞".to_string(),
};
let upper = match &self.upper {
Bound::Included(v) => format!("{:?}]", v),
Bound::Excluded(v) => format!("{:?})", v),
Bound::Unbounded => "∞)".to_string(),
};
write!(f, "{}, {}", lower, upper)
}
}
impl<T: Clone + ProductOrd> Bounds<T> {
pub fn member(&self, val: &T) -> Fallible<bool> {
Ok(match &self.lower {
Bound::Included(bound) => val.total_ge(bound)?,
Bound::Excluded(bound) => val.total_gt(bound)?,
Bound::Unbounded => true,
} && match &self.upper {
Bound::Included(bound) => val.total_le(bound)?,
Bound::Excluded(bound) => val.total_lt(bound)?,
Bound::Unbounded => true,
})
}
}
#[derive(Clone, PartialEq, Debug, Default)]
pub struct MapDomain<DK: Domain, DV: Domain>
where
DK::Carrier: Eq + Hash,
{
pub key_domain: DK,
pub value_domain: DV,
}
impl<DK: Domain, DV: Domain> MapDomain<DK, DV>
where
DK::Carrier: Eq + Hash,
{
pub fn new(key_domain: DK, value_domain: DV) -> Self {
MapDomain {
key_domain,
value_domain,
}
}
}
impl<DK: Domain, DV: Domain> Domain for MapDomain<DK, DV>
where
DK::Carrier: Eq + Hash,
{
type Carrier = HashMap<DK::Carrier, DV::Carrier>;
fn member(&self, val: &Self::Carrier) -> Fallible<bool> {
for (k, v) in val {
if !self.key_domain.member(k)? || !self.value_domain.member(v)? {
return Ok(false);
}
}
Ok(true)
}
}
#[derive(Clone, PartialEq)]
pub struct VectorDomain<D: Domain> {
pub element_domain: D,
pub size: Option<usize>,
}
impl<D: Domain> Debug for VectorDomain<D> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
let size_str = self
.size
.map(|size| format!(", size={:?}", size))
.unwrap_or_default();
write!(f, "VectorDomain({:?}{})", self.element_domain, size_str)
}
}
impl<D: Domain + Default> Default for VectorDomain<D> {
fn default() -> Self {
Self::new(D::default())
}
}
impl<D: Domain> VectorDomain<D> {
pub fn new(element_domain: D) -> Self {
VectorDomain {
element_domain,
size: None,
}
}
pub fn with_size(mut self, size: usize) -> Self {
self.size = Some(size);
self
}
pub fn without_size(mut self) -> Self {
self.size = None;
self
}
}
impl<D: Domain> Domain for VectorDomain<D> {
type Carrier = Vec<D::Carrier>;
fn member(&self, val: &Self::Carrier) -> Fallible<bool> {
for e in val {
if !self.element_domain.member(e)? {
return Ok(false);
}
}
if let Some(size) = self.size {
if size != val.len() {
return Ok(false);
}
}
Ok(true)
}
}
pub type BitVector = BitVec<u8, Lsb0>;
#[derive(Clone, PartialEq)]
pub struct BitVectorDomain {
pub max_weight: Option<usize>,
}
impl Debug for BitVectorDomain {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
let weight_str = self
.max_weight
.map(|max_weight| format!("weight={:?}", max_weight))
.unwrap_or_default();
write!(f, "BitVectorDomain({})", weight_str)
}
}
impl Default for BitVectorDomain {
fn default() -> Self {
Self::new()
}
}
impl BitVectorDomain {
pub fn new() -> Self {
BitVectorDomain { max_weight: None }
}
pub fn with_max_weight(mut self, max_weight: usize) -> Self {
self.max_weight = Some(max_weight);
self
}
pub fn without_max_weight(mut self) -> Self {
self.max_weight = None;
self
}
}
impl Domain for BitVectorDomain {
type Carrier = BitVector; fn member(&self, val: &Self::Carrier) -> Fallible<bool> {
Ok(if let Some(max_weight) = self.max_weight {
val.count_ones() <= max_weight
} else {
true
})
}
}
#[derive(Clone, PartialEq)]
pub struct OptionDomain<D: Domain> {
pub element_domain: D,
}
impl<D: Domain + Default> Default for OptionDomain<D> {
fn default() -> Self {
Self::new(D::default())
}
}
impl<D: Domain> OptionDomain<D> {
pub fn new(member_domain: D) -> Self {
OptionDomain {
element_domain: member_domain,
}
}
}
impl<D: Domain> Debug for OptionDomain<D> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(f, "OptionDomain({:?})", self.element_domain)
}
}
impl<D: Domain> Domain for OptionDomain<D> {
type Carrier = Option<D::Carrier>;
fn member(&self, value: &Self::Carrier) -> Fallible<bool> {
value
.as_ref()
.map(|v| self.element_domain.member(v))
.unwrap_or(Ok(true))
}
}
macro_rules! type_name {
($ty:ty) => {
std::any::type_name::<$ty>()
.split("::")
.last()
.unwrap_or("")
};
}
pub(crate) use type_name;
#[cfg(feature = "contrib")]
pub use contrib::*;
#[cfg(feature = "contrib")]
mod contrib {
use super::*;
#[derive(Clone, PartialEq, Debug)]
pub struct PairDomain<D0: Domain, D1: Domain>(pub D0, pub D1);
impl<D0: Domain, D1: Domain> PairDomain<D0, D1> {
pub fn new(element_domain0: D0, element_domain1: D1) -> Self {
PairDomain(element_domain0, element_domain1)
}
}
impl<D0: Domain, D1: Domain> Domain for PairDomain<D0, D1> {
type Carrier = (D0::Carrier, D1::Carrier);
fn member(&self, val: &Self::Carrier) -> Fallible<bool> {
Ok(self.0.member(&val.0)? && self.1.member(&val.1)?)
}
}
#[derive(Clone, PartialEq, Debug)]
pub struct BoxDomain<D: Domain> {
element_domain: Box<D>,
}
impl<D: Domain> BoxDomain<D> {
pub fn new(element_domain: Box<D>) -> Self {
BoxDomain { element_domain }
}
}
impl<D: Domain> Domain for BoxDomain<D> {
type Carrier = Box<D::Carrier>;
fn member(&self, val: &Self::Carrier) -> Fallible<bool> {
self.element_domain.member(val)
}
}
#[derive(Clone, PartialEq, Debug)]
pub struct DataDomain<D: Domain> {
pub form_domain: D,
}
impl<D: Domain> DataDomain<D> {
pub fn new(form_domain: D) -> Self {
DataDomain { form_domain }
}
}
impl<D: Domain> Domain for DataDomain<D>
where
D::Carrier: 'static,
{
type Carrier = Box<dyn Any>;
fn member(&self, val: &Self::Carrier) -> Fallible<bool> {
let val = val
.downcast_ref::<D::Carrier>()
.ok_or_else(|| err!(FailedCast, "failed to downcast to carrier type"))?;
self.form_domain.member(val)
}
}
}