use crate::core::{Model, Op, SqlValue};
use crate::query::QuerySet;
use sqlx::postgres::{PgPool, PgRow};
use sqlx::FromRow;
use super::ExecError;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ForeignKey<T, K = i64> {
Unloaded(K),
Loaded {
pk: K,
value: Box<T>,
},
}
impl<T, K> ForeignKey<T, K> {
#[must_use]
pub fn unloaded(pk: K) -> Self {
Self::Unloaded(pk)
}
#[must_use]
pub fn loaded(pk: K, value: T) -> Self {
Self::Loaded {
pk,
value: Box::new(value),
}
}
#[must_use]
pub fn is_loaded(&self) -> bool {
matches!(self, Self::Loaded { .. })
}
#[must_use]
pub fn value(&self) -> Option<&T> {
match self {
Self::Loaded { value, .. } => Some(value),
Self::Unloaded(_) => None,
}
}
#[must_use]
pub fn into_value(self) -> Option<T> {
match self {
Self::Loaded { value, .. } => Some(*value),
Self::Unloaded(_) => None,
}
}
#[must_use]
pub fn pk_ref(&self) -> &K {
match self {
Self::Unloaded(pk) | Self::Loaded { pk, .. } => pk,
}
}
}
impl<T, K: Clone> ForeignKey<T, K> {
#[must_use]
pub fn pk(&self) -> K {
self.pk_ref().clone()
}
}
impl<T, K: serde::Serialize> serde::Serialize for ForeignKey<T, K> {
fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
self.pk_ref().serialize(ser)
}
}
impl<T, K> From<K> for ForeignKey<T, K> {
fn from(pk: K) -> Self {
Self::Unloaded(pk)
}
}
impl<T, K: Clone + Into<SqlValue>> From<ForeignKey<T, K>> for SqlValue {
fn from(fk: ForeignKey<T, K>) -> Self {
match fk {
ForeignKey::Unloaded(k) | ForeignKey::Loaded { pk: k, .. } => k.into(),
}
}
}
impl<'r, T, K> sqlx::Decode<'r, sqlx::Postgres> for ForeignKey<T, K>
where
K: sqlx::Decode<'r, sqlx::Postgres>,
{
fn decode(
value: <sqlx::Postgres as sqlx::Database>::ValueRef<'r>,
) -> Result<Self, sqlx::error::BoxDynError> {
Ok(Self::Unloaded(<K as sqlx::Decode<sqlx::Postgres>>::decode(
value,
)?))
}
}
impl<T, K> sqlx::Type<sqlx::Postgres> for ForeignKey<T, K>
where
K: sqlx::Type<sqlx::Postgres>,
{
fn type_info() -> sqlx::postgres::PgTypeInfo {
<K as sqlx::Type<sqlx::Postgres>>::type_info()
}
fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool {
<K as sqlx::Type<sqlx::Postgres>>::compatible(ty)
}
}
#[cfg(feature = "mysql")]
impl<'r, T, K> sqlx::Decode<'r, sqlx::MySql> for ForeignKey<T, K>
where
K: sqlx::Decode<'r, sqlx::MySql>,
{
fn decode(
value: <sqlx::MySql as sqlx::Database>::ValueRef<'r>,
) -> Result<Self, sqlx::error::BoxDynError> {
Ok(Self::Unloaded(<K as sqlx::Decode<sqlx::MySql>>::decode(
value,
)?))
}
}
#[cfg(feature = "mysql")]
impl<T, K> sqlx::Type<sqlx::MySql> for ForeignKey<T, K>
where
K: sqlx::Type<sqlx::MySql>,
{
fn type_info() -> sqlx::mysql::MySqlTypeInfo {
<K as sqlx::Type<sqlx::MySql>>::type_info()
}
fn compatible(ty: &sqlx::mysql::MySqlTypeInfo) -> bool {
<K as sqlx::Type<sqlx::MySql>>::compatible(ty)
}
}
impl<T, K> ForeignKey<T, K>
where
T: Model + for<'r> FromRow<'r, PgRow> + Send + Unpin + crate::sql::LoadRelated,
K: Clone + Into<SqlValue> + Send + Sync + 'static,
{
pub async fn get(&mut self, pool: &PgPool) -> Result<&T, ExecError> {
self.get_on(pool).await
}
pub async fn get_on<'c, E>(&mut self, executor: E) -> Result<&T, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
if matches!(self, Self::Unloaded(_)) {
let pk = self.pk_ref().clone();
let pk_field = T::SCHEMA
.primary_key()
.ok_or(ExecError::MissingPrimaryKey {
table: T::SCHEMA.table,
})?;
let mut rows: Vec<T> = QuerySet::<T>::new()
.filter(pk_field.column, Op::Eq, pk.clone())
.fetch_on(executor)
.await?;
let value = rows.pop().ok_or_else(|| {
let sv: SqlValue = pk.clone().into();
ExecError::ForeignKeyTargetMissing {
table: T::SCHEMA.table,
pk: sv.to_display_string(),
}
})?;
*self = Self::Loaded {
pk,
value: Box::new(value),
};
}
match self {
Self::Loaded { value, .. } => Ok(value),
Self::Unloaded(_) => unreachable!("just transitioned to Loaded above"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn unloaded_constructor_and_pk_accessor() {
let fk: ForeignKey<()> = ForeignKey::unloaded(42);
assert_eq!(fk.pk(), 42);
assert!(!fk.is_loaded());
assert!(fk.value().is_none());
}
#[test]
fn loaded_constructor_caches_value() {
let fk = ForeignKey::loaded(7_i64, "alice".to_string());
assert_eq!(fk.pk(), 7);
assert!(fk.is_loaded());
assert_eq!(fk.value(), Some(&"alice".to_string()));
}
#[test]
fn from_i64_yields_unloaded() {
let fk: ForeignKey<()> = 99_i64.into();
match fk {
ForeignKey::Unloaded(pk) => assert_eq!(pk, 99),
ForeignKey::Loaded { .. } => panic!("expected Unloaded"),
}
}
#[test]
fn into_sqlvalue_gives_i64_in_either_state() {
let unloaded: ForeignKey<()> = ForeignKey::unloaded(1_i64);
let loaded = ForeignKey::loaded(2_i64, ());
assert!(matches!(SqlValue::from(unloaded), SqlValue::I64(1)));
assert!(matches!(SqlValue::from(loaded), SqlValue::I64(2)));
}
#[test]
fn into_value_consumes_when_loaded() {
let loaded = ForeignKey::loaded(3_i64, 100_u32);
assert_eq!(loaded.into_value(), Some(100));
let unloaded: ForeignKey<u32> = ForeignKey::unloaded(4_i64);
assert_eq!(unloaded.into_value(), None);
}
#[test]
fn string_pk_unloaded_round_trip() {
let fk: ForeignKey<(), String> = ForeignKey::unloaded("alice-uuid".to_owned());
assert_eq!(fk.pk_ref(), "alice-uuid");
assert_eq!(fk.pk(), "alice-uuid");
assert!(!fk.is_loaded());
}
#[test]
fn string_pk_lowers_to_sqlvalue_string() {
let fk: ForeignKey<(), String> = ForeignKey::unloaded("k".to_owned());
match SqlValue::from(fk) {
SqlValue::String(s) => assert_eq!(s, "k"),
other => panic!("expected SqlValue::String, got {other:?}"),
}
}
#[test]
fn uuid_pk_round_trip() {
let id = uuid::Uuid::nil();
let fk: ForeignKey<(), uuid::Uuid> = ForeignKey::unloaded(id);
assert_eq!(fk.pk(), id);
match SqlValue::from(fk) {
SqlValue::Uuid(u) => assert_eq!(u, id),
other => panic!("expected SqlValue::Uuid, got {other:?}"),
}
}
#[test]
fn from_string_yields_unloaded() {
let fk: ForeignKey<(), String> = "x".to_owned().into();
assert!(matches!(fk, ForeignKey::Unloaded(ref s) if s == "x"));
}
}