at-jet 0.7.2

High-performance HTTP + Protobuf API framework for mobile services
Documentation
//! Request extractors for AT-Jet
//!
//! Provides Protobuf-aware request body extractors for axum handlers.

use {crate::{content_types::APPLICATION_PROTOBUF,
             error::JetError},
     axum::{async_trait,
            body::Bytes,
            extract::{FromRequest,
                      Request},
            http::header::CONTENT_TYPE},
     prost::Message,
     std::marker::PhantomData};

/// Maximum request body size (10MB default)
const MAX_BODY_SIZE: usize = 10 * 1024 * 1024;

/// Protobuf request body extractor
///
/// Extracts and decodes a Protobuf message from the request body.
///
/// # Example
///
/// ```rust,ignore
/// use at_jet::prelude::*;
///
/// async fn create_user(
///     ProtobufRequest(request): ProtobufRequest<CreateUserRequest>
/// ) -> ProtobufResponse<User> {
///     // request is already decoded
///     let user = User {
///         id: 1,
///         name: request.name,
///     };
///     ProtobufResponse::ok(user)
/// }
/// ```
pub struct ProtobufRequest<T>(pub T)
where
  T: Message + Default;

#[async_trait]
impl<S, T> FromRequest<S> for ProtobufRequest<T>
where
  S: Send + Sync,
  T: Message + Default,
{
  type Rejection = JetError;

  async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
    // Check content type
    let content_type = req
      .headers()
      .get(CONTENT_TYPE)
      .and_then(|v| v.to_str().ok())
      .unwrap_or("");

    if !content_type.starts_with(APPLICATION_PROTOBUF) {
      return Err(JetError::InvalidContentType {
        expected: APPLICATION_PROTOBUF.to_string(),
        actual:   content_type.to_string(),
      });
    }

    // Extract body
    let bytes = Bytes::from_request(req, state)
      .await
      .map_err(|e| JetError::BadRequest(format!("Failed to read body: {}", e)))?;

    // Check size
    if bytes.len() > MAX_BODY_SIZE {
      return Err(JetError::BodyTooLarge {
        size: bytes.len(),
        max:  MAX_BODY_SIZE,
      });
    }

    // Decode protobuf
    let message = T::decode(bytes)?;

    Ok(ProtobufRequest(message))
  }
}

/// Optional Protobuf request body extractor
///
/// Like `ProtobufRequest`, but returns `None` if the body is empty.
pub struct OptionalProtobufRequest<T>(pub Option<T>)
where
  T: Message + Default;

#[async_trait]
impl<S, T> FromRequest<S> for OptionalProtobufRequest<T>
where
  S: Send + Sync,
  T: Message + Default,
{
  type Rejection = JetError;

  async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
    let bytes = Bytes::from_request(req, state)
      .await
      .map_err(|e| JetError::BadRequest(format!("Failed to read body: {}", e)))?;

    if bytes.is_empty() {
      return Ok(OptionalProtobufRequest(None));
    }

    if bytes.len() > MAX_BODY_SIZE {
      return Err(JetError::BodyTooLarge {
        size: bytes.len(),
        max:  MAX_BODY_SIZE,
      });
    }

    let message = T::decode(bytes)?;
    Ok(OptionalProtobufRequest(Some(message)))
  }
}

/// Protobuf request with configurable max size
pub struct ProtobufRequestWithLimit<T, const LIMIT: usize>(pub T, PhantomData<T>)
where
  T: Message + Default;

#[async_trait]
impl<S, T, const LIMIT: usize> FromRequest<S> for ProtobufRequestWithLimit<T, LIMIT>
where
  S: Send + Sync,
  T: Message + Default,
{
  type Rejection = JetError;

  async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
    let bytes = Bytes::from_request(req, state)
      .await
      .map_err(|e| JetError::BadRequest(format!("Failed to read body: {}", e)))?;

    if bytes.len() > LIMIT {
      return Err(JetError::BodyTooLarge {
        size: bytes.len(),
        max:  LIMIT,
      });
    }

    let message = T::decode(bytes)?;
    Ok(ProtobufRequestWithLimit(message, PhantomData))
  }
}