1use std::collections::{HashMap, HashSet};
7use std::sync::Arc;
8
9use arrow::array::{
10 Array, ArrayRef, BooleanArray, BooleanBuilder, Float64Array, Float64Builder, Int64Array,
11 Int64Builder, StringArray, StringBuilder,
12};
13use arrow::datatypes::{DataType, Field, Schema};
14use arrow::record_batch::RecordBatch;
15
16use crate::rset::Rset;
17use crate::{Db, OwnedRset, Record, SelectionExpression};
18
19pub fn rec_to_record_batch(
20 db: &mut Db,
21 record_type: &str,
22) -> Result<(Arc<Schema>, RecordBatch), Box<dyn std::error::Error>> {
23 let rset = db
24 .rset_by_type(record_type)
25 .ok_or_else(|| format!("no record set of type {record_type:?}"))?;
26 rec_to_record_batch_from_rset(&rset)
27}
28
29pub fn rec_to_record_batch_from_rset(
33 rset: &Rset<'_>,
34) -> Result<(Arc<Schema>, RecordBatch), Box<dyn std::error::Error>> {
35 let mut declared_types: HashMap<String, String> = HashMap::new();
36 if let Some(desc) = rset.descriptor() {
37 for f in desc.fields() {
38 if f.name() == "%type" {
39 if let Some((field, ty)) = split_type_decl(&f.value()) {
40 declared_types.insert(field, ty);
41 }
42 }
43 }
44 }
45
46 let (column_order, rows) = collect_rows_from_rset(rset)?;
47 let schema = build_schema(&column_order, &declared_types);
48 let columns = build_columns(&schema, &rows);
49 let batch = RecordBatch::try_new(Arc::clone(&schema), columns)?;
50 Ok((schema, batch))
51}
52
53pub fn rec_to_filtered_batch(
58 db: &mut Db,
59 record_type: &str,
60 schema: &Arc<Schema>,
61 selection_expression: &SelectionExpression,
62) -> Result<RecordBatch, Box<dyn std::error::Error>> {
63 let rset = db
64 .rset_by_type(record_type)
65 .ok_or_else(|| format!("no record set of type {record_type:?}"))?;
66 rec_to_filtered_batch_from_rset(&rset, schema, selection_expression)
67}
68
69pub fn rec_to_filtered_batch_from_rset(
71 rset: &Rset<'_>,
72 schema: &Arc<Schema>,
73 selection_expression: &SelectionExpression,
74) -> Result<RecordBatch, Box<dyn std::error::Error>> {
75 let mut rows: Vec<HashMap<String, String>> = Vec::new();
76 for (i, record) in rset.records().enumerate() {
77 if !selection_expression.matches(&record) {
78 continue;
79 }
80 let mut row: HashMap<String, String> = HashMap::new();
81 for f in record.fields() {
82 let name = f.name();
83 if name.starts_with('%') {
84 continue;
85 }
86 if row.contains_key(&name) {
87 return Err(format!(
88 "field {:?} repeated in record {} (1-based); use a List<T> mapping (not yet supported) or remove the repeat",
89 name,
90 i + 1
91 )
92 .into());
93 }
94 row.insert(name.clone(), f.value());
95 }
96 rows.push(row);
97 }
98 let columns = build_columns(schema, &rows);
99 Ok(RecordBatch::try_new(Arc::clone(schema), columns)?)
100}
101
102pub fn split_type_decl(value: &str) -> Option<(String, String)> {
103 let trimmed = value.trim();
104 let (name, rest) = trimmed.split_once(char::is_whitespace)?;
105 Some((name.trim().to_string(), rest.trim().to_string()))
106}
107
108pub fn collect_rows(
109 db: &mut Db,
110 record_type: &str,
111) -> Result<(Vec<String>, Vec<HashMap<String, String>>), Box<dyn std::error::Error>> {
112 let rset = db
113 .rset_by_type(record_type)
114 .ok_or_else(|| format!("no record set of type {record_type:?}"))?;
115 collect_rows_from_rset(&rset)
116}
117
118pub fn collect_rows_from_rset(
119 rset: &Rset<'_>,
120) -> Result<(Vec<String>, Vec<HashMap<String, String>>), Box<dyn std::error::Error>> {
121 let mut column_order: Vec<String> = Vec::new();
122 let mut seen: HashSet<String> = HashSet::new();
123 let mut rows: Vec<HashMap<String, String>> = Vec::new();
124
125 for (i, record) in rset.records().enumerate() {
126 let mut row: HashMap<String, String> = HashMap::new();
127 for f in record.fields() {
128 let name = f.name();
129 if name.starts_with('%') {
130 continue;
131 }
132 if row.contains_key(&name) {
133 return Err(format!(
134 "field {:?} repeated in record {} (1-based); use a List<T> mapping (not yet supported) or remove the repeat",
135 name,
136 i + 1
137 )
138 .into());
139 }
140 row.insert(name.clone(), f.value());
141 if seen.insert(name.clone()) {
142 column_order.push(name);
143 }
144 }
145 rows.push(row);
146 }
147 Ok((column_order, rows))
148}
149
150pub fn build_schema(
151 column_order: &[String],
152 declared: &HashMap<String, String>,
153) -> Arc<Schema> {
154 let fields: Vec<Field> = column_order
155 .iter()
156 .map(|name| {
157 let dt = match declared.get(name) {
158 Some(t) => map_rec_type(t),
159 None => {
160 log::info!("no %type for field {name:?}; falling back to Utf8");
161 DataType::Utf8
162 }
163 };
164 Field::new(name, dt, true)
165 })
166 .collect();
167 Arc::new(Schema::new(fields))
168}
169
170pub fn map_rec_type(t: &str) -> DataType {
171 match t.split_whitespace().next().unwrap_or("") {
172 "int" | "range" => DataType::Int64,
173 "real" => DataType::Float64,
174 "bool" => DataType::Boolean,
175 _ => DataType::Utf8,
176 }
177}
178
179pub fn build_columns(schema: &Schema, rows: &[HashMap<String, String>]) -> Vec<ArrayRef> {
180 schema
181 .fields()
182 .iter()
183 .map(|f| build_column(f, rows))
184 .collect()
185}
186
187pub fn build_column(field: &Field, rows: &[HashMap<String, String>]) -> ArrayRef {
188 let name = field.name();
189 match field.data_type() {
190 DataType::Int64 => {
191 let mut b = Int64Builder::with_capacity(rows.len());
192 for row in rows {
193 match row.get(name).map(|s| s.trim()) {
194 Some(s) if s.is_empty() => b.append_null(),
195 Some(s) => match s.parse::<i64>() {
196 Ok(v) => b.append_value(v),
197 Err(_) => {
198 log::warn!("field {name:?}: cannot parse {s:?} as int; nulled");
199 b.append_null();
200 }
201 },
202 None => b.append_null(),
203 }
204 }
205 Arc::new(b.finish())
206 }
207 DataType::Float64 => {
208 let mut b = Float64Builder::with_capacity(rows.len());
209 for row in rows {
210 match row.get(name).map(|s| s.trim()) {
211 Some(s) if s.is_empty() => b.append_null(),
212 Some(s) => match s.parse::<f64>() {
213 Ok(v) => b.append_value(v),
214 Err(_) => {
215 log::warn!("field {name:?}: cannot parse {s:?} as real; nulled");
216 b.append_null();
217 }
218 },
219 None => b.append_null(),
220 }
221 }
222 Arc::new(b.finish())
223 }
224 DataType::Boolean => {
225 let mut b = BooleanBuilder::with_capacity(rows.len());
226 for row in rows {
227 match row.get(name).map(|s| s.trim()) {
228 Some(s) if s.is_empty() => b.append_null(),
229 Some(s) => match parse_rec_bool(s) {
230 Some(v) => b.append_value(v),
231 None => {
232 log::warn!("field {name:?}: cannot parse {s:?} as bool; nulled");
233 b.append_null();
234 }
235 },
236 None => b.append_null(),
237 }
238 }
239 Arc::new(b.finish())
240 }
241 DataType::Utf8 => {
242 let mut b = StringBuilder::with_capacity(rows.len(), rows.len() * 16);
243 for row in rows {
244 match row.get(name) {
245 Some(s) => b.append_value(s),
246 None => b.append_null(),
247 }
248 }
249 Arc::new(b.finish())
250 }
251 other => panic!("unsupported arrow type {other:?}"),
252 }
253}
254
255pub fn parse_rec_bool(s: &str) -> Option<bool> {
256 match s {
257 "yes" | "true" | "1" => Some(true),
258 "no" | "false" | "0" => Some(false),
259 _ => None,
260 }
261}
262
263pub fn record_batches_to_rec_string(
273 record_type: &str,
274 schema: &Schema,
275 batches: &[RecordBatch],
276) -> Result<String, Box<dyn std::error::Error>> {
277 if record_type.is_empty() {
278 return Err("record_type must be a non-empty rec type name".into());
279 }
280
281 let mut db = Db::new();
282 let mut rset = OwnedRset::new();
283 rset.set_descriptor(build_descriptor(record_type, schema)?);
284
285 for batch in batches {
286 if batch.num_columns() != schema.fields().len() {
287 return Err(format!(
288 "batch has {} columns but schema has {}",
289 batch.num_columns(),
290 schema.fields().len()
291 )
292 .into());
293 }
294 for row in 0..batch.num_rows() {
295 let mut record = Record::new();
296 for (col_idx, field) in schema.fields().iter().enumerate() {
297 let array = batch.column(col_idx).as_ref();
298 if array.is_null(row) {
299 continue;
300 }
301 let value = format_arrow_value(field, array, row)?;
302 record.append_field(field.name(), &value)?;
303 }
304 rset.append_record(record)?;
305 }
306 }
307
308 db.append_rset(rset)?;
309 Ok(db.to_rec_string()?)
310}
311
312fn build_descriptor(
313 record_type: &str,
314 schema: &Schema,
315) -> Result<Record, Box<dyn std::error::Error>> {
316 let mut desc = Record::new();
317 desc.append_field("%rec", record_type)?;
318 for field in schema.fields() {
319 if let Some(rec_ty) = map_arrow_to_rec_type(field.data_type())? {
320 desc.append_field("%type", &format!("{} {}", field.name(), rec_ty))?;
321 }
322 }
323 for field in schema.fields() {
324 if !field.is_nullable() {
325 desc.append_field("%mandatory", field.name())?;
326 }
327 }
328 Ok(desc)
329}
330
331pub fn map_arrow_to_rec_type(
335 dt: &DataType,
336) -> Result<Option<&'static str>, Box<dyn std::error::Error>> {
337 Ok(match dt {
338 DataType::Int64 => Some("int"),
339 DataType::Float64 => Some("real"),
340 DataType::Boolean => Some("bool"),
341 DataType::Utf8 => None,
342 other => {
343 return Err(format!("unsupported arrow type {other:?} for rec output").into());
344 }
345 })
346}
347
348pub fn format_arrow_value(
349 field: &Field,
350 array: &dyn Array,
351 row: usize,
352) -> Result<String, Box<dyn std::error::Error>> {
353 match field.data_type() {
354 DataType::Int64 => {
355 let a = array
356 .as_any()
357 .downcast_ref::<Int64Array>()
358 .ok_or("expected Int64Array")?;
359 Ok(a.value(row).to_string())
360 }
361 DataType::Float64 => {
362 let a = array
363 .as_any()
364 .downcast_ref::<Float64Array>()
365 .ok_or("expected Float64Array")?;
366 Ok(format_rec_float(a.value(row)))
367 }
368 DataType::Boolean => {
369 let a = array
370 .as_any()
371 .downcast_ref::<BooleanArray>()
372 .ok_or("expected BooleanArray")?;
373 Ok(if a.value(row) { "yes" } else { "no" }.to_string())
374 }
375 DataType::Utf8 => {
376 let a = array
377 .as_any()
378 .downcast_ref::<StringArray>()
379 .ok_or("expected StringArray")?;
380 Ok(a.value(row).to_string())
381 }
382 other => Err(format!("unsupported arrow type {other:?} for rec output").into()),
383 }
384}
385
386fn format_rec_float(f: f64) -> String {
390 if f.is_finite() && f.fract() == 0.0 {
391 format!("{f:.1}")
392 } else {
393 f.to_string()
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400 use arrow::array::{BooleanArray, Float64Array, Int64Array, StringArray};
401 use arrow::datatypes::{DataType, Field, Schema};
402
403 fn sample_schema() -> Arc<Schema> {
404 Arc::new(Schema::new(vec![
405 Field::new("Title", DataType::Utf8, false),
406 Field::new("Year", DataType::Int64, true),
407 Field::new("Price", DataType::Float64, true),
408 Field::new("InPrint", DataType::Boolean, true),
409 ]))
410 }
411
412 fn sample_batch(schema: &Arc<Schema>) -> RecordBatch {
413 let titles: ArrayRef = Arc::new(StringArray::from(vec![
414 Some("Refactoring"),
415 Some("TDD"),
416 ]));
417 let years: ArrayRef = Arc::new(Int64Array::from(vec![Some(1999), None]));
418 let prices: ArrayRef =
419 Arc::new(Float64Array::from(vec![Some(42.0), Some(19.95)]));
420 let in_print: ArrayRef =
421 Arc::new(BooleanArray::from(vec![Some(true), Some(false)]));
422 RecordBatch::try_new(
423 Arc::clone(schema),
424 vec![titles, years, prices, in_print],
425 )
426 .unwrap()
427 }
428
429 #[test]
430 fn descriptor_carries_types_and_mandatory() {
431 let schema = sample_schema();
432 let batch = sample_batch(&schema);
433 let text =
434 record_batches_to_rec_string("Book", &schema, std::slice::from_ref(&batch))
435 .unwrap();
436 assert!(text.contains("%rec: Book"));
437 assert!(text.contains("%type: Year int"));
438 assert!(text.contains("%type: Price real"));
439 assert!(text.contains("%type: InPrint bool"));
440 assert!(!text.contains("%type: Title"));
442 assert!(text.contains("%mandatory: Title"));
444 assert!(!text.contains("%mandatory: Year"));
445 }
446
447 #[test]
448 fn integer_valued_float_keeps_decimal() {
449 let schema = sample_schema();
450 let batch = sample_batch(&schema);
451 let text =
452 record_batches_to_rec_string("Book", &schema, std::slice::from_ref(&batch))
453 .unwrap();
454 assert!(text.contains("Price: 42.0"));
455 assert!(text.contains("Price: 19.95"));
456 }
457
458 #[test]
459 fn bool_writes_yes_no() {
460 let schema = sample_schema();
461 let batch = sample_batch(&schema);
462 let text =
463 record_batches_to_rec_string("Book", &schema, std::slice::from_ref(&batch))
464 .unwrap();
465 assert!(text.contains("InPrint: yes"));
466 assert!(text.contains("InPrint: no"));
467 }
468
469 #[test]
470 fn null_field_is_omitted() {
471 let schema = sample_schema();
472 let batch = sample_batch(&schema);
473 let text =
474 record_batches_to_rec_string("Book", &schema, std::slice::from_ref(&batch))
475 .unwrap();
476 let tdd_idx = text.find("Title: TDD").expect("TDD record present");
479 let tdd_block = &text[tdd_idx..];
480 let block_end = tdd_block.find("\n\n").unwrap_or(tdd_block.len());
482 let block = &tdd_block[..block_end];
483 assert!(!block.contains("Year:"), "Year should be omitted: {block:?}");
484 }
485
486 #[test]
487 fn round_trip_through_librec_parser() {
488 let schema = sample_schema();
489 let batch = sample_batch(&schema);
490 let text =
491 record_batches_to_rec_string("Book", &schema, std::slice::from_ref(&batch))
492 .unwrap();
493
494 let mut db = Db::parse_str(&text).unwrap();
495 let (schema2, batch2) = rec_to_record_batch(&mut db, "Book").unwrap();
496
497 let names: Vec<&str> =
499 schema2.fields().iter().map(|f| f.name().as_str()).collect();
500 assert_eq!(names, vec!["Title", "Year", "Price", "InPrint"]);
501 assert_eq!(schema2.field(0).data_type(), &DataType::Utf8);
503 assert_eq!(schema2.field(1).data_type(), &DataType::Int64);
504 assert_eq!(schema2.field(2).data_type(), &DataType::Float64);
505 assert_eq!(schema2.field(3).data_type(), &DataType::Boolean);
506 assert_eq!(batch2.num_rows(), batch.num_rows());
508 }
509
510 #[test]
511 fn empty_record_type_rejected() {
512 let schema = sample_schema();
513 let batch = sample_batch(&schema);
514 assert!(
515 record_batches_to_rec_string("", &schema, std::slice::from_ref(&batch))
516 .is_err()
517 );
518 }
519
520 #[test]
521 fn unsupported_arrow_type_errors() {
522 let schema = Arc::new(Schema::new(vec![Field::new(
523 "Stamp",
524 DataType::Int32,
525 true,
526 )]));
527 let arr: ArrayRef = Arc::new(arrow::array::Int32Array::from(vec![Some(1)]));
528 let batch = RecordBatch::try_new(Arc::clone(&schema), vec![arr]).unwrap();
529 assert!(
530 record_batches_to_rec_string("T", &schema, std::slice::from_ref(&batch))
531 .is_err()
532 );
533 }
534}