1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
use crate::arrays::Array;
use crate::traits::{
    create::ArrayCreate,
    errors::ArrayError,
    manipulate::{
        ArrayManipulate,
        broadcast::ArrayBroadcast,
    },
    meta::ArrayMeta,
    types::{
        numeric::Numeric,
        tuple_numeric::Tuple2,
    },
    validators::{
        validate_has_error::ValidateHasError,
        validate_shape::ValidateShape,
    },
};

impl <N: Numeric> ArrayBroadcast<N> for Array<N> {

    fn broadcast(&self, other: &Array<N>) -> Result<Array<Tuple2<N>>, ArrayError> {
        self.get_shape()?.is_broadcastable(&other.get_shape()?)?;

        let final_shape = self.broadcast_shape(other.get_shape()?)?;

        let inner_arrays_self = self.extract_inner_arrays();
        let inner_arrays_other = other.extract_inner_arrays();

        let output_elements = inner_arrays_self.iter().cycle()
            .zip(inner_arrays_other.iter().cycle())
            .flat_map( | (inner_self, inner_other) | match (inner_self.len(), inner_other.len()) {
                (1, _) => inner_self.iter().cycle()
                    .zip(inner_other.iter())
                    .take(final_shape[final_shape.len() - 1])
                    .map( | ( & a, & b) | Tuple2(a, b))
                    .collect::< Vec < _ > > (),
                (_, 1) => inner_self.iter()
                    .zip(inner_other.iter().cycle())
                    .take(final_shape[final_shape.len() - 1])
                    .map(| ( & a, & b) | Tuple2(a, b))
                    .collect::<Vec < _ > > (),
                _ => inner_self.iter().cycle()
                    .zip(inner_other.iter().cycle())
                    .take(final_shape[final_shape.len() - 1])
                    .map( |( & a, & b) | Tuple2(a, b))
                    .collect::< Vec< _ > > (),
            })
            .take(final_shape.iter().product())
            .collect:: < Vec<_ > > ();

        Array::new(output_elements, final_shape)
    }

    fn broadcast_to(&self, shape: Vec<usize>) -> Result<Array<N>, ArrayError> {
        self.get_shape()?.is_broadcastable(&shape)?;

        if self.get_shape()?.iter().product::<usize>() == shape.iter().product::<usize>() {
            self.reshape(shape)
        } else {
            let output_elements: Vec<N> = self.elements
                .chunks_exact(self.shape[self.shape.len() - 1])
                .flat_map(|inner| {
                    let extended_inner = inner.iter()
                        .cycle()
                        .take(shape[shape.len() - 1])
                        .copied()
                        .collect::<Vec<N>>();
                    extended_inner.into_iter()
                })
                .cycle()
                .take(shape.iter().product())
                .collect();

            Array::new(output_elements, shape)
        }
    }

    fn broadcast_arrays(arrays: Vec<Array<N>>) -> Result<Vec<Array<N>>, ArrayError> {
        arrays.iter().map(|array| array.get_shape()).collect::<Vec<Result<Vec<usize>, ArrayError>>>().has_error()?;
        let shapes = arrays.iter()
            .map(|array| array.get_shape().unwrap())
            .collect::<Vec<_>>();

        let common_shape = Self::common_broadcast_shape(&shapes);
        if let Ok(common_shape) = common_shape {
            let result = arrays.iter()
                .map(|array| array.broadcast_to(common_shape.clone()))
                .collect::<Vec<Result<Self, _>>>()
                .has_error()?
                .into_iter().map(|a| a.unwrap())
                .collect();
            Ok(result)
        } else {
            Err(common_shape.err().unwrap())
        }
    }
}

impl <N: Numeric> ArrayBroadcast<N> for Result<Array<N>, ArrayError> {

    fn broadcast(&self, other: &Array<N>) -> Result<Array<Tuple2<N>>, ArrayError> {
        self.clone()?.broadcast(other)
    }

    fn broadcast_to(&self, shape: Vec<usize>) -> Result<Array<N>, ArrayError> {
        self.clone()?.broadcast_to(shape)
    }

    fn broadcast_arrays(arrays: Vec<Array<N>>) -> Result<Vec<Array<N>>, ArrayError> {
        Array::broadcast_arrays(arrays)
    }
}

impl <N: Numeric> Array<N> {

    fn broadcast_shape(&self, shape: Vec<usize>) -> Result<Vec<usize>, ArrayError> {
        let max_dim = self.shape.len().max(shape.len());
        let shape1_padded = self.shape.iter().rev()
            .copied().chain(std::iter::repeat(1))
            .take(max_dim);
        let shape2_padded = shape.iter().rev()
            .copied().chain(std::iter::repeat(1))
            .take(max_dim);

        let zipped = shape1_padded.zip(shape2_padded.into_iter());
        let result = zipped
            .map(|(dim1, dim2)| {
                if dim1 == 1 { Ok(dim2) }
                else if dim2 == 1 || dim1 == dim2 { Ok(dim1) }
                else { Err(ArrayError::BroadcastShapeMismatch) }
            })
            .collect::<Vec<Result<usize, ArrayError>>>()
            .has_error()?.iter()
            .map(|a| *a.as_ref().unwrap())
            .collect();
        Ok(result)
    }

    fn common_broadcast_shape(shapes: &[Vec<usize>]) -> Result<Vec<usize>, ArrayError> {
        let max_dim = shapes.iter()
            .map(|shape| shape.len())
            .max().unwrap_or(0);

        let shapes_padded: Vec<_> = shapes
            .iter()
            .map(|shape| shape.iter().rev().copied()
                .chain(std::iter::repeat(1))
                .take(max_dim)
                .collect::<Vec<_>>()
            )
            .collect();

        let common_shape: Vec<usize> = (0..max_dim)
            .map(|dim_idx| shapes_padded.iter()
                .map(|shape| shape[dim_idx])
                .max().unwrap_or(1)
            )
            .collect();

        let is_compatible = shapes_padded.iter()
            .all(|shape| common_shape.iter().enumerate()
                .all(|(dim_idx, &common_dim)| {
                    let dim = shape[dim_idx];
                    dim == common_dim || dim == 1 || common_dim == 1
                })
            );

        if is_compatible { Ok(common_shape.into_iter().rev().collect()) }
        else { Err(ArrayError::BroadcastShapeMismatch) }
    }

    fn extract_inner_arrays(&self) -> Vec<Vec<N>> {
        match self.shape.len() {
            1 => vec![self.elements.clone()],
            _ => self.elements
                .chunks_exact(*self.shape.last().unwrap())
                .map(Vec::from)
                .collect(),
        }
    }
}