1use crate::builder::ArrayBuilder;
19use crate::{Array, ArrayRef, MapArray, StructArray};
20use arrow_buffer::Buffer;
21use arrow_buffer::{NullBuffer, NullBufferBuilder};
22use arrow_data::ArrayData;
23use arrow_schema::{ArrowError, DataType, Field, FieldRef};
24use std::any::Any;
25use std::sync::Arc;
26
27#[derive(Debug)]
58pub struct MapBuilder<K: ArrayBuilder, V: ArrayBuilder> {
59    offsets_builder: Vec<i32>,
60    null_buffer_builder: NullBufferBuilder,
61    field_names: MapFieldNames,
62    key_builder: K,
63    value_builder: V,
64    key_field: Option<FieldRef>,
65    value_field: Option<FieldRef>,
66}
67
68#[derive(Debug, Clone)]
70pub struct MapFieldNames {
71    pub entry: String,
73    pub key: String,
75    pub value: String,
77}
78
79impl Default for MapFieldNames {
80    fn default() -> Self {
81        Self {
82            entry: "entries".to_string(),
83            key: "keys".to_string(),
84            value: "values".to_string(),
85        }
86    }
87}
88
89impl<K: ArrayBuilder, V: ArrayBuilder> MapBuilder<K, V> {
90    pub fn new(field_names: Option<MapFieldNames>, key_builder: K, value_builder: V) -> Self {
92        let capacity = key_builder.len();
93        Self::with_capacity(field_names, key_builder, value_builder, capacity)
94    }
95
96    pub fn with_capacity(
98        field_names: Option<MapFieldNames>,
99        key_builder: K,
100        value_builder: V,
101        capacity: usize,
102    ) -> Self {
103        let mut offsets_builder = Vec::with_capacity(capacity + 1);
104        offsets_builder.push(0);
105        Self {
106            offsets_builder,
107            null_buffer_builder: NullBufferBuilder::new(capacity),
108            field_names: field_names.unwrap_or_default(),
109            key_builder,
110            value_builder,
111            key_field: None,
112            value_field: None,
113        }
114    }
115
116    pub fn with_keys_field(self, field: impl Into<FieldRef>) -> Self {
123        Self {
124            key_field: Some(field.into()),
125            ..self
126        }
127    }
128
129    pub fn with_values_field(self, field: impl Into<FieldRef>) -> Self {
136        Self {
137            value_field: Some(field.into()),
138            ..self
139        }
140    }
141
142    pub fn keys(&mut self) -> &mut K {
144        &mut self.key_builder
145    }
146
147    pub fn values(&mut self) -> &mut V {
149        &mut self.value_builder
150    }
151
152    pub fn entries(&mut self) -> (&mut K, &mut V) {
154        (&mut self.key_builder, &mut self.value_builder)
155    }
156
157    #[inline]
161    pub fn append(&mut self, is_valid: bool) -> Result<(), ArrowError> {
162        if self.key_builder.len() != self.value_builder.len() {
163            return Err(ArrowError::InvalidArgumentError(format!(
164                "Cannot append to a map builder when its keys and values have unequal lengths of {} and {}",
165                self.key_builder.len(),
166                self.value_builder.len()
167            )));
168        }
169        self.offsets_builder.push(self.key_builder.len() as i32);
170        self.null_buffer_builder.append(is_valid);
171        Ok(())
172    }
173
174    pub fn finish(&mut self) -> MapArray {
176        let len = self.len();
177        let keys_arr = self.key_builder.finish();
179        let values_arr = self.value_builder.finish();
180        let offset_buffer = Buffer::from_vec(std::mem::take(&mut self.offsets_builder));
181        self.offsets_builder.push(0);
182        let null_bit_buffer = self.null_buffer_builder.finish();
183
184        self.finish_helper(keys_arr, values_arr, offset_buffer, null_bit_buffer, len)
185    }
186
187    pub fn finish_cloned(&self) -> MapArray {
189        let len = self.len();
190        let keys_arr = self.key_builder.finish_cloned();
192        let values_arr = self.value_builder.finish_cloned();
193        let offset_buffer = Buffer::from_slice_ref(self.offsets_builder.as_slice());
194        let nulls = self.null_buffer_builder.finish_cloned();
195        self.finish_helper(keys_arr, values_arr, offset_buffer, nulls, len)
196    }
197
198    fn finish_helper(
199        &self,
200        keys_arr: Arc<dyn Array>,
201        values_arr: Arc<dyn Array>,
202        offset_buffer: Buffer,
203        nulls: Option<NullBuffer>,
204        len: usize,
205    ) -> MapArray {
206        assert!(
207            keys_arr.null_count() == 0,
208            "Keys array must have no null values, found {} null value(s)",
209            keys_arr.null_count()
210        );
211
212        let keys_field = match &self.key_field {
213            Some(f) => {
214                assert!(!f.is_nullable(), "Keys field must not be nullable");
215                f.clone()
216            }
217            None => Arc::new(Field::new(
218                self.field_names.key.as_str(),
219                keys_arr.data_type().clone(),
220                false, )),
222        };
223        let values_field = match &self.value_field {
224            Some(f) => f.clone(),
225            None => Arc::new(Field::new(
226                self.field_names.value.as_str(),
227                values_arr.data_type().clone(),
228                true,
229            )),
230        };
231
232        let struct_array =
233            StructArray::from(vec![(keys_field, keys_arr), (values_field, values_arr)]);
234
235        let map_field = Arc::new(Field::new(
236            self.field_names.entry.as_str(),
237            struct_array.data_type().clone(),
238            false, ));
240        let array_data = ArrayData::builder(DataType::Map(map_field, false)) .len(len)
242            .add_buffer(offset_buffer)
243            .add_child_data(struct_array.into_data())
244            .nulls(nulls);
245
246        let array_data = unsafe { array_data.build_unchecked() };
247
248        MapArray::from(array_data)
249    }
250
251    pub fn validity_slice(&self) -> Option<&[u8]> {
253        self.null_buffer_builder.as_slice()
254    }
255}
256
257impl<K: ArrayBuilder, V: ArrayBuilder> ArrayBuilder for MapBuilder<K, V> {
258    fn len(&self) -> usize {
259        self.null_buffer_builder.len()
260    }
261
262    fn finish(&mut self) -> ArrayRef {
263        Arc::new(self.finish())
264    }
265
266    fn finish_cloned(&self) -> ArrayRef {
268        Arc::new(self.finish_cloned())
269    }
270
271    fn as_any(&self) -> &dyn Any {
272        self
273    }
274
275    fn as_any_mut(&mut self) -> &mut dyn Any {
276        self
277    }
278
279    fn into_box_any(self: Box<Self>) -> Box<dyn Any> {
280        self
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use crate::builder::{Int32Builder, StringBuilder, make_builder};
288    use crate::{Int32Array, StringArray};
289    use std::collections::HashMap;
290
291    #[test]
292    #[should_panic(expected = "Keys array must have no null values, found 1 null value(s)")]
293    fn test_map_builder_with_null_keys_panics() {
294        let mut builder = MapBuilder::new(None, StringBuilder::new(), Int32Builder::new());
295        builder.keys().append_null();
296        builder.values().append_value(42);
297        builder.append(true).unwrap();
298
299        builder.finish();
300    }
301
302    #[test]
303    fn test_boxed_map_builder() {
304        let keys_builder = make_builder(&DataType::Utf8, 5);
305        let values_builder = make_builder(&DataType::Int32, 5);
306
307        let mut builder = MapBuilder::new(None, keys_builder, values_builder);
308        builder
309            .keys()
310            .as_any_mut()
311            .downcast_mut::<StringBuilder>()
312            .expect("should be an StringBuilder")
313            .append_value("1");
314        builder
315            .values()
316            .as_any_mut()
317            .downcast_mut::<Int32Builder>()
318            .expect("should be an Int32Builder")
319            .append_value(42);
320        builder.append(true).unwrap();
321
322        let map_array = builder.finish();
323
324        assert_eq!(
325            map_array
326                .keys()
327                .as_any()
328                .downcast_ref::<StringArray>()
329                .expect("should be an StringArray")
330                .value(0),
331            "1"
332        );
333        assert_eq!(
334            map_array
335                .values()
336                .as_any()
337                .downcast_ref::<Int32Array>()
338                .expect("should be an Int32Array")
339                .value(0),
340            42
341        );
342    }
343
344    #[test]
345    fn test_with_values_field() {
346        let value_field = Arc::new(Field::new("bars", DataType::Int32, false));
347        let mut builder = MapBuilder::new(None, Int32Builder::new(), Int32Builder::new())
348            .with_values_field(value_field.clone());
349        builder.keys().append_value(1);
350        builder.values().append_value(2);
351        builder.append(true).unwrap();
352        builder.append(false).unwrap(); builder.keys().append_value(3);
354        builder.values().append_value(4);
355        builder.append(true).unwrap();
356        let map = builder.finish();
357
358        assert_eq!(map.len(), 3);
359        assert_eq!(
360            map.data_type(),
361            &DataType::Map(
362                Arc::new(Field::new(
363                    "entries",
364                    DataType::Struct(
365                        vec![
366                            Arc::new(Field::new("keys", DataType::Int32, false)),
367                            value_field.clone()
368                        ]
369                        .into()
370                    ),
371                    false,
372                )),
373                false
374            )
375        );
376
377        builder.keys().append_value(5);
378        builder.values().append_value(6);
379        builder.append(true).unwrap();
380        let map = builder.finish();
381
382        assert_eq!(map.len(), 1);
383        assert_eq!(
384            map.data_type(),
385            &DataType::Map(
386                Arc::new(Field::new(
387                    "entries",
388                    DataType::Struct(
389                        vec![
390                            Arc::new(Field::new("keys", DataType::Int32, false)),
391                            value_field
392                        ]
393                        .into()
394                    ),
395                    false,
396                )),
397                false
398            )
399        );
400    }
401
402    #[test]
403    fn test_with_keys_field() {
404        let mut key_metadata = HashMap::new();
405        key_metadata.insert("foo".to_string(), "bar".to_string());
406        let key_field = Arc::new(
407            Field::new("keys", DataType::Int32, false).with_metadata(key_metadata.clone()),
408        );
409        let mut builder = MapBuilder::new(None, Int32Builder::new(), Int32Builder::new())
410            .with_keys_field(key_field.clone());
411        builder.keys().append_value(1);
412        builder.values().append_value(2);
413        builder.append(true).unwrap();
414        let map = builder.finish();
415
416        assert_eq!(map.len(), 1);
417        assert_eq!(
418            map.data_type(),
419            &DataType::Map(
420                Arc::new(Field::new(
421                    "entries",
422                    DataType::Struct(
423                        vec![
424                            Arc::new(
425                                Field::new("keys", DataType::Int32, false)
426                                    .with_metadata(key_metadata)
427                            ),
428                            Arc::new(Field::new("values", DataType::Int32, true))
429                        ]
430                        .into()
431                    ),
432                    false,
433                )),
434                false
435            )
436        );
437    }
438
439    #[test]
440    #[should_panic(expected = "Keys field must not be nullable")]
441    fn test_with_nullable_keys_field() {
442        let mut builder = MapBuilder::new(None, Int32Builder::new(), Int32Builder::new())
443            .with_keys_field(Arc::new(Field::new("keys", DataType::Int32, true)));
444
445        builder.keys().append_value(1);
446        builder.values().append_value(2);
447        builder.append(true).unwrap();
448
449        builder.finish();
450    }
451
452    #[test]
453    #[should_panic(expected = "Incorrect datatype")]
454    fn test_keys_field_type_mismatch() {
455        let mut builder = MapBuilder::new(None, Int32Builder::new(), Int32Builder::new())
456            .with_keys_field(Arc::new(Field::new("keys", DataType::Utf8, false)));
457
458        builder.keys().append_value(1);
459        builder.values().append_value(2);
460        builder.append(true).unwrap();
461
462        builder.finish();
463    }
464}