Skip to main content

copybook_arrow/
batch_builder.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! `RecordBatch` builder that accumulates records and flushes Arrow batches
3//!
4//! Feeds decoded COBOL records into column accumulators and produces
5//! Arrow `RecordBatch` objects when the configured batch size is reached.
6
7use arrow::array::RecordBatch;
8use arrow::datatypes::Schema as ArrowSchema;
9use std::sync::Arc;
10
11use copybook_core::schema::FieldKind;
12
13use crate::builders::{ColumnAccumulator, create_accumulator};
14use crate::decode_direct::decode_record_to_columns;
15use crate::options::ArrowOptions;
16use crate::{ArrowError, Result};
17
18/// Builds Arrow `RecordBatch` objects from a stream of raw COBOL records.
19pub struct RecordBatchBuilder {
20    arrow_schema: Arc<ArrowSchema>,
21    cobol_schema: copybook_core::Schema,
22    accumulators: Vec<Box<dyn ColumnAccumulator>>,
23    options: ArrowOptions,
24    current_count: usize,
25}
26
27impl RecordBatchBuilder {
28    /// Create a new builder.
29    ///
30    /// # Errors
31    ///
32    /// Returns an error if accumulator creation fails for any field.
33    #[inline]
34    pub fn new(
35        arrow_schema: Arc<ArrowSchema>,
36        cobol_schema: &copybook_core::Schema,
37        options: &ArrowOptions,
38    ) -> Result<Self> {
39        let accumulators = create_accumulators(cobol_schema, options)?;
40        Ok(Self {
41            arrow_schema,
42            cobol_schema: cobol_schema.clone(),
43            accumulators,
44            options: options.clone(),
45            current_count: 0,
46        })
47    }
48
49    /// Append a raw record. Returns `Some(batch)` when `batch_size` is reached.
50    ///
51    /// # Errors
52    ///
53    /// Returns an error if decoding or batch construction fails.
54    #[inline]
55    pub fn append_record(&mut self, record: &[u8]) -> Result<Option<RecordBatch>> {
56        decode_record_to_columns(
57            &self.cobol_schema,
58            record,
59            &mut self.accumulators,
60            &self.options,
61        )?;
62        self.current_count += 1;
63
64        if self.current_count >= self.options.batch_size {
65            let batch = self.build_batch()?;
66            self.reset_accumulators()?;
67            return Ok(Some(batch));
68        }
69
70        Ok(None)
71    }
72
73    /// Flush remaining records as a partial batch (may be empty).
74    ///
75    /// # Errors
76    ///
77    /// Returns an error if batch construction fails.
78    #[inline]
79    pub fn flush(&mut self) -> Result<Option<RecordBatch>> {
80        if self.current_count == 0 {
81            return Ok(None);
82        }
83        let batch = self.build_batch()?;
84        self.reset_accumulators()?;
85        Ok(Some(batch))
86    }
87
88    /// Build a `RecordBatch` from the current accumulators.
89    fn build_batch(&mut self) -> Result<RecordBatch> {
90        let columns: Vec<_> = self.accumulators.iter_mut().map(|a| a.finish()).collect();
91        RecordBatch::try_new(self.arrow_schema.clone(), columns)
92            .map_err(|e| ArrowError::ColumnBuild(format!("RecordBatch build failed: {e}")))
93    }
94
95    /// Reset accumulators for the next batch.
96    fn reset_accumulators(&mut self) -> Result<()> {
97        self.accumulators = create_accumulators(&self.cobol_schema, &self.options)?;
98        self.current_count = 0;
99        Ok(())
100    }
101}
102
103/// Create one accumulator per leaf field, matching the Arrow schema field order.
104fn create_accumulators(
105    schema: &copybook_core::Schema,
106    options: &ArrowOptions,
107) -> Result<Vec<Box<dyn ColumnAccumulator>>> {
108    let mut accumulators = Vec::new();
109    for field in &schema.fields {
110        collect_accumulators(field, options, &mut accumulators)?;
111    }
112    Ok(accumulators)
113}
114
115/// Recursively collect accumulators for leaf fields.
116fn collect_accumulators(
117    field: &copybook_core::schema::Field,
118    options: &ArrowOptions,
119    output: &mut Vec<Box<dyn ColumnAccumulator>>,
120) -> Result<()> {
121    // Skip non-storage fields
122    if matches!(
123        field.kind,
124        FieldKind::Condition { .. } | FieldKind::Renames { .. }
125    ) {
126        return Ok(());
127    }
128
129    // Skip FILLER unless configured
130    if (field.name.starts_with("_filler_") || field.name.eq_ignore_ascii_case("FILLER"))
131        && !options.emit_filler
132    {
133        return Ok(());
134    }
135
136    // Handle groups (flatten only for now)
137    if matches!(field.kind, FieldKind::Group) {
138        if options.flatten_groups {
139            for child in &field.children {
140                collect_accumulators(child, options, output)?;
141            }
142        }
143        return Ok(());
144    }
145
146    // Scalar field
147    let acc = create_accumulator(&field.kind, field.len, options)?;
148    output.push(acc);
149    Ok(())
150}
151
152#[cfg(test)]
153#[allow(clippy::unwrap_used)]
154mod tests {
155    use super::*;
156    use crate::schema_convert::cobol_schema_to_arrow;
157    use copybook_core::schema::{Field, FieldKind, Schema};
158
159    fn make_field(name: &str, kind: FieldKind, offset: u32, len: u32) -> Field {
160        let mut f = Field::with_kind(5, name.to_string(), kind);
161        f.path = name.to_string();
162        f.offset = offset;
163        f.len = len;
164        f
165    }
166
167    #[test]
168    fn test_batch_builder_flush() {
169        let schema = Schema::from_fields(vec![make_field(
170            "AMOUNT",
171            FieldKind::PackedDecimal {
172                digits: 5,
173                scale: 2,
174                signed: true,
175            },
176            0,
177            3,
178        )]);
179        let opts = ArrowOptions::default();
180        let arrow_schema = cobol_schema_to_arrow(&schema, &opts).unwrap();
181        let mut builder = RecordBatchBuilder::new(Arc::new(arrow_schema), &schema, &opts).unwrap();
182
183        // Packed decimal: 12345 with sign C (positive) -> 0x12 0x34 0x5C
184        let record = [0x12, 0x34, 0x5C];
185        let batch = builder.append_record(&record).unwrap();
186        assert!(batch.is_none()); // Not at batch size yet
187
188        let batch = builder.flush().unwrap();
189        assert!(batch.is_some());
190        let batch = batch.unwrap();
191        assert_eq!(batch.num_rows(), 1);
192        assert_eq!(batch.num_columns(), 1);
193    }
194
195    #[test]
196    fn test_batch_builder_auto_flush() {
197        let schema = Schema::from_fields(vec![make_field(
198            "NAME",
199            FieldKind::Alphanum { len: 4 },
200            0,
201            4,
202        )]);
203        let opts = ArrowOptions {
204            batch_size: 2, // small batch for testing
205            codepage: copybook_codec::Codepage::ASCII,
206            ..ArrowOptions::default()
207        };
208
209        let arrow_schema = cobol_schema_to_arrow(&schema, &opts).unwrap();
210        let mut builder = RecordBatchBuilder::new(Arc::new(arrow_schema), &schema, &opts).unwrap();
211
212        let r1 = b"ABCD";
213        let r2 = b"EFGH";
214
215        assert!(builder.append_record(r1).unwrap().is_none());
216        let batch = builder.append_record(r2).unwrap();
217        assert!(batch.is_some());
218        assert_eq!(batch.unwrap().num_rows(), 2);
219    }
220}