qfall_math/integer_mod_q/mat_zq/
concat.rs1use super::MatZq;
12use crate::{
13 error::MathError,
14 traits::{CompareBase, Concatenate, MatrixDimensions},
15};
16use flint_sys::fmpz_mod_mat::{fmpz_mod_mat_concat_horizontal, fmpz_mod_mat_concat_vertical};
17
18impl Concatenate for &MatZq {
19 type Output = MatZq;
20
21 fn concat_vertical(self, other: Self) -> Result<Self::Output, crate::error::MathError> {
48 if self.get_num_columns() != other.get_num_columns() {
49 return Err(MathError::MismatchingMatrixDimension(format!(
50 "Tried to concatenate vertically a '{}x{}' matrix and a '{}x{}' matrix.",
51 self.get_num_rows(),
52 self.get_num_columns(),
53 other.get_num_rows(),
54 other.get_num_columns()
55 )));
56 }
57
58 if !self.compare_base(other) {
59 return Err(self.call_compare_base_error(other).unwrap());
60 }
61
62 let mut out = MatZq::new(
63 self.get_num_rows() + other.get_num_rows(),
64 self.get_num_columns(),
65 self.get_mod(),
66 );
67 unsafe {
68 fmpz_mod_mat_concat_vertical(&mut out.matrix, &self.matrix, &other.matrix);
69 }
70 Ok(out)
71 }
72
73 fn concat_horizontal(self, other: Self) -> Result<Self::Output, crate::error::MathError> {
100 if self.get_num_rows() != other.get_num_rows() {
101 return Err(MathError::MismatchingMatrixDimension(format!(
102 "Tried to concatenate horizontally a '{}x{}' matrix and a '{}x{}' matrix.",
103 self.get_num_rows(),
104 self.get_num_columns(),
105 other.get_num_rows(),
106 other.get_num_columns()
107 )));
108 }
109
110 if !self.compare_base(other) {
111 return Err(self.call_compare_base_error(other).unwrap());
112 }
113
114 let mut out = MatZq::new(
115 self.get_num_rows(),
116 self.get_num_columns() + other.get_num_columns(),
117 self.get_mod(),
118 );
119 unsafe {
120 fmpz_mod_mat_concat_horizontal(&mut out.matrix, &self.matrix, &other.matrix);
121 }
122
123 Ok(out)
124 }
125}
126
127#[cfg(test)]
128mod test_concatenate {
129 use crate::{
130 integer_mod_q::MatZq,
131 traits::{Concatenate, MatrixDimensions},
132 };
133 use std::str::FromStr;
134
135 #[test]
138 fn dimensions_vertical() {
139 let mat_1 = MatZq::new(13, 5, 17);
140 let mat_2 = MatZq::new(17, 5, 17);
141 let mat_3 = MatZq::new(17, 6, 17);
142
143 let mat_vert = mat_1.concat_vertical(&mat_2).unwrap();
144
145 assert_eq!(5, mat_vert.get_num_columns());
146 assert_eq!(30, mat_vert.get_num_rows());
147 assert!(mat_1.concat_vertical(&mat_3).is_err());
148 }
149
150 #[test]
153 fn dimensions_horizontal() {
154 let mat_1 = MatZq::new(13, 5, 17);
155 let mat_2 = MatZq::new(17, 5, 17);
156 let mat_3 = MatZq::new(17, 6, 17);
157
158 let mat_hor = mat_2.concat_horizontal(&mat_3).unwrap();
159
160 assert_eq!(11, mat_hor.get_num_columns());
161 assert_eq!(17, mat_hor.get_num_rows());
162 assert!(mat_1.concat_horizontal(&mat_2).is_err());
163 }
164
165 #[test]
168 fn mismatching_moduli() {
169 let mat_1 = MatZq::new(2, 2, 17);
170 let mat_2 = MatZq::new(2, 2, 19);
171
172 let mat_hor = mat_1.concat_horizontal(&mat_2);
173 let mat_vert = mat_1.concat_vertical(&mat_2);
174
175 assert!(mat_hor.is_err());
176 assert!(mat_vert.is_err());
177 }
178
179 #[test]
181 fn vertically_correct() {
182 let mat_1 = MatZq::from_str(&format!(
183 "[[1, 2, {}],[4, 5, {}]] mod {}",
184 i64::MIN,
185 i64::MAX,
186 u64::MAX
187 ))
188 .unwrap();
189 let mat_2 = MatZq::from_str(&format!("[[-1, 2, -17]] mod {}", u64::MAX)).unwrap();
190
191 let mat_vertical = mat_1.concat_vertical(&mat_2).unwrap();
192
193 let cmp_mat = MatZq::from_str(&format!(
194 "[[1, 2, {}],[4, 5, {}],[-1, 2, -17]] mod {}",
195 i64::MIN,
196 i64::MAX,
197 u64::MAX,
198 ))
199 .unwrap();
200 assert_eq!(cmp_mat, mat_vertical);
201 }
202
203 #[test]
205 fn horizontally_correct() {
206 let mat_1 = MatZq::from_str(&format!(
207 "[[1, 2, {}],[4, 5, {}]] mod {}",
208 i64::MIN,
209 i64::MAX,
210 u64::MAX
211 ))
212 .unwrap();
213 let mat_2 = MatZq::from_str(&format!("[[-1, 2],[4, 5]] mod {}", u64::MAX)).unwrap();
214
215 let mat_horizontal = mat_1.concat_horizontal(&mat_2).unwrap();
216
217 let cmp_mat = MatZq::from_str(&format!(
218 "[[1, 2, {}, -1, 2],[4, 5, {}, 4, 5]] mod {}",
219 i64::MIN,
220 i64::MAX,
221 u64::MAX
222 ))
223 .unwrap();
224 assert_eq!(cmp_mat, mat_horizontal);
225 }
226}