use super::{SerializerError, ValidatorError};
use async_trait::async_trait;
use reinhardt_db::orm::{Model, query::*};
use serde::{Serialize, de::DeserializeOwned};
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
type ManyResolverFn<T, RelatedItem> = Arc<
dyn Fn(&T) -> Pin<Box<dyn Future<Output = Result<Vec<RelatedItem>, String>> + Send>>
+ Send
+ Sync,
>;
pub struct ManyRelationResolverFn<T, RelatedItem> {
inner: ManyResolverFn<T, RelatedItem>,
_phantom: PhantomData<RelatedItem>,
}
impl<T, RelatedItem> ManyRelationResolverFn<T, RelatedItem> {
pub fn new<F>(func: F) -> Self
where
F: Fn(&T) -> Pin<Box<dyn Future<Output = Result<Vec<RelatedItem>, String>> + Send>>
+ Send
+ Sync
+ 'static,
{
Self {
inner: Arc::new(func),
_phantom: PhantomData,
}
}
}
impl<T, RelatedItem> std::ops::Deref for ManyRelationResolverFn<T, RelatedItem> {
type Target = dyn Fn(&T) -> Pin<Box<dyn Future<Output = Result<Vec<RelatedItem>, String>> + Send>>
+ Send
+ Sync;
fn deref(&self) -> &Self::Target {
&*self.inner
}
}
impl<T, RelatedItem> Clone for ManyRelationResolverFn<T, RelatedItem> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
_phantom: PhantomData,
}
}
}
type SingleResolverFn<T, RelatedItem> = Arc<
dyn Fn(&T) -> Pin<Box<dyn Future<Output = Result<RelatedItem, String>> + Send>> + Send + Sync,
>;
pub struct SingleRelationResolverFn<T, RelatedItem> {
inner: SingleResolverFn<T, RelatedItem>,
_phantom: PhantomData<RelatedItem>,
}
impl<T, RelatedItem> SingleRelationResolverFn<T, RelatedItem> {
pub fn new<F>(func: F) -> Self
where
F: Fn(&T) -> Pin<Box<dyn Future<Output = Result<RelatedItem, String>> + Send>>
+ Send
+ Sync
+ 'static,
{
Self {
inner: Arc::new(func),
_phantom: PhantomData,
}
}
}
impl<T, RelatedItem> std::ops::Deref for SingleRelationResolverFn<T, RelatedItem> {
type Target = dyn Fn(&T) -> Pin<Box<dyn Future<Output = Result<RelatedItem, String>> + Send>>
+ Send
+ Sync;
fn deref(&self) -> &Self::Target {
&*self.inner
}
}
impl<T, RelatedItem> Clone for SingleRelationResolverFn<T, RelatedItem> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
_phantom: PhantomData,
}
}
}
#[derive(Debug, Clone)]
pub struct PrimaryKeyRelatedFieldORM<T>
where
T: Model + Serialize + DeserializeOwned + Clone + Send + Sync,
{
_phantom: PhantomData<T>,
pub allow_null: bool,
pub queryset_filter: Option<Filter>,
}
impl<T> PrimaryKeyRelatedFieldORM<T>
where
T: Model + Serialize + DeserializeOwned + Clone + Send + Sync,
{
pub fn new() -> Self {
Self {
_phantom: PhantomData,
allow_null: false,
queryset_filter: None,
}
}
pub fn with_allow_null(mut self, allow_null: bool) -> Self {
self.allow_null = allow_null;
self
}
pub fn with_queryset_filter(mut self, filter: Filter) -> Self {
self.queryset_filter = Some(filter);
self
}
pub async fn validate_exists(&self, pk: &T::PrimaryKey) -> Result<(), SerializerError>
where
T: Model + Serialize + DeserializeOwned + Clone + Send + Sync,
T::PrimaryKey: std::fmt::Display + Clone + Send + Sync,
{
use reinhardt_db::orm::QuerySet;
let mut queryset = QuerySet::<T>::new();
let pk_field = T::primary_key_field();
let filter = Filter::new(
pk_field.to_string(),
FilterOperator::Eq,
FilterValue::String(pk.to_string()),
);
queryset = queryset.filter(filter);
if let Some(ref custom_filter) = self.queryset_filter {
queryset = queryset.filter(custom_filter.clone());
}
let count = queryset.count().await.map_err(|e| SerializerError::Other {
message: format!("Failed to validate existence: {}", e),
})?;
if count == 0 {
return Err(SerializerError::Validation(ValidatorError::Custom {
message: format!("Instance with pk {} does not exist", pk),
}));
}
Ok(())
}
pub async fn get_instance(&self, pk: &T::PrimaryKey) -> Result<T, SerializerError>
where
T: Model + Serialize + DeserializeOwned + Clone + Send + Sync,
T::PrimaryKey: std::fmt::Display + Clone + Send + Sync,
{
use reinhardt_db::orm::QuerySet;
let mut queryset = QuerySet::<T>::new();
let pk_field = T::primary_key_field();
let filter = Filter::new(
pk_field.to_string(),
FilterOperator::Eq,
FilterValue::String(pk.to_string()),
);
queryset = queryset.filter(filter);
if let Some(ref custom_filter) = self.queryset_filter {
queryset = queryset.filter(custom_filter.clone());
}
let instance = queryset
.first()
.await
.map_err(|e| SerializerError::Other {
message: format!("Failed to fetch instance: {}", e),
})?
.ok_or_else(|| {
SerializerError::Validation(ValidatorError::Custom {
message: format!("Instance with pk {} not found", pk),
})
})?;
Ok(instance)
}
pub async fn get_instances(&self, pks: Vec<T::PrimaryKey>) -> Result<Vec<T>, SerializerError>
where
T: Model + Serialize + DeserializeOwned + Clone + Send + Sync,
T::PrimaryKey: std::fmt::Display + Clone + Send + Sync,
{
use reinhardt_db::orm::QuerySet;
if pks.is_empty() {
return Ok(Vec::new());
}
let mut queryset = QuerySet::<T>::new();
let pk_field = T::primary_key_field();
let pk_values: Vec<String> = pks.iter().map(|pk| pk.to_string()).collect();
let filter = Filter::new(
pk_field.to_string(),
FilterOperator::In,
FilterValue::Array(pk_values),
);
queryset = queryset.filter(filter);
if let Some(ref custom_filter) = self.queryset_filter {
queryset = queryset.filter(custom_filter.clone());
}
let instances = queryset.all().await.map_err(|e| SerializerError::Other {
message: format!("Failed to fetch instances: {}", e),
})?;
Ok(instances)
}
}
impl<T> Default for PrimaryKeyRelatedFieldORM<T>
where
T: Model + Serialize + DeserializeOwned + Clone + Send + Sync,
{
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct SlugRelatedFieldORM<T>
where
T: Model + Serialize + DeserializeOwned + Clone + Send + Sync,
{
_phantom: PhantomData<T>,
pub slug_field: String,
pub allow_null: bool,
pub queryset_filter: Option<Filter>,
}
impl<T> SlugRelatedFieldORM<T>
where
T: Model + Serialize + DeserializeOwned + Clone + Send + Sync,
{
pub fn new(slug_field: impl Into<String>) -> Self {
Self {
_phantom: PhantomData,
slug_field: slug_field.into(),
allow_null: false,
queryset_filter: None,
}
}
pub fn with_allow_null(mut self, allow_null: bool) -> Self {
self.allow_null = allow_null;
self
}
pub fn with_queryset_filter(mut self, filter: Filter) -> Self {
self.queryset_filter = Some(filter);
self
}
pub async fn validate_exists(&self, slug: &str) -> Result<(), SerializerError> {
use reinhardt_db::orm::QuerySet;
let mut queryset = QuerySet::<T>::new();
let filter = Filter::new(
self.slug_field.clone(),
FilterOperator::Eq,
FilterValue::String(slug.to_string()),
);
queryset = queryset.filter(filter);
if let Some(ref custom_filter) = self.queryset_filter {
queryset = queryset.filter(custom_filter.clone());
}
let count = queryset.count().await.map_err(|e| SerializerError::Other {
message: format!("Failed to validate existence: {}", e),
})?;
if count == 0 {
return Err(SerializerError::Validation(ValidatorError::Custom {
message: format!(
"Instance with {} '{}' does not exist",
self.slug_field, slug
),
}));
}
Ok(())
}
pub async fn get_instance(&self, slug: &str) -> Result<T, SerializerError> {
use reinhardt_db::orm::QuerySet;
let mut queryset = QuerySet::<T>::new();
let filter = Filter::new(
self.slug_field.clone(),
FilterOperator::Eq,
FilterValue::String(slug.to_string()),
);
queryset = queryset.filter(filter);
if let Some(ref custom_filter) = self.queryset_filter {
queryset = queryset.filter(custom_filter.clone());
}
let instance = queryset
.first()
.await
.map_err(|e| SerializerError::Other {
message: format!("Failed to fetch instance: {}", e),
})?
.ok_or_else(|| {
SerializerError::Validation(ValidatorError::Custom {
message: format!("Instance with {} '{}' not found", self.slug_field, slug),
})
})?;
Ok(instance)
}
pub async fn get_instances(&self, slugs: Vec<String>) -> Result<Vec<T>, SerializerError> {
use reinhardt_db::orm::QuerySet;
if slugs.is_empty() {
return Ok(Vec::new());
}
let mut queryset = QuerySet::<T>::new();
let slug_values: Vec<String> = slugs.to_vec();
let filter = Filter::new(
self.slug_field.clone(),
FilterOperator::In,
FilterValue::Array(slug_values),
);
queryset = queryset.filter(filter);
if let Some(ref custom_filter) = self.queryset_filter {
queryset = queryset.filter(custom_filter.clone());
}
let instances = queryset.all().await.map_err(|e| SerializerError::Other {
message: format!("Failed to fetch instances: {}", e),
})?;
Ok(instances)
}
}
#[derive(Debug, Clone, Default)]
pub struct QueryOptimizer {
pub select_related: Vec<String>,
pub prefetch_related: Vec<String>,
}
impl QueryOptimizer {
pub fn new() -> Self {
Self {
select_related: Vec::new(),
prefetch_related: Vec::new(),
}
}
pub fn with_select_related(mut self, fields: Vec<String>) -> Self {
self.select_related = fields;
self
}
pub fn with_prefetch_related(mut self, fields: Vec<String>) -> Self {
self.prefetch_related = fields;
self
}
pub fn apply<T>(&self, mut queryset: QuerySet<T>) -> QuerySet<T>
where
T: Model,
{
if !self.select_related.is_empty() {
let fields: Vec<&str> = self.select_related.iter().map(|s| s.as_str()).collect();
queryset = queryset.select_related(&fields);
}
if !self.prefetch_related.is_empty() {
let fields: Vec<&str> = self.prefetch_related.iter().map(|s| s.as_str()).collect();
queryset = queryset.prefetch_related(&fields);
}
queryset
}
}
#[async_trait]
pub trait OptimizableRelationField {
fn get_optimizer(&self) -> Option<&QueryOptimizer>;
fn set_optimizer(&mut self, optimizer: QueryOptimizer);
}
#[cfg(test)]
mod tests {
use super::*;
use reinhardt_db::orm::{FieldSelector, Model};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
struct TestUser {
id: Option<i64>,
username: String,
email: String,
}
#[derive(Debug, Clone)]
struct TestUserFields;
impl FieldSelector for TestUserFields {
fn with_alias(self, _alias: &str) -> Self {
self
}
}
impl Model for TestUser {
type PrimaryKey = i64;
type Fields = TestUserFields;
fn table_name() -> &'static str {
"test_users"
}
fn new_fields() -> Self::Fields {
TestUserFields
}
fn primary_key(&self) -> Option<Self::PrimaryKey> {
self.id
}
fn set_primary_key(&mut self, value: Self::PrimaryKey) {
self.id = Some(value);
}
}
#[test]
fn test_pk_related_field_creation() {
let field = PrimaryKeyRelatedFieldORM::<TestUser>::new();
assert!(!field.allow_null);
assert!(field.queryset_filter.is_none());
}
#[test]
fn test_pk_related_field_with_allow_null() {
let field = PrimaryKeyRelatedFieldORM::<TestUser>::new().with_allow_null(true);
assert!(field.allow_null);
}
#[test]
fn test_pk_related_field_with_filter() {
let filter = Filter::new(
"is_active".to_string(),
FilterOperator::Eq,
FilterValue::Boolean(true),
);
let field = PrimaryKeyRelatedFieldORM::<TestUser>::new().with_queryset_filter(filter);
assert!(field.queryset_filter.is_some());
}
#[test]
fn test_slug_related_field_creation() {
let field = SlugRelatedFieldORM::<TestUser>::new("username");
assert_eq!(field.slug_field, "username");
assert!(!field.allow_null);
assert!(field.queryset_filter.is_none());
}
#[test]
fn test_slug_related_field_with_allow_null() {
let field = SlugRelatedFieldORM::<TestUser>::new("username").with_allow_null(true);
assert!(field.allow_null);
}
#[test]
fn test_query_optimizer_creation() {
let optimizer = QueryOptimizer::new();
assert!(optimizer.select_related.is_empty());
assert!(optimizer.prefetch_related.is_empty());
}
#[test]
fn test_query_optimizer_with_select_related() {
let optimizer = QueryOptimizer::new()
.with_select_related(vec!["author".to_string(), "category".to_string()]);
assert_eq!(optimizer.select_related.len(), 2);
assert!(optimizer.select_related.contains(&"author".to_string()));
assert!(optimizer.select_related.contains(&"category".to_string()));
}
#[test]
fn test_query_optimizer_with_prefetch_related() {
let optimizer = QueryOptimizer::new()
.with_prefetch_related(vec!["comments".to_string(), "tags".to_string()]);
assert_eq!(optimizer.prefetch_related.len(), 2);
assert!(optimizer.prefetch_related.contains(&"comments".to_string()));
assert!(optimizer.prefetch_related.contains(&"tags".to_string()));
}
#[test]
fn test_query_optimizer_apply() {
let optimizer = QueryOptimizer::new()
.with_select_related(vec!["author".to_string()])
.with_prefetch_related(vec!["comments".to_string()]);
let queryset = QuerySet::<TestUser>::new();
let optimized = optimizer.apply(queryset);
drop(optimized);
}
}