1use std::collections::HashMap;
8use std::sync::Arc;
9
10use arrow::array::*;
11use arrow::compute;
12use arrow::datatypes::{
13 ArrowPrimitiveType, DataType, Field, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type,
14 Int8Type, Schema, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
15};
16use arrow::record_batch::RecordBatch;
17use bytes::Bytes;
18use parquet::arrow::ArrowWriter;
19use parquet::basic::Compression;
20use parquet::file::properties::WriterProperties;
21use tracing::{info, warn};
22use uuid::Uuid;
23
24use apiary_core::{
25 ApiaryError, CellId, CellMetadata, CellSizingPolicy, ColumnStats, FieldDef, FrameSchema,
26 Result, StorageBackend,
27};
28
29pub struct CellWriter {
31 storage: Arc<dyn StorageBackend>,
32 frame_path: String,
33 schema: FrameSchema,
34 partition_by: Vec<String>,
35 sizing: CellSizingPolicy,
36}
37
38impl CellWriter {
39 pub fn new(
41 storage: Arc<dyn StorageBackend>,
42 frame_path: String,
43 schema: FrameSchema,
44 partition_by: Vec<String>,
45 sizing: CellSizingPolicy,
46 ) -> Self {
47 Self {
48 storage,
49 frame_path,
50 schema,
51 partition_by,
52 sizing,
53 }
54 }
55
56 pub async fn write(&self, batch: &RecordBatch) -> Result<Vec<CellMetadata>> {
59 self.validate_schema(batch)?;
61
62 let partitions = self.partition_data(batch)?;
64
65 let mut all_cells = Vec::new();
66
67 for (partition_values, partition_batch) in &partitions {
68 let sub_batches = self.split_by_size(partition_batch)?;
70
71 for sub_batch in &sub_batches {
72 let cell = self.write_cell(sub_batch, partition_values).await?;
73 all_cells.push(cell);
74 }
75 }
76
77 Ok(all_cells)
78 }
79
80 fn validate_schema(&self, batch: &RecordBatch) -> Result<()> {
89 for field in &self.schema.fields {
91 let found = batch.schema().index_of(&field.name).ok();
92 if found.is_none() && !field.nullable {
93 return Err(ApiaryError::Schema {
94 message: format!("Missing non-nullable column '{}' in write data", field.name),
95 });
96 }
97 }
98
99 for batch_field in batch.schema().fields() {
101 let in_schema = self
102 .schema
103 .fields
104 .iter()
105 .any(|f| f.name == *batch_field.name());
106 if !in_schema {
107 warn!(
108 column = %batch_field.name(),
109 "Extra column in write data will be dropped"
110 );
111 }
112 }
113
114 for part_col in &self.partition_by {
116 if let Ok(col_idx) = batch.schema().index_of(part_col) {
117 let col = batch.column(col_idx);
118 if col.null_count() > 0 {
119 return Err(ApiaryError::Schema {
120 message: format!("Partition column '{}' contains null values", part_col),
121 });
122 }
123 for row_idx in 0..batch.num_rows() {
125 let val = array_value_to_string(col, row_idx);
126 if val.contains("..")
127 || val.contains('/')
128 || val.contains('\\')
129 || val.contains('\0')
130 {
131 return Err(ApiaryError::Schema {
132 message: format!(
133 "Partition column '{}' contains invalid characters (path separators or '..'): '{}'",
134 part_col, val
135 ),
136 });
137 }
138 }
139 }
140 }
141 Ok(())
142 }
143
144 fn partition_data(
146 &self,
147 batch: &RecordBatch,
148 ) -> Result<Vec<(HashMap<String, String>, RecordBatch)>> {
149 if self.partition_by.is_empty() || batch.num_rows() == 0 {
150 return Ok(vec![(HashMap::new(), batch.clone())]);
151 }
152
153 let mut partition_keys: Vec<HashMap<String, String>> = Vec::new();
155 for row_idx in 0..batch.num_rows() {
156 let mut key = HashMap::new();
157 for col_name in &self.partition_by {
158 let col_idx =
159 batch
160 .schema()
161 .index_of(col_name)
162 .map_err(|_| ApiaryError::Schema {
163 message: format!("Partition column '{}' not found in data", col_name),
164 })?;
165 let col = batch.column(col_idx);
166 let val = array_value_to_string(col, row_idx);
167 key.insert(col_name.clone(), val);
168 }
169 partition_keys.push(key);
170 }
171
172 let mut groups: HashMap<String, (HashMap<String, String>, Vec<usize>)> = HashMap::new();
174 for (row_idx, key) in partition_keys.iter().enumerate() {
175 let key_str = partition_key_string(key, &self.partition_by);
176 groups
177 .entry(key_str)
178 .or_insert_with(|| (key.clone(), Vec::new()))
179 .1
180 .push(row_idx);
181 }
182
183 let mut result = Vec::new();
185 for (_, (partition_values, row_indices)) in groups {
186 let indices =
187 UInt32Array::from(row_indices.iter().map(|i| *i as u32).collect::<Vec<_>>());
188 let columns: Vec<ArrayRef> = batch
189 .columns()
190 .iter()
191 .map(|col| compute::take(col, &indices, None).unwrap())
192 .collect();
193 let sub_batch = RecordBatch::try_new(batch.schema(), columns).map_err(|e| {
194 ApiaryError::Internal {
195 message: format!("Failed to create partition batch: {}", e),
196 }
197 })?;
198 result.push((partition_values, sub_batch));
199 }
200
201 Ok(result)
202 }
203
204 fn split_by_size(&self, batch: &RecordBatch) -> Result<Vec<RecordBatch>> {
206 let estimated_size: usize = batch
208 .columns()
209 .iter()
210 .map(|col| col.get_buffer_memory_size())
211 .sum();
212
213 let target = self.sizing.target_cell_size as usize;
214
215 if estimated_size <= target || batch.num_rows() <= 1 {
216 return Ok(vec![batch.clone()]);
217 }
218
219 let num_chunks = estimated_size.div_ceil(target);
221 let rows_per_chunk = batch.num_rows().div_ceil(num_chunks);
222 let rows_per_chunk = rows_per_chunk.max(1);
223
224 let mut batches = Vec::new();
225 let mut start = 0;
226
227 while start < batch.num_rows() {
228 let end = (start + rows_per_chunk).min(batch.num_rows());
229 let sub_batch = batch.slice(start, end - start);
230 batches.push(sub_batch);
231 start = end;
232 }
233
234 Ok(batches)
235 }
236
237 async fn write_cell(
239 &self,
240 batch: &RecordBatch,
241 partition_values: &HashMap<String, String>,
242 ) -> Result<CellMetadata> {
243 let cell_id = CellId::new(format!("cell_{}", Uuid::new_v4()));
244 let rows = batch.num_rows() as u64;
245
246 let partition_path = if partition_values.is_empty() {
248 String::new()
249 } else {
250 let parts: Vec<String> = self
251 .partition_by
252 .iter()
253 .filter_map(|col| {
254 partition_values
255 .get(col)
256 .map(|val| format!("{}={}", col, val))
257 })
258 .collect();
259 parts.join("/") + "/"
260 };
261
262 let cell_filename = format!("{}.parquet", cell_id.as_str());
263 let relative_path = format!("{}{}", partition_path, cell_filename);
264 let storage_key = format!("{}/{}", self.frame_path, relative_path);
265
266 let stats = compute_column_stats(batch, &self.schema)?;
268
269 let parquet_bytes = write_parquet_bytes(batch)?;
271 let byte_size = parquet_bytes.len() as u64;
272
273 self.storage
274 .put(&storage_key, Bytes::from(parquet_bytes))
275 .await?;
276
277 info!(
278 cell_id = %cell_id,
279 rows,
280 bytes = byte_size,
281 path = %relative_path,
282 "Wrote cell to storage"
283 );
284
285 Ok(CellMetadata {
286 id: cell_id,
287 path: relative_path,
288 format: "parquet".into(),
289 partition_values: partition_values.clone(),
290 rows,
291 bytes: byte_size,
292 stats,
293 })
294 }
295}
296
297pub fn type_string_to_arrow(type_str: &str) -> DataType {
299 match type_str.to_lowercase().as_str() {
300 "int8" => DataType::Int8,
301 "int16" => DataType::Int16,
302 "int32" => DataType::Int32,
303 "int64" | "int" | "integer" => DataType::Int64,
304 "uint8" => DataType::UInt8,
305 "uint16" => DataType::UInt16,
306 "uint32" => DataType::UInt32,
307 "uint64" => DataType::UInt64,
308 "float16" | "half" => DataType::Float16,
309 "float32" | "float" => DataType::Float32,
310 "float64" | "double" => DataType::Float64,
311 "string" | "utf8" | "text" => DataType::Utf8,
312 "boolean" | "bool" => DataType::Boolean,
313 "datetime" | "timestamp" => {
314 DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, None)
315 }
316 "date" => DataType::Date32,
317 "binary" | "bytes" => DataType::Binary,
318 _ => DataType::Utf8, }
320}
321
322pub fn frame_schema_to_arrow(schema: &FrameSchema) -> Schema {
324 let fields: Vec<Field> = schema
325 .fields
326 .iter()
327 .map(|f| Field::new(&f.name, type_string_to_arrow(&f.data_type), f.nullable))
328 .collect();
329 Schema::new(fields)
330}
331
332pub fn arrow_schema_to_frame(schema: &Schema) -> FrameSchema {
334 let fields: Vec<FieldDef> = schema
335 .fields()
336 .iter()
337 .map(|f| FieldDef {
338 name: f.name().clone(),
339 data_type: arrow_type_to_string(f.data_type()),
340 nullable: f.is_nullable(),
341 })
342 .collect();
343 FrameSchema { fields }
344}
345
346fn arrow_type_to_string(dt: &DataType) -> String {
348 match dt {
349 DataType::Int8 => "int8".into(),
350 DataType::Int16 => "int16".into(),
351 DataType::Int32 => "int32".into(),
352 DataType::Int64 => "int64".into(),
353 DataType::UInt8 => "uint8".into(),
354 DataType::UInt16 => "uint16".into(),
355 DataType::UInt32 => "uint32".into(),
356 DataType::UInt64 => "uint64".into(),
357 DataType::Float16 => "float16".into(),
358 DataType::Float32 => "float32".into(),
359 DataType::Float64 => "float64".into(),
360 DataType::Utf8 => "string".into(),
361 DataType::Boolean => "boolean".into(),
362 DataType::Timestamp(_, _) => "datetime".into(),
363 DataType::Date32 | DataType::Date64 => "date".into(),
364 DataType::Binary => "binary".into(),
365 _ => "string".into(),
366 }
367}
368
369fn write_parquet_bytes(batch: &RecordBatch) -> Result<Vec<u8>> {
371 let props = WriterProperties::builder()
372 .set_compression(Compression::LZ4_RAW)
373 .build();
374
375 let mut buf: Vec<u8> = Vec::new();
376 {
377 let mut writer =
378 ArrowWriter::try_new(&mut buf, batch.schema(), Some(props)).map_err(|e| {
379 ApiaryError::Storage {
380 message: format!("Failed to create Parquet writer: {}", e),
381 source: None,
382 }
383 })?;
384 writer.write(batch).map_err(|e| ApiaryError::Storage {
385 message: format!("Failed to write Parquet data: {}", e),
386 source: None,
387 })?;
388 writer.close().map_err(|e| ApiaryError::Storage {
389 message: format!("Failed to close Parquet writer: {}", e),
390 source: None,
391 })?;
392 }
393 Ok(buf)
394}
395
396fn compute_column_stats(
398 batch: &RecordBatch,
399 schema: &FrameSchema,
400) -> Result<HashMap<String, ColumnStats>> {
401 let mut stats = HashMap::new();
402
403 for field in &schema.fields {
404 if let Ok(col_idx) = batch.schema().index_of(&field.name) {
405 let col = batch.column(col_idx);
406 let null_count = col.null_count() as u64;
407
408 let (min_val, max_val) = compute_min_max(col);
409
410 stats.insert(
411 field.name.clone(),
412 ColumnStats {
413 min: min_val,
414 max: max_val,
415 null_count,
416 distinct_count: None,
417 },
418 );
419 }
420 }
421
422 Ok(stats)
423}
424
425fn compute_min_max(array: &dyn Array) -> (Option<serde_json::Value>, Option<serde_json::Value>) {
427 if array.is_empty() || array.null_count() == array.len() {
428 return (None, None);
429 }
430
431 match array.data_type() {
432 DataType::Int8 => numeric_min_max::<Int8Type>(array),
433 DataType::Int16 => numeric_min_max::<Int16Type>(array),
434 DataType::Int32 => numeric_min_max::<Int32Type>(array),
435 DataType::Int64 => numeric_min_max::<Int64Type>(array),
436 DataType::UInt8 => numeric_min_max::<UInt8Type>(array),
437 DataType::UInt16 => numeric_min_max::<UInt16Type>(array),
438 DataType::UInt32 => numeric_min_max::<UInt32Type>(array),
439 DataType::UInt64 => uint64_min_max(array),
440 DataType::Float32 => float_min_max::<Float32Type>(array),
441 DataType::Float64 => float_min_max::<Float64Type>(array),
442 DataType::Utf8 => string_min_max(array),
443 DataType::Boolean => bool_min_max(array),
444 _ => (None, None),
445 }
446}
447
448fn numeric_min_max<T>(array: &dyn Array) -> (Option<serde_json::Value>, Option<serde_json::Value>)
449where
450 T: ArrowPrimitiveType,
451 T::Native: Into<i64>,
452{
453 let arr = array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
454 let values: Vec<T::Native> = arr.iter().flatten().collect();
455 if values.is_empty() {
456 return (None, None);
457 }
458 let min: i64 = values.iter().copied().map(Into::into).min().unwrap();
459 let max: i64 = values.iter().copied().map(Into::into).max().unwrap();
460 (
461 Some(serde_json::Value::Number(min.into())),
462 Some(serde_json::Value::Number(max.into())),
463 )
464}
465
466fn uint64_min_max(array: &dyn Array) -> (Option<serde_json::Value>, Option<serde_json::Value>) {
467 let arr = array
468 .as_any()
469 .downcast_ref::<PrimitiveArray<UInt64Type>>()
470 .unwrap();
471 let values: Vec<u64> = arr.iter().flatten().collect();
472 if values.is_empty() {
473 return (None, None);
474 }
475 let min = *values.iter().min().unwrap();
476 let max = *values.iter().max().unwrap();
477 (
478 Some(serde_json::json!(min as f64)),
479 Some(serde_json::json!(max as f64)),
480 )
481}
482
483fn float_min_max<T>(array: &dyn Array) -> (Option<serde_json::Value>, Option<serde_json::Value>)
484where
485 T: ArrowPrimitiveType,
486 T::Native: Into<f64>,
487{
488 let arr = array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
489 let values: Vec<f64> = arr.iter().flatten().map(|v| v.into()).collect();
490 if values.is_empty() {
491 return (None, None);
492 }
493 let min = values
494 .iter()
495 .copied()
496 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
497 .unwrap();
498 let max = values
499 .iter()
500 .copied()
501 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
502 .unwrap();
503 (Some(serde_json::json!(min)), Some(serde_json::json!(max)))
504}
505
506fn string_min_max(array: &dyn Array) -> (Option<serde_json::Value>, Option<serde_json::Value>) {
507 let arr = array.as_any().downcast_ref::<StringArray>().unwrap();
508 let values: Vec<&str> = arr.iter().flatten().collect();
509 if values.is_empty() {
510 return (None, None);
511 }
512 let min = values.iter().min().unwrap();
513 let max = values.iter().max().unwrap();
514 (
515 Some(serde_json::Value::String(min.to_string())),
516 Some(serde_json::Value::String(max.to_string())),
517 )
518}
519
520fn bool_min_max(array: &dyn Array) -> (Option<serde_json::Value>, Option<serde_json::Value>) {
521 let arr = array.as_any().downcast_ref::<BooleanArray>().unwrap();
522 let values: Vec<bool> = arr.iter().flatten().collect();
523 if values.is_empty() {
524 return (None, None);
525 }
526 let has_false = values.iter().any(|v| !v);
528 let has_true = values.iter().any(|v| *v);
530 (
531 Some(serde_json::Value::Bool(!has_false)),
532 Some(serde_json::Value::Bool(has_true)),
533 )
534}
535
536fn array_value_to_string(array: &ArrayRef, idx: usize) -> String {
538 if array.is_null(idx) {
539 return "null".into();
540 }
541
542 match array.data_type() {
543 DataType::Utf8 => {
544 let arr = array.as_any().downcast_ref::<StringArray>().unwrap();
545 arr.value(idx).to_string()
546 }
547 DataType::Int8 => {
548 let arr = array.as_any().downcast_ref::<Int8Array>().unwrap();
549 arr.value(idx).to_string()
550 }
551 DataType::Int16 => {
552 let arr = array.as_any().downcast_ref::<Int16Array>().unwrap();
553 arr.value(idx).to_string()
554 }
555 DataType::Int32 => {
556 let arr = array.as_any().downcast_ref::<Int32Array>().unwrap();
557 arr.value(idx).to_string()
558 }
559 DataType::Int64 => {
560 let arr = array.as_any().downcast_ref::<Int64Array>().unwrap();
561 arr.value(idx).to_string()
562 }
563 DataType::Float32 => {
564 let arr = array.as_any().downcast_ref::<Float32Array>().unwrap();
565 arr.value(idx).to_string()
566 }
567 DataType::Float64 => {
568 let arr = array.as_any().downcast_ref::<Float64Array>().unwrap();
569 arr.value(idx).to_string()
570 }
571 DataType::Boolean => {
572 let arr = array.as_any().downcast_ref::<BooleanArray>().unwrap();
573 arr.value(idx).to_string()
574 }
575 _ => format!("{:?}", array),
576 }
577}
578
579fn partition_key_string(values: &HashMap<String, String>, partition_by: &[String]) -> String {
581 partition_by
582 .iter()
583 .map(|col| {
584 let val = values.get(col).map(|v| v.as_str()).unwrap_or("");
585 format!("{}={}", col, val)
586 })
587 .collect::<Vec<_>>()
588 .join("/")
589}
590
591#[cfg(test)]
592mod tests {
593 use super::*;
594
595 #[test]
596 fn test_type_string_to_arrow() {
597 assert_eq!(type_string_to_arrow("int64"), DataType::Int64);
598 assert_eq!(type_string_to_arrow("float64"), DataType::Float64);
599 assert_eq!(type_string_to_arrow("string"), DataType::Utf8);
600 assert_eq!(type_string_to_arrow("boolean"), DataType::Boolean);
601 }
602
603 #[test]
604 fn test_frame_schema_to_arrow() {
605 let schema = FrameSchema {
606 fields: vec![
607 FieldDef {
608 name: "region".into(),
609 data_type: "string".into(),
610 nullable: false,
611 },
612 FieldDef {
613 name: "temp".into(),
614 data_type: "float64".into(),
615 nullable: true,
616 },
617 ],
618 };
619 let arrow_schema = frame_schema_to_arrow(&schema);
620 assert_eq!(arrow_schema.fields().len(), 2);
621 assert_eq!(arrow_schema.field(0).name(), "region");
622 assert_eq!(*arrow_schema.field(0).data_type(), DataType::Utf8);
623 }
624
625 #[test]
626 fn test_write_parquet_bytes() {
627 let schema = Arc::new(Schema::new(vec![
628 Field::new("name", DataType::Utf8, false),
629 Field::new("value", DataType::Float64, true),
630 ]));
631
632 let batch = RecordBatch::try_new(
633 schema,
634 vec![
635 Arc::new(StringArray::from(vec!["a", "b", "c"])),
636 Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])),
637 ],
638 )
639 .unwrap();
640
641 let bytes = write_parquet_bytes(&batch).unwrap();
642 assert!(!bytes.is_empty());
643 assert_eq!(&bytes[0..4], b"PAR1");
645 }
646
647 #[test]
648 fn test_compute_column_stats() {
649 let schema = FrameSchema {
650 fields: vec![
651 FieldDef {
652 name: "name".into(),
653 data_type: "string".into(),
654 nullable: false,
655 },
656 FieldDef {
657 name: "value".into(),
658 data_type: "float64".into(),
659 nullable: true,
660 },
661 ],
662 };
663
664 let arrow_schema = Arc::new(Schema::new(vec![
665 Field::new("name", DataType::Utf8, false),
666 Field::new("value", DataType::Float64, true),
667 ]));
668
669 let batch = RecordBatch::try_new(
670 arrow_schema,
671 vec![
672 Arc::new(StringArray::from(vec!["alpha", "gamma", "beta"])),
673 Arc::new(Float64Array::from(vec![10.5, 30.2, 20.1])),
674 ],
675 )
676 .unwrap();
677
678 let stats = compute_column_stats(&batch, &schema).unwrap();
679
680 assert!(stats.contains_key("name"));
681 assert!(stats.contains_key("value"));
682
683 let name_stats = &stats["name"];
684 assert_eq!(
685 name_stats.min,
686 Some(serde_json::Value::String("alpha".into()))
687 );
688 assert_eq!(
689 name_stats.max,
690 Some(serde_json::Value::String("gamma".into()))
691 );
692
693 let value_stats = &stats["value"];
694 assert_eq!(value_stats.min, Some(serde_json::json!(10.5)));
695 assert_eq!(value_stats.max, Some(serde_json::json!(30.2)));
696 }
697
698 #[test]
699 fn test_validate_schema_missing_non_nullable_column() {
700 let frame_schema = FrameSchema {
701 fields: vec![
702 FieldDef {
703 name: "id".into(),
704 data_type: "int64".into(),
705 nullable: false,
706 },
707 FieldDef {
708 name: "name".into(),
709 data_type: "string".into(),
710 nullable: false,
711 },
712 ],
713 };
714
715 let sizing = CellSizingPolicy::from_memory_per_bee(1024 * 1024 * 1024);
716 let storage: Arc<dyn StorageBackend> = Arc::new(
717 tokio::runtime::Runtime::new()
718 .unwrap()
719 .block_on(crate::local::LocalBackend::new(
720 tempfile::TempDir::new().unwrap().keep(),
721 ))
722 .unwrap(),
723 );
724
725 let writer = CellWriter::new(
726 storage,
727 "test/default/test_frame".into(),
728 frame_schema,
729 vec![],
730 sizing,
731 );
732
733 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
735 let batch =
736 RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![1, 2, 3]))]).unwrap();
737
738 let result = writer.validate_schema(&batch);
739 assert!(result.is_err());
740 let err = format!("{}", result.unwrap_err());
741 assert!(
742 err.contains("Missing non-nullable column 'name'"),
743 "Error should mention missing non-nullable column, got: {err}"
744 );
745 }
746
747 #[test]
748 fn test_validate_schema_missing_nullable_column_ok() {
749 let frame_schema = FrameSchema {
750 fields: vec![
751 FieldDef {
752 name: "id".into(),
753 data_type: "int64".into(),
754 nullable: false,
755 },
756 FieldDef {
757 name: "notes".into(),
758 data_type: "string".into(),
759 nullable: true, },
761 ],
762 };
763
764 let sizing = CellSizingPolicy::from_memory_per_bee(1024 * 1024 * 1024);
765 let storage: Arc<dyn StorageBackend> = Arc::new(
766 tokio::runtime::Runtime::new()
767 .unwrap()
768 .block_on(crate::local::LocalBackend::new(
769 tempfile::TempDir::new().unwrap().keep(),
770 ))
771 .unwrap(),
772 );
773
774 let writer = CellWriter::new(
775 storage,
776 "test/default/test_frame".into(),
777 frame_schema,
778 vec![],
779 sizing,
780 );
781
782 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
784 let batch =
785 RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![1, 2, 3]))]).unwrap();
786
787 let result = writer.validate_schema(&batch);
788 assert!(result.is_ok(), "Missing nullable column should not error");
789 }
790}