datafusion-table-providers 0.11.1

Extend the capabilities of DataFusion to support additional data sources via implementations of the `TableProvider` trait.
use datafusion::arrow::datatypes::SchemaRef;
use itertools::Itertools;
use sea_query::{self, Alias};
use snafu::prelude::*;
use std::fmt::Display;

use super::column_reference::{self, ColumnReference};

#[derive(Debug, Snafu)]
pub enum Error {
    #[snafu(display("Invalid column reference: {source}"))]
    InvalidColumnReference { source: column_reference::Error },

    #[snafu(display("Expected do_nothing or upsert, found: {token}"))]
    UnexpectedToken { token: String },

    #[snafu(display("Expected semicolon in: {token}"))]
    ExpectedSemicolon { token: String },
}

#[derive(Debug, Clone, PartialEq)]
pub enum OnConflict {
    DoNothingAll,
    DoNothing(ColumnReference),
    Upsert(ColumnReference),
}

impl OnConflict {
    #[must_use]
    pub fn build_on_conflict_statement(&self, schema: &SchemaRef) -> String {
        match self {
            OnConflict::DoNothingAll => "ON CONFLICT DO NOTHING".to_string(),
            OnConflict::DoNothing(column) => {
                format!(
                    r#"ON CONFLICT ("{}") DO NOTHING"#,
                    column.iter().join(r#"", ""#)
                )
            }
            OnConflict::Upsert(column) => {
                let non_constraint_columns = schema
                    .fields()
                    .iter()
                    .filter(|f| !column.contains(f.name()))
                    .map(|f| f.name().to_string())
                    .collect::<Vec<String>>();
                let mut update_cols = String::new();
                for (i, col) in non_constraint_columns.iter().enumerate() {
                    update_cols.push_str(&format!(r#""{col}" = EXCLUDED."{col}""#));
                    if i < non_constraint_columns.len() - 1 {
                        update_cols.push_str(", ");
                    }
                }
                // This means that all columns are constraint columns, so we should do nothing
                if update_cols.is_empty() {
                    return format!(
                        r#"ON CONFLICT ("{}") DO NOTHING"#,
                        column.iter().join(r#"", ""#)
                    );
                }
                format!(
                    r#"ON CONFLICT ("{}") DO UPDATE SET {update_cols}"#,
                    column.iter().join(r#"", ""#)
                )
            }
        }
    }

    #[must_use]
    pub fn build_sea_query_on_conflict(&self, schema: &SchemaRef) -> sea_query::OnConflict {
        match self {
            OnConflict::DoNothingAll => {
                let mut on_conflict = sea_query::OnConflict::new();
                on_conflict.do_nothing();
                on_conflict
            }
            OnConflict::DoNothing(column) => {
                let mut on_conflict = sea_query::OnConflict::columns::<Vec<Alias>, Alias>(
                    column.iter().map(Alias::new).collect(),
                );
                on_conflict.do_nothing();
                on_conflict
            }
            OnConflict::Upsert(column) => {
                let mut on_conflict = sea_query::OnConflict::columns::<Vec<Alias>, Alias>(
                    column.iter().map(Alias::new).collect(),
                );

                let non_constraint_columns = schema
                    .fields()
                    .iter()
                    .filter(|f| !column.contains(f.name()))
                    .map(|f| Alias::new(f.name()))
                    .collect::<Vec<Alias>>();

                on_conflict.update_columns(non_constraint_columns);

                on_conflict
            }
        }
    }
}

impl Display for OnConflict {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            OnConflict::DoNothingAll => write!(f, "do_nothing_all"),
            OnConflict::DoNothing(column) => write!(f, "do_nothing:{column}"),
            OnConflict::Upsert(column) => write!(f, "upsert:{column}"),
        }
    }
}

impl TryFrom<&str> for OnConflict {
    type Error = Error;

    fn try_from(value: &str) -> std::result::Result<Self, Self::Error> {
        if value == "do_nothing_all" {
            return Ok(OnConflict::DoNothingAll);
        }

        let parts: Vec<&str> = value.split(':').collect();
        if parts.len() != 2 {
            return ExpectedSemicolonSnafu {
                token: value.to_string(),
            }
            .fail();
        }

        let column_ref =
            ColumnReference::try_from(parts[1]).context(InvalidColumnReferenceSnafu)?;

        let on_conflict_behavior = parts[0];
        match on_conflict_behavior {
            "do_nothing" => Ok(OnConflict::DoNothing(column_ref)),
            "upsert" => Ok(OnConflict::Upsert(column_ref)),
            _ => UnexpectedTokenSnafu {
                token: parts[0].to_string(),
            }
            .fail(),
        }
    }
}

#[cfg(test)]
mod tests {
    use std::sync::Arc;

    use datafusion::arrow::datatypes::{DataType, Field, Schema};

    use crate::util::{column_reference::ColumnReference, on_conflict::OnConflict};

    #[test]
    fn test_on_conflict_from_str() {
        let on_conflict = OnConflict::try_from("do_nothing_all").expect("valid on conflict");
        assert_eq!(on_conflict, OnConflict::DoNothingAll);

        let on_conflict = OnConflict::try_from("do_nothing:col1").expect("valid on conflict");
        assert_eq!(
            on_conflict,
            OnConflict::DoNothing(ColumnReference::new(vec!["col1".to_string()]))
        );

        let on_conflict = OnConflict::try_from("upsert:col2").expect("valid on conflict");
        assert_eq!(
            on_conflict,
            OnConflict::Upsert(ColumnReference::new(vec!["col2".to_string()]))
        );

        let err = OnConflict::try_from("do_nothing").expect_err("invalid on conflict");
        assert_eq!(
            err.to_string(),
            "Expected semicolon in: do_nothing".to_string()
        );
    }

    #[test]
    fn test_roundtrip() {
        let on_conflict = OnConflict::DoNothingAll.to_string();
        assert_eq!(
            OnConflict::try_from(on_conflict.as_str()).expect("valid on conflict"),
            OnConflict::DoNothingAll
        );

        let on_conflict =
            OnConflict::DoNothing(ColumnReference::new(vec!["col1".to_string()])).to_string();
        assert_eq!(
            OnConflict::try_from(on_conflict.as_str()).expect("valid on conflict"),
            OnConflict::DoNothing(ColumnReference::new(vec!["col1".to_string()]))
        );

        let on_conflict =
            OnConflict::Upsert(ColumnReference::new(vec!["col2".to_string()])).to_string();
        assert_eq!(
            OnConflict::try_from(on_conflict.as_str()).expect("valid on conflict"),
            OnConflict::Upsert(ColumnReference::new(vec!["col2".to_string()]))
        );
    }

    #[test]
    fn test_build_on_conflict_statement() {
        let schema = Arc::new(Schema::new(vec![
            Field::new("col1", DataType::Int64, false),
            Field::new("col2", DataType::Int64, false),
        ]));
        let on_conflict = OnConflict::DoNothingAll;
        assert_eq!(
            on_conflict.build_on_conflict_statement(&schema),
            "ON CONFLICT DO NOTHING".to_string()
        );

        let on_conflict = OnConflict::DoNothing(ColumnReference::new(vec!["col1".to_string()]));
        assert_eq!(
            on_conflict.build_on_conflict_statement(&schema),
            r#"ON CONFLICT ("col1") DO NOTHING"#.to_string()
        );

        let on_conflict = OnConflict::Upsert(ColumnReference::new(vec!["col2".to_string()]));
        assert_eq!(
            on_conflict.build_on_conflict_statement(&schema),
            r#"ON CONFLICT ("col2") DO UPDATE SET "col1" = EXCLUDED."col1""#.to_string()
        );
    }
}