use super::{FromRequest, IntoResponse, ResponseError, StatusCode};
use crate::router::CoapumRequest;
use async_trait::async_trait;
use coap_lite::ObserveOption;
use std::{fmt, net::SocketAddr};
pub struct Identity(pub String);
impl fmt::Debug for Identity {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Identity").field(&self.0).finish()
}
}
impl Clone for Identity {
fn clone(&self) -> Self {
Identity(self.0.clone())
}
}
impl std::ops::Deref for Identity {
type Target = String;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::ops::DerefMut for Identity {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[async_trait]
impl<S> FromRequest<S> for Identity {
type Rejection = std::convert::Infallible;
async fn from_request(
req: &CoapumRequest<SocketAddr>,
_state: &S,
) -> Result<Self, Self::Rejection> {
Ok(Identity(req.identity.clone()))
}
}
pub struct Source(pub SocketAddr);
impl fmt::Debug for Source {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Source").field(&self.0).finish()
}
}
impl Clone for Source {
fn clone(&self) -> Self {
*self
}
}
impl Copy for Source {}
impl std::ops::Deref for Source {
type Target = SocketAddr;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[async_trait]
impl<S> FromRequest<S> for Source {
type Rejection = std::convert::Infallible;
async fn from_request(
req: &CoapumRequest<SocketAddr>,
_state: &S,
) -> Result<Self, Self::Rejection> {
let addr = req.source.unwrap_or_else(|| "0.0.0.0:0".parse().unwrap());
Ok(Source(addr))
}
}
pub struct ObserveFlag(pub Option<ObserveOption>);
impl fmt::Debug for ObserveFlag {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("ObserveFlag").field(&self.0).finish()
}
}
impl Clone for ObserveFlag {
fn clone(&self) -> Self {
*self
}
}
impl Copy for ObserveFlag {}
impl std::ops::Deref for ObserveFlag {
type Target = Option<ObserveOption>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[async_trait]
impl<S> FromRequest<S> for ObserveFlag {
type Rejection = std::convert::Infallible;
async fn from_request(
req: &CoapumRequest<SocketAddr>,
_state: &S,
) -> Result<Self, Self::Rejection> {
Ok(ObserveFlag(*req.get_observe_flag()))
}
}
pub struct State<T>(pub T);
impl<T> fmt::Debug for State<T>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("State").field(&self.0).finish()
}
}
impl<T> Clone for State<T>
where
T: Clone,
{
fn clone(&self) -> Self {
State(self.0.clone())
}
}
impl<T> std::ops::Deref for State<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> std::ops::DerefMut for State<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[derive(Debug)]
pub struct StateRejection {
message: String,
}
impl fmt::Display for StateRejection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "State extraction failed: {}", self.message)
}
}
impl std::error::Error for StateRejection {}
impl IntoResponse for StateRejection {
fn into_response(self) -> Result<crate::CoapResponse, ResponseError> {
StatusCode::InternalServerError.into_response()
}
}
#[async_trait]
impl<T, S> FromRequest<S> for State<T>
where
T: Clone + Send + Sync + 'static,
S: AsRef<T> + Send + Sync,
{
type Rejection = StateRejection;
async fn from_request(
_req: &CoapumRequest<SocketAddr>,
state: &S,
) -> Result<Self, Self::Rejection> {
Ok(State(state.as_ref().clone()))
}
}
pub struct FullRequest(pub CoapumRequest<SocketAddr>);
impl fmt::Debug for FullRequest {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("FullRequest")
.field(&format!("CoapumRequest({})", self.0.get_path()))
.finish()
}
}
impl Clone for FullRequest {
fn clone(&self) -> Self {
FullRequest(self.0.clone())
}
}
impl std::ops::Deref for FullRequest {
type Target = CoapumRequest<SocketAddr>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::ops::DerefMut for FullRequest {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[async_trait]
impl<S> FromRequest<S> for FullRequest {
type Rejection = std::convert::Infallible;
async fn from_request(
req: &CoapumRequest<SocketAddr>,
_state: &S,
) -> Result<Self, Self::Rejection> {
Ok(FullRequest(req.clone()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{CoapRequest, Packet};
use coap_lite::RequestType;
use std::net::{Ipv4Addr, SocketAddrV4};
fn create_test_request() -> CoapumRequest<SocketAddr> {
let mut request = CoapRequest::from_packet(
Packet::new(),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080)),
);
request.set_method(RequestType::Get);
request.set_path("test");
let mut coap_request: CoapumRequest<SocketAddr> = request.into();
coap_request.identity = "test_client".to_string();
coap_request
}
#[tokio::test]
async fn test_identity_extraction() {
let req = create_test_request();
let result = Identity::from_request(&req, &()).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().0, "test_client");
}
#[tokio::test]
async fn test_source_extraction() {
let req = create_test_request();
let result = Source::from_request(&req, &()).await;
assert!(result.is_ok());
let source = result.unwrap();
assert_eq!(source.port(), 8080);
}
#[tokio::test]
async fn test_observe_flag_extraction() {
let req = create_test_request();
let result = ObserveFlag::from_request(&req, &()).await;
assert!(result.is_ok());
let observe_flag = result.unwrap();
assert!(observe_flag.is_none());
}
#[tokio::test]
async fn test_state_extraction() {
#[derive(Clone, Debug, PartialEq)]
struct TestState {
value: i32,
}
impl AsRef<TestState> for TestState {
fn as_ref(&self) -> &TestState {
self
}
}
let req = create_test_request();
let state = TestState { value: 42 };
let result = State::<TestState>::from_request(&req, &state).await;
assert!(result.is_ok());
let extracted_state = result.unwrap();
assert_eq!(extracted_state.value, 42);
}
#[tokio::test]
async fn test_full_request_extraction() {
let req = create_test_request();
let result = FullRequest::from_request(&req, &()).await;
assert!(result.is_ok());
let full_request = result.unwrap();
assert_eq!(full_request.get_path(), "test");
assert_eq!(*full_request.get_method(), RequestType::Get);
assert_eq!(full_request.identity, "test_client");
}
}