1use super::ValueIter;
19use arrow_schema::{ArrowError, DataType, Field, Fields, Schema};
20use indexmap::map::IndexMap as HashMap;
21use indexmap::set::IndexSet as HashSet;
22use serde_json::Value;
23use std::borrow::Borrow;
24use std::io::{BufRead, Seek};
25use std::sync::Arc;
26
27#[derive(Debug, Clone)]
28enum InferredType {
29 Scalar(HashSet<DataType>),
30 Array(Box<InferredType>),
31 Object(HashMap<String, InferredType>),
32 Any,
33}
34
35impl InferredType {
36 fn merge(&mut self, other: InferredType) -> Result<(), ArrowError> {
37 match (self, other) {
38 (InferredType::Array(s), InferredType::Array(o)) => {
39 s.merge(*o)?;
40 }
41 (InferredType::Scalar(self_hs), InferredType::Scalar(other_hs)) => {
42 other_hs.into_iter().for_each(|v| {
43 self_hs.insert(v);
44 });
45 }
46 (InferredType::Object(self_map), InferredType::Object(other_map)) => {
47 for (k, v) in other_map {
48 self_map.entry(k).or_insert(InferredType::Any).merge(v)?;
49 }
50 }
51 (s @ InferredType::Any, v) => {
52 *s = v;
53 }
54 (_, InferredType::Any) => {}
55 (InferredType::Array(self_inner_type), other_scalar @ InferredType::Scalar(_)) => {
57 self_inner_type.merge(other_scalar)?;
58 }
59 (s @ InferredType::Scalar(_), InferredType::Array(mut other_inner_type)) => {
60 other_inner_type.merge(s.clone())?;
61 *s = InferredType::Array(other_inner_type);
62 }
63 (s, o) => {
65 return Err(ArrowError::JsonError(format!(
66 "Incompatible type found during schema inference: {s:?} v.s. {o:?}",
67 )));
68 }
69 }
70
71 Ok(())
72 }
73
74 fn is_none_or_any(ty: Option<&Self>) -> bool {
75 matches!(ty, Some(Self::Any) | None)
76 }
77}
78
79fn list_type_of(ty: DataType) -> DataType {
81 DataType::List(Arc::new(Field::new_list_field(ty, true)))
82}
83
84fn coerce_data_type(dt: Vec<&DataType>) -> DataType {
90 let mut dt_iter = dt.into_iter().cloned();
91 let dt_init = dt_iter.next().unwrap_or(DataType::Utf8);
92
93 dt_iter.fold(dt_init, |l, r| match (l, r) {
94 (DataType::Null, o) | (o, DataType::Null) => o,
95 (DataType::Boolean, DataType::Boolean) => DataType::Boolean,
96 (DataType::Int64, DataType::Int64) => DataType::Int64,
97 (DataType::Float64, DataType::Float64)
98 | (DataType::Float64, DataType::Int64)
99 | (DataType::Int64, DataType::Float64) => DataType::Float64,
100 (DataType::List(l), DataType::List(r)) => {
101 list_type_of(coerce_data_type(vec![l.data_type(), r.data_type()]))
102 }
103 (DataType::List(e), not_list) | (not_list, DataType::List(e)) => {
105 list_type_of(coerce_data_type(vec![e.data_type(), ¬_list]))
106 }
107 _ => DataType::Utf8,
108 })
109}
110
111fn generate_datatype(t: &InferredType) -> Result<DataType, ArrowError> {
112 Ok(match t {
113 InferredType::Scalar(hs) => coerce_data_type(hs.iter().collect()),
114 InferredType::Object(spec) => DataType::Struct(generate_fields(spec)?),
115 InferredType::Array(ele_type) => list_type_of(generate_datatype(ele_type)?),
116 InferredType::Any => DataType::Null,
117 })
118}
119
120fn generate_fields(spec: &HashMap<String, InferredType>) -> Result<Fields, ArrowError> {
121 spec.iter()
122 .map(|(k, types)| Ok(Field::new(k, generate_datatype(types)?, true)))
123 .collect()
124}
125
126fn generate_schema(spec: HashMap<String, InferredType>) -> Result<Schema, ArrowError> {
128 Ok(Schema::new(generate_fields(&spec)?))
129}
130
131pub fn infer_json_schema_from_seekable<R: BufRead + Seek>(
156 mut reader: R,
157 max_read_records: Option<usize>,
158) -> Result<(Schema, usize), ArrowError> {
159 let schema = infer_json_schema(&mut reader, max_read_records);
160 reader.rewind()?;
162
163 schema
164}
165
166pub fn infer_json_schema<R: BufRead>(
204 reader: R,
205 max_read_records: Option<usize>,
206) -> Result<(Schema, usize), ArrowError> {
207 let mut values = ValueIter::new(reader, max_read_records);
208 let schema = infer_json_schema_from_iterator(&mut values)?;
209 Ok((schema, values.record_count()))
210}
211
212fn set_object_scalar_field_type(
213 field_types: &mut HashMap<String, InferredType>,
214 key: &str,
215 ftype: DataType,
216) -> Result<(), ArrowError> {
217 if InferredType::is_none_or_any(field_types.get(key)) {
218 field_types.insert(key.to_string(), InferredType::Scalar(HashSet::new()));
219 }
220
221 match field_types.get_mut(key).unwrap() {
222 InferredType::Scalar(hs) => {
223 hs.insert(ftype);
224 Ok(())
225 }
226 scalar_array @ InferredType::Array(_) => {
229 let mut hs = HashSet::new();
230 hs.insert(ftype);
231 scalar_array.merge(InferredType::Scalar(hs))?;
232 Ok(())
233 }
234 t => Err(ArrowError::JsonError(format!(
235 "Expected scalar or scalar array JSON type, found: {t:?}",
236 ))),
237 }
238}
239
240fn infer_scalar_array_type(array: &[Value]) -> Result<InferredType, ArrowError> {
241 let mut hs = HashSet::new();
242
243 for v in array {
244 match v {
245 Value::Null => {}
246 Value::Number(n) => {
247 if n.is_i64() {
248 hs.insert(DataType::Int64);
249 } else {
250 hs.insert(DataType::Float64);
251 }
252 }
253 Value::Bool(_) => {
254 hs.insert(DataType::Boolean);
255 }
256 Value::String(_) => {
257 hs.insert(DataType::Utf8);
258 }
259 Value::Array(_) | Value::Object(_) => {
260 return Err(ArrowError::JsonError(format!(
261 "Expected scalar value for scalar array, got: {v:?}"
262 )));
263 }
264 }
265 }
266
267 Ok(InferredType::Scalar(hs))
268}
269
270fn infer_nested_array_type(array: &[Value]) -> Result<InferredType, ArrowError> {
271 let mut inner_ele_type = InferredType::Any;
272
273 for v in array {
274 match v {
275 Value::Array(inner_array) => {
276 inner_ele_type.merge(infer_array_element_type(inner_array)?)?;
277 }
278 x => {
279 return Err(ArrowError::JsonError(format!(
280 "Got non array element in nested array: {x:?}"
281 )));
282 }
283 }
284 }
285
286 Ok(InferredType::Array(Box::new(inner_ele_type)))
287}
288
289fn infer_struct_array_type(array: &[Value]) -> Result<InferredType, ArrowError> {
290 let mut field_types = HashMap::new();
291
292 for v in array {
293 match v {
294 Value::Object(map) => {
295 collect_field_types_from_object(&mut field_types, map)?;
296 }
297 _ => {
298 return Err(ArrowError::JsonError(format!(
299 "Expected struct value for struct array, got: {v:?}"
300 )));
301 }
302 }
303 }
304
305 Ok(InferredType::Object(field_types))
306}
307
308fn infer_array_element_type(array: &[Value]) -> Result<InferredType, ArrowError> {
309 match array.iter().take(1).next() {
310 None => Ok(InferredType::Any), Some(a) => match a {
312 Value::Array(_) => infer_nested_array_type(array),
313 Value::Object(_) => infer_struct_array_type(array),
314 _ => infer_scalar_array_type(array),
315 },
316 }
317}
318
319fn collect_field_types_from_object(
320 field_types: &mut HashMap<String, InferredType>,
321 map: &serde_json::map::Map<String, Value>,
322) -> Result<(), ArrowError> {
323 for (k, v) in map {
324 match v {
325 Value::Array(array) => {
326 let ele_type = infer_array_element_type(array)?;
327
328 if InferredType::is_none_or_any(field_types.get(k)) {
329 match ele_type {
330 InferredType::Scalar(_) => {
331 field_types.insert(
332 k.to_string(),
333 InferredType::Array(Box::new(InferredType::Scalar(HashSet::new()))),
334 );
335 }
336 InferredType::Object(_) => {
337 field_types.insert(
338 k.to_string(),
339 InferredType::Array(Box::new(InferredType::Object(HashMap::new()))),
340 );
341 }
342 InferredType::Any | InferredType::Array(_) => {
343 field_types.insert(
346 k.to_string(),
347 InferredType::Array(Box::new(InferredType::Any)),
348 );
349 }
350 }
351 }
352
353 match field_types.get_mut(k).unwrap() {
354 InferredType::Array(inner_type) => {
355 inner_type.merge(ele_type)?;
356 }
357 field_type @ InferredType::Scalar(_) => {
360 field_type.merge(ele_type)?;
361 *field_type = InferredType::Array(Box::new(field_type.clone()));
362 }
363 t => {
364 return Err(ArrowError::JsonError(format!(
365 "Expected array json type, found: {t:?}",
366 )));
367 }
368 }
369 }
370 Value::Bool(_) => {
371 set_object_scalar_field_type(field_types, k, DataType::Boolean)?;
372 }
373 Value::Null => {
374 if !field_types.contains_key(k) {
377 field_types.insert(k.to_string(), InferredType::Any);
378 }
379 }
380 Value::Number(n) => {
381 if n.is_i64() {
382 set_object_scalar_field_type(field_types, k, DataType::Int64)?;
383 } else {
384 set_object_scalar_field_type(field_types, k, DataType::Float64)?;
385 }
386 }
387 Value::String(_) => {
388 set_object_scalar_field_type(field_types, k, DataType::Utf8)?;
389 }
390 Value::Object(inner_map) => {
391 if let InferredType::Any = field_types.get(k).unwrap_or(&InferredType::Any) {
392 field_types.insert(k.to_string(), InferredType::Object(HashMap::new()));
393 }
394 match field_types.get_mut(k).unwrap() {
395 InferredType::Object(inner_field_types) => {
396 collect_field_types_from_object(inner_field_types, inner_map)?;
397 }
398 t => {
399 return Err(ArrowError::JsonError(format!(
400 "Expected object json type, found: {t:?}",
401 )));
402 }
403 }
404 }
405 }
406 }
407
408 Ok(())
409}
410
411pub fn infer_json_schema_from_iterator<I, V>(value_iter: I) -> Result<Schema, ArrowError>
425where
426 I: Iterator<Item = Result<V, ArrowError>>,
427 V: Borrow<Value>,
428{
429 let mut field_types: HashMap<String, InferredType> = HashMap::new();
430
431 for record in value_iter {
432 match record?.borrow() {
433 Value::Object(map) => {
434 collect_field_types_from_object(&mut field_types, map)?;
435 }
436 value => {
437 return Err(ArrowError::JsonError(format!(
438 "Expected JSON record to be an object, found {value:?}"
439 )));
440 }
441 };
442 }
443
444 generate_schema(field_types)
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450 use flate2::read::GzDecoder;
451 use std::fs::File;
452 use std::io::{BufReader, Cursor};
453
454 #[test]
455 fn test_json_infer_schema() {
456 let schema = Schema::new(vec![
457 Field::new("a", DataType::Int64, true),
458 Field::new("b", list_type_of(DataType::Float64), true),
459 Field::new("c", list_type_of(DataType::Boolean), true),
460 Field::new("d", list_type_of(DataType::Utf8), true),
461 ]);
462
463 let mut reader = BufReader::new(File::open("test/data/mixed_arrays.json").unwrap());
464 let (inferred_schema, n_rows) = infer_json_schema_from_seekable(&mut reader, None).unwrap();
465
466 assert_eq!(inferred_schema, schema);
467 assert_eq!(n_rows, 4);
468
469 let file = File::open("test/data/mixed_arrays.json.gz").unwrap();
470 let mut reader = BufReader::new(GzDecoder::new(&file));
471 let (inferred_schema, n_rows) = infer_json_schema(&mut reader, None).unwrap();
472
473 assert_eq!(inferred_schema, schema);
474 assert_eq!(n_rows, 4);
475 }
476
477 #[test]
478 fn test_row_limit() {
479 let mut reader = BufReader::new(File::open("test/data/basic.json").unwrap());
480
481 let (_, n_rows) = infer_json_schema_from_seekable(&mut reader, None).unwrap();
482 assert_eq!(n_rows, 12);
483
484 let (_, n_rows) = infer_json_schema_from_seekable(&mut reader, Some(5)).unwrap();
485 assert_eq!(n_rows, 5);
486 }
487
488 #[test]
489 fn test_json_infer_schema_nested_structs() {
490 let schema = Schema::new(vec![
491 Field::new(
492 "c1",
493 DataType::Struct(Fields::from(vec![
494 Field::new("a", DataType::Boolean, true),
495 Field::new(
496 "b",
497 DataType::Struct(vec![Field::new("c", DataType::Utf8, true)].into()),
498 true,
499 ),
500 ])),
501 true,
502 ),
503 Field::new("c2", DataType::Int64, true),
504 Field::new("c3", DataType::Utf8, true),
505 ]);
506
507 let inferred_schema = infer_json_schema_from_iterator(
508 vec![
509 Ok(serde_json::json!({"c1": {"a": true, "b": {"c": "text"}}, "c2": 1})),
510 Ok(serde_json::json!({"c1": {"a": false, "b": null}, "c2": 0})),
511 Ok(serde_json::json!({"c1": {"a": true, "b": {"c": "text"}}, "c3": "ok"})),
512 ]
513 .into_iter(),
514 )
515 .unwrap();
516
517 assert_eq!(inferred_schema, schema);
518 }
519
520 #[test]
521 fn test_json_infer_schema_struct_in_list() {
522 let schema = Schema::new(vec![
523 Field::new(
524 "c1",
525 list_type_of(DataType::Struct(Fields::from(vec![
526 Field::new("a", DataType::Utf8, true),
527 Field::new("b", DataType::Int64, true),
528 Field::new("c", DataType::Boolean, true),
529 ]))),
530 true,
531 ),
532 Field::new("c2", DataType::Float64, true),
533 Field::new(
534 "c3",
535 list_type_of(DataType::Null),
537 true,
538 ),
539 ]);
540
541 let inferred_schema = infer_json_schema_from_iterator(
542 vec![
543 Ok(serde_json::json!({
544 "c1": [{"a": "foo", "b": 100}], "c2": 1, "c3": [],
545 })),
546 Ok(serde_json::json!({
547 "c1": [{"a": "bar", "b": 2}, {"a": "foo", "c": true}], "c2": 0, "c3": [],
548 })),
549 Ok(serde_json::json!({"c1": [], "c2": 0.5, "c3": []})),
550 ]
551 .into_iter(),
552 )
553 .unwrap();
554
555 assert_eq!(inferred_schema, schema);
556 }
557
558 #[test]
559 fn test_json_infer_schema_nested_list() {
560 let schema = Schema::new(vec![
561 Field::new("c1", list_type_of(list_type_of(DataType::Utf8)), true),
562 Field::new("c2", DataType::Float64, true),
563 ]);
564
565 let inferred_schema = infer_json_schema_from_iterator(
566 vec![
567 Ok(serde_json::json!({
568 "c1": [],
569 "c2": 12,
570 })),
571 Ok(serde_json::json!({
572 "c1": [["a", "b"], ["c"]],
573 })),
574 Ok(serde_json::json!({
575 "c1": [["foo"]],
576 "c2": 0.11,
577 })),
578 ]
579 .into_iter(),
580 )
581 .unwrap();
582
583 assert_eq!(inferred_schema, schema);
584 }
585
586 #[test]
587 fn test_infer_json_schema_bigger_than_i64_max() {
588 let bigger_than_i64_max = (i64::MAX as i128) + 1;
589 let smaller_than_i64_min = (i64::MIN as i128) - 1;
590 let json = format!(
591 "{{ \"bigger_than_i64_max\": {bigger_than_i64_max}, \"smaller_than_i64_min\": {smaller_than_i64_min} }}",
592 );
593 let mut buf_reader = BufReader::new(json.as_bytes());
594 let (inferred_schema, _) = infer_json_schema(&mut buf_reader, Some(1)).unwrap();
595 let fields = inferred_schema.fields();
596
597 let (_, big_field) = fields.find("bigger_than_i64_max").unwrap();
598 assert_eq!(big_field.data_type(), &DataType::Float64);
599 let (_, small_field) = fields.find("smaller_than_i64_min").unwrap();
600 assert_eq!(small_field.data_type(), &DataType::Float64);
601 }
602
603 #[test]
604 fn test_coercion_scalar_and_list() {
605 assert_eq!(
606 list_type_of(DataType::Float64),
607 coerce_data_type(vec![&DataType::Float64, &list_type_of(DataType::Float64)])
608 );
609 assert_eq!(
610 list_type_of(DataType::Float64),
611 coerce_data_type(vec![&DataType::Float64, &list_type_of(DataType::Int64)])
612 );
613 assert_eq!(
614 list_type_of(DataType::Int64),
615 coerce_data_type(vec![&DataType::Int64, &list_type_of(DataType::Int64)])
616 );
617 assert_eq!(
619 list_type_of(DataType::Utf8),
620 coerce_data_type(vec![&DataType::Boolean, &list_type_of(DataType::Float64)])
621 );
622 }
623
624 #[test]
625 fn test_invalid_json_infer_schema() {
626 let re = infer_json_schema_from_seekable(Cursor::new(b"}"), None);
627 assert_eq!(
628 re.err().unwrap().to_string(),
629 "Json error: Not valid JSON: expected value at line 1 column 1",
630 );
631 }
632
633 #[test]
634 fn test_null_field_inferred_as_null() {
635 let data = r#"
636 {"in":1, "ni":null, "ns":null, "sn":"4", "n":null, "an":[], "na": null, "nas":null}
637 {"in":null, "ni":2, "ns":"3", "sn":null, "n":null, "an":null, "na": [], "nas":["8"]}
638 {"in":1, "ni":null, "ns":null, "sn":"4", "n":null, "an":[], "na": null, "nas":[]}
639 "#;
640 let (inferred_schema, _) =
641 infer_json_schema_from_seekable(Cursor::new(data), None).expect("infer");
642 let schema = Schema::new(vec![
643 Field::new("an", list_type_of(DataType::Null), true),
644 Field::new("in", DataType::Int64, true),
645 Field::new("n", DataType::Null, true),
646 Field::new("na", list_type_of(DataType::Null), true),
647 Field::new("nas", list_type_of(DataType::Utf8), true),
648 Field::new("ni", DataType::Int64, true),
649 Field::new("ns", DataType::Utf8, true),
650 Field::new("sn", DataType::Utf8, true),
651 ]);
652 assert_eq!(inferred_schema, schema);
653 }
654
655 #[test]
656 fn test_infer_from_null_then_object() {
657 let data = r#"
658 {"obj":null}
659 {"obj":{"foo":1}}
660 "#;
661 let (inferred_schema, _) =
662 infer_json_schema_from_seekable(Cursor::new(data), None).expect("infer");
663 let schema = Schema::new(vec![Field::new(
664 "obj",
665 DataType::Struct(
666 [Field::new("foo", DataType::Int64, true)]
667 .into_iter()
668 .collect(),
669 ),
670 true,
671 )]);
672 assert_eq!(inferred_schema, schema);
673 }
674}