#![allow(deprecated)]
use super::error::ViewError;
use reinhardt_auth::{Permission, PermissionContext};
use reinhardt_db::orm::{Model, query_types::DbBackend};
use reinhardt_http::{Request, Response};
use reinhardt_rest::filters::FilterBackend;
use reinhardt_rest::serializers::{ModelSerializer, Serializer};
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::marker::PhantomData;
use std::sync::Arc;
pub struct ModelViewSetHandler<T>
where
T: Model + Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
{
queryset: Option<Vec<T>>,
serializer_class: Option<Arc<dyn Serializer<Input = T, Output = String> + Send + Sync>>,
permission_classes: Vec<Arc<dyn Permission>>,
filter_backends: Vec<Arc<dyn FilterBackend>>,
pagination_class: Option<reinhardt_core::pagination::PaginatorImpl>,
pool: Option<Arc<sqlx::AnyPool>>,
db_backend: DbBackend,
_phantom: PhantomData<T>,
}
impl<T> ModelViewSetHandler<T>
where
T: Model + Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
queryset: None,
serializer_class: None,
permission_classes: Vec::new(),
filter_backends: Vec::new(),
pagination_class: None,
pool: None,
db_backend: DbBackend::Postgres, _phantom: PhantomData,
}
}
pub fn with_queryset(mut self, queryset: Vec<T>) -> Self {
self.queryset = Some(queryset);
self
}
pub fn with_serializer(
mut self,
serializer: Arc<dyn Serializer<Input = T, Output = String> + Send + Sync>,
) -> Self {
self.serializer_class = Some(serializer);
self
}
pub fn with_pool(mut self, pool: Arc<sqlx::AnyPool>) -> Self {
self.pool = Some(pool);
self
}
pub fn with_db_backend(mut self, db_backend: DbBackend) -> Self {
self.db_backend = db_backend;
self
}
pub fn add_permission(mut self, permission: Arc<dyn Permission>) -> Self {
self.permission_classes.push(permission);
self
}
pub fn add_filter_backend(mut self, backend: Arc<dyn FilterBackend>) -> Self {
self.filter_backends.push(backend);
self
}
pub fn with_pagination(
mut self,
pagination: reinhardt_core::pagination::PaginatorImpl,
) -> Self {
self.pagination_class = Some(pagination);
self
}
fn get_queryset(&self) -> &[T] {
self.queryset.as_deref().unwrap_or(&[])
}
fn get_serializer(&self) -> Arc<dyn Serializer<Input = T, Output = String> + Send + Sync> {
self.serializer_class
.clone()
.unwrap_or_else(|| Arc::new(ModelSerializer::<T>::new()))
}
async fn check_permissions(&self, request: &Request) -> std::result::Result<(), ViewError> {
let user_id_string: Option<String> = request.extensions.get::<String>().or_else(|| {
request
.extensions
.get::<uuid::Uuid>()
.map(|id| id.to_string())
});
let is_authenticated = user_id_string.is_some();
let (is_admin, is_active, user_obj) = if let (Some(user_id_str), Some(_pool)) =
(user_id_string.as_ref(), self.pool.as_ref())
{
#[cfg(feature = "argon2-hasher")]
match uuid::Uuid::parse_str(user_id_str) {
Ok(user_uuid) => {
use reinhardt_db::orm::manager::get_connection;
match get_connection().await {
Ok(conn) => {
use reinhardt_auth::DefaultUser;
use reinhardt_db::orm::{
Alias, ColumnRef, DatabaseBackend, Expr, ExprTrait, Model,
MySqlQueryBuilder, PostgresQueryBuilder, Query,
QueryStatementBuilder, SqliteQueryBuilder,
};
let table_name = DefaultUser::table_name();
let pk_field = DefaultUser::primary_key_field();
let stmt = Query::select()
.column(ColumnRef::Asterisk)
.from(Alias::new(table_name))
.and_where(
Expr::col(Alias::new(pk_field))
.eq(Expr::value(user_uuid.to_string())),
)
.to_owned();
let sql = match conn.backend() {
DatabaseBackend::Postgres => stmt.to_string(PostgresQueryBuilder),
DatabaseBackend::MySql => stmt.to_string(MySqlQueryBuilder),
DatabaseBackend::Sqlite => stmt.to_string(SqliteQueryBuilder),
};
match conn.query_optional(&sql, vec![]).await {
Ok(Some(row)) => {
match serde_json::from_value::<DefaultUser>(row.data) {
Ok(user) => {
use reinhardt_auth::User;
let is_admin = user.is_admin();
let is_active = user.is_active();
let boxed_user: Box<dyn User> = Box::new(user);
(is_admin, is_active, Some(boxed_user))
}
Err(_) => {
(false, true, None)
}
}
}
Ok(None) => {
(false, true, None)
}
Err(_) => {
(false, true, None)
}
}
}
Err(_) => {
(false, true, None)
}
}
}
Err(_) => {
(false, true, None)
}
}
#[cfg(not(feature = "argon2-hasher"))]
{
let _ = user_id_str; (false, true, None)
}
} else {
(false, true, None)
};
let context = PermissionContext {
request,
is_authenticated,
is_admin,
is_active,
user: user_obj,
};
for permission in &self.permission_classes {
if !permission.has_permission(&context).await {
return Err(ViewError::Permission(format!(
"Permission denied by {}",
std::any::type_name_of_val(&**permission)
)));
}
}
Ok(())
}
pub async fn list(&self, request: &Request) -> std::result::Result<Response, ViewError> {
self.check_permissions(request).await?;
let serializer = self.get_serializer();
let items: Vec<T> = if let Some(pool) = &self.pool {
let session = reinhardt_db::prelude::Session::new(pool.clone(), self.db_backend)
.await
.map_err(|e| {
ViewError::DatabaseError(format!("Failed to create session: {}", e))
})?;
session
.list_all()
.await
.map_err(|e| ViewError::DatabaseError(format!("Failed to list objects: {}", e)))?
} else {
self.get_queryset().to_vec()
};
let mut serialized_items = Vec::new();
for item in &items {
let json = serializer
.serialize(item)
.map_err(|e| ViewError::Serialization(e.to_string()))?;
serialized_items.push(json);
}
let response_body = format!("[{}]", serialized_items.join(","));
Ok(Response::ok().with_body(response_body))
}
pub async fn retrieve(
&self,
request: &Request,
pk: serde_json::Value,
) -> std::result::Result<Response, ViewError> {
self.check_permissions(request).await?;
let serializer = self.get_serializer();
let item: T = if let Some(pool) = &self.pool {
let session = reinhardt_db::prelude::Session::new(pool.clone(), self.db_backend)
.await
.map_err(|e| {
ViewError::DatabaseError(format!("Failed to create session: {}", e))
})?;
let items: Vec<T> = session
.list_all()
.await
.map_err(|e| ViewError::DatabaseError(format!("Failed to query objects: {}", e)))?;
let pk_str = pk.to_string();
let pk_str = pk_str.trim_matches('"');
items
.into_iter()
.find(|item| {
if let Some(item_pk) = item.primary_key() {
item_pk.to_string() == pk_str
} else {
false
}
})
.ok_or_else(|| ViewError::NotFound(format!("Object with pk={} not found", pk)))?
} else {
let queryset = self.get_queryset();
let pk_str = pk.to_string();
let pk_str = pk_str.trim_matches('"');
queryset
.iter()
.find(|item| {
if let Some(item_pk) = item.primary_key() {
item_pk.to_string() == pk_str
} else {
false
}
})
.cloned()
.ok_or_else(|| ViewError::NotFound(format!("Object with pk={} not found", pk)))?
};
let json = serializer
.serialize(&item)
.map_err(|e| ViewError::Serialization(e.to_string()))?;
Ok(Response::ok().with_body(json))
}
pub async fn create(&self, request: &Request) -> std::result::Result<Response, ViewError> {
self.check_permissions(request).await?;
let serializer = self.get_serializer();
let body_str = String::from_utf8(request.body().to_vec())
.map_err(|e| ViewError::BadRequest(format!("Invalid UTF-8: {}", e)))?;
let item = serializer
.deserialize(&body_str)
.map_err(|e| ViewError::Serialization(e.to_string()))?;
if let Some(pool) = &self.pool {
let mut session = reinhardt_db::prelude::Session::new(pool.clone(), self.db_backend)
.await
.map_err(|e| {
ViewError::DatabaseError(format!("Failed to create session: {}", e))
})?;
session.begin().await.map_err(|e| {
ViewError::DatabaseError(format!("Failed to begin transaction: {}", e))
})?;
session
.add(item.clone())
.await
.map_err(|e| ViewError::DatabaseError(format!("Failed to add object: {}", e)))?;
session
.flush()
.await
.map_err(|e| ViewError::DatabaseError(format!("Failed to flush: {}", e)))?;
let generated_id = session.get_generated_ids().first().map(|(_, id)| *id);
session
.commit()
.await
.map_err(|e| ViewError::DatabaseError(format!("Failed to commit: {}", e)))?;
if let Some(id) = generated_id {
let fetch_session =
reinhardt_db::prelude::Session::new(pool.clone(), self.db_backend)
.await
.map_err(|e| {
ViewError::DatabaseError(format!("Failed to create session: {}", e))
})?;
let items: Vec<T> = fetch_session.list_all().await.map_err(|e| {
ViewError::DatabaseError(format!("Failed to fetch objects: {}", e))
})?;
let created_item = items
.into_iter()
.find(|i| {
i.primary_key()
.map(|pk| pk.to_string() == id.to_string())
.unwrap_or(false)
})
.ok_or_else(|| {
ViewError::DatabaseError("Failed to find created object".to_string())
})?;
let response_body = serializer
.serialize(&created_item)
.map_err(|e| ViewError::Serialization(e.to_string()))?;
return Ok(Response::created().with_body(response_body));
}
}
let response_body = serializer
.serialize(&item)
.map_err(|e| ViewError::Serialization(e.to_string()))?;
Ok(Response::created().with_body(response_body))
}
pub async fn update(
&self,
request: &Request,
pk: serde_json::Value,
) -> std::result::Result<Response, ViewError> {
self.check_permissions(request).await?;
let serializer = self.get_serializer();
let existing_obj: T = if let Some(pool) = &self.pool {
let session = reinhardt_db::prelude::Session::new(pool.clone(), self.db_backend)
.await
.map_err(|e| {
ViewError::DatabaseError(format!("Failed to create session: {}", e))
})?;
let items: Vec<T> = session
.list_all()
.await
.map_err(|e| ViewError::DatabaseError(format!("Failed to list objects: {}", e)))?;
let pk_str_owned = pk.to_string();
let pk_str = pk_str_owned.trim_matches('"');
items
.into_iter()
.find(|item| {
if let Some(item_pk) = item.primary_key() {
item_pk.to_string() == pk_str
} else {
false
}
})
.ok_or_else(|| {
ViewError::NotFound(format!("Object with pk {} not found", pk_str))
})?
} else {
let pk_str_owned = pk.to_string();
let pk_str = pk_str_owned.trim_matches('"');
self.get_queryset()
.iter()
.find(|item| {
if let Some(item_pk) = item.primary_key() {
item_pk.to_string() == pk_str
} else {
false
}
})
.cloned()
.ok_or_else(|| {
ViewError::NotFound(format!("Object with pk {} not found", pk_str))
})?
};
let body_str = String::from_utf8(request.body().to_vec())
.map_err(|e| ViewError::BadRequest(format!("Invalid UTF-8: {}", e)))?;
let patch_data: serde_json::Value = serde_json::from_str(&body_str)
.map_err(|e| ViewError::Serialization(format!("Invalid JSON: {}", e)))?;
let existing_json = serializer
.serialize(&existing_obj)
.map_err(|e| ViewError::Serialization(e.to_string()))?;
let mut existing_value: serde_json::Value = serde_json::from_str(&existing_json)
.map_err(|e| ViewError::Serialization(format!("Failed to parse existing: {}", e)))?;
crate::generic::patch_utils::merge_patch_object_into(&mut existing_value, &patch_data)
.map_err(ViewError::BadRequest)?;
let merged_json = serde_json::to_string(&existing_value)
.map_err(|e| ViewError::Serialization(format!("Failed to serialize merged: {}", e)))?;
let updated_item: T = serializer
.deserialize(&merged_json)
.map_err(|e| ViewError::Serialization(e.to_string()))?;
if let Some(pool) = &self.pool {
let mut session = reinhardt_db::prelude::Session::new(pool.clone(), self.db_backend)
.await
.map_err(|e| {
ViewError::DatabaseError(format!("Failed to create session: {}", e))
})?;
session.begin().await.map_err(|e| {
ViewError::DatabaseError(format!("Failed to begin transaction: {}", e))
})?;
session
.add(updated_item.clone())
.await
.map_err(|e| ViewError::DatabaseError(format!("Failed to add object: {}", e)))?;
session
.flush()
.await
.map_err(|e| ViewError::DatabaseError(format!("Failed to flush: {}", e)))?;
session
.commit()
.await
.map_err(|e| ViewError::DatabaseError(format!("Failed to commit: {}", e)))?;
}
Ok(Response::ok().with_body(merged_json))
}
pub async fn destroy(
&self,
request: &Request,
pk: serde_json::Value,
) -> std::result::Result<Response, ViewError> {
self.check_permissions(request).await?;
let serializer = self.get_serializer();
let response = self.retrieve(request, pk).await?;
let body_str = String::from_utf8(response.body.to_vec())
.map_err(|e| ViewError::BadRequest(format!("Invalid UTF-8: {}", e)))?;
let item = serializer
.deserialize(&body_str)
.map_err(|e| ViewError::Serialization(e.to_string()))?;
if let Some(pool) = &self.pool {
let mut session = reinhardt_db::prelude::Session::new(pool.clone(), self.db_backend)
.await
.map_err(|e| {
ViewError::DatabaseError(format!("Failed to create session: {}", e))
})?;
session.begin().await.map_err(|e| {
ViewError::DatabaseError(format!("Failed to begin transaction: {}", e))
})?;
session.delete(item).await.map_err(|e| {
ViewError::DatabaseError(format!("Failed to mark object for deletion: {}", e))
})?;
session
.flush()
.await
.map_err(|e| ViewError::DatabaseError(format!("Failed to flush: {}", e)))?;
session
.commit()
.await
.map_err(|e| ViewError::DatabaseError(format!("Failed to commit: {}", e)))?;
}
Ok(Response::no_content())
}
}
impl<T> Default for ModelViewSetHandler<T>
where
T: Model + Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
{
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use hyper::{HeaderMap, Method, Version};
use reinhardt_http::Request;
use rstest::rstest;
fn build_request(uri: &str) -> Request {
Request::builder()
.method(Method::GET)
.uri(uri)
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap()
}
#[derive(Debug, Clone, Serialize, serde::Deserialize, PartialEq)]
struct TestItem {
id: Option<i64>,
name: String,
}
#[derive(Clone)]
struct TestItemFields;
impl reinhardt_db::orm::FieldSelector for TestItemFields {
fn with_alias(self, _alias: &str) -> Self {
self
}
}
impl reinhardt_db::orm::Model for TestItem {
type PrimaryKey = i64;
type Fields = TestItemFields;
fn table_name() -> &'static str {
"test_items"
}
fn primary_key(&self) -> Option<Self::PrimaryKey> {
self.id
}
fn set_primary_key(&mut self, value: Self::PrimaryKey) {
self.id = Some(value);
}
fn new_fields() -> Self::Fields {
TestItemFields
}
}
fn build_model_handler(items: Vec<TestItem>) -> ModelViewSetHandler<TestItem> {
ModelViewSetHandler::<TestItem>::new().with_queryset(items)
}
#[rstest]
#[tokio::test]
async fn test_retrieve_strips_quotes_from_numeric_pk() {
let items = vec![
TestItem {
id: Some(1),
name: "first".to_string(),
},
TestItem {
id: Some(2),
name: "second".to_string(),
},
];
let handler = build_model_handler(items);
let request = build_request("/items/1/");
let pk = serde_json::json!("1");
let result = handler.retrieve(&request, pk).await;
assert!(result.is_ok(), "retrieve should succeed with quoted pk");
let response = result.unwrap();
assert_eq!(response.status, hyper::StatusCode::OK);
let body: TestItem =
serde_json::from_slice(&response.body).expect("response should be valid JSON");
assert_eq!(body.name, "first");
assert_eq!(body.id, Some(1));
}
#[rstest]
#[tokio::test]
async fn test_retrieve_works_with_unquoted_numeric_pk() {
let items = vec![TestItem {
id: Some(42),
name: "answer".to_string(),
}];
let handler = build_model_handler(items);
let request = build_request("/items/42/");
let pk = serde_json::json!(42);
let result = handler.retrieve(&request, pk).await;
assert!(result.is_ok(), "retrieve should succeed with numeric pk");
let response = result.unwrap();
assert_eq!(response.status, hyper::StatusCode::OK);
let body: TestItem =
serde_json::from_slice(&response.body).expect("response should be valid JSON");
assert_eq!(body.name, "answer");
assert_eq!(body.id, Some(42));
}
#[rstest]
#[tokio::test]
async fn test_retrieve_returns_not_found_for_nonexistent_pk() {
let items = vec![TestItem {
id: Some(1),
name: "only".to_string(),
}];
let handler = build_model_handler(items);
let request = build_request("/items/999/");
let pk = serde_json::json!(999);
let result = handler.retrieve(&request, pk).await;
assert!(result.is_err(), "retrieve should fail for nonexistent pk");
let err = result.unwrap_err();
assert!(
matches!(err, ViewError::NotFound(_)),
"error should be NotFound, got: {:?}",
err
);
}
}