use crate::{Action, ViewSet};
use async_trait::async_trait;
use hyper::Method;
use parking_lot::RwLock;
use reinhardt_http::{Handler, Request, Response, Result};
use std::collections::HashMap;
use std::sync::Arc;
pub struct ViewSetHandler<V: ViewSet> {
viewset: Arc<V>,
action_map: HashMap<Method, String>,
#[allow(dead_code)]
name: Option<String>,
#[allow(dead_code)]
suffix: Option<String>,
args: RwLock<Option<Vec<String>>>,
kwargs: RwLock<Option<HashMap<String, String>>>,
has_handled_request: RwLock<bool>,
}
impl<V: ViewSet + std::panic::RefUnwindSafe> std::panic::RefUnwindSafe for ViewSetHandler<V> {}
impl<V: ViewSet> ViewSetHandler<V> {
pub fn new(
viewset: Arc<V>,
action_map: HashMap<Method, String>,
name: Option<String>,
suffix: Option<String>,
) -> Self {
Self {
viewset,
action_map,
name,
suffix,
args: RwLock::new(None),
kwargs: RwLock::new(None),
has_handled_request: RwLock::new(false),
}
}
pub fn has_args(&self) -> bool {
self.args.read().is_some()
}
pub fn has_kwargs(&self) -> bool {
self.kwargs.read().is_some()
}
pub fn has_request(&self) -> bool {
*self.has_handled_request.read()
}
pub fn has_action_map(&self) -> bool {
!self.action_map.is_empty()
}
}
#[async_trait]
impl<V: ViewSet + 'static> Handler for ViewSetHandler<V> {
async fn handle(&self, mut request: Request) -> Result<Response> {
*self.has_handled_request.write() = true;
*self.args.write() = Some(Vec::new());
let kwargs = extract_path_params(&request);
*self.kwargs.write() = Some(kwargs);
if let Some(middleware) = self.viewset.get_middleware()
&& let Some(response) = middleware.process_request(&mut request).await?
{
return Ok(response);
}
let action_name = match self.action_map.get(&request.method) {
Some(name) => name,
None => {
let allowed: Vec<String> = self.action_map.keys().map(|m| m.to_string()).collect();
let mut response = Response::new(hyper::StatusCode::METHOD_NOT_ALLOWED);
match allowed.join(", ").parse() {
Ok(header_value) => {
response.headers.insert(hyper::header::ALLOW, header_value);
}
Err(e) => {
tracing::warn!(
error = %e,
"Failed to parse allowed methods as header value"
);
}
}
return Ok(response);
}
};
let action = Action::from_name(action_name);
let response = self.viewset.dispatch(request, action).await?;
Ok(response)
}
}
pub(crate) fn extract_path_params(request: &Request) -> HashMap<String, String> {
let mut params = HashMap::new();
let path = request.uri.path();
let segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
if segments.len() >= 2 {
params.insert("id".to_string(), segments[1].to_string());
}
params
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use hyper::{HeaderMap, Method, Version};
use reinhardt_http::Request;
use rstest::rstest;
use std::thread;
fn build_request(uri: &str) -> Request {
Request::builder()
.method(Method::GET)
.uri(uri)
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap()
}
#[rstest]
fn test_parking_lot_rwlock_does_not_poison_after_panic() {
let lock = RwLock::new(42);
let lock_ref = &lock;
let result = thread::scope(|s| {
let handle = s.spawn(|| {
let mut guard = lock_ref.write();
*guard = 100;
panic!("intentional panic while holding write lock");
});
let _ = handle.join();
*lock_ref.read()
});
assert!(result == 42 || result == 100);
}
#[rstest]
fn test_rwlock_concurrent_read_access() {
let lock = RwLock::new(String::from("test_value"));
let guard1 = lock.read();
let guard2 = lock.read();
assert_eq!(*guard1, "test_value");
assert_eq!(*guard2, "test_value");
}
#[rstest]
fn test_extract_path_params_numeric_segment_treated_as_id() {
let request = build_request("/resource/123/");
let params = extract_path_params(&request);
assert_eq!(params.get("id"), Some(&"123".to_string()));
}
#[rstest]
fn test_extract_path_params_non_numeric_segment_treated_as_id() {
let request = build_request("/resource/username/");
let params = extract_path_params(&request);
assert_eq!(params.get("id"), Some(&"username".to_string()));
}
#[rstest]
fn test_extract_path_params_slug_segment_treated_as_id() {
let request = build_request("/resource/my-slug/");
let params = extract_path_params(&request);
assert_eq!(params.get("id"), Some(&"my-slug".to_string()));
}
#[rstest]
fn test_extract_path_params_uuid_segment_treated_as_id() {
let request = build_request("/resource/550e8400-e29b-41d4-a716-446655440000/");
let params = extract_path_params(&request);
assert_eq!(
params.get("id"),
Some(&"550e8400-e29b-41d4-a716-446655440000".to_string())
);
}
#[rstest]
fn test_extract_path_params_single_segment_no_id() {
let request = build_request("/resource/");
let params = extract_path_params(&request);
assert_eq!(params.get("id"), None);
}
struct MockViewSet;
#[async_trait]
impl ViewSet for MockViewSet {
fn get_basename(&self) -> &str {
"mock"
}
async fn dispatch(
&self,
_request: reinhardt_http::Request,
_action: crate::Action,
) -> reinhardt_http::Result<reinhardt_http::Response> {
Ok(reinhardt_http::Response::ok())
}
}
fn build_handler(methods: Vec<Method>) -> ViewSetHandler<MockViewSet> {
let mut action_map = HashMap::new();
for method in methods {
action_map.insert(method, "mock_action".to_string());
}
ViewSetHandler::new(Arc::new(MockViewSet), action_map, None, None)
}
fn build_method_request(method: Method) -> reinhardt_http::Request {
reinhardt_http::Request::builder()
.method(method)
.uri("/mock/")
.version(hyper::Version::HTTP_11)
.headers(hyper::HeaderMap::new())
.body(bytes::Bytes::new())
.build()
.unwrap()
}
#[rstest]
#[tokio::test]
async fn test_unregistered_method_returns_405() {
let handler = build_handler(vec![Method::GET]);
let request = build_method_request(Method::DELETE);
let response = Handler::handle(&handler, request).await.unwrap();
assert_eq!(response.status, hyper::StatusCode::METHOD_NOT_ALLOWED);
}
#[rstest]
#[tokio::test]
async fn test_405_response_allow_header_contains_registered_methods() {
let handler = build_handler(vec![Method::GET, Method::POST]);
let request = build_method_request(Method::DELETE);
let response = Handler::handle(&handler, request).await.unwrap();
assert_eq!(response.status, hyper::StatusCode::METHOD_NOT_ALLOWED);
let allow_header = response
.headers
.get(hyper::header::ALLOW)
.expect("Allow header must be present");
let allow_str = allow_header.to_str().unwrap();
assert!(allow_str.contains("GET"), "Allow header must contain GET");
assert!(allow_str.contains("POST"), "Allow header must contain POST");
}
#[rstest]
#[tokio::test]
async fn test_405_response_allow_header_comma_separated_format() {
let handler = build_handler(vec![Method::GET, Method::PUT]);
let request = build_method_request(Method::PATCH);
let response = Handler::handle(&handler, request).await.unwrap();
assert_eq!(response.status, hyper::StatusCode::METHOD_NOT_ALLOWED);
let allow_header = response
.headers
.get(hyper::header::ALLOW)
.expect("Allow header must be present");
let allow_str = allow_header.to_str().unwrap();
let methods: Vec<&str> = allow_str.split(", ").collect();
assert_eq!(
methods.len(),
2,
"Allow header must contain exactly 2 methods"
);
for method in &methods {
assert!(
*method == "GET" || *method == "PUT",
"Unexpected method in Allow header: {}",
method
);
}
}
#[rstest]
#[tokio::test]
async fn test_registered_method_does_not_return_405() {
let handler = build_handler(vec![Method::GET]);
let request = build_method_request(Method::GET);
let response = Handler::handle(&handler, request).await.unwrap();
assert_eq!(response.status, hyper::StatusCode::OK);
}
}