use std::marker::PhantomData;
use serde::{Deserialize, Serialize};
use super::Model;
#[derive(Debug, Clone)]
pub struct ForeignKey<T: Model> {
raw: T::PrimaryKey,
resolved: Option<Box<T>>,
_phantom: PhantomData<T>,
}
impl<T: Model> PartialEq for ForeignKey<T>
where
T::PrimaryKey: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.raw == other.raw
}
}
impl<T: Model> Eq for ForeignKey<T> where T::PrimaryKey: Eq {}
impl<T: Model> Default for ForeignKey<T>
where
T::PrimaryKey: Default,
{
fn default() -> Self {
Self::new(T::PrimaryKey::default())
}
}
impl<T: Model> ForeignKey<T> {
pub fn new(raw: T::PrimaryKey) -> Self {
Self {
raw,
resolved: None,
_phantom: PhantomData,
}
}
pub fn id(&self) -> T::PrimaryKey {
self.raw.clone()
}
pub fn id_ref(&self) -> &T::PrimaryKey {
&self.raw
}
pub fn set(&mut self, raw: T::PrimaryKey) {
self.raw = raw;
}
pub fn resolved(&self) -> Option<&T> {
self.resolved.as_deref()
}
pub fn set_resolved(&mut self, row: T) {
self.resolved = Some(Box::new(row));
}
}
impl<T: Model> ForeignKey<T> {
fn pk_column_name() -> &'static str {
T::FIELDS
.iter()
.find(|f| f.primary_key)
.map(|f| f.name)
.unwrap_or("id")
}
}
impl<T: Model> ForeignKey<T>
where
T::PrimaryKey: for<'q> sqlx::Encode<'q, sqlx::Sqlite> + sqlx::Type<sqlx::Sqlite>,
{
pub async fn resolve(&self, pool: &sqlx::SqlitePool) -> Result<T, sqlx::Error>
where
T: for<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> + Clone,
{
if let Some(cached) = &self.resolved {
return Ok(*cached.clone());
}
let columns: Vec<&str> = T::FIELDS.iter().map(|f| f.name).collect();
let col_list = columns.join(", ");
let pk_col = Self::pk_column_name();
let sql = format!(
"SELECT {} FROM {} WHERE {} = ? LIMIT 1",
col_list,
T::TABLE,
pk_col
);
sqlx::query_as::<sqlx::Sqlite, T>(&sql)
.bind(self.raw.clone())
.fetch_one(pool)
.await
}
}
impl<T: Model> ForeignKey<T>
where
T::PrimaryKey: for<'q> sqlx::Encode<'q, sqlx::Postgres> + sqlx::Type<sqlx::Postgres>,
{
pub async fn resolve_pg(&self, pool: &sqlx::PgPool) -> Result<T, sqlx::Error>
where
T: for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow> + Clone,
{
if let Some(cached) = &self.resolved {
return Ok(*cached.clone());
}
let columns: Vec<&str> = T::FIELDS.iter().map(|f| f.name).collect();
let col_list = columns.join(", ");
let pk_col = Self::pk_column_name();
let sql = format!(
"SELECT {} FROM {} WHERE {} = $1 LIMIT 1",
col_list,
T::TABLE,
pk_col
);
sqlx::query_as::<sqlx::Postgres, T>(&sql)
.bind(self.raw.clone())
.fetch_one(pool)
.await
}
}
impl<T: Model + Serialize> Serialize for ForeignKey<T>
where
T::PrimaryKey: Serialize,
{
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
if let Some(resolved) = &self.resolved {
resolved.serialize(s)
} else {
self.raw.serialize(s)
}
}
}
impl<'de, T: Model + serde::de::DeserializeOwned> Deserialize<'de> for ForeignKey<T>
where
T::PrimaryKey: serde::de::DeserializeOwned,
{
fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
use serde::de::Error;
let v = serde_json::Value::deserialize(d)?;
match v {
serde_json::Value::Object(_) => {
let pk_name = T::FIELDS
.iter()
.find(|f| f.primary_key)
.map(|f| f.name)
.unwrap_or("id");
let pk_v = v.get(pk_name).cloned().ok_or_else(|| {
D::Error::custom(format!(
"ForeignKey<{}>: nested object missing pk field `{pk_name}`",
T::NAME
))
})?;
let raw: T::PrimaryKey = serde_json::from_value(pk_v).map_err(D::Error::custom)?;
let resolved: T = serde_json::from_value(v).map_err(D::Error::custom)?;
Ok(Self {
raw,
resolved: Some(Box::new(resolved)),
_phantom: PhantomData,
})
}
other => {
let raw: T::PrimaryKey = serde_json::from_value(other).map_err(D::Error::custom)?;
Ok(Self::new(raw))
}
}
}
}
impl<T: Model> sqlx::Type<sqlx::Sqlite> for ForeignKey<T>
where
T::PrimaryKey: sqlx::Type<sqlx::Sqlite>,
{
fn type_info() -> sqlx::sqlite::SqliteTypeInfo {
<T::PrimaryKey as sqlx::Type<sqlx::Sqlite>>::type_info()
}
fn compatible(ty: &sqlx::sqlite::SqliteTypeInfo) -> bool {
<T::PrimaryKey as sqlx::Type<sqlx::Sqlite>>::compatible(ty)
}
}
impl<T: Model> sqlx::Type<sqlx::Postgres> for ForeignKey<T>
where
T::PrimaryKey: sqlx::Type<sqlx::Postgres>,
{
fn type_info() -> sqlx::postgres::PgTypeInfo {
<T::PrimaryKey as sqlx::Type<sqlx::Postgres>>::type_info()
}
fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool {
<T::PrimaryKey as sqlx::Type<sqlx::Postgres>>::compatible(ty)
}
}
impl<'r, T: Model> sqlx::Decode<'r, sqlx::Sqlite> for ForeignKey<T>
where
T::PrimaryKey: sqlx::Decode<'r, sqlx::Sqlite>,
{
fn decode(
value: sqlx::sqlite::SqliteValueRef<'r>,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let raw = <T::PrimaryKey as sqlx::Decode<sqlx::Sqlite>>::decode(value)?;
Ok(Self::new(raw))
}
}
impl<'r, T: Model> sqlx::Decode<'r, sqlx::Postgres> for ForeignKey<T>
where
T::PrimaryKey: sqlx::Decode<'r, sqlx::Postgres>,
{
fn decode(
value: sqlx::postgres::PgValueRef<'r>,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let raw = <T::PrimaryKey as sqlx::Decode<sqlx::Postgres>>::decode(value)?;
Ok(Self::new(raw))
}
}
impl<'q, T: Model> sqlx::Encode<'q, sqlx::Sqlite> for ForeignKey<T>
where
T::PrimaryKey: sqlx::Encode<'q, sqlx::Sqlite> + Clone,
{
fn encode_by_ref(
&self,
buf: &mut <sqlx::Sqlite as sqlx::Database>::ArgumentBuffer<'q>,
) -> Result<sqlx::encode::IsNull, Box<dyn std::error::Error + Send + Sync>> {
<T::PrimaryKey as sqlx::Encode<'q, sqlx::Sqlite>>::encode_by_ref(&self.raw, buf)
}
}
impl<'q, T: Model> sqlx::Encode<'q, sqlx::Postgres> for ForeignKey<T>
where
T::PrimaryKey: sqlx::Encode<'q, sqlx::Postgres> + Clone,
{
fn encode_by_ref(
&self,
buf: &mut <sqlx::Postgres as sqlx::Database>::ArgumentBuffer<'q>,
) -> Result<sqlx::encode::IsNull, Box<dyn std::error::Error + Send + Sync>> {
<T::PrimaryKey as sqlx::Encode<'q, sqlx::Postgres>>::encode_by_ref(&self.raw, buf)
}
}