1#![allow(deprecated)]
5
6use crate::{Action, ViewSet};
8use async_trait::async_trait;
9use hyper::Method;
10use parking_lot::RwLock;
11use reinhardt_auth::{Permission, PermissionContext};
12use reinhardt_db::orm::{Model, query_types::DbBackend};
13use reinhardt_http::{Handler, Request, Response, Result};
14use reinhardt_rest::filters::FilterBackend;
15use reinhardt_rest::serializers::{ModelSerializer, Serializer};
16use serde::Serialize;
17use serde::de::DeserializeOwned;
18use std::collections::HashMap;
19use std::marker::PhantomData;
20use std::sync::Arc;
21use tracing;
22
23pub struct ViewSetHandler<V: ViewSet> {
25 viewset: Arc<V>,
26 action_map: HashMap<Method, String>,
27 #[allow(dead_code)]
29 name: Option<String>,
30 #[allow(dead_code)]
32 suffix: Option<String>,
33
34 args: RwLock<Option<Vec<String>>>,
37 kwargs: RwLock<Option<HashMap<String, String>>>,
38 has_handled_request: RwLock<bool>,
39}
40
41impl<V: ViewSet> std::panic::RefUnwindSafe for ViewSetHandler<V> {}
44
45impl<V: ViewSet> ViewSetHandler<V> {
46 pub fn new(
48 viewset: Arc<V>,
49 action_map: HashMap<Method, String>,
50 name: Option<String>,
51 suffix: Option<String>,
52 ) -> Self {
53 Self {
54 viewset,
55 action_map,
56 name,
57 suffix,
58 args: RwLock::new(None),
59 kwargs: RwLock::new(None),
60 has_handled_request: RwLock::new(false),
61 }
62 }
63
64 pub fn has_args(&self) -> bool {
66 self.args.read().is_some()
67 }
68
69 pub fn has_kwargs(&self) -> bool {
71 self.kwargs.read().is_some()
72 }
73
74 pub fn has_request(&self) -> bool {
76 *self.has_handled_request.read()
77 }
78
79 pub fn has_action_map(&self) -> bool {
81 !self.action_map.is_empty()
82 }
83}
84
85#[async_trait]
86impl<V: ViewSet + 'static> Handler for ViewSetHandler<V> {
87 async fn handle(&self, mut request: Request) -> Result<Response> {
88 *self.has_handled_request.write() = true;
90 *self.args.write() = Some(Vec::new());
91
92 let kwargs = extract_path_params(&request);
94 *self.kwargs.write() = Some(kwargs);
95
96 if let Some(middleware) = self.viewset.get_middleware()
98 && let Some(response) = middleware.process_request(&mut request).await?
99 {
100 return Ok(response);
101 }
102
103 let action_name = match self.action_map.get(&request.method) {
105 Some(name) => name,
106 None => {
107 let allowed: Vec<String> = self.action_map.keys().map(|m| m.to_string()).collect();
108 let mut response = Response::new(hyper::StatusCode::METHOD_NOT_ALLOWED);
109 match allowed.join(", ").parse() {
110 Ok(header_value) => {
111 response.headers.insert(hyper::header::ALLOW, header_value);
112 }
113 Err(e) => {
114 tracing::warn!(
115 error = %e,
116 "Failed to parse allowed methods as header value"
117 );
118 }
119 }
120 return Ok(response);
121 }
122 };
123
124 let action = Action::from_name(action_name);
126
127 let response = self.viewset.dispatch(request, action).await?;
129
130 Ok(response)
132 }
133}
134
135fn extract_path_params(request: &Request) -> HashMap<String, String> {
138 let mut params = HashMap::new();
139
140 let path = request.uri.path();
143 let segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
144
145 if segments.len() >= 2 {
148 params.insert("id".to_string(), segments[1].to_string());
149 }
150
151 params
152}
153
154#[derive(Debug)]
156pub enum ViewError {
157 Serialization(String),
159 Permission(String),
161 NotFound(String),
163 BadRequest(String),
165 Internal(String),
167 DatabaseError(String),
169}
170
171impl std::fmt::Display for ViewError {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 match self {
174 ViewError::Serialization(msg) => write!(f, "Serialization error: {}", msg),
175 ViewError::Permission(msg) => write!(f, "Permission denied: {}", msg),
176 ViewError::NotFound(msg) => write!(f, "Not found: {}", msg),
177 ViewError::BadRequest(msg) => write!(f, "Bad request: {}", msg),
178 ViewError::Internal(msg) => write!(f, "Internal error: {}", msg),
179 ViewError::DatabaseError(msg) => write!(f, "Database error: {}", msg),
180 }
181 }
182}
183
184impl std::error::Error for ViewError {}
185
186impl From<ViewError> for reinhardt_core::exception::Error {
199 fn from(value: ViewError) -> Self {
200 match value {
201 ViewError::Serialization(m) => Self::Serialization(m),
202 ViewError::Permission(m) => Self::Authorization(m),
203 ViewError::NotFound(m) => Self::NotFound(m),
204 ViewError::BadRequest(m) => Self::Http(m),
205 ViewError::Internal(m) => Self::Internal(m),
206 ViewError::DatabaseError(m) => Self::Database(m),
207 }
208 }
209}
210
211pub struct ModelViewSetHandler<T>
250where
251 T: Model + Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
252{
253 queryset: Option<Vec<T>>,
254 serializer_class: Option<Arc<dyn Serializer<Input = T, Output = String> + Send + Sync>>,
255 permission_classes: Vec<Arc<dyn Permission>>,
256 filter_backends: Vec<Arc<dyn FilterBackend>>,
257 pagination_class: Option<reinhardt_core::pagination::PaginatorImpl>,
258 pool: Option<Arc<sqlx::AnyPool>>,
259 db_backend: DbBackend,
261 _phantom: PhantomData<T>,
262}
263
264impl<T> ModelViewSetHandler<T>
265where
266 T: Model + Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
267{
268 pub fn new() -> Self {
301 Self {
302 queryset: None,
303 serializer_class: None,
304 permission_classes: Vec::new(),
305 filter_backends: Vec::new(),
306 pagination_class: None,
307 pool: None,
308 db_backend: DbBackend::Postgres, _phantom: PhantomData,
310 }
311 }
312
313 pub fn with_queryset(mut self, queryset: Vec<T>) -> Self {
351 self.queryset = Some(queryset);
352 self
353 }
354
355 pub fn with_serializer(
392 mut self,
393 serializer: Arc<dyn Serializer<Input = T, Output = String> + Send + Sync>,
394 ) -> Self {
395 self.serializer_class = Some(serializer);
396 self
397 }
398
399 pub fn with_pool(mut self, pool: Arc<sqlx::AnyPool>) -> Self {
439 self.pool = Some(pool);
440 self
441 }
442
443 pub fn with_db_backend(mut self, db_backend: DbBackend) -> Self {
477 self.db_backend = db_backend;
478 self
479 }
480
481 pub fn add_permission(mut self, permission: Arc<dyn Permission>) -> Self {
517 self.permission_classes.push(permission);
518 self
519 }
520
521 pub fn add_filter_backend(mut self, backend: Arc<dyn FilterBackend>) -> Self {
523 self.filter_backends.push(backend);
524 self
525 }
526
527 pub fn with_pagination(
529 mut self,
530 pagination: reinhardt_core::pagination::PaginatorImpl,
531 ) -> Self {
532 self.pagination_class = Some(pagination);
533 self
534 }
535
536 fn get_queryset(&self) -> &[T] {
538 self.queryset.as_deref().unwrap_or(&[])
539 }
540
541 fn get_serializer(&self) -> Arc<dyn Serializer<Input = T, Output = String> + Send + Sync> {
543 self.serializer_class
544 .clone()
545 .unwrap_or_else(|| Arc::new(ModelSerializer::<T>::new()))
546 }
547
548 async fn check_permissions(&self, request: &Request) -> std::result::Result<(), ViewError> {
550 let user_id_string: Option<String> = request.extensions.get::<String>().or_else(|| {
566 request
567 .extensions
568 .get::<uuid::Uuid>()
569 .map(|id| id.to_string())
570 });
571
572 let is_authenticated = user_id_string.is_some();
574
575 let (is_admin, is_active, user_obj) = if let (Some(user_id_str), Some(_pool)) =
577 (user_id_string.as_ref(), self.pool.as_ref())
578 {
579 #[cfg(feature = "argon2-hasher")]
581 match uuid::Uuid::parse_str(user_id_str) {
582 Ok(user_uuid) => {
583 use reinhardt_db::orm::manager::get_connection;
585 match get_connection().await {
586 Ok(conn) => {
587 use reinhardt_auth::DefaultUser;
589 use reinhardt_db::orm::{
590 Alias, ColumnRef, DatabaseBackend, Expr, ExprTrait, Model,
591 MySqlQueryBuilder, PostgresQueryBuilder, Query,
592 QueryStatementBuilder, SqliteQueryBuilder,
593 };
594
595 let table_name = DefaultUser::table_name();
596 let pk_field = DefaultUser::primary_key_field();
597
598 let stmt = Query::select()
600 .column(ColumnRef::Asterisk)
601 .from(Alias::new(table_name))
602 .and_where(
603 Expr::col(Alias::new(pk_field))
604 .eq(Expr::value(user_uuid.to_string())),
605 )
606 .to_owned();
607
608 let sql = match conn.backend() {
609 DatabaseBackend::Postgres => stmt.to_string(PostgresQueryBuilder),
610 DatabaseBackend::MySql => stmt.to_string(MySqlQueryBuilder),
611 DatabaseBackend::Sqlite => stmt.to_string(SqliteQueryBuilder),
612 };
613
614 match conn.query_optional(&sql, vec![]).await {
615 Ok(Some(row)) => {
616 match serde_json::from_value::<DefaultUser>(row.data) {
618 Ok(user) => {
619 use reinhardt_auth::User;
620 let is_admin = user.is_admin();
622 let is_active = user.is_active();
623 let boxed_user: Box<dyn User> = Box::new(user);
625 (is_admin, is_active, Some(boxed_user))
626 }
627 Err(_) => {
628 (false, true, None)
630 }
631 }
632 }
633 Ok(None) => {
634 (false, true, None)
636 }
637 Err(_) => {
638 (false, true, None)
640 }
641 }
642 }
643 Err(_) => {
644 (false, true, None)
646 }
647 }
648 }
649 Err(_) => {
650 (false, true, None)
652 }
653 }
654
655 #[cfg(not(feature = "argon2-hasher"))]
658 {
659 let _ = user_id_str; (false, true, None)
661 }
662 } else {
663 (false, true, None)
665 };
666
667 let context = PermissionContext {
668 request,
669 is_authenticated,
670 is_admin,
671 is_active,
672 user: user_obj,
673 };
674
675 for permission in &self.permission_classes {
677 if !permission.has_permission(&context).await {
678 return Err(ViewError::Permission(format!(
680 "Permission denied by {}",
681 std::any::type_name_of_val(&**permission)
682 )));
683 }
684 }
685
686 Ok(())
687 }
688
689 pub async fn list(&self, request: &Request) -> std::result::Result<Response, ViewError> {
737 self.check_permissions(request).await?;
738
739 let serializer = self.get_serializer();
740
741 let items: Vec<T> = if let Some(pool) = &self.pool {
743 let session = reinhardt_db::prelude::Session::new(pool.clone(), self.db_backend)
745 .await
746 .map_err(|e| {
747 ViewError::DatabaseError(format!("Failed to create session: {}", e))
748 })?;
749
750 session
751 .list_all()
752 .await
753 .map_err(|e| ViewError::DatabaseError(format!("Failed to list objects: {}", e)))?
754 } else {
755 self.get_queryset().to_vec()
757 };
758
759 let mut serialized_items = Vec::new();
761 for item in &items {
762 let json = serializer
763 .serialize(item)
764 .map_err(|e| ViewError::Serialization(e.to_string()))?;
765 serialized_items.push(json);
766 }
767
768 let response_body = format!("[{}]", serialized_items.join(","));
770
771 Ok(Response::ok().with_body(response_body))
772 }
773
774 pub async fn retrieve(
824 &self,
825 request: &Request,
826 pk: serde_json::Value,
827 ) -> std::result::Result<Response, ViewError> {
828 self.check_permissions(request).await?;
829
830 let serializer = self.get_serializer();
831
832 let item: T = if let Some(pool) = &self.pool {
834 let session = reinhardt_db::prelude::Session::new(pool.clone(), self.db_backend)
836 .await
837 .map_err(|e| {
838 ViewError::DatabaseError(format!("Failed to create session: {}", e))
839 })?;
840
841 let items: Vec<T> = session
842 .list_all()
843 .await
844 .map_err(|e| ViewError::DatabaseError(format!("Failed to query objects: {}", e)))?;
845
846 let pk_str = pk.to_string();
848 let pk_str = pk_str.trim_matches('"');
849
850 items
851 .into_iter()
852 .find(|item| {
853 if let Some(item_pk) = item.primary_key() {
854 item_pk.to_string() == pk_str
855 } else {
856 false
857 }
858 })
859 .ok_or_else(|| ViewError::NotFound(format!("Object with pk={} not found", pk)))?
860 } else {
861 let queryset = self.get_queryset();
863 let pk_str = pk.to_string();
864 let pk_str = pk_str.trim_matches('"');
865 queryset
866 .iter()
867 .find(|item| {
868 if let Some(item_pk) = item.primary_key() {
869 item_pk.to_string() == pk_str
870 } else {
871 false
872 }
873 })
874 .cloned()
875 .ok_or_else(|| ViewError::NotFound(format!("Object with pk={} not found", pk)))?
876 };
877
878 let json = serializer
879 .serialize(&item)
880 .map_err(|e| ViewError::Serialization(e.to_string()))?;
881
882 Ok(Response::ok().with_body(json))
883 }
884
885 pub async fn create(&self, request: &Request) -> std::result::Result<Response, ViewError> {
933 self.check_permissions(request).await?;
934
935 let serializer = self.get_serializer();
936
937 let body_str = String::from_utf8(request.body().to_vec())
939 .map_err(|e| ViewError::BadRequest(format!("Invalid UTF-8: {}", e)))?;
940
941 let item = serializer
943 .deserialize(&body_str)
944 .map_err(|e| ViewError::Serialization(e.to_string()))?;
945
946 if let Some(pool) = &self.pool {
948 let mut session = reinhardt_db::prelude::Session::new(pool.clone(), self.db_backend)
950 .await
951 .map_err(|e| {
952 ViewError::DatabaseError(format!("Failed to create session: {}", e))
953 })?;
954
955 session.begin().await.map_err(|e| {
957 ViewError::DatabaseError(format!("Failed to begin transaction: {}", e))
958 })?;
959
960 session
962 .add(item.clone())
963 .await
964 .map_err(|e| ViewError::DatabaseError(format!("Failed to add object: {}", e)))?;
965
966 session
968 .flush()
969 .await
970 .map_err(|e| ViewError::DatabaseError(format!("Failed to flush: {}", e)))?;
971
972 let generated_id = session.get_generated_ids().first().map(|(_, id)| *id);
974
975 session
977 .commit()
978 .await
979 .map_err(|e| ViewError::DatabaseError(format!("Failed to commit: {}", e)))?;
980
981 if let Some(id) = generated_id {
984 let fetch_session =
985 reinhardt_db::prelude::Session::new(pool.clone(), self.db_backend)
986 .await
987 .map_err(|e| {
988 ViewError::DatabaseError(format!("Failed to create session: {}", e))
989 })?;
990
991 let items: Vec<T> = fetch_session.list_all().await.map_err(|e| {
993 ViewError::DatabaseError(format!("Failed to fetch objects: {}", e))
994 })?;
995
996 let created_item = items
997 .into_iter()
998 .find(|i| {
999 i.primary_key()
1000 .map(|pk| pk.to_string() == id.to_string())
1001 .unwrap_or(false)
1002 })
1003 .ok_or_else(|| {
1004 ViewError::DatabaseError("Failed to find created object".to_string())
1005 })?;
1006
1007 let response_body = serializer
1009 .serialize(&created_item)
1010 .map_err(|e| ViewError::Serialization(e.to_string()))?;
1011
1012 return Ok(Response::created().with_body(response_body));
1013 }
1014 }
1015
1016 let response_body = serializer
1018 .serialize(&item)
1019 .map_err(|e| ViewError::Serialization(e.to_string()))?;
1020
1021 Ok(Response::created().with_body(response_body))
1022 }
1023
1024 pub async fn update(
1074 &self,
1075 request: &Request,
1076 pk: serde_json::Value,
1077 ) -> std::result::Result<Response, ViewError> {
1078 self.check_permissions(request).await?;
1079
1080 let serializer = self.get_serializer();
1081
1082 let existing_obj: T = if let Some(pool) = &self.pool {
1084 let session = reinhardt_db::prelude::Session::new(pool.clone(), self.db_backend)
1085 .await
1086 .map_err(|e| {
1087 ViewError::DatabaseError(format!("Failed to create session: {}", e))
1088 })?;
1089
1090 let items: Vec<T> = session
1091 .list_all()
1092 .await
1093 .map_err(|e| ViewError::DatabaseError(format!("Failed to list objects: {}", e)))?;
1094
1095 let pk_str = pk.to_string().replace('"', "");
1096 items
1097 .into_iter()
1098 .find(|item| {
1099 if let Some(item_pk) = item.primary_key() {
1100 item_pk.to_string() == pk_str
1101 } else {
1102 false
1103 }
1104 })
1105 .ok_or_else(|| {
1106 ViewError::NotFound(format!("Object with pk {} not found", pk_str))
1107 })?
1108 } else {
1109 let pk_str = pk.to_string().replace('"', "");
1111 self.get_queryset()
1112 .iter()
1113 .find(|item| {
1114 if let Some(item_pk) = item.primary_key() {
1115 item_pk.to_string() == pk_str
1116 } else {
1117 false
1118 }
1119 })
1120 .cloned()
1121 .ok_or_else(|| {
1122 ViewError::NotFound(format!("Object with pk {} not found", pk_str))
1123 })?
1124 };
1125
1126 let body_str = String::from_utf8(request.body().to_vec())
1128 .map_err(|e| ViewError::BadRequest(format!("Invalid UTF-8: {}", e)))?;
1129
1130 let patch_data: serde_json::Value = serde_json::from_str(&body_str)
1132 .map_err(|e| ViewError::Serialization(format!("Invalid JSON: {}", e)))?;
1133
1134 let existing_json = serializer
1136 .serialize(&existing_obj)
1137 .map_err(|e| ViewError::Serialization(e.to_string()))?;
1138 let mut existing_value: serde_json::Value = serde_json::from_str(&existing_json)
1139 .map_err(|e| ViewError::Serialization(format!("Failed to parse existing: {}", e)))?;
1140
1141 crate::generic::patch_utils::merge_patch_object_into(&mut existing_value, &patch_data)
1143 .map_err(ViewError::BadRequest)?;
1144
1145 let merged_json = serde_json::to_string(&existing_value)
1147 .map_err(|e| ViewError::Serialization(format!("Failed to serialize merged: {}", e)))?;
1148 let updated_item: T = serializer
1149 .deserialize(&merged_json)
1150 .map_err(|e| ViewError::Serialization(e.to_string()))?;
1151
1152 if let Some(pool) = &self.pool {
1154 let mut session = reinhardt_db::prelude::Session::new(pool.clone(), self.db_backend)
1156 .await
1157 .map_err(|e| {
1158 ViewError::DatabaseError(format!("Failed to create session: {}", e))
1159 })?;
1160
1161 session.begin().await.map_err(|e| {
1163 ViewError::DatabaseError(format!("Failed to begin transaction: {}", e))
1164 })?;
1165
1166 session
1168 .add(updated_item.clone())
1169 .await
1170 .map_err(|e| ViewError::DatabaseError(format!("Failed to add object: {}", e)))?;
1171
1172 session
1174 .flush()
1175 .await
1176 .map_err(|e| ViewError::DatabaseError(format!("Failed to flush: {}", e)))?;
1177
1178 session
1180 .commit()
1181 .await
1182 .map_err(|e| ViewError::DatabaseError(format!("Failed to commit: {}", e)))?;
1183 }
1184
1185 Ok(Response::ok().with_body(merged_json))
1187 }
1188
1189 pub async fn destroy(
1239 &self,
1240 request: &Request,
1241 pk: serde_json::Value,
1242 ) -> std::result::Result<Response, ViewError> {
1243 self.check_permissions(request).await?;
1244
1245 let serializer = self.get_serializer();
1246
1247 let response = self.retrieve(request, pk).await?;
1249
1250 let body_str = String::from_utf8(response.body.to_vec())
1252 .map_err(|e| ViewError::BadRequest(format!("Invalid UTF-8: {}", e)))?;
1253
1254 let item = serializer
1256 .deserialize(&body_str)
1257 .map_err(|e| ViewError::Serialization(e.to_string()))?;
1258
1259 if let Some(pool) = &self.pool {
1261 let mut session = reinhardt_db::prelude::Session::new(pool.clone(), self.db_backend)
1263 .await
1264 .map_err(|e| {
1265 ViewError::DatabaseError(format!("Failed to create session: {}", e))
1266 })?;
1267
1268 session.begin().await.map_err(|e| {
1270 ViewError::DatabaseError(format!("Failed to begin transaction: {}", e))
1271 })?;
1272
1273 session.delete(item).await.map_err(|e| {
1275 ViewError::DatabaseError(format!("Failed to mark object for deletion: {}", e))
1276 })?;
1277
1278 session
1280 .flush()
1281 .await
1282 .map_err(|e| ViewError::DatabaseError(format!("Failed to flush: {}", e)))?;
1283
1284 session
1286 .commit()
1287 .await
1288 .map_err(|e| ViewError::DatabaseError(format!("Failed to commit: {}", e)))?;
1289 }
1290
1291 Ok(Response::no_content())
1292 }
1293}
1294
1295impl<T> Default for ModelViewSetHandler<T>
1296where
1297 T: Model + Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
1298{
1299 fn default() -> Self {
1300 Self::new()
1301 }
1302}
1303
1304#[cfg(test)]
1305mod tests {
1306 use super::*;
1307 use bytes::Bytes;
1308 use hyper::{HeaderMap, Method, Version};
1309 use reinhardt_http::Request;
1310 use rstest::rstest;
1311 use std::thread;
1312
1313 fn build_request(uri: &str) -> Request {
1314 Request::builder()
1315 .method(Method::GET)
1316 .uri(uri)
1317 .version(Version::HTTP_11)
1318 .headers(HeaderMap::new())
1319 .body(Bytes::new())
1320 .build()
1321 .unwrap()
1322 }
1323
1324 #[rstest]
1325 fn test_parking_lot_rwlock_does_not_poison_after_panic() {
1326 let lock = RwLock::new(42);
1330
1331 let lock_ref = &lock;
1333 let result = thread::scope(|s| {
1334 let handle = s.spawn(|| {
1335 let mut guard = lock_ref.write();
1336 *guard = 100;
1337 panic!("intentional panic while holding write lock");
1338 });
1339 let _ = handle.join(); let value = *lock_ref.read();
1343 value
1344 });
1345
1346 assert!(result == 42 || result == 100);
1348 }
1349
1350 #[rstest]
1351 fn test_rwlock_concurrent_read_access() {
1352 let lock = RwLock::new(String::from("test_value"));
1354
1355 let guard1 = lock.read();
1357 let guard2 = lock.read();
1358
1359 assert_eq!(*guard1, "test_value");
1361 assert_eq!(*guard2, "test_value");
1362 }
1363
1364 #[rstest]
1365 fn test_extract_path_params_numeric_segment_treated_as_id() {
1366 let request = build_request("/resource/123/");
1368
1369 let params = extract_path_params(&request);
1371
1372 assert_eq!(params.get("id"), Some(&"123".to_string()));
1374 }
1375
1376 #[rstest]
1377 fn test_extract_path_params_non_numeric_segment_treated_as_id() {
1378 let request = build_request("/resource/username/");
1380
1381 let params = extract_path_params(&request);
1383
1384 assert_eq!(params.get("id"), Some(&"username".to_string()));
1386 }
1387
1388 #[rstest]
1389 fn test_extract_path_params_slug_segment_treated_as_id() {
1390 let request = build_request("/resource/my-slug/");
1392
1393 let params = extract_path_params(&request);
1395
1396 assert_eq!(params.get("id"), Some(&"my-slug".to_string()));
1398 }
1399
1400 #[rstest]
1401 fn test_extract_path_params_uuid_segment_treated_as_id() {
1402 let request = build_request("/resource/550e8400-e29b-41d4-a716-446655440000/");
1404
1405 let params = extract_path_params(&request);
1407
1408 assert_eq!(
1410 params.get("id"),
1411 Some(&"550e8400-e29b-41d4-a716-446655440000".to_string())
1412 );
1413 }
1414
1415 #[rstest]
1416 fn test_extract_path_params_single_segment_no_id() {
1417 let request = build_request("/resource/");
1419
1420 let params = extract_path_params(&request);
1422
1423 assert_eq!(params.get("id"), None);
1425 }
1426
1427 struct MockViewSet;
1429
1430 #[async_trait]
1431 impl ViewSet for MockViewSet {
1432 fn get_basename(&self) -> &str {
1433 "mock"
1434 }
1435
1436 async fn dispatch(
1437 &self,
1438 _request: reinhardt_http::Request,
1439 _action: crate::Action,
1440 ) -> reinhardt_http::Result<reinhardt_http::Response> {
1441 Ok(reinhardt_http::Response::ok())
1442 }
1443 }
1444
1445 fn build_handler(methods: Vec<Method>) -> ViewSetHandler<MockViewSet> {
1447 let mut action_map = HashMap::new();
1448 for method in methods {
1449 action_map.insert(method, "mock_action".to_string());
1450 }
1451 ViewSetHandler::new(Arc::new(MockViewSet), action_map, None, None)
1452 }
1453
1454 fn build_method_request(method: Method) -> reinhardt_http::Request {
1456 reinhardt_http::Request::builder()
1457 .method(method)
1458 .uri("/mock/")
1459 .version(hyper::Version::HTTP_11)
1460 .headers(hyper::HeaderMap::new())
1461 .body(bytes::Bytes::new())
1462 .build()
1463 .unwrap()
1464 }
1465
1466 #[rstest]
1467 #[tokio::test]
1468 async fn test_unregistered_method_returns_405() {
1469 let handler = build_handler(vec![Method::GET]);
1471 let request = build_method_request(Method::DELETE);
1472
1473 let response = Handler::handle(&handler, request).await.unwrap();
1475
1476 assert_eq!(response.status, hyper::StatusCode::METHOD_NOT_ALLOWED);
1478 }
1479
1480 #[rstest]
1481 #[tokio::test]
1482 async fn test_405_response_allow_header_contains_registered_methods() {
1483 let handler = build_handler(vec![Method::GET, Method::POST]);
1485 let request = build_method_request(Method::DELETE);
1486
1487 let response = Handler::handle(&handler, request).await.unwrap();
1489
1490 assert_eq!(response.status, hyper::StatusCode::METHOD_NOT_ALLOWED);
1492 let allow_header = response
1493 .headers
1494 .get(hyper::header::ALLOW)
1495 .expect("Allow header must be present");
1496 let allow_str = allow_header.to_str().unwrap();
1497 assert!(allow_str.contains("GET"), "Allow header must contain GET");
1499 assert!(allow_str.contains("POST"), "Allow header must contain POST");
1500 }
1501
1502 #[rstest]
1503 #[tokio::test]
1504 async fn test_405_response_allow_header_comma_separated_format() {
1505 let handler = build_handler(vec![Method::GET, Method::PUT]);
1507 let request = build_method_request(Method::PATCH);
1508
1509 let response = Handler::handle(&handler, request).await.unwrap();
1511
1512 assert_eq!(response.status, hyper::StatusCode::METHOD_NOT_ALLOWED);
1514 let allow_header = response
1515 .headers
1516 .get(hyper::header::ALLOW)
1517 .expect("Allow header must be present");
1518 let allow_str = allow_header.to_str().unwrap();
1519 let methods: Vec<&str> = allow_str.split(", ").collect();
1521 assert_eq!(
1522 methods.len(),
1523 2,
1524 "Allow header must contain exactly 2 methods"
1525 );
1526 for method in &methods {
1527 assert!(
1528 *method == "GET" || *method == "PUT",
1529 "Unexpected method in Allow header: {}",
1530 method
1531 );
1532 }
1533 }
1534
1535 #[rstest]
1536 #[tokio::test]
1537 async fn test_registered_method_does_not_return_405() {
1538 let handler = build_handler(vec![Method::GET]);
1540 let request = build_method_request(Method::GET);
1541
1542 let response = Handler::handle(&handler, request).await.unwrap();
1544
1545 assert_eq!(response.status, hyper::StatusCode::OK);
1547 }
1548
1549 #[derive(Debug, Clone, Serialize, serde::Deserialize, PartialEq)]
1554 struct TestItem {
1555 id: Option<i64>,
1556 name: String,
1557 }
1558
1559 #[derive(Clone)]
1560 struct TestItemFields;
1561
1562 impl reinhardt_db::orm::FieldSelector for TestItemFields {
1563 fn with_alias(self, _alias: &str) -> Self {
1564 self
1565 }
1566 }
1567
1568 impl reinhardt_db::orm::Model for TestItem {
1569 type PrimaryKey = i64;
1570 type Fields = TestItemFields;
1571
1572 fn table_name() -> &'static str {
1573 "test_items"
1574 }
1575
1576 fn primary_key(&self) -> Option<Self::PrimaryKey> {
1577 self.id
1578 }
1579
1580 fn set_primary_key(&mut self, value: Self::PrimaryKey) {
1581 self.id = Some(value);
1582 }
1583
1584 fn new_fields() -> Self::Fields {
1585 TestItemFields
1586 }
1587 }
1588
1589 fn build_model_handler(items: Vec<TestItem>) -> ModelViewSetHandler<TestItem> {
1591 ModelViewSetHandler::<TestItem>::new().with_queryset(items)
1592 }
1593
1594 #[rstest]
1595 #[tokio::test]
1596 async fn test_retrieve_strips_quotes_from_numeric_pk() {
1597 let items = vec![
1599 TestItem {
1600 id: Some(1),
1601 name: "first".to_string(),
1602 },
1603 TestItem {
1604 id: Some(2),
1605 name: "second".to_string(),
1606 },
1607 ];
1608 let handler = build_model_handler(items);
1609 let request = build_request("/items/1/");
1610
1611 let pk = serde_json::json!("1");
1613 let result = handler.retrieve(&request, pk).await;
1614
1615 assert!(result.is_ok(), "retrieve should succeed with quoted pk");
1617 let response = result.unwrap();
1618 assert_eq!(response.status, hyper::StatusCode::OK);
1619 let body: TestItem =
1620 serde_json::from_slice(&response.body.to_vec()).expect("response should be valid JSON");
1621 assert_eq!(body.name, "first");
1622 assert_eq!(body.id, Some(1));
1623 }
1624
1625 #[rstest]
1626 #[tokio::test]
1627 async fn test_retrieve_works_with_unquoted_numeric_pk() {
1628 let items = vec![TestItem {
1630 id: Some(42),
1631 name: "answer".to_string(),
1632 }];
1633 let handler = build_model_handler(items);
1634 let request = build_request("/items/42/");
1635
1636 let pk = serde_json::json!(42);
1638 let result = handler.retrieve(&request, pk).await;
1639
1640 assert!(result.is_ok(), "retrieve should succeed with numeric pk");
1642 let response = result.unwrap();
1643 assert_eq!(response.status, hyper::StatusCode::OK);
1644 let body: TestItem =
1645 serde_json::from_slice(&response.body.to_vec()).expect("response should be valid JSON");
1646 assert_eq!(body.name, "answer");
1647 assert_eq!(body.id, Some(42));
1648 }
1649
1650 #[rstest]
1651 #[tokio::test]
1652 async fn test_retrieve_returns_not_found_for_nonexistent_pk() {
1653 let items = vec![TestItem {
1655 id: Some(1),
1656 name: "only".to_string(),
1657 }];
1658 let handler = build_model_handler(items);
1659 let request = build_request("/items/999/");
1660
1661 let pk = serde_json::json!(999);
1663 let result = handler.retrieve(&request, pk).await;
1664
1665 assert!(result.is_err(), "retrieve should fail for nonexistent pk");
1667 let err = result.unwrap_err();
1668 assert!(
1669 matches!(err, ViewError::NotFound(_)),
1670 "error should be NotFound, got: {:?}",
1671 err
1672 );
1673 }
1674}