#![allow(clippy::type_complexity)]
#[derive(Clone, Copy)]
pub struct Iso<A, B, To = fn(A) -> B, From = fn(B) -> A> {
pub to: To,
pub from: From,
_marker: std::marker::PhantomData<fn(A) -> B>,
}
impl<A, B, To, From> Iso<A, B, To, From>
where
To: Fn(A) -> B,
From: Fn(B) -> A,
{
pub fn new(to: To, from: From) -> Self {
Self {
to,
from,
_marker: std::marker::PhantomData,
}
}
#[inline]
pub fn forward(&self, a: A) -> B {
(self.to)(a)
}
#[inline]
pub fn backward(&self, b: B) -> A {
(self.from)(b)
}
}
impl<A, B, To, From> Iso<A, B, To, From>
where
To: Fn(A) -> B + Clone,
From: Fn(B) -> A + Clone,
{
pub fn reverse(self) -> Iso<B, A, From, To> {
Iso {
to: self.from,
from: self.to,
_marker: std::marker::PhantomData,
}
}
}
#[inline]
pub fn identity<A>() -> Iso<A, A, fn(A) -> A, fn(A) -> A> {
Iso {
to: |a| a,
from: |a| a,
_marker: std::marker::PhantomData,
}
}
#[inline]
pub fn swap<A, B>() -> Iso<(A, B), (B, A), fn((A, B)) -> (B, A), fn((B, A)) -> (A, B)> {
Iso {
to: |(a, b)| (b, a),
from: |(b, a)| (a, b),
_marker: std::marker::PhantomData,
}
}
#[inline]
pub fn unit_right<A>() -> Iso<A, (A, ()), fn(A) -> (A, ()), fn((A, ())) -> A> {
Iso {
to: |a| (a, ()),
from: |(a, ())| a,
_marker: std::marker::PhantomData,
}
}
#[inline]
pub fn unit_left<A>() -> Iso<A, ((), A), fn(A) -> ((), A), fn(((), A)) -> A> {
Iso {
to: |a| ((), a),
from: |((), a)| a,
_marker: std::marker::PhantomData,
}
}
#[inline]
pub fn assoc_product<A, B, C>()
-> Iso<((A, B), C), (A, (B, C)), fn(((A, B), C)) -> (A, (B, C)), fn((A, (B, C))) -> ((A, B), C)> {
Iso {
to: |((a, b), c)| (a, (b, c)),
from: |(a, (b, c))| ((a, b), c),
_marker: std::marker::PhantomData,
}
}
#[inline]
pub fn uncurry<A, B, C, F, G>(f: F) -> impl Fn((A, B)) -> C
where
F: Fn(A) -> G,
G: Fn(B) -> C,
{
move |(a, b)| f(a)(b)
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
mod iso_struct {
use super::*;
#[test]
fn iso_new_creates_isomorphism() {
let iso = Iso::new(|n: i32| n.to_string(), |s: String| s.parse().unwrap_or(0));
assert_eq!(iso.forward(42), "42");
assert_eq!(iso.backward("42".to_string()), 42);
}
#[test]
fn iso_reverse_swaps_directions() {
let iso = Iso::new(|n: i32| n as f64, |f: f64| f as i32);
let reversed = iso.reverse();
assert_eq!(reversed.forward(3.7), 3);
assert_eq!(reversed.backward(5), 5.0);
}
}
mod identity_iso {
use super::*;
#[rstest]
#[case::integer(42_i32)]
#[case::zero(0_i32)]
#[case::negative(-7_i32)]
fn identity_roundtrips(#[case] value: i32) {
let iso = identity::<i32>();
assert_eq!((iso.to)(value), value);
assert_eq!((iso.from)(value), value);
}
#[test]
fn identity_with_string() {
let iso = identity::<String>();
let s = String::from("hello");
assert_eq!((iso.to)(s.clone()), s);
}
}
mod swap_iso {
use super::*;
#[test]
fn swap_exchanges_components() {
let iso = swap::<i32, &str>();
assert_eq!((iso.to)((1, "hello")), ("hello", 1));
}
#[test]
fn swap_roundtrips() {
let iso = swap::<i32, i32>();
let pair = (1, 2);
assert_eq!((iso.from)((iso.to)(pair)), pair);
}
#[test]
fn swap_is_self_inverse() {
let iso = swap::<i32, i32>();
let pair = (3, 7);
assert_eq!((iso.to)((iso.to)(pair)), pair);
}
}
mod unit_isos {
use super::*;
#[test]
fn unit_right_adds_unit() {
let iso = unit_right::<i32>();
assert_eq!((iso.to)(42), (42, ()));
}
#[test]
fn unit_right_roundtrips() {
let iso = unit_right::<i32>();
assert_eq!((iso.from)((iso.to)(42)), 42);
}
#[test]
fn unit_left_adds_unit() {
let iso = unit_left::<i32>();
assert_eq!((iso.to)(42), ((), 42));
}
#[test]
fn unit_left_roundtrips() {
let iso = unit_left::<i32>();
assert_eq!((iso.from)((iso.to)(42)), 42);
}
}
mod assoc_product_iso {
use super::*;
#[test]
fn assoc_product_reassociates() {
let iso = assoc_product::<i32, i32, i32>();
assert_eq!((iso.to)(((1, 2), 3)), (1, (2, 3)));
}
#[test]
fn assoc_product_roundtrips() {
let iso = assoc_product::<i32, i32, i32>();
let left_assoc = ((1, 2), 3);
assert_eq!((iso.from)((iso.to)(left_assoc)), left_assoc);
}
}
mod uncurry_tests {
use super::*;
#[test]
fn uncurry_converts_curried_function() {
let curried_add = |a: i32| move |b: i32| a + b;
let uncurried = uncurry(curried_add);
assert_eq!(uncurried((2, 3)), 5);
}
#[test]
fn uncurry_with_subtraction() {
let curried_sub = |a: i32| move |b: i32| a - b;
let uncurried = uncurry(curried_sub);
assert_eq!(uncurried((10, 3)), 7);
}
#[test]
fn uncurry_with_multiple_inputs() {
let curried = |a: i32| move |b: i32| a * b;
let uncurried = uncurry(curried);
for (a, b) in [(1, 2), (0, 5), (-3, 4)] {
assert_eq!(uncurried((a, b)), a * b);
}
}
}
mod laws {
use super::*;
#[test]
fn identity_is_reflexive() {
let iso = identity::<i32>();
for x in [0, 1, -5, 100] {
assert_eq!((iso.from)((iso.to)(x)), x);
assert_eq!((iso.to)((iso.from)(x)), x);
}
}
#[test]
fn swap_is_symmetric() {
let iso = swap::<i32, i32>();
let reversed = iso.reverse();
for pair in [(1, 2), (0, 0), (-5, 10)] {
assert_eq!((iso.to)(pair), (reversed.from)(pair));
assert_eq!((iso.from)(pair), (reversed.to)(pair));
}
}
#[test]
fn unit_right_inverse_law() {
let iso = unit_right::<i32>();
for x in [0, 42, -7] {
assert_eq!((iso.from)((iso.to)(x)), x);
}
}
#[test]
fn unit_right_inverse_law_reverse() {
let iso = unit_right::<i32>();
for x in [0, 42, -7] {
let pair = (x, ());
assert_eq!((iso.to)((iso.from)(pair)), pair);
}
}
#[test]
fn assoc_product_inverse_laws() {
let iso = assoc_product::<i32, i32, i32>();
let left_assoc = ((1, 2), 3);
let right_assoc = (1, (2, 3));
assert_eq!((iso.from)((iso.to)(left_assoc)), left_assoc);
assert_eq!((iso.to)((iso.from)(right_assoc)), right_assoc);
}
}
}