use super::connection::{DatabaseBackend, DatabaseConnection};
use crate::orm::Model;
use reinhardt_query::prelude::{
Alias, BinOper, ColumnRef, Expr, Func, MySqlQueryBuilder, PostgresQueryBuilder, Query,
QueryBuilder, SelectStatement, SqliteQueryBuilder,
};
use serde::{Serialize, de::DeserializeOwned};
use std::marker::PhantomData;
fn build_select_sql(
stmt: &SelectStatement,
backend: DatabaseBackend,
) -> (String, reinhardt_query::prelude::Values) {
match backend {
DatabaseBackend::Postgres => PostgresQueryBuilder.build_select(stmt),
DatabaseBackend::MySql => MySqlQueryBuilder.build_select(stmt),
DatabaseBackend::Sqlite => SqliteQueryBuilder.build_select(stmt),
}
}
pub struct ReverseAccessor<S, T>
where
S: Model,
T: Model + Serialize + DeserializeOwned,
{
source_id: S::PrimaryKey,
foreign_key_field: String,
db: DatabaseConnection,
limit: Option<usize>,
offset: Option<usize>,
_phantom_source: PhantomData<S>,
_phantom_target: PhantomData<T>,
}
impl<S, T> ReverseAccessor<S, T>
where
S: Model,
T: Model + Serialize + DeserializeOwned,
{
pub fn new(source: &S, foreign_key_field: &str, db: DatabaseConnection) -> Self {
let source_id = source
.primary_key()
.expect("Source model must have primary key")
.clone();
Self {
source_id,
foreign_key_field: foreign_key_field.to_string(),
db,
limit: None,
offset: None,
_phantom_source: PhantomData,
_phantom_target: PhantomData,
}
}
pub async fn all(&self) -> Result<Vec<T>, String> {
let mut query = Query::select();
query
.from(Alias::new(T::table_name()))
.column(ColumnRef::table_asterisk(Alias::new(T::table_name())))
.and_where(
Expr::col(Alias::new(&self.foreign_key_field))
.binary(BinOper::Equal, Expr::val(self.source_id.to_string())),
);
if let Some(limit) = self.limit {
query.limit(limit as u64);
}
if let Some(offset) = self.offset {
query.offset(offset as u64);
}
let query = query.to_owned();
let (sql, _values) = build_select_sql(&query, self.db.backend());
let rows = self
.db
.query(&sql, vec![])
.await
.map_err(|e| e.to_string())?;
rows.into_iter()
.map(|row| serde_json::from_value(row.data).map_err(|e| e.to_string()))
.collect()
}
pub async fn count(&self) -> Result<usize, String> {
let query = Query::select()
.from(Alias::new(T::table_name()))
.expr(Func::count(Expr::asterisk().into_simple_expr()))
.and_where(
Expr::col(Alias::new(&self.foreign_key_field))
.binary(BinOper::Equal, Expr::val(self.source_id.to_string())),
)
.to_owned();
let (sql, _) = build_select_sql(&query, self.db.backend());
let rows = self
.db
.query(&sql, vec![])
.await
.map_err(|e| e.to_string())?;
if let Some(row) = rows.first()
&& let Some(count_value) = row.data.get("count")
&& let Some(count) = count_value.as_i64()
{
return Ok(count as usize);
}
Ok(0)
}
pub fn limit(mut self, limit: usize) -> Self {
self.limit = Some(limit);
self
}
pub fn offset(mut self, offset: usize) -> Self {
self.offset = Some(offset);
self
}
pub fn paginate(self, page: usize, page_size: usize) -> Self {
let offset = page.saturating_sub(1) * page_size;
self.offset(offset).limit(page_size)
}
}
#[cfg(test)]
mod tests {
use super::*;
use reinhardt_query::prelude::QueryStatementBuilder;
#[test]
fn test_sql_generation_all() {
let query = Query::select()
.from(Alias::new("tweets"))
.column(ColumnRef::table_asterisk(Alias::new("tweets")))
.and_where(
Expr::col(Alias::new("user_id")).binary(BinOper::Equal, Expr::val("user-123")),
)
.to_owned();
let (sql, _) = query.build(PostgresQueryBuilder);
assert!(sql.contains("SELECT"));
assert!(sql.contains("tweets"));
assert!(sql.contains("user_id"));
}
#[test]
fn test_sql_generation_count() {
let query = Query::select()
.from(Alias::new("tweets"))
.expr(Func::count(Expr::asterisk().into_simple_expr()))
.and_where(
Expr::col(Alias::new("user_id")).binary(BinOper::Equal, Expr::val("user-123")),
)
.to_owned();
let (sql, _) = query.build(PostgresQueryBuilder);
assert!(sql.contains("SELECT"));
assert!(sql.contains("COUNT"));
assert!(sql.contains("tweets"));
assert!(sql.contains("user_id"));
}
#[test]
fn test_sql_generation_with_limit_offset() {
let mut query = Query::select();
query
.from(Alias::new("tweets"))
.column(ColumnRef::table_asterisk(Alias::new("tweets")))
.and_where(
Expr::col(Alias::new("user_id")).binary(BinOper::Equal, Expr::val("user-123")),
)
.limit(10)
.offset(20);
let query = query.to_owned();
let (sql, _) = query.build(PostgresQueryBuilder);
assert!(sql.contains("LIMIT"));
assert!(sql.contains("OFFSET"));
}
#[allow(dead_code)]
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct TestUser {
id: String,
username: String,
}
#[derive(Debug, Clone)]
struct TestUserFields;
impl crate::orm::model::FieldSelector for TestUserFields {
fn with_alias(self, _alias: &str) -> Self {
self
}
}
impl Model for TestUser {
type PrimaryKey = String;
type Fields = TestUserFields;
fn table_name() -> &'static str {
"users"
}
fn app_label() -> &'static str {
"auth"
}
fn primary_key(&self) -> Option<Self::PrimaryKey> {
Some(self.id.clone())
}
fn set_primary_key(&mut self, value: Self::PrimaryKey) {
self.id = value;
}
fn primary_key_field() -> &'static str {
"id"
}
fn new_fields() -> Self::Fields {
TestUserFields
}
}
#[allow(dead_code)]
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct TestTweet {
id: String,
user_id: String,
content: String,
}
#[derive(Debug, Clone)]
struct TestTweetFields;
impl crate::orm::model::FieldSelector for TestTweetFields {
fn with_alias(self, _alias: &str) -> Self {
self
}
}
impl Model for TestTweet {
type PrimaryKey = String;
type Fields = TestTweetFields;
fn table_name() -> &'static str {
"tweets"
}
fn app_label() -> &'static str {
"twitter"
}
fn primary_key(&self) -> Option<Self::PrimaryKey> {
Some(self.id.clone())
}
fn set_primary_key(&mut self, value: Self::PrimaryKey) {
self.id = value;
}
fn primary_key_field() -> &'static str {
"id"
}
fn new_fields() -> Self::Fields {
TestTweetFields
}
}
}