1use arrow::array::{
23 ArrayRef, BooleanBuilder, Float64Builder, Int32Builder, ListBuilder, StringBuilder,
24};
25use arrow::datatypes::{DataType, Field, Schema};
26use serde_json::Value;
27use std::collections::HashMap;
28use std::sync::Arc;
29
30use crate::{ProcessedRow, SofError};
31
32pub fn infer_arrow_type(values: &[Option<Value>]) -> DataType {
33 let mut type_counts: HashMap<String, usize> = HashMap::new();
34 let mut has_array = false;
35 let mut has_object = false;
36 let mut array_element_type = None;
37
38 for value in values.iter().flatten() {
39 match value {
40 Value::Bool(_) => {
41 *type_counts.entry("bool".to_string()).or_insert(0) += 1;
42 }
43 Value::Number(n) => {
44 if n.is_i64() || n.is_u64() {
45 *type_counts.entry("integer".to_string()).or_insert(0) += 1;
46 } else {
47 *type_counts.entry("decimal".to_string()).or_insert(0) += 1;
48 }
49 }
50 Value::String(_) => {
51 *type_counts.entry("string".to_string()).or_insert(0) += 1;
52 }
53 Value::Array(arr) => {
54 has_array = true;
55 if !arr.is_empty() && array_element_type.is_none() {
56 let element_values: Vec<Option<Value>> =
57 arr.iter().map(|v| Some(v.clone())).collect();
58 array_element_type = Some(infer_arrow_type(&element_values));
59 }
60 }
61 Value::Object(_) => {
62 has_object = true;
63 }
64 Value::Null => {}
65 }
66 }
67
68 if has_array {
69 if let Some(element_type) = array_element_type {
70 return DataType::List(Arc::new(Field::new("item", element_type, true)));
71 } else {
72 return DataType::List(Arc::new(Field::new("item", DataType::Utf8, true)));
73 }
74 }
75
76 if has_object {
77 return DataType::Utf8;
78 }
79
80 let most_common = type_counts
81 .into_iter()
82 .max_by_key(|(_, count)| *count)
83 .map(|(type_name, _)| type_name);
84
85 match most_common.as_deref() {
86 Some("bool") => DataType::Boolean,
87 Some("integer") => DataType::Int32,
88 Some("decimal") => DataType::Float64,
89 Some("string") => DataType::Utf8,
90 _ => DataType::Utf8,
91 }
92}
93
94pub fn create_arrow_schema(columns: &[String], rows: &[ProcessedRow]) -> Result<Schema, SofError> {
95 let sample_size = std::cmp::min(100, rows.len());
96 let mut fields = Vec::new();
97
98 for (col_idx, column_name) in columns.iter().enumerate() {
99 let sample_values: Vec<Option<Value>> = rows
100 .iter()
101 .take(sample_size)
102 .map(|row| row.values.get(col_idx).cloned().flatten())
103 .collect();
104
105 let data_type = infer_arrow_type(&sample_values);
106 let field = Field::new(column_name, data_type, true);
107 fields.push(field);
108 }
109
110 Ok(Schema::new(fields))
111}
112
113fn build_array_from_values(
114 values: Vec<Option<Value>>,
115 data_type: &DataType,
116) -> Result<ArrayRef, SofError> {
117 match data_type {
118 DataType::Boolean => {
119 let mut builder = BooleanBuilder::new();
120 for value in values {
121 match value {
122 Some(Value::Bool(b)) => builder.append_value(b),
123 _ => builder.append_null(),
124 }
125 }
126 Ok(Arc::new(builder.finish()))
127 }
128 DataType::Int32 => {
129 let mut builder = Int32Builder::new();
130 for value in values {
131 match value {
132 Some(Value::Number(n)) if n.is_i64() => {
133 if let Some(i) = n.as_i64() {
134 builder.append_value(i as i32);
135 } else {
136 builder.append_null();
137 }
138 }
139 _ => builder.append_null(),
140 }
141 }
142 Ok(Arc::new(builder.finish()))
143 }
144 DataType::Float64 => {
145 let mut builder = Float64Builder::new();
146 for value in values {
147 match value {
148 Some(Value::Number(n)) => {
149 if let Some(f) = n.as_f64() {
150 builder.append_value(f);
151 } else {
152 builder.append_null();
153 }
154 }
155 _ => builder.append_null(),
156 }
157 }
158 Ok(Arc::new(builder.finish()))
159 }
160 DataType::Utf8 => {
161 let mut builder = StringBuilder::new();
162 for value in values {
163 match value {
164 Some(Value::String(s)) => builder.append_value(s),
165 Some(Value::Number(n)) => builder.append_value(n.to_string()),
166 Some(Value::Bool(b)) => builder.append_value(b.to_string()),
167 Some(Value::Object(_)) | Some(Value::Array(_)) => {
168 builder.append_value(
169 serde_json::to_string(&value.unwrap())
170 .unwrap_or_else(|_| "null".to_string()),
171 );
172 }
173 _ => builder.append_null(),
174 }
175 }
176 Ok(Arc::new(builder.finish()))
177 }
178 DataType::List(field) => {
179 let element_type = field.data_type();
180 match element_type {
181 DataType::Utf8 => {
182 let mut builder = ListBuilder::new(StringBuilder::new());
183 for value in values {
184 match value {
185 Some(Value::Array(arr)) => {
186 for elem in arr {
187 match elem {
188 Value::String(s) => builder.values().append_value(s),
189 _ => builder.values().append_value(elem.to_string()),
190 }
191 }
192 builder.append(true);
193 }
194 _ => builder.append(false),
195 }
196 }
197 Ok(Arc::new(builder.finish()))
198 }
199 _ => {
200 let mut string_builder = ListBuilder::new(StringBuilder::new());
201 for value in values {
202 match value {
203 Some(Value::Array(arr)) => {
204 for elem in arr {
205 string_builder.values().append_value(elem.to_string());
206 }
207 string_builder.append(true);
208 }
209 _ => string_builder.append(false),
210 }
211 }
212 Ok(Arc::new(string_builder.finish()))
213 }
214 }
215 }
216 _ => Err(SofError::ParquetConversionError(format!(
217 "Unsupported data type for Parquet conversion: {:?}",
218 data_type
219 ))),
220 }
221}
222
223pub fn process_to_arrow_arrays(
224 schema: &Schema,
225 _columns: &[String],
226 rows: &[ProcessedRow],
227) -> Result<Vec<ArrayRef>, SofError> {
228 let mut arrays = Vec::new();
229
230 for (col_idx, field) in schema.fields().iter().enumerate() {
231 let values: Vec<Option<Value>> = rows
232 .iter()
233 .map(|row| row.values.get(col_idx).cloned().flatten())
234 .collect();
235
236 let array = build_array_from_values(values, field.data_type())?;
237 arrays.push(array);
238 }
239
240 Ok(arrays)
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use arrow::array::Array;
247 use serde_json::json;
248
249 #[test]
250 fn test_infer_boolean_type() {
251 let values = vec![
252 Some(json!(true)),
253 Some(json!(false)),
254 None,
255 Some(json!(true)),
256 ];
257 assert_eq!(infer_arrow_type(&values), DataType::Boolean);
258 }
259
260 #[test]
261 fn test_infer_integer_type() {
262 let values = vec![Some(json!(42)), Some(json!(100)), None, Some(json!(-5))];
263 assert_eq!(infer_arrow_type(&values), DataType::Int32);
264 }
265
266 #[test]
267 fn test_infer_decimal_type() {
268 let values = vec![
269 Some(json!(std::f64::consts::PI)),
270 Some(json!(std::f64::consts::E)),
271 None,
272 Some(json!(1.0)),
273 ];
274 assert_eq!(infer_arrow_type(&values), DataType::Float64);
275 }
276
277 #[test]
278 fn test_infer_string_type() {
279 let values = vec![
280 Some(json!("hello")),
281 Some(json!("world")),
282 None,
283 Some(json!("test")),
284 ];
285 assert_eq!(infer_arrow_type(&values), DataType::Utf8);
286 }
287
288 #[test]
289 fn test_infer_array_type() {
290 let values = vec![Some(json!(["a", "b", "c"])), Some(json!(["d", "e"])), None];
291 match infer_arrow_type(&values) {
292 DataType::List(field) => {
293 assert_eq!(field.name(), "item");
294 assert_eq!(field.data_type(), &DataType::Utf8);
295 }
296 _ => panic!("Expected List type"),
297 }
298 }
299
300 #[test]
301 fn test_infer_object_type_as_string() {
302 let values = vec![
303 Some(json!({"key": "value"})),
304 Some(json!({"foo": "bar"})),
305 None,
306 ];
307 assert_eq!(infer_arrow_type(&values), DataType::Utf8);
308 }
309
310 #[test]
311 fn test_mixed_types_favor_most_common() {
312 let values = vec![
313 Some(json!("string1")),
314 Some(json!("string2")),
315 Some(json!(42)),
316 Some(json!("string3")),
317 ];
318 assert_eq!(infer_arrow_type(&values), DataType::Utf8);
319 }
320
321 #[test]
322 fn test_create_schema_basic() {
323 let columns = vec!["id".to_string(), "name".to_string(), "age".to_string()];
324 let rows = vec![
325 ProcessedRow {
326 values: vec![Some(json!("123")), Some(json!("John Doe")), Some(json!(42))],
327 },
328 ProcessedRow {
329 values: vec![
330 Some(json!("456")),
331 Some(json!("Jane Smith")),
332 Some(json!(35)),
333 ],
334 },
335 ];
336
337 let schema = create_arrow_schema(&columns, &rows).unwrap();
338 assert_eq!(schema.fields().len(), 3);
339 assert_eq!(schema.field(0).name(), "id");
340 assert_eq!(schema.field(0).data_type(), &DataType::Utf8);
341 assert_eq!(schema.field(1).name(), "name");
342 assert_eq!(schema.field(1).data_type(), &DataType::Utf8);
343 assert_eq!(schema.field(2).name(), "age");
344 assert_eq!(schema.field(2).data_type(), &DataType::Int32);
345 }
346
347 #[test]
348 fn test_build_boolean_array() {
349 let values = vec![
350 Some(json!(true)),
351 None,
352 Some(json!(false)),
353 Some(json!(true)),
354 ];
355 let array = build_array_from_values(values, &DataType::Boolean).unwrap();
356 let bool_array = array
357 .as_any()
358 .downcast_ref::<arrow::array::BooleanArray>()
359 .unwrap();
360
361 assert_eq!(array.len(), 4);
362 assert!(bool_array.value(0));
363 assert!(array.is_null(1));
364 assert!(!bool_array.value(2));
365 assert!(bool_array.value(3));
366 }
367
368 #[test]
369 fn test_build_string_array_with_mixed_types() {
370 let values = vec![
371 Some(json!("text")),
372 Some(json!(42)),
373 Some(json!(true)),
374 Some(json!({"key": "value"})),
375 None,
376 ];
377 let array = build_array_from_values(values, &DataType::Utf8).unwrap();
378 let string_array = array
379 .as_any()
380 .downcast_ref::<arrow::array::StringArray>()
381 .unwrap();
382
383 assert_eq!(array.len(), 5);
384 assert_eq!(string_array.value(0), "text");
385 assert_eq!(string_array.value(1), "42");
386 assert_eq!(string_array.value(2), "true");
387 assert!(string_array.value(3).contains("\"key\""));
388 assert!(array.is_null(4));
389 }
390}