arkflow_plugin/processor/
json.rs

1//! Arrow Processor Components
2//!
3//! A processor for converting between binary data and the Arrow format
4
5use arkflow_core::processor::{register_processor_builder, Processor, ProcessorBuilder};
6use arkflow_core::{Bytes, Content, Error, MessageBatch};
7use async_trait::async_trait;
8use datafusion::arrow;
9use datafusion::arrow::array::{
10    ArrayRef, BooleanArray, Float64Array, Int64Array, NullArray, StringArray, UInt64Array,
11};
12use datafusion::arrow::datatypes::{DataType, Field, Schema};
13use datafusion::arrow::record_batch::RecordBatch;
14use serde_json::Value;
15use std::sync::Arc;
16
17/// Arrow format conversion processor configuration
18
19/// Arrow Format Conversion Processor
20
21pub struct JsonToArrowProcessor;
22
23#[async_trait]
24impl Processor for JsonToArrowProcessor {
25    async fn process(&self, msg_batch: MessageBatch) -> Result<Vec<MessageBatch>, Error> {
26        match msg_batch.content {
27            Content::Arrow(_) => Err(Error::Process("The input must be binary data".to_string())),
28            Content::Binary(v) => {
29                let mut batches = Vec::with_capacity(v.len());
30                for x in v {
31                    let record_batch = json_to_arrow(&x)?;
32                    batches.push(record_batch)
33                }
34                if batches.is_empty() {
35                    return Ok(vec![]);
36                }
37
38                let schema = batches[0].schema();
39                let batch = arrow::compute::concat_batches(&schema, &batches)
40                    .map_err(|e| Error::Process(format!("Merge batches failed: {}", e)))?;
41                Ok(vec![MessageBatch::new_arrow(batch)])
42            }
43        }
44    }
45
46    async fn close(&self) -> Result<(), Error> {
47        Ok(())
48    }
49}
50
51pub struct ArrowToJsonProcessor;
52
53#[async_trait]
54impl Processor for ArrowToJsonProcessor {
55    async fn process(&self, msg_batch: MessageBatch) -> Result<Vec<MessageBatch>, Error> {
56        match msg_batch.content {
57            Content::Arrow(v) => {
58                let json_data = arrow_to_json(&v)?;
59                Ok(vec![MessageBatch::new_binary(vec![json_data])])
60            }
61            Content::Binary(_) => Err(Error::Process(
62                "The input must be in Arrow format".to_string(),
63            )),
64        }
65    }
66
67    async fn close(&self) -> Result<(), Error> {
68        Ok(())
69    }
70}
71
72fn json_to_arrow(content: &Bytes) -> Result<RecordBatch, Error> {
73    // 解析JSON内容
74    let json_value: Value = serde_json::from_slice(content)
75        .map_err(|e| Error::Process(format!("JSON解析错误: {}", e)))?;
76
77    match json_value {
78        Value::Object(obj) => {
79            // 单个对象转换为单行表
80            let mut fields = Vec::new();
81            let mut columns: Vec<ArrayRef> = Vec::new();
82
83            // 提取所有字段和值
84            for (key, value) in obj {
85                match value {
86                    Value::Null => {
87                        fields.push(Field::new(&key, DataType::Null, true));
88                        // 空值列处理
89                        columns.push(Arc::new(NullArray::new(1)));
90                    }
91                    Value::Bool(v) => {
92                        fields.push(Field::new(&key, DataType::Boolean, false));
93                        columns.push(Arc::new(BooleanArray::from(vec![v])));
94                    }
95                    Value::Number(v) => {
96                        if v.is_i64() {
97                            fields.push(Field::new(&key, DataType::Int64, false));
98                            columns.push(Arc::new(Int64Array::from(vec![v.as_i64().unwrap()])));
99                        } else if v.is_u64() {
100                            fields.push(Field::new(&key, DataType::UInt64, false));
101                            columns.push(Arc::new(UInt64Array::from(vec![v.as_u64().unwrap()])));
102                        } else {
103                            fields.push(Field::new(&key, DataType::Float64, false));
104                            columns.push(Arc::new(Float64Array::from(vec![v
105                                .as_f64()
106                                .unwrap_or(0.0)])));
107                        }
108                    }
109                    Value::String(v) => {
110                        fields.push(Field::new(&key, DataType::Utf8, false));
111                        columns.push(Arc::new(StringArray::from(vec![v])));
112                    }
113                    Value::Array(v) => {
114                        fields.push(Field::new(&key, DataType::Utf8, false));
115                        if let Ok(x) = serde_json::to_string(&v) {
116                            columns.push(Arc::new(StringArray::from(vec![x])));
117                        } else {
118                            columns.push(Arc::new(StringArray::from(vec!["[]".to_string()])));
119                        }
120                    }
121                    Value::Object(v) => {
122                        fields.push(Field::new(&key, DataType::Utf8, false));
123                        if let Ok(x) = serde_json::to_string(&v) {
124                            columns.push(Arc::new(StringArray::from(vec![x])));
125                        } else {
126                            columns.push(Arc::new(StringArray::from(vec!["{}".to_string()])));
127                        }
128                    }
129                };
130            }
131
132            // 创建schema和记录批次
133            let schema = Arc::new(Schema::new(fields));
134            RecordBatch::try_new(schema, columns)
135                .map_err(|e| Error::Process(format!("创建Arrow记录批次失败: {}", e)))
136        }
137        _ => Err(Error::Process("输入必须是JSON对象".to_string())),
138    }
139}
140
141/// Convert Arrow format to JSON
142fn arrow_to_json(batch: &RecordBatch) -> Result<Vec<u8>, Error> {
143    // 使用Arrow的JSON序列化功能
144    let mut buf = Vec::new();
145    let mut writer = arrow::json::ArrayWriter::new(&mut buf);
146    writer
147        .write(batch)
148        .map_err(|e| Error::Process(format!("Arrow JSON序列化错误: {}", e)))?;
149    writer
150        .finish()
151        .map_err(|e| Error::Process(format!("Arrow JSON序列化完成错误: {}", e)))?;
152
153    Ok(buf)
154}
155
156pub(crate) struct JsonToArrowProcessorBuilder;
157impl ProcessorBuilder for JsonToArrowProcessorBuilder {
158    fn build(&self, _: &Option<serde_json::Value>) -> Result<Arc<dyn Processor>, Error> {
159        Ok(Arc::new(JsonToArrowProcessor))
160    }
161}
162pub(crate) struct ArrowToJsonProcessorBuilder;
163impl ProcessorBuilder for ArrowToJsonProcessorBuilder {
164    fn build(&self, _: &Option<serde_json::Value>) -> Result<Arc<dyn Processor>, Error> {
165        Ok(Arc::new(ArrowToJsonProcessor))
166    }
167}
168
169pub fn init() {
170    register_processor_builder("arrow_to_json", Arc::new(ArrowToJsonProcessorBuilder));
171    register_processor_builder("json_to_arrow", Arc::new(JsonToArrowProcessorBuilder));
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use std::collections::HashMap;
178
179    // Helper function to create a simple JSON object for testing
180    fn create_test_json() -> Vec<u8> {
181        // Create a JSON object with different data types
182        let mut map = HashMap::new();
183        map.insert("null_field", Value::Null);
184        map.insert("bool_field", Value::Bool(true));
185        map.insert("int_field", Value::Number(serde_json::Number::from(42)));
186        map.insert(
187            "uint_field",
188            Value::Number(serde_json::Number::from(100u64)),
189        );
190        map.insert(
191            "float_field",
192            Value::Number(serde_json::Number::from_f64(3.14).unwrap()),
193        );
194        map.insert("string_field", Value::String("test".to_string()));
195        map.insert(
196            "array_field",
197            Value::Array(vec![Value::Number(serde_json::Number::from(1))]),
198        );
199        map.insert(
200            "object_field",
201            Value::Object({
202                let mut inner = serde_json::Map::new();
203                inner.insert("key".to_string(), Value::String("value".to_string()));
204                inner
205            }),
206        );
207
208        // Serialize to JSON bytes
209        serde_json::to_vec(&map).unwrap()
210    }
211
212    #[tokio::test]
213    async fn test_json_to_arrow_processor_success() {
214        // Test successful conversion from JSON to Arrow
215        let processor = JsonToArrowProcessor;
216        let json_data = create_test_json();
217
218        // Create a message batch with binary content
219        let msg_batch = MessageBatch::new_binary(vec![json_data]);
220
221        // Process the message batch
222        let result = processor.process(msg_batch).await.unwrap();
223
224        // Verify the result
225        assert_eq!(result.len(), 1, "Should return one message batch");
226        match &result[0].content {
227            Content::Arrow(batch) => {
228                // Verify the schema and data
229                assert_eq!(batch.num_rows(), 1, "Should have one row");
230                assert_eq!(batch.num_columns(), 8, "Should have 8 columns");
231
232                // Verify column names
233                let schema = batch.schema();
234                let field_names: Vec<&str> =
235                    schema.fields().iter().map(|f| f.name().as_str()).collect();
236                assert!(field_names.contains(&"null_field"));
237                assert!(field_names.contains(&"bool_field"));
238                assert!(field_names.contains(&"int_field"));
239                assert!(field_names.contains(&"uint_field"));
240                assert!(field_names.contains(&"float_field"));
241                assert!(field_names.contains(&"string_field"));
242                assert!(field_names.contains(&"array_field"));
243                assert!(field_names.contains(&"object_field"));
244            }
245            _ => panic!("Expected Arrow content"),
246        }
247    }
248
249    #[tokio::test]
250    async fn test_json_to_arrow_processor_empty_input() {
251        // Test with empty input
252        let processor = JsonToArrowProcessor;
253        let msg_batch = MessageBatch::new_binary(vec![]);
254
255        // Process the message batch
256        let result = processor.process(msg_batch).await.unwrap();
257
258        // Verify the result
259        assert!(
260            result.is_empty(),
261            "Should return empty result for empty input"
262        );
263    }
264
265    #[tokio::test]
266    async fn test_json_to_arrow_processor_invalid_input() {
267        // Test with invalid JSON input
268        let processor = JsonToArrowProcessor;
269        let invalid_json = b"{invalid json";
270
271        // Create a message batch with invalid JSON content
272        let msg_batch = MessageBatch::new_binary(vec![invalid_json.to_vec()]);
273
274        // Process the message batch should fail
275        let result = processor.process(msg_batch).await;
276        assert!(result.is_err(), "Should return error for invalid JSON");
277    }
278
279    #[tokio::test]
280    async fn test_json_to_arrow_processor_non_object_input() {
281        // Test with JSON that is not an object (e.g., array)
282        let processor = JsonToArrowProcessor;
283        let array_json = serde_json::to_vec(&[1, 2, 3]).unwrap();
284
285        // Create a message batch with array JSON content
286        let msg_batch = MessageBatch::new_binary(vec![array_json]);
287
288        // Process the message batch should fail
289        let result = processor.process(msg_batch).await;
290        assert!(result.is_err(), "Should return error for non-object JSON");
291    }
292
293    #[tokio::test]
294    async fn test_json_to_arrow_processor_wrong_content_type() {
295        // Test with Arrow content instead of Binary
296        let processor = JsonToArrowProcessor;
297
298        // Create a simple Arrow record batch
299        let schema = Arc::new(Schema::new(vec![Field::new("test", DataType::Utf8, false)]));
300        let columns: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec!["test"]))];
301        let record_batch = RecordBatch::try_new(schema, columns).unwrap();
302
303        // Create a message batch with Arrow content
304        let msg_batch = MessageBatch::new_arrow(record_batch);
305
306        // Process the message batch should fail
307        let result = processor.process(msg_batch).await;
308        assert!(result.is_err(), "Should return error for Arrow content");
309    }
310
311    #[tokio::test]
312    async fn test_arrow_to_json_processor_success() {
313        // Test successful conversion from Arrow to JSON
314        let processor = ArrowToJsonProcessor;
315
316        // Create a simple Arrow record batch
317        let schema = Arc::new(Schema::new(vec![
318            Field::new("string_field", DataType::Utf8, false),
319            Field::new("int_field", DataType::Int64, false),
320            Field::new("bool_field", DataType::Boolean, false),
321        ]));
322
323        let columns: Vec<ArrayRef> = vec![
324            Arc::new(StringArray::from(vec!["test"])),
325            Arc::new(Int64Array::from(vec![42])),
326            Arc::new(BooleanArray::from(vec![true])),
327        ];
328
329        let record_batch = RecordBatch::try_new(schema, columns).unwrap();
330
331        // Create a message batch with Arrow content
332        let msg_batch = MessageBatch::new_arrow(record_batch);
333
334        // Process the message batch
335        let result = processor.process(msg_batch).await.unwrap();
336
337        // Verify the result
338        assert_eq!(result.len(), 1, "Should return one message batch");
339        match &result[0].content {
340            Content::Binary(v) => {
341                assert_eq!(v.len(), 1, "Should have one binary item");
342
343                // Parse the JSON to verify content
344                let json_str = String::from_utf8_lossy(&v[0]);
345                let json_value: serde_json::Value = serde_json::from_str(&json_str).unwrap();
346
347                // Verify it's a valid JSON array with one object
348                assert!(json_value.is_array(), "Should be a JSON array");
349                let array = json_value.as_array().unwrap();
350                assert_eq!(array.len(), 1, "Should have one object in array");
351
352                let obj = &array[0];
353                assert!(obj.is_object(), "Should be a JSON object");
354                let obj_map = obj.as_object().unwrap();
355
356                // Verify fields
357                assert_eq!(obj_map["string_field"], "test");
358                assert_eq!(obj_map["int_field"], 42);
359                assert_eq!(obj_map["bool_field"], true);
360            }
361            _ => panic!("Expected Binary content"),
362        }
363    }
364
365    #[tokio::test]
366    async fn test_arrow_to_json_processor_wrong_content_type() {
367        // Test with Binary content instead of Arrow
368        let processor = ArrowToJsonProcessor;
369        let binary_data = vec![1, 2, 3];
370
371        // Create a message batch with Binary content
372        let msg_batch = MessageBatch::new_binary(vec![binary_data]);
373
374        // Process the message batch should fail
375        let result = processor.process(msg_batch).await;
376        assert!(result.is_err(), "Should return error for Binary content");
377    }
378
379    #[tokio::test]
380    async fn test_json_to_arrow_function() {
381        // Test the json_to_arrow function directly
382        let json_data = create_test_json();
383        let result = json_to_arrow(&json_data).unwrap();
384
385        // Verify the result
386        assert_eq!(result.num_rows(), 1, "Should have one row");
387        assert_eq!(result.num_columns(), 8, "Should have 8 columns");
388
389        // Verify specific values
390        let schema = result.schema();
391        for (i, field) in schema.fields().iter().enumerate() {
392            match field.name().as_str() {
393                "bool_field" => {
394                    let array = result
395                        .column(i)
396                        .as_any()
397                        .downcast_ref::<BooleanArray>()
398                        .unwrap();
399                    assert_eq!(array.value(0), true);
400                }
401                "int_field" => {
402                    let array = result
403                        .column(i)
404                        .as_any()
405                        .downcast_ref::<Int64Array>()
406                        .unwrap();
407                    assert_eq!(array.value(0), 42);
408                }
409                "string_field" => {
410                    let array = result
411                        .column(i)
412                        .as_any()
413                        .downcast_ref::<StringArray>()
414                        .unwrap();
415                    assert_eq!(array.value(0), "test");
416                }
417                _ => {}
418            }
419        }
420    }
421
422    #[tokio::test]
423    async fn test_arrow_to_json_function() {
424        // Test the arrow_to_json function directly
425        // Create a simple Arrow record batch
426        let schema = Arc::new(Schema::new(vec![Field::new(
427            "test_field",
428            DataType::Utf8,
429            false,
430        )]));
431
432        let columns: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec!["test_value"]))];
433
434        let record_batch = RecordBatch::try_new(schema, columns).unwrap();
435
436        // Convert to JSON
437        let json_bytes = arrow_to_json(&record_batch).unwrap();
438
439        // Verify the result
440        let json_str = String::from_utf8_lossy(&json_bytes);
441        let json_value: serde_json::Value = serde_json::from_str(&json_str).unwrap();
442
443        // Verify it's a valid JSON array with one object
444        assert!(json_value.is_array(), "Should be a JSON array");
445        let array = json_value.as_array().unwrap();
446        assert_eq!(array.len(), 1, "Should have one object in array");
447
448        let obj = &array[0];
449        assert!(obj.is_object(), "Should be a JSON object");
450        let obj_map = obj.as_object().unwrap();
451
452        // Verify field
453        assert_eq!(obj_map["test_field"], "test_value");
454    }
455}