Skip to main content

lance_file/previous/writer/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4mod statistics;
5
6use std::collections::HashMap;
7use std::marker::PhantomData;
8
9use arrow_array::builder::{ArrayBuilder, PrimitiveBuilder};
10use arrow_array::cast::{as_large_list_array, as_list_array, as_struct_array};
11use arrow_array::types::{Int32Type, Int64Type};
12use arrow_array::{Array, ArrayRef, RecordBatch, StructArray};
13use arrow_buffer::ArrowNativeType;
14use arrow_data::ArrayData;
15use arrow_schema::DataType;
16use async_recursion::async_recursion;
17use async_trait::async_trait;
18use lance_arrow::*;
19use lance_core::datatypes::{Encoding, Field, NullabilityComparison, Schema, SchemaCompareOptions};
20use lance_core::{Error, Result};
21use lance_io::encodings::{
22    Encoder, binary::BinaryEncoder, dictionary::DictionaryEncoder, plain::PlainEncoder,
23};
24use lance_io::object_store::ObjectStore;
25use lance_io::traits::{WriteExt, Writer};
26use object_store::path::Path;
27use tokio::io::AsyncWriteExt;
28
29use crate::format::{MAGIC, MAJOR_VERSION, MINOR_VERSION};
30use crate::previous::format::metadata::{Metadata, StatisticsMetadata};
31use crate::previous::page_table::{PageInfo, PageTable};
32
33/// The file format currently includes a "manifest" where it stores the schema for
34/// self-describing files.  Historically this has been a table format manifest that
35/// is empty except for the schema field.
36///
37/// Since this crate is not aware of the table format we need this to be provided
38/// externally.  You should always use lance_table::io::manifest::ManifestDescribing
39/// for this today.
40#[async_trait]
41pub trait ManifestProvider {
42    /// Store the schema in the file
43    ///
44    /// This should just require writing the schema (or a manifest wrapper) as a proto struct
45    ///
46    /// Note: the dictionaries have already been written by this point and the schema should
47    /// be populated with the dictionary lengths/offsets
48    async fn store_schema(object_writer: &mut dyn Writer, schema: &Schema)
49    -> Result<Option<usize>>;
50}
51
52/// Implementation of ManifestProvider that does not store the schema
53#[cfg(test)]
54pub(crate) struct NotSelfDescribing {}
55
56#[cfg(test)]
57#[async_trait]
58impl ManifestProvider for NotSelfDescribing {
59    async fn store_schema(_: &mut dyn Writer, _: &Schema) -> Result<Option<usize>> {
60        Ok(None)
61    }
62}
63
64/// [FileWriter] writes Arrow [RecordBatch] to one Lance file.
65///
66/// ```ignored
67/// use lance::io::FileWriter;
68/// use futures::stream::Stream;
69///
70/// let mut file_writer = FileWriter::new(object_store, &path, &schema);
71/// while let Ok(batch) = stream.next().await {
72///     file_writer.write(&batch).unwrap();
73/// }
74/// // Need to close file writer to flush buffer and footer.
75/// file_writer.shutdown();
76/// ```
77pub struct FileWriter<M: ManifestProvider + Send + Sync> {
78    pub object_writer: Box<dyn Writer>,
79    schema: Schema,
80    batch_id: i32,
81    page_table: PageTable,
82    metadata: Metadata,
83    stats_collector: Option<statistics::StatisticsCollector>,
84    manifest_provider: PhantomData<M>,
85}
86
87#[derive(Debug, Clone, Default)]
88pub struct FileWriterOptions {
89    /// The field ids to collect statistics for.
90    ///
91    /// If None, will collect for all fields in the schema (that support stats).
92    /// If an empty vector, will not collect any statistics.
93    pub collect_stats_for_fields: Option<Vec<i32>>,
94}
95
96impl<M: ManifestProvider + Send + Sync> FileWriter<M> {
97    pub async fn try_new(
98        object_store: &ObjectStore,
99        path: &Path,
100        schema: Schema,
101        options: &FileWriterOptions,
102    ) -> Result<Self> {
103        let object_writer = object_store.create(path).await?;
104        Self::with_object_writer(object_writer, schema, options)
105    }
106
107    pub fn with_object_writer(
108        object_writer: Box<dyn Writer>,
109        schema: Schema,
110        options: &FileWriterOptions,
111    ) -> Result<Self> {
112        let collect_stats_for_fields = if let Some(stats_fields) = &options.collect_stats_for_fields
113        {
114            stats_fields.clone()
115        } else {
116            schema.field_ids()
117        };
118
119        let stats_collector = if !collect_stats_for_fields.is_empty() {
120            let stats_schema = schema.project_by_ids(&collect_stats_for_fields, true);
121            statistics::StatisticsCollector::try_new(&stats_schema)
122        } else {
123            None
124        };
125
126        Ok(Self {
127            object_writer,
128            schema,
129            batch_id: 0,
130            page_table: PageTable::default(),
131            metadata: Metadata::default(),
132            stats_collector,
133            manifest_provider: PhantomData,
134        })
135    }
136
137    /// Return the schema of the file writer.
138    pub fn schema(&self) -> &Schema {
139        &self.schema
140    }
141
142    fn verify_field_nullability(arr: &ArrayData, field: &Field) -> Result<()> {
143        if !field.nullable && arr.null_count() > 0 {
144            return Err(Error::invalid_input(format!(
145                "The field `{}` contained null values even though the field is marked non-null in the schema",
146                field.name
147            )));
148        }
149
150        for (child_field, child_arr) in field.children.iter().zip(arr.child_data()) {
151            Self::verify_field_nullability(child_arr, child_field)?;
152        }
153
154        Ok(())
155    }
156
157    fn verify_nullability_constraints(&self, batch: &RecordBatch) -> Result<()> {
158        for (col, field) in batch.columns().iter().zip(self.schema.fields.iter()) {
159            Self::verify_field_nullability(&col.to_data(), field)?;
160        }
161        Ok(())
162    }
163
164    /// Write a [RecordBatch] to the open file.
165    /// All RecordBatch will be treated as one RecordBatch on disk
166    ///
167    /// Returns [Err] if the schema does not match with the batch.
168    pub async fn write(&mut self, batches: &[RecordBatch]) -> Result<()> {
169        if batches.is_empty() {
170            return Ok(());
171        }
172
173        for batch in batches {
174            // Compare, ignore metadata and dictionary
175            //   dictionary should have been checked earlier and could be an expensive check
176            let schema = Schema::try_from(batch.schema().as_ref())?;
177            schema.check_compatible(
178                &self.schema,
179                &SchemaCompareOptions {
180                    compare_nullability: NullabilityComparison::Ignore,
181                    ..Default::default()
182                },
183            )?;
184            self.verify_nullability_constraints(batch)?;
185        }
186
187        // If we are collecting stats for this column, collect them.
188        // Statistics need to traverse nested arrays, so it's a separate loop
189        // from writing which is done on top-level arrays.
190        if let Some(stats_collector) = &mut self.stats_collector {
191            for (field, arrays) in fields_in_batches(batches, &self.schema) {
192                if let Some(stats_builder) = stats_collector.get_builder(field.id) {
193                    let stats_row = statistics::collect_statistics(&arrays);
194                    stats_builder.append(stats_row);
195                }
196            }
197        }
198
199        // Copy a list of fields to avoid borrow checker error.
200        let fields = self.schema.fields.clone();
201        for field in fields.iter() {
202            let arrs = batches
203                .iter()
204                .map(|batch| {
205                    batch.column_by_name(&field.name).ok_or_else(|| {
206                        Error::invalid_input(format!(
207                            "FileWriter::write: Field '{}' not found",
208                            field.name
209                        ))
210                    })
211                })
212                .collect::<Result<Vec<_>>>()?;
213
214            Self::write_array(
215                self.object_writer.as_mut(),
216                field,
217                &arrs,
218                self.batch_id,
219                &mut self.page_table,
220            )
221            .await?;
222        }
223        let batch_length = batches.iter().map(|b| b.num_rows() as i32).sum();
224        self.metadata.push_batch_length(batch_length);
225
226        // It's imperative we complete any in-flight requests, since we are
227        // returning control to the caller. If the caller takes a long time to
228        // write the next batch, the in-flight requests will not be polled and
229        // may time out.
230        self.object_writer.flush().await?;
231
232        self.batch_id += 1;
233        Ok(())
234    }
235
236    /// Add schema metadata, as (key, value) pair to the file.
237    pub fn add_metadata(&mut self, key: &str, value: &str) {
238        self.schema
239            .metadata
240            .insert(key.to_string(), value.to_string());
241    }
242
243    pub async fn finish_with_metadata(
244        &mut self,
245        metadata: &HashMap<String, String>,
246    ) -> Result<usize> {
247        self.schema
248            .metadata
249            .extend(metadata.iter().map(|(k, y)| (k.clone(), y.clone())));
250        self.finish().await
251    }
252
253    pub async fn finish(&mut self) -> Result<usize> {
254        self.write_footer().await?;
255        Writer::shutdown(self.object_writer.as_mut()).await?;
256        let num_rows = self
257            .metadata
258            .batch_offsets
259            .last()
260            .cloned()
261            .unwrap_or_default();
262        Ok(num_rows as usize)
263    }
264
265    /// Total records written in this file.
266    pub fn len(&self) -> usize {
267        self.metadata.len()
268    }
269
270    /// Total bytes written so far
271    pub async fn tell(&mut self) -> Result<usize> {
272        self.object_writer.tell().await
273    }
274
275    /// Return the id of the next batch to be written.
276    pub fn next_batch_id(&self) -> i32 {
277        self.batch_id
278    }
279
280    pub fn is_empty(&self) -> bool {
281        self.len() == 0
282    }
283
284    #[async_recursion]
285    async fn write_array(
286        object_writer: &mut dyn Writer,
287        field: &Field,
288        arrs: &[&ArrayRef],
289        batch_id: i32,
290        page_table: &mut PageTable,
291    ) -> Result<()> {
292        assert!(!arrs.is_empty());
293        let data_type = arrs[0].data_type();
294        let arrs_ref = arrs.iter().map(|a| a.as_ref()).collect::<Vec<_>>();
295
296        match data_type {
297            DataType::Null => {
298                Self::write_null_array(
299                    object_writer,
300                    field,
301                    arrs_ref.as_slice(),
302                    batch_id,
303                    page_table,
304                )
305                .await
306            }
307            dt if dt.is_fixed_stride() => {
308                Self::write_fixed_stride_array(
309                    object_writer,
310                    field,
311                    arrs_ref.as_slice(),
312                    batch_id,
313                    page_table,
314                )
315                .await
316            }
317            dt if dt.is_binary_like() => {
318                Self::write_binary_array(
319                    object_writer,
320                    field,
321                    arrs_ref.as_slice(),
322                    batch_id,
323                    page_table,
324                )
325                .await
326            }
327            DataType::Dictionary(key_type, _) => {
328                Self::write_dictionary_arr(
329                    object_writer,
330                    field,
331                    arrs_ref.as_slice(),
332                    key_type,
333                    batch_id,
334                    page_table,
335                )
336                .await
337            }
338            dt if dt.is_struct() => {
339                let struct_arrays = arrs.iter().map(|a| as_struct_array(a)).collect::<Vec<_>>();
340                Self::write_struct_array(
341                    object_writer,
342                    field,
343                    struct_arrays.as_slice(),
344                    batch_id,
345                    page_table,
346                )
347                .await
348            }
349            DataType::FixedSizeList(_, _) | DataType::FixedSizeBinary(_) => {
350                Self::write_fixed_stride_array(
351                    object_writer,
352                    field,
353                    arrs_ref.as_slice(),
354                    batch_id,
355                    page_table,
356                )
357                .await
358            }
359            DataType::List(_) => {
360                Self::write_list_array(
361                    object_writer,
362                    field,
363                    arrs_ref.as_slice(),
364                    batch_id,
365                    page_table,
366                )
367                .await
368            }
369            DataType::LargeList(_) => {
370                Self::write_large_list_array(
371                    object_writer,
372                    field,
373                    arrs_ref.as_slice(),
374                    batch_id,
375                    page_table,
376                )
377                .await
378            }
379            _ => Err(Error::schema(format!(
380                "FileWriter::write: unsupported data type: {data_type}"
381            ))),
382        }
383    }
384
385    async fn write_null_array(
386        object_writer: &mut dyn Writer,
387        field: &Field,
388        arrs: &[&dyn Array],
389        batch_id: i32,
390        page_table: &mut PageTable,
391    ) -> Result<()> {
392        let arrs_length: i32 = arrs.iter().map(|a| a.len() as i32).sum();
393        let page_info = PageInfo::new(object_writer.tell().await?, arrs_length as usize);
394        page_table.set(field.id, batch_id, page_info);
395        Ok(())
396    }
397
398    /// Write fixed size array, including, primtiives, fixed size binary, and fixed size list.
399    async fn write_fixed_stride_array(
400        object_writer: &mut dyn Writer,
401        field: &Field,
402        arrs: &[&dyn Array],
403        batch_id: i32,
404        page_table: &mut PageTable,
405    ) -> Result<()> {
406        assert_eq!(field.encoding, Some(Encoding::Plain));
407        assert!(!arrs.is_empty());
408        let data_type = arrs[0].data_type();
409
410        let mut encoder = PlainEncoder::new(object_writer, data_type);
411        let pos = encoder.encode(arrs).await?;
412        let arrs_length: i32 = arrs.iter().map(|a| a.len() as i32).sum();
413        let page_info = PageInfo::new(pos, arrs_length as usize);
414        page_table.set(field.id, batch_id, page_info);
415        Ok(())
416    }
417
418    /// Write var-length binary arrays.
419    async fn write_binary_array(
420        object_writer: &mut dyn Writer,
421        field: &Field,
422        arrs: &[&dyn Array],
423        batch_id: i32,
424        page_table: &mut PageTable,
425    ) -> Result<()> {
426        assert_eq!(field.encoding, Some(Encoding::VarBinary));
427        let mut encoder = BinaryEncoder::new(object_writer);
428        let pos = encoder.encode(arrs).await?;
429        let arrs_length: i32 = arrs.iter().map(|a| a.len() as i32).sum();
430        let page_info = PageInfo::new(pos, arrs_length as usize);
431        page_table.set(field.id, batch_id, page_info);
432        Ok(())
433    }
434
435    async fn write_dictionary_arr(
436        object_writer: &mut dyn Writer,
437        field: &Field,
438        arrs: &[&dyn Array],
439        key_type: &DataType,
440        batch_id: i32,
441        page_table: &mut PageTable,
442    ) -> Result<()> {
443        assert_eq!(field.encoding, Some(Encoding::Dictionary));
444
445        // Write the dictionary keys.
446        let mut encoder = DictionaryEncoder::new(object_writer, key_type);
447        let pos = encoder.encode(arrs).await?;
448        let arrs_length: i32 = arrs.iter().map(|a| a.len() as i32).sum();
449        let page_info = PageInfo::new(pos, arrs_length as usize);
450        page_table.set(field.id, batch_id, page_info);
451        Ok(())
452    }
453
454    #[async_recursion]
455    async fn write_struct_array(
456        object_writer: &mut dyn Writer,
457        field: &Field,
458        arrays: &[&StructArray],
459        batch_id: i32,
460        page_table: &mut PageTable,
461    ) -> Result<()> {
462        arrays
463            .iter()
464            .for_each(|a| assert_eq!(a.num_columns(), field.children.len()));
465
466        for child in &field.children {
467            let mut arrs: Vec<&ArrayRef> = Vec::new();
468            for struct_array in arrays {
469                let arr = struct_array
470                    .column_by_name(&child.name)
471                    .ok_or(Error::schema(format!(
472                        "FileWriter: schema mismatch: column {} does not exist in array: {:?}",
473                        child.name,
474                        struct_array.data_type()
475                    )))?;
476                arrs.push(arr);
477            }
478            Self::write_array(object_writer, child, arrs.as_slice(), batch_id, page_table).await?;
479        }
480        Ok(())
481    }
482
483    async fn write_list_array(
484        object_writer: &mut dyn Writer,
485        field: &Field,
486        arrs: &[&dyn Array],
487        batch_id: i32,
488        page_table: &mut PageTable,
489    ) -> Result<()> {
490        let capacity: usize = arrs.iter().map(|a| a.len()).sum();
491        let mut list_arrs: Vec<ArrayRef> = Vec::new();
492        let mut pos_builder: PrimitiveBuilder<Int32Type> =
493            PrimitiveBuilder::with_capacity(capacity);
494
495        let mut last_offset: usize = 0;
496        pos_builder.append_value(last_offset as i32);
497        for array in arrs.iter() {
498            let list_arr = as_list_array(*array);
499            let offsets = list_arr.value_offsets();
500
501            assert!(!offsets.is_empty());
502            let start_offset = offsets[0].as_usize();
503            let end_offset = offsets[offsets.len() - 1].as_usize();
504
505            let list_values = list_arr.values();
506            let sliced_values = list_values.slice(start_offset, end_offset - start_offset);
507            list_arrs.push(sliced_values);
508
509            offsets
510                .iter()
511                .skip(1)
512                .map(|b| b.as_usize() - start_offset + last_offset)
513                .for_each(|o| pos_builder.append_value(o as i32));
514            last_offset = pos_builder.values_slice()[pos_builder.len() - 1_usize] as usize;
515        }
516
517        let positions: &dyn Array = &pos_builder.finish();
518        Self::write_fixed_stride_array(object_writer, field, &[positions], batch_id, page_table)
519            .await?;
520        let arrs = list_arrs.iter().collect::<Vec<_>>();
521        Self::write_array(
522            object_writer,
523            &field.children[0],
524            arrs.as_slice(),
525            batch_id,
526            page_table,
527        )
528        .await
529    }
530
531    async fn write_large_list_array(
532        object_writer: &mut dyn Writer,
533        field: &Field,
534        arrs: &[&dyn Array],
535        batch_id: i32,
536        page_table: &mut PageTable,
537    ) -> Result<()> {
538        let capacity: usize = arrs.iter().map(|a| a.len()).sum();
539        let mut list_arrs: Vec<ArrayRef> = Vec::new();
540        let mut pos_builder: PrimitiveBuilder<Int64Type> =
541            PrimitiveBuilder::with_capacity(capacity);
542
543        let mut last_offset: usize = 0;
544        pos_builder.append_value(last_offset as i64);
545        for array in arrs.iter() {
546            let list_arr = as_large_list_array(*array);
547            let offsets = list_arr.value_offsets();
548
549            assert!(!offsets.is_empty());
550            let start_offset = offsets[0].as_usize();
551            let end_offset = offsets[offsets.len() - 1].as_usize();
552
553            let sliced_values = list_arr
554                .values()
555                .slice(start_offset, end_offset - start_offset);
556            list_arrs.push(sliced_values);
557
558            offsets
559                .iter()
560                .skip(1)
561                .map(|b| b.as_usize() - start_offset + last_offset)
562                .for_each(|o| pos_builder.append_value(o as i64));
563            last_offset = pos_builder.values_slice()[pos_builder.len() - 1_usize] as usize;
564        }
565
566        let positions: &dyn Array = &pos_builder.finish();
567        Self::write_fixed_stride_array(object_writer, field, &[positions], batch_id, page_table)
568            .await?;
569        let arrs = list_arrs.iter().collect::<Vec<_>>();
570        Self::write_array(
571            object_writer,
572            &field.children[0],
573            arrs.as_slice(),
574            batch_id,
575            page_table,
576        )
577        .await
578    }
579
580    async fn write_statistics(&mut self) -> Result<Option<StatisticsMetadata>> {
581        let statistics = self
582            .stats_collector
583            .as_mut()
584            .map(|collector| collector.finish());
585
586        match statistics {
587            Some(Ok(stats_batch)) if stats_batch.num_rows() > 0 => {
588                debug_assert_eq!(self.next_batch_id() as usize, stats_batch.num_rows());
589                let schema = Schema::try_from(stats_batch.schema().as_ref())?;
590                let leaf_field_ids = schema.field_ids();
591
592                let mut stats_page_table = PageTable::default();
593                for (i, field) in schema.fields.iter().enumerate() {
594                    Self::write_array(
595                        self.object_writer.as_mut(),
596                        field,
597                        &[stats_batch.column(i)],
598                        0, // Only one batch for statistics.
599                        &mut stats_page_table,
600                    )
601                    .await?;
602                }
603
604                let page_table_position = stats_page_table
605                    .write(self.object_writer.as_mut(), 0)
606                    .await?;
607
608                Ok(Some(StatisticsMetadata {
609                    schema,
610                    leaf_field_ids,
611                    page_table_position,
612                }))
613            }
614            Some(Err(e)) => Err(e),
615            _ => Ok(None),
616        }
617    }
618
619    /// Writes the dictionaries (using plain/binary encoding) into the file
620    ///
621    /// The offsets and lengths of the written buffers are stored in the given
622    /// schema so that the dictionaries can be loaded in the future.
623    async fn write_dictionaries(writer: &mut dyn Writer, schema: &mut Schema) -> Result<()> {
624        // Write dictionary values.
625        let max_field_id = schema.max_field_id().unwrap_or(-1);
626        for field_id in 0..max_field_id + 1 {
627            if let Some(field) = schema.mut_field_by_id(field_id)
628                && field.data_type().is_dictionary()
629            {
630                let dict_info = field.dictionary.as_mut().ok_or_else(|| {
631                    // and wrap it in here.
632                    Error::io(format!("Lance field {} misses dictionary info", field.name))
633                })?;
634
635                let value_arr = dict_info.values.as_ref().ok_or_else(|| {
636                    Error::invalid_input(format!(
637                        "Lance field {} is dictionary type, but misses the dictionary value array",
638                        field.name
639                    ))
640                })?;
641
642                let data_type = value_arr.data_type();
643                let pos = match data_type {
644                    dt if dt.is_numeric() => {
645                        let mut encoder = PlainEncoder::new(writer, dt);
646                        encoder.encode(&[value_arr]).await?
647                    }
648                    dt if dt.is_binary_like() => {
649                        let mut encoder = BinaryEncoder::new(writer);
650                        encoder.encode(&[value_arr]).await?
651                    }
652                    _ => {
653                        return Err(Error::schema(format!(
654                            "Does not support {} as dictionary value type",
655                            value_arr.data_type()
656                        )));
657                    }
658                };
659                dict_info.offset = pos;
660                dict_info.length = value_arr.len();
661            }
662        }
663        Ok(())
664    }
665
666    async fn write_footer(&mut self) -> Result<()> {
667        // Step 1. Write page table.
668        let field_id_offset = *self.schema.field_ids().iter().min().unwrap();
669        let pos = self
670            .page_table
671            .write(self.object_writer.as_mut(), field_id_offset)
672            .await?;
673        self.metadata.page_table_position = pos;
674
675        // Step 2. Write statistics.
676        self.metadata.stats_metadata = self.write_statistics().await?;
677
678        // Step 3. Write manifest and dictionary values.
679        Self::write_dictionaries(self.object_writer.as_mut(), &mut self.schema).await?;
680        let pos = M::store_schema(self.object_writer.as_mut(), &self.schema).await?;
681
682        // Step 4. Write metadata.
683        self.metadata.manifest_position = pos;
684        let pos = self.object_writer.write_struct(&self.metadata).await?;
685
686        // Step 5. Write magics.
687        self.object_writer
688            .write_magics(pos, MAJOR_VERSION, MINOR_VERSION, MAGIC)
689            .await
690    }
691}
692
693/// Walk through the schema and return arrays with their Lance field.
694///
695/// This skips over nested arrays and fields within list arrays. It does walk
696/// over the children of structs.
697fn fields_in_batches<'a>(
698    batches: &'a [RecordBatch],
699    schema: &'a Schema,
700) -> impl Iterator<Item = (&'a Field, Vec<&'a ArrayRef>)> {
701    let num_columns = batches[0].num_columns();
702    let array_iters = (0..num_columns).map(|col_i| {
703        batches
704            .iter()
705            .map(|batch| batch.column(col_i))
706            .collect::<Vec<_>>()
707    });
708    let mut to_visit: Vec<(&'a Field, Vec<&'a ArrayRef>)> =
709        schema.fields.iter().zip(array_iters).collect();
710
711    std::iter::from_fn(move || {
712        loop {
713            let (field, arrays): (_, Vec<&'a ArrayRef>) = to_visit.pop()?;
714            match field.data_type() {
715                DataType::Struct(_) => {
716                    for (i, child_field) in field.children.iter().enumerate() {
717                        let child_arrays = arrays
718                            .iter()
719                            .map(|arr| as_struct_array(*arr).column(i))
720                            .collect::<Vec<&'a ArrayRef>>();
721                        to_visit.push((child_field, child_arrays));
722                    }
723                    continue;
724                }
725                // We only walk structs right now.
726                _ if field.data_type().is_nested() => continue,
727                _ => return Some((field, arrays)),
728            }
729        }
730    })
731}
732
733#[cfg(test)]
734mod tests {
735    use super::*;
736
737    use std::sync::Arc;
738
739    use arrow_array::{
740        BooleanArray, Decimal128Array, Decimal256Array, DictionaryArray, DurationMicrosecondArray,
741        DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray,
742        FixedSizeBinaryArray, FixedSizeListArray, Float32Array, Int32Array, Int64Array, ListArray,
743        NullArray, StringArray, TimestampMicrosecondArray, TimestampSecondArray, UInt8Array,
744        types::UInt32Type,
745    };
746    use arrow_buffer::i256;
747    use arrow_schema::{
748        Field as ArrowField, Fields as ArrowFields, Schema as ArrowSchema, TimeUnit,
749    };
750    use arrow_select::concat::concat_batches;
751
752    use crate::previous::reader::FileReader;
753
754    #[tokio::test]
755    async fn test_write_file() {
756        let arrow_schema = ArrowSchema::new(vec![
757            ArrowField::new("null", DataType::Null, true),
758            ArrowField::new("bool", DataType::Boolean, true),
759            ArrowField::new("i", DataType::Int64, true),
760            ArrowField::new("f", DataType::Float32, false),
761            ArrowField::new("b", DataType::Utf8, true),
762            ArrowField::new("decimal128", DataType::Decimal128(7, 3), false),
763            ArrowField::new("decimal256", DataType::Decimal256(7, 3), false),
764            ArrowField::new("duration_sec", DataType::Duration(TimeUnit::Second), false),
765            ArrowField::new(
766                "duration_msec",
767                DataType::Duration(TimeUnit::Millisecond),
768                false,
769            ),
770            ArrowField::new(
771                "duration_usec",
772                DataType::Duration(TimeUnit::Microsecond),
773                false,
774            ),
775            ArrowField::new(
776                "duration_nsec",
777                DataType::Duration(TimeUnit::Nanosecond),
778                false,
779            ),
780            ArrowField::new(
781                "d",
782                DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
783                true,
784            ),
785            ArrowField::new(
786                "fixed_size_list",
787                DataType::FixedSizeList(
788                    Arc::new(ArrowField::new("item", DataType::Float32, true)),
789                    16,
790                ),
791                true,
792            ),
793            ArrowField::new("fixed_size_binary", DataType::FixedSizeBinary(8), true),
794            ArrowField::new(
795                "l",
796                DataType::List(Arc::new(ArrowField::new("item", DataType::Utf8, true))),
797                true,
798            ),
799            ArrowField::new(
800                "large_l",
801                DataType::LargeList(Arc::new(ArrowField::new("item", DataType::Utf8, true))),
802                true,
803            ),
804            ArrowField::new(
805                "l_dict",
806                DataType::List(Arc::new(ArrowField::new(
807                    "item",
808                    DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
809                    true,
810                ))),
811                true,
812            ),
813            ArrowField::new(
814                "large_l_dict",
815                DataType::LargeList(Arc::new(ArrowField::new(
816                    "item",
817                    DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
818                    true,
819                ))),
820                true,
821            ),
822            ArrowField::new(
823                "s",
824                DataType::Struct(ArrowFields::from(vec![
825                    ArrowField::new("si", DataType::Int64, true),
826                    ArrowField::new("sb", DataType::Utf8, true),
827                ])),
828                true,
829            ),
830        ]);
831        let mut schema = Schema::try_from(&arrow_schema).unwrap();
832
833        let dict_vec = (0..100).map(|n| ["a", "b", "c"][n % 3]).collect::<Vec<_>>();
834        let dict_arr: DictionaryArray<UInt32Type> = dict_vec.into_iter().collect();
835
836        let fixed_size_list_arr = FixedSizeListArray::try_new_from_values(
837            Float32Array::from_iter((0..1600).map(|n| n as f32).collect::<Vec<_>>()),
838            16,
839        )
840        .unwrap();
841
842        let binary_data: [u8; 800] = [123; 800];
843        let fixed_size_binary_arr =
844            FixedSizeBinaryArray::try_new_from_values(&UInt8Array::from_iter(binary_data), 8)
845                .unwrap();
846
847        let list_offsets: Int32Array = (0..202).step_by(2).collect();
848        let list_values =
849            StringArray::from((0..200).map(|n| format!("str-{}", n)).collect::<Vec<_>>());
850        let list_arr: arrow_array::GenericListArray<i32> =
851            try_new_generic_list_array(list_values, &list_offsets).unwrap();
852
853        let large_list_offsets: Int64Array = (0..202).step_by(2).collect();
854        let large_list_values =
855            StringArray::from((0..200).map(|n| format!("str-{}", n)).collect::<Vec<_>>());
856        let large_list_arr: arrow_array::GenericListArray<i64> =
857            try_new_generic_list_array(large_list_values, &large_list_offsets).unwrap();
858
859        let list_dict_offsets: Int32Array = (0..202).step_by(2).collect();
860        let list_dict_vec = (0..200).map(|n| ["a", "b", "c"][n % 3]).collect::<Vec<_>>();
861        let list_dict_arr: DictionaryArray<UInt32Type> = list_dict_vec.into_iter().collect();
862        let list_dict_arr: arrow_array::GenericListArray<i32> =
863            try_new_generic_list_array(list_dict_arr, &list_dict_offsets).unwrap();
864
865        let large_list_dict_offsets: Int64Array = (0..202).step_by(2).collect();
866        let large_list_dict_vec = (0..200).map(|n| ["a", "b", "c"][n % 3]).collect::<Vec<_>>();
867        let large_list_dict_arr: DictionaryArray<UInt32Type> =
868            large_list_dict_vec.into_iter().collect();
869        let large_list_dict_arr: arrow_array::GenericListArray<i64> =
870            try_new_generic_list_array(large_list_dict_arr, &large_list_dict_offsets).unwrap();
871
872        let columns: Vec<ArrayRef> = vec![
873            Arc::new(NullArray::new(100)),
874            Arc::new(BooleanArray::from_iter(
875                (0..100).map(|f| Some(f % 3 == 0)).collect::<Vec<_>>(),
876            )),
877            Arc::new(Int64Array::from_iter((0..100).collect::<Vec<_>>())),
878            Arc::new(Float32Array::from_iter(
879                (0..100).map(|n| n as f32).collect::<Vec<_>>(),
880            )),
881            Arc::new(StringArray::from(
882                (0..100).map(|n| n.to_string()).collect::<Vec<_>>(),
883            )),
884            Arc::new(
885                Decimal128Array::from_iter_values(0..100)
886                    .with_precision_and_scale(7, 3)
887                    .unwrap(),
888            ),
889            Arc::new(
890                Decimal256Array::from_iter_values((0..100).map(|v| i256::from_i128(v as i128)))
891                    .with_precision_and_scale(7, 3)
892                    .unwrap(),
893            ),
894            Arc::new(DurationSecondArray::from_iter_values(0..100)),
895            Arc::new(DurationMillisecondArray::from_iter_values(0..100)),
896            Arc::new(DurationMicrosecondArray::from_iter_values(0..100)),
897            Arc::new(DurationNanosecondArray::from_iter_values(0..100)),
898            Arc::new(dict_arr),
899            Arc::new(fixed_size_list_arr),
900            Arc::new(fixed_size_binary_arr),
901            Arc::new(list_arr),
902            Arc::new(large_list_arr),
903            Arc::new(list_dict_arr),
904            Arc::new(large_list_dict_arr),
905            Arc::new(StructArray::from(vec![
906                (
907                    Arc::new(ArrowField::new("si", DataType::Int64, true)),
908                    Arc::new(Int64Array::from_iter((100..200).collect::<Vec<_>>())) as ArrayRef,
909                ),
910                (
911                    Arc::new(ArrowField::new("sb", DataType::Utf8, true)),
912                    Arc::new(StringArray::from(
913                        (0..100).map(|n| n.to_string()).collect::<Vec<_>>(),
914                    )) as ArrayRef,
915                ),
916            ])),
917        ];
918        let batch = RecordBatch::try_new(Arc::new(arrow_schema), columns).unwrap();
919        schema.set_dictionary(&batch).unwrap();
920
921        let store = ObjectStore::memory();
922        let path = Path::from("/foo");
923        let mut file_writer = FileWriter::<NotSelfDescribing>::try_new(
924            &store,
925            &path,
926            schema.clone(),
927            &Default::default(),
928        )
929        .await
930        .unwrap();
931        file_writer
932            .write(std::slice::from_ref(&batch))
933            .await
934            .unwrap();
935        file_writer.finish().await.unwrap();
936
937        let reader = FileReader::try_new(&store, &path, schema).await.unwrap();
938        let actual = reader.read_batch(0, .., reader.schema()).await.unwrap();
939        assert_eq!(actual, batch);
940    }
941
942    #[tokio::test]
943    async fn test_dictionary_first_element_file() {
944        let arrow_schema = ArrowSchema::new(vec![ArrowField::new(
945            "d",
946            DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
947            true,
948        )]);
949        let mut schema = Schema::try_from(&arrow_schema).unwrap();
950
951        let dict_vec = (0..100).map(|n| ["a", "b", "c"][n % 3]).collect::<Vec<_>>();
952        let dict_arr: DictionaryArray<UInt32Type> = dict_vec.into_iter().collect();
953
954        let columns: Vec<ArrayRef> = vec![Arc::new(dict_arr)];
955        let batch = RecordBatch::try_new(Arc::new(arrow_schema), columns).unwrap();
956        schema.set_dictionary(&batch).unwrap();
957
958        let store = ObjectStore::memory();
959        let path = Path::from("/foo");
960        let mut file_writer = FileWriter::<NotSelfDescribing>::try_new(
961            &store,
962            &path,
963            schema.clone(),
964            &Default::default(),
965        )
966        .await
967        .unwrap();
968        file_writer
969            .write(std::slice::from_ref(&batch))
970            .await
971            .unwrap();
972        file_writer.finish().await.unwrap();
973
974        let reader = FileReader::try_new(&store, &path, schema).await.unwrap();
975        let actual = reader.read_batch(0, .., reader.schema()).await.unwrap();
976        assert_eq!(actual, batch);
977    }
978
979    #[tokio::test]
980    async fn test_write_temporal_types() {
981        let arrow_schema = Arc::new(ArrowSchema::new(vec![
982            ArrowField::new(
983                "ts_notz",
984                DataType::Timestamp(TimeUnit::Second, None),
985                false,
986            ),
987            ArrowField::new(
988                "ts_tz",
989                DataType::Timestamp(TimeUnit::Microsecond, Some("America/Los_Angeles".into())),
990                false,
991            ),
992        ]));
993        let columns: Vec<ArrayRef> = vec![
994            Arc::new(TimestampSecondArray::from(vec![11111111, 22222222])),
995            Arc::new(
996                TimestampMicrosecondArray::from(vec![3333333, 4444444])
997                    .with_timezone("America/Los_Angeles"),
998            ),
999        ];
1000        let batch = RecordBatch::try_new(arrow_schema.clone(), columns).unwrap();
1001
1002        let schema = Schema::try_from(arrow_schema.as_ref()).unwrap();
1003        let store = ObjectStore::memory();
1004        let path = Path::from("/foo");
1005        let mut file_writer = FileWriter::<NotSelfDescribing>::try_new(
1006            &store,
1007            &path,
1008            schema.clone(),
1009            &Default::default(),
1010        )
1011        .await
1012        .unwrap();
1013        file_writer
1014            .write(std::slice::from_ref(&batch))
1015            .await
1016            .unwrap();
1017        file_writer.finish().await.unwrap();
1018
1019        let reader = FileReader::try_new(&store, &path, schema).await.unwrap();
1020        let actual = reader.read_batch(0, .., reader.schema()).await.unwrap();
1021        assert_eq!(actual, batch);
1022    }
1023
1024    #[tokio::test]
1025    async fn test_collect_stats() {
1026        // Validate:
1027        // Only collects stats for requested columns
1028        // Can collect stats in nested structs
1029        // Won't collect stats for list columns (for now)
1030
1031        let arrow_schema = ArrowSchema::new(vec![
1032            ArrowField::new("i", DataType::Int64, true),
1033            ArrowField::new("i2", DataType::Int64, true),
1034            ArrowField::new(
1035                "l",
1036                DataType::List(Arc::new(ArrowField::new("item", DataType::Int32, true))),
1037                true,
1038            ),
1039            ArrowField::new(
1040                "s",
1041                DataType::Struct(ArrowFields::from(vec![
1042                    ArrowField::new("si", DataType::Int64, true),
1043                    ArrowField::new("sb", DataType::Utf8, true),
1044                ])),
1045                true,
1046            ),
1047        ]);
1048
1049        let schema = Schema::try_from(&arrow_schema).unwrap();
1050
1051        let store = ObjectStore::memory();
1052        let path = Path::from("/foo");
1053
1054        let options = FileWriterOptions {
1055            collect_stats_for_fields: Some(vec![0, 1, 5, 6]),
1056        };
1057        let mut file_writer =
1058            FileWriter::<NotSelfDescribing>::try_new(&store, &path, schema.clone(), &options)
1059                .await
1060                .unwrap();
1061
1062        let batch1 = RecordBatch::try_new(
1063            Arc::new(arrow_schema.clone()),
1064            vec![
1065                Arc::new(Int64Array::from(vec![1, 2, 3])),
1066                Arc::new(Int64Array::from(vec![4, 5, 6])),
1067                Arc::new(ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
1068                    Some(vec![Some(1i32), Some(2), Some(3)]),
1069                    Some(vec![Some(4), Some(5)]),
1070                    Some(vec![]),
1071                ])),
1072                Arc::new(StructArray::from(vec![
1073                    (
1074                        Arc::new(ArrowField::new("si", DataType::Int64, true)),
1075                        Arc::new(Int64Array::from(vec![1, 2, 3])) as ArrayRef,
1076                    ),
1077                    (
1078                        Arc::new(ArrowField::new("sb", DataType::Utf8, true)),
1079                        Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef,
1080                    ),
1081                ])),
1082            ],
1083        )
1084        .unwrap();
1085        file_writer.write(&[batch1]).await.unwrap();
1086
1087        let batch2 = RecordBatch::try_new(
1088            Arc::new(arrow_schema.clone()),
1089            vec![
1090                Arc::new(Int64Array::from(vec![5, 6])),
1091                Arc::new(Int64Array::from(vec![10, 11])),
1092                Arc::new(ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
1093                    Some(vec![Some(1i32), Some(2), Some(3)]),
1094                    Some(vec![]),
1095                ])),
1096                Arc::new(StructArray::from(vec![
1097                    (
1098                        Arc::new(ArrowField::new("si", DataType::Int64, true)),
1099                        Arc::new(Int64Array::from(vec![4, 5])) as ArrayRef,
1100                    ),
1101                    (
1102                        Arc::new(ArrowField::new("sb", DataType::Utf8, true)),
1103                        Arc::new(StringArray::from(vec!["d", "e"])) as ArrayRef,
1104                    ),
1105                ])),
1106            ],
1107        )
1108        .unwrap();
1109        file_writer.write(&[batch2]).await.unwrap();
1110
1111        file_writer.finish().await.unwrap();
1112
1113        let reader = FileReader::try_new(&store, &path, schema).await.unwrap();
1114
1115        let read_stats = reader.read_page_stats(&[0, 1, 5, 6]).await.unwrap();
1116        assert!(read_stats.is_some());
1117        let read_stats = read_stats.unwrap();
1118
1119        let expected_stats_schema = stats_schema([
1120            (0, DataType::Int64),
1121            (1, DataType::Int64),
1122            (5, DataType::Int64),
1123            (6, DataType::Utf8),
1124        ]);
1125
1126        assert_eq!(read_stats.schema().as_ref(), &expected_stats_schema);
1127
1128        let expected_stats = stats_batch(&[
1129            Stats {
1130                field_id: 0,
1131                null_counts: vec![0, 0],
1132                min_values: Arc::new(Int64Array::from(vec![1, 5])),
1133                max_values: Arc::new(Int64Array::from(vec![3, 6])),
1134            },
1135            Stats {
1136                field_id: 1,
1137                null_counts: vec![0, 0],
1138                min_values: Arc::new(Int64Array::from(vec![4, 10])),
1139                max_values: Arc::new(Int64Array::from(vec![6, 11])),
1140            },
1141            Stats {
1142                field_id: 5,
1143                null_counts: vec![0, 0],
1144                min_values: Arc::new(Int64Array::from(vec![1, 4])),
1145                max_values: Arc::new(Int64Array::from(vec![3, 5])),
1146            },
1147            // FIXME: these max values shouldn't be incremented
1148            // https://github.com/lancedb/lance/issues/1517
1149            Stats {
1150                field_id: 6,
1151                null_counts: vec![0, 0],
1152                min_values: Arc::new(StringArray::from(vec!["a", "d"])),
1153                max_values: Arc::new(StringArray::from(vec!["c", "e"])),
1154            },
1155        ]);
1156
1157        assert_eq!(read_stats, expected_stats);
1158    }
1159
1160    fn stats_schema(data_fields: impl IntoIterator<Item = (i32, DataType)>) -> ArrowSchema {
1161        let fields = data_fields
1162            .into_iter()
1163            .map(|(field_id, data_type)| {
1164                Arc::new(ArrowField::new(
1165                    format!("{}", field_id),
1166                    DataType::Struct(
1167                        vec![
1168                            Arc::new(ArrowField::new("null_count", DataType::Int64, false)),
1169                            Arc::new(ArrowField::new("min_value", data_type.clone(), true)),
1170                            Arc::new(ArrowField::new("max_value", data_type, true)),
1171                        ]
1172                        .into(),
1173                    ),
1174                    false,
1175                ))
1176            })
1177            .collect::<Vec<_>>();
1178        ArrowSchema::new(fields)
1179    }
1180
1181    struct Stats {
1182        field_id: i32,
1183        null_counts: Vec<i64>,
1184        min_values: ArrayRef,
1185        max_values: ArrayRef,
1186    }
1187
1188    fn stats_batch(stats: &[Stats]) -> RecordBatch {
1189        let schema = stats_schema(
1190            stats
1191                .iter()
1192                .map(|s| (s.field_id, s.min_values.data_type().clone())),
1193        );
1194
1195        let columns = stats
1196            .iter()
1197            .map(|s| {
1198                let data_type = s.min_values.data_type().clone();
1199                let fields = vec![
1200                    Arc::new(ArrowField::new("null_count", DataType::Int64, false)),
1201                    Arc::new(ArrowField::new("min_value", data_type.clone(), true)),
1202                    Arc::new(ArrowField::new("max_value", data_type, true)),
1203                ];
1204                let arrays = vec![
1205                    Arc::new(Int64Array::from(s.null_counts.clone())),
1206                    s.min_values.clone(),
1207                    s.max_values.clone(),
1208                ];
1209                Arc::new(StructArray::new(fields.into(), arrays, None)) as ArrayRef
1210            })
1211            .collect();
1212
1213        RecordBatch::try_new(Arc::new(schema), columns).unwrap()
1214    }
1215
1216    async fn read_file_as_one_batch(
1217        object_store: &ObjectStore,
1218        path: &Path,
1219        schema: Schema,
1220    ) -> RecordBatch {
1221        let reader = FileReader::try_new(object_store, path, schema)
1222            .await
1223            .unwrap();
1224        let mut batches = vec![];
1225        for i in 0..reader.num_batches() {
1226            batches.push(
1227                reader
1228                    .read_batch(i as i32, .., reader.schema())
1229                    .await
1230                    .unwrap(),
1231            );
1232        }
1233        let arrow_schema = Arc::new(reader.schema().into());
1234        concat_batches(&arrow_schema, &batches).unwrap()
1235    }
1236
1237    /// Test encoding arrays that share the same underneath buffer.
1238    #[tokio::test]
1239    async fn test_encode_slice() {
1240        let store = ObjectStore::memory();
1241        let path = Path::from("/shared_slice");
1242
1243        let arrow_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
1244            "i",
1245            DataType::Int32,
1246            false,
1247        )]));
1248        let schema = Schema::try_from(arrow_schema.as_ref()).unwrap();
1249        let mut file_writer = FileWriter::<NotSelfDescribing>::try_new(
1250            &store,
1251            &path,
1252            schema.clone(),
1253            &Default::default(),
1254        )
1255        .await
1256        .unwrap();
1257
1258        let array = Int32Array::from_iter_values(0..1000);
1259
1260        for i in (0..1000).step_by(4) {
1261            let data = array.slice(i, 4);
1262            file_writer
1263                .write(&[RecordBatch::try_new(arrow_schema.clone(), vec![Arc::new(data)]).unwrap()])
1264                .await
1265                .unwrap();
1266        }
1267        file_writer.finish().await.unwrap();
1268        assert!(store.size(&path).await.unwrap() < 2 * 8 * 1000);
1269
1270        let batch = read_file_as_one_batch(&store, &path, schema).await;
1271        assert_eq!(batch.column_by_name("i").unwrap().as_ref(), &array);
1272    }
1273
1274    #[tokio::test]
1275    async fn test_write_schema_with_holes() {
1276        let store = ObjectStore::memory();
1277        let path = Path::from("test");
1278
1279        let mut field0 = Field::try_from(&ArrowField::new("a", DataType::Int32, true)).unwrap();
1280        field0.set_id(-1, &mut 0);
1281        assert_eq!(field0.id, 0);
1282        let mut field2 = Field::try_from(&ArrowField::new("b", DataType::Int32, true)).unwrap();
1283        field2.set_id(-1, &mut 2);
1284        assert_eq!(field2.id, 2);
1285        // There is a hole at field id 1.
1286        let schema = Schema {
1287            fields: vec![field0, field2],
1288            metadata: Default::default(),
1289        };
1290
1291        let arrow_schema = Arc::new(ArrowSchema::new(vec![
1292            ArrowField::new("a", DataType::Int32, true),
1293            ArrowField::new("b", DataType::Int32, true),
1294        ]));
1295        let data = RecordBatch::try_new(
1296            arrow_schema.clone(),
1297            vec![
1298                Arc::new(Int32Array::from_iter_values(0..10)),
1299                Arc::new(Int32Array::from_iter_values(10..20)),
1300            ],
1301        )
1302        .unwrap();
1303
1304        let mut file_writer = FileWriter::<NotSelfDescribing>::try_new(
1305            &store,
1306            &path,
1307            schema.clone(),
1308            &Default::default(),
1309        )
1310        .await
1311        .unwrap();
1312        file_writer.write(&[data]).await.unwrap();
1313        file_writer.finish().await.unwrap();
1314
1315        let page_table = file_writer.page_table;
1316        assert!(page_table.get(0, 0).is_some());
1317        assert!(page_table.get(2, 0).is_some());
1318    }
1319}