use crate::error::{Result, UpsertError};
use crate::types::{Field, UpsertOptions};
pub struct UpsertQuery {
pub sql: String,
}
fn quote_identifier(identifier: &str) -> String {
format!("\"{}\"", identifier.replace('"', "\"\""))
}
pub fn build_upsert_sql(
table: &str,
fields: &[Field],
conflict_fields: &[&str],
row_count: usize,
options: &UpsertOptions,
) -> Result<UpsertQuery> {
if fields.is_empty() {
return Err(UpsertError::EmptyFields);
}
if conflict_fields.is_empty() {
return Err(UpsertError::EmptyConflictFields);
}
if row_count == 0 {
return Err(UpsertError::EmptyValues);
}
if let Some(ref version_field) = options.version_field
&& !fields.iter().any(|f| f.name == *version_field)
{
return Err(UpsertError::VersionFieldNotFound(version_field.clone()));
}
let field_names: Vec<String> = fields.iter().map(|f| f.name.clone()).collect();
let field_count = field_names.len();
let columns = field_names
.iter()
.map(|name| quote_identifier(name))
.collect::<Vec<_>>()
.join(", ");
let mut values_parts = Vec::with_capacity(row_count);
let mut param_idx = 1;
for _ in 0..row_count {
let placeholders: Vec<String> = (0..field_count)
.map(|_| {
let p = format!("${}", param_idx);
param_idx += 1;
p
})
.collect();
values_parts.push(format!("({})", placeholders.join(", ")));
}
let values = values_parts.join(", ");
let conflict_cols = conflict_fields
.iter()
.map(|field| quote_identifier(field))
.collect::<Vec<_>>()
.join(", ");
let sql = if options.do_nothing_on_conflict {
format!(
"INSERT INTO {} ({}) VALUES {} ON CONFLICT ({}) DO NOTHING",
quote_identifier(table),
columns,
values,
conflict_cols
)
} else {
let update_fields: Vec<String> = field_names
.iter()
.filter(|name| !conflict_fields.contains(&name.as_str()))
.map(|name| {
format!(
"{} = EXCLUDED.{}",
quote_identifier(name),
quote_identifier(name)
)
})
.collect();
let update_clause = update_fields.join(", ");
let where_clause = if let Some(ref version_field) = options.version_field {
format!(
" WHERE {}.{} < EXCLUDED.{}",
quote_identifier(table),
quote_identifier(version_field),
quote_identifier(version_field)
)
} else {
String::new()
};
format!(
"INSERT INTO {} ({}) VALUES {} ON CONFLICT ({}) DO UPDATE SET {}{}",
quote_identifier(table),
columns,
values,
conflict_cols,
update_clause,
where_clause
)
};
Ok(UpsertQuery { sql })
}
#[cfg(test)]
mod tests {
use super::*;
use crate::fields;
#[test]
fn test_build_basic_upsert_sql() {
let fields = fields![
"id" => 1_i32,
"name" => "Alice",
"email" => "alice@example.com",
];
let result =
build_upsert_sql("users", &fields, &["id"], 1, &UpsertOptions::default()).unwrap();
assert_eq!(
result.sql,
"INSERT INTO \"users\" (\"id\", \"name\", \"email\") VALUES ($1, $2, $3) ON CONFLICT (\"id\") DO UPDATE SET \"name\" = EXCLUDED.\"name\", \"email\" = EXCLUDED.\"email\""
);
}
#[test]
fn test_build_upsert_with_do_nothing() {
let fields = fields![
"id" => 1_i32,
"name" => "Alice",
];
let options = UpsertOptions {
do_nothing_on_conflict: true,
..Default::default()
};
let result = build_upsert_sql("users", &fields, &["id"], 1, &options).unwrap();
assert_eq!(
result.sql,
"INSERT INTO \"users\" (\"id\", \"name\") VALUES ($1, $2) ON CONFLICT (\"id\") DO NOTHING"
);
}
#[test]
fn test_build_upsert_with_version_field() {
let fields = fields![
"id" => 1_i32,
"name" => "Alice",
"version" => 5_i32,
];
let options = UpsertOptions {
version_field: Some("version".into()),
..Default::default()
};
let result = build_upsert_sql("users", &fields, &["id"], 1, &options).unwrap();
assert_eq!(
result.sql,
"INSERT INTO \"users\" (\"id\", \"name\", \"version\") VALUES ($1, $2, $3) ON CONFLICT (\"id\") DO UPDATE SET \"name\" = EXCLUDED.\"name\", \"version\" = EXCLUDED.\"version\" WHERE \"users\".\"version\" < EXCLUDED.\"version\""
);
}
#[test]
fn test_build_upsert_multiple_conflict_fields() {
let fields = fields![
"warehouse_id" => 1_i32,
"product_id" => 100_i32,
"quantity" => 50_i32,
];
let result = build_upsert_sql(
"inventory",
&fields,
&["warehouse_id", "product_id"],
1,
&UpsertOptions::default(),
)
.unwrap();
assert_eq!(
result.sql,
"INSERT INTO \"inventory\" (\"warehouse_id\", \"product_id\", \"quantity\") VALUES ($1, $2, $3) ON CONFLICT (\"warehouse_id\", \"product_id\") DO UPDATE SET \"quantity\" = EXCLUDED.\"quantity\""
);
}
#[test]
fn test_build_upsert_multiple_rows() {
let fields = fields![
"id" => 1_i32,
"name" => "Alice",
];
let result =
build_upsert_sql("users", &fields, &["id"], 3, &UpsertOptions::default()).unwrap();
assert_eq!(
result.sql,
"INSERT INTO \"users\" (\"id\", \"name\") VALUES ($1, $2), ($3, $4), ($5, $6) ON CONFLICT (\"id\") DO UPDATE SET \"name\" = EXCLUDED.\"name\""
);
}
#[test]
fn test_build_upsert_empty_fields_error() {
let fields: Vec<Field> = vec![];
let result = build_upsert_sql("users", &fields, &["id"], 1, &UpsertOptions::default());
assert!(matches!(result, Err(UpsertError::EmptyFields)));
}
#[test]
fn test_build_upsert_empty_conflict_fields_error() {
let fields = fields!["id" => 1_i32];
let result = build_upsert_sql("users", &fields, &[], 1, &UpsertOptions::default());
assert!(matches!(result, Err(UpsertError::EmptyConflictFields)));
}
#[test]
fn test_build_upsert_empty_values_error() {
let fields = fields!["id" => 1_i32];
let result = build_upsert_sql("users", &fields, &["id"], 0, &UpsertOptions::default());
assert!(matches!(result, Err(UpsertError::EmptyValues)));
}
#[test]
fn test_build_upsert_version_field_not_found_error() {
let fields = fields!["id" => 1_i32, "name" => "Alice"];
let options = UpsertOptions {
version_field: Some("version".into()),
..Default::default()
};
let result = build_upsert_sql("users", &fields, &["id"], 1, &options);
assert!(matches!(result, Err(UpsertError::VersionFieldNotFound(_))));
}
#[test]
fn test_sql_injection_protection_table_name() {
let fields = fields!["id" => 1_i32, "name" => "Alice"];
let malicious_table = "users; DROP TABLE users--";
let result = build_upsert_sql(
malicious_table,
&fields,
&["id"],
1,
&UpsertOptions::default(),
)
.unwrap();
assert_eq!(
result.sql,
"INSERT INTO \"users; DROP TABLE users--\" (\"id\", \"name\") VALUES ($1, $2) ON CONFLICT (\"id\") DO UPDATE SET \"name\" = EXCLUDED.\"name\""
);
}
#[test]
fn test_sql_injection_protection_column_name() {
let malicious_field = Field::new("id\"; DROP TABLE users--", 1_i32);
let fields = vec![malicious_field, Field::new("name", "Alice")];
let result =
build_upsert_sql("users", &fields, &["id"], 1, &UpsertOptions::default()).unwrap();
assert_eq!(
result.sql,
"INSERT INTO \"users\" (\"id\"\"; DROP TABLE users--\", \"name\") VALUES ($1, $2) ON CONFLICT (\"id\") DO UPDATE SET \"id\"\"; DROP TABLE users--\" = EXCLUDED.\"id\"\"; DROP TABLE users--\", \"name\" = EXCLUDED.\"name\""
);
}
#[test]
fn test_build_upsert_with_date_time_fields() {
use chrono::{DateTime, NaiveDate, Utc};
let date = NaiveDate::from_ymd_opt(2025, 12, 26).unwrap();
let timestamp = DateTime::<Utc>::from_timestamp(1735225800, 0).unwrap();
let fields = vec![
Field::new("id", 1_i32),
Field::new("created_at", timestamp),
Field::new("birth_date", date),
];
let result =
build_upsert_sql("events", &fields, &["id"], 1, &UpsertOptions::default()).unwrap();
assert_eq!(
result.sql,
"INSERT INTO \"events\" (\"id\", \"created_at\", \"birth_date\") VALUES ($1, $2, $3) ON CONFLICT (\"id\") DO UPDATE SET \"created_at\" = EXCLUDED.\"created_at\", \"birth_date\" = EXCLUDED.\"birth_date\""
);
}
}