use async_trait::async_trait;
use hyper::Method;
use reinhardt_core::exception::{Error, Result};
use reinhardt_db::orm::{Filter, FilterOperator, FilterValue, Model, QuerySet};
use reinhardt_http::{Request, Response};
use reinhardt_rest::serializers::Serializer;
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use crate::core::View;
pub struct RetrieveAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone,
S: Serializer<Input = M, Output = String> + Send + Sync,
{
queryset: Option<QuerySet<M>>,
lookup_field: String,
_serializer: PhantomData<S>,
}
impl<M, S> RetrieveAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone + 'static,
S: Serializer<Input = M, Output = String> + Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
queryset: None,
lookup_field: "pk".to_string(),
_serializer: PhantomData,
}
}
pub fn with_queryset(mut self, queryset: QuerySet<M>) -> Self {
self.queryset = Some(queryset);
self
}
pub fn with_lookup_field(mut self, field: String) -> Self {
self.lookup_field = field;
self
}
fn get_queryset(&self) -> QuerySet<M> {
self.queryset.clone().unwrap_or_default()
}
async fn get_object(&self, request: &Request) -> Result<M>
where
M: serde::de::DeserializeOwned,
{
let lookup_value = request.path_params.get(&self.lookup_field).ok_or_else(|| {
Error::Http(format!(
"Missing lookup field '{}' in path parameters",
self.lookup_field
))
})?;
let filter_value = if let Ok(int_value) = lookup_value.parse::<i64>() {
FilterValue::Integer(int_value)
} else {
FilterValue::String(lookup_value.clone())
};
let filter = Filter::new(self.lookup_field.clone(), FilterOperator::Eq, filter_value);
self.get_queryset()
.filter(filter)
.get()
.await
.map_err(|e| Error::Http(format!("Object not found: {}", e)))
}
}
impl<M, S> Default for RetrieveAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone + 'static,
S: Serializer<Input = M, Output = String> + Send + Sync + 'static,
{
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl<M, S> View for RetrieveAPIView<M, S>
where
M: Model + Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone + 'static,
S: Serializer<Input = M, Output = String> + Send + Sync + 'static + Default,
{
async fn dispatch(&self, request: Request) -> Result<Response> {
match request.method {
Method::GET | Method::HEAD => {
let object = self.get_object(&request).await?;
let serializer = S::default();
let serialized = serializer
.serialize(&object)
.map_err(|e| Error::Http(e.to_string()))?;
let json_value: serde_json::Value = serde_json::from_str(&serialized)
.map_err(|e| Error::Http(format!("Serialization error: {}", e)))?;
Response::ok().with_json(&json_value)
}
_ => Err(Error::MethodNotAllowed(format!(
"Method {} not allowed",
request.method
))),
}
}
fn allowed_methods(&self) -> Vec<&'static str> {
vec!["GET", "HEAD", "OPTIONS"]
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[rstest]
#[case("42", true)]
#[case("0", true)]
#[case("-1", true)]
#[case("9999999999", true)]
#[case("abc", false)]
#[case("12abc", false)]
#[case("", false)]
#[case("3.14", false)]
#[case("9223372036854775808", false)] fn test_lookup_value_integer_parsing(#[case] input: &str, #[case] should_be_integer: bool) {
let lookup_value = input.to_string();
let filter_value = if let Ok(int_value) = lookup_value.parse::<i64>() {
FilterValue::Integer(int_value)
} else {
FilterValue::String(lookup_value.clone())
};
match filter_value {
FilterValue::Integer(_) => assert!(
should_be_integer,
"Expected String variant for input '{}'",
input
),
FilterValue::String(_) => assert!(
!should_be_integer,
"Expected Integer variant for input '{}'",
input
),
_ => panic!("Unexpected FilterValue variant"),
}
}
}