datafusion-distributed 2.0.0

Framework for enhancing Apache DataFusion with distributed capabilities
Documentation
use datafusion::common::{DataFusionError, internal_datafusion_err};
use datafusion::config::ConfigExtension;
use datafusion::prelude::SessionConfig;
use http::{HeaderMap, HeaderName};
use std::error::Error;
use std::str::FromStr;
use std::sync::Arc;

pub(crate) const FLIGHT_METADATA_CONFIG_PREFIX: &str = "x-datafusion-distributed-config-";

pub(crate) fn set_distributed_option_extension<T: ConfigExtension + Default>(
    cfg: &mut SessionConfig,
    t: T,
) {
    cfg.options_mut().extensions.insert(t);
    let mut propagation_ctx = match cfg.get_extension::<ConfigExtensionPropagationContext>() {
        None => ConfigExtensionPropagationContext::default(),
        Some(prev) => prev.as_ref().clone(),
    };
    propagation_ctx.prefixes.push(T::PREFIX);
    cfg.set_extension(Arc::new(propagation_ctx));
}

pub(crate) fn set_distributed_option_extension_from_headers<'a, T: ConfigExtension + Default>(
    cfg: &'a mut SessionConfig,
    headers: &HeaderMap,
) -> Result<&'a T, DataFusionError> {
    enum MutOrOwned<'a, T> {
        Mut(&'a mut T),
        Owned(T),
    }

    impl<'a, T> MutOrOwned<'a, T> {
        fn as_mut(&mut self) -> &mut T {
            match self {
                MutOrOwned::Mut(v) => v,
                MutOrOwned::Owned(v) => v,
            }
        }
    }

    let mut propagation_ctx = match cfg.get_extension::<ConfigExtensionPropagationContext>() {
        None => ConfigExtensionPropagationContext::default(),
        Some(prev) => prev.as_ref().clone(),
    };
    propagation_ctx.prefixes.push(T::PREFIX);
    cfg.set_extension(Arc::new(propagation_ctx));

    // If the config extension existed before, we want to modify instead of adding a new one from
    // scratch. If not, we'll start from scratch with a new one.
    let mut result = match cfg.options_mut().extensions.get_mut::<T>() {
        Some(v) => MutOrOwned::Mut(v),
        None => MutOrOwned::Owned(T::default()),
    };

    for (k, v) in headers.iter() {
        let key = k.as_str().trim_start_matches(FLIGHT_METADATA_CONFIG_PREFIX);
        let prefix = format!("{}.", T::PREFIX);
        if key.starts_with(&prefix) {
            result.as_mut().set(
                key.trim_start_matches(&prefix),
                v.to_str()
                    .map_err(|err| internal_datafusion_err!("Cannot parse header value: {err}"))?,
            )?;
        }
    }

    // Only insert the extension if it is not already there. If this is otherwise MutOrOwned::Mut it
    // means that the extension was already there, and we already modified it.
    if let MutOrOwned::Owned(v) = result {
        cfg.options_mut().extensions.insert(v);
    }
    cfg.options().extensions.get().ok_or_else(|| {
        internal_datafusion_err!("ProgrammingError: a config option extension was just inserted, but it was not immediately retrievable")
    })
}

#[derive(Clone, Debug, Default)]
struct ConfigExtensionPropagationContext {
    prefixes: Vec<&'static str>,
}

pub(crate) fn get_config_extension_propagation_headers(
    cfg: &SessionConfig,
) -> Result<HeaderMap, DataFusionError> {
    fn parse_err(err: impl Error) -> DataFusionError {
        DataFusionError::Internal(format!("Failed to add config extension: {err}"))
    }
    let prefixes_to_send = cfg
        .get_extension::<ConfigExtensionPropagationContext>()
        .unwrap_or_default();

    if prefixes_to_send.prefixes.is_empty() {
        return Ok(HeaderMap::new());
    }

    let mut headers = HeaderMap::new();

    for (prefix, extension) in cfg.options().extensions.iter() {
        if !prefixes_to_send.prefixes.contains(&prefix) {
            continue;
        }
        for entry in extension.entries() {
            let Some(value) = entry.value else {
                continue;
            };

            let key = entry.key;
            headers.insert(
                HeaderName::from_str(&format!("{FLIGHT_METADATA_CONFIG_PREFIX}{prefix}.{key}"))
                    .map_err(parse_err)?,
                value.parse().map_err(parse_err)?,
            );
        }
    }
    Ok(headers)
}

#[cfg(test)]
mod tests {
    use crate::config_extension_ext::{
        ConfigExtensionPropagationContext, get_config_extension_propagation_headers,
        set_distributed_option_extension, set_distributed_option_extension_from_headers,
    };
    use datafusion::common::extensions_options;
    use datafusion::config::ConfigExtension;
    use datafusion::prelude::SessionConfig;
    use http::{HeaderMap, HeaderName, HeaderValue};
    use std::str::FromStr;

    #[test]
    fn test_propagation() -> Result<(), Box<dyn std::error::Error>> {
        let mut config = SessionConfig::new();

        let opt = CustomExtension {
            foo: "".to_string(),
            bar: 0,
            baz: false,
        };

        set_distributed_option_extension(&mut config, opt);
        let headers = get_config_extension_propagation_headers(&config)?;
        let mut new_config = SessionConfig::new();
        set_distributed_option_extension_from_headers::<CustomExtension>(
            &mut new_config,
            &headers,
        )?;

        let opt = get_ext::<CustomExtension>(&config);
        let new_opt = get_ext::<CustomExtension>(&new_config);

        assert_eq!(new_opt.foo, opt.foo);
        assert_eq!(new_opt.bar, opt.bar);
        assert_eq!(new_opt.baz, opt.baz);

        Ok(())
    }

    #[test]
    fn test_add_extension_with_empty_values() -> Result<(), Box<dyn std::error::Error>> {
        let mut config = SessionConfig::new();
        let opt = CustomExtension::default();

        set_distributed_option_extension(&mut config, opt);

        let flight_metadata = config.get_extension::<ConfigExtensionPropagationContext>();
        assert!(flight_metadata.is_some());

        let headers = get_config_extension_propagation_headers(&config)?;
        assert!(headers.contains_key("x-datafusion-distributed-config-custom.foo"));
        assert!(headers.contains_key("x-datafusion-distributed-config-custom.bar"));
        assert!(headers.contains_key("x-datafusion-distributed-config-custom.baz"));

        let get = |key: &str| headers.get(key).unwrap().to_str().unwrap();
        assert_eq!(get("x-datafusion-distributed-config-custom.foo"), "");
        assert_eq!(get("x-datafusion-distributed-config-custom.bar"), "0");
        assert_eq!(get("x-datafusion-distributed-config-custom.baz"), "false");

        Ok(())
    }

    #[test]
    fn test_new_extension_overwrites_previous() -> Result<(), Box<dyn std::error::Error>> {
        let mut config = SessionConfig::new();

        let opt1 = CustomExtension {
            foo: "first".to_string(),
            ..Default::default()
        };
        set_distributed_option_extension(&mut config, opt1);

        let opt2 = CustomExtension {
            bar: 42,
            ..Default::default()
        };
        set_distributed_option_extension(&mut config, opt2);

        let headers = get_config_extension_propagation_headers(&config)?;

        let get = |key: &str| headers.get(key).unwrap().to_str().unwrap();
        assert_eq!(get("x-datafusion-distributed-config-custom.foo"), "");
        assert_eq!(get("x-datafusion-distributed-config-custom.bar"), "42");
        assert_eq!(get("x-datafusion-distributed-config-custom.baz"), "false");

        Ok(())
    }

    #[test]
    fn test_propagate_no_metadata() -> Result<(), Box<dyn std::error::Error>> {
        let mut config = SessionConfig::new();

        set_distributed_option_extension_from_headers::<CustomExtension>(
            &mut config,
            &Default::default(),
        )?;

        let extension = config
            .options()
            .extensions
            .get::<CustomExtension>()
            .unwrap();
        let default = CustomExtension::default();
        assert_eq!(extension.foo, default.foo);
        assert_eq!(extension.bar, default.bar);
        assert_eq!(extension.baz, default.baz);

        Ok(())
    }

    #[test]
    fn test_propagate_no_matching_prefix() -> Result<(), Box<dyn std::error::Error>> {
        let mut config = SessionConfig::new();
        let mut header_map = HeaderMap::new();
        header_map.insert(
            HeaderName::from_str("x-datafusion-distributed-config-other.setting")?,
            HeaderValue::from_str("value")?,
        );

        set_distributed_option_extension_from_headers::<CustomExtension>(&mut config, &header_map)?;

        let extension = config
            .options()
            .extensions
            .get::<CustomExtension>()
            .unwrap();
        let default = CustomExtension::default();
        assert_eq!(extension.foo, default.foo);
        assert_eq!(extension.bar, default.bar);
        assert_eq!(extension.baz, default.baz);

        Ok(())
    }

    #[test]
    fn test_multiple_extensions_different_prefixes() -> Result<(), Box<dyn std::error::Error>> {
        let mut config = SessionConfig::new();

        let custom_opt = CustomExtension {
            foo: "custom_value".to_string(),
            bar: 123,
            ..Default::default()
        };

        let another_opt = AnotherExtension {
            setting1: "other".to_string(),
            setting2: 456,
            ..Default::default()
        };

        set_distributed_option_extension(&mut config, custom_opt);
        set_distributed_option_extension(&mut config, another_opt);

        let headers = get_config_extension_propagation_headers(&config)?;

        assert!(headers.contains_key("x-datafusion-distributed-config-custom.foo"));
        assert!(headers.contains_key("x-datafusion-distributed-config-custom.bar"));
        assert!(headers.contains_key("x-datafusion-distributed-config-another.setting1"));
        assert!(headers.contains_key("x-datafusion-distributed-config-another.setting2"));

        let get = |key: &str| headers.get(key).unwrap().to_str().unwrap();

        assert_eq!(
            get("x-datafusion-distributed-config-custom.foo"),
            "custom_value"
        );
        assert_eq!(get("x-datafusion-distributed-config-custom.bar"), "123");
        assert_eq!(
            get("x-datafusion-distributed-config-another.setting1"),
            "other"
        );
        assert_eq!(
            get("x-datafusion-distributed-config-another.setting2"),
            "456"
        );

        let mut new_config = SessionConfig::new();
        set_distributed_option_extension_from_headers::<CustomExtension>(
            &mut new_config,
            &headers,
        )?;
        set_distributed_option_extension_from_headers::<AnotherExtension>(
            &mut new_config,
            &headers,
        )?;

        let propagated_custom = get_ext::<CustomExtension>(&new_config);
        let propagated_another = get_ext::<AnotherExtension>(&new_config);

        assert_eq!(propagated_custom.foo, "custom_value");
        assert_eq!(propagated_custom.bar, 123);
        assert_eq!(propagated_another.setting1, "other");
        assert_eq!(propagated_another.setting2, 456);

        Ok(())
    }

    #[test]
    fn test_invalid_header_name() {
        let mut config = SessionConfig::new();
        let extension = InvalidExtension::default();

        set_distributed_option_extension(&mut config, extension);
        let result = get_config_extension_propagation_headers(&config);
        assert!(result.is_err());
    }

    #[test]
    fn test_invalid_header_value() {
        let mut config = SessionConfig::new();
        let extension = InvalidValueExtension::default();

        set_distributed_option_extension(&mut config, extension);
        let result = get_config_extension_propagation_headers(&config);
        assert!(result.is_err());
    }

    extensions_options! {
        pub struct CustomExtension {
            pub foo: String, default = "".to_string()
            pub bar: usize, default = 0
            pub baz: bool, default = false
        }
    }

    impl ConfigExtension for CustomExtension {
        const PREFIX: &'static str = "custom";
    }

    extensions_options! {
        pub struct AnotherExtension {
            pub setting1: String, default = "default1".to_string()
            pub setting2: usize, default = 42
        }
    }

    impl ConfigExtension for AnotherExtension {
        const PREFIX: &'static str = "another";
    }

    extensions_options! {
        pub struct InvalidExtension {
            pub key_with_spaces: String, default = "value".to_string()
        }
    }

    impl ConfigExtension for InvalidExtension {
        const PREFIX: &'static str = "invalid key with spaces";
    }

    extensions_options! {
        pub struct InvalidValueExtension {
            pub key: String, default = "\u{0000}invalid\u{0001}".to_string()
        }
    }

    impl ConfigExtension for InvalidValueExtension {
        const PREFIX: &'static str = "invalid_value";
    }

    fn get_ext<T: ConfigExtension>(cfg: &SessionConfig) -> &T {
        cfg.options().extensions.get::<T>().unwrap()
    }
}