1use async_trait::async_trait;
6use datafusion::arrow;
7use datafusion::arrow::array::{
8 Array, ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array,
9 StringArray, UInt32Array, UInt64Array,
10};
11use datafusion::arrow::datatypes::{DataType, Field, Schema};
12use datafusion::arrow::record_batch::RecordBatch;
13use datafusion::parquet::data_type::AsBytes;
14use prost_reflect::prost::Message;
15use prost_reflect::prost_types::FileDescriptorSet;
16use prost_reflect::{DynamicMessage, MessageDescriptor, Value};
17use serde::{Deserialize, Serialize};
18use std::path::Path;
19use std::sync::Arc;
20use std::{fs, io};
21
22use arkflow_core::processor::{register_processor_builder, Processor, ProcessorBuilder};
23use arkflow_core::{Content, Error, MessageBatch};
24use protobuf::Message as ProtobufMessage;
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ProtobufProcessorConfig {
29 pub proto_inputs: Vec<String>,
31 pub proto_includes: Option<Vec<String>>,
32 pub message_type: String,
34}
35
36pub struct ProtobufProcessor {
38 _config: ProtobufProcessorConfig,
39 descriptor: MessageDescriptor,
40}
41
42impl ProtobufProcessor {
43 pub fn new(config: ProtobufProcessorConfig) -> Result<Self, Error> {
45 let file_descriptor_set = Self::parse_proto_file(&config)?;
47
48 let descriptor_pool = prost_reflect::DescriptorPool::from_file_descriptor_set(
49 file_descriptor_set,
50 )
51 .map_err(|e| Error::Config(format!("Unable to create Protobuf descriptor pool: {}", e)))?;
52
53 let message_descriptor = descriptor_pool
54 .get_message_by_name(&config.message_type)
55 .ok_or_else(|| {
56 Error::Config(format!(
57 "The message type could not be found: {}",
58 config.message_type
59 ))
60 })?;
61
62 Ok(Self {
63 _config: config.clone(),
64 descriptor: message_descriptor,
65 })
66 }
67
68 fn parse_proto_file(c: &ProtobufProcessorConfig) -> Result<FileDescriptorSet, Error> {
70 let mut proto_inputs: Vec<String> = vec![];
71 for x in &c.proto_inputs {
72 let files_in_dir_result = list_files_in_dir(x)
73 .map_err(|e| Error::Config(format!("Failed to list proto files: {}", e)))?;
74 proto_inputs.extend(
75 files_in_dir_result
76 .iter()
77 .filter(|path| path.ends_with(".proto"))
78 .map(|path| format!("{}/{}", x, path))
79 .collect::<Vec<_>>(),
80 )
81 }
82 let proto_includes = c.proto_includes.clone().unwrap_or(c.proto_inputs.clone());
83
84 let file_descriptor_protos = protobuf_parse::Parser::new()
86 .pure()
87 .inputs(proto_inputs)
88 .includes(proto_includes)
89 .parse_and_typecheck()
90 .map_err(|e| Error::Config(format!("Failed to parse the proto file: {}", e)))?
91 .file_descriptors;
92
93 if file_descriptor_protos.is_empty() {
94 return Err(Error::Config(
95 "Parsing the proto file does not yield any descriptors".to_string(),
96 ));
97 }
98
99 let mut file_descriptor_set = FileDescriptorSet { file: Vec::new() };
101
102 for proto in file_descriptor_protos {
103 let proto_bytes = proto.write_to_bytes().map_err(|e| {
105 Error::Config(format!("Failed to serialize FileDescriptorProto: {}", e))
106 })?;
107
108 let prost_proto =
109 prost_reflect::prost_types::FileDescriptorProto::decode(proto_bytes.as_slice())
110 .map_err(|e| {
111 Error::Config(format!("Failed to convert FileDescriptorProto: {}", e))
112 })?;
113
114 file_descriptor_set.file.push(prost_proto);
115 }
116
117 Ok(file_descriptor_set)
118 }
119
120 fn protobuf_to_arrow(&self, data: &[u8]) -> Result<RecordBatch, Error> {
122 let proto_msg = DynamicMessage::decode(self.descriptor.clone(), data)
124 .map_err(|e| Error::Process(format!("Protobuf message parsing failed: {}", e)))?;
125
126 let mut fields = Vec::new();
128 let mut columns: Vec<ArrayRef> = Vec::new();
129
130 for field in self.descriptor.fields() {
132 let field_name = field.name();
133
134 let field_value_opt = proto_msg.get_field_by_name(field_name);
135 if field_value_opt.is_none() {
136 continue;
137 }
138 let field_value = field_value_opt.unwrap();
139 match field_value.as_ref() {
140 Value::Bool(value) => {
141 fields.push(Field::new(field_name, DataType::Boolean, false));
142 columns.push(Arc::new(BooleanArray::from(vec![value.clone()])));
143 }
144 Value::I32(value) => {
145 fields.push(Field::new(field_name, DataType::Int32, false));
146 columns.push(Arc::new(Int32Array::from(vec![value.clone()])));
147 }
148 Value::I64(value) => {
149 fields.push(Field::new(field_name, DataType::Int64, false));
150 columns.push(Arc::new(Int64Array::from(vec![value.clone()])));
151 }
152 Value::U32(value) => {
153 fields.push(Field::new(field_name, DataType::UInt32, false));
154 columns.push(Arc::new(UInt32Array::from(vec![value.clone()])));
155 }
156 Value::U64(value) => {
157 fields.push(Field::new(field_name, DataType::UInt64, false));
158 columns.push(Arc::new(UInt64Array::from(vec![value.clone()])));
159 }
160 Value::F32(value) => {
161 fields.push(Field::new(field_name, DataType::Float32, false));
162 columns.push(Arc::new(Float32Array::from(vec![value.clone()])))
163 }
164 Value::F64(value) => {
165 fields.push(Field::new(field_name, DataType::Float64, false));
166 columns.push(Arc::new(Float64Array::from(vec![value.clone()])));
167 }
168 Value::String(value) => {
169 fields.push(Field::new(field_name, DataType::Utf8, false));
170 columns.push(Arc::new(StringArray::from(vec![value.clone()])));
171 }
172 Value::Bytes(value) => {
173 fields.push(Field::new(field_name, DataType::Binary, false));
174 columns.push(Arc::new(BinaryArray::from(vec![value.as_bytes()])));
175 }
176 Value::EnumNumber(value) => {
177 fields.push(Field::new(field_name, DataType::Int32, false));
178 columns.push(Arc::new(Int32Array::from(vec![value.clone()])));
179 }
180 _ => {
181 return Err(Error::Process(format!(
182 "Unsupported field type: {}",
183 field_name
184 )));
185 } }
189 }
190
191 let schema = Arc::new(Schema::new(fields));
193 RecordBatch::try_new(schema, columns)
194 .map_err(|e| Error::Process(format!("Creating an Arrow record batch failed: {}", e)))
195 }
196
197 fn arrow_to_protobuf(&self, batch: &RecordBatch) -> Result<Vec<u8>, Error> {
199 let mut proto_msg = DynamicMessage::new(self.descriptor.clone());
201
202 let schema = batch.schema();
204
205 if batch.num_rows() != 1 {
207 return Err(Error::Process(
208 "Only supports single-line Arrow data conversion to Protobuf.".to_string(),
209 ));
210 }
211
212 for (i, field) in schema.fields().iter().enumerate() {
213 let field_name = field.name();
214
215 if let Some(proto_field) = self.descriptor.get_field_by_name(field_name) {
216 let column = batch.column(i);
217
218 match proto_field.kind() {
219 prost_reflect::Kind::Bool => {
220 if let Some(value) = column.as_any().downcast_ref::<BooleanArray>() {
221 if value.len() > 0 {
222 proto_msg
223 .set_field_by_name(field_name, Value::Bool(value.value(0)));
224 }
225 }
226 }
227 prost_reflect::Kind::Int32
228 | prost_reflect::Kind::Sint32
229 | prost_reflect::Kind::Sfixed32 => {
230 if let Some(value) = column.as_any().downcast_ref::<Int32Array>() {
231 if value.len() > 0 {
232 proto_msg.set_field_by_name(field_name, Value::I32(value.value(0)));
233 }
234 }
235 }
236 prost_reflect::Kind::Int64
237 | prost_reflect::Kind::Sint64
238 | prost_reflect::Kind::Sfixed64 => {
239 if let Some(value) = column.as_any().downcast_ref::<Int64Array>() {
240 if value.len() > 0 {
241 proto_msg.set_field_by_name(field_name, Value::I64(value.value(0)));
242 }
243 }
244 }
245 prost_reflect::Kind::Uint32 | prost_reflect::Kind::Fixed32 => {
246 if let Some(value) = column.as_any().downcast_ref::<UInt32Array>() {
247 if value.len() > 0 {
248 proto_msg.set_field_by_name(field_name, Value::U32(value.value(0)));
249 }
250 }
251 }
252 prost_reflect::Kind::Uint64 | prost_reflect::Kind::Fixed64 => {
253 if let Some(value) = column.as_any().downcast_ref::<UInt64Array>() {
254 if value.len() > 0 {
255 proto_msg.set_field_by_name(field_name, Value::U64(value.value(0)));
256 }
257 }
258 }
259 prost_reflect::Kind::Float => {
260 if let Some(value) = column.as_any().downcast_ref::<Float32Array>() {
261 if value.len() > 0 {
262 proto_msg.set_field_by_name(field_name, Value::F32(value.value(0)));
263 }
264 }
265 }
266 prost_reflect::Kind::Double => {
267 if let Some(value) = column.as_any().downcast_ref::<Float64Array>() {
268 if value.len() > 0 {
269 proto_msg.set_field_by_name(field_name, Value::F64(value.value(0)));
270 }
271 }
272 }
273 prost_reflect::Kind::String => {
274 if let Some(value) = column.as_any().downcast_ref::<StringArray>() {
275 if value.len() > 0 {
276 proto_msg.set_field_by_name(
277 field_name,
278 Value::String(value.value(0).to_string()),
279 );
280 }
281 }
282 }
283 prost_reflect::Kind::Bytes => {
284 if let Some(value) = column.as_any().downcast_ref::<BinaryArray>() {
285 if value.len() > 0 {
286 proto_msg.set_field_by_name(
287 field_name,
288 Value::Bytes(value.value(0).to_vec().into()),
289 );
290 }
291 }
292 }
293 prost_reflect::Kind::Enum(_) => {
294 if let Some(value) = column.as_any().downcast_ref::<Int32Array>() {
295 if value.len() > 0 {
296 proto_msg.set_field_by_name(
297 field_name,
298 Value::EnumNumber(value.value(0)),
299 );
300 }
301 }
302 }
303 _ => {
304 return Err(Error::Process(format!(
305 "Unsupported Protobuf type: {:?}",
306 proto_field.kind()
307 )))
308 }
309 }
310 }
311 }
312
313 let mut buf = Vec::new();
314 proto_msg
315 .encode(&mut buf)
316 .map_err(|e| Error::Process(format!("Protobuf encoding failed: {}", e)))?;
317
318 Ok(buf)
319 }
320}
321
322#[async_trait]
323impl Processor for ProtobufProcessor {
324 async fn process(&self, msg: MessageBatch) -> Result<Vec<MessageBatch>, Error> {
325 if msg.is_empty() {
326 return Ok(vec![]);
327 }
328 match msg.content {
329 Content::Arrow(v) => {
330 let proto_data = self.arrow_to_protobuf(&v)?;
332 let new_msg = MessageBatch::new_binary(vec![proto_data]);
333
334 Ok(vec![new_msg])
335 }
336 Content::Binary(v) => {
337 if v.is_empty() {
338 return Ok(vec![]);
339 }
340 let mut batches = Vec::with_capacity(v.len());
341 for x in v {
342 let batch = self.protobuf_to_arrow(&x)?;
344 batches.push(batch)
345 }
346
347 let schema = batches[0].schema();
348 let batch = arrow::compute::concat_batches(&schema, &batches)
349 .map_err(|e| Error::Process(format!("Batch merge failed: {}", e)))?;
350 Ok(vec![MessageBatch::new_arrow(batch)])
351 }
352 }
353 }
354
355 async fn close(&self) -> Result<(), Error> {
356 Ok(())
357 }
358}
359
360fn list_files_in_dir<P: AsRef<Path>>(dir: P) -> io::Result<Vec<String>> {
361 let mut files = Vec::new();
362 if dir.as_ref().is_dir() {
363 for entry in fs::read_dir(dir)? {
364 let entry = entry?;
365 let path = entry.path();
366 if path.is_file() {
367 if let Some(file_name) = path.file_name() {
368 if let Some(file_name_str) = file_name.to_str() {
369 files.push(file_name_str.to_string());
370 }
371 }
372 }
373 }
374 }
375 Ok(files)
376}
377
378pub(crate) struct ProtobufProcessorBuilder;
379impl ProcessorBuilder for ProtobufProcessorBuilder {
380 fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Processor>, Error> {
381 if config.is_none() {
382 return Err(Error::Config(
383 "Batch processor configuration is missing".to_string(),
384 ));
385 }
386 let config: ProtobufProcessorConfig = serde_json::from_value(config.clone().unwrap())?;
387 Ok(Arc::new(ProtobufProcessor::new(config)?))
388 }
389}
390
391pub fn init() {
392 register_processor_builder("protobuf", Arc::new(ProtobufProcessorBuilder));
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398 use std::fs::File;
399 use std::io::Write;
400 use tempfile::tempdir;
401
402 fn create_test_proto_file() -> (tempfile::TempDir, String, String) {
404 let temp_dir = tempdir().unwrap();
406 let proto_dir = temp_dir.path().to_str().unwrap().to_string();
407
408 let proto_file_path = temp_dir.path().join("test.proto");
410 let mut file = File::create(&proto_file_path).unwrap();
411
412 let proto_content = r#"syntax = "proto3";
414
415package test;
416
417message TestMessage {
418 bool bool_field = 1;
419 int32 int32_field = 2;
420 int64 int64_field = 3;
421 uint32 uint32_field = 4;
422 uint64 uint64_field = 5;
423 float float_field = 6;
424 double double_field = 7;
425 string string_field = 8;
426 bytes bytes_field = 9;
427 enum TestEnum {
428 UNKNOWN = 0;
429 VALUE1 = 1;
430 VALUE2 = 2;
431 }
432 TestEnum enum_field = 10;
433}
434"#;
435
436 file.write_all(proto_content.as_bytes()).unwrap();
437
438 (temp_dir, proto_dir, "test.TestMessage".to_string())
439 }
440
441 fn create_test_protobuf_message(descriptor: &MessageDescriptor) -> Vec<u8> {
443 let mut message = DynamicMessage::new(descriptor.clone());
444
445 message.set_field_by_name("bool_field", Value::Bool(true));
447 message.set_field_by_name("int32_field", Value::I32(42));
448 message.set_field_by_name("int64_field", Value::I64(1234567890));
449 message.set_field_by_name("uint32_field", Value::U32(42));
450 message.set_field_by_name("uint64_field", Value::U64(1234567890));
451 message.set_field_by_name("float_field", Value::F32(3.14));
452 message.set_field_by_name("double_field", Value::F64(2.71828));
453 message.set_field_by_name("string_field", Value::String("test string".to_string()));
454 message.set_field_by_name("bytes_field", Value::Bytes(vec![1, 2, 3, 4].into()));
455 message.set_field_by_name("enum_field", Value::EnumNumber(1));
456
457 let mut buf = Vec::new();
459 message.encode(&mut buf).unwrap();
460 buf
461 }
462
463 fn create_test_arrow_batch() -> RecordBatch {
465 let fields = vec![
467 Field::new("bool_field", DataType::Boolean, false),
468 Field::new("int32_field", DataType::Int32, false),
469 Field::new("int64_field", DataType::Int64, false),
470 Field::new("uint32_field", DataType::UInt32, false),
471 Field::new("uint64_field", DataType::UInt64, false),
472 Field::new("float_field", DataType::Float32, false),
473 Field::new("double_field", DataType::Float64, false),
474 Field::new("string_field", DataType::Utf8, false),
475 Field::new("bytes_field", DataType::Binary, false),
476 Field::new("enum_field", DataType::Int32, false),
477 ];
478
479 let columns: Vec<ArrayRef> = vec![
481 Arc::new(BooleanArray::from(vec![true])),
482 Arc::new(Int32Array::from(vec![42])),
483 Arc::new(Int64Array::from(vec![1234567890])),
484 Arc::new(UInt32Array::from(vec![42])),
485 Arc::new(UInt64Array::from(vec![1234567890])),
486 Arc::new(Float32Array::from(vec![3.14])),
487 Arc::new(Float64Array::from(vec![2.71828])),
488 Arc::new(StringArray::from(vec!["test string"])),
489 Arc::new(BinaryArray::from(vec![&[1, 2, 3, 4][..]])),
490 Arc::new(Int32Array::from(vec![1])),
491 ];
492
493 let schema = Arc::new(Schema::new(fields));
495 RecordBatch::try_new(schema, columns).unwrap()
496 }
497
498 #[tokio::test]
499 async fn test_protobuf_processor_creation() {
500 let (temp_dir, proto_dir, message_type) = create_test_proto_file();
502
503 let config = ProtobufProcessorConfig {
505 proto_inputs: vec![proto_dir],
506 proto_includes: None,
507 message_type,
508 };
509
510 let processor = ProtobufProcessor::new(config);
512 assert!(
513 processor.is_ok(),
514 "Failed to create ProtobufProcessor: {:?}",
515 processor.err()
516 );
517
518 drop(temp_dir);
520 }
521
522 #[tokio::test]
523 async fn test_protobuf_to_arrow_conversion() {
524 let (temp_dir, proto_dir, message_type) = create_test_proto_file();
526
527 let config = ProtobufProcessorConfig {
529 proto_inputs: vec![proto_dir],
530 proto_includes: None,
531 message_type,
532 };
533
534 let processor = ProtobufProcessor::new(config).unwrap();
536
537 let proto_data = create_test_protobuf_message(&processor.descriptor);
539
540 let arrow_batch = processor.protobuf_to_arrow(&proto_data);
542 assert!(
543 arrow_batch.is_ok(),
544 "Failed to convert Protobuf to Arrow: {:?}",
545 arrow_batch.err()
546 );
547
548 let batch = arrow_batch.unwrap();
549
550 assert_eq!(batch.num_rows(), 1, "Expected 1 row in the Arrow batch");
552 assert_eq!(
553 batch.num_columns(),
554 10,
555 "Expected 10 columns in the Arrow batch"
556 );
557
558 let bool_array = batch
560 .column(0)
561 .as_any()
562 .downcast_ref::<BooleanArray>()
563 .unwrap();
564 assert_eq!(bool_array.value(0), true);
565
566 let int32_array = batch
567 .column(1)
568 .as_any()
569 .downcast_ref::<Int32Array>()
570 .unwrap();
571 assert_eq!(int32_array.value(0), 42);
572
573 let string_array = batch
574 .column(7)
575 .as_any()
576 .downcast_ref::<StringArray>()
577 .unwrap();
578 assert_eq!(string_array.value(0), "test string");
579
580 drop(temp_dir);
582 }
583
584 #[tokio::test]
585 async fn test_arrow_to_protobuf_conversion() {
586 let (temp_dir, proto_dir, message_type) = create_test_proto_file();
588
589 let config = ProtobufProcessorConfig {
591 proto_inputs: vec![proto_dir],
592 proto_includes: None,
593 message_type,
594 };
595
596 let processor = ProtobufProcessor::new(config).unwrap();
598
599 let arrow_batch = create_test_arrow_batch();
601
602 let proto_data = processor.arrow_to_protobuf(&arrow_batch);
604 assert!(
605 proto_data.is_ok(),
606 "Failed to convert Arrow to Protobuf: {:?}",
607 proto_data.err()
608 );
609
610 let proto_bytes = proto_data.unwrap();
612 let arrow_batch_2 = processor.protobuf_to_arrow(&proto_bytes);
613 assert!(
614 arrow_batch_2.is_ok(),
615 "Failed to convert back to Arrow: {:?}",
616 arrow_batch_2.err()
617 );
618
619 let batch = arrow_batch_2.unwrap();
620
621 let bool_array = batch
623 .column(0)
624 .as_any()
625 .downcast_ref::<BooleanArray>()
626 .unwrap();
627 assert_eq!(bool_array.value(0), true);
628
629 let int32_array = batch
630 .column(1)
631 .as_any()
632 .downcast_ref::<Int32Array>()
633 .unwrap();
634 assert_eq!(int32_array.value(0), 42);
635
636 let string_array = batch
637 .column(7)
638 .as_any()
639 .downcast_ref::<StringArray>()
640 .unwrap();
641 assert_eq!(string_array.value(0), "test string");
642
643 drop(temp_dir);
645 }
646
647 #[tokio::test]
648 async fn test_process_empty_batch() {
649 let (temp_dir, proto_dir, message_type) = create_test_proto_file();
651
652 let config = ProtobufProcessorConfig {
654 proto_inputs: vec![proto_dir],
655 proto_includes: None,
656 message_type,
657 };
658
659 let processor = ProtobufProcessor::new(config).unwrap();
661
662 let empty_batch = MessageBatch::new_binary(vec![]);
664 let result = processor.process(empty_batch).await;
665
666 assert!(
667 result.is_ok(),
668 "Failed to process empty batch: {:?}",
669 result.err()
670 );
671 assert!(
672 result.unwrap().is_empty(),
673 "Expected empty result for empty batch"
674 );
675
676 drop(temp_dir);
678 }
679
680 #[tokio::test]
681 async fn test_process_binary_to_arrow() {
682 let (temp_dir, proto_dir, message_type) = create_test_proto_file();
684
685 let config = ProtobufProcessorConfig {
687 proto_inputs: vec![proto_dir],
688 proto_includes: None,
689 message_type,
690 };
691
692 let processor = ProtobufProcessor::new(config).unwrap();
694
695 let proto_data = create_test_protobuf_message(&processor.descriptor);
697
698 let msg_batch = MessageBatch::new_binary(vec![proto_data]);
700
701 let result = processor.process(msg_batch).await;
703 assert!(
704 result.is_ok(),
705 "Failed to process binary to arrow: {:?}",
706 result.err()
707 );
708
709 let batches = result.unwrap();
710 assert_eq!(batches.len(), 1, "Expected 1 message batch");
711
712 match &batches[0].content {
714 Content::Arrow(batch) => {
715 assert_eq!(batch.num_rows(), 1, "Expected 1 row");
716 assert_eq!(batch.num_columns(), 10, "Expected 10 columns");
717 }
718 _ => panic!("Expected Arrow content"),
719 }
720
721 drop(temp_dir);
723 }
724
725 #[tokio::test]
726 async fn test_process_arrow_to_binary() {
727 let (temp_dir, proto_dir, message_type) = create_test_proto_file();
729
730 let config = ProtobufProcessorConfig {
732 proto_inputs: vec![proto_dir],
733 proto_includes: None,
734 message_type,
735 };
736
737 let processor = ProtobufProcessor::new(config).unwrap();
739
740 let arrow_batch = create_test_arrow_batch();
742
743 let msg_batch = MessageBatch::new_arrow(arrow_batch);
745
746 let result = processor.process(msg_batch).await;
748 assert!(
749 result.is_ok(),
750 "Failed to process arrow to binary: {:?}",
751 result.err()
752 );
753
754 let batches = result.unwrap();
755 assert_eq!(batches.len(), 1, "Expected 1 message batch");
756
757 match &batches[0].content {
759 Content::Binary(data) => {
760 assert_eq!(data.len(), 1, "Expected 1 binary message");
761 }
762 _ => panic!("Expected Binary content"),
763 }
764
765 drop(temp_dir);
767 }
768}