arrow_digest/
record_digest.rs

1use crate::{ArrayDigest, ArrayDigestV0, RecordDigest};
2use arrow::{
3    array::{Array, ArrayRef, StructArray},
4    buffer::NullBuffer,
5    datatypes::{DataType, Field, Fields, Schema},
6    record_batch::RecordBatch,
7};
8use digest::{Digest, Output, OutputSizeUser};
9
10/////////////////////////////////////////////////////////////////////////////////////////
11
12pub struct RecordDigestV0<Dig: Digest> {
13    columns: Vec<ArrayDigestV0<Dig>>,
14    hasher: Dig,
15}
16
17/////////////////////////////////////////////////////////////////////////////////////////
18
19impl<Dig: Digest> OutputSizeUser for RecordDigestV0<Dig> {
20    type OutputSize = Dig::OutputSize;
21}
22
23impl<Dig: Digest> RecordDigest for RecordDigestV0<Dig> {
24    fn digest(batch: &RecordBatch) -> Output<Dig> {
25        let mut d = Self::new(batch.schema().as_ref());
26        d.update(batch);
27        d.finalize()
28    }
29
30    fn new(schema: &Schema) -> Self {
31        let mut hasher = Dig::new();
32        let mut columns = Vec::new();
33
34        Self::walk_nested_fields(schema.fields(), 0, &mut |field, level| {
35            hasher.update((field.name().len() as u64).to_le_bytes());
36            hasher.update(field.name().as_bytes());
37            hasher.update((level as u64).to_le_bytes());
38
39            match field.data_type() {
40                DataType::Struct(_) => (),
41                _ => columns.push(ArrayDigestV0::new(field.data_type())),
42            }
43        });
44
45        Self { columns, hasher }
46    }
47
48    fn update(&mut self, batch: &RecordBatch) {
49        let mut col_index = 0;
50        Self::walk_nested_columns(
51            batch.columns().iter(),
52            None,
53            &mut |array, parent_null_bitmap| {
54                let col_digest = &mut self.columns[col_index];
55                col_digest.update(array.as_ref(), parent_null_bitmap);
56                col_index += 1;
57            },
58        );
59    }
60
61    fn finalize(mut self) -> Output<Dig> {
62        for c in self.columns {
63            let column_hash = c.finalize();
64            self.hasher.update(column_hash.as_slice());
65        }
66        self.hasher.finalize()
67    }
68}
69
70impl<Dig: Digest> RecordDigestV0<Dig> {
71    fn walk_nested_fields(fields: &Fields, level: usize, fun: &mut impl FnMut(&Field, usize)) {
72        for field in fields {
73            match field.data_type() {
74                DataType::Struct(nested_fields) => {
75                    fun(field, level);
76                    Self::walk_nested_fields(nested_fields, level + 1, fun);
77                }
78                _ => fun(field, level),
79            }
80        }
81    }
82
83    fn walk_nested_columns<'a>(
84        arrays: impl Iterator<Item = &'a ArrayRef>,
85        parent_null_bitmap: Option<&NullBuffer>,
86        fun: &mut impl FnMut(&ArrayRef, Option<&NullBuffer>),
87    ) {
88        for array in arrays {
89            match array.data_type() {
90                DataType::Struct(_) => {
91                    let array = array.as_any().downcast_ref::<StructArray>().unwrap();
92                    let array_data = array.to_data();
93
94                    let combined_nulls = crate::utils::maybe_combine_null_buffers(
95                        parent_null_bitmap,
96                        array_data.nulls(),
97                    );
98
99                    for i in 0..array.num_columns() {
100                        Self::walk_nested_columns(
101                            [array.column(i)].into_iter(),
102                            combined_nulls.as_option(),
103                            fun,
104                        );
105                    }
106                }
107                _ => fun(array, parent_null_bitmap),
108            }
109        }
110    }
111}
112
113/////////////////////////////////////////////////////////////////////////////////////////
114// Tests
115/////////////////////////////////////////////////////////////////////////////////////////
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use arrow::{
121        array::{Array, Int32Array, StringArray},
122        buffer::Buffer,
123        datatypes::{DataType, Field, Schema},
124        record_batch::RecordBatch,
125    };
126    use sha3::Sha3_256;
127    use std::sync::Arc;
128
129    #[test]
130    fn test_batch_mixed() {
131        let schema = Arc::new(Schema::new(vec![
132            Field::new("a", DataType::Int32, false),
133            Field::new("b", DataType::Utf8, false),
134        ]));
135
136        let a: Arc<dyn Array> = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
137        let b: Arc<dyn Array> = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
138        let c: Arc<dyn Array> = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6]));
139        let d: Arc<dyn Array> = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e", "d"]));
140
141        let record_batch1 =
142            RecordBatch::try_new(Arc::clone(&schema), vec![Arc::clone(&a), Arc::clone(&b)])
143                .unwrap();
144        let record_batch2 =
145            RecordBatch::try_new(Arc::clone(&schema), vec![Arc::clone(&a), Arc::clone(&b)])
146                .unwrap();
147        let record_batch3 =
148            RecordBatch::try_new(Arc::clone(&schema), vec![Arc::clone(&c), Arc::clone(&d)])
149                .unwrap();
150
151        assert_eq!(
152            RecordDigestV0::<Sha3_256>::digest(&record_batch1),
153            RecordDigestV0::<Sha3_256>::digest(&record_batch2),
154        );
155
156        assert_ne!(
157            RecordDigestV0::<Sha3_256>::digest(&record_batch2),
158            RecordDigestV0::<Sha3_256>::digest(&record_batch3),
159        );
160    }
161
162    #[test]
163    fn test_batch_nested() {
164        let schema = Arc::new(Schema::new(vec![
165            Field::new("a", DataType::Int32, false),
166            Field::new(
167                "b",
168                DataType::Struct(Fields::from(vec![
169                    Field::new("c", DataType::Utf8, false),
170                    Field::new("d", DataType::Int32, false),
171                ])),
172                false,
173            ),
174        ]));
175
176        let a: Arc<dyn Array> = Arc::new(Int32Array::from(vec![1, 2, 3]));
177        let c: Arc<dyn Array> = Arc::new(StringArray::from(vec!["a", "b", "c"]));
178        let d: Arc<dyn Array> = Arc::new(Int32Array::from(vec![3, 2, 1]));
179        let b = Arc::new(StructArray::from(vec![
180            (Arc::new(Field::new("c", DataType::Utf8, false)), c.clone()),
181            (Arc::new(Field::new("d", DataType::Int32, false)), d.clone()),
182        ]));
183
184        let record_batch1 = RecordBatch::try_new(schema, vec![a.clone(), b.clone()]).unwrap();
185
186        assert_eq!(
187            RecordDigestV0::<sha3::Sha3_256>::digest(&record_batch1),
188            RecordDigestV0::<sha3::Sha3_256>::digest(&record_batch1),
189        );
190
191        // Different column name
192        let schema = Arc::new(Schema::new(vec![
193            Field::new("a", DataType::Int32, false),
194            Field::new(
195                "bee",
196                DataType::Struct(Fields::from(vec![
197                    Field::new("c", DataType::Utf8, false),
198                    Field::new("d", DataType::Int32, false),
199                ])),
200                false,
201            ),
202        ]));
203
204        let record_batch2 = RecordBatch::try_new(schema, vec![a.clone(), b.clone()]).unwrap();
205
206        assert_ne!(
207            RecordDigestV0::<sha3::Sha3_256>::digest(&record_batch1),
208            RecordDigestV0::<sha3::Sha3_256>::digest(&record_batch2),
209        );
210
211        // Nullability - equal
212        let schema = Arc::new(Schema::new(vec![
213            Field::new("a", DataType::Int32, false),
214            Field::new(
215                "b",
216                DataType::Struct(Fields::from(vec![
217                    Field::new("c", DataType::Utf8, false),
218                    Field::new("d", DataType::Int32, false),
219                ])),
220                true,
221            ),
222        ]));
223
224        let b = Arc::new(StructArray::from((
225            vec![
226                (Arc::new(Field::new("c", DataType::Utf8, false)), c.clone()),
227                (Arc::new(Field::new("d", DataType::Int32, false)), d.clone()),
228            ],
229            Buffer::from([0b111]),
230        )));
231
232        let record_batch3 =
233            RecordBatch::try_new(schema.clone(), vec![a.clone(), b.clone()]).unwrap();
234
235        assert_eq!(
236            RecordDigestV0::<sha3::Sha3_256>::digest(&record_batch1),
237            RecordDigestV0::<sha3::Sha3_256>::digest(&record_batch3),
238        );
239
240        // Nullability - not equal
241        let b = Arc::new(StructArray::from((
242            vec![
243                (Arc::new(Field::new("c", DataType::Utf8, false)), c.clone()),
244                (Arc::new(Field::new("d", DataType::Int32, false)), d.clone()),
245            ],
246            Buffer::from([0b101]),
247        )));
248
249        let record_batch4 =
250            RecordBatch::try_new(schema.clone(), vec![a.clone(), b.clone()]).unwrap();
251
252        assert_ne!(
253            RecordDigestV0::<sha3::Sha3_256>::digest(&record_batch1),
254            RecordDigestV0::<sha3::Sha3_256>::digest(&record_batch4),
255        );
256    }
257
258    /*#[test]
259    fn test_batch_parquet() {
260        use crate::{RecordDigest, RecordDigestV0};
261        use parquet::arrow::ArrowReader;
262        use parquet::arrow::ParquetFileArrowReader;
263        use parquet::file::reader::SerializedFileReader;
264
265        let file = std::fs::File::open(
266            ".priv/97dfa84bb29db02b46cb33f6e8a7e51be3f15b3bbdac2e3e61849dcf5c67de6b",
267        )
268        .unwrap();
269        let parquet_reader = SerializedFileReader::new(file).unwrap();
270        let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(parquet_reader));
271
272        println!("{:?}", arrow_reader.get_schema());
273
274        let mut hasher = RecordDigestV0::<sha3::Sha3_256>::new(&arrow_reader.get_schema().unwrap());
275
276        for res_batch in arrow_reader.get_record_reader(100000).unwrap() {
277            let batch = res_batch.unwrap();
278            println!(".");
279            hasher.update(&batch);
280            println!("x");
281        }
282
283        println!("{:x}", hasher.finalize());
284    }*/
285}