lume 0.13.1

A simple and intuitive Query Builder inspired by Drizzle
Documentation
#![warn(missing_docs)]

//! # Query Module
//!
//! This module provides type-safe query building and execution functionality.
//! It includes the `Query<T>` struct for building and executing database queries.

use std::{fmt::Debug, marker::PhantomData, sync::Arc};

#[cfg(feature = "mysql")]
use sqlx::MySqlPool;

#[cfg(feature = "postgres")]
use sqlx::PgPool;

#[cfg(feature = "sqlite")]
use sqlx::SqlitePool;

use crate::filter::Filtered;
use crate::helpers::{StartingSql, bind_value, build_filter_expr, get_starting_sql};
use crate::schema::{UpdateTrait, Value};
use crate::{database::error::DatabaseError, schema::Schema};

#[derive(Debug)]
/// Represents a SQL UPDATE operation for a given table and update struct.
///
/// The `Update<T, U>` struct is used to construct and execute a type-safe
/// SQL UPDATE statement for the table represented by the schema type `T`,
/// using the update struct `U` to specify which columns and values to update.
///
/// # Type Parameters
///
/// * `T` - The schema type representing the table to update. This type must implement [`Schema`].
/// * `U` - The update struct type, typically generated by the `define_schema!` macro, which must implement [`UpdateTrait`].
///
/// # Fields
///
/// - `table`: Marker for the schema type `T`. Used for type safety and SQL generation.
/// - `update_table`: Marker for the update struct type `U`. Used for type safety.
/// - `filters`: A list of filter conditions (implementing [`Filtered`]) to restrict which rows are updated.
/// - `conn`: The database connection pool used to execute the update operation.
/// - `update_data`: A vector of (column name, value) pairs representing the columns and values to be updated.
///
/// # Example
///
/// ```no_run
/// use lume::define_schema;
/// use lume::database::Database;
/// use lume::filter::eq_value;
/// use lume::schema::Schema;
/// use lume::schema::ColumnInfo;
///
/// define_schema! {
///     Users {
///         id: u64 [primary_key().not_null().auto_increment()],
///         username: String [not_null()],
///         age: u16,
///     }
/// }
///
/// #[tokio::main]
/// async fn main() -> Result<(), lume::database::error::DatabaseError> {
///     let db = Database::connect("mysql://...").await?;
///     db.update::<Users, UpdateUsers>()
///         .set(UpdateUsers { age: Some(30), ..Default::default() })
///         .filter(eq_value(Users::username(), "alice"))
///         .execute()
///         .await?;
///     Ok(())
/// }
/// ```
pub struct Update<T: Schema + Debug, U: UpdateTrait + Debug> {
    /// Marker for the schema type `T`.
    table: PhantomData<T>,
    /// Marker for the update struct type `U`.
    update_table: PhantomData<U>,
    /// List of filters to apply to the update query.
    filters: Vec<Box<dyn Filtered>>,
    /// Database connection pool.
    #[cfg(feature = "mysql")]
    conn: Arc<MySqlPool>,

    #[cfg(feature = "postgres")]
    conn: Arc<PgPool>,

    #[cfg(feature = "sqlite")]
    conn: Arc<SqlitePool>,

    /// Vector of (column name, value) pairs to be updated.
    update_data: Vec<(&'static str, Value)>,
}

impl<T: Schema + Debug, U: UpdateTrait + Debug> Update<T, U> {
    #[cfg(feature = "mysql")]
    pub(crate) fn new(conn: Arc<MySqlPool>) -> Self {
        Self {
            table: PhantomData,
            update_table: PhantomData,
            filters: Vec::new(),
            update_data: Vec::new(),
            conn,
        }
    }

    #[cfg(feature = "postgres")]
    pub(crate) fn new(conn: Arc<PgPool>) -> Self {
        Self {
            table: PhantomData,
            update_table: PhantomData,
            filters: Vec::new(),
            update_data: Vec::new(),
            conn,
        }
    }

    #[cfg(feature = "sqlite")]
    pub(crate) fn new(conn: Arc<SqlitePool>) -> Self {
        Self {
            table: PhantomData,
            update_table: PhantomData,
            filters: Vec::new(),
            update_data: Vec::new(),
            conn,
        }
    }

    /// Sets the columns and values to be updated in the query.
    ///
    /// This method takes an update struct (typically generated by the `define_schema!` macro)
    /// and extracts the fields that should be updated. Only fields set to `Some(value)` will
    /// be included in the update statement; fields set to `None` are ignored.
    ///
    /// # Arguments
    ///
    /// * `data` - An update struct containing the columns and values to update.
    ///
    /// # Returns
    ///
    /// Returns the updated query builder for further chaining.
    ///
    /// # Example
    ///
    /// ```no_run
    /// use lume::database::Database;
    /// use lume::define_schema;
    /// use lume::schema::ColumnInfo;
    /// use lume::schema::Schema;
    ///
    /// define_schema! {
    ///     Users {
    ///         id: u64 [primary_key().not_null().auto_increment()],
    ///         username: String [not_null()],
    ///         age: u16,
    ///     }
    /// }
    ///
    /// #[tokio::main]
    /// async fn main() {
    ///     let db = Database::connect("mysql://...").await.unwrap();
    ///     db.update::<Users, UpdateUsers>()
    ///         .set(UpdateUsers { age: Some(30), ..Default::default() })
    ///         .execute()
    ///         .await
    ///         .unwrap();
    /// }
    /// ```
    pub fn set(mut self, data: U) -> Self {
        self.update_data = data.get_updated();
        self
    }

    /// Adds a filter condition to the update query.
    ///
    /// This method allows you to specify a filter (typically created using filter combinators)
    /// to restrict which rows will be updated. Multiple calls to `filter` will combine filters
    /// using logical AND.
    ///
    /// # Arguments
    ///
    /// * `filter` - A filter condition implementing the [`Filtered`] trait. This can be created
    ///   using filter combinators such as [`eq_value`], [`and`], [`or`], etc.
    ///
    /// # Returns
    ///
    /// Returns the updated query builder for further chaining.
    ///
    /// # Example
    ///
    /// ```no_run
    /// use lume::database::Database;
    /// use lume::define_schema;
    /// use lume::filter::eq_value;
    /// use lume::schema::Schema;
    /// use lume::schema::ColumnInfo;
    ///
    /// define_schema! {
    ///     Users {
    ///         id: u64 [primary_key().not_null().auto_increment()],
    ///         username: String [not_null()],
    ///         age: u16,
    ///     }
    /// }
    ///
    /// #[tokio::main]
    /// async fn main() {
    ///     let db = Database::connect("mysql://...").await.unwrap();
    ///     db.update::<Users, UpdateUsers>()
    ///         .set(UpdateUsers { age: Some(30), ..Default::default() })
    ///         .filter(eq_value(Users::username(), "alice"))
    ///         .execute()
    ///         .await
    ///         .unwrap();
    /// }
    /// ```
    pub fn filter<F>(mut self, filter: F) -> Self
    where
        F: Filtered + 'static,
    {
        self.filters.push(Box::new(filter));
        self
    }

    /// Executes the SQL UPDATE operation with the specified update data and filters.
    ///
    /// This method builds the SQL UPDATE statement using the provided update data and filter conditions,
    /// binds the appropriate parameters, and executes the statement against the database.
    ///
    /// # Returns
    ///
    /// Returns `Ok(())` if the update operation was successful, or a [`DatabaseError`] if an error occurred.
    ///
    /// # Example
    ///
    /// ```no_run
    /// use lume::database::Database;
    /// use lume::define_schema;
    /// use lume::filter::eq_value;
    /// use lume::schema::Schema;
    /// use lume::schema::ColumnInfo;
    ///
    /// define_schema! {
    ///     Users {
    ///         id: u64 [primary_key().not_null().auto_increment()],
    ///         username: String [not_null()],
    ///         age: u16,
    ///     }
    /// }
    ///
    /// #[tokio::main]
    /// async fn main() {
    ///     let db = Database::connect("mysql://...").await.unwrap();
    ///     db.update::<Users, UpdateUsers>()
    ///         .set(UpdateUsers { age: Some(30), ..Default::default() })
    ///         .filter(eq_value(Users::username(), "alice"))
    ///         .execute()
    ///         .await
    ///         .unwrap();
    /// }
    /// ```
    pub async fn execute(self) -> Result<(), DatabaseError> {
        let sql = get_starting_sql(StartingSql::Update, T::table_name());
        let sql = Self::update_sql(sql, self.update_data);

        let mut params: Vec<Value> = Vec::new();
        let sql = Self::filter_sql(sql, self.filters, &mut params);

        let mut conn = self
            .conn
            .acquire()
            .await
            .map_err(DatabaseError::ConnectionError)?;
        let mut query = sqlx::query(&sql);
        for v in params {
            query = bind_value(query, v);
        }

        query
            .execute(conn.as_mut())
            .await
            .map_err(|e| DatabaseError::ExecutionError(e.to_string()))?;

        Ok(())
    }

    pub(crate) fn update_sql(mut sql: String, data: Vec<(&'static str, Value)>) -> String {
        if data.is_empty() {
            return sql;
        }

        for column in data {
            sql.push_str(&format!("{} = {}", column.0, column.1));
            sql.push_str(" ");
        }

        sql
    }

    pub(crate) fn filter_sql(
        mut sql: String,
        filters: Vec<Box<dyn Filtered>>,
        params: &mut Vec<Value>,
    ) -> String {
        if filters.is_empty() {
            return sql;
        }

        sql.push_str(" WHERE ");
        let mut parts: Vec<String> = Vec::with_capacity(filters.len());
        for filter in &filters {
            parts.push(build_filter_expr(filter.as_ref(), params));
        }
        sql.push_str(&parts.join(" AND "));

        sql
    }
}