datafusion-distributed 2.0.0

Framework for enhancing Apache DataFusion with distributed capabilities
Documentation
use datafusion::arrow::error::ArrowError;

use crate::protobuf::errors::io_error::IoErrorProto;

#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ArrowErrorProto {
    #[prost(string, optional, tag = "1")]
    pub ctx: Option<String>,
    #[prost(
        oneof = "ArrowErrorInnerProto",
        tags = "2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20"
    )]
    pub inner: Option<ArrowErrorInnerProto>,
}

#[derive(Clone, PartialEq, prost::Oneof)]
pub enum ArrowErrorInnerProto {
    #[prost(string, tag = "2")]
    NotYetImplemented(String),
    #[prost(string, tag = "3")]
    ExternalError(String),
    #[prost(string, tag = "4")]
    CastError(String),
    #[prost(string, tag = "5")]
    MemoryError(String),
    #[prost(string, tag = "6")]
    ParseError(String),
    #[prost(string, tag = "7")]
    SchemaError(String),
    #[prost(string, tag = "8")]
    ComputeError(String),
    #[prost(bool, tag = "9")]
    DivideByZero(bool),
    #[prost(string, tag = "10")]
    ArithmeticOverflow(String),
    #[prost(string, tag = "11")]
    CsvError(String),
    #[prost(string, tag = "12")]
    JsonError(String),
    #[prost(message, tag = "13")]
    IoError(IoErrorProto),
    #[prost(message, tag = "14")]
    IpcError(String),
    #[prost(message, tag = "15")]
    InvalidArgumentError(String),
    #[prost(message, tag = "16")]
    ParquetError(String),
    #[prost(message, tag = "17")]
    CDataInterface(String),
    #[prost(bool, tag = "18")]
    DictionaryKeyOverflowError(bool),
    #[prost(bool, tag = "19")]
    RunEndIndexOverflowError(bool),
    #[prost(uint64, tag = "20")]
    OffsetOverflowError(u64),
    #[prost(string, tag = "21")]
    AvroError(String),
}

impl ArrowErrorProto {
    pub fn from_arrow_error(err: &ArrowError, ctx: Option<&String>) -> Self {
        match err {
            ArrowError::NotYetImplemented(msg) => ArrowErrorProto {
                inner: Some(ArrowErrorInnerProto::NotYetImplemented(msg.to_string())),
                ctx: ctx.cloned(),
            },
            ArrowError::ExternalError(msg) => ArrowErrorProto {
                inner: Some(ArrowErrorInnerProto::ExternalError(msg.to_string())),
                ctx: ctx.cloned(),
            },
            ArrowError::CastError(msg) => ArrowErrorProto {
                inner: Some(ArrowErrorInnerProto::CastError(msg.to_string())),
                ctx: ctx.cloned(),
            },
            ArrowError::MemoryError(msg) => ArrowErrorProto {
                inner: Some(ArrowErrorInnerProto::MemoryError(msg.to_string())),
                ctx: ctx.cloned(),
            },
            ArrowError::ParseError(msg) => ArrowErrorProto {
                inner: Some(ArrowErrorInnerProto::ParseError(msg.to_string())),
                ctx: ctx.cloned(),
            },
            ArrowError::SchemaError(msg) => ArrowErrorProto {
                inner: Some(ArrowErrorInnerProto::SchemaError(msg.to_string())),
                ctx: ctx.cloned(),
            },
            ArrowError::ComputeError(msg) => ArrowErrorProto {
                inner: Some(ArrowErrorInnerProto::ComputeError(msg.to_string())),
                ctx: ctx.cloned(),
            },
            ArrowError::DivideByZero => ArrowErrorProto {
                inner: Some(ArrowErrorInnerProto::DivideByZero(true)),
                ctx: ctx.cloned(),
            },
            ArrowError::ArithmeticOverflow(msg) => ArrowErrorProto {
                inner: Some(ArrowErrorInnerProto::ArithmeticOverflow(msg.to_string())),
                ctx: ctx.cloned(),
            },
            ArrowError::CsvError(msg) => ArrowErrorProto {
                inner: Some(ArrowErrorInnerProto::CsvError(msg.to_string())),
                ctx: ctx.cloned(),
            },
            ArrowError::JsonError(msg) => ArrowErrorProto {
                inner: Some(ArrowErrorInnerProto::JsonError(msg.to_string())),
                ctx: ctx.cloned(),
            },
            ArrowError::IoError(msg, err) => ArrowErrorProto {
                inner: Some(ArrowErrorInnerProto::IoError(IoErrorProto::from_io_error(
                    msg, err,
                ))),
                ctx: ctx.cloned(),
            },
            ArrowError::IpcError(msg) => ArrowErrorProto {
                inner: Some(ArrowErrorInnerProto::IpcError(msg.to_string())),
                ctx: ctx.cloned(),
            },
            ArrowError::InvalidArgumentError(msg) => ArrowErrorProto {
                inner: Some(ArrowErrorInnerProto::InvalidArgumentError(msg.to_string())),
                ctx: ctx.cloned(),
            },
            ArrowError::ParquetError(msg) => ArrowErrorProto {
                inner: Some(ArrowErrorInnerProto::ParquetError(msg.to_string())),
                ctx: ctx.cloned(),
            },
            ArrowError::CDataInterface(msg) => ArrowErrorProto {
                inner: Some(ArrowErrorInnerProto::CDataInterface(msg.to_string())),
                ctx: ctx.cloned(),
            },
            ArrowError::DictionaryKeyOverflowError => ArrowErrorProto {
                inner: Some(ArrowErrorInnerProto::DictionaryKeyOverflowError(true)),
                ctx: ctx.cloned(),
            },
            ArrowError::RunEndIndexOverflowError => ArrowErrorProto {
                inner: Some(ArrowErrorInnerProto::RunEndIndexOverflowError(true)),
                ctx: ctx.cloned(),
            },
            ArrowError::OffsetOverflowError(offset) => ArrowErrorProto {
                inner: Some(ArrowErrorInnerProto::OffsetOverflowError(*offset as u64)),
                ctx: ctx.cloned(),
            },
            ArrowError::AvroError(msg) => ArrowErrorProto {
                inner: Some(ArrowErrorInnerProto::AvroError(msg.to_string())),
                ctx: ctx.cloned(),
            },
        }
    }

    pub fn to_arrow_error(&self) -> (ArrowError, Option<String>) {
        let Some(ref inner) = self.inner else {
            return (
                ArrowError::ExternalError(Box::from("Malformed protobuf message".to_string())),
                None,
            );
        };
        let err = match inner {
            ArrowErrorInnerProto::NotYetImplemented(msg) => {
                ArrowError::NotYetImplemented(msg.to_string())
            }
            ArrowErrorInnerProto::ExternalError(msg) => {
                ArrowError::ExternalError(Box::from(msg.to_string()))
            }
            ArrowErrorInnerProto::CastError(msg) => ArrowError::CastError(msg.to_string()),
            ArrowErrorInnerProto::MemoryError(msg) => ArrowError::MemoryError(msg.to_string()),
            ArrowErrorInnerProto::ParseError(msg) => ArrowError::ParseError(msg.to_string()),
            ArrowErrorInnerProto::SchemaError(msg) => ArrowError::SchemaError(msg.to_string()),
            ArrowErrorInnerProto::ComputeError(msg) => ArrowError::ComputeError(msg.to_string()),
            ArrowErrorInnerProto::DivideByZero(_) => ArrowError::DivideByZero,
            ArrowErrorInnerProto::ArithmeticOverflow(msg) => {
                ArrowError::ArithmeticOverflow(msg.to_string())
            }
            ArrowErrorInnerProto::CsvError(msg) => ArrowError::CsvError(msg.to_string()),
            ArrowErrorInnerProto::JsonError(msg) => ArrowError::JsonError(msg.to_string()),
            ArrowErrorInnerProto::IoError(msg) => {
                let (msg, err) = msg.to_io_error();
                ArrowError::IoError(err, msg)
            }
            ArrowErrorInnerProto::IpcError(msg) => ArrowError::IpcError(msg.to_string()),
            ArrowErrorInnerProto::InvalidArgumentError(msg) => {
                ArrowError::InvalidArgumentError(msg.to_string())
            }
            ArrowErrorInnerProto::ParquetError(msg) => ArrowError::ParquetError(msg.to_string()),
            ArrowErrorInnerProto::CDataInterface(msg) => {
                ArrowError::CDataInterface(msg.to_string())
            }
            ArrowErrorInnerProto::DictionaryKeyOverflowError(_) => {
                ArrowError::DictionaryKeyOverflowError
            }
            ArrowErrorInnerProto::RunEndIndexOverflowError(_) => {
                ArrowError::RunEndIndexOverflowError
            }
            ArrowErrorInnerProto::OffsetOverflowError(offset) => {
                ArrowError::OffsetOverflowError(*offset as usize)
            }
            ArrowErrorInnerProto::AvroError(msg) => ArrowError::AvroError(msg.to_string()),
        };
        (err, self.ctx.clone())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use prost::Message;
    use std::io::{Error as IoError, ErrorKind};

    #[test]
    fn test_arrow_error_roundtrip() {
        let test_cases = vec![
            ArrowError::NotYetImplemented("test not implemented".to_string()),
            ArrowError::ExternalError(Box::new(std::io::Error::other("external error"))),
            ArrowError::CastError("cast error".to_string()),
            ArrowError::MemoryError("memory error".to_string()),
            ArrowError::ParseError("parse error".to_string()),
            ArrowError::SchemaError("schema error".to_string()),
            ArrowError::ComputeError("compute error".to_string()),
            ArrowError::DivideByZero,
            ArrowError::ArithmeticOverflow("overflow".to_string()),
            ArrowError::CsvError("csv error".to_string()),
            ArrowError::JsonError("json error".to_string()),
            ArrowError::IoError(
                "io message".to_string(),
                IoError::new(ErrorKind::NotFound, "file not found"),
            ),
            ArrowError::IpcError("ipc error".to_string()),
            ArrowError::InvalidArgumentError("invalid arg".to_string()),
            ArrowError::ParquetError("parquet error".to_string()),
            ArrowError::CDataInterface("cdata error".to_string()),
            ArrowError::DictionaryKeyOverflowError,
            ArrowError::RunEndIndexOverflowError,
            ArrowError::OffsetOverflowError(12345),
        ];

        for original_error in test_cases {
            let proto = ArrowErrorProto::from_arrow_error(
                &original_error,
                Some(&"test context".to_string()),
            );
            let proto = ArrowErrorProto::decode(proto.encode_to_vec().as_ref()).unwrap();
            let (recovered_error, recovered_ctx) = proto.to_arrow_error();

            if original_error.to_string() != recovered_error.to_string() {
                println!("original error: {original_error}");
                println!("recovered error: {recovered_error}");
            }

            assert_eq!(original_error.to_string(), recovered_error.to_string());
            assert_eq!(recovered_ctx, Some("test context".to_string()));

            let proto_no_ctx = ArrowErrorProto::from_arrow_error(&original_error, None);
            let proto_no_ctx =
                ArrowErrorProto::decode(proto_no_ctx.encode_to_vec().as_ref()).unwrap();
            let (recovered_error_no_ctx, recovered_ctx_no_ctx) = proto_no_ctx.to_arrow_error();

            assert_eq!(
                original_error.to_string(),
                recovered_error_no_ctx.to_string()
            );
            assert_eq!(recovered_ctx_no_ctx, None);
        }
    }

    #[test]
    fn test_malformed_protobuf_message() {
        let malformed_proto = ArrowErrorProto {
            inner: None,
            ctx: None,
        };
        let (recovered_error, _) = malformed_proto.to_arrow_error();
        assert!(matches!(recovered_error, ArrowError::ExternalError(_)));
    }
}