use async_trait::async_trait;
use hyper::Method;
use reinhardt_core::exception::{Error, Result};
use reinhardt_db::orm::{Model, QuerySet};
use reinhardt_http::{Request, Response};
use reinhardt_rest::serializers::{Serializer, ValidatorConfig};
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use crate::core::View;
pub struct CreateAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone,
S: Serializer<Input = M, Output = String> + Send + Sync,
{
queryset: Option<QuerySet<M>>,
validation_config: Option<ValidatorConfig<M>>,
_serializer: PhantomData<S>,
}
impl<M, S> CreateAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone + 'static,
S: Serializer<Input = M, Output = String> + Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
queryset: None,
validation_config: None,
_serializer: PhantomData,
}
}
pub fn with_queryset(mut self, queryset: QuerySet<M>) -> Self {
self.queryset = Some(queryset);
self
}
fn get_queryset(&self) -> QuerySet<M> {
self.queryset.clone().unwrap_or_default()
}
pub fn with_validation_config(mut self, config: ValidatorConfig<M>) -> Self {
self.validation_config = Some(config);
self
}
async fn perform_create(&self, request: &Request) -> Result<M> {
let data: M = request
.json()
.map_err(|e| Error::Http(format!("Invalid request body: {}", e)))?;
if let Some(ref validators) = self.validation_config
&& let Some(di_ctx) =
request.get_di_context::<std::sync::Arc<reinhardt_di::InjectionContext>>()
{
use reinhardt_db::DatabaseConnection;
use reinhardt_di::Depends;
let conn = Depends::<DatabaseConnection>::resolve(&di_ctx, true)
.await
.map_err(|e| Error::Internal(format!("Failed to resolve DB: {:?}", e)))?;
validators
.validate_async(conn.into_inner().inner(), &data, None)
.await?;
}
let queryset = self.get_queryset();
queryset
.create(data)
.await
.map_err(|e| Error::Http(format!("Failed to create: {}", e)))
}
}
impl<M, S> Default for CreateAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone + 'static,
S: Serializer<Input = M, Output = String> + Send + Sync + 'static,
{
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl<M, S> View for CreateAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone + 'static,
S: Serializer<Input = M, Output = String> + Send + Sync + 'static + Default,
{
async fn dispatch(&self, request: Request) -> Result<Response> {
match request.method {
Method::POST => {
let obj = self.perform_create(&request).await?;
let serializer = S::default();
let serialized = serializer
.serialize(&obj)
.map_err(|e| Error::Http(e.to_string()))?;
let json_value: serde_json::Value = serde_json::from_str(&serialized)
.map_err(|e| Error::Http(format!("Failed to parse serialized data: {}", e)))?;
Response::created()
.with_json(&json_value)
.map_err(|e| Error::Http(e.to_string()))
}
_ => Err(Error::MethodNotAllowed(format!(
"Method {} not allowed",
request.method
))),
}
}
fn allowed_methods(&self) -> Vec<&'static str> {
vec!["POST", "OPTIONS"]
}
}
impl<M, S> std::panic::UnwindSafe for CreateAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone,
S: Serializer<Input = M, Output = String> + Send + Sync,
{
}
impl<M, S> std::panic::RefUnwindSafe for CreateAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone,
S: Serializer<Input = M, Output = String> + Send + Sync,
{
}