Skip to main content

recutils_rs/
arrow.rs

1//! Convert rec records into Apache Arrow `RecordBatch`es.
2//!
3//! Gated behind the `arrow` cargo feature. Honors `%type:` declarations
4//! from the rset descriptor; untyped fields fall back to `Utf8`.
5
6use std::collections::{HashMap, HashSet};
7use std::sync::Arc;
8
9use arrow::array::{
10    Array, ArrayRef, BooleanArray, BooleanBuilder, Float64Array, Float64Builder, Int64Array,
11    Int64Builder, StringArray, StringBuilder,
12};
13use arrow::datatypes::{DataType, Field, Schema};
14use arrow::record_batch::RecordBatch;
15
16use crate::rset::Rset;
17use crate::{Db, OwnedRset, Record, SelectionExpression};
18
19pub fn rec_to_record_batch(
20    db: &mut Db,
21    record_type: &str,
22) -> Result<(Arc<Schema>, RecordBatch), Box<dyn std::error::Error>> {
23    let rset = db
24        .rset_by_type(record_type)
25        .ok_or_else(|| format!("no record set of type {record_type:?}"))?;
26    rec_to_record_batch_from_rset(&rset)
27}
28
29/// Build the `(schema, batch)` for an arbitrary [`Rset`], including
30/// anonymous record sets that have no `%rec:` descriptor (so they can't be
31/// looked up by [`Db::rset_by_type`]).
32pub fn rec_to_record_batch_from_rset(
33    rset: &Rset<'_>,
34) -> Result<(Arc<Schema>, RecordBatch), Box<dyn std::error::Error>> {
35    let mut declared_types: HashMap<String, String> = HashMap::new();
36    if let Some(desc) = rset.descriptor() {
37        for f in desc.fields() {
38            if f.name() == "%type" {
39                if let Some((field, ty)) = split_type_decl(&f.value()) {
40                    declared_types.insert(field, ty);
41                }
42            }
43        }
44    }
45
46    let (column_order, rows) = collect_rows_from_rset(rset)?;
47    let schema = build_schema(&column_order, &declared_types);
48    let columns = build_columns(&schema, &rows);
49    let batch = RecordBatch::try_new(Arc::clone(&schema), columns)?;
50    Ok((schema, batch))
51}
52
53/// Build a [`RecordBatch`] for the records of `record_type` that match the
54/// given selection expression, using the caller-provided `schema` (so the
55/// column set stays stable even when the filter excludes every record that
56/// has a particular field).
57pub fn rec_to_filtered_batch(
58    db: &mut Db,
59    record_type: &str,
60    schema: &Arc<Schema>,
61    selection_expression: &SelectionExpression,
62) -> Result<RecordBatch, Box<dyn std::error::Error>> {
63    let rset = db
64        .rset_by_type(record_type)
65        .ok_or_else(|| format!("no record set of type {record_type:?}"))?;
66    rec_to_filtered_batch_from_rset(&rset, schema, selection_expression)
67}
68
69/// Same as [`rec_to_filtered_batch`] but for an arbitrary [`Rset`].
70pub fn rec_to_filtered_batch_from_rset(
71    rset: &Rset<'_>,
72    schema: &Arc<Schema>,
73    selection_expression: &SelectionExpression,
74) -> Result<RecordBatch, Box<dyn std::error::Error>> {
75    let mut rows: Vec<HashMap<String, String>> = Vec::new();
76    for (i, record) in rset.records().enumerate() {
77        if !selection_expression.matches(&record) {
78            continue;
79        }
80        let mut row: HashMap<String, String> = HashMap::new();
81        for f in record.fields() {
82            let name = f.name();
83            if name.starts_with('%') {
84                continue;
85            }
86            if row.contains_key(&name) {
87                return Err(format!(
88                    "field {:?} repeated in record {} (1-based); use a List<T> mapping (not yet supported) or remove the repeat",
89                    name,
90                    i + 1
91                )
92                .into());
93            }
94            row.insert(name.clone(), f.value());
95        }
96        rows.push(row);
97    }
98    let columns = build_columns(schema, &rows);
99    Ok(RecordBatch::try_new(Arc::clone(schema), columns)?)
100}
101
102pub fn split_type_decl(value: &str) -> Option<(String, String)> {
103    let trimmed = value.trim();
104    let (name, rest) = trimmed.split_once(char::is_whitespace)?;
105    Some((name.trim().to_string(), rest.trim().to_string()))
106}
107
108pub fn collect_rows(
109    db: &mut Db,
110    record_type: &str,
111) -> Result<(Vec<String>, Vec<HashMap<String, String>>), Box<dyn std::error::Error>> {
112    let rset = db
113        .rset_by_type(record_type)
114        .ok_or_else(|| format!("no record set of type {record_type:?}"))?;
115    collect_rows_from_rset(&rset)
116}
117
118pub fn collect_rows_from_rset(
119    rset: &Rset<'_>,
120) -> Result<(Vec<String>, Vec<HashMap<String, String>>), Box<dyn std::error::Error>> {
121    let mut column_order: Vec<String> = Vec::new();
122    let mut seen: HashSet<String> = HashSet::new();
123    let mut rows: Vec<HashMap<String, String>> = Vec::new();
124
125    for (i, record) in rset.records().enumerate() {
126        let mut row: HashMap<String, String> = HashMap::new();
127        for f in record.fields() {
128            let name = f.name();
129            if name.starts_with('%') {
130                continue;
131            }
132            if row.contains_key(&name) {
133                return Err(format!(
134                    "field {:?} repeated in record {} (1-based); use a List<T> mapping (not yet supported) or remove the repeat",
135                    name,
136                    i + 1
137                )
138                .into());
139            }
140            row.insert(name.clone(), f.value());
141            if seen.insert(name.clone()) {
142                column_order.push(name);
143            }
144        }
145        rows.push(row);
146    }
147    Ok((column_order, rows))
148}
149
150pub fn build_schema(
151    column_order: &[String],
152    declared: &HashMap<String, String>,
153) -> Arc<Schema> {
154    let fields: Vec<Field> = column_order
155        .iter()
156        .map(|name| {
157            let dt = match declared.get(name) {
158                Some(t) => map_rec_type(t),
159                None => {
160                    log::info!("no %type for field {name:?}; falling back to Utf8");
161                    DataType::Utf8
162                }
163            };
164            Field::new(name, dt, true)
165        })
166        .collect();
167    Arc::new(Schema::new(fields))
168}
169
170pub fn map_rec_type(t: &str) -> DataType {
171    match t.split_whitespace().next().unwrap_or("") {
172        "int" | "range" => DataType::Int64,
173        "real" => DataType::Float64,
174        "bool" => DataType::Boolean,
175        _ => DataType::Utf8,
176    }
177}
178
179pub fn build_columns(schema: &Schema, rows: &[HashMap<String, String>]) -> Vec<ArrayRef> {
180    schema
181        .fields()
182        .iter()
183        .map(|f| build_column(f, rows))
184        .collect()
185}
186
187pub fn build_column(field: &Field, rows: &[HashMap<String, String>]) -> ArrayRef {
188    let name = field.name();
189    match field.data_type() {
190        DataType::Int64 => {
191            let mut b = Int64Builder::with_capacity(rows.len());
192            for row in rows {
193                match row.get(name).map(|s| s.trim()) {
194                    Some(s) if s.is_empty() => b.append_null(),
195                    Some(s) => match s.parse::<i64>() {
196                        Ok(v) => b.append_value(v),
197                        Err(_) => {
198                            log::warn!("field {name:?}: cannot parse {s:?} as int; nulled");
199                            b.append_null();
200                        }
201                    },
202                    None => b.append_null(),
203                }
204            }
205            Arc::new(b.finish())
206        }
207        DataType::Float64 => {
208            let mut b = Float64Builder::with_capacity(rows.len());
209            for row in rows {
210                match row.get(name).map(|s| s.trim()) {
211                    Some(s) if s.is_empty() => b.append_null(),
212                    Some(s) => match s.parse::<f64>() {
213                        Ok(v) => b.append_value(v),
214                        Err(_) => {
215                            log::warn!("field {name:?}: cannot parse {s:?} as real; nulled");
216                            b.append_null();
217                        }
218                    },
219                    None => b.append_null(),
220                }
221            }
222            Arc::new(b.finish())
223        }
224        DataType::Boolean => {
225            let mut b = BooleanBuilder::with_capacity(rows.len());
226            for row in rows {
227                match row.get(name).map(|s| s.trim()) {
228                    Some(s) if s.is_empty() => b.append_null(),
229                    Some(s) => match parse_rec_bool(s) {
230                        Some(v) => b.append_value(v),
231                        None => {
232                            log::warn!("field {name:?}: cannot parse {s:?} as bool; nulled");
233                            b.append_null();
234                        }
235                    },
236                    None => b.append_null(),
237                }
238            }
239            Arc::new(b.finish())
240        }
241        DataType::Utf8 => {
242            let mut b = StringBuilder::with_capacity(rows.len(), rows.len() * 16);
243            for row in rows {
244                match row.get(name) {
245                    Some(s) => b.append_value(s),
246                    None => b.append_null(),
247                }
248            }
249            Arc::new(b.finish())
250        }
251        other => panic!("unsupported arrow type {other:?}"),
252    }
253}
254
255pub fn parse_rec_bool(s: &str) -> Option<bool> {
256    match s {
257        "yes" | "true" | "1" => Some(true),
258        "no" | "false" | "0" => Some(false),
259        _ => None,
260    }
261}
262
263/// Serialize `batches` as a `.rec` file body containing a single record set
264/// of type `record_type`. The descriptor block carries `%rec:`, one `%type:`
265/// line per non-Utf8 column, and one `%mandatory:` line per non-nullable
266/// Arrow field. Null values are omitted from the produced records (rec
267/// convention: absent field == null).
268///
269/// Each batch's column count and layout must match `schema`. Unsupported
270/// Arrow types (anything beyond Int64 / Float64 / Boolean / Utf8) return an
271/// error rather than producing a lossy serialization.
272pub fn record_batches_to_rec_string(
273    record_type: &str,
274    schema: &Schema,
275    batches: &[RecordBatch],
276) -> Result<String, Box<dyn std::error::Error>> {
277    if record_type.is_empty() {
278        return Err("record_type must be a non-empty rec type name".into());
279    }
280
281    let mut db = Db::new();
282    let mut rset = OwnedRset::new();
283    rset.set_descriptor(build_descriptor(record_type, schema)?);
284
285    for batch in batches {
286        if batch.num_columns() != schema.fields().len() {
287            return Err(format!(
288                "batch has {} columns but schema has {}",
289                batch.num_columns(),
290                schema.fields().len()
291            )
292            .into());
293        }
294        for row in 0..batch.num_rows() {
295            let mut record = Record::new();
296            for (col_idx, field) in schema.fields().iter().enumerate() {
297                let array = batch.column(col_idx).as_ref();
298                if array.is_null(row) {
299                    continue;
300                }
301                let value = format_arrow_value(field, array, row)?;
302                record.append_field(field.name(), &value)?;
303            }
304            rset.append_record(record)?;
305        }
306    }
307
308    db.append_rset(rset)?;
309    Ok(db.to_rec_string()?)
310}
311
312fn build_descriptor(
313    record_type: &str,
314    schema: &Schema,
315) -> Result<Record, Box<dyn std::error::Error>> {
316    let mut desc = Record::new();
317    desc.append_field("%rec", record_type)?;
318    for field in schema.fields() {
319        if let Some(rec_ty) = map_arrow_to_rec_type(field.data_type())? {
320            desc.append_field("%type", &format!("{} {}", field.name(), rec_ty))?;
321        }
322    }
323    for field in schema.fields() {
324        if !field.is_nullable() {
325            desc.append_field("%mandatory", field.name())?;
326        }
327    }
328    Ok(desc)
329}
330
331/// Inverse of [`map_rec_type`]. Returns `Ok(None)` for `Utf8`, since rec's
332/// untyped default is string and emitting `%type: <name> string` would be
333/// noise. Returns `Err` for Arrow types we don't know how to round-trip.
334pub fn map_arrow_to_rec_type(
335    dt: &DataType,
336) -> Result<Option<&'static str>, Box<dyn std::error::Error>> {
337    Ok(match dt {
338        DataType::Int64 => Some("int"),
339        DataType::Float64 => Some("real"),
340        DataType::Boolean => Some("bool"),
341        DataType::Utf8 => None,
342        other => {
343            return Err(format!("unsupported arrow type {other:?} for rec output").into());
344        }
345    })
346}
347
348pub fn format_arrow_value(
349    field: &Field,
350    array: &dyn Array,
351    row: usize,
352) -> Result<String, Box<dyn std::error::Error>> {
353    match field.data_type() {
354        DataType::Int64 => {
355            let a = array
356                .as_any()
357                .downcast_ref::<Int64Array>()
358                .ok_or("expected Int64Array")?;
359            Ok(a.value(row).to_string())
360        }
361        DataType::Float64 => {
362            let a = array
363                .as_any()
364                .downcast_ref::<Float64Array>()
365                .ok_or("expected Float64Array")?;
366            Ok(format_rec_float(a.value(row)))
367        }
368        DataType::Boolean => {
369            let a = array
370                .as_any()
371                .downcast_ref::<BooleanArray>()
372                .ok_or("expected BooleanArray")?;
373            Ok(if a.value(row) { "yes" } else { "no" }.to_string())
374        }
375        DataType::Utf8 => {
376            let a = array
377                .as_any()
378                .downcast_ref::<StringArray>()
379                .ok_or("expected StringArray")?;
380            Ok(a.value(row).to_string())
381        }
382        other => Err(format!("unsupported arrow type {other:?} for rec output").into()),
383    }
384}
385
386/// Format an `f64` so integer-valued finite floats serialize as `"1.0"`
387/// rather than `"1"`. Keeps round-trips stable when the file is read back
388/// without `%type: real` (e.g. by a human-trimmed descriptor).
389fn format_rec_float(f: f64) -> String {
390    if f.is_finite() && f.fract() == 0.0 {
391        format!("{f:.1}")
392    } else {
393        f.to_string()
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400    use arrow::array::{BooleanArray, Float64Array, Int64Array, StringArray};
401    use arrow::datatypes::{DataType, Field, Schema};
402
403    fn sample_schema() -> Arc<Schema> {
404        Arc::new(Schema::new(vec![
405            Field::new("Title", DataType::Utf8, false),
406            Field::new("Year", DataType::Int64, true),
407            Field::new("Price", DataType::Float64, true),
408            Field::new("InPrint", DataType::Boolean, true),
409        ]))
410    }
411
412    fn sample_batch(schema: &Arc<Schema>) -> RecordBatch {
413        let titles: ArrayRef = Arc::new(StringArray::from(vec![
414            Some("Refactoring"),
415            Some("TDD"),
416        ]));
417        let years: ArrayRef = Arc::new(Int64Array::from(vec![Some(1999), None]));
418        let prices: ArrayRef =
419            Arc::new(Float64Array::from(vec![Some(42.0), Some(19.95)]));
420        let in_print: ArrayRef =
421            Arc::new(BooleanArray::from(vec![Some(true), Some(false)]));
422        RecordBatch::try_new(
423            Arc::clone(schema),
424            vec![titles, years, prices, in_print],
425        )
426        .unwrap()
427    }
428
429    #[test]
430    fn descriptor_carries_types_and_mandatory() {
431        let schema = sample_schema();
432        let batch = sample_batch(&schema);
433        let text =
434            record_batches_to_rec_string("Book", &schema, std::slice::from_ref(&batch))
435                .unwrap();
436        assert!(text.contains("%rec: Book"));
437        assert!(text.contains("%type: Year int"));
438        assert!(text.contains("%type: Price real"));
439        assert!(text.contains("%type: InPrint bool"));
440        // Utf8 columns get no %type: line.
441        assert!(!text.contains("%type: Title"));
442        // Only the non-nullable Arrow field becomes %mandatory.
443        assert!(text.contains("%mandatory: Title"));
444        assert!(!text.contains("%mandatory: Year"));
445    }
446
447    #[test]
448    fn integer_valued_float_keeps_decimal() {
449        let schema = sample_schema();
450        let batch = sample_batch(&schema);
451        let text =
452            record_batches_to_rec_string("Book", &schema, std::slice::from_ref(&batch))
453                .unwrap();
454        assert!(text.contains("Price: 42.0"));
455        assert!(text.contains("Price: 19.95"));
456    }
457
458    #[test]
459    fn bool_writes_yes_no() {
460        let schema = sample_schema();
461        let batch = sample_batch(&schema);
462        let text =
463            record_batches_to_rec_string("Book", &schema, std::slice::from_ref(&batch))
464                .unwrap();
465        assert!(text.contains("InPrint: yes"));
466        assert!(text.contains("InPrint: no"));
467    }
468
469    #[test]
470    fn null_field_is_omitted() {
471        let schema = sample_schema();
472        let batch = sample_batch(&schema);
473        let text =
474            record_batches_to_rec_string("Book", &schema, std::slice::from_ref(&batch))
475                .unwrap();
476        // The second record has Year=null; it should not emit a Year field.
477        // Anchor on the unique Title "TDD" to find the second record block.
478        let tdd_idx = text.find("Title: TDD").expect("TDD record present");
479        let tdd_block = &text[tdd_idx..];
480        // Stop at the next blank-line-prefixed record or EOF.
481        let block_end = tdd_block.find("\n\n").unwrap_or(tdd_block.len());
482        let block = &tdd_block[..block_end];
483        assert!(!block.contains("Year:"), "Year should be omitted: {block:?}");
484    }
485
486    #[test]
487    fn round_trip_through_librec_parser() {
488        let schema = sample_schema();
489        let batch = sample_batch(&schema);
490        let text =
491            record_batches_to_rec_string("Book", &schema, std::slice::from_ref(&batch))
492                .unwrap();
493
494        let mut db = Db::parse_str(&text).unwrap();
495        let (schema2, batch2) = rec_to_record_batch(&mut db, "Book").unwrap();
496
497        // Same column set in the same order.
498        let names: Vec<&str> =
499            schema2.fields().iter().map(|f| f.name().as_str()).collect();
500        assert_eq!(names, vec!["Title", "Year", "Price", "InPrint"]);
501        // Types survive the round-trip.
502        assert_eq!(schema2.field(0).data_type(), &DataType::Utf8);
503        assert_eq!(schema2.field(1).data_type(), &DataType::Int64);
504        assert_eq!(schema2.field(2).data_type(), &DataType::Float64);
505        assert_eq!(schema2.field(3).data_type(), &DataType::Boolean);
506        // Row count is preserved.
507        assert_eq!(batch2.num_rows(), batch.num_rows());
508    }
509
510    #[test]
511    fn empty_record_type_rejected() {
512        let schema = sample_schema();
513        let batch = sample_batch(&schema);
514        assert!(
515            record_batches_to_rec_string("", &schema, std::slice::from_ref(&batch))
516                .is_err()
517        );
518    }
519
520    #[test]
521    fn unsupported_arrow_type_errors() {
522        let schema = Arc::new(Schema::new(vec![Field::new(
523            "Stamp",
524            DataType::Int32,
525            true,
526        )]));
527        let arr: ArrayRef = Arc::new(arrow::array::Int32Array::from(vec![Some(1)]));
528        let batch = RecordBatch::try_new(Arc::clone(&schema), vec![arr]).unwrap();
529        assert!(
530            record_batches_to_rec_string("T", &schema, std::slice::from_ref(&batch))
531                .is_err()
532        );
533    }
534}