duskphantom_frontend/transform/
reshape_array.rs

1// Copyright 2024 Duskphantom Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// SPDX-License-Identifier: Apache-2.0
16
17use std::collections::VecDeque;
18
19use anyhow::Result;
20
21use crate::{Expr, Type};
22
23/// Reshape a possibly flattened constant array to nested.
24///
25/// # Panics
26/// Please make sure `arr` is non-empty.
27pub fn reshape_const_array(arr: &mut VecDeque<Expr>, ty: &Type) -> Result<Expr> {
28    if arr.is_empty() {
29        // Use default initializer for `{}`
30        return ty.default_initializer();
31    }
32    if let Type::Array(element_ty, len) = ty {
33        let size = len.to_i32()?;
34        let mut new_arr: Vec<Expr> = vec![];
35        for _ in 0..size {
36            let Some(first_item) = arr.pop_front() else {
37                // Later elements are missing, fill with default initializer
38                new_arr.push(element_ty.default_initializer()?);
39                continue;
40            };
41            if let Expr::Array(arr) = first_item {
42                // First element is array, sub-array is nested
43                new_arr.push(reshape_const_array(&mut VecDeque::from(arr), element_ty)?);
44            } else {
45                // First element is non-array, sub-array is flattened
46                arr.push_front(first_item);
47                new_arr.push(reshape_const_array(arr, element_ty)?);
48            }
49        }
50        Ok(Expr::Array(new_arr))
51    } else {
52        Ok(arr.pop_front().unwrap())
53    }
54}
55
56/// Reshape a possibly flattened array to nested.
57///
58/// # Panics
59/// Please make sure `arr` is non-empty.
60pub fn reshape_array(arr: &mut VecDeque<Expr>, ty: &Type) -> Result<Expr> {
61    if let Type::Array(element_ty, len) = ty {
62        let size = len.to_i32()?;
63        let mut new_arr: Vec<Expr> = vec![];
64        for _ in 0..size {
65            let Some(first_item) = arr.pop_front() else {
66                break;
67            };
68            if let Expr::Array(arr) = first_item {
69                // First element is array, sub-array is nested
70                new_arr.push(reshape_array(&mut VecDeque::from(arr), element_ty)?);
71            } else {
72                // First element is non-array, sub-array is flattened
73                arr.push_front(first_item);
74                new_arr.push(reshape_array(arr, element_ty)?);
75            }
76        }
77        Ok(Expr::Array(new_arr))
78    } else {
79        Ok(arr.pop_front().unwrap())
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use std::collections::VecDeque;
86
87    use crate::{transform::reshape_array::reshape_array, Expr, Type};
88
89    #[test]
90    fn test_reshape_flattened_array() {
91        let arr = vec![Expr::Int(1), Expr::Int(2), Expr::Int(3), Expr::Int(4)];
92        let mut vec_deque = VecDeque::from(arr);
93        let ty = Type::Array(
94            Type::Array(Type::Int.into(), Expr::Int(2).into()).into(),
95            Expr::Int(2).into(),
96        );
97        let res = reshape_array(&mut vec_deque, &ty).unwrap();
98        assert_eq!(
99            res,
100            Expr::Array(vec![
101                Expr::Array(vec![Expr::Int(1), Expr::Int(2)]),
102                Expr::Array(vec![Expr::Int(3), Expr::Int(4)]),
103            ])
104        );
105    }
106
107    #[test]
108    fn test_reshape_nested_array() {
109        let arr = vec![
110            Expr::Array(vec![Expr::Int(1), Expr::Int(2)]),
111            Expr::Array(vec![Expr::Int(3), Expr::Int(4)]),
112        ];
113        let mut vec_deque = VecDeque::from(arr);
114        let ty = Type::Array(
115            Type::Array(Type::Int.into(), Expr::Int(2).into()).into(),
116            Expr::Int(2).into(),
117        );
118        let res = reshape_array(&mut vec_deque, &ty).unwrap();
119        assert_eq!(
120            res,
121            Expr::Array(vec![
122                Expr::Array(vec![Expr::Int(1), Expr::Int(2)]),
123                Expr::Array(vec![Expr::Int(3), Expr::Int(4)]),
124            ])
125        );
126    }
127
128    #[test]
129    fn test_reshape_mixed_array() {
130        let arr = vec![
131            Expr::Int(1),
132            Expr::Int(2),
133            Expr::Array(vec![Expr::Int(3), Expr::Int(4)]),
134        ];
135        let mut vec_deque = VecDeque::from(arr);
136        let ty = Type::Array(
137            Type::Array(Type::Int.into(), Expr::Int(2).into()).into(),
138            Expr::Int(2).into(),
139        );
140        let res = reshape_array(&mut vec_deque, &ty).unwrap();
141        assert_eq!(
142            res,
143            Expr::Array(vec![
144                Expr::Array(vec![Expr::Int(1), Expr::Int(2)]),
145                Expr::Array(vec![Expr::Int(3), Expr::Int(4)]),
146            ])
147        );
148    }
149
150    #[test]
151    fn test_reshape_mixed_array_2() {
152        let arr = vec![
153            Expr::Array(vec![Expr::Int(1), Expr::Int(2)]),
154            Expr::Int(3),
155            Expr::Int(4),
156        ];
157        let mut vec_deque = VecDeque::from(arr);
158        let ty = Type::Array(
159            Type::Array(Type::Int.into(), Expr::Int(2).into()).into(),
160            Expr::Int(2).into(),
161        );
162        let res = reshape_array(&mut vec_deque, &ty).unwrap();
163        assert_eq!(
164            res,
165            Expr::Array(vec![
166                Expr::Array(vec![Expr::Int(1), Expr::Int(2)]),
167                Expr::Array(vec![Expr::Int(3), Expr::Int(4)]),
168            ])
169        );
170    }
171
172    #[test]
173    fn test_reshape_fractured_array() {
174        let arr = vec![Expr::Array(vec![Expr::Int(1)]), Expr::Int(3)];
175        let mut vec_deque = VecDeque::from(arr);
176        let ty = Type::Array(
177            Type::Array(Type::Int.into(), Expr::Int(2).into()).into(),
178            Expr::Int(2).into(),
179        );
180        let res = reshape_array(&mut vec_deque, &ty).unwrap();
181        assert_eq!(
182            res,
183            Expr::Array(vec![
184                Expr::Array(vec![Expr::Int(1),]),
185                Expr::Array(vec![Expr::Int(3),]),
186            ])
187        );
188    }
189}