1use super::{
2 column_commitment_metadata::ColumnCommitmentMetadataMismatch, ColumnCommitmentMetadata,
3 CommittableColumn,
4};
5use crate::base::{database::ColumnField, map::IndexMap};
6use alloc::string::{String, ToString};
7use snafu::Snafu;
8use sqlparser::ast::Ident;
9
10pub type ColumnCommitmentMetadataMap = IndexMap<Ident, ColumnCommitmentMetadata>;
12
13#[derive(Debug, Snafu)]
15pub enum ColumnCommitmentsMismatch {
16 #[snafu(transparent)]
18 ColumnCommitmentMetadata {
19 source: ColumnCommitmentMetadataMismatch,
21 },
22 #[snafu(display("commitments with different column counts cannot operate with each other"))]
24 NumColumns,
25 #[snafu(display("column with ident {id_a} cannot operate with column with ident {id_b}"))]
29 Ident {
30 id_a: String,
32 id_b: String,
34 },
35}
36
37pub trait ColumnCommitmentMetadataMapExt {
39 fn from_column_fields_with_max_bounds(columns: &[ColumnField]) -> Self;
42
43 fn from_columns<'a>(
45 columns: impl IntoIterator<Item = (&'a Ident, &'a CommittableColumn<'a>)>,
46 ) -> Self
47 where
48 Self: Sized;
49
50 fn try_union(self, other: Self) -> Result<Self, ColumnCommitmentsMismatch>
52 where
53 Self: Sized;
54
55 fn try_difference(self, other: Self) -> Result<Self, ColumnCommitmentsMismatch>
57 where
58 Self: Sized;
59}
60
61impl ColumnCommitmentMetadataMapExt for ColumnCommitmentMetadataMap {
62 fn from_column_fields_with_max_bounds(columns: &[ColumnField]) -> Self {
63 columns
64 .iter()
65 .map(|f| {
66 (
67 f.name().clone(),
68 ColumnCommitmentMetadata::from_column_type_with_max_bounds(f.data_type()),
69 )
70 })
71 .collect()
72 }
73
74 fn from_columns<'a>(
75 columns: impl IntoIterator<Item = (&'a Ident, &'a CommittableColumn<'a>)>,
76 ) -> Self
77 where
78 Self: Sized,
79 {
80 columns
81 .into_iter()
82 .map(|(identifier, column)| {
83 (
84 identifier.clone(),
85 ColumnCommitmentMetadata::from_column(column),
86 )
87 })
88 .collect()
89 }
90
91 fn try_union(self, other: Self) -> Result<Self, ColumnCommitmentsMismatch>
92 where
93 Self: Sized,
94 {
95 if self.len() != other.len() {
96 return Err(ColumnCommitmentsMismatch::NumColumns);
97 }
98
99 self.into_iter()
100 .zip(other)
101 .map(|((identifier_a, metadata_a), (identifier_b, metadata_b))| {
102 if identifier_a != identifier_b {
103 Err(ColumnCommitmentsMismatch::Ident {
104 id_a: identifier_a.to_string(),
105 id_b: identifier_b.to_string(),
106 })?;
107 }
108
109 Ok((identifier_a, metadata_a.try_union(metadata_b)?))
110 })
111 .collect()
112 }
113
114 fn try_difference(self, other: Self) -> Result<Self, ColumnCommitmentsMismatch>
115 where
116 Self: Sized,
117 {
118 if self.len() != other.len() {
119 return Err(ColumnCommitmentsMismatch::NumColumns);
120 }
121
122 self.into_iter()
123 .zip(other)
124 .map(|((identifier_a, metadata_a), (identifier_b, metadata_b))| {
125 if identifier_a != identifier_b {
126 Err(ColumnCommitmentsMismatch::Ident {
127 id_a: identifier_a.to_string(),
128 id_b: identifier_b.to_string(),
129 })?;
130 }
131
132 Ok((identifier_a, metadata_a.try_difference(metadata_b)?))
133 })
134 .collect()
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use crate::base::{
142 commitment::{column_bounds::Bounds, ColumnBounds},
143 database::{owned_table_utility::*, ColumnType, OwnedTable},
144 scalar::test_scalar::TestScalar,
145 };
146 use alloc::vec::Vec;
147 use itertools::Itertools;
148
149 fn metadata_map_from_owned_table(
150 table: &OwnedTable<TestScalar>,
151 ) -> ColumnCommitmentMetadataMap {
152 let (identifiers, columns): (Vec<&Ident>, Vec<CommittableColumn>) = table
153 .inner_table()
154 .into_iter()
155 .map(|(identifier, owned_column)| (identifier, CommittableColumn::from(owned_column)))
156 .unzip();
157
158 ColumnCommitmentMetadataMap::from_columns(identifiers.into_iter().zip(columns.iter()))
159 }
160
161 #[test]
162 fn we_can_construct_metadata_map_from_columns() {
163 let empty_metadata_map = ColumnCommitmentMetadataMap::from_columns([]);
165 assert_eq!(empty_metadata_map.len(), 0);
166
167 let table: OwnedTable<TestScalar> = owned_table([
169 bigint("bigint_column", [1, 5, -5, 0]),
170 int128("int128_column", [100, 200, 300, 400]),
171 varchar("varchar_column", ["Lorem", "ipsum", "dolor", "sit"]),
172 scalar("scalar_column", [1000, 2000, -1000, 0]),
173 ]);
174
175 let metadata_map = metadata_map_from_owned_table(&table);
176
177 assert_eq!(metadata_map.len(), 4);
178
179 let (index_0, metadata_0) = metadata_map.get_index(0).unwrap();
180 assert_eq!(index_0.value.as_str(), "bigint_column");
181 assert_eq!(metadata_0.column_type(), &ColumnType::BigInt);
182 if let ColumnBounds::BigInt(Bounds::Sharp(bounds)) = metadata_0.bounds() {
183 assert_eq!(bounds.min(), &-5);
184 assert_eq!(bounds.max(), &5);
185 } else {
186 panic!("metadata constructed from BigInt column should have BigInt/Sharp bounds");
187 }
188
189 let (index_1, metadata_1) = metadata_map.get_index(1).unwrap();
190 assert_eq!(index_1.value.as_str(), "int128_column");
191 assert_eq!(metadata_1.column_type(), &ColumnType::Int128);
192 if let ColumnBounds::Int128(Bounds::Sharp(bounds)) = metadata_1.bounds() {
193 assert_eq!(bounds.min(), &100);
194 assert_eq!(bounds.max(), &400);
195 } else {
196 panic!("metadata constructed from Int128 column should have Int128/Sharp bounds");
197 }
198
199 let (index_2, metadata_2) = metadata_map.get_index(2).unwrap();
200 assert_eq!(index_2.value.as_str(), "varchar_column");
201 assert_eq!(metadata_2.column_type(), &ColumnType::VarChar);
202 assert_eq!(metadata_2.bounds(), &ColumnBounds::NoOrder);
203
204 let (index_3, metadata_3) = metadata_map.get_index(3).unwrap();
205 assert_eq!(index_3.value.as_str(), "scalar_column");
206 assert_eq!(metadata_3.column_type(), &ColumnType::Scalar);
207 assert_eq!(metadata_3.bounds(), &ColumnBounds::NoOrder);
208 }
209
210 #[test]
211 fn we_can_union_matching_metadata_maps() {
212 let table_a = owned_table([
213 bigint("bigint_column", [1, 5]),
214 int128("int128_column", [100, 200]),
215 varchar("varchar_column", ["Lorem", "ipsum"]),
216 scalar("scalar_column", [1000, 2000]),
217 ]);
218 let metadata_a = metadata_map_from_owned_table(&table_a);
219
220 let table_b = owned_table([
221 bigint("bigint_column", [-5, 0, 10]),
222 int128("int128_column", [300, 400, 500]),
223 varchar("varchar_column", ["dolor", "sit", "amet"]),
224 scalar("scalar_column", [-1000, 0, -2000]),
225 ]);
226 let metadata_b = metadata_map_from_owned_table(&table_b);
227
228 let table_c = owned_table([
229 bigint("bigint_column", [1, 5, -5, 0, 10]),
230 int128("int128_column", [100, 200, 300, 400, 500]),
231 varchar("varchar_column", ["Lorem", "ipsum", "dolor", "sit", "amet"]),
232 scalar("scalar_column", [1000, 2000, -1000, 0, -2000]),
233 ]);
234 let metadata_c = metadata_map_from_owned_table(&table_c);
235
236 assert_eq!(metadata_a.try_union(metadata_b).unwrap(), metadata_c);
237 }
238 #[test]
239 fn we_can_difference_matching_metadata_maps() {
240 let table_a = owned_table([
241 bigint("bigint_column", [1, 5]),
242 int128("int128_column", [100, 200]),
243 varchar("varchar_column", ["Lorem", "ipsum"]),
244 scalar("scalar_column", [1000, 2000]),
245 ]);
246 let metadata_a = metadata_map_from_owned_table(&table_a);
247
248 let table_b = owned_table([
249 bigint("bigint_column", [1, 5, -5, 0, 10]),
250 int128("int128_column", [100, 200, 300, 400, 500]),
251 varchar("varchar_column", ["Lorem", "ipsum", "dolor", "sit", "amet"]),
252 scalar("scalar_column", [1000, 2000, -1000, 0, -2000]),
253 ]);
254 let metadata_b = metadata_map_from_owned_table(&table_b);
255
256 let b_difference_a = metadata_b.try_difference(metadata_a.clone()).unwrap();
257
258 assert_eq!(b_difference_a.len(), 4);
259
260 let (index_0, metadata_0) = b_difference_a.get_index(0).unwrap();
262 assert_eq!(index_0.value.as_str(), "bigint_column");
263 assert_eq!(metadata_0.column_type(), &ColumnType::BigInt);
264 if let ColumnBounds::BigInt(Bounds::Bounded(bounds)) = metadata_0.bounds() {
265 assert_eq!(bounds.min(), &-5);
266 assert_eq!(bounds.max(), &10);
267 } else {
268 panic!("difference of overlapping bounds should be Bounded");
269 }
270
271 let (index_1, metadata_1) = b_difference_a.get_index(1).unwrap();
272 assert_eq!(index_1.value.as_str(), "int128_column");
273 assert_eq!(metadata_1.column_type(), &ColumnType::Int128);
274 if let ColumnBounds::Int128(Bounds::Bounded(bounds)) = metadata_1.bounds() {
275 assert_eq!(bounds.min(), &100);
276 assert_eq!(bounds.max(), &500);
277 } else {
278 panic!("difference of overlapping bounds should be Bounded");
279 }
280
281 assert_eq!(
283 b_difference_a.get_index(2).unwrap(),
284 metadata_a.get_index(2).unwrap()
285 );
286
287 assert_eq!(
288 b_difference_a.get_index(3).unwrap(),
289 metadata_a.get_index(3).unwrap()
290 );
291 }
292
293 #[test]
294 fn we_cannot_perform_arithmetic_on_metadata_maps_with_different_column_counts() {
295 let table_a = owned_table([
296 bigint("bigint_column", [1, 5, -5, 0, 10]),
297 int128("int128_column", [100, 200, 300, 400, 500]),
298 varchar("varchar_column", ["Lorem", "ipsum", "dolor", "sit", "amet"]),
299 scalar("scalar_column", [1000, 2000, -1000, 0, -2000]),
300 ]);
301 let metadata_a = metadata_map_from_owned_table(&table_a);
302
303 let table_b = owned_table([
304 bigint("bigint_column", [1, 5, -5, 0, 10]),
305 varchar("varchar_column", ["Lorem", "ipsum", "dolor", "sit", "amet"]),
306 ]);
307 let metadata_b = metadata_map_from_owned_table(&table_b);
308
309 assert!(matches!(
310 metadata_a.clone().try_union(metadata_b.clone()),
311 Err(ColumnCommitmentsMismatch::NumColumns)
312 ));
313 assert!(matches!(
314 metadata_b.try_union(metadata_a.clone()),
315 Err(ColumnCommitmentsMismatch::NumColumns)
316 ));
317
318 let empty_metadata = ColumnCommitmentMetadataMap::default();
319
320 assert!(matches!(
321 metadata_a.clone().try_union(empty_metadata.clone()),
322 Err(ColumnCommitmentsMismatch::NumColumns)
323 ));
324 assert!(matches!(
325 empty_metadata.try_union(metadata_a),
326 Err(ColumnCommitmentsMismatch::NumColumns)
327 ));
328 }
329
330 #[expect(clippy::similar_names)]
331 #[test]
332 fn we_cannot_perform_arithmetic_on_mismatched_metadata_maps_with_same_column_counts() {
333 let id_a = "column_a";
334 let id_b = "column_b";
335 let id_c = "column_c";
336 let id_d = "column_d";
337 let ints = [1i64, 2, 3, 4];
338 let strings = ["Lorem", "ipsum", "dolor", "sit"];
339
340 let ab_ii_metadata =
341 metadata_map_from_owned_table(&owned_table([bigint(id_a, ints), bigint(id_b, ints)]));
342
343 let ab_iv_metadata = metadata_map_from_owned_table(&owned_table([
344 bigint(id_a, ints),
345 varchar(id_b, strings),
346 ]));
347
348 let ab_vi_metadata = metadata_map_from_owned_table(&owned_table([
349 varchar(id_a, strings),
350 bigint(id_b, ints),
351 ]));
352
353 let ad_ii_metadata =
354 metadata_map_from_owned_table(&owned_table([bigint(id_a, ints), bigint(id_d, ints)]));
355
356 let cb_ii_metadata =
357 metadata_map_from_owned_table(&owned_table([bigint(id_c, ints), bigint(id_b, ints)]));
358
359 let cd_vv_metadata = metadata_map_from_owned_table(&owned_table([
360 varchar(id_c, strings),
361 varchar(id_d, strings),
362 ]));
363
364 let mismatched_metadata_maps = [
367 ab_ii_metadata,
368 ab_iv_metadata,
369 ab_vi_metadata,
370 ad_ii_metadata,
371 cb_ii_metadata,
372 cd_vv_metadata,
373 ];
374
375 for (metadata_map_a, metadata_map_b) in
376 mismatched_metadata_maps.into_iter().tuple_combinations()
377 {
378 assert!(metadata_map_a
379 .clone()
380 .try_union(metadata_map_b.clone())
381 .is_err());
382 assert!(metadata_map_b
383 .clone()
384 .try_union(metadata_map_a.clone())
385 .is_err());
386 assert!(metadata_map_a
387 .clone()
388 .try_difference(metadata_map_b.clone())
389 .is_err());
390 assert!(metadata_map_b.try_difference(metadata_map_a).is_err());
391 }
392 }
393}