1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
use rten_tensor::prelude::*;
use rten_tensor::{NdTensorView, SliceItem, Tensor, TensorView};

use crate::ops::{resolve_axis, InputList, OpError, Operator, Output};
use crate::static_dims;

pub fn split<T: Copy>(
    input: TensorView<T>,
    axis: isize,
    split: &NdTensorView<i32, 1>,
) -> Result<Vec<Tensor<T>>, OpError> {
    let axis = resolve_axis(input.ndim(), axis)?;

    if split.iter().any(|size| *size < 0) {
        return Err(OpError::InvalidValue("Split sizes must be >= 0"));
    }
    let split_sum = split.iter().sum::<i32>() as usize;
    if split_sum != input.size(axis) {
        return Err(OpError::InvalidValue(
            "Split sizes do not sum to dimension size",
        ));
    }

    let mut split_start = 0;
    let outputs = split
        .iter()
        .map(|&split_size| {
            let split_size = split_size as usize;
            let slice_range: Vec<SliceItem> = (0..input.ndim())
                .map(|dim| {
                    if dim == axis {
                        (split_start..split_start + split_size).into()
                    } else {
                        SliceItem::full_range()
                    }
                })
                .collect();

            split_start += split_size;

            input.slice_dyn(slice_range.as_slice()).to_tensor()
        })
        .collect();

    Ok(outputs)
}

#[derive(Debug)]
pub struct Split {
    pub axis: isize,
}

impl Operator for Split {
    fn name(&self) -> &str {
        "Split"
    }

    fn run(&self, inputs: InputList) -> Result<Vec<Output>, OpError> {
        let input = inputs.require_as::<f32>(0)?;
        let splits = inputs.require_as::<i32>(1)?;
        let splits = static_dims!(splits, 1)?;

        split(input, self.axis, &splits)
            .map(|tensors| tensors.into_iter().map(|t| t.into()).collect())
    }
}

#[cfg(test)]
mod tests {
    use rten_tensor::prelude::*;
    use rten_tensor::tensor;

    use crate::ops::{split, OpError};

    #[test]
    fn test_split() {
        let input = tensor!((5, 2); [0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]);

        // Split with positive axis
        let splits = &[1, 1];
        let results = split(input.view(), 1, &splits.into()).unwrap();

        assert_eq!(results.len(), 2);
        assert_eq!(results[0].data().unwrap(), &[0., 2., 4., 6., 8.]);
        assert_eq!(results[1].data().unwrap(), &[1., 3., 5., 7., 9.]);

        // Split with negative axis
        let splits = &[1, 1];
        let results = split(input.view(), -1, &splits.into()).unwrap();

        assert_eq!(results.len(), 2);
        assert_eq!(results[0].data().unwrap(), &[0., 2., 4., 6., 8.]);
        assert_eq!(results[1].data().unwrap(), &[1., 3., 5., 7., 9.]);
    }

    #[test]
    fn test_split_invalid_inputs() {
        let input = tensor!((5, 2); [0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]);

        let splits = &[1, 1];
        let result = split(input.view(), 2, &splits.into());
        assert_eq!(result.err(), Some(OpError::InvalidValue("Axis is invalid")));

        let result = split(input.view(), -3, &splits.into());
        assert_eq!(result.err(), Some(OpError::InvalidValue("Axis is invalid")));

        let splits = &[1, 2];
        let result = split(input.view(), 1, &splits.into());
        assert_eq!(
            result.err(),
            Some(OpError::InvalidValue(
                "Split sizes do not sum to dimension size"
            ))
        );

        let splits = &[1, -2];
        let result = split(input.view(), 1, &splits.into());
        assert_eq!(
            result.err(),
            Some(OpError::InvalidValue("Split sizes must be >= 0"))
        );
    }
}