reinhardt_views/viewsets/handler/model_view_set_handler.rs
1// The `User` trait and `DefaultUser` struct are deprecated in favour of the new
2// `#[model]`-based user macro system. This file references them during the
3// transition period until viewsets are migrated to `AuthIdentity`.
4#![allow(deprecated)]
5
6//! `ModelViewSetHandler` — Django REST Framework-style CRUD handler.
7//!
8//! Provides the standard list/retrieve/create/update/destroy actions with
9//! permission checks, optional pagination, and serialization for `Model`
10//! types. The response rendering for each action lives next to the action
11//! itself in this module.
12
13use super::error::ViewError;
14use reinhardt_auth::{Permission, PermissionContext};
15use reinhardt_db::orm::{Model, query_types::DbBackend};
16use reinhardt_http::{Request, Response};
17use reinhardt_rest::filters::FilterBackend;
18use reinhardt_rest::serializers::{ModelSerializer, Serializer};
19use serde::Serialize;
20use serde::de::DeserializeOwned;
21use std::marker::PhantomData;
22use std::sync::Arc;
23
24/// Django REST Framework-style ViewSet handler for models.
25///
26/// Provides automatic CRUD operations with permission checks, filtering,
27/// pagination, and serialization for Model types.
28///
29/// # Examples
30///
31/// ```no_run
32/// # use reinhardt_views::viewsets::ModelViewSetHandler;
33/// # use reinhardt_db::orm::Model;
34/// # use serde::{Serialize, Deserialize};
35/// #
36/// # #[derive(Serialize, Deserialize, Clone, Debug)]
37/// # struct User {
38/// # id: Option<i64>,
39/// # username: String,
40/// # }
41/// #
42/// # #[derive(Clone)]
43/// # struct UserFields;
44/// #
45/// # impl reinhardt_db::orm::FieldSelector for UserFields {
46/// # fn with_alias(self, _alias: &str) -> Self { self }
47/// # }
48/// #
49/// # impl Model for User {
50/// # type PrimaryKey = i64;
51/// # type Fields = UserFields;
52/// # fn table_name() -> &'static str { "users" }
53/// # fn primary_key(&self) -> Option<Self::PrimaryKey> { self.id }
54/// # fn set_primary_key(&mut self, value: Self::PrimaryKey) { self.id = Some(value); }
55/// # fn new_fields() -> Self::Fields { UserFields }
56/// # }
57/// #
58/// # async fn example() {
59/// let handler = ModelViewSetHandler::<User>::new();
60/// # }
61/// ```
62pub struct ModelViewSetHandler<T>
63where
64 T: Model + Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
65{
66 queryset: Option<Vec<T>>,
67 serializer_class: Option<Arc<dyn Serializer<Input = T, Output = String> + Send + Sync>>,
68 permission_classes: Vec<Arc<dyn Permission>>,
69 filter_backends: Vec<Arc<dyn FilterBackend>>,
70 pagination_class: Option<reinhardt_core::pagination::PaginatorImpl>,
71 pool: Option<Arc<sqlx::AnyPool>>,
72 /// Database backend type (default: PostgreSQL)
73 db_backend: DbBackend,
74 _phantom: PhantomData<T>,
75}
76
77impl<T> ModelViewSetHandler<T>
78where
79 T: Model + Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
80{
81 /// Create a new ModelViewSetHandler
82 ///
83 /// # Examples
84 ///
85 /// ```
86 /// # use reinhardt_views::viewsets::ModelViewSetHandler;
87 /// # use reinhardt_db::orm::Model;
88 /// # use serde::{Serialize, Deserialize};
89 /// #
90 /// # #[derive(Debug, Clone, Serialize, Deserialize)]
91 /// # struct User {
92 /// # id: Option<i64>,
93 /// # username: String,
94 /// # }
95 /// #
96 /// # #[derive(Clone)]
97 /// # struct UserFields;
98 /// #
99 /// # impl reinhardt_db::orm::FieldSelector for UserFields {
100 /// # fn with_alias(self, _alias: &str) -> Self { self }
101 /// # }
102 /// #
103 /// # impl Model for User {
104 /// # type PrimaryKey = i64;
105 /// # type Fields = UserFields;
106 /// # fn table_name() -> &'static str { "users" }
107 /// # fn primary_key(&self) -> Option<Self::PrimaryKey> { self.id }
108 /// # fn set_primary_key(&mut self, value: Self::PrimaryKey) { self.id = Some(value); }
109 /// # fn new_fields() -> Self::Fields { UserFields }
110 /// # }
111 /// let handler = ModelViewSetHandler::<User>::new();
112 /// ```
113 pub fn new() -> Self {
114 Self {
115 queryset: None,
116 serializer_class: None,
117 permission_classes: Vec::new(),
118 filter_backends: Vec::new(),
119 pagination_class: None,
120 pool: None,
121 db_backend: DbBackend::Postgres, // Default to PostgreSQL
122 _phantom: PhantomData,
123 }
124 }
125
126 /// Set the queryset (in-memory data) for this handler
127 ///
128 /// # Examples
129 ///
130 /// ```
131 /// # use reinhardt_views::viewsets::ModelViewSetHandler;
132 /// # use reinhardt_db::orm::Model;
133 /// # use serde::{Serialize, Deserialize};
134 /// #
135 /// # #[derive(Debug, Clone, Serialize, Deserialize)]
136 /// # struct User {
137 /// # id: Option<i64>,
138 /// # username: String,
139 /// # }
140 /// #
141 /// # #[derive(Clone)]
142 /// # struct UserFields;
143 /// #
144 /// # impl reinhardt_db::orm::FieldSelector for UserFields {
145 /// # fn with_alias(self, _alias: &str) -> Self { self }
146 /// # }
147 /// #
148 /// # impl Model for User {
149 /// # type PrimaryKey = i64;
150 /// # type Fields = UserFields;
151 /// # fn table_name() -> &'static str { "users" }
152 /// # fn primary_key(&self) -> Option<Self::PrimaryKey> { self.id }
153 /// # fn set_primary_key(&mut self, value: Self::PrimaryKey) { self.id = Some(value); }
154 /// # fn new_fields() -> Self::Fields { UserFields }
155 /// # }
156 /// let users = vec![
157 /// User { id: Some(1), username: "alice".to_string() },
158 /// User { id: Some(2), username: "bob".to_string() },
159 /// ];
160 /// let handler = ModelViewSetHandler::<User>::new()
161 /// .with_queryset(users);
162 /// ```
163 pub fn with_queryset(mut self, queryset: Vec<T>) -> Self {
164 self.queryset = Some(queryset);
165 self
166 }
167
168 /// Set the serializer class for this handler
169 ///
170 /// # Examples
171 ///
172 /// ```
173 /// # use reinhardt_views::viewsets::ModelViewSetHandler;
174 /// # use reinhardt_rest::serializers::ModelSerializer;
175 /// # use reinhardt_db::orm::Model;
176 /// # use serde::{Serialize, Deserialize};
177 /// # use std::sync::Arc;
178 /// #
179 /// # #[derive(Debug, Clone, Serialize, Deserialize)]
180 /// # struct User {
181 /// # id: Option<i64>,
182 /// # username: String,
183 /// # }
184 /// #
185 /// # #[derive(Clone)]
186 /// # struct UserFields;
187 /// #
188 /// # impl reinhardt_db::orm::FieldSelector for UserFields {
189 /// # fn with_alias(self, _alias: &str) -> Self { self }
190 /// # }
191 /// #
192 /// # impl Model for User {
193 /// # type PrimaryKey = i64;
194 /// # type Fields = UserFields;
195 /// # fn table_name() -> &'static str { "users" }
196 /// # fn primary_key(&self) -> Option<Self::PrimaryKey> { self.id }
197 /// # fn set_primary_key(&mut self, value: Self::PrimaryKey) { self.id = Some(value); }
198 /// # fn new_fields() -> Self::Fields { UserFields }
199 /// # }
200 /// let serializer = Arc::new(ModelSerializer::<User>::new());
201 /// let handler = ModelViewSetHandler::<User>::new()
202 /// .with_serializer(serializer);
203 /// ```
204 pub fn with_serializer(
205 mut self,
206 serializer: Arc<dyn Serializer<Input = T, Output = String> + Send + Sync>,
207 ) -> Self {
208 self.serializer_class = Some(serializer);
209 self
210 }
211
212 /// Set the database connection pool for this handler
213 ///
214 /// # Examples
215 ///
216 /// ```no_run
217 /// # use reinhardt_views::viewsets::ModelViewSetHandler;
218 /// # use reinhardt_db::orm::Model;
219 /// # use serde::{Serialize, Deserialize};
220 /// # use sqlx::AnyPool;
221 /// # use std::sync::Arc;
222 /// #
223 /// # #[derive(Debug, Clone, Serialize, Deserialize)]
224 /// # struct User {
225 /// # id: Option<i64>,
226 /// # username: String,
227 /// # }
228 /// #
229 /// # #[derive(Clone)]
230 /// # struct UserFields;
231 /// #
232 /// # impl reinhardt_db::orm::FieldSelector for UserFields {
233 /// # fn with_alias(self, _alias: &str) -> Self { self }
234 /// # }
235 /// #
236 /// # impl Model for User {
237 /// # type PrimaryKey = i64;
238 /// # type Fields = UserFields;
239 /// # fn table_name() -> &'static str { "users" }
240 /// # fn primary_key(&self) -> Option<Self::PrimaryKey> { self.id }
241 /// # fn set_primary_key(&mut self, value: Self::PrimaryKey) { self.id = Some(value); }
242 /// # fn new_fields() -> Self::Fields { UserFields }
243 /// # }
244 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
245 /// let pool = Arc::new(AnyPool::connect("postgres://localhost/mydb").await?);
246 /// let handler = ModelViewSetHandler::<User>::new()
247 /// .with_pool(pool);
248 /// # Ok(())
249 /// # }
250 /// ```
251 pub fn with_pool(mut self, pool: Arc<sqlx::AnyPool>) -> Self {
252 self.pool = Some(pool);
253 self
254 }
255
256 /// Set the database backend type for this handler
257 ///
258 /// # Examples
259 ///
260 /// ```
261 /// # use reinhardt_views::viewsets::ModelViewSetHandler;
262 /// # use reinhardt_db::orm::{Model, query_types::DbBackend};
263 /// # use serde::{Serialize, Deserialize};
264 /// #
265 /// # #[derive(Debug, Clone, Serialize, Deserialize)]
266 /// # struct User {
267 /// # id: Option<i64>,
268 /// # username: String,
269 /// # }
270 /// #
271 /// # #[derive(Clone)]
272 /// # struct UserFields;
273 /// #
274 /// # impl reinhardt_db::orm::FieldSelector for UserFields {
275 /// # fn with_alias(self, _alias: &str) -> Self { self }
276 /// # }
277 /// #
278 /// # impl Model for User {
279 /// # type PrimaryKey = i64;
280 /// # type Fields = UserFields;
281 /// # fn table_name() -> &'static str { "users" }
282 /// # fn primary_key(&self) -> Option<Self::PrimaryKey> { self.id }
283 /// # fn set_primary_key(&mut self, value: Self::PrimaryKey) { self.id = Some(value); }
284 /// # fn new_fields() -> Self::Fields { UserFields }
285 /// # }
286 /// let handler = ModelViewSetHandler::<User>::new()
287 /// .with_db_backend(DbBackend::Sqlite);
288 /// ```
289 pub fn with_db_backend(mut self, db_backend: DbBackend) -> Self {
290 self.db_backend = db_backend;
291 self
292 }
293
294 /// Add a permission class to this handler
295 ///
296 /// # Examples
297 ///
298 /// ```
299 /// # use reinhardt_views::viewsets::ModelViewSetHandler;
300 /// # use reinhardt_auth::IsAuthenticated;
301 /// # use reinhardt_db::orm::Model;
302 /// # use serde::{Serialize, Deserialize};
303 /// # use std::sync::Arc;
304 /// #
305 /// # #[derive(Debug, Clone, Serialize, Deserialize)]
306 /// # struct User {
307 /// # id: Option<i64>,
308 /// # username: String,
309 /// # }
310 /// #
311 /// # #[derive(Clone)]
312 /// # struct UserFields;
313 /// #
314 /// # impl reinhardt_db::orm::FieldSelector for UserFields {
315 /// # fn with_alias(self, _alias: &str) -> Self { self }
316 /// # }
317 /// #
318 /// # impl Model for User {
319 /// # type PrimaryKey = i64;
320 /// # type Fields = UserFields;
321 /// # fn table_name() -> &'static str { "users" }
322 /// # fn primary_key(&self) -> Option<Self::PrimaryKey> { self.id }
323 /// # fn set_primary_key(&mut self, value: Self::PrimaryKey) { self.id = Some(value); }
324 /// # fn new_fields() -> Self::Fields { UserFields }
325 /// # }
326 /// let handler = ModelViewSetHandler::<User>::new()
327 /// .add_permission(Arc::new(IsAuthenticated));
328 /// ```
329 pub fn add_permission(mut self, permission: Arc<dyn Permission>) -> Self {
330 self.permission_classes.push(permission);
331 self
332 }
333
334 /// Add a filter backend to this handler
335 pub fn add_filter_backend(mut self, backend: Arc<dyn FilterBackend>) -> Self {
336 self.filter_backends.push(backend);
337 self
338 }
339
340 /// Set the pagination class for this handler
341 pub fn with_pagination(
342 mut self,
343 pagination: reinhardt_core::pagination::PaginatorImpl,
344 ) -> Self {
345 self.pagination_class = Some(pagination);
346 self
347 }
348
349 /// Get the queryset for this handler
350 fn get_queryset(&self) -> &[T] {
351 self.queryset.as_deref().unwrap_or(&[])
352 }
353
354 /// Get the serializer for this handler
355 fn get_serializer(&self) -> Arc<dyn Serializer<Input = T, Output = String> + Send + Sync> {
356 self.serializer_class
357 .clone()
358 .unwrap_or_else(|| Arc::new(ModelSerializer::<T>::new()))
359 }
360
361 /// Check permissions for the request
362 async fn check_permissions(&self, request: &Request) -> std::result::Result<(), ViewError> {
363 // Extract authentication information from request extensions
364 // The session middleware stores authenticated user_id in extensions
365 //
366 // Expected usage:
367 // 1. Session middleware extracts session from cookie/token
368 // 2. Middleware validates session and extracts user_id
369 // 3. Middleware stores user_id in request.extensions using a dedicated type
370 //
371 // Example middleware implementation:
372 // if let Some(user_id) = session.get::<i64>("user_id").ok().flatten() {
373 // request.extensions.insert(AuthenticatedUserId(user_id));
374 // }
375
376 // Try to extract user_id from extensions
377 // Support both String and UUID formats
378 let user_id_string: Option<String> = request.extensions.get::<String>().or_else(|| {
379 request
380 .extensions
381 .get::<uuid::Uuid>()
382 .map(|id| id.to_string())
383 });
384
385 // Determine authentication status based on user_id presence
386 let is_authenticated = user_id_string.is_some();
387
388 // Load user from database if authenticated and pool is available
389 let (is_admin, is_active, user_obj) = if let (Some(user_id_str), Some(_pool)) =
390 (user_id_string.as_ref(), self.pool.as_ref())
391 {
392 // Parse user_id as UUID
393 #[cfg(feature = "argon2-hasher")]
394 match uuid::Uuid::parse_str(user_id_str) {
395 Ok(user_uuid) => {
396 // Get database connection
397 use reinhardt_db::orm::manager::get_connection;
398 match get_connection().await {
399 Ok(conn) => {
400 // Build SQL query using reinhardt-query for type-safe query construction
401 use reinhardt_auth::DefaultUser;
402 use reinhardt_db::orm::{
403 Alias, ColumnRef, DatabaseBackend, Expr, ExprTrait, Model,
404 MySqlQueryBuilder, PostgresQueryBuilder, Query,
405 QueryStatementBuilder, SqliteQueryBuilder,
406 };
407
408 let table_name = DefaultUser::table_name();
409 let pk_field = DefaultUser::primary_key_field();
410
411 // Build SELECT * query using reinhardt-query
412 let stmt = Query::select()
413 .column(ColumnRef::Asterisk)
414 .from(Alias::new(table_name))
415 .and_where(
416 Expr::col(Alias::new(pk_field))
417 .eq(Expr::value(user_uuid.to_string())),
418 )
419 .to_owned();
420
421 let sql = match conn.backend() {
422 DatabaseBackend::Postgres => stmt.to_string(PostgresQueryBuilder),
423 DatabaseBackend::MySql => stmt.to_string(MySqlQueryBuilder),
424 DatabaseBackend::Sqlite => stmt.to_string(SqliteQueryBuilder),
425 };
426
427 match conn.query_optional(&sql, vec![]).await {
428 Ok(Some(row)) => {
429 // Deserialize user from query result
430 match serde_json::from_value::<DefaultUser>(row.data) {
431 Ok(user) => {
432 use reinhardt_auth::User;
433 // Extract admin and active status from loaded user
434 let is_admin = user.is_admin();
435 let is_active = user.is_active();
436 // Box the user object to store in PermissionContext
437 let boxed_user: Box<dyn User> = Box::new(user);
438 (is_admin, is_active, Some(boxed_user))
439 }
440 Err(_) => {
441 // Deserialization failed, use defaults
442 (false, true, None)
443 }
444 }
445 }
446 Ok(None) => {
447 // User not found, use defaults
448 (false, true, None)
449 }
450 Err(_) => {
451 // Database query failed, use defaults
452 (false, true, None)
453 }
454 }
455 }
456 Err(_) => {
457 // Connection failed, use defaults
458 (false, true, None)
459 }
460 }
461 }
462 Err(_) => {
463 // UUID parse failed, use defaults
464 (false, true, None)
465 }
466 }
467
468 // When argon2-hasher feature is disabled, DefaultUser is not available
469 // Return default values to indicate user retrieval is not supported
470 #[cfg(not(feature = "argon2-hasher"))]
471 {
472 let _ = user_id_str; // Suppress unused variable warning
473 (false, true, None)
474 }
475 } else {
476 // Not authenticated or no pool, use defaults
477 (false, true, None)
478 };
479
480 let context = PermissionContext {
481 request,
482 is_authenticated,
483 is_admin,
484 is_active,
485 user: user_obj,
486 };
487
488 // Check all registered permission classes
489 for permission in &self.permission_classes {
490 if !permission.has_permission(&context).await {
491 // Permission denied - return specific error
492 return Err(ViewError::Permission(format!(
493 "Permission denied by {}",
494 std::any::type_name_of_val(&**permission)
495 )));
496 }
497 }
498
499 Ok(())
500 }
501
502 /// List all objects with optional filtering and pagination
503 ///
504 /// # Examples
505 ///
506 /// ```no_run
507 /// # use reinhardt_views::viewsets::ModelViewSetHandler;
508 /// # use reinhardt_http::Request;
509 /// # use reinhardt_db::orm::Model;
510 /// # use serde::{Serialize, Deserialize};
511 /// # use bytes::Bytes;
512 /// # use hyper::{Method, Version, HeaderMap};
513 /// #
514 /// # #[derive(Debug, Clone, Serialize, Deserialize)]
515 /// # struct User {
516 /// # id: Option<i64>,
517 /// # username: String,
518 /// # }
519 /// #
520 /// # #[derive(Clone)]
521 /// # struct UserFields;
522 /// #
523 /// # impl reinhardt_db::orm::FieldSelector for UserFields {
524 /// # fn with_alias(self, _alias: &str) -> Self { self }
525 /// # }
526 /// #
527 /// # impl Model for User {
528 /// # type PrimaryKey = i64;
529 /// # type Fields = UserFields;
530 /// # fn table_name() -> &'static str { "users" }
531 /// # fn primary_key(&self) -> Option<Self::PrimaryKey> { self.id }
532 /// # fn set_primary_key(&mut self, value: Self::PrimaryKey) { self.id = Some(value); }
533 /// # fn new_fields() -> Self::Fields { UserFields }
534 /// # }
535 /// #
536 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
537 /// let handler = ModelViewSetHandler::<User>::new();
538 /// let request = Request::builder()
539 /// .method(Method::GET)
540 /// .uri("/users/")
541 /// .version(Version::HTTP_11)
542 /// .headers(HeaderMap::new())
543 /// .body(Bytes::new())
544 /// .build()?;
545 /// let response = handler.list(&request).await?;
546 /// # Ok(())
547 /// # }
548 /// ```
549 pub async fn list(&self, request: &Request) -> std::result::Result<Response, ViewError> {
550 self.check_permissions(request).await?;
551
552 let serializer = self.get_serializer();
553
554 // Get items from database if pool is available, otherwise use in-memory queryset
555 let items: Vec<T> = if let Some(pool) = &self.pool {
556 // Query database for all objects
557 let session = reinhardt_db::prelude::Session::new(pool.clone(), self.db_backend)
558 .await
559 .map_err(|e| {
560 ViewError::DatabaseError(format!("Failed to create session: {}", e))
561 })?;
562
563 session
564 .list_all()
565 .await
566 .map_err(|e| ViewError::DatabaseError(format!("Failed to list objects: {}", e)))?
567 } else {
568 // Use in-memory queryset
569 self.get_queryset().to_vec()
570 };
571
572 // Serialize all objects
573 let mut serialized_items = Vec::new();
574 for item in &items {
575 let json = serializer
576 .serialize(item)
577 .map_err(|e| ViewError::Serialization(e.to_string()))?;
578 serialized_items.push(json);
579 }
580
581 // Create response body
582 let response_body = format!("[{}]", serialized_items.join(","));
583
584 Ok(Response::ok().with_body(response_body))
585 }
586
587 /// Retrieve a single object by primary key
588 ///
589 /// # Examples
590 ///
591 /// ```no_run
592 /// # use reinhardt_views::viewsets::ModelViewSetHandler;
593 /// # use reinhardt_http::Request;
594 /// # use reinhardt_db::orm::Model;
595 /// # use serde::{Serialize, Deserialize};
596 /// # use serde_json::Value;
597 /// # use bytes::Bytes;
598 /// # use hyper::{Method, Version, HeaderMap};
599 /// #
600 /// # #[derive(Debug, Clone, Serialize, Deserialize)]
601 /// # struct User {
602 /// # id: Option<i64>,
603 /// # username: String,
604 /// # }
605 /// #
606 /// # #[derive(Clone)]
607 /// # struct UserFields;
608 /// #
609 /// # impl reinhardt_db::orm::FieldSelector for UserFields {
610 /// # fn with_alias(self, _alias: &str) -> Self { self }
611 /// # }
612 /// #
613 /// # impl Model for User {
614 /// # type PrimaryKey = i64;
615 /// # type Fields = UserFields;
616 /// # fn table_name() -> &'static str { "users" }
617 /// # fn primary_key(&self) -> Option<Self::PrimaryKey> { self.id }
618 /// # fn set_primary_key(&mut self, value: Self::PrimaryKey) { self.id = Some(value); }
619 /// # fn new_fields() -> Self::Fields { UserFields }
620 /// # }
621 /// #
622 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
623 /// let handler = ModelViewSetHandler::<User>::new();
624 /// let request = Request::builder()
625 /// .method(Method::GET)
626 /// .uri("/users/1/")
627 /// .version(Version::HTTP_11)
628 /// .headers(HeaderMap::new())
629 /// .body(Bytes::new())
630 /// .build()?;
631 /// let pk = serde_json::json!(1);
632 /// let response = handler.retrieve(&request, pk).await?;
633 /// # Ok(())
634 /// # }
635 /// ```
636 pub async fn retrieve(
637 &self,
638 request: &Request,
639 pk: serde_json::Value,
640 ) -> std::result::Result<Response, ViewError> {
641 self.check_permissions(request).await?;
642
643 let serializer = self.get_serializer();
644
645 // Get item from database if pool is available, otherwise use in-memory queryset
646 let item: T = if let Some(pool) = &self.pool {
647 // Query database for all objects and find by pk
648 let session = reinhardt_db::prelude::Session::new(pool.clone(), self.db_backend)
649 .await
650 .map_err(|e| {
651 ViewError::DatabaseError(format!("Failed to create session: {}", e))
652 })?;
653
654 let items: Vec<T> = session
655 .list_all()
656 .await
657 .map_err(|e| ViewError::DatabaseError(format!("Failed to query objects: {}", e)))?;
658
659 // Normalize pk: strip surrounding quotes from JSON string PKs for comparison
660 let pk_str = pk.to_string();
661 let pk_str = pk_str.trim_matches('"');
662
663 items
664 .into_iter()
665 .find(|item| {
666 if let Some(item_pk) = item.primary_key() {
667 item_pk.to_string() == pk_str
668 } else {
669 false
670 }
671 })
672 .ok_or_else(|| ViewError::NotFound(format!("Object with pk={} not found", pk)))?
673 } else {
674 // Use in-memory queryset
675 let queryset = self.get_queryset();
676 let pk_str = pk.to_string();
677 let pk_str = pk_str.trim_matches('"');
678 queryset
679 .iter()
680 .find(|item| {
681 if let Some(item_pk) = item.primary_key() {
682 item_pk.to_string() == pk_str
683 } else {
684 false
685 }
686 })
687 .cloned()
688 .ok_or_else(|| ViewError::NotFound(format!("Object with pk={} not found", pk)))?
689 };
690
691 let json = serializer
692 .serialize(&item)
693 .map_err(|e| ViewError::Serialization(e.to_string()))?;
694
695 Ok(Response::ok().with_body(json))
696 }
697
698 /// Create a new object
699 ///
700 /// # Examples
701 ///
702 /// ```no_run
703 /// # use reinhardt_views::viewsets::ModelViewSetHandler;
704 /// # use reinhardt_http::Request;
705 /// # use reinhardt_db::orm::Model;
706 /// # use serde::{Serialize, Deserialize};
707 /// # use bytes::Bytes;
708 /// # use hyper::{Method, Version, HeaderMap};
709 /// #
710 /// # #[derive(Debug, Clone, Serialize, Deserialize)]
711 /// # struct User {
712 /// # id: Option<i64>,
713 /// # username: String,
714 /// # }
715 /// #
716 /// # #[derive(Clone)]
717 /// # struct UserFields;
718 /// #
719 /// # impl reinhardt_db::orm::FieldSelector for UserFields {
720 /// # fn with_alias(self, _alias: &str) -> Self { self }
721 /// # }
722 /// #
723 /// # impl Model for User {
724 /// # type PrimaryKey = i64;
725 /// # type Fields = UserFields;
726 /// # fn table_name() -> &'static str { "users" }
727 /// # fn primary_key(&self) -> Option<Self::PrimaryKey> { self.id }
728 /// # fn set_primary_key(&mut self, value: Self::PrimaryKey) { self.id = Some(value); }
729 /// # fn new_fields() -> Self::Fields { UserFields }
730 /// # }
731 /// #
732 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
733 /// let handler = ModelViewSetHandler::<User>::new();
734 /// let request = Request::builder()
735 /// .method(Method::POST)
736 /// .uri("/users/")
737 /// .version(Version::HTTP_11)
738 /// .headers(HeaderMap::new())
739 /// .body(Bytes::from(r#"{"username":"alice"}"#))
740 /// .build()?;
741 /// let response = handler.create(&request).await?;
742 /// # Ok(())
743 /// # }
744 /// ```
745 pub async fn create(&self, request: &Request) -> std::result::Result<Response, ViewError> {
746 self.check_permissions(request).await?;
747
748 let serializer = self.get_serializer();
749
750 // Parse request body
751 let body_str = String::from_utf8(request.body().to_vec())
752 .map_err(|e| ViewError::BadRequest(format!("Invalid UTF-8: {}", e)))?;
753
754 // Deserialize into model
755 let item = serializer
756 .deserialize(&body_str)
757 .map_err(|e| ViewError::Serialization(e.to_string()))?;
758
759 // Save to database if pool is available
760 if let Some(pool) = &self.pool {
761 // Create a new session for this request
762 let mut session = reinhardt_db::prelude::Session::new(pool.clone(), self.db_backend)
763 .await
764 .map_err(|e| {
765 ViewError::DatabaseError(format!("Failed to create session: {}", e))
766 })?;
767
768 // Begin transaction
769 session.begin().await.map_err(|e| {
770 ViewError::DatabaseError(format!("Failed to begin transaction: {}", e))
771 })?;
772
773 // Add object to session
774 session
775 .add(item.clone())
776 .await
777 .map_err(|e| ViewError::DatabaseError(format!("Failed to add object: {}", e)))?;
778
779 // Flush changes to database (generates and executes INSERT)
780 session
781 .flush()
782 .await
783 .map_err(|e| ViewError::DatabaseError(format!("Failed to flush: {}", e)))?;
784
785 // Get the generated ID from the session
786 let generated_id = session.get_generated_ids().first().map(|(_, id)| *id);
787
788 // Commit transaction
789 session
790 .commit()
791 .await
792 .map_err(|e| ViewError::DatabaseError(format!("Failed to commit: {}", e)))?;
793
794 // Re-fetch the created object from the database to get all auto-populated fields
795 // (e.g., created_at which is set by database DEFAULT)
796 if let Some(id) = generated_id {
797 let fetch_session =
798 reinhardt_db::prelude::Session::new(pool.clone(), self.db_backend)
799 .await
800 .map_err(|e| {
801 ViewError::DatabaseError(format!("Failed to create session: {}", e))
802 })?;
803
804 // Fetch all objects and find the one with matching ID
805 let items: Vec<T> = fetch_session.list_all().await.map_err(|e| {
806 ViewError::DatabaseError(format!("Failed to fetch objects: {}", e))
807 })?;
808
809 let created_item = items
810 .into_iter()
811 .find(|i| {
812 i.primary_key()
813 .map(|pk| pk.to_string() == id.to_string())
814 .unwrap_or(false)
815 })
816 .ok_or_else(|| {
817 ViewError::DatabaseError("Failed to find created object".to_string())
818 })?;
819
820 // Serialize the complete object (including auto-populated fields)
821 let response_body = serializer
822 .serialize(&created_item)
823 .map_err(|e| ViewError::Serialization(e.to_string()))?;
824
825 return Ok(Response::created().with_body(response_body));
826 }
827 }
828
829 // Fallback: return the original item if no database pool
830 let response_body = serializer
831 .serialize(&item)
832 .map_err(|e| ViewError::Serialization(e.to_string()))?;
833
834 Ok(Response::created().with_body(response_body))
835 }
836
837 /// Update an existing object
838 ///
839 /// # Examples
840 ///
841 /// ```no_run
842 /// # use reinhardt_views::viewsets::ModelViewSetHandler;
843 /// # use reinhardt_http::Request;
844 /// # use reinhardt_db::orm::Model;
845 /// # use serde::{Serialize, Deserialize};
846 /// # use serde_json::Value;
847 /// # use bytes::Bytes;
848 /// # use hyper::{Method, Version, HeaderMap};
849 /// #
850 /// # #[derive(Debug, Clone, Serialize, Deserialize)]
851 /// # struct User {
852 /// # id: Option<i64>,
853 /// # username: String,
854 /// # }
855 /// #
856 /// # #[derive(Clone)]
857 /// # struct UserFields;
858 /// #
859 /// # impl reinhardt_db::orm::FieldSelector for UserFields {
860 /// # fn with_alias(self, _alias: &str) -> Self { self }
861 /// # }
862 /// #
863 /// # impl Model for User {
864 /// # type PrimaryKey = i64;
865 /// # type Fields = UserFields;
866 /// # fn table_name() -> &'static str { "users" }
867 /// # fn primary_key(&self) -> Option<Self::PrimaryKey> { self.id }
868 /// # fn set_primary_key(&mut self, value: Self::PrimaryKey) { self.id = Some(value); }
869 /// # fn new_fields() -> Self::Fields { UserFields }
870 /// # }
871 /// #
872 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
873 /// let handler = ModelViewSetHandler::<User>::new();
874 /// let request = Request::builder()
875 /// .method(Method::PUT)
876 /// .uri("/users/1/")
877 /// .version(Version::HTTP_11)
878 /// .headers(HeaderMap::new())
879 /// .body(Bytes::from(r#"{"username":"alice_updated"}"#))
880 /// .build()?;
881 /// let pk = serde_json::json!(1);
882 /// let response = handler.update(&request, pk).await?;
883 /// # Ok(())
884 /// # }
885 /// ```
886 pub async fn update(
887 &self,
888 request: &Request,
889 pk: serde_json::Value,
890 ) -> std::result::Result<Response, ViewError> {
891 self.check_permissions(request).await?;
892
893 let serializer = self.get_serializer();
894
895 // Get existing object from database
896 let existing_obj: T = if let Some(pool) = &self.pool {
897 let session = reinhardt_db::prelude::Session::new(pool.clone(), self.db_backend)
898 .await
899 .map_err(|e| {
900 ViewError::DatabaseError(format!("Failed to create session: {}", e))
901 })?;
902
903 let items: Vec<T> = session
904 .list_all()
905 .await
906 .map_err(|e| ViewError::DatabaseError(format!("Failed to list objects: {}", e)))?;
907
908 // Normalize pk: strip surrounding quotes only (consistent with retrieve()).
909 let pk_str_owned = pk.to_string();
910 let pk_str = pk_str_owned.trim_matches('"');
911 items
912 .into_iter()
913 .find(|item| {
914 if let Some(item_pk) = item.primary_key() {
915 item_pk.to_string() == pk_str
916 } else {
917 false
918 }
919 })
920 .ok_or_else(|| {
921 ViewError::NotFound(format!("Object with pk {} not found", pk_str))
922 })?
923 } else {
924 // Fall back to queryset for non-database mode
925 // Normalize pk: strip surrounding quotes only (consistent with retrieve()).
926 let pk_str_owned = pk.to_string();
927 let pk_str = pk_str_owned.trim_matches('"');
928 self.get_queryset()
929 .iter()
930 .find(|item| {
931 if let Some(item_pk) = item.primary_key() {
932 item_pk.to_string() == pk_str
933 } else {
934 false
935 }
936 })
937 .cloned()
938 .ok_or_else(|| {
939 ViewError::NotFound(format!("Object with pk {} not found", pk_str))
940 })?
941 };
942
943 // Parse request body as JSON for partial update (PATCH semantics)
944 let body_str = String::from_utf8(request.body().to_vec())
945 .map_err(|e| ViewError::BadRequest(format!("Invalid UTF-8: {}", e)))?;
946
947 // Parse patch data as JSON
948 let patch_data: serde_json::Value = serde_json::from_str(&body_str)
949 .map_err(|e| ViewError::Serialization(format!("Invalid JSON: {}", e)))?;
950
951 // Serialize existing object to JSON and merge with patch data
952 let existing_json = serializer
953 .serialize(&existing_obj)
954 .map_err(|e| ViewError::Serialization(e.to_string()))?;
955 let mut existing_value: serde_json::Value = serde_json::from_str(&existing_json)
956 .map_err(|e| ViewError::Serialization(format!("Failed to parse existing: {}", e)))?;
957
958 // Validate and merge patch data into existing object (only overwrites provided fields)
959 crate::generic::patch_utils::merge_patch_object_into(&mut existing_value, &patch_data)
960 .map_err(ViewError::BadRequest)?;
961
962 // Deserialize merged object back to model type
963 let merged_json = serde_json::to_string(&existing_value)
964 .map_err(|e| ViewError::Serialization(format!("Failed to serialize merged: {}", e)))?;
965 let updated_item: T = serializer
966 .deserialize(&merged_json)
967 .map_err(|e| ViewError::Serialization(e.to_string()))?;
968
969 // Update database if pool is available
970 if let Some(pool) = &self.pool {
971 // Create a new session for this request
972 let mut session = reinhardt_db::prelude::Session::new(pool.clone(), self.db_backend)
973 .await
974 .map_err(|e| {
975 ViewError::DatabaseError(format!("Failed to create session: {}", e))
976 })?;
977
978 // Begin transaction
979 session.begin().await.map_err(|e| {
980 ViewError::DatabaseError(format!("Failed to begin transaction: {}", e))
981 })?;
982
983 // Add updated object to session (marks as dirty for UPDATE)
984 session
985 .add(updated_item.clone())
986 .await
987 .map_err(|e| ViewError::DatabaseError(format!("Failed to add object: {}", e)))?;
988
989 // Flush changes to database (generates and executes UPDATE)
990 session
991 .flush()
992 .await
993 .map_err(|e| ViewError::DatabaseError(format!("Failed to flush: {}", e)))?;
994
995 // Commit transaction
996 session
997 .commit()
998 .await
999 .map_err(|e| ViewError::DatabaseError(format!("Failed to commit: {}", e)))?;
1000 }
1001
1002 // Return the complete merged/updated object
1003 Ok(Response::ok().with_body(merged_json))
1004 }
1005
1006 /// Delete an object
1007 ///
1008 /// # Examples
1009 ///
1010 /// ```no_run
1011 /// # use reinhardt_views::viewsets::ModelViewSetHandler;
1012 /// # use reinhardt_http::Request;
1013 /// # use reinhardt_db::orm::Model;
1014 /// # use serde::{Serialize, Deserialize};
1015 /// # use serde_json::Value;
1016 /// # use bytes::Bytes;
1017 /// # use hyper::{Method, Version, HeaderMap};
1018 /// #
1019 /// # #[derive(Debug, Clone, Serialize, Deserialize)]
1020 /// # struct User {
1021 /// # id: Option<i64>,
1022 /// # username: String,
1023 /// # }
1024 /// #
1025 /// # #[derive(Clone)]
1026 /// # struct UserFields;
1027 /// #
1028 /// # impl reinhardt_db::orm::FieldSelector for UserFields {
1029 /// # fn with_alias(self, _alias: &str) -> Self { self }
1030 /// # }
1031 /// #
1032 /// # impl Model for User {
1033 /// # type PrimaryKey = i64;
1034 /// # type Fields = UserFields;
1035 /// # fn table_name() -> &'static str { "users" }
1036 /// # fn primary_key(&self) -> Option<Self::PrimaryKey> { self.id }
1037 /// # fn set_primary_key(&mut self, value: Self::PrimaryKey) { self.id = Some(value); }
1038 /// # fn new_fields() -> Self::Fields { UserFields }
1039 /// # }
1040 /// #
1041 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1042 /// let handler = ModelViewSetHandler::<User>::new();
1043 /// let request = Request::builder()
1044 /// .method(Method::DELETE)
1045 /// .uri("/users/1/")
1046 /// .version(Version::HTTP_11)
1047 /// .headers(HeaderMap::new())
1048 /// .body(Bytes::new())
1049 /// .build()?;
1050 /// let pk = serde_json::json!(1);
1051 /// let response = handler.destroy(&request, pk).await?;
1052 /// # Ok(())
1053 /// # }
1054 /// ```
1055 pub async fn destroy(
1056 &self,
1057 request: &Request,
1058 pk: serde_json::Value,
1059 ) -> std::result::Result<Response, ViewError> {
1060 self.check_permissions(request).await?;
1061
1062 let serializer = self.get_serializer();
1063
1064 // Verify object exists and get it for deletion
1065 let response = self.retrieve(request, pk).await?;
1066
1067 // Extract the object from response body
1068 let body_str = String::from_utf8(response.body.to_vec())
1069 .map_err(|e| ViewError::BadRequest(format!("Invalid UTF-8: {}", e)))?;
1070
1071 // Deserialize into model
1072 let item = serializer
1073 .deserialize(&body_str)
1074 .map_err(|e| ViewError::Serialization(e.to_string()))?;
1075
1076 // Delete from database if pool is available
1077 if let Some(pool) = &self.pool {
1078 // Create a new session for this request
1079 let mut session = reinhardt_db::prelude::Session::new(pool.clone(), self.db_backend)
1080 .await
1081 .map_err(|e| {
1082 ViewError::DatabaseError(format!("Failed to create session: {}", e))
1083 })?;
1084
1085 // Begin transaction
1086 session.begin().await.map_err(|e| {
1087 ViewError::DatabaseError(format!("Failed to begin transaction: {}", e))
1088 })?;
1089
1090 // Mark object for deletion
1091 session.delete(item).await.map_err(|e| {
1092 ViewError::DatabaseError(format!("Failed to mark object for deletion: {}", e))
1093 })?;
1094
1095 // Flush changes to database (generates and executes DELETE)
1096 session
1097 .flush()
1098 .await
1099 .map_err(|e| ViewError::DatabaseError(format!("Failed to flush: {}", e)))?;
1100
1101 // Commit transaction
1102 session
1103 .commit()
1104 .await
1105 .map_err(|e| ViewError::DatabaseError(format!("Failed to commit: {}", e)))?;
1106 }
1107
1108 Ok(Response::no_content())
1109 }
1110}
1111
1112impl<T> Default for ModelViewSetHandler<T>
1113where
1114 T: Model + Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
1115{
1116 fn default() -> Self {
1117 Self::new()
1118 }
1119}
1120
1121#[cfg(test)]
1122mod tests {
1123 use super::*;
1124 use bytes::Bytes;
1125 use hyper::{HeaderMap, Method, Version};
1126 use reinhardt_http::Request;
1127 use rstest::rstest;
1128
1129 fn build_request(uri: &str) -> Request {
1130 Request::builder()
1131 .method(Method::GET)
1132 .uri(uri)
1133 .version(Version::HTTP_11)
1134 .headers(HeaderMap::new())
1135 .body(Bytes::new())
1136 .build()
1137 .unwrap()
1138 }
1139
1140 // -----------------------------------------------------------------------
1141 // Test model for retrieve PK tests
1142 // -----------------------------------------------------------------------
1143
1144 #[derive(Debug, Clone, Serialize, serde::Deserialize, PartialEq)]
1145 struct TestItem {
1146 id: Option<i64>,
1147 name: String,
1148 }
1149
1150 #[derive(Clone)]
1151 struct TestItemFields;
1152
1153 impl reinhardt_db::orm::FieldSelector for TestItemFields {
1154 fn with_alias(self, _alias: &str) -> Self {
1155 self
1156 }
1157 }
1158
1159 impl reinhardt_db::orm::Model for TestItem {
1160 type PrimaryKey = i64;
1161 type Fields = TestItemFields;
1162
1163 fn table_name() -> &'static str {
1164 "test_items"
1165 }
1166
1167 fn primary_key(&self) -> Option<Self::PrimaryKey> {
1168 self.id
1169 }
1170
1171 fn set_primary_key(&mut self, value: Self::PrimaryKey) {
1172 self.id = Some(value);
1173 }
1174
1175 fn new_fields() -> Self::Fields {
1176 TestItemFields
1177 }
1178 }
1179
1180 /// Helper to build a ModelViewSetHandler with in-memory queryset
1181 fn build_model_handler(items: Vec<TestItem>) -> ModelViewSetHandler<TestItem> {
1182 ModelViewSetHandler::<TestItem>::new().with_queryset(items)
1183 }
1184
1185 #[rstest]
1186 #[tokio::test]
1187 async fn test_retrieve_strips_quotes_from_numeric_pk() {
1188 // Arrange
1189 let items = vec![
1190 TestItem {
1191 id: Some(1),
1192 name: "first".to_string(),
1193 },
1194 TestItem {
1195 id: Some(2),
1196 name: "second".to_string(),
1197 },
1198 ];
1199 let handler = build_model_handler(items);
1200 let request = build_request("/items/1/");
1201
1202 // Act - pass pk with surrounding quotes (as JSON string value)
1203 let pk = serde_json::json!("1");
1204 let result = handler.retrieve(&request, pk).await;
1205
1206 // Assert - should find the item despite quotes in pk
1207 assert!(result.is_ok(), "retrieve should succeed with quoted pk");
1208 let response = result.unwrap();
1209 assert_eq!(response.status, hyper::StatusCode::OK);
1210 let body: TestItem =
1211 serde_json::from_slice(&response.body).expect("response should be valid JSON");
1212 assert_eq!(body.name, "first");
1213 assert_eq!(body.id, Some(1));
1214 }
1215
1216 #[rstest]
1217 #[tokio::test]
1218 async fn test_retrieve_works_with_unquoted_numeric_pk() {
1219 // Arrange
1220 let items = vec![TestItem {
1221 id: Some(42),
1222 name: "answer".to_string(),
1223 }];
1224 let handler = build_model_handler(items);
1225 let request = build_request("/items/42/");
1226
1227 // Act - pass pk as JSON number (no quotes)
1228 let pk = serde_json::json!(42);
1229 let result = handler.retrieve(&request, pk).await;
1230
1231 // Assert
1232 assert!(result.is_ok(), "retrieve should succeed with numeric pk");
1233 let response = result.unwrap();
1234 assert_eq!(response.status, hyper::StatusCode::OK);
1235 let body: TestItem =
1236 serde_json::from_slice(&response.body).expect("response should be valid JSON");
1237 assert_eq!(body.name, "answer");
1238 assert_eq!(body.id, Some(42));
1239 }
1240
1241 #[rstest]
1242 #[tokio::test]
1243 async fn test_retrieve_returns_not_found_for_nonexistent_pk() {
1244 // Arrange
1245 let items = vec![TestItem {
1246 id: Some(1),
1247 name: "only".to_string(),
1248 }];
1249 let handler = build_model_handler(items);
1250 let request = build_request("/items/999/");
1251
1252 // Act
1253 let pk = serde_json::json!(999);
1254 let result = handler.retrieve(&request, pk).await;
1255
1256 // Assert
1257 assert!(result.is_err(), "retrieve should fail for nonexistent pk");
1258 let err = result.unwrap_err();
1259 assert!(
1260 matches!(err, ViewError::NotFound(_)),
1261 "error should be NotFound, got: {:?}",
1262 err
1263 );
1264 }
1265}