use crate::{
dialect::{CurrentDialect, CurrentRow, Db, Dialect},
query::{ImageQuery, TagQuery},
storage::{ImageMetadata, PixelHash},
utils,
};
use chrono::{DateTime, Utc};
use sqlx::{Execute, FromRow, Row};
use std::str::FromStr;
use thiserror::Error;
pub type Pool = sqlx::Pool<Db>;
#[cfg(all(feature = "sqlite", not(feature = "postgres")))]
pub static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("migrations/sqlite");
#[cfg(all(feature = "postgres", not(feature = "sqlite")))]
pub static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("migrations/postgres");
pub async fn run_migration(pool: &sqlx::Pool<Db>) -> Result<(), sqlx::Error> {
MIGRATOR.run(pool).await?;
Ok(())
}
impl FromRow<'_, CurrentRow> for ImageMetadata {
fn from_row(row: &CurrentRow) -> Result<Self, sqlx::Error> {
let width: i32 = row.try_get("width")?;
let height: i32 = row.try_get("height")?;
let format: String = row.try_get("format")?;
let color_type: String = row.try_get("color_type")?;
let file_size: i64 = row.try_get("file_size")?;
let created_at: String = row.try_get("created_at")?;
let created_at = DateTime::from_str(&created_at).expect("");
let duration: Option<f64> = row.try_get("duration")?;
Ok(ImageMetadata {
width: width as u32,
height: height as u32,
format,
color_type,
file_size: file_size as u64,
created_at: Some(created_at),
duration,
})
}
}
#[derive(Debug, Clone)]
pub struct Database {
pub pool: Pool,
}
impl Database {
pub fn new(pool: sqlx::Pool<Db>) -> Self {
Self { pool }
}
pub async fn migrate(&self) -> Result<(), sqlx::Error> {
run_migration(&self.pool).await
}
pub async fn image_exists(&self, hash: &PixelHash) -> Result<bool, DatabaseError> {
let stmt = CurrentDialect::exists_image();
let res = utils::retry(|| async {
let query = sqlx::query_scalar(&stmt).bind(hash.clone().to_string());
let sql = query.sql();
query
.fetch_one(&self.pool)
.await
.map_err(|e| DatabaseError::QueryFailed {
operation: DbOperation::QueryImages,
sql: sql.to_string(),
source: e,
})
})
.await?;
Ok(res)
}
pub async fn ensure_image(&self, hash: &PixelHash) -> Result<(), DatabaseError> {
if self.image_exists(hash).await? {
return Ok(());
}
let stmt = CurrentDialect::ensure_image_statement();
utils::retry(|| async {
let query = sqlx::query(&stmt).bind(hash.clone().to_string());
let sql = query.sql();
query
.execute(&self.pool)
.await
.map_err(|e| DatabaseError::QueryFailed {
operation: DbOperation::InsertImage { hash: hash.clone() },
sql: sql.to_string(),
source: e,
})
})
.await?;
Ok(())
}
pub async fn ensure_image_has_metadata(
&self,
hash: &PixelHash,
metadata: &ImageMetadata,
) -> Result<(), DatabaseError> {
self.ensure_image(hash).await?;
let stmt = CurrentDialect::ensure_metadata_statement();
utils::retry(|| async {
let query = sqlx::query(&stmt)
.bind(hash.clone().to_string())
.bind(metadata.width as i64)
.bind(metadata.height as i64)
.bind(&metadata.format)
.bind(&metadata.color_type)
.bind(metadata.file_size as i64)
.bind(metadata.created_at.unwrap_or(Utc::now()).to_rfc3339())
.bind(metadata.duration);
let sql = query.sql();
query
.execute(&self.pool)
.await
.map_err(|e| DatabaseError::QueryFailed {
operation: DbOperation::InsertMetadata {
metadata: metadata.clone(),
},
sql: sql.to_string(),
source: e,
})
})
.await?;
Ok(())
}
pub async fn ensure_tags(&self, tags: &[&str]) -> Result<(), DatabaseError> {
let stmt = CurrentDialect::ensure_tag_statement();
utils::retry(|| async {
let mut tx = self
.pool
.begin()
.await
.map_err(|e| DatabaseError::TransactionFailed { source: e })?;
for tag in tags.iter() {
let query = sqlx::query(&stmt).bind(tag);
let sql = query.sql();
query
.execute(&mut *tx)
.await
.map_err(|e| DatabaseError::QueryFailed {
operation: DbOperation::InsertTag {
tag: tag.to_string(),
},
sql: sql.to_string(),
source: e,
})?;
}
tx.commit()
.await
.map_err(|e| DatabaseError::TransactionFailed { source: e })
})
.await?;
Ok(())
}
pub async fn ensure_image_has_tags(
&self,
hash: &PixelHash,
tags: &[&str],
) -> Result<(), DatabaseError> {
self.ensure_image(hash).await?;
self.ensure_tags(tags).await?;
let stmt = CurrentDialect::ensure_image_tag_statement();
utils::retry(|| async {
let mut tx = self
.pool
.begin()
.await
.map_err(|e| DatabaseError::TransactionFailed { source: e })?;
for tag in tags.iter() {
let query = sqlx::query(&stmt).bind(hash.to_string()).bind(tag);
let sql = query.sql();
query
.execute(&mut *tx)
.await
.map_err(|e| DatabaseError::QueryFailed {
operation: DbOperation::InsertImageTag {
hash: hash.clone(),
tag: tag.to_string(),
},
sql: sql.to_string(),
source: e,
})?;
}
tx.commit()
.await
.map_err(|e| DatabaseError::TransactionFailed { source: e })
})
.await?;
Ok(())
}
pub async fn ensure_image_has_source(
&self,
hash: &PixelHash,
source: &str,
) -> Result<(), DatabaseError> {
self.ensure_image(hash).await?;
let stmt = CurrentDialect::update_source_statement();
utils::retry(|| async {
let query = sqlx::query(&stmt)
.bind(source)
.bind(hash.clone().to_string());
let sql = query.sql();
query
.execute(&self.pool)
.await
.map_err(|e| DatabaseError::QueryFailed {
operation: DbOperation::UpdateImageSource {
hash: hash.clone(),
source: source.to_string(),
},
sql: sql.to_string(),
source: e,
})
})
.await?;
Ok(())
}
pub async fn query_image(&self, query: ImageQuery) -> Result<Vec<PixelHash>, DatabaseError> {
let (sql, params) = query.to_sql();
let stmt = CurrentDialect::query_image_statement(sql);
let hashes = utils::retry(|| async {
let mut q = sqlx::query_scalar::<_, String>(&stmt);
for param in ¶ms {
q = q.bind(param);
}
q.fetch_all(&self.pool)
.await
.map_err(|e| DatabaseError::QueryFailed {
operation: DbOperation::QueryImages,
sql: stmt.to_string(),
source: e,
})
})
.await?
.into_iter()
.filter_map(|s| PixelHash::try_from(s).ok())
.collect();
Ok(hashes)
}
pub async fn count_image(&self, query: ImageQuery) -> Result<u64, DatabaseError> {
let (sql, params) = query.to_sql();
let stmt = CurrentDialect::count_image_statement(sql);
let count = utils::retry(|| async {
let mut q = sqlx::query_scalar(&stmt);
for param in ¶ms {
q = q.bind(param);
}
let count: i64 =
q.fetch_one(&self.pool)
.await
.map_err(|e| DatabaseError::QueryFailed {
operation: DbOperation::QueryImages,
sql: stmt.to_string(),
source: e,
})?;
Ok(count as u64)
})
.await?;
Ok(count)
}
pub async fn count_image_by_tag(&self, tag: &str) -> Result<u64, DatabaseError> {
let stmt = CurrentDialect::count_image_by_tag_statement();
let count = utils::retry(|| async {
let q = sqlx::query_scalar(&stmt).bind(tag);
let count: i64 = q
.fetch_optional(&self.pool)
.await
.map(|r| r.unwrap_or_default())
.map_err(|e| DatabaseError::QueryFailed {
operation: DbOperation::QueryImages,
sql: stmt.to_string(),
source: e,
})?;
Ok(count as u64)
})
.await?;
Ok(count)
}
pub async fn refresh_image_count(&self) -> Result<(), DatabaseError> {
utils::retry(|| async {
let mut tx = self
.pool
.begin()
.await
.map_err(|e| DatabaseError::TransactionFailed { source: e })?;
for stmt in CurrentDialect::refresh_tag_counts_statement() {
let q = sqlx::query(&stmt);
q.execute(&mut *tx)
.await
.map_err(|e| DatabaseError::QueryFailed {
operation: DbOperation::QueryImages,
sql: stmt.to_string(),
source: e,
})?;
}
tx.commit()
.await
.map_err(|e| DatabaseError::TransactionFailed { source: e })
})
.await?;
Ok(())
}
pub async fn query_tags(&self, query: TagQuery) -> Result<Vec<String>, DatabaseError> {
let (sql, params) = query.to_sql();
let stmt = CurrentDialect::query_tag_statement(sql);
let hashes = utils::retry(|| async {
let mut q = sqlx::query_scalar::<_, String>(&stmt);
for param in ¶ms {
q = q.bind(param);
}
q.fetch_all(&self.pool)
.await
.map_err(|e| DatabaseError::QueryFailed {
operation: DbOperation::QueryTags,
sql: stmt.to_string(),
source: e,
})
})
.await?
.into_iter()
.collect();
Ok(hashes)
}
pub async fn get_tags(&self, hash: &PixelHash) -> Result<Vec<String>, DatabaseError> {
let stmt = CurrentDialect::query_tags_by_image_statement();
let rows = utils::retry(|| async {
sqlx::query_scalar(&stmt)
.bind(hash.clone().to_string())
.fetch_all(&self.pool)
.await
.map_err(|e| DatabaseError::QueryFailed {
operation: DbOperation::QueryImages,
sql: stmt.to_string(),
source: e,
})
})
.await?;
Ok(rows)
}
pub async fn get_metadata(
&self,
hash: &PixelHash,
) -> Result<Option<ImageMetadata>, DatabaseError> {
let stmt = CurrentDialect::query_metadata_statement();
let metadata: Option<ImageMetadata> = utils::retry(|| async {
sqlx::query_as(&stmt)
.bind(hash.clone().to_string())
.fetch_optional(&self.pool)
.await
.map_err(|e| DatabaseError::QueryFailed {
operation: DbOperation::QueryImages,
sql: stmt.to_string(),
source: e,
})
})
.await?;
Ok(metadata)
}
pub async fn get_source(&self, hash: &PixelHash) -> Result<Option<String>, DatabaseError> {
let stmt = CurrentDialect::query_source_statement();
let soruce: Option<String> = utils::retry(|| async {
let query = sqlx::query_scalar(&stmt).bind(hash.clone().to_string());
let sql = query.sql();
query
.fetch_one(&self.pool)
.await
.map_err(|e| DatabaseError::QueryFailed {
operation: DbOperation::QueryImages,
sql: sql.to_string(),
source: e,
})
})
.await?;
Ok(soruce)
}
pub async fn ensure_tags_removed(
&self,
hash: &PixelHash,
tags: &[&str],
) -> Result<(), DatabaseError> {
let stmt = CurrentDialect::delete_image_tag_statement();
utils::retry(|| async {
let mut tx = self
.pool
.begin()
.await
.map_err(|e| DatabaseError::TransactionFailed { source: e })?;
for tag in tags.iter() {
let query = sqlx::query(&stmt).bind(hash.to_string()).bind(tag);
let sql = query.sql();
query
.execute(&mut *tx)
.await
.map_err(|e| DatabaseError::QueryFailed {
operation: DbOperation::DeleteImageTag {
hash: hash.clone(),
tag: tag.to_string(),
},
sql: sql.to_string(),
source: e,
})?;
}
tx.commit()
.await
.map_err(|e| DatabaseError::TransactionFailed { source: e })
})
.await?;
Ok(())
}
pub async fn ensure_image_removed(&self, hash: &PixelHash) -> Result<(), DatabaseError> {
let stmt_tags = CurrentDialect::delete_tags_by_image_statement();
let stmt_image = CurrentDialect::delete_image_statement();
utils::retry(|| async {
let mut tx = self
.pool
.begin()
.await
.map_err(|e| DatabaseError::TransactionFailed { source: e })?;
sqlx::query(&stmt_tags)
.bind(hash.clone().to_string())
.execute(&mut *tx)
.await
.map_err(|e| DatabaseError::QueryFailed {
operation: DbOperation::DeleteImageTags { hash: hash.clone() },
sql: stmt_tags.to_string(),
source: e,
})?;
sqlx::query(&stmt_image)
.bind(hash.clone().to_string())
.execute(&mut *tx)
.await
.map_err(|e| DatabaseError::QueryFailed {
operation: DbOperation::DeleteImage { hash: hash.clone() },
sql: stmt_image.to_string(),
source: e,
})?;
tx.commit()
.await
.map_err(|e| DatabaseError::TransactionFailed { source: e })
})
.await?;
Ok(())
}
}
#[derive(Debug, Error)]
pub enum DatabaseError {
#[error("Query failed during {operation:?}: sql={sql} source={source}")]
QueryFailed {
operation: DbOperation,
sql: String,
#[source]
source: sqlx::Error,
},
#[error("Failed to operate transaction: source={source}")]
TransactionFailed {
#[source]
source: sqlx::Error,
},
}
#[derive(Debug)]
pub enum DbOperation {
InsertImage {
hash: PixelHash,
},
InsertTag {
tag: String,
},
InsertImageTag {
hash: PixelHash,
tag: String,
},
DeleteImageTag {
hash: PixelHash,
tag: String,
},
DeleteImage {
hash: PixelHash,
},
DeleteImageTags {
hash: PixelHash,
},
QueryImageTags {
hash: PixelHash,
},
QueryImages,
InsertMetadata {
metadata: ImageMetadata,
},
UpdateImageSource {
hash: PixelHash,
source: String,
},
QueryTags,
}
#[cfg(test)]
mod tests {
use crate::{
database::{Database, MIGRATOR, Pool},
query::{ImageQuery, ImageQueryExpr, ImageQueryKind, TagQuery, TagQueryExpr, TagQueryKind},
storage::{ImageMetadata, PixelHash},
};
use chrono::DateTime;
use std::str::FromStr;
#[sqlx::test(migrator = "MIGRATOR")]
async fn test_ensure_image(pool: Pool) {
let db = Database::new(pool);
let image = PixelHash::try_from("329435e5e66be809").unwrap();
db.ensure_image(&image).await.unwrap();
db.ensure_image(&image).await.unwrap();
}
#[sqlx::test(migrator = "MIGRATOR")]
async fn test_ensure_source(pool: Pool) {
let db = Database::new(pool);
let image = PixelHash::try_from("329435e5e66be809").unwrap();
db.ensure_image_has_source(&image, "src").await.unwrap();
assert_eq!(
Some("src".to_string()),
db.get_source(&image).await.unwrap()
);
}
#[sqlx::test(migrator = "MIGRATOR")]
async fn test_ensure_metadata(pool: Pool) {
let db = Database::new(pool);
let image = PixelHash::try_from("329435e5e66be809").unwrap();
let metadata = ImageMetadata {
width: 200,
height: 200,
format: "image/png".to_string(),
color_type: "rgba".to_string(),
file_size: 1337,
created_at: Some(DateTime::from_str("2025-05-02T01:18:49.678809123Z").unwrap()),
duration: Some(1.0),
};
db.ensure_image_has_metadata(&image, &metadata)
.await
.unwrap();
db.ensure_image_has_metadata(&image, &metadata)
.await
.unwrap();
assert_eq!(Some(metadata), db.get_metadata(&image).await.unwrap());
}
#[sqlx::test(migrator = "MIGRATOR")]
async fn test_ensure_metadata_without_created_at(pool: Pool) {
let db = Database::new(pool);
let image = PixelHash::try_from("329435e5e66be809").unwrap();
let metadata = ImageMetadata {
width: 200,
height: 200,
format: "image/png".to_string(),
color_type: "rgba".to_string(),
file_size: 1337,
created_at: None,
duration: None,
};
db.ensure_image_has_metadata(&image, &metadata)
.await
.unwrap();
assert!(db.get_metadata(&image).await.unwrap().is_some());
}
#[sqlx::test(migrator = "MIGRATOR")]
async fn test_operate_image_tag(pool: Pool) {
let db = Database::new(pool);
let image = PixelHash::try_from("329435e5e66be809").unwrap();
assert!(db.ensure_image_has_tags(&image, &["cat"]).await.is_ok());
assert!(db.ensure_image_has_tags(&image, &["cat"]).await.is_ok());
assert!(db.ensure_image_has_tags(&image, &["dog"]).await.is_ok());
assert_eq!(
vec!["cat".to_string(), "dog".to_string()],
db.get_tags(&image).await.unwrap()
);
db.ensure_tags_removed(&image, &["dog"]).await.unwrap();
db.ensure_tags_removed(&image, &["dog"]).await.unwrap();
assert_eq!(vec!["cat".to_string()], db.get_tags(&image).await.unwrap());
}
#[sqlx::test(migrator = "MIGRATOR")]
async fn test_query_image(pool: Pool) {
let db = Database::new(pool);
let image_cat = PixelHash::try_from("329435e5e66be809").unwrap();
let image_dog = PixelHash::try_from("229435e5e66be809").unwrap();
let image_cat_and_dog = PixelHash::try_from("129435e5e66be809").unwrap();
assert!(db.ensure_image_has_tags(&image_cat, &["cat"]).await.is_ok());
assert!(db.ensure_image_has_tags(&image_dog, &["dog"]).await.is_ok());
assert!(
db.ensure_image_has_tags(&image_cat_and_dog, &["cat", "dog"])
.await
.is_ok()
);
let query_cat = ImageQuery::new(ImageQueryKind::Where(ImageQueryExpr::tag("cat")));
let query_dog = ImageQuery::new(ImageQueryKind::Where(ImageQueryExpr::tag("dog")));
let query_cat_and_dog = ImageQuery::new(ImageQueryKind::Where(
ImageQueryExpr::tag("cat").and(ImageQueryExpr::tag("dog")),
));
let mut res = db.query_image(query_cat).await.unwrap();
res.sort();
assert_eq!(vec![image_cat_and_dog.clone(), image_cat.clone()], res);
let mut res = db.query_image(query_dog).await.unwrap();
res.sort();
assert_eq!(vec![image_cat_and_dog.clone(), image_dog.clone()], res);
let mut res = db.query_image(query_cat_and_dog).await.unwrap();
res.sort();
assert_eq!(vec![image_cat_and_dog], res);
}
#[sqlx::test(migrator = "MIGRATOR")]
async fn test_count_image(pool: Pool) {
let db = Database::new(pool);
let image_cat = PixelHash::try_from("329435e5e66be809").unwrap();
let image_dog = PixelHash::try_from("229435e5e66be809").unwrap();
let image_cat_and_dog = PixelHash::try_from("129435e5e66be809").unwrap();
assert!(db.ensure_image_has_tags(&image_cat, &["cat"]).await.is_ok());
assert!(db.ensure_image_has_tags(&image_dog, &["dog"]).await.is_ok());
assert!(
db.ensure_image_has_tags(&image_cat_and_dog, &["cat", "dog"])
.await
.is_ok()
);
let query_cat = ImageQuery::new(ImageQueryKind::Where(ImageQueryExpr::tag("cat")));
let query_dog = ImageQuery::new(ImageQueryKind::Where(ImageQueryExpr::tag("dog")));
let query_cat_and_dog = ImageQuery::new(ImageQueryKind::Where(
ImageQueryExpr::tag("cat").and(ImageQueryExpr::tag("dog")),
));
assert_eq!(2, db.count_image(query_cat).await.unwrap());
assert_eq!(2, db.count_image(query_dog).await.unwrap());
assert_eq!(1, db.count_image(query_cat_and_dog).await.unwrap());
}
#[sqlx::test(migrator = "MIGRATOR")]
async fn test_count_image_by_tag(pool: Pool) {
let db = Database::new(pool);
let image_cat = PixelHash::try_from("329435e5e66be809").unwrap();
let image_dog = PixelHash::try_from("229435e5e66be809").unwrap();
let image_cat_and_dog = PixelHash::try_from("129435e5e66be809").unwrap();
assert!(db.ensure_image_has_tags(&image_cat, &["cat"]).await.is_ok());
assert!(db.ensure_image_has_tags(&image_dog, &["dog"]).await.is_ok());
assert!(
db.ensure_image_has_tags(&image_cat_and_dog, &["cat", "dog"])
.await
.is_ok()
);
db.refresh_image_count().await.unwrap();
assert_eq!(2, db.count_image_by_tag("cat").await.unwrap());
assert_eq!(2, db.count_image_by_tag("dog").await.unwrap());
}
#[sqlx::test(migrator = "MIGRATOR")]
async fn test_query_tags(pool: Pool) {
let db = Database::new(pool);
assert!(db.ensure_tags(&["cat"]).await.is_ok());
assert!(db.ensure_tags(&["dog"]).await.is_ok());
let query_cat = TagQuery::new(TagQueryKind::Where(TagQueryExpr::Exact("cat".to_string())));
let query_dog = TagQuery::new(TagQueryKind::Where(TagQueryExpr::Exact("dog".to_string())));
let query_all = TagQuery::new(TagQueryKind::All);
let query_contains_ca = TagQuery::new(TagQueryKind::Where(TagQueryExpr::Contains(
"ca".to_string(),
)));
assert_eq!(
vec!["cat".to_string(), "dog".to_string()],
db.query_tags(query_all).await.unwrap()
);
assert_eq!(
vec!["cat".to_string()],
db.query_tags(query_cat).await.unwrap()
);
assert_eq!(
vec!["dog".to_string()],
db.query_tags(query_dog).await.unwrap()
);
assert_eq!(
vec!["cat".to_string()],
db.query_tags(query_contains_ca).await.unwrap()
);
}
#[sqlx::test(migrator = "MIGRATOR")]
async fn test_get_source(pool: Pool) {
let db = Database::new(pool);
let image_has_no_source = PixelHash::try_from("329435e5e66be809").unwrap();
db.ensure_image(&image_has_no_source).await.unwrap();
assert_eq!(None, db.get_source(&image_has_no_source).await.unwrap());
let image_has_source = PixelHash::try_from("329435e5e66be800").unwrap();
db.ensure_image_has_source(&image_has_source, "source")
.await
.unwrap();
assert_eq!(
Some("source".to_string()),
db.get_source(&image_has_source).await.unwrap()
);
}
}