use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use crate::orm::Model;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenericRelationSet<T> {
content_type_id: i64,
object_id: i64,
ct_field: String,
fk_field: String,
#[serde(skip)]
_phantom: PhantomData<T>,
}
impl<T> GenericRelationSet<T> {
pub fn new(
content_type_id: i64,
object_id: i64,
ct_field: impl Into<String>,
fk_field: impl Into<String>,
) -> Self {
Self {
content_type_id,
object_id,
ct_field: ct_field.into(),
fk_field: fk_field.into(),
_phantom: PhantomData,
}
}
pub fn with_defaults(content_type_id: i64, object_id: i64) -> Self {
Self::new(content_type_id, object_id, "content_type_id", "object_id")
}
pub fn content_type_id(&self) -> i64 {
self.content_type_id
}
pub fn object_id(&self) -> i64 {
self.object_id
}
pub fn ct_field(&self) -> &str {
&self.ct_field
}
pub fn fk_field(&self) -> &str {
&self.fk_field
}
pub fn where_clause(&self) -> String {
format!(
"{} = {} AND {} = {}",
self.ct_field, self.content_type_id, self.fk_field, self.object_id
)
}
pub fn sql_condition(&self) -> (String, Vec<i64>) {
let sql = format!("{} = $1 AND {} = $2", self.ct_field, self.fk_field);
(sql, vec![self.content_type_id, self.object_id])
}
}
impl<T: Model> GenericRelationSet<T> {
pub fn query(&self) -> super::query::QuerySet<T> {
use crate::orm::query::{Filter, FilterOperator, FilterValue};
let ct_filter = Filter::new(
self.ct_field.clone(),
FilterOperator::Eq,
FilterValue::Integer(self.content_type_id),
);
let fk_filter = Filter::new(
self.fk_field.clone(),
FilterOperator::Eq,
FilterValue::Integer(self.object_id),
);
T::objects().all().filter(ct_filter).filter(fk_filter)
}
pub async fn all(&self) -> reinhardt_core::exception::Result<Vec<T>> {
self.query().all().await
}
pub async fn count(&self) -> reinhardt_core::exception::Result<usize> {
self.query().count().await
}
pub async fn exists(&self) -> reinhardt_core::exception::Result<bool> {
Ok(self.count().await? > 0)
}
pub async fn first(&self) -> reinhardt_core::exception::Result<Option<T>> {
self.query().first().await
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenericRelationConfig {
related_model: String,
ct_field: String,
fk_field: String,
related_name: Option<String>,
}
impl GenericRelationConfig {
pub fn new(related_model: impl Into<String>) -> Self {
Self {
related_model: related_model.into(),
ct_field: "content_type_id".to_string(),
fk_field: "object_id".to_string(),
related_name: None,
}
}
pub fn ct_field(mut self, field: impl Into<String>) -> Self {
self.ct_field = field.into();
self
}
pub fn fk_field(mut self, field: impl Into<String>) -> Self {
self.fk_field = field.into();
self
}
pub fn related_name(mut self, name: impl Into<String>) -> Self {
self.related_name = Some(name.into());
self
}
pub fn related_model(&self) -> &str {
&self.related_model
}
pub fn get_ct_field(&self) -> &str {
&self.ct_field
}
pub fn get_fk_field(&self) -> &str {
&self.fk_field
}
pub fn get_related_name(&self) -> Option<&str> {
self.related_name.as_deref()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generic_relation_set_new() {
let relation: GenericRelationSet<()> =
GenericRelationSet::new(1, 42, "content_type_id", "object_id");
assert_eq!(relation.content_type_id(), 1);
assert_eq!(relation.object_id(), 42);
assert_eq!(relation.ct_field(), "content_type_id");
assert_eq!(relation.fk_field(), "object_id");
}
#[test]
fn test_generic_relation_set_with_defaults() {
let relation: GenericRelationSet<()> = GenericRelationSet::with_defaults(5, 100);
assert_eq!(relation.content_type_id(), 5);
assert_eq!(relation.object_id(), 100);
assert_eq!(relation.ct_field(), "content_type_id");
assert_eq!(relation.fk_field(), "object_id");
}
#[test]
fn test_generic_relation_set_where_clause() {
let relation: GenericRelationSet<()> = GenericRelationSet::new(1, 42, "ct_id", "obj_id");
assert_eq!(relation.where_clause(), "ct_id = 1 AND obj_id = 42");
}
#[test]
fn test_generic_relation_set_sql_condition() {
let relation: GenericRelationSet<()> =
GenericRelationSet::new(1, 42, "content_type_id", "object_id");
let (sql, values) = relation.sql_condition();
assert_eq!(sql, "content_type_id = $1 AND object_id = $2");
assert_eq!(values, vec![1_i64, 42_i64]);
}
#[test]
fn test_generic_relation_config_new() {
let config = GenericRelationConfig::new("Comment");
assert_eq!(config.related_model(), "Comment");
assert_eq!(config.get_ct_field(), "content_type_id");
assert_eq!(config.get_fk_field(), "object_id");
assert!(config.get_related_name().is_none());
}
#[test]
fn test_generic_relation_config_builder() {
let config = GenericRelationConfig::new("Comment")
.ct_field("ct_id")
.fk_field("obj_id")
.related_name("comments");
assert_eq!(config.related_model(), "Comment");
assert_eq!(config.get_ct_field(), "ct_id");
assert_eq!(config.get_fk_field(), "obj_id");
assert_eq!(config.get_related_name(), Some("comments"));
}
#[test]
fn test_generic_relation_set_serialization() {
let relation: GenericRelationSet<()> = GenericRelationSet::new(1, 42, "ct", "obj");
let serialized = serde_json::to_string(&relation).unwrap();
assert!(serialized.contains("1"));
assert!(serialized.contains("42"));
let deserialized: GenericRelationSet<()> = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized.content_type_id(), 1);
assert_eq!(deserialized.object_id(), 42);
}
}