use crate::model::Model;
use serde::de::Error;
use serde::Deserialize;
use serde::{Serialize, Serializer};
use sql::query::Criteria;
use sql::query::SelectColumn;
use sql::{Expr, Operation, Where};
use sqlx::{Database, Decode, Encode, Type};
use std::ops::{Deref, DerefMut};
pub trait JoinMeta {
type IdType: Clone + Send + Eq + PartialEq + std::hash::Hash;
fn _id(&self) -> Self::IdType;
}
impl<T: JoinMeta> JoinMeta for Option<T> {
type IdType = Option<T::IdType>;
fn _id(&self) -> Self::IdType {
self.as_ref().map(|x| x._id())
}
}
impl<T: JoinMeta> JoinMeta for Join<T> {
type IdType = T::IdType;
fn _id(&self) -> Self::IdType {
self.id.clone()
}
}
pub trait Loadable<DB, T: JoinMeta> {
#[allow(async_fn_in_trait)]
async fn load<'s, 'e, E>(&'s mut self, db: E) -> crate::error::Result<&'s T>
where
T::IdType: 'e + Send + Sync,
E: 'e + sqlx::Executor<'e, Database = DB>,
T: 's;
}
#[derive(Debug)]
pub struct Join<T: JoinMeta> {
pub id: T::IdType,
data: JoinData<T>,
}
#[derive(Debug)]
pub enum JoinData<T: JoinMeta> {
NotQueried,
QueryResult(T),
Modified(T),
}
impl<T: JoinMeta> Join<T> {
pub fn new_with_id(id: T::IdType) -> Self {
Self {
id,
data: JoinData::NotQueried,
}
}
pub fn new(obj: T) -> Self {
Self {
id: crate::join::JoinMeta::_id(&obj),
data: JoinData::Modified(obj),
}
}
pub fn loaded(&self) -> bool {
match &self.data {
JoinData::NotQueried => false,
JoinData::QueryResult(_) => true,
JoinData::Modified(_) => true,
}
}
pub fn is_modified(&self) -> bool {
match &self.data {
JoinData::NotQueried => false,
JoinData::QueryResult(_) => false,
JoinData::Modified(_) => true,
}
}
#[doc(hidden)]
pub fn _take_modification(&mut self) -> Option<T> {
let owned = std::mem::replace(&mut self.data, JoinData::NotQueried);
match owned {
JoinData::NotQueried => None,
JoinData::QueryResult(_) => None,
JoinData::Modified(obj) => Some(obj),
}
}
fn transition_to_modified(&mut self) -> &mut T {
let owned = std::mem::replace(&mut self.data, JoinData::NotQueried);
match owned {
JoinData::NotQueried => {
panic!("Tried to deref_mut a joined object, but it has not been queried.")
}
JoinData::QueryResult(r) => {
self.data = JoinData::Modified(r);
}
JoinData::Modified(r) => {
self.data = JoinData::Modified(r);
}
}
match &mut self.data {
JoinData::Modified(r) => r,
_ => unreachable!(),
}
}
#[doc(hidden)]
pub fn _query_result(obj: T) -> Self {
Self {
id: obj._id(),
data: JoinData::QueryResult(obj),
}
}
}
impl<DB, T> Loadable<DB, T> for Join<T>
where
DB: Database,
T: JoinMeta + Model<DB> + Send,
T::IdType: for<'a> Encode<'a, DB> + for<'a> Decode<'a, DB> + Type<DB>,
{
async fn load<'s, 'e, E: sqlx::Executor<'e, Database = DB> + 'e>(
&'s mut self,
conn: E,
) -> crate::error::Result<&'s T>
where
T::IdType: 'e + Send + Sync,
T: 's,
{
let model = T::fetch_one(self.id.clone(), conn).await?;
self.data = JoinData::QueryResult(model);
let s = &*self;
Ok(s.deref())
}
}
impl<T: JoinMeta> Deref for Join<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
match &self.data {
JoinData::NotQueried => {
panic!("Tried to deref a joined object, but it has not been queried.")
}
JoinData::QueryResult(r) => r,
JoinData::Modified(r) => r,
}
}
}
impl<T: JoinMeta> DerefMut for Join<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.transition_to_modified()
}
}
#[doc(hidden)]
#[derive(Debug, Clone, Copy)]
pub enum JoinDescription {
ManyToOne {
columns: &'static [&'static str],
foreign_table: &'static str,
local_column: &'static str,
field: &'static str,
foreign_key: &'static str,
},
}
pub fn column_alias(field: &str, column: &str) -> String {
format!("__{}__{}", field, column)
}
pub fn select_columns(
columns: &'static [&'static str],
field: &'static str,
) -> impl Iterator<Item = SelectColumn> + 'static {
columns
.iter()
.map(|&c| SelectColumn::table_column(field, c).alias(column_alias(field, c)))
}
pub fn criteria(local_table: &str, local_column: &str, remote_table: &str, remote_column: &str) -> Criteria {
Criteria::On(Where::Expr(Expr::BinOp(
Operation::Eq,
Expr::Column {
schema: None,
table: Some(local_table.to_string()),
column: local_column.to_string(),
}
.into(),
Expr::Column {
schema: None,
table: Some(remote_table.to_string()),
column: remote_column.to_string(),
}
.into(),
)))
}
impl<T: JoinMeta + Serialize> Serialize for Join<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match &self.data {
JoinData::Modified(data) => data.serialize(serializer),
JoinData::NotQueried => serializer.serialize_none(),
JoinData::QueryResult(data) => data.serialize(serializer),
}
}
}
impl<'de, T> Deserialize<'de> for Join<T>
where
T: JoinMeta + Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let data = Option::<T>::deserialize(deserializer)?;
let (id_type, join_data) = match data {
Some(value) => (T::_id(&value), JoinData::QueryResult(value)),
None => return Err(D::Error::custom("Invalid value")),
};
Ok(Join {
id: id_type,
data: join_data,
})
}
}