use crate::viewsets::actions::Action;
use crate::viewsets::filtering_support::{FilterConfig, FilterableViewSet, OrderingConfig};
use crate::viewsets::handler::ModelViewSetHandler;
use crate::viewsets::metadata::{ActionMetadata, get_actions_for_viewset};
use crate::viewsets::middleware::ViewSetMiddleware;
use crate::viewsets::pagination_support::{PaginatedViewSet, PaginationConfig};
use async_trait::async_trait;
use hyper::Method;
use reinhardt_auth::Permission;
use reinhardt_db::orm::{Model, query_types::DbBackend};
use reinhardt_http::{Request, Response, Result};
use reinhardt_rest::filters::FilterBackend;
use reinhardt_rest::serializers::Serializer;
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::Arc;
fn extract_pk(request: &Request, lookup_field: &str) -> Result<serde_json::Value> {
request
.path_params
.get(lookup_field)
.map(|v| serde_json::Value::String(v.clone()))
.ok_or_else(|| {
reinhardt_core::exception::Error::Http(format!(
"Missing path parameter: {}",
lookup_field
))
})
}
fn method_not_allowed(method: &Method) -> reinhardt_core::exception::Error {
reinhardt_core::exception::Error::MethodNotAllowed(format!("Method {} not allowed", method))
}
#[async_trait]
pub trait ViewSet: Send + Sync {
fn get_basename(&self) -> &str;
fn get_lookup_field(&self) -> &str {
"id"
}
async fn dispatch(&self, request: Request, action: Action) -> Result<Response>;
fn get_extra_actions(&self) -> Vec<ActionMetadata> {
let viewset_type = std::any::type_name::<Self>();
let mut actions = get_actions_for_viewset(viewset_type);
let manual_actions = crate::viewsets::registry::get_registered_actions(viewset_type);
actions.extend(manual_actions);
actions
}
fn get_extra_action_url_map(&self) -> HashMap<String, String> {
HashMap::new()
}
fn get_current_base_url(&self) -> Option<String> {
None
}
fn reverse_action(&self, _action_name: &str, _args: &[&str]) -> Result<String> {
Err(reinhardt_core::exception::Error::NotFound(
"ViewSet not bound to router".to_string(),
))
}
fn get_middleware(&self) -> Option<Arc<dyn ViewSetMiddleware>> {
None
}
fn requires_login(&self) -> bool {
false
}
fn get_required_permissions(&self) -> Vec<String> {
Vec::new()
}
}
#[allow(dead_code)]
#[derive(Clone)]
pub struct GenericViewSet<T> {
basename: String,
handler: T,
}
impl<T: 'static> GenericViewSet<T> {
pub fn new(basename: impl Into<String>, handler: T) -> Self {
Self {
basename: basename.into(),
handler,
}
}
pub fn as_view(self) -> crate::viewsets::builder::ViewSetBuilder<Self>
where
T: Send + Sync,
{
crate::viewsets::builder::ViewSetBuilder::new(self)
}
}
#[async_trait]
impl<T: Send + Sync> ViewSet for GenericViewSet<T> {
fn get_basename(&self) -> &str {
&self.basename
}
async fn dispatch(&self, _request: Request, action: Action) -> Result<Response> {
Err(reinhardt_core::exception::Error::NotFound(format!(
"GenericViewSet has no built-in CRUD logic for action {:?}. \
For real CRUD, use ModelViewSet<M, S> or ReadOnlyModelViewSet<M, S>. \
To implement custom logic, define your own struct and \
`impl ViewSet for YourType` with a hand-written dispatch().",
action.action_type
)))
}
}
pub struct ModelViewSet<M, S>
where
M: Model + Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
basename: String,
lookup_field: String,
pagination_config: Option<PaginationConfig>,
filter_config: Option<FilterConfig>,
ordering_config: Option<OrderingConfig>,
handler: ModelViewSetHandler<M>,
_serializer: PhantomData<S>,
}
impl<M, S> FilterableViewSet for ModelViewSet<M, S>
where
M: Model + Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
fn get_filter_config(&self) -> Option<FilterConfig> {
self.filter_config.clone()
}
fn get_ordering_config(&self) -> Option<OrderingConfig> {
self.ordering_config.clone()
}
}
impl<M, S> ModelViewSet<M, S>
where
M: Model + Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
pub fn new(basename: impl Into<String>) -> Self {
Self {
basename: basename.into(),
lookup_field: "id".to_string(),
pagination_config: Some(PaginationConfig::default()),
filter_config: None,
ordering_config: None,
handler: ModelViewSetHandler::<M>::new(),
_serializer: PhantomData,
}
}
pub fn with_lookup_field(mut self, field: impl Into<String>) -> Self {
self.lookup_field = field.into();
self
}
pub fn with_pagination(mut self, config: PaginationConfig) -> Self {
self.pagination_config = Some(config);
self
}
pub fn without_pagination(mut self) -> Self {
self.pagination_config = None;
self
}
pub fn with_filters(mut self, config: FilterConfig) -> Self {
self.filter_config = Some(config);
self
}
pub fn with_ordering(mut self, config: OrderingConfig) -> Self {
self.ordering_config = Some(config);
self
}
pub fn with_pool(mut self, pool: Arc<sqlx::AnyPool>) -> Self {
self.handler = std::mem::take(&mut self.handler).with_pool(pool);
self
}
pub fn with_db_backend(mut self, backend: DbBackend) -> Self {
self.handler = std::mem::take(&mut self.handler).with_db_backend(backend);
self
}
pub fn with_serializer(
mut self,
serializer: Arc<dyn Serializer<Input = M, Output = String> + Send + Sync>,
) -> Self {
self.handler = std::mem::take(&mut self.handler).with_serializer(serializer);
self
}
pub fn with_queryset(mut self, items: Vec<M>) -> Self {
self.handler = std::mem::take(&mut self.handler).with_queryset(items);
self
}
pub fn add_permission(mut self, permission: Arc<dyn Permission>) -> Self {
self.handler = std::mem::take(&mut self.handler).add_permission(permission);
self
}
pub fn add_filter_backend(mut self, backend: Arc<dyn FilterBackend>) -> Self {
self.handler = std::mem::take(&mut self.handler).add_filter_backend(backend);
self
}
pub fn as_view(self) -> crate::viewsets::builder::ViewSetBuilder<Self> {
crate::viewsets::builder::ViewSetBuilder::new(self)
}
}
#[async_trait]
impl<M, S> ViewSet for ModelViewSet<M, S>
where
M: Model + Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
fn get_basename(&self) -> &str {
&self.basename
}
fn get_lookup_field(&self) -> &str {
&self.lookup_field
}
async fn dispatch(&self, request: Request, action: Action) -> Result<Response> {
match (request.method.clone(), action.detail) {
(Method::GET, false) => self.handler.list(&request).await.map_err(Into::into),
(Method::POST, false) => self.handler.create(&request).await.map_err(Into::into),
(Method::GET, true) => {
let pk = extract_pk(&request, &self.lookup_field)?;
self.handler
.retrieve(&request, pk)
.await
.map_err(Into::into)
}
(Method::PUT, true) | (Method::PATCH, true) => {
let pk = extract_pk(&request, &self.lookup_field)?;
self.handler.update(&request, pk).await.map_err(Into::into)
}
(Method::DELETE, true) => {
let pk = extract_pk(&request, &self.lookup_field)?;
self.handler.destroy(&request, pk).await.map_err(Into::into)
}
_ => Err(method_not_allowed(&request.method)),
}
}
}
impl<M, S> PaginatedViewSet for ModelViewSet<M, S>
where
M: Model + Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
fn get_pagination_config(&self) -> Option<PaginationConfig> {
self.pagination_config.clone()
}
}
pub struct ReadOnlyModelViewSet<M, S>
where
M: Model + Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
basename: String,
lookup_field: String,
pagination_config: Option<PaginationConfig>,
filter_config: Option<FilterConfig>,
ordering_config: Option<OrderingConfig>,
handler: ModelViewSetHandler<M>,
_serializer: PhantomData<S>,
}
impl<M, S> ReadOnlyModelViewSet<M, S>
where
M: Model + Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
pub fn new(basename: impl Into<String>) -> Self {
Self {
basename: basename.into(),
lookup_field: "id".to_string(),
pagination_config: Some(PaginationConfig::default()),
filter_config: None,
ordering_config: None,
handler: ModelViewSetHandler::<M>::new(),
_serializer: PhantomData,
}
}
pub fn with_lookup_field(mut self, field: impl Into<String>) -> Self {
self.lookup_field = field.into();
self
}
pub fn with_pagination(mut self, config: PaginationConfig) -> Self {
self.pagination_config = Some(config);
self
}
pub fn without_pagination(mut self) -> Self {
self.pagination_config = None;
self
}
pub fn with_filters(mut self, config: FilterConfig) -> Self {
self.filter_config = Some(config);
self
}
pub fn with_ordering(mut self, config: OrderingConfig) -> Self {
self.ordering_config = Some(config);
self
}
pub fn with_pool(mut self, pool: Arc<sqlx::AnyPool>) -> Self {
self.handler = std::mem::take(&mut self.handler).with_pool(pool);
self
}
pub fn with_db_backend(mut self, backend: DbBackend) -> Self {
self.handler = std::mem::take(&mut self.handler).with_db_backend(backend);
self
}
pub fn with_serializer(
mut self,
serializer: Arc<dyn Serializer<Input = M, Output = String> + Send + Sync>,
) -> Self {
self.handler = std::mem::take(&mut self.handler).with_serializer(serializer);
self
}
pub fn with_queryset(mut self, items: Vec<M>) -> Self {
self.handler = std::mem::take(&mut self.handler).with_queryset(items);
self
}
pub fn add_permission(mut self, permission: Arc<dyn Permission>) -> Self {
self.handler = std::mem::take(&mut self.handler).add_permission(permission);
self
}
pub fn add_filter_backend(mut self, backend: Arc<dyn FilterBackend>) -> Self {
self.handler = std::mem::take(&mut self.handler).add_filter_backend(backend);
self
}
pub fn as_view(self) -> crate::viewsets::builder::ViewSetBuilder<Self> {
crate::viewsets::builder::ViewSetBuilder::new(self)
}
}
#[async_trait]
impl<M, S> ViewSet for ReadOnlyModelViewSet<M, S>
where
M: Model + Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
fn get_basename(&self) -> &str {
&self.basename
}
fn get_lookup_field(&self) -> &str {
&self.lookup_field
}
async fn dispatch(&self, request: Request, action: Action) -> Result<Response> {
match (request.method.clone(), action.detail) {
(Method::GET, false) => self.handler.list(&request).await.map_err(Into::into),
(Method::GET, true) => {
let pk = extract_pk(&request, &self.lookup_field)?;
self.handler
.retrieve(&request, pk)
.await
.map_err(Into::into)
}
_ => Err(method_not_allowed(&request.method)),
}
}
}
impl<M, S> PaginatedViewSet for ReadOnlyModelViewSet<M, S>
where
M: Model + Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
fn get_pagination_config(&self) -> Option<PaginationConfig> {
self.pagination_config.clone()
}
}
impl<M, S> FilterableViewSet for ReadOnlyModelViewSet<M, S>
where
M: Model + Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
fn get_filter_config(&self) -> Option<FilterConfig> {
self.filter_config.clone()
}
fn get_ordering_config(&self) -> Option<OrderingConfig> {
self.ordering_config.clone()
}
}
impl<M, S> std::panic::UnwindSafe for ModelViewSet<M, S>
where
M: Model + Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
}
impl<M, S> std::panic::RefUnwindSafe for ModelViewSet<M, S>
where
M: Model + Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
}
impl<M, S> std::panic::UnwindSafe for ReadOnlyModelViewSet<M, S>
where
M: Model + Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
}
impl<M, S> std::panic::RefUnwindSafe for ReadOnlyModelViewSet<M, S>
where
M: Model + Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
}
#[cfg(test)]
mod tests {
use super::*;
use hyper::Method;
use reinhardt_db::orm::{FieldSelector, Model};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct DummyModel {
id: Option<i64>,
}
#[derive(Clone)]
struct DummyFields;
impl FieldSelector for DummyFields {
fn with_alias(self, _alias: &str) -> Self {
self
}
}
impl Model for DummyModel {
type PrimaryKey = i64;
type Fields = DummyFields;
fn table_name() -> &'static str {
"dummy"
}
fn primary_key(&self) -> Option<Self::PrimaryKey> {
self.id
}
fn set_primary_key(&mut self, value: Self::PrimaryKey) {
self.id = Some(value);
}
fn new_fields() -> Self::Fields {
DummyFields
}
}
#[tokio::test]
async fn test_viewset_builder_validation_empty_actions() {
let viewset = ModelViewSet::<DummyModel, ()>::new("test");
let builder = viewset.as_view();
let result = builder.build();
assert!(result.is_err());
match result {
Err(e) => assert!(
e.to_string()
.contains("The `actions` argument must be provided")
),
Ok(_) => panic!("Expected error but got success"),
}
}
#[tokio::test]
async fn test_viewset_builder_name_suffix_mutual_exclusivity() {
let viewset = ModelViewSet::<DummyModel, ()>::new("test");
let builder = viewset.as_view();
let result = builder
.with_name("test_name")
.and_then(|b| b.with_suffix("test_suffix"));
assert!(result.is_err());
match result {
Err(e) => assert!(e.to_string().contains("received both `name` and `suffix`")),
Ok(_) => panic!("Expected error but got success"),
}
}
#[tokio::test]
async fn test_viewset_builder_successful_build() {
let viewset = ModelViewSet::<DummyModel, ()>::new("test");
let mut actions = HashMap::new();
actions.insert(Method::GET, "list".to_string());
let builder = viewset.as_view();
let result = builder.with_actions(actions).build();
let handler = result.unwrap();
assert!(Arc::strong_count(&handler) > 0);
}
#[tokio::test]
async fn test_viewset_builder_with_name() {
let viewset = ModelViewSet::<DummyModel, ()>::new("test");
let mut actions = HashMap::new();
actions.insert(Method::GET, "list".to_string());
let builder = viewset.as_view();
let result = builder
.with_actions(actions)
.with_name("test_view")
.and_then(|b| b.build());
assert!(result.is_ok());
}
#[tokio::test]
async fn test_viewset_builder_with_suffix() {
let viewset = ModelViewSet::<DummyModel, ()>::new("test");
let mut actions = HashMap::new();
actions.insert(Method::GET, "list".to_string());
let builder = viewset.as_view();
let result = builder
.with_actions(actions)
.with_suffix("_list")
.and_then(|b| b.build());
assert!(result.is_ok());
}
}