use async_trait::async_trait;
use sea_orm::{ConnectionTrait, DatabaseBackend, Statement};
use serde_json::Value;
use crate::database::DB;
use crate::validation::translate_validation;
use super::async_rule::AsyncRule;
pub struct Unique {
table: String,
col: String,
ignore: Option<(String, sea_orm::Value)>, }
pub fn unique(table: impl Into<String>, col: impl Into<String>) -> Unique {
Unique {
table: table.into(),
col: col.into(),
ignore: None,
}
}
impl Unique {
pub fn ignore(mut self, id: impl Into<sea_orm::Value>) -> Self {
self.ignore = Some(("id".to_string(), id.into()));
self
}
pub fn ignore_on(mut self, pk_col: impl Into<String>, id: impl Into<sea_orm::Value>) -> Self {
self.ignore = Some((pk_col.into(), id.into()));
self
}
fn validate_identifier(ident: &str) -> Result<(), String> {
if !ident.is_empty() && ident.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
Ok(())
} else {
Err(format!("Invalid SQL identifier: {ident:?}"))
}
}
fn quote_ident(ident: &str) -> String {
format!("\"{ident}\"")
}
fn build_sql(&self, backend: DatabaseBackend, table: &str, col: &str) -> String {
match (&self.ignore, backend) {
(None, DatabaseBackend::Postgres) => {
format!("SELECT COUNT(*) AS count FROM {table} WHERE {col} = $1")
}
(None, _) => {
format!("SELECT COUNT(*) AS count FROM {table} WHERE {col} = ?")
}
(Some((pk_col, _)), DatabaseBackend::Postgres) => {
let pk = Self::quote_ident(pk_col);
format!("SELECT COUNT(*) AS count FROM {table} WHERE {col} = $1 AND {pk} <> $2")
}
(Some((pk_col, _)), _) => {
let pk = Self::quote_ident(pk_col);
format!("SELECT COUNT(*) AS count FROM {table} WHERE {col} = ? AND {pk} <> ?")
}
}
}
}
#[async_trait]
impl AsyncRule for Unique {
async fn validate(&self, field: &str, value: &Value, _data: &Value) -> Result<(), String> {
Self::validate_identifier(&self.table)
.map_err(|e| format!("Unique rule misconfigured: {e}"))?;
Self::validate_identifier(&self.col)
.map_err(|e| format!("Unique rule misconfigured: {e}"))?;
if let Some((ref pk_col, _)) = self.ignore {
Self::validate_identifier(pk_col)
.map_err(|e| format!("Unique rule misconfigured: {e}"))?;
}
let table = Self::quote_ident(&self.table);
let col = Self::quote_ident(&self.col);
let db = DB::connection().map_err(|e| format!("__infra_error__: {e}"))?;
let backend = db.get_database_backend();
let sql = self.build_sql(backend, &table, &col);
let values: Vec<sea_orm::Value> = match &self.ignore {
None => vec![json_value_to_sea_value(value)],
Some((_, pk_val)) => vec![json_value_to_sea_value(value), pk_val.clone()],
};
let stmt = Statement::from_sql_and_values(backend, sql, values);
let row = db
.query_one(stmt)
.await
.map_err(|e| format!("__infra_error__: {e}"))?;
let count: i64 = match row {
Some(r) => r
.try_get::<i64>("", "count")
.map_err(|e| format!("__infra_error__: {e}"))?,
None => return Err("__infra_error__: uniqueness COUNT returned no row".to_string()),
};
if count > 0 {
Err(
translate_validation("validation.unique", &[("attribute", field)])
.unwrap_or_else(|| format!("The {field} has already been taken.")),
)
} else {
Ok(())
}
}
fn name(&self) -> &'static str {
"unique"
}
}
pub(crate) fn json_value_to_sea_value(v: &serde_json::Value) -> sea_orm::Value {
match v {
serde_json::Value::String(s) => sea_orm::Value::String(Some(Box::new(s.clone()))),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
sea_orm::Value::BigInt(Some(i))
} else if let Some(f) = n.as_f64() {
sea_orm::Value::Double(Some(f))
} else {
sea_orm::Value::String(Some(Box::new(n.to_string())))
}
}
serde_json::Value::Bool(b) => sea_orm::Value::Bool(Some(*b)),
serde_json::Value::Null => sea_orm::Value::String(None),
_ => sea_orm::Value::String(Some(Box::new(v.to_string()))),
}
}
#[cfg(test)]
mod tests {
use super::*;
use sea_orm::DatabaseBackend;
use serde_json::json;
use serial_test::serial;
async fn init_test_db() {
use crate::database::{DatabaseConfig, DB};
use sea_orm::{ConnectionTrait, Statement};
let config = DatabaseConfig::builder().url("sqlite::memory:").build();
DB::init_with(config).await.expect("init in-memory sqlite");
let db = DB::connection().expect("connection after init");
db.execute(Statement::from_string(
db.get_database_backend(),
"CREATE TABLE IF NOT EXISTS widgets (id INTEGER PRIMARY KEY, slug TEXT)".to_owned(),
))
.await
.expect("create widgets scratch table");
}
async fn seed_widget(id: i64, slug: &str) {
use crate::database::DB;
use sea_orm::{ConnectionTrait, Statement};
let db = DB::connection().expect("connection for seed_widget");
db.execute(Statement::from_string(
db.get_database_backend(),
format!("INSERT INTO widgets (id, slug) VALUES ({id}, '{slug}')"),
))
.await
.expect("seed widget row");
}
#[test]
fn validate_identifier_accepts_valid_names() {
assert!(Unique::validate_identifier("slug").is_ok());
assert!(Unique::validate_identifier("my_table").is_ok());
assert!(Unique::validate_identifier("Table123").is_ok());
assert!(Unique::validate_identifier("a").is_ok());
}
#[test]
fn validate_identifier_rejects_invalid_names() {
assert!(Unique::validate_identifier("").is_err());
assert!(Unique::validate_identifier("a;b").is_err());
assert!(Unique::validate_identifier("a b").is_err());
assert!(Unique::validate_identifier("a.b").is_err());
assert!(Unique::validate_identifier("a-b").is_err());
assert!(Unique::validate_identifier("a'b").is_err());
}
#[test]
fn quote_ident_wraps_in_double_quotes() {
assert_eq!(Unique::quote_ident("slug"), "\"slug\"");
assert_eq!(Unique::quote_ident("my_col"), "\"my_col\"");
}
#[test]
fn ignore_sets_default_id_pk() {
let u = unique("widgets", "slug").ignore(5_i64);
let (pk_col, _) = u.ignore.expect("ignore should be set");
assert_eq!(pk_col, "id");
}
#[test]
fn ignore_on_sets_custom_pk() {
let u = unique("widgets", "slug").ignore_on("uuid", "abc");
let (pk_col, _) = u.ignore.expect("ignore should be set");
assert_eq!(pk_col, "uuid");
}
#[test]
fn json_value_to_sea_value_string() {
let v = json!("hello");
let sv = json_value_to_sea_value(&v);
assert!(matches!(sv, sea_orm::Value::String(Some(s)) if s.as_str() == "hello"));
}
#[test]
fn json_value_to_sea_value_integer() {
let v = json!(7_i64);
let sv = json_value_to_sea_value(&v);
assert!(matches!(sv, sea_orm::Value::BigInt(Some(7))));
}
#[test]
fn json_value_to_sea_value_bool() {
let v = json!(true);
let sv = json_value_to_sea_value(&v);
assert!(matches!(sv, sea_orm::Value::Bool(Some(true))));
}
#[test]
fn json_value_to_sea_value_null_binds_sql_null() {
let v = json!(null);
let sv = json_value_to_sea_value(&v);
assert!(matches!(sv, sea_orm::Value::String(None)));
}
#[test]
fn unique_postgres_sql_uses_dollar_placeholders() {
let u = unique("widgets", "slug");
let sql = u.build_sql(DatabaseBackend::Postgres, "\"widgets\"", "\"slug\"");
assert!(sql.contains("$1"), "expected $1 in: {sql}");
assert!(
!sql.contains('?'),
"should not have ? in postgres sql: {sql}"
);
}
#[test]
fn unique_postgres_sql_with_ignore_uses_dollar_two() {
let u = unique("widgets", "slug").ignore(1_i64);
let sql = u.build_sql(DatabaseBackend::Postgres, "\"widgets\"", "\"slug\"");
assert!(sql.contains("$1"), "expected $1 in: {sql}");
assert!(sql.contains("$2"), "expected $2 in: {sql}");
}
#[test]
fn unique_sqlite_sql_uses_question_placeholder() {
let u = unique("widgets", "slug");
let sql = u.build_sql(DatabaseBackend::Sqlite, "\"widgets\"", "\"slug\"");
assert!(sql.contains('?'), "expected ? in sqlite sql: {sql}");
assert!(
!sql.contains("$1"),
"should not have $1 in sqlite sql: {sql}"
);
}
#[tokio::test]
async fn unique_rejects_bad_identifier_before_db() {
let data = json!({});
let result = unique("bad;name", "slug")
.validate("slug", &json!("value"), &data)
.await;
assert!(result.is_err());
let msg = result.unwrap_err();
assert!(
msg.starts_with("Unique rule misconfigured"),
"expected 'Unique rule misconfigured' prefix, got: {msg}"
);
}
#[tokio::test]
async fn unique_rejects_bad_column_identifier_before_db() {
let data = json!({});
let result = unique("widgets", "bad col")
.validate("slug", &json!("value"), &data)
.await;
assert!(result.is_err());
let msg = result.unwrap_err();
assert!(msg.starts_with("Unique rule misconfigured"));
}
#[tokio::test]
#[serial]
async fn unique_detects_existing_value() {
init_test_db().await;
seed_widget(1, "taken").await;
let data = json!({});
let result = unique("widgets", "slug")
.validate("slug", &json!("taken"), &data)
.await;
assert!(result.is_err(), "expected Err for duplicate slug");
let msg = result.unwrap_err();
assert!(
!msg.starts_with("__infra_error__"),
"must not be an infra error: {msg}"
);
}
#[tokio::test]
#[serial]
async fn unique_passes_on_no_match() {
init_test_db().await;
let data = json!({});
let result = unique("widgets", "slug")
.validate("slug", &json!("free"), &data)
.await;
assert!(result.is_ok(), "expected Ok for non-duplicate slug");
}
#[tokio::test]
#[serial]
async fn unique_ignore_excludes_self() {
init_test_db().await;
seed_widget(1, "taken").await;
let data = json!({});
let result = unique("widgets", "slug")
.ignore(1_i64)
.validate("slug", &json!("taken"), &data)
.await;
assert!(result.is_ok(), "expected Ok when ignoring own row");
let result = unique("widgets", "slug")
.ignore(2_i64)
.validate("slug", &json!("taken"), &data)
.await;
assert!(
result.is_err(),
"expected Err when ignoring a different row"
);
}
#[tokio::test]
#[serial]
async fn unique_ignore_on_custom_pk() {
init_test_db().await;
let db = crate::database::DB::connection().expect("connection");
db.execute(sea_orm::Statement::from_string(
db.get_database_backend(),
"CREATE TABLE IF NOT EXISTS items (uid INTEGER PRIMARY KEY, code TEXT)".to_owned(),
))
.await
.expect("create items scratch table");
db.execute(sea_orm::Statement::from_string(
db.get_database_backend(),
"DELETE FROM items".to_owned(),
))
.await
.expect("clear items");
db.execute(sea_orm::Statement::from_string(
db.get_database_backend(),
"INSERT INTO items (uid, code) VALUES (10, 'ABC')".to_owned(),
))
.await
.expect("seed item");
let data = json!({});
let result = unique("items", "code")
.ignore_on("uid", 10_i64)
.validate("code", &json!("ABC"), &data)
.await;
assert!(
result.is_ok(),
"expected Ok when ignoring own row via custom PK"
);
let result = unique("items", "code")
.ignore_on("uid", 99_i64)
.validate("code", &json!("ABC"), &data)
.await;
assert!(
result.is_err(),
"expected Err when different row owns the value"
);
}
}