use crate::hlist::{HCons, HNil};
use crate::indices::{Here, There};
use crate::traits::{Func, Poly, ToMut, ToRef};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(PartialEq, Debug, Eq, Clone, Copy, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum Coproduct<H, T> {
Inl(H),
Inr(T),
}
#[derive(PartialEq, Debug, Eq, Clone, Copy, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum CNil {}
impl<Head, Tail> Coproduct<Head, Tail> {
#[inline(always)]
pub fn inject<T, Index>(to_insert: T) -> Self
where
Self: CoprodInjector<T, Index>,
{
CoprodInjector::inject(to_insert)
}
#[inline(always)]
pub fn get<S, Index>(&self) -> Option<&S>
where
Self: CoproductSelector<S, Index>,
{
CoproductSelector::get(self)
}
#[inline(always)]
pub fn take<T, Index>(self) -> Option<T>
where
Self: CoproductTaker<T, Index>,
{
CoproductTaker::take(self)
}
#[inline(always)]
pub fn uninject<T, Index>(self) -> Result<T, <Self as CoprodUninjector<T, Index>>::Remainder>
where
Self: CoprodUninjector<T, Index>,
{
CoprodUninjector::uninject(self)
}
#[inline(always)]
pub fn subset<Targets, Indices>(
self,
) -> Result<Targets, <Self as CoproductSubsetter<Targets, Indices>>::Remainder>
where
Self: CoproductSubsetter<Targets, Indices>,
{
CoproductSubsetter::subset(self)
}
#[inline(always)]
pub fn embed<Targets, Indices>(self) -> Targets
where
Self: CoproductEmbedder<Targets, Indices>,
{
CoproductEmbedder::embed(self)
}
#[inline(always)]
pub fn to_ref<'a>(&'a self) -> <Self as ToRef<'a>>::Output
where
Self: ToRef<'a>,
{
ToRef::to_ref(self)
}
#[inline(always)]
pub fn to_mut<'a>(&'a mut self) -> <Self as ToMut<'a>>::Output
where
Self: ToMut<'a>,
{
ToMut::to_mut(self)
}
#[inline(always)]
pub fn fold<Output, Folder>(self, folder: Folder) -> Output
where
Self: CoproductFoldable<Folder, Output>,
{
CoproductFoldable::fold(self, folder)
}
#[inline(always)]
pub fn map<F>(self, mapper: F) -> <Self as CoproductMappable<F>>::Output
where
Self: CoproductMappable<F>,
{
CoproductMappable::map(self, mapper)
}
}
impl<T> Coproduct<T, CNil> {
#[inline(always)]
pub fn extract(self) -> T {
match self {
Coproduct::Inl(v) => v,
Coproduct::Inr(never) => match never {},
}
}
}
pub trait CoprodInjector<InjectType, Index> {
fn inject(to_insert: InjectType) -> Self;
}
impl<I, Tail> CoprodInjector<I, Here> for Coproduct<I, Tail> {
fn inject(to_insert: I) -> Self {
Coproduct::Inl(to_insert)
}
}
impl<Head, I, Tail, TailIndex> CoprodInjector<I, There<TailIndex>> for Coproduct<Head, Tail>
where
Tail: CoprodInjector<I, TailIndex>,
{
fn inject(to_insert: I) -> Self {
let tail_inserted = <Tail as CoprodInjector<I, TailIndex>>::inject(to_insert);
Coproduct::Inr(tail_inserted)
}
}
pub trait CoproductSelector<S, I> {
fn get(&self) -> Option<&S>;
}
impl<Head, Tail> CoproductSelector<Head, Here> for Coproduct<Head, Tail> {
fn get(&self) -> Option<&Head> {
use self::Coproduct::*;
match *self {
Inl(ref thing) => Some(thing),
_ => None, }
}
}
impl<Head, FromTail, Tail, TailIndex> CoproductSelector<FromTail, There<TailIndex>>
for Coproduct<Head, Tail>
where
Tail: CoproductSelector<FromTail, TailIndex>,
{
fn get(&self) -> Option<&FromTail> {
use self::Coproduct::*;
match *self {
Inr(ref rest) => rest.get(),
_ => None, }
}
}
pub trait CoproductTaker<S, I> {
fn take(self) -> Option<S>;
}
impl<Head, Tail> CoproductTaker<Head, Here> for Coproduct<Head, Tail> {
fn take(self) -> Option<Head> {
use self::Coproduct::*;
match self {
Inl(thing) => Some(thing),
_ => None, }
}
}
impl<Head, FromTail, Tail, TailIndex> CoproductTaker<FromTail, There<TailIndex>>
for Coproduct<Head, Tail>
where
Tail: CoproductTaker<FromTail, TailIndex>,
{
fn take(self) -> Option<FromTail> {
use self::Coproduct::*;
match self {
Inr(rest) => rest.take(),
_ => None, }
}
}
pub trait CoproductFoldable<Folder, Output> {
fn fold(self, f: Folder) -> Output;
}
impl<P, R, CH, CTail> CoproductFoldable<Poly<P>, R> for Coproduct<CH, CTail>
where
P: Func<CH, Output = R>,
CTail: CoproductFoldable<Poly<P>, R>,
{
fn fold(self, f: Poly<P>) -> R {
use self::Coproduct::*;
match self {
Inl(r) => P::call(r),
Inr(rest) => rest.fold(f),
}
}
}
impl<F, R, FTail, CH, CTail> CoproductFoldable<HCons<F, FTail>, R> for Coproduct<CH, CTail>
where
F: FnOnce(CH) -> R,
CTail: CoproductFoldable<FTail, R>,
{
fn fold(self, f: HCons<F, FTail>) -> R {
use self::Coproduct::*;
let f_head = f.head;
let f_tail = f.tail;
match self {
Inl(r) => (f_head)(r),
Inr(rest) => rest.fold(f_tail),
}
}
}
impl<F, R> CoproductFoldable<F, R> for CNil {
fn fold(self, _: F) -> R {
unreachable!()
}
}
pub trait CoproductMappable<Mapper> {
type Output;
fn map(self, f: Mapper) -> Self::Output;
}
impl<F, R, MapperTail, CH, CTail> CoproductMappable<HCons<F, MapperTail>> for Coproduct<CH, CTail>
where
F: FnOnce(CH) -> R,
CTail: CoproductMappable<MapperTail>,
{
type Output = Coproduct<R, <CTail as CoproductMappable<MapperTail>>::Output>;
#[inline]
fn map(self, mapper: HCons<F, MapperTail>) -> Self::Output {
match self {
Coproduct::Inl(l) => Coproduct::Inl((mapper.head)(l)),
Coproduct::Inr(rest) => Coproduct::Inr(rest.map(mapper.tail)),
}
}
}
impl<'a, F, R, MapperTail, CH, CTail> CoproductMappable<&'a HCons<F, MapperTail>>
for Coproduct<CH, CTail>
where
F: Fn(CH) -> R,
CTail: CoproductMappable<&'a MapperTail>,
{
type Output = Coproduct<R, <CTail as CoproductMappable<&'a MapperTail>>::Output>;
#[inline]
fn map(self, mapper: &'a HCons<F, MapperTail>) -> Self::Output {
match self {
Coproduct::Inl(l) => Coproduct::Inl((mapper.head)(l)),
Coproduct::Inr(rest) => Coproduct::Inr(rest.map(&mapper.tail)),
}
}
}
impl<'a, F, R, MapperTail, CH, CTail> CoproductMappable<&'a mut HCons<F, MapperTail>>
for Coproduct<CH, CTail>
where
F: FnMut(CH) -> R,
CTail: CoproductMappable<&'a mut MapperTail>,
{
type Output = Coproduct<R, <CTail as CoproductMappable<&'a mut MapperTail>>::Output>;
#[inline]
fn map(self, mapper: &'a mut HCons<F, MapperTail>) -> Self::Output {
match self {
Coproduct::Inl(l) => Coproduct::Inl((mapper.head)(l)),
Coproduct::Inr(rest) => Coproduct::Inr(rest.map(&mut mapper.tail)),
}
}
}
impl<P, CH, CTail> CoproductMappable<Poly<P>> for Coproduct<CH, CTail>
where
P: Func<CH>,
CTail: CoproductMappable<Poly<P>>,
{
type Output = Coproduct<<P as Func<CH>>::Output, <CTail as CoproductMappable<Poly<P>>>::Output>;
#[inline]
fn map(self, poly: Poly<P>) -> Self::Output {
match self {
Coproduct::Inl(l) => Coproduct::Inl(P::call(l)),
Coproduct::Inr(rest) => Coproduct::Inr(rest.map(poly)),
}
}
}
impl<'a, P, CH, CTail> CoproductMappable<&'a Poly<P>> for Coproduct<CH, CTail>
where
P: Func<CH>,
CTail: CoproductMappable<&'a Poly<P>>,
{
type Output =
Coproduct<<P as Func<CH>>::Output, <CTail as CoproductMappable<&'a Poly<P>>>::Output>;
#[inline]
fn map(self, poly: &'a Poly<P>) -> Self::Output {
match self {
Coproduct::Inl(l) => Coproduct::Inl(P::call(l)),
Coproduct::Inr(rest) => Coproduct::Inr(rest.map(poly)),
}
}
}
impl<'a, P, CH, CTail> CoproductMappable<&'a mut Poly<P>> for Coproduct<CH, CTail>
where
P: Func<CH>,
CTail: CoproductMappable<&'a mut Poly<P>>,
{
type Output =
Coproduct<<P as Func<CH>>::Output, <CTail as CoproductMappable<&'a mut Poly<P>>>::Output>;
#[inline]
fn map(self, poly: &'a mut Poly<P>) -> Self::Output {
match self {
Coproduct::Inl(l) => Coproduct::Inl(P::call(l)),
Coproduct::Inr(rest) => Coproduct::Inr(rest.map(poly)),
}
}
}
impl<F, R, CH, CTail> CoproductMappable<F> for Coproduct<CH, CTail>
where
F: FnMut(CH) -> R,
CTail: CoproductMappable<F>,
{
type Output = Coproduct<R, <CTail as CoproductMappable<F>>::Output>;
#[inline]
fn map(self, mut f: F) -> Self::Output {
match self {
Coproduct::Inl(l) => Coproduct::Inl(f(l)),
Coproduct::Inr(rest) => Coproduct::Inr(rest.map(f)),
}
}
}
impl<F> CoproductMappable<F> for CNil {
type Output = CNil;
#[inline(always)]
fn map(self, _: F) -> Self::Output {
match self {}
}
}
impl<'a, CH: 'a, CTail> ToRef<'a> for Coproduct<CH, CTail>
where
CTail: ToRef<'a>,
{
type Output = Coproduct<&'a CH, <CTail as ToRef<'a>>::Output>;
#[inline(always)]
fn to_ref(&'a self) -> Self::Output {
match *self {
Coproduct::Inl(ref r) => Coproduct::Inl(r),
Coproduct::Inr(ref rest) => Coproduct::Inr(rest.to_ref()),
}
}
}
impl<'a> ToRef<'a> for CNil {
type Output = CNil;
fn to_ref(&'a self) -> CNil {
match *self {}
}
}
impl<'a, CH: 'a, CTail> ToMut<'a> for Coproduct<CH, CTail>
where
CTail: ToMut<'a>,
{
type Output = Coproduct<&'a mut CH, <CTail as ToMut<'a>>::Output>;
#[inline(always)]
fn to_mut(&'a mut self) -> Self::Output {
match *self {
Coproduct::Inl(ref mut r) => Coproduct::Inl(r),
Coproduct::Inr(ref mut rest) => Coproduct::Inr(rest.to_mut()),
}
}
}
impl<'a> ToMut<'a> for CNil {
type Output = CNil;
fn to_mut(&'a mut self) -> CNil {
match *self {}
}
}
pub trait CoprodUninjector<T, Idx>: CoprodInjector<T, Idx> {
type Remainder;
fn uninject(self) -> Result<T, Self::Remainder>;
}
impl<Hd, Tl> CoprodUninjector<Hd, Here> for Coproduct<Hd, Tl> {
type Remainder = Tl;
fn uninject(self) -> Result<Hd, Tl> {
match self {
Coproduct::Inl(h) => Ok(h),
Coproduct::Inr(t) => Err(t),
}
}
}
impl<Hd, Tl, T, N> CoprodUninjector<T, There<N>> for Coproduct<Hd, Tl>
where
Tl: CoprodUninjector<T, N>,
{
type Remainder = Coproduct<Hd, Tl::Remainder>;
fn uninject(self) -> Result<T, Self::Remainder> {
match self {
Coproduct::Inl(h) => Err(Coproduct::Inl(h)),
Coproduct::Inr(t) => t.uninject().map_err(Coproduct::Inr),
}
}
}
pub trait CoproductSubsetter<Targets, Indices>: Sized {
type Remainder;
fn subset(self) -> Result<Targets, Self::Remainder>;
}
impl<Choices, THead, TTail, NHead, NTail, Rem>
CoproductSubsetter<Coproduct<THead, TTail>, HCons<NHead, NTail>> for Choices
where
Self: CoprodUninjector<THead, NHead, Remainder = Rem>,
Rem: CoproductSubsetter<TTail, NTail>,
{
type Remainder = <Rem as CoproductSubsetter<TTail, NTail>>::Remainder;
fn subset(self) -> Result<Coproduct<THead, TTail>, Self::Remainder> {
match self.uninject() {
Ok(good) => Ok(Coproduct::Inl(good)),
Err(bads) => match bads.subset() {
Ok(goods) => Ok(Coproduct::Inr(goods)),
Err(bads) => Err(bads),
},
}
}
}
impl<Choices> CoproductSubsetter<CNil, HNil> for Choices {
type Remainder = Self;
#[inline(always)]
fn subset(self) -> Result<CNil, Self::Remainder> {
Err(self)
}
}
pub trait CoproductEmbedder<Out, Indices> {
fn embed(self) -> Out;
}
impl CoproductEmbedder<CNil, HNil> for CNil {
fn embed(self) -> CNil {
match self {
}
}
}
impl<Head, Tail> CoproductEmbedder<Coproduct<Head, Tail>, HNil> for CNil
where
CNil: CoproductEmbedder<Tail, HNil>,
{
fn embed(self) -> Coproduct<Head, Tail> {
match self {
}
}
}
impl<Head, Tail, Out, NHead, NTail> CoproductEmbedder<Out, HCons<NHead, NTail>>
for Coproduct<Head, Tail>
where
Out: CoprodInjector<Head, NHead>,
Tail: CoproductEmbedder<Out, NTail>,
{
fn embed(self) -> Out {
match self {
Coproduct::Inl(this) => Out::inject(this),
Coproduct::Inr(those) => those.embed(),
}
}
}
#[cfg(test)]
mod tests {
use super::Coproduct::*;
use super::*;
#[test]
fn test_coproduct_inject() {
type I32StrBool = Coprod!(i32, &'static str, bool);
let co1 = I32StrBool::inject(3);
assert_eq!(co1, Inl(3));
let get_from_1a: Option<&i32> = co1.get();
let get_from_1b: Option<&bool> = co1.get();
assert_eq!(get_from_1a, Some(&3));
assert_eq!(get_from_1b, None);
let co2 = I32StrBool::inject(false);
assert_eq!(co2, Inr(Inr(Inl(false))));
let get_from_2a: Option<&i32> = co2.get();
let get_from_2b: Option<&bool> = co2.get();
assert_eq!(get_from_2a, None);
assert_eq!(get_from_2b, Some(&false));
}
#[test]
#[cfg(feature = "std")]
fn test_coproduct_fold_consuming() {
type I32F32StrBool = Coprod!(i32, f32, bool);
let co1 = I32F32StrBool::inject(3);
let folded = co1.fold(hlist![
|i| format!("int {}", i),
|f| format!("float {}", f),
|b| (if b { "t" } else { "f" }).to_string(),
]);
assert_eq!(folded, "int 3".to_string());
}
#[test]
fn test_coproduct_poly_fold_consuming() {
type I32F32StrBool = Coprod!(i32, f32, bool);
impl Func<i32> for P {
type Output = bool;
fn call(args: i32) -> Self::Output {
args > 100
}
}
impl Func<bool> for P {
type Output = bool;
fn call(args: bool) -> Self::Output {
args
}
}
impl Func<f32> for P {
type Output = bool;
fn call(args: f32) -> Self::Output {
args > 9000f32
}
}
struct P;
let co1 = I32F32StrBool::inject(3);
let folded = co1.fold(Poly(P));
assert!(!folded);
}
#[test]
#[cfg(feature = "std")]
fn test_coproduct_fold_non_consuming() {
type I32F32Bool = Coprod!(i32, f32, bool);
let co1 = I32F32Bool::inject(3);
let co2 = I32F32Bool::inject(true);
let co3 = I32F32Bool::inject(42f32);
assert_eq!(
co1.to_ref().fold(hlist![
|&i| format!("int {}", i),
|&f| format!("float {}", f),
|&b| (if b { "t" } else { "f" }).to_string(),
]),
"int 3".to_string()
);
assert_eq!(
co2.to_ref().fold(hlist![
|&i| format!("int {}", i),
|&f| format!("float {}", f),
|&b| (if b { "t" } else { "f" }).to_string(),
]),
"t".to_string()
);
assert_eq!(
co3.to_ref().fold(hlist![
|&i| format!("int {}", i),
|&f| format!("float {}", f),
|&b| (if b { "t" } else { "f" }).to_string(),
]),
"float 42".to_string()
);
}
#[test]
fn test_coproduct_uninject() {
type I32StrBool = Coprod!(i32, &'static str, bool);
let co1 = I32StrBool::inject(3);
let co2 = I32StrBool::inject("hello");
let co3 = I32StrBool::inject(false);
let uninject_i32_co1: Result<i32, _> = co1.uninject();
let uninject_str_co1: Result<&'static str, _> = co1.uninject();
let uninject_bool_co1: Result<bool, _> = co1.uninject();
assert_eq!(uninject_i32_co1, Ok(3));
assert!(uninject_str_co1.is_err());
assert!(uninject_bool_co1.is_err());
let uninject_i32_co2: Result<i32, _> = co2.uninject();
let uninject_str_co2: Result<&'static str, _> = co2.uninject();
let uninject_bool_co2: Result<bool, _> = co2.uninject();
assert!(uninject_i32_co2.is_err());
assert_eq!(uninject_str_co2, Ok("hello"));
assert!(uninject_bool_co2.is_err());
let uninject_i32_co3: Result<i32, _> = co3.uninject();
let uninject_str_co3: Result<&'static str, _> = co3.uninject();
let uninject_bool_co3: Result<bool, _> = co3.uninject();
assert!(uninject_i32_co3.is_err());
assert!(uninject_str_co3.is_err());
assert_eq!(uninject_bool_co3, Ok(false));
}
#[test]
fn test_coproduct_subset() {
type I32StrBool = Coprod!(i32, &'static str, bool);
let res: Result<CNil, _> = I32StrBool::inject(3).subset();
assert!(res.is_err());
if false {
#[allow(unreachable_code, clippy::diverging_sub_expression)]
{
#[allow(unused)]
let cnil: CNil = panic!();
let _res: Result<CNil, _> = cnil.subset();
let _ = res;
}
}
{
let co = I32StrBool::inject(3);
let res: Result<Coprod!(bool, i32), _> = co.subset();
assert_eq!(res, Ok(Coproduct::Inr(Coproduct::Inl(3))));
let co = I32StrBool::inject("4");
let res: Result<Coprod!(bool, i32), _> = co.subset();
assert_eq!(res, Err(Coproduct::Inl("4")));
}
}
#[test]
fn test_coproduct_embed() {
if false {
#[allow(unreachable_code, clippy::diverging_sub_expression)]
{
#[allow(unused)]
let cnil: CNil = panic!();
let _: CNil = cnil.embed();
#[allow(unused)]
let cnil: CNil = panic!();
let _: Coprod!(i32, bool) = cnil.embed();
}
}
#[derive(Debug, PartialEq)]
struct A;
#[derive(Debug, PartialEq)]
struct B;
#[derive(Debug, PartialEq)]
struct C;
{
let co_a = <Coprod!(C, A, B)>::inject(A);
let co_b = <Coprod!(C, A, B)>::inject(B);
let co_c = <Coprod!(C, A, B)>::inject(C);
let out_a: Coprod!(A, B, C) = co_a.embed();
let out_b: Coprod!(A, B, C) = co_b.embed();
let out_c: Coprod!(A, B, C) = co_c.embed();
assert_eq!(out_a, Coproduct::Inl(A));
assert_eq!(out_b, Coproduct::Inr(Coproduct::Inl(B)));
assert_eq!(out_c, Coproduct::Inr(Coproduct::Inr(Coproduct::Inl(C))));
}
#[allow(clippy::upper_case_acronyms)]
{
type ABC = Coprod!(A, B, C);
type BBB = Coprod!(B, B, B);
let b1 = BBB::inject::<_, Here>(B);
let b2 = BBB::inject::<_, There<Here>>(B);
let out1: ABC = b1.embed();
let out2: ABC = b2.embed();
assert_eq!(out1, Coproduct::Inr(Coproduct::Inl(B)));
assert_eq!(out2, Coproduct::Inr(Coproduct::Inl(B)));
}
}
#[test]
fn test_coproduct_map_ref() {
type I32Bool = Coprod!(i32, bool);
type I32BoolRef<'a> = Coprod!(i32, &'a bool);
fn map_it(co: &I32Bool) -> I32BoolRef<'_> {
let map_bool: fn(&bool) -> &bool = |b| b;
let mapper = hlist![|n: &i32| *n + 3, map_bool];
co.to_ref().map(mapper)
}
let co = I32Bool::inject(3);
let new = map_it(&co);
assert_eq!(new, I32BoolRef::inject(6))
}
#[test]
fn test_coproduct_map_with_ref_mapper() {
type I32Bool = Coprod!(i32, bool);
let mapper = hlist![|n| n + 3, |b: bool| !b];
let co = I32Bool::inject(3);
let co = co.map(&mapper);
let co = co.map(&mapper);
assert_eq!(co, I32Bool::inject(9));
let mapper = poly_fn!(|n: i32| -> i32 { n + 3 }, |b: bool| -> bool { !b });
let co = I32Bool::inject(3);
let co = co.map(&mapper);
let co = co.map(&mapper);
assert_eq!(co, I32Bool::inject(9));
type StrStr = Coprod!(String, String);
let captured = String::from("!");
let mapper = |s: String| format!("{}{}", s, &captured);
let co = StrStr::Inl(String::from("hi"));
let co = co.map(&mapper);
let co = co.map(&mapper);
assert_eq!(co, StrStr::Inl(String::from("hi!!")));
}
#[test]
fn test_coproduct_map_with_mut_mapper() {
type I32Bool = Coprod!(i32, bool);
let mut number = None;
let mut boolean = None;
let mut mapper = hlist![
|n: i32| {
number = Some(n);
n
},
|b: bool| {
boolean = Some(b);
b
},
];
let co = I32Bool::inject(3);
let co = co.map(&mut mapper);
assert_eq!(co, I32Bool::inject(3));
assert_eq!(number, Some(3));
assert_eq!(boolean, None);
let mut mapper = poly_fn!(
|n: i32| -> i32 {
n
},
|b: bool| -> bool {
b
},
);
let co = I32Bool::inject(3);
let co = co.map(&mut mapper);
assert_eq!(co, I32Bool::inject(3));
type StrStr = Coprod!(String, String);
let mut captured = String::new();
let mut mapper = |s: String| {
let s = format!("{s}!");
captured.push_str(&s);
s
};
let co = StrStr::Inl(String::from("hi"));
let co = co.map(&mut mapper);
let co = co.map(&mut mapper);
assert_eq!(co, StrStr::Inl(String::from("hi!!")));
assert_eq!(captured, String::from("hi!hi!!"));
}
}