1use arrow::array::{
2 ArrayRef, BooleanBuilder, Float64Builder, Int32Builder, ListBuilder, StringBuilder,
3};
4use arrow::datatypes::{DataType, Field, Schema};
5use serde_json::Value;
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use crate::{ProcessedRow, SofError};
10
11pub fn infer_arrow_type(values: &[Option<Value>]) -> DataType {
12 let mut type_counts: HashMap<String, usize> = HashMap::new();
13 let mut has_array = false;
14 let mut has_object = false;
15 let mut array_element_type = None;
16
17 for value in values.iter().flatten() {
18 match value {
19 Value::Bool(_) => {
20 *type_counts.entry("bool".to_string()).or_insert(0) += 1;
21 }
22 Value::Number(n) => {
23 if n.is_i64() || n.is_u64() {
24 *type_counts.entry("integer".to_string()).or_insert(0) += 1;
25 } else {
26 *type_counts.entry("decimal".to_string()).or_insert(0) += 1;
27 }
28 }
29 Value::String(_) => {
30 *type_counts.entry("string".to_string()).or_insert(0) += 1;
31 }
32 Value::Array(arr) => {
33 has_array = true;
34 if !arr.is_empty() && array_element_type.is_none() {
35 let element_values: Vec<Option<Value>> =
36 arr.iter().map(|v| Some(v.clone())).collect();
37 array_element_type = Some(infer_arrow_type(&element_values));
38 }
39 }
40 Value::Object(_) => {
41 has_object = true;
42 }
43 Value::Null => {}
44 }
45 }
46
47 if has_array {
48 if let Some(element_type) = array_element_type {
49 return DataType::List(Arc::new(Field::new("item", element_type, true)));
50 } else {
51 return DataType::List(Arc::new(Field::new("item", DataType::Utf8, true)));
52 }
53 }
54
55 if has_object {
56 return DataType::Utf8;
57 }
58
59 let most_common = type_counts
60 .into_iter()
61 .max_by_key(|(_, count)| *count)
62 .map(|(type_name, _)| type_name);
63
64 match most_common.as_deref() {
65 Some("bool") => DataType::Boolean,
66 Some("integer") => DataType::Int32,
67 Some("decimal") => DataType::Float64,
68 Some("string") => DataType::Utf8,
69 _ => DataType::Utf8,
70 }
71}
72
73pub fn create_arrow_schema(columns: &[String], rows: &[ProcessedRow]) -> Result<Schema, SofError> {
74 let sample_size = std::cmp::min(100, rows.len());
75 let mut fields = Vec::new();
76
77 for (col_idx, column_name) in columns.iter().enumerate() {
78 let sample_values: Vec<Option<Value>> = rows
79 .iter()
80 .take(sample_size)
81 .map(|row| row.values.get(col_idx).cloned().flatten())
82 .collect();
83
84 let data_type = infer_arrow_type(&sample_values);
85 let field = Field::new(column_name, data_type, true);
86 fields.push(field);
87 }
88
89 Ok(Schema::new(fields))
90}
91
92fn build_array_from_values(
93 values: Vec<Option<Value>>,
94 data_type: &DataType,
95) -> Result<ArrayRef, SofError> {
96 match data_type {
97 DataType::Boolean => {
98 let mut builder = BooleanBuilder::new();
99 for value in values {
100 match value {
101 Some(Value::Bool(b)) => builder.append_value(b),
102 _ => builder.append_null(),
103 }
104 }
105 Ok(Arc::new(builder.finish()))
106 }
107 DataType::Int32 => {
108 let mut builder = Int32Builder::new();
109 for value in values {
110 match value {
111 Some(Value::Number(n)) if n.is_i64() => {
112 if let Some(i) = n.as_i64() {
113 builder.append_value(i as i32);
114 } else {
115 builder.append_null();
116 }
117 }
118 _ => builder.append_null(),
119 }
120 }
121 Ok(Arc::new(builder.finish()))
122 }
123 DataType::Float64 => {
124 let mut builder = Float64Builder::new();
125 for value in values {
126 match value {
127 Some(Value::Number(n)) => {
128 if let Some(f) = n.as_f64() {
129 builder.append_value(f);
130 } else {
131 builder.append_null();
132 }
133 }
134 _ => builder.append_null(),
135 }
136 }
137 Ok(Arc::new(builder.finish()))
138 }
139 DataType::Utf8 => {
140 let mut builder = StringBuilder::new();
141 for value in values {
142 match value {
143 Some(Value::String(s)) => builder.append_value(s),
144 Some(Value::Number(n)) => builder.append_value(n.to_string()),
145 Some(Value::Bool(b)) => builder.append_value(b.to_string()),
146 Some(Value::Object(_)) | Some(Value::Array(_)) => {
147 builder.append_value(
148 serde_json::to_string(&value.unwrap())
149 .unwrap_or_else(|_| "null".to_string()),
150 );
151 }
152 _ => builder.append_null(),
153 }
154 }
155 Ok(Arc::new(builder.finish()))
156 }
157 DataType::List(field) => {
158 let element_type = field.data_type();
159 match element_type {
160 DataType::Utf8 => {
161 let mut builder = ListBuilder::new(StringBuilder::new());
162 for value in values {
163 match value {
164 Some(Value::Array(arr)) => {
165 for elem in arr {
166 match elem {
167 Value::String(s) => builder.values().append_value(s),
168 _ => builder.values().append_value(elem.to_string()),
169 }
170 }
171 builder.append(true);
172 }
173 _ => builder.append(false),
174 }
175 }
176 Ok(Arc::new(builder.finish()))
177 }
178 _ => {
179 let mut string_builder = ListBuilder::new(StringBuilder::new());
180 for value in values {
181 match value {
182 Some(Value::Array(arr)) => {
183 for elem in arr {
184 string_builder.values().append_value(elem.to_string());
185 }
186 string_builder.append(true);
187 }
188 _ => string_builder.append(false),
189 }
190 }
191 Ok(Arc::new(string_builder.finish()))
192 }
193 }
194 }
195 _ => Err(SofError::ParquetConversionError(format!(
196 "Unsupported data type for Parquet conversion: {:?}",
197 data_type
198 ))),
199 }
200}
201
202pub fn process_to_arrow_arrays(
203 schema: &Schema,
204 _columns: &[String],
205 rows: &[ProcessedRow],
206) -> Result<Vec<ArrayRef>, SofError> {
207 let mut arrays = Vec::new();
208
209 for (col_idx, field) in schema.fields().iter().enumerate() {
210 let values: Vec<Option<Value>> = rows
211 .iter()
212 .map(|row| row.values.get(col_idx).cloned().flatten())
213 .collect();
214
215 let array = build_array_from_values(values, field.data_type())?;
216 arrays.push(array);
217 }
218
219 Ok(arrays)
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use arrow::array::Array;
226 use serde_json::json;
227
228 #[test]
229 fn test_infer_boolean_type() {
230 let values = vec![
231 Some(json!(true)),
232 Some(json!(false)),
233 None,
234 Some(json!(true)),
235 ];
236 assert_eq!(infer_arrow_type(&values), DataType::Boolean);
237 }
238
239 #[test]
240 fn test_infer_integer_type() {
241 let values = vec![Some(json!(42)), Some(json!(100)), None, Some(json!(-5))];
242 assert_eq!(infer_arrow_type(&values), DataType::Int32);
243 }
244
245 #[test]
246 fn test_infer_decimal_type() {
247 let values = vec![
248 Some(json!(std::f64::consts::PI)),
249 Some(json!(std::f64::consts::E)),
250 None,
251 Some(json!(1.0)),
252 ];
253 assert_eq!(infer_arrow_type(&values), DataType::Float64);
254 }
255
256 #[test]
257 fn test_infer_string_type() {
258 let values = vec![
259 Some(json!("hello")),
260 Some(json!("world")),
261 None,
262 Some(json!("test")),
263 ];
264 assert_eq!(infer_arrow_type(&values), DataType::Utf8);
265 }
266
267 #[test]
268 fn test_infer_array_type() {
269 let values = vec![Some(json!(["a", "b", "c"])), Some(json!(["d", "e"])), None];
270 match infer_arrow_type(&values) {
271 DataType::List(field) => {
272 assert_eq!(field.name(), "item");
273 assert_eq!(field.data_type(), &DataType::Utf8);
274 }
275 _ => panic!("Expected List type"),
276 }
277 }
278
279 #[test]
280 fn test_infer_object_type_as_string() {
281 let values = vec![
282 Some(json!({"key": "value"})),
283 Some(json!({"foo": "bar"})),
284 None,
285 ];
286 assert_eq!(infer_arrow_type(&values), DataType::Utf8);
287 }
288
289 #[test]
290 fn test_mixed_types_favor_most_common() {
291 let values = vec![
292 Some(json!("string1")),
293 Some(json!("string2")),
294 Some(json!(42)),
295 Some(json!("string3")),
296 ];
297 assert_eq!(infer_arrow_type(&values), DataType::Utf8);
298 }
299
300 #[test]
301 fn test_create_schema_basic() {
302 let columns = vec!["id".to_string(), "name".to_string(), "age".to_string()];
303 let rows = vec![
304 ProcessedRow {
305 values: vec![Some(json!("123")), Some(json!("John Doe")), Some(json!(42))],
306 },
307 ProcessedRow {
308 values: vec![
309 Some(json!("456")),
310 Some(json!("Jane Smith")),
311 Some(json!(35)),
312 ],
313 },
314 ];
315
316 let schema = create_arrow_schema(&columns, &rows).unwrap();
317 assert_eq!(schema.fields().len(), 3);
318 assert_eq!(schema.field(0).name(), "id");
319 assert_eq!(schema.field(0).data_type(), &DataType::Utf8);
320 assert_eq!(schema.field(1).name(), "name");
321 assert_eq!(schema.field(1).data_type(), &DataType::Utf8);
322 assert_eq!(schema.field(2).name(), "age");
323 assert_eq!(schema.field(2).data_type(), &DataType::Int32);
324 }
325
326 #[test]
327 fn test_build_boolean_array() {
328 let values = vec![
329 Some(json!(true)),
330 None,
331 Some(json!(false)),
332 Some(json!(true)),
333 ];
334 let array = build_array_from_values(values, &DataType::Boolean).unwrap();
335 let bool_array = array
336 .as_any()
337 .downcast_ref::<arrow::array::BooleanArray>()
338 .unwrap();
339
340 assert_eq!(array.len(), 4);
341 assert!(bool_array.value(0));
342 assert!(array.is_null(1));
343 assert!(!bool_array.value(2));
344 assert!(bool_array.value(3));
345 }
346
347 #[test]
348 fn test_build_string_array_with_mixed_types() {
349 let values = vec![
350 Some(json!("text")),
351 Some(json!(42)),
352 Some(json!(true)),
353 Some(json!({"key": "value"})),
354 None,
355 ];
356 let array = build_array_from_values(values, &DataType::Utf8).unwrap();
357 let string_array = array
358 .as_any()
359 .downcast_ref::<arrow::array::StringArray>()
360 .unwrap();
361
362 assert_eq!(array.len(), 5);
363 assert_eq!(string_array.value(0), "text");
364 assert_eq!(string_array.value(1), "42");
365 assert_eq!(string_array.value(2), "true");
366 assert!(string_array.value(3).contains("\"key\""));
367 assert!(array.is_null(4));
368 }
369}