use crate::viewsets::{FilterConfig, PaginationConfig};
use async_trait::async_trait;
use hyper::Method;
use reinhardt_core::exception::{Error, Result};
use reinhardt_db::orm::{Filter, FilterOperator, FilterValue, Manager, Model, QuerySet};
use reinhardt_http::{Request, Response};
use reinhardt_rest::serializers::{Serializer, ValidatorConfig};
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use crate::core::View;
pub struct ListCreateAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone,
S: Serializer<Input = M, Output = String> + Send + Sync,
{
queryset: Option<QuerySet<M>>,
pagination_config: Option<PaginationConfig>,
filter_config: Option<FilterConfig>,
ordering: Option<Vec<String>>,
validation_config: Option<ValidatorConfig<M>>,
_serializer: PhantomData<S>,
}
impl<M, S> ListCreateAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone + 'static,
S: Serializer<Input = M, Output = String> + Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
queryset: None,
pagination_config: None,
filter_config: None,
ordering: None,
validation_config: None,
_serializer: PhantomData,
}
}
pub fn with_queryset(mut self, queryset: QuerySet<M>) -> Self {
self.queryset = Some(queryset);
self
}
pub fn with_paginate_by(mut self, page_size: usize) -> Self {
self.pagination_config = Some(PaginationConfig::page_number(page_size, Some(100)));
self
}
pub fn with_filter_config(mut self, filter_config: FilterConfig) -> Self {
self.filter_config = Some(filter_config);
self
}
pub fn with_ordering(mut self, ordering: Vec<String>) -> Self {
self.ordering = Some(ordering);
self
}
fn get_queryset(&self) -> QuerySet<M> {
self.queryset.clone().unwrap_or_default()
}
fn get_filtered_queryset(&self, request: &Request) -> QuerySet<M> {
let mut queryset = self.get_queryset();
if let Some(ref ordering) = self.ordering {
let order_fields: Vec<&str> = ordering.iter().map(|s| s.as_str()).collect();
queryset = queryset.order_by(&order_fields);
}
if let Some(ref filter_config) = self.filter_config {
for field in &filter_config.filterable_fields {
if let Some(value) = request.query_params.get(field) {
let filter = Filter::new(
field.clone(),
FilterOperator::Eq,
FilterValue::String(value.clone()),
);
queryset = queryset.filter(filter);
}
}
}
queryset
}
async fn get_objects(&self, request: &Request) -> Result<Vec<M>> {
let mut queryset = self.get_filtered_queryset(request);
if let Some(ref pagination) = self.pagination_config {
match pagination {
PaginationConfig::PageNumber { page_size, .. } => {
let page = request
.query_params
.get("page")
.and_then(|p| p.parse::<usize>().ok())
.unwrap_or(1);
queryset = queryset.paginate(page, *page_size);
}
PaginationConfig::LimitOffset {
default_limit,
max_limit,
} => {
let limit = request
.query_params
.get("limit")
.and_then(|l| l.parse::<usize>().ok())
.unwrap_or(*default_limit)
.min(max_limit.unwrap_or(usize::MAX));
let offset = request
.query_params
.get("offset")
.and_then(|o| o.parse::<usize>().ok())
.unwrap_or(0);
queryset = queryset.offset(offset).limit(limit);
}
PaginationConfig::Cursor { page_size, .. } => {
queryset = queryset.limit(*page_size);
}
PaginationConfig::None => {
}
}
}
queryset.all().await.map_err(|e| Error::Http(e.to_string()))
}
}
impl<M, S> Default for ListCreateAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone + 'static,
S: Serializer<Input = M, Output = String> + Send + Sync + 'static,
{
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl<M, S> View for ListCreateAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone + 'static,
S: Serializer<Input = M, Output = String> + Send + Sync + 'static + Default,
{
async fn dispatch(&self, request: Request) -> Result<Response> {
match request.method {
Method::GET | Method::HEAD => {
let objects = self.get_objects(&request).await?;
let serializer = S::default();
let serialized = objects
.iter()
.map(|obj| {
serializer
.serialize(obj)
.map_err(|e| Error::Http(e.to_string()))
})
.collect::<Result<Vec<_>>>()?;
let results: Vec<serde_json::Value> = serialized
.iter()
.filter_map(|s| serde_json::from_str::<serde_json::Value>(s).ok())
.collect();
let response_body = if let Some(ref pagination) = self.pagination_config {
let total_count = self
.get_filtered_queryset(&request)
.count()
.await
.map_err(|e| Error::Http(e.to_string()))?;
match pagination {
PaginationConfig::PageNumber { page_size, .. } => {
let page = request
.query_params
.get("page")
.and_then(|p| p.parse::<usize>().ok())
.unwrap_or(1);
let has_next = page.saturating_mul(*page_size) < total_count;
serde_json::json!({
"count": total_count,
"page": page,
"page_size": page_size,
"next": if has_next { Some(format!("?page={}", page + 1)) } else { None::<String> },
"previous": if page > 1 { Some(format!("?page={}", page - 1)) } else { None::<String> },
"results": results
})
}
PaginationConfig::LimitOffset { .. } => {
let offset = request
.query_params
.get("offset")
.and_then(|o| o.parse::<usize>().ok())
.unwrap_or(0);
let limit = request
.query_params
.get("limit")
.and_then(|l| l.parse::<usize>().ok())
.unwrap_or(10);
let has_next = offset.saturating_add(limit) < total_count;
serde_json::json!({
"count": total_count,
"offset": offset,
"limit": limit,
"next": if has_next { Some(format!("?offset={}&limit={}", offset.saturating_add(limit), limit)) } else { None::<String> },
"previous": if offset > 0 { Some(format!("?offset={}&limit={}", offset.saturating_sub(limit), limit)) } else { None::<String> },
"results": results
})
}
_ => {
serde_json::json!({
"count": total_count,
"results": results
})
}
}
} else {
serde_json::json!(results)
};
Response::ok().with_json(&response_body)
}
Method::POST => {
let data: M = request
.json()
.map_err(|e| Error::Http(format!("Invalid request body: {}", e)))?;
if let Some(ref validators) = self.validation_config
&& let Some(di_ctx) =
request.get_di_context::<std::sync::Arc<reinhardt_di::InjectionContext>>()
{
use reinhardt_db::DatabaseConnection;
use reinhardt_di::Depends;
let conn = Depends::<DatabaseConnection>::resolve(&di_ctx, true)
.await
.map_err(|e| Error::Internal(format!("Failed to resolve DB: {:?}", e)))?;
validators
.validate_async(conn.into_inner().inner(), &data, None)
.await?;
}
let queryset = self.get_queryset();
let created = queryset
.create(data)
.await
.map_err(|e| Error::Http(format!("Failed to create: {}", e)))?;
let serializer = S::default();
let serialized = serializer
.serialize(&created)
.map_err(|e| Error::Http(e.to_string()))?;
let json_value: serde_json::Value = serde_json::from_str(&serialized)
.map_err(|e| Error::Http(format!("Serialization error: {}", e)))?;
Response::created().with_json(&json_value)
}
_ => Err(Error::MethodNotAllowed(format!(
"Method {} not allowed",
request.method
))),
}
}
fn allowed_methods(&self) -> Vec<&'static str> {
vec!["GET", "HEAD", "POST", "OPTIONS"]
}
}
impl<M, S> std::panic::UnwindSafe for ListCreateAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone,
S: Serializer<Input = M, Output = String> + Send + Sync,
{
}
impl<M, S> std::panic::RefUnwindSafe for ListCreateAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone,
S: Serializer<Input = M, Output = String> + Send + Sync,
{
}
pub struct RetrieveUpdateAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone,
S: Serializer<Input = M, Output = String> + Send + Sync,
{
queryset: Option<QuerySet<M>>,
lookup_field: String,
_serializer: PhantomData<S>,
}
impl<M, S> RetrieveUpdateAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone + 'static,
S: Serializer<Input = M, Output = String> + Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
queryset: None,
lookup_field: "pk".to_string(),
_serializer: PhantomData,
}
}
pub fn with_queryset(mut self, queryset: QuerySet<M>) -> Self {
self.queryset = Some(queryset);
self
}
pub fn with_lookup_field(mut self, field: String) -> Self {
self.lookup_field = field;
self
}
fn get_queryset(&self) -> QuerySet<M> {
self.queryset.clone().unwrap_or_default()
}
async fn get_object(&self, request: &Request) -> Result<M>
where
M: serde::de::DeserializeOwned,
{
let lookup_value = request.path_params.get(&self.lookup_field).ok_or_else(|| {
Error::Http(format!(
"Missing lookup field '{}' in path parameters",
self.lookup_field
))
})?;
let filter_value = if let Ok(int_value) = lookup_value.parse::<i64>() {
FilterValue::Integer(int_value)
} else {
FilterValue::String(lookup_value.clone())
};
let filter = Filter::new(self.lookup_field.clone(), FilterOperator::Eq, filter_value);
self.get_queryset()
.filter(filter)
.get()
.await
.map_err(|e| Error::Http(format!("Object not found: {}", e)))
}
}
impl<M, S> Default for RetrieveUpdateAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone + 'static,
S: Serializer<Input = M, Output = String> + Send + Sync + 'static,
{
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl<M, S> View for RetrieveUpdateAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone + 'static,
S: Serializer<Input = M, Output = String> + Send + Sync + 'static + Default,
{
async fn dispatch(&self, request: Request) -> Result<Response> {
match request.method {
Method::GET | Method::HEAD => {
let object = self.get_object(&request).await?;
let serializer = S::default();
let serialized = serializer
.serialize(&object)
.map_err(|e| Error::Http(e.to_string()))?;
let json_value: serde_json::Value = serde_json::from_str(&serialized)
.map_err(|e| Error::Http(format!("Serialization error: {}", e)))?;
Response::ok().with_json(&json_value)
}
Method::PUT => {
let mut object = self.get_object(&request).await?;
let update_data: M = request
.json()
.map_err(|e| Error::Http(format!("Invalid request body: {}", e)))?;
let pk = object
.primary_key()
.ok_or_else(|| Error::Http("Object has no primary key".to_string()))?;
object = update_data;
object.set_primary_key(pk);
let manager = Manager::<M>::new();
let updated = manager
.update(&object)
.await
.map_err(|e| Error::Http(format!("Failed to update: {}", e)))?;
let serializer = S::default();
let serialized = serializer
.serialize(&updated)
.map_err(|e| Error::Http(e.to_string()))?;
let json_value: serde_json::Value = serde_json::from_str(&serialized)
.map_err(|e| Error::Http(format!("Serialization error: {}", e)))?;
Response::ok().with_json(&json_value)
}
Method::PATCH => {
let object = self.get_object(&request).await?;
let serializer = S::default();
let current_json = serializer
.serialize(&object)
.map_err(|e| Error::Http(e.to_string()))?;
let mut current: serde_json::Value = serde_json::from_str(¤t_json)
.map_err(|e| Error::Http(format!("Serialization error: {}", e)))?;
let patch_data: serde_json::Value = request
.json()
.map_err(|e| Error::Http(format!("Invalid request body: {}", e)))?;
crate::generic::patch_utils::merge_patch_object_into(&mut current, &patch_data)
.map_err(Error::Http)?;
let merged: M = serde_json::from_value(current)
.map_err(|e| Error::Http(format!("Failed to merge patch: {}", e)))?;
let manager = Manager::<M>::new();
let updated = manager
.update(&merged)
.await
.map_err(|e| Error::Http(format!("Failed to update: {}", e)))?;
let serialized = serializer
.serialize(&updated)
.map_err(|e| Error::Http(e.to_string()))?;
let json_value: serde_json::Value = serde_json::from_str(&serialized)
.map_err(|e| Error::Http(format!("Serialization error: {}", e)))?;
Response::ok().with_json(&json_value)
}
_ => Err(Error::MethodNotAllowed(format!(
"Method {} not allowed",
request.method
))),
}
}
fn allowed_methods(&self) -> Vec<&'static str> {
vec!["GET", "HEAD", "PUT", "PATCH", "OPTIONS"]
}
}
pub struct RetrieveDestroyAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone,
S: Serializer<Input = M, Output = String> + Send + Sync,
{
queryset: Option<QuerySet<M>>,
lookup_field: String,
_serializer: PhantomData<S>,
}
impl<M, S> RetrieveDestroyAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone + 'static,
S: Serializer<Input = M, Output = String> + Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
queryset: None,
lookup_field: "pk".to_string(),
_serializer: PhantomData,
}
}
pub fn with_queryset(mut self, queryset: QuerySet<M>) -> Self {
self.queryset = Some(queryset);
self
}
pub fn with_lookup_field(mut self, field: String) -> Self {
self.lookup_field = field;
self
}
fn get_queryset(&self) -> QuerySet<M> {
self.queryset.clone().unwrap_or_default()
}
async fn get_object(&self, request: &Request) -> Result<M>
where
M: serde::de::DeserializeOwned,
{
let lookup_value = request.path_params.get(&self.lookup_field).ok_or_else(|| {
Error::Http(format!(
"Missing lookup field '{}' in path parameters",
self.lookup_field
))
})?;
let filter_value = if let Ok(int_value) = lookup_value.parse::<i64>() {
FilterValue::Integer(int_value)
} else {
FilterValue::String(lookup_value.clone())
};
let filter = Filter::new(self.lookup_field.clone(), FilterOperator::Eq, filter_value);
self.get_queryset()
.filter(filter)
.get()
.await
.map_err(|e| Error::Http(format!("Object not found: {}", e)))
}
}
impl<M, S> Default for RetrieveDestroyAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone + 'static,
S: Serializer<Input = M, Output = String> + Send + Sync + 'static,
{
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl<M, S> View for RetrieveDestroyAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone + 'static,
S: Serializer<Input = M, Output = String> + Send + Sync + 'static + Default,
{
async fn dispatch(&self, request: Request) -> Result<Response> {
match request.method {
Method::GET | Method::HEAD => {
let object = self.get_object(&request).await?;
let serializer = S::default();
let serialized = serializer
.serialize(&object)
.map_err(|e| Error::Http(e.to_string()))?;
let json_value: serde_json::Value = serde_json::from_str(&serialized)
.map_err(|e| Error::Http(format!("Serialization error: {}", e)))?;
Response::ok().with_json(&json_value)
}
Method::DELETE => {
let object = self.get_object(&request).await?;
let pk = object
.primary_key()
.ok_or_else(|| Error::Http("Object has no primary key".to_string()))?;
let manager = Manager::<M>::new();
manager
.delete(pk)
.await
.map_err(|e| Error::Http(format!("Failed to delete: {}", e)))?;
Ok(Response::no_content())
}
_ => Err(Error::MethodNotAllowed(format!(
"Method {} not allowed",
request.method
))),
}
}
fn allowed_methods(&self) -> Vec<&'static str> {
vec!["GET", "HEAD", "DELETE", "OPTIONS"]
}
}
pub struct RetrieveUpdateDestroyAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone,
S: Serializer<Input = M, Output = String> + Send + Sync,
{
queryset: Option<QuerySet<M>>,
lookup_field: String,
_serializer: PhantomData<S>,
}
impl<M, S> RetrieveUpdateDestroyAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone + 'static,
S: Serializer<Input = M, Output = String> + Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
queryset: None,
lookup_field: "pk".to_string(),
_serializer: PhantomData,
}
}
pub fn with_queryset(mut self, queryset: QuerySet<M>) -> Self {
self.queryset = Some(queryset);
self
}
pub fn with_lookup_field(mut self, field: String) -> Self {
self.lookup_field = field;
self
}
fn get_queryset(&self) -> QuerySet<M> {
self.queryset.clone().unwrap_or_default()
}
async fn get_object(&self, request: &Request) -> Result<M>
where
M: serde::de::DeserializeOwned,
{
let lookup_value = request.path_params.get(&self.lookup_field).ok_or_else(|| {
Error::Http(format!(
"Missing lookup field '{}' in path parameters",
self.lookup_field
))
})?;
let filter_value = if let Ok(int_value) = lookup_value.parse::<i64>() {
FilterValue::Integer(int_value)
} else {
FilterValue::String(lookup_value.clone())
};
let filter = Filter::new(self.lookup_field.clone(), FilterOperator::Eq, filter_value);
self.get_queryset()
.filter(filter)
.get()
.await
.map_err(|e| Error::Http(format!("Object not found: {}", e)))
}
}
impl<M, S> Default for RetrieveUpdateDestroyAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone + 'static,
S: Serializer<Input = M, Output = String> + Send + Sync + 'static,
{
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl<M, S> View for RetrieveUpdateDestroyAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone + 'static,
S: Serializer<Input = M, Output = String> + Send + Sync + 'static + Default,
{
async fn dispatch(&self, request: Request) -> Result<Response> {
match request.method {
Method::GET | Method::HEAD => {
let object = self.get_object(&request).await?;
let serializer = S::default();
let serialized = serializer
.serialize(&object)
.map_err(|e| Error::Http(e.to_string()))?;
let json_value: serde_json::Value = serde_json::from_str(&serialized)
.map_err(|e| Error::Http(format!("Serialization error: {}", e)))?;
Response::ok().with_json(&json_value)
}
Method::PUT => {
let mut object = self.get_object(&request).await?;
let update_data: M = request
.json()
.map_err(|e| Error::Http(format!("Invalid request body: {}", e)))?;
let pk = object
.primary_key()
.ok_or_else(|| Error::Http("Object has no primary key".to_string()))?;
object = update_data;
object.set_primary_key(pk);
let manager = Manager::<M>::new();
let updated = manager
.update(&object)
.await
.map_err(|e| Error::Http(format!("Failed to update: {}", e)))?;
let serializer = S::default();
let serialized = serializer
.serialize(&updated)
.map_err(|e| Error::Http(e.to_string()))?;
let json_value: serde_json::Value = serde_json::from_str(&serialized)
.map_err(|e| Error::Http(format!("Serialization error: {}", e)))?;
Response::ok().with_json(&json_value)
}
Method::PATCH => {
let object = self.get_object(&request).await?;
let serializer = S::default();
let current_json = serializer
.serialize(&object)
.map_err(|e| Error::Http(e.to_string()))?;
let mut current: serde_json::Value = serde_json::from_str(¤t_json)
.map_err(|e| Error::Http(format!("Serialization error: {}", e)))?;
let patch_data: serde_json::Value = request
.json()
.map_err(|e| Error::Http(format!("Invalid request body: {}", e)))?;
crate::generic::patch_utils::merge_patch_object_into(&mut current, &patch_data)
.map_err(Error::Http)?;
let merged: M = serde_json::from_value(current)
.map_err(|e| Error::Http(format!("Failed to merge patch: {}", e)))?;
let manager = Manager::<M>::new();
let updated = manager
.update(&merged)
.await
.map_err(|e| Error::Http(format!("Failed to update: {}", e)))?;
let serialized = serializer
.serialize(&updated)
.map_err(|e| Error::Http(e.to_string()))?;
let json_value: serde_json::Value = serde_json::from_str(&serialized)
.map_err(|e| Error::Http(format!("Serialization error: {}", e)))?;
Response::ok().with_json(&json_value)
}
Method::DELETE => {
let object = self.get_object(&request).await?;
let pk = object
.primary_key()
.ok_or_else(|| Error::Http("Object has no primary key".to_string()))?;
let manager = Manager::<M>::new();
manager
.delete(pk)
.await
.map_err(|e| Error::Http(format!("Failed to delete: {}", e)))?;
Ok(Response::no_content())
}
_ => Err(Error::MethodNotAllowed(format!(
"Method {} not allowed",
request.method
))),
}
}
fn allowed_methods(&self) -> Vec<&'static str> {
vec!["GET", "HEAD", "PUT", "PATCH", "DELETE", "OPTIONS"]
}
}