1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
use crate::basis::{Composable, Projection, Projector};
use crate::core::DenseT;
use crate::geometry::{Card, Space};

fn stack_projections(p1: Projection, n1: usize, p2: Projection, n2: usize) -> Projection {
    match (p1, p2) {
        (Projection::Sparse(mut p1_indices), Projection::Sparse(p2_indices)) => {
            p2_indices.iter().for_each(|&i| {
                p1_indices.insert(i + n1);
            });

            Projection::Sparse(p1_indices)
        },
        (p1, p2) => {
            let mut all_activations = p1.expanded(n1).to_vec();
            all_activations.extend_from_slice(p2.expanded(n2).as_slice().unwrap());

            Projection::Dense(DenseT::from_vec(all_activations))
        },
    }
}

/// Stack the output of two `Projector` instances.
#[derive(Clone, Copy, Serialize, Deserialize, Debug)]
pub struct Stack<P1, P2> {
    p1: P1,
    p2: P2,
}

impl<P1, P2> Stack<P1, P2> {
    pub fn new(p1: P1, p2: P2) -> Self { Stack { p1, p2 } }
}

impl<P1: Space, P2: Space> Space for Stack<P1, P2> {
    type Value = Projection;

    fn dim(&self) -> usize { self.p1.dim() + self.p2.dim() }

    fn card(&self) -> Card { self.p1.card() * self.p2.card() }
}

impl<I: ?Sized, P1: Projector<I>, P2: Projector<I>> Projector<I> for Stack<P1, P2> {
    fn project(&self, input: &I) -> Projection {
        stack_projections(
            self.p1.project(input),
            self.p1.dim(),
            self.p2.project(input),
            self.p2.dim(),
        )
    }
}

impl<P1, P2> Composable for Stack<P1, P2> {}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::basis::{
        fixed::{Constant, Indices},
        Projector,
    };
    use crate::geometry::Vector;
    use std::iter;

    #[test]
    fn test_stack_constant() {
        let p = Stack::new(Constant::zeros(10), Constant::ones(10));
        let output: Projection =
            Vector::from_iter(iter::repeat(0.0).take(10).chain(iter::repeat(1.0).take(10))).into();

        assert_eq!(p.dim(), 20);
        assert_eq!(p.project(&[0.0]), output);
        assert_eq!(p.project(&[0.0, 1.0]), output);
        assert_eq!(p.project(&[-1.0, 1.0]), output);
    }

    #[test]
    fn test_stack_indices() {
        let p = Stack::new(Indices::new(10, vec![5]), Indices::new(10, vec![0]));
        let output: Projection = vec![5, 10].into();

        assert_eq!(p.dim(), 20);
        assert_eq!(p.project(&[0.0]), output);
        assert_eq!(p.project(&[0.0, 1.0]), output);
        assert_eq!(p.project(&[-1.0, 1.0]), output);
    }

    #[test]
    fn test_stack_mixed() {
        let p = Stack::new(Constant::ones(10), Indices::new(10, vec![0]));
        let output: Projection =
            Vector::from_iter(iter::repeat(1.0).take(11).chain(iter::repeat(0.0).take(9))).into();

        assert_eq!(p.dim(), 20);
        assert_eq!(p.project(&[0.0]), output);
        assert_eq!(p.project(&[0.0, 1.0]), output);
        assert_eq!(p.project(&[-1.0, 1.0]), output);
    }
}