1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
use continuous::Interval;
use core::{Space, Card, Surjection};
use discrete::Partition;
use std::fmt::{self, Display};

/// 2-dimensional homogeneous space.
#[derive(Clone, Copy, Serialize, Deserialize, Debug)]
pub struct PairSpace<D1, D2>(pub D1, pub D2)
where
    D1: Space,
    D2: Space;

impl<D1: Space, D2: Space> PairSpace<D1, D2> {
    pub fn new(d1: D1, d2: D2) -> Self { PairSpace(d1, d2) }
}

impl PairSpace<Interval, Interval> {
    pub fn partitioned(self, density: usize) -> PairSpace<Partition, Partition> {
        PairSpace(
            Partition::from_interval(self.0, density),
            Partition::from_interval(self.1, density),
        )
    }
}

impl<D1: Space, D2: Space> Space for PairSpace<D1, D2> {
    type Value = (D1::Value, D2::Value);

    fn dim(&self) -> usize { 2 }

    fn card(&self) -> Card { self.0.card() * self.1.card() }
}

impl<D1, X1, D2, X2> Surjection<(X1, X2), (D1::Value, D2::Value)> for PairSpace<D1, D2>
where
    D1: Space + Surjection<X1, <D1 as Space>::Value>,
    D2: Space + Surjection<X2, <D2 as Space>::Value>,
{
    fn map(&self, val: (X1, X2)) -> (D1::Value, D2::Value) {
        (self.0.map(val.0), self.1.map(val.1))
    }
}

impl<D1: Space + Display, D2: Space + Display> fmt::Display for PairSpace<D1, D2> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "({}, {})", self.0, self.1)
    }
}

#[cfg(test)]
mod tests {
    extern crate ndarray;

    use core::{Space, Card, Surjection};
    use continuous::Interval;
    use discrete::{Ordinal, Partition};
    use product::PairSpace;

    #[test]
    fn test_dim() {
        assert_eq!(PairSpace::new(Ordinal::new(2), Ordinal::new(2)).dim(), 2);
    }

    #[test]
    fn test_card() {
        assert_eq!(
            PairSpace::new(Ordinal::new(2), Ordinal::new(2)).card(),
            Card::Finite(4)
        );
    }

    #[test]
    fn test_partitioned() {
        let ps = PairSpace::new(Interval::bounded(0.0, 5.0), Interval::bounded(1.0, 2.0));
        let ps = ps.partitioned(5);

        assert_eq!(ps.0, Partition::new(0.0, 5.0, 5));
        assert_eq!(ps.1, Partition::new(1.0, 2.0, 5));
    }

    #[test]
    fn test_surjection() {
        let ps = PairSpace::new(Interval::bounded(0.0, 5.0), Interval::bounded(1.0, 2.0));

        assert_eq!(ps.map((6.0, 0.0)), (5.0, 1.0));
        assert_eq!(ps.map((2.5, 1.5)), (2.5, 1.5));
        assert_eq!(ps.map((-1.0, 10.0)), (0.0, 2.0));
    }
}