pg-upsert 0.1.1

PostgreSQL UPSERT operations using sqlx
Documentation
use crate::error::{Result, UpsertError};
use crate::types::{Field, UpsertOptions};

pub struct UpsertQuery {
    pub sql: String,
}

fn quote_identifier(identifier: &str) -> String {
    format!("\"{}\"", identifier.replace('"', "\"\""))
}

pub fn build_upsert_sql(
    table: &str,
    fields: &[Field],
    conflict_fields: &[&str],
    row_count: usize,
    options: &UpsertOptions,
) -> Result<UpsertQuery> {
    if fields.is_empty() {
        return Err(UpsertError::EmptyFields);
    }

    if conflict_fields.is_empty() {
        return Err(UpsertError::EmptyConflictFields);
    }

    if row_count == 0 {
        return Err(UpsertError::EmptyValues);
    }

    if let Some(ref version_field) = options.version_field
        && !fields.iter().any(|f| f.name == *version_field)
    {
        return Err(UpsertError::VersionFieldNotFound(version_field.clone()));
    }

    let field_names: Vec<String> = fields.iter().map(|f| f.name.clone()).collect();
    let field_count = field_names.len();

    let columns = field_names
        .iter()
        .map(|name| quote_identifier(name))
        .collect::<Vec<_>>()
        .join(", ");

    let mut values_parts = Vec::with_capacity(row_count);
    let mut param_idx = 1;

    for _ in 0..row_count {
        let placeholders: Vec<String> = (0..field_count)
            .map(|_| {
                let p = format!("${}", param_idx);
                param_idx += 1;
                p
            })
            .collect();
        values_parts.push(format!("({})", placeholders.join(", ")));
    }

    let values = values_parts.join(", ");
    let conflict_cols = conflict_fields
        .iter()
        .map(|field| quote_identifier(field))
        .collect::<Vec<_>>()
        .join(", ");

    let sql = if options.do_nothing_on_conflict {
        format!(
            "INSERT INTO {} ({}) VALUES {} ON CONFLICT ({}) DO NOTHING",
            quote_identifier(table),
            columns,
            values,
            conflict_cols
        )
    } else {
        let update_fields: Vec<String> = field_names
            .iter()
            .filter(|name| !conflict_fields.contains(&name.as_str()))
            .map(|name| {
                format!(
                    "{} = EXCLUDED.{}",
                    quote_identifier(name),
                    quote_identifier(name)
                )
            })
            .collect();

        let update_clause = update_fields.join(", ");

        let where_clause = if let Some(ref version_field) = options.version_field {
            format!(
                " WHERE {}.{} < EXCLUDED.{}",
                quote_identifier(table),
                quote_identifier(version_field),
                quote_identifier(version_field)
            )
        } else {
            String::new()
        };

        format!(
            "INSERT INTO {} ({}) VALUES {} ON CONFLICT ({}) DO UPDATE SET {}{}",
            quote_identifier(table),
            columns,
            values,
            conflict_cols,
            update_clause,
            where_clause
        )
    };

    Ok(UpsertQuery { sql })
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::fields;

    #[test]
    fn test_build_basic_upsert_sql() {
        let fields = fields![
            "id" => 1_i32,
            "name" => "Alice",
            "email" => "alice@example.com",
        ];

        let result =
            build_upsert_sql("users", &fields, &["id"], 1, &UpsertOptions::default()).unwrap();

        assert_eq!(
            result.sql,
            "INSERT INTO \"users\" (\"id\", \"name\", \"email\") VALUES ($1, $2, $3) ON CONFLICT (\"id\") DO UPDATE SET \"name\" = EXCLUDED.\"name\", \"email\" = EXCLUDED.\"email\""
        );
    }

    #[test]
    fn test_build_upsert_with_do_nothing() {
        let fields = fields![
            "id" => 1_i32,
            "name" => "Alice",
        ];

        let options = UpsertOptions {
            do_nothing_on_conflict: true,
            ..Default::default()
        };

        let result = build_upsert_sql("users", &fields, &["id"], 1, &options).unwrap();

        assert_eq!(
            result.sql,
            "INSERT INTO \"users\" (\"id\", \"name\") VALUES ($1, $2) ON CONFLICT (\"id\") DO NOTHING"
        );
    }

    #[test]
    fn test_build_upsert_with_version_field() {
        let fields = fields![
            "id" => 1_i32,
            "name" => "Alice",
            "version" => 5_i32,
        ];

        let options = UpsertOptions {
            version_field: Some("version".into()),
            ..Default::default()
        };

        let result = build_upsert_sql("users", &fields, &["id"], 1, &options).unwrap();

        assert_eq!(
            result.sql,
            "INSERT INTO \"users\" (\"id\", \"name\", \"version\") VALUES ($1, $2, $3) ON CONFLICT (\"id\") DO UPDATE SET \"name\" = EXCLUDED.\"name\", \"version\" = EXCLUDED.\"version\" WHERE \"users\".\"version\" < EXCLUDED.\"version\""
        );
    }

    #[test]
    fn test_build_upsert_multiple_conflict_fields() {
        let fields = fields![
            "warehouse_id" => 1_i32,
            "product_id" => 100_i32,
            "quantity" => 50_i32,
        ];

        let result = build_upsert_sql(
            "inventory",
            &fields,
            &["warehouse_id", "product_id"],
            1,
            &UpsertOptions::default(),
        )
        .unwrap();

        assert_eq!(
            result.sql,
            "INSERT INTO \"inventory\" (\"warehouse_id\", \"product_id\", \"quantity\") VALUES ($1, $2, $3) ON CONFLICT (\"warehouse_id\", \"product_id\") DO UPDATE SET \"quantity\" = EXCLUDED.\"quantity\""
        );
    }

    #[test]
    fn test_build_upsert_multiple_rows() {
        let fields = fields![
            "id" => 1_i32,
            "name" => "Alice",
        ];

        let result =
            build_upsert_sql("users", &fields, &["id"], 3, &UpsertOptions::default()).unwrap();

        assert_eq!(
            result.sql,
            "INSERT INTO \"users\" (\"id\", \"name\") VALUES ($1, $2), ($3, $4), ($5, $6) ON CONFLICT (\"id\") DO UPDATE SET \"name\" = EXCLUDED.\"name\""
        );
    }

    #[test]
    fn test_build_upsert_empty_fields_error() {
        let fields: Vec<Field> = vec![];
        let result = build_upsert_sql("users", &fields, &["id"], 1, &UpsertOptions::default());

        assert!(matches!(result, Err(UpsertError::EmptyFields)));
    }

    #[test]
    fn test_build_upsert_empty_conflict_fields_error() {
        let fields = fields!["id" => 1_i32];
        let result = build_upsert_sql("users", &fields, &[], 1, &UpsertOptions::default());

        assert!(matches!(result, Err(UpsertError::EmptyConflictFields)));
    }

    #[test]
    fn test_build_upsert_empty_values_error() {
        let fields = fields!["id" => 1_i32];
        let result = build_upsert_sql("users", &fields, &["id"], 0, &UpsertOptions::default());

        assert!(matches!(result, Err(UpsertError::EmptyValues)));
    }

    #[test]
    fn test_build_upsert_version_field_not_found_error() {
        let fields = fields!["id" => 1_i32, "name" => "Alice"];
        let options = UpsertOptions {
            version_field: Some("version".into()),
            ..Default::default()
        };

        let result = build_upsert_sql("users", &fields, &["id"], 1, &options);

        assert!(matches!(result, Err(UpsertError::VersionFieldNotFound(_))));
    }

    #[test]
    fn test_sql_injection_protection_table_name() {
        let fields = fields!["id" => 1_i32, "name" => "Alice"];
        let malicious_table = "users; DROP TABLE users--";

        let result = build_upsert_sql(
            malicious_table,
            &fields,
            &["id"],
            1,
            &UpsertOptions::default(),
        )
        .unwrap();

        assert_eq!(
            result.sql,
            "INSERT INTO \"users; DROP TABLE users--\" (\"id\", \"name\") VALUES ($1, $2) ON CONFLICT (\"id\") DO UPDATE SET \"name\" = EXCLUDED.\"name\""
        );
    }

    #[test]
    fn test_sql_injection_protection_column_name() {
        let malicious_field = Field::new("id\"; DROP TABLE users--", 1_i32);
        let fields = vec![malicious_field, Field::new("name", "Alice")];

        let result =
            build_upsert_sql("users", &fields, &["id"], 1, &UpsertOptions::default()).unwrap();

        assert_eq!(
            result.sql,
            "INSERT INTO \"users\" (\"id\"\"; DROP TABLE users--\", \"name\") VALUES ($1, $2) ON CONFLICT (\"id\") DO UPDATE SET \"id\"\"; DROP TABLE users--\" = EXCLUDED.\"id\"\"; DROP TABLE users--\", \"name\" = EXCLUDED.\"name\""
        );
    }

    #[test]
    fn test_build_upsert_with_date_time_fields() {
        use chrono::{DateTime, NaiveDate, Utc};

        let date = NaiveDate::from_ymd_opt(2025, 12, 26).unwrap();
        let timestamp = DateTime::<Utc>::from_timestamp(1735225800, 0).unwrap();

        let fields = vec![
            Field::new("id", 1_i32),
            Field::new("created_at", timestamp),
            Field::new("birth_date", date),
        ];

        let result =
            build_upsert_sql("events", &fields, &["id"], 1, &UpsertOptions::default()).unwrap();

        assert_eq!(
            result.sql,
            "INSERT INTO \"events\" (\"id\", \"created_at\", \"birth_date\") VALUES ($1, $2, $3) ON CONFLICT (\"id\") DO UPDATE SET \"created_at\" = EXCLUDED.\"created_at\", \"birth_date\" = EXCLUDED.\"birth_date\""
        );
    }
}