use crate::ScalarLoss;
use crate::{Beam, Optic, Prism};
use std::convert::Infallible;
use terni::Imperfect;
#[derive(Clone, Copy)]
pub struct OpticPrism<S, A> {
match_fn: fn(&S) -> bool,
extract_fn: fn(&S) -> A,
review_fn: fn(A) -> S,
}
impl<S: 'static, A: 'static> OpticPrism<S, A> {
pub fn new(matches: fn(&S) -> bool, extract: fn(&S) -> A, review: fn(A) -> S) -> Self {
OpticPrism {
match_fn: matches,
extract_fn: extract,
review_fn: review,
}
}
pub fn matches(&self, s: &S) -> bool {
(self.match_fn)(s)
}
pub fn extract(&self, s: &S) -> Option<A> {
if (self.match_fn)(s) {
Some((self.extract_fn)(s))
} else {
None
}
}
pub fn review(&self, a: A) -> S {
(self.review_fn)(a)
}
}
impl<S: 'static, A: 'static> crate::Addressable for OpticPrism<S, A> {
fn oid(&self) -> crate::Oid {
let mut bytes = Vec::with_capacity(24);
bytes.extend_from_slice(&(self.match_fn as usize).to_le_bytes());
bytes.extend_from_slice(&(self.extract_fn as usize).to_le_bytes());
bytes.extend_from_slice(&(self.review_fn as usize).to_le_bytes());
crate::Oid::hash(&bytes)
}
}
impl<S: Clone + 'static, A: Clone + 'static> Prism for OpticPrism<S, A> {
type Input = Optic<(), S, Infallible, ScalarLoss>;
type Focused = Optic<S, A, Infallible, ScalarLoss>;
type Projected = Optic<A, A, Infallible, ScalarLoss>;
type Refracted = Optic<A, A, Infallible, ScalarLoss>;
fn focus(&self, beam: Self::Input) -> Self::Focused {
let s = beam.result().ok().expect("focus: Err beam").clone();
if (self.match_fn)(&s) {
let a = (self.extract_fn)(&s);
beam.next(a)
} else {
let sentinel = (self.extract_fn)(&s);
beam.tick(Imperfect::partial(sentinel, ScalarLoss::new(f64::INFINITY)))
}
}
fn project(&self, beam: Self::Focused) -> Self::Projected {
let a = beam.result().ok().expect("project: Err beam").clone();
beam.next(a)
}
fn settle(&self, beam: Self::Projected) -> Self::Refracted {
let a = beam.result().ok().expect("settle: Err beam").clone();
beam.next(a)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Beam as BeamTrait;
#[derive(Clone, Debug, PartialEq)]
enum Shape {
Circle(i32),
Square(i32),
}
fn shape_is_circle(s: &Shape) -> bool {
matches!(s, Shape::Circle(_))
}
fn shape_extract_circle(s: &Shape) -> i32 {
if let Shape::Circle(r) = s {
*r
} else {
-1
}
}
fn shape_review_circle(r: i32) -> Shape {
Shape::Circle(r)
}
fn circle_prism() -> OpticPrism<Shape, i32> {
OpticPrism::new(shape_is_circle, shape_extract_circle, shape_review_circle)
}
#[test]
fn optic_prism_matches() {
let p = circle_prism();
assert!(p.matches(&Shape::Circle(5)));
assert!(!p.matches(&Shape::Square(3)));
}
#[test]
fn optic_prism_extract_some_on_match() {
let p = circle_prism();
assert_eq!(p.extract(&Shape::Circle(5)), Some(5));
}
#[test]
fn optic_prism_extract_none_on_mismatch() {
let p = circle_prism();
assert_eq!(p.extract(&Shape::Square(3)), None);
}
#[test]
fn optic_prism_review() {
let p = circle_prism();
assert_eq!(p.review(7), Shape::Circle(7));
}
fn seed<T: Clone>(v: T) -> Optic<(), T, Infallible, ScalarLoss> {
Optic::ok((), v)
}
#[test]
fn optic_prism_focus_matching_is_lossless() {
let p = circle_prism();
let beam = seed(Shape::Circle(42));
let focused = p.focus(beam);
assert_eq!(focused.result().ok(), Some(&42));
assert!(!focused.is_partial());
}
#[test]
fn optic_prism_focus_nonmatch_produces_infinite_loss() {
let p = circle_prism();
let beam = seed(Shape::Square(3));
let focused = p.focus(beam);
assert!(focused.is_partial());
assert_eq!(focused.result().ok(), Some(&-1));
}
#[test]
fn optic_prism_full_pipeline_matching() {
let p = circle_prism();
let focused = p.focus(seed(Shape::Circle(10)));
let projected = p.project(focused);
let refracted = p.settle(projected);
assert_eq!(refracted.result().ok(), Some(&10));
assert!(!refracted.is_partial());
}
#[test]
fn optic_prism_full_pipeline_nonmatch_carries_loss() {
let p = circle_prism();
let focused = p.focus(seed(Shape::Square(3)));
assert!(focused.is_partial());
let projected = p.project(focused);
assert!(projected.is_partial());
let refracted = p.settle(projected);
assert!(refracted.is_partial());
}
#[test]
fn optic_prism_is_clone_and_copy() {
let p = circle_prism();
let p2 = p; let p3 = p2.clone(); assert!(p3.matches(&Shape::Circle(1)));
}
#[test]
fn optic_prism_same_fns_same_oid() {
use crate::Addressable;
let a = circle_prism();
let b = circle_prism();
assert_eq!(a.oid(), b.oid());
}
#[test]
fn optic_prism_different_fns_different_oid() {
use crate::Addressable;
fn square_match(s: &Shape) -> bool {
matches!(s, Shape::Square(_))
}
fn square_extract(s: &Shape) -> i32 {
if let Shape::Square(r) = s {
*r
} else {
-1
}
}
fn square_review(r: i32) -> Shape {
Shape::Square(r)
}
let a = circle_prism();
let b = OpticPrism::new(square_match, square_extract, square_review);
assert_ne!(a.oid(), b.oid());
}
}