dojo-orm 0.2.2

A simple ORM for Rust
Documentation
mod async_graphql;

use std::error::Error;
use std::fmt::{Debug, Display};
use std::marker::PhantomData;

use crate::order_by::Direction;
use crate::types::{accepts, to_sql_checked, IsNull, ToSql, Type};
use crate::Model;
use anyhow::Result;
use base64ct::{Base64, Encoding};
use bytes::BytesMut;
use chrono::NaiveDateTime;
use serde::{Deserialize, Serialize};
use tracing::debug;

pub use async_graphql::*;

pub trait DefaultSortKeys {
    fn keys() -> Vec<String>;

    fn order_by_stmt(direction: Direction) -> String {
        let mut stmt = "".to_string();
        for (i, order) in Self::keys().iter().enumerate() {
            if i > 0 {
                stmt.push_str(", ");
            }
            stmt.push_str(order);
            if i == 0 {
                if direction == Direction::Asc {
                    stmt.push_str(" ASC");
                } else {
                    stmt.push_str(" DESC");
                }
            } else {
                stmt.push_str(" ASC");
            }
        }

        stmt
    }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Row {
    pub column: String,
    pub value: crate::model::Value,
}

impl Row {
    pub fn new(column: String, value: crate::model::Value) -> Self {
        Self { column, value }
    }
}

#[derive(Debug, Clone)]
pub struct Cursor {
    pub values: Vec<Row>,
}

impl Cursor {
    pub fn to_where_stmt(&self, direction: Direction) -> (String, Vec<&(dyn ToSql + Sync)>) {
        let mut columns = vec![];
        let mut params: Vec<&(dyn ToSql + Sync)> = vec![];
        for value in &self.values {
            columns.push(value.column.clone());
            params.push(&value.value);
        }
        let mut stmt = "(".to_string();
        stmt.push_str(&columns.join(", "));
        stmt.push_str(") ");

        if direction == Direction::Asc {
            stmt.push('>');
        } else {
            stmt.push('<');
        }
        stmt.push_str(" (");
        stmt.push_str(
            &params
                .iter()
                .enumerate()
                .map(|(i, _)| format!("${}", i + 1))
                .collect::<Vec<_>>()
                .join(", "),
        );
        stmt.push(')');

        (stmt, params)
    }

    pub fn to_order_by_stmt(&self, direction: Direction) -> String {
        let keys = self
            .values
            .iter()
            .map(|v| v.column.clone())
            .collect::<Vec<_>>();
        Self::order_by_stmt_by_keys(&keys, direction)
    }

    pub fn order_by_stmt_by_keys(keys: &[String], direction: Direction) -> String {
        let mut stmt = "".to_string();
        if let Some(value) = keys.first() {
            stmt.push_str(value);
            if direction == Direction::Asc {
                stmt.push_str(" ASC");
            } else {
                stmt.push_str(" DESC");
            }
        }

        for value in keys.iter().skip(1) {
            stmt.push_str(", ");
            stmt.push_str(value);
            stmt.push_str(" ASC");
        }

        stmt
    }
}

impl Cursor {
    pub fn new(values: Vec<Row>) -> Self {
        Self { values }
    }

    pub fn encode(&self) -> String {
        // it's safe, trust me bro.
        let buf = bincode::serialize(&self.values).unwrap();
        Base64::encode_string(&buf)
    }

    pub fn decode(encoded: &str) -> Result<Self> {
        let decoded = Base64::decode_vec(encoded).unwrap();
        let values: Vec<Row> = bincode::deserialize(&decoded[..])?;
        Ok(Self { values })
    }
}

#[derive(Debug)]
pub struct Pagination<T>
where
    T: Model + Debug,
{
    pub items: Vec<T>,
    pub before: Option<Cursor>,
    pub after: Option<Cursor>,
    pub first: Option<i64>,
    pub last: Option<i64>,
    pub total_nodes: i64,
    pub has_next: bool,
    pub has_previous: bool,
}

impl<T> Pagination<T>
where
    T: Model + Debug,
{
    pub fn new(
        items: Vec<T>,
        first: Option<i64>,
        after: Option<Cursor>,
        last: Option<i64>,
        before: Option<Cursor>,
        total_nodes: i64,
    ) -> Self {
        let limit = first.or(last).unwrap_or(0);
        let has_next_or_previous = items.len() as i64 == limit + 1;
        let has_next = first.is_some() && has_next_or_previous;
        let has_previous = last.is_some() && has_next_or_previous;

        let mut items = items;
        if has_next_or_previous {
            items.pop();
        }

        Self {
            items,
            before,
            after,
            first,
            last,
            total_nodes,
            has_next,
            has_previous,
        }
    }

    pub fn end_cursor(&self) -> Option<Cursor> {
        self.items.last().map(|item| item.cursor())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use chrono::NaiveDateTime;
    use dojo_macros::Model;
    use googletest::prelude::*;
    use uuid::Uuid;

    #[test]
    fn test_cursor_to_sql_with_1_key() -> anyhow::Result<()> {
        let created_at = NaiveDateTime::parse_from_str("2024-01-07 12:34:56", "%Y-%m-%d %H:%M:%S")?;
        let cursor_value = Row {
            column: "created_at".to_string(),
            value: crate::model::Value::NaiveDateTime(created_at),
        };
        let cursor = Cursor::new(vec![cursor_value]);
        let (sql, params) = cursor.to_where_stmt(Direction::Asc);
        println!("sql: {}", sql);
        println!("params: {:?}", params);

        Ok(())
    }

    #[test]
    fn test_cursor_to_sql_with_2_key() -> anyhow::Result<()> {
        let created_at = NaiveDateTime::parse_from_str("2024-01-07 12:34:56", "%Y-%m-%d %H:%M:%S")?;
        let uuid = Uuid::parse_str("ce2087a7-bdbc-4453-9fb8-d4dff3584f3e")?;
        let cursor = Cursor::new(vec![
            Row {
                column: "created_at".to_string(),
                value: crate::model::Value::NaiveDateTime(created_at),
            },
            Row {
                column: "id".to_string(),
                value: crate::model::Value::Uuid(uuid),
            },
        ]);
        let (sql, params) = cursor.to_where_stmt(Direction::Asc);
        println!("sql: {}", sql);
        println!("params: {:?}", params);

        Ok(())
    }

    #[test]
    fn test_decode_cursor() -> anyhow::Result<()> {
        let created_at = NaiveDateTime::parse_from_str("2024-01-07 12:34:56", "%Y-%m-%d %H:%M:%S")?;
        let cursor_value = Row {
            column: "created_at".to_string(),
            value: crate::model::Value::NaiveDateTime(created_at),
        };
        let cursor = Cursor::new(vec![cursor_value]);
        let encoded = cursor.encode();

        let decoded = Cursor::decode(&encoded).unwrap();
        assert_that!(
            decoded,
            pat!(Cursor {
                values: contains_each![pat!(Row {
                    column: eq("created_at".to_string()),
                    value: eq(crate::model::Value::NaiveDateTime(created_at)),
                }),],
            })
        );

        Ok(())
    }
}