Skip to main content

apiary_storage/
cell_writer.rs

1//! Cell writer for creating Parquet files in object storage.
2//!
3//! The cell writer handles partitioning incoming data, writing Parquet cells
4//! with LZ4 compression, computing cell-level statistics, and respecting
5//! leafcutter sizing policies.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use arrow::array::*;
11use arrow::compute;
12use arrow::datatypes::{
13    ArrowPrimitiveType, DataType, Field, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type,
14    Int8Type, Schema, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
15};
16use arrow::record_batch::RecordBatch;
17use bytes::Bytes;
18use parquet::arrow::ArrowWriter;
19use parquet::basic::Compression;
20use parquet::file::properties::WriterProperties;
21use tracing::{info, warn};
22use uuid::Uuid;
23
24use apiary_core::{
25    ApiaryError, CellId, CellMetadata, CellSizingPolicy, ColumnStats, FieldDef, FrameSchema,
26    Result, StorageBackend,
27};
28
29/// Writes Arrow RecordBatches as Parquet cells to object storage.
30pub struct CellWriter {
31    storage: Arc<dyn StorageBackend>,
32    frame_path: String,
33    schema: FrameSchema,
34    partition_by: Vec<String>,
35    sizing: CellSizingPolicy,
36}
37
38impl CellWriter {
39    /// Create a new CellWriter.
40    pub fn new(
41        storage: Arc<dyn StorageBackend>,
42        frame_path: String,
43        schema: FrameSchema,
44        partition_by: Vec<String>,
45        sizing: CellSizingPolicy,
46    ) -> Self {
47        Self {
48            storage,
49            frame_path,
50            schema,
51            partition_by,
52            sizing,
53        }
54    }
55
56    /// Write a RecordBatch, partitioning and sizing as needed.
57    /// Returns the CellMetadata for each cell written.
58    pub async fn write(&self, batch: &RecordBatch) -> Result<Vec<CellMetadata>> {
59        // Validate schema
60        self.validate_schema(batch)?;
61
62        // Partition the data
63        let partitions = self.partition_data(batch)?;
64
65        let mut all_cells = Vec::new();
66
67        for (partition_values, partition_batch) in &partitions {
68            // Split by cell sizing if needed
69            let sub_batches = self.split_by_size(partition_batch)?;
70
71            for sub_batch in &sub_batches {
72                let cell = self.write_cell(sub_batch, partition_values).await?;
73                all_cells.push(cell);
74            }
75        }
76
77        Ok(all_cells)
78    }
79
80    /// Validate the incoming batch against the frame schema.
81    ///
82    /// Rules (from architecture doc 02-storage-engine.md):
83    /// - Safely castable types → implicit cast handled at read time
84    /// - Extra columns in write data → dropped with warning
85    /// - Missing nullable column → filled with null
86    /// - Missing non-nullable column → error
87    /// - Null value in partition column → error
88    fn validate_schema(&self, batch: &RecordBatch) -> Result<()> {
89        // Check for missing non-nullable columns
90        for field in &self.schema.fields {
91            let found = batch.schema().index_of(&field.name).ok();
92            if found.is_none() && !field.nullable {
93                return Err(ApiaryError::Schema {
94                    message: format!("Missing non-nullable column '{}' in write data", field.name),
95                });
96            }
97        }
98
99        // Check extra columns (warn but don't error)
100        for batch_field in batch.schema().fields() {
101            let in_schema = self
102                .schema
103                .fields
104                .iter()
105                .any(|f| f.name == *batch_field.name());
106            if !in_schema {
107                warn!(
108                    column = %batch_field.name(),
109                    "Extra column in write data will be dropped"
110                );
111            }
112        }
113
114        // Check that partition columns have no nulls and no path traversal characters
115        for part_col in &self.partition_by {
116            if let Ok(col_idx) = batch.schema().index_of(part_col) {
117                let col = batch.column(col_idx);
118                if col.null_count() > 0 {
119                    return Err(ApiaryError::Schema {
120                        message: format!("Partition column '{}' contains null values", part_col),
121                    });
122                }
123                // Validate partition values don't contain path traversal characters
124                for row_idx in 0..batch.num_rows() {
125                    let val = array_value_to_string(col, row_idx);
126                    if val.contains("..")
127                        || val.contains('/')
128                        || val.contains('\\')
129                        || val.contains('\0')
130                    {
131                        return Err(ApiaryError::Schema {
132                            message: format!(
133                                "Partition column '{}' contains invalid characters (path separators or '..'): '{}'",
134                                part_col, val
135                            ),
136                        });
137                    }
138                }
139            }
140        }
141        Ok(())
142    }
143
144    /// Partition data by partition column values.
145    fn partition_data(
146        &self,
147        batch: &RecordBatch,
148    ) -> Result<Vec<(HashMap<String, String>, RecordBatch)>> {
149        if self.partition_by.is_empty() || batch.num_rows() == 0 {
150            return Ok(vec![(HashMap::new(), batch.clone())]);
151        }
152
153        // Build partition keys for each row
154        let mut partition_keys: Vec<HashMap<String, String>> = Vec::new();
155        for row_idx in 0..batch.num_rows() {
156            let mut key = HashMap::new();
157            for col_name in &self.partition_by {
158                let col_idx =
159                    batch
160                        .schema()
161                        .index_of(col_name)
162                        .map_err(|_| ApiaryError::Schema {
163                            message: format!("Partition column '{}' not found in data", col_name),
164                        })?;
165                let col = batch.column(col_idx);
166                let val = array_value_to_string(col, row_idx);
167                key.insert(col_name.clone(), val);
168            }
169            partition_keys.push(key);
170        }
171
172        // Group rows by partition key
173        let mut groups: HashMap<String, (HashMap<String, String>, Vec<usize>)> = HashMap::new();
174        for (row_idx, key) in partition_keys.iter().enumerate() {
175            let key_str = partition_key_string(key, &self.partition_by);
176            groups
177                .entry(key_str)
178                .or_insert_with(|| (key.clone(), Vec::new()))
179                .1
180                .push(row_idx);
181        }
182
183        // Build sub-batches
184        let mut result = Vec::new();
185        for (_, (partition_values, row_indices)) in groups {
186            let indices =
187                UInt32Array::from(row_indices.iter().map(|i| *i as u32).collect::<Vec<_>>());
188            let columns: Vec<ArrayRef> = batch
189                .columns()
190                .iter()
191                .map(|col| compute::take(col, &indices, None).unwrap())
192                .collect();
193            let sub_batch = RecordBatch::try_new(batch.schema(), columns).map_err(|e| {
194                ApiaryError::Internal {
195                    message: format!("Failed to create partition batch: {}", e),
196                }
197            })?;
198            result.push((partition_values, sub_batch));
199        }
200
201        Ok(result)
202    }
203
204    /// Split a batch into multiple batches if it exceeds the target cell size.
205    fn split_by_size(&self, batch: &RecordBatch) -> Result<Vec<RecordBatch>> {
206        // Estimate the size of the batch
207        let estimated_size: usize = batch
208            .columns()
209            .iter()
210            .map(|col| col.get_buffer_memory_size())
211            .sum();
212
213        let target = self.sizing.target_cell_size as usize;
214
215        if estimated_size <= target || batch.num_rows() <= 1 {
216            return Ok(vec![batch.clone()]);
217        }
218
219        // Calculate how many chunks we need
220        let num_chunks = estimated_size.div_ceil(target);
221        let rows_per_chunk = batch.num_rows().div_ceil(num_chunks);
222        let rows_per_chunk = rows_per_chunk.max(1);
223
224        let mut batches = Vec::new();
225        let mut start = 0;
226
227        while start < batch.num_rows() {
228            let end = (start + rows_per_chunk).min(batch.num_rows());
229            let sub_batch = batch.slice(start, end - start);
230            batches.push(sub_batch);
231            start = end;
232        }
233
234        Ok(batches)
235    }
236
237    /// Write a single cell to storage. Returns metadata about the cell.
238    async fn write_cell(
239        &self,
240        batch: &RecordBatch,
241        partition_values: &HashMap<String, String>,
242    ) -> Result<CellMetadata> {
243        let cell_id = CellId::new(format!("cell_{}", Uuid::new_v4()));
244        let rows = batch.num_rows() as u64;
245
246        // Build the storage path
247        let partition_path = if partition_values.is_empty() {
248            String::new()
249        } else {
250            let parts: Vec<String> = self
251                .partition_by
252                .iter()
253                .filter_map(|col| {
254                    partition_values
255                        .get(col)
256                        .map(|val| format!("{}={}", col, val))
257                })
258                .collect();
259            parts.join("/") + "/"
260        };
261
262        let cell_filename = format!("{}.parquet", cell_id.as_str());
263        let relative_path = format!("{}{}", partition_path, cell_filename);
264        let storage_key = format!("{}/{}", self.frame_path, relative_path);
265
266        // Compute cell-level statistics
267        let stats = compute_column_stats(batch, &self.schema)?;
268
269        // Write the Parquet file
270        let parquet_bytes = write_parquet_bytes(batch)?;
271        let byte_size = parquet_bytes.len() as u64;
272
273        self.storage
274            .put(&storage_key, Bytes::from(parquet_bytes))
275            .await?;
276
277        info!(
278            cell_id = %cell_id,
279            rows,
280            bytes = byte_size,
281            path = %relative_path,
282            "Wrote cell to storage"
283        );
284
285        Ok(CellMetadata {
286            id: cell_id,
287            path: relative_path,
288            format: "parquet".into(),
289            partition_values: partition_values.clone(),
290            rows,
291            bytes: byte_size,
292            stats,
293        })
294    }
295}
296
297/// Convert the frame's type string to an Arrow DataType.
298pub fn type_string_to_arrow(type_str: &str) -> DataType {
299    match type_str.to_lowercase().as_str() {
300        "int8" => DataType::Int8,
301        "int16" => DataType::Int16,
302        "int32" => DataType::Int32,
303        "int64" | "int" | "integer" => DataType::Int64,
304        "uint8" => DataType::UInt8,
305        "uint16" => DataType::UInt16,
306        "uint32" => DataType::UInt32,
307        "uint64" => DataType::UInt64,
308        "float16" | "half" => DataType::Float16,
309        "float32" | "float" => DataType::Float32,
310        "float64" | "double" => DataType::Float64,
311        "string" | "utf8" | "text" => DataType::Utf8,
312        "boolean" | "bool" => DataType::Boolean,
313        "datetime" | "timestamp" => {
314            DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, None)
315        }
316        "date" => DataType::Date32,
317        "binary" | "bytes" => DataType::Binary,
318        _ => DataType::Utf8, // Default to string
319    }
320}
321
322/// Build an Arrow Schema from a FrameSchema.
323pub fn frame_schema_to_arrow(schema: &FrameSchema) -> Schema {
324    let fields: Vec<Field> = schema
325        .fields
326        .iter()
327        .map(|f| Field::new(&f.name, type_string_to_arrow(&f.data_type), f.nullable))
328        .collect();
329    Schema::new(fields)
330}
331
332/// Build a FrameSchema from an Arrow Schema.
333pub fn arrow_schema_to_frame(schema: &Schema) -> FrameSchema {
334    let fields: Vec<FieldDef> = schema
335        .fields()
336        .iter()
337        .map(|f| FieldDef {
338            name: f.name().clone(),
339            data_type: arrow_type_to_string(f.data_type()),
340            nullable: f.is_nullable(),
341        })
342        .collect();
343    FrameSchema { fields }
344}
345
346/// Convert Arrow DataType to type string.
347fn arrow_type_to_string(dt: &DataType) -> String {
348    match dt {
349        DataType::Int8 => "int8".into(),
350        DataType::Int16 => "int16".into(),
351        DataType::Int32 => "int32".into(),
352        DataType::Int64 => "int64".into(),
353        DataType::UInt8 => "uint8".into(),
354        DataType::UInt16 => "uint16".into(),
355        DataType::UInt32 => "uint32".into(),
356        DataType::UInt64 => "uint64".into(),
357        DataType::Float16 => "float16".into(),
358        DataType::Float32 => "float32".into(),
359        DataType::Float64 => "float64".into(),
360        DataType::Utf8 => "string".into(),
361        DataType::Boolean => "boolean".into(),
362        DataType::Timestamp(_, _) => "datetime".into(),
363        DataType::Date32 | DataType::Date64 => "date".into(),
364        DataType::Binary => "binary".into(),
365        _ => "string".into(),
366    }
367}
368
369/// Write a RecordBatch to Parquet bytes in memory with LZ4 compression.
370fn write_parquet_bytes(batch: &RecordBatch) -> Result<Vec<u8>> {
371    let props = WriterProperties::builder()
372        .set_compression(Compression::LZ4_RAW)
373        .build();
374
375    let mut buf: Vec<u8> = Vec::new();
376    {
377        let mut writer =
378            ArrowWriter::try_new(&mut buf, batch.schema(), Some(props)).map_err(|e| {
379                ApiaryError::Storage {
380                    message: format!("Failed to create Parquet writer: {}", e),
381                    source: None,
382                }
383            })?;
384        writer.write(batch).map_err(|e| ApiaryError::Storage {
385            message: format!("Failed to write Parquet data: {}", e),
386            source: None,
387        })?;
388        writer.close().map_err(|e| ApiaryError::Storage {
389            message: format!("Failed to close Parquet writer: {}", e),
390            source: None,
391        })?;
392    }
393    Ok(buf)
394}
395
396/// Compute column-level statistics from a RecordBatch.
397fn compute_column_stats(
398    batch: &RecordBatch,
399    schema: &FrameSchema,
400) -> Result<HashMap<String, ColumnStats>> {
401    let mut stats = HashMap::new();
402
403    for field in &schema.fields {
404        if let Ok(col_idx) = batch.schema().index_of(&field.name) {
405            let col = batch.column(col_idx);
406            let null_count = col.null_count() as u64;
407
408            let (min_val, max_val) = compute_min_max(col);
409
410            stats.insert(
411                field.name.clone(),
412                ColumnStats {
413                    min: min_val,
414                    max: max_val,
415                    null_count,
416                    distinct_count: None,
417                },
418            );
419        }
420    }
421
422    Ok(stats)
423}
424
425/// Compute min/max values for an array.
426fn compute_min_max(array: &dyn Array) -> (Option<serde_json::Value>, Option<serde_json::Value>) {
427    if array.is_empty() || array.null_count() == array.len() {
428        return (None, None);
429    }
430
431    match array.data_type() {
432        DataType::Int8 => numeric_min_max::<Int8Type>(array),
433        DataType::Int16 => numeric_min_max::<Int16Type>(array),
434        DataType::Int32 => numeric_min_max::<Int32Type>(array),
435        DataType::Int64 => numeric_min_max::<Int64Type>(array),
436        DataType::UInt8 => numeric_min_max::<UInt8Type>(array),
437        DataType::UInt16 => numeric_min_max::<UInt16Type>(array),
438        DataType::UInt32 => numeric_min_max::<UInt32Type>(array),
439        DataType::UInt64 => uint64_min_max(array),
440        DataType::Float32 => float_min_max::<Float32Type>(array),
441        DataType::Float64 => float_min_max::<Float64Type>(array),
442        DataType::Utf8 => string_min_max(array),
443        DataType::Boolean => bool_min_max(array),
444        _ => (None, None),
445    }
446}
447
448fn numeric_min_max<T>(array: &dyn Array) -> (Option<serde_json::Value>, Option<serde_json::Value>)
449where
450    T: ArrowPrimitiveType,
451    T::Native: Into<i64>,
452{
453    let arr = array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
454    let values: Vec<T::Native> = arr.iter().flatten().collect();
455    if values.is_empty() {
456        return (None, None);
457    }
458    let min: i64 = values.iter().copied().map(Into::into).min().unwrap();
459    let max: i64 = values.iter().copied().map(Into::into).max().unwrap();
460    (
461        Some(serde_json::Value::Number(min.into())),
462        Some(serde_json::Value::Number(max.into())),
463    )
464}
465
466fn uint64_min_max(array: &dyn Array) -> (Option<serde_json::Value>, Option<serde_json::Value>) {
467    let arr = array
468        .as_any()
469        .downcast_ref::<PrimitiveArray<UInt64Type>>()
470        .unwrap();
471    let values: Vec<u64> = arr.iter().flatten().collect();
472    if values.is_empty() {
473        return (None, None);
474    }
475    let min = *values.iter().min().unwrap();
476    let max = *values.iter().max().unwrap();
477    (
478        Some(serde_json::json!(min as f64)),
479        Some(serde_json::json!(max as f64)),
480    )
481}
482
483fn float_min_max<T>(array: &dyn Array) -> (Option<serde_json::Value>, Option<serde_json::Value>)
484where
485    T: ArrowPrimitiveType,
486    T::Native: Into<f64>,
487{
488    let arr = array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
489    let values: Vec<f64> = arr.iter().flatten().map(|v| v.into()).collect();
490    if values.is_empty() {
491        return (None, None);
492    }
493    let min = values
494        .iter()
495        .copied()
496        .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
497        .unwrap();
498    let max = values
499        .iter()
500        .copied()
501        .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
502        .unwrap();
503    (Some(serde_json::json!(min)), Some(serde_json::json!(max)))
504}
505
506fn string_min_max(array: &dyn Array) -> (Option<serde_json::Value>, Option<serde_json::Value>) {
507    let arr = array.as_any().downcast_ref::<StringArray>().unwrap();
508    let values: Vec<&str> = arr.iter().flatten().collect();
509    if values.is_empty() {
510        return (None, None);
511    }
512    let min = values.iter().min().unwrap();
513    let max = values.iter().max().unwrap();
514    (
515        Some(serde_json::Value::String(min.to_string())),
516        Some(serde_json::Value::String(max.to_string())),
517    )
518}
519
520fn bool_min_max(array: &dyn Array) -> (Option<serde_json::Value>, Option<serde_json::Value>) {
521    let arr = array.as_any().downcast_ref::<BooleanArray>().unwrap();
522    let values: Vec<bool> = arr.iter().flatten().collect();
523    if values.is_empty() {
524        return (None, None);
525    }
526    // Min is false if any value is false, otherwise true
527    let has_false = values.iter().any(|v| !v);
528    // Max is true if any value is true, otherwise false
529    let has_true = values.iter().any(|v| *v);
530    (
531        Some(serde_json::Value::Bool(!has_false)),
532        Some(serde_json::Value::Bool(has_true)),
533    )
534}
535
536/// Extract a string value from an array at a given index.
537fn array_value_to_string(array: &ArrayRef, idx: usize) -> String {
538    if array.is_null(idx) {
539        return "null".into();
540    }
541
542    match array.data_type() {
543        DataType::Utf8 => {
544            let arr = array.as_any().downcast_ref::<StringArray>().unwrap();
545            arr.value(idx).to_string()
546        }
547        DataType::Int8 => {
548            let arr = array.as_any().downcast_ref::<Int8Array>().unwrap();
549            arr.value(idx).to_string()
550        }
551        DataType::Int16 => {
552            let arr = array.as_any().downcast_ref::<Int16Array>().unwrap();
553            arr.value(idx).to_string()
554        }
555        DataType::Int32 => {
556            let arr = array.as_any().downcast_ref::<Int32Array>().unwrap();
557            arr.value(idx).to_string()
558        }
559        DataType::Int64 => {
560            let arr = array.as_any().downcast_ref::<Int64Array>().unwrap();
561            arr.value(idx).to_string()
562        }
563        DataType::Float32 => {
564            let arr = array.as_any().downcast_ref::<Float32Array>().unwrap();
565            arr.value(idx).to_string()
566        }
567        DataType::Float64 => {
568            let arr = array.as_any().downcast_ref::<Float64Array>().unwrap();
569            arr.value(idx).to_string()
570        }
571        DataType::Boolean => {
572            let arr = array.as_any().downcast_ref::<BooleanArray>().unwrap();
573            arr.value(idx).to_string()
574        }
575        _ => format!("{:?}", array),
576    }
577}
578
579/// Build a canonical string from partition values for grouping.
580fn partition_key_string(values: &HashMap<String, String>, partition_by: &[String]) -> String {
581    partition_by
582        .iter()
583        .map(|col| {
584            let val = values.get(col).map(|v| v.as_str()).unwrap_or("");
585            format!("{}={}", col, val)
586        })
587        .collect::<Vec<_>>()
588        .join("/")
589}
590
591#[cfg(test)]
592mod tests {
593    use super::*;
594
595    #[test]
596    fn test_type_string_to_arrow() {
597        assert_eq!(type_string_to_arrow("int64"), DataType::Int64);
598        assert_eq!(type_string_to_arrow("float64"), DataType::Float64);
599        assert_eq!(type_string_to_arrow("string"), DataType::Utf8);
600        assert_eq!(type_string_to_arrow("boolean"), DataType::Boolean);
601    }
602
603    #[test]
604    fn test_frame_schema_to_arrow() {
605        let schema = FrameSchema {
606            fields: vec![
607                FieldDef {
608                    name: "region".into(),
609                    data_type: "string".into(),
610                    nullable: false,
611                },
612                FieldDef {
613                    name: "temp".into(),
614                    data_type: "float64".into(),
615                    nullable: true,
616                },
617            ],
618        };
619        let arrow_schema = frame_schema_to_arrow(&schema);
620        assert_eq!(arrow_schema.fields().len(), 2);
621        assert_eq!(arrow_schema.field(0).name(), "region");
622        assert_eq!(*arrow_schema.field(0).data_type(), DataType::Utf8);
623    }
624
625    #[test]
626    fn test_write_parquet_bytes() {
627        let schema = Arc::new(Schema::new(vec![
628            Field::new("name", DataType::Utf8, false),
629            Field::new("value", DataType::Float64, true),
630        ]));
631
632        let batch = RecordBatch::try_new(
633            schema,
634            vec![
635                Arc::new(StringArray::from(vec!["a", "b", "c"])),
636                Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])),
637            ],
638        )
639        .unwrap();
640
641        let bytes = write_parquet_bytes(&batch).unwrap();
642        assert!(!bytes.is_empty());
643        // Parquet magic bytes
644        assert_eq!(&bytes[0..4], b"PAR1");
645    }
646
647    #[test]
648    fn test_compute_column_stats() {
649        let schema = FrameSchema {
650            fields: vec![
651                FieldDef {
652                    name: "name".into(),
653                    data_type: "string".into(),
654                    nullable: false,
655                },
656                FieldDef {
657                    name: "value".into(),
658                    data_type: "float64".into(),
659                    nullable: true,
660                },
661            ],
662        };
663
664        let arrow_schema = Arc::new(Schema::new(vec![
665            Field::new("name", DataType::Utf8, false),
666            Field::new("value", DataType::Float64, true),
667        ]));
668
669        let batch = RecordBatch::try_new(
670            arrow_schema,
671            vec![
672                Arc::new(StringArray::from(vec!["alpha", "gamma", "beta"])),
673                Arc::new(Float64Array::from(vec![10.5, 30.2, 20.1])),
674            ],
675        )
676        .unwrap();
677
678        let stats = compute_column_stats(&batch, &schema).unwrap();
679
680        assert!(stats.contains_key("name"));
681        assert!(stats.contains_key("value"));
682
683        let name_stats = &stats["name"];
684        assert_eq!(
685            name_stats.min,
686            Some(serde_json::Value::String("alpha".into()))
687        );
688        assert_eq!(
689            name_stats.max,
690            Some(serde_json::Value::String("gamma".into()))
691        );
692
693        let value_stats = &stats["value"];
694        assert_eq!(value_stats.min, Some(serde_json::json!(10.5)));
695        assert_eq!(value_stats.max, Some(serde_json::json!(30.2)));
696    }
697
698    #[test]
699    fn test_validate_schema_missing_non_nullable_column() {
700        let frame_schema = FrameSchema {
701            fields: vec![
702                FieldDef {
703                    name: "id".into(),
704                    data_type: "int64".into(),
705                    nullable: false,
706                },
707                FieldDef {
708                    name: "name".into(),
709                    data_type: "string".into(),
710                    nullable: false,
711                },
712            ],
713        };
714
715        let sizing = CellSizingPolicy::from_memory_per_bee(1024 * 1024 * 1024);
716        let storage: Arc<dyn StorageBackend> = Arc::new(
717            tokio::runtime::Runtime::new()
718                .unwrap()
719                .block_on(crate::local::LocalBackend::new(
720                    tempfile::TempDir::new().unwrap().keep(),
721                ))
722                .unwrap(),
723        );
724
725        let writer = CellWriter::new(
726            storage,
727            "test/default/test_frame".into(),
728            frame_schema,
729            vec![],
730            sizing,
731        );
732
733        // Batch only has 'id' column, missing non-nullable 'name'
734        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
735        let batch =
736            RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![1, 2, 3]))]).unwrap();
737
738        let result = writer.validate_schema(&batch);
739        assert!(result.is_err());
740        let err = format!("{}", result.unwrap_err());
741        assert!(
742            err.contains("Missing non-nullable column 'name'"),
743            "Error should mention missing non-nullable column, got: {err}"
744        );
745    }
746
747    #[test]
748    fn test_validate_schema_missing_nullable_column_ok() {
749        let frame_schema = FrameSchema {
750            fields: vec![
751                FieldDef {
752                    name: "id".into(),
753                    data_type: "int64".into(),
754                    nullable: false,
755                },
756                FieldDef {
757                    name: "notes".into(),
758                    data_type: "string".into(),
759                    nullable: true, // nullable — should not error if missing
760                },
761            ],
762        };
763
764        let sizing = CellSizingPolicy::from_memory_per_bee(1024 * 1024 * 1024);
765        let storage: Arc<dyn StorageBackend> = Arc::new(
766            tokio::runtime::Runtime::new()
767                .unwrap()
768                .block_on(crate::local::LocalBackend::new(
769                    tempfile::TempDir::new().unwrap().keep(),
770                ))
771                .unwrap(),
772        );
773
774        let writer = CellWriter::new(
775            storage,
776            "test/default/test_frame".into(),
777            frame_schema,
778            vec![],
779            sizing,
780        );
781
782        // Batch only has 'id' column, missing nullable 'notes'
783        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
784        let batch =
785            RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![1, 2, 3]))]).unwrap();
786
787        let result = writer.validate_schema(&batch);
788        assert!(result.is_ok(), "Missing nullable column should not error");
789    }
790}