rustapi-core 0.1.450

The core engine of the RustAPI framework. Provides the hyper-based HTTP server, router, extraction logic, and foundational traits.
Documentation
//! Request ID middleware
//!
//! Generates a unique UUID for each request and makes it available via the `RequestId` extractor.

use super::layer::{BoxedNext, MiddlewareLayer};
use crate::error::{ApiError, Result};
use crate::extract::FromRequestParts;
use crate::request::Request;
use crate::response::Response;
use std::future::Future;
use std::pin::Pin;

/// A unique identifier for a request
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct RequestId(pub String);

impl RequestId {
    /// Create a new RequestId with a generated UUID
    pub fn new() -> Self {
        Self(generate_uuid())
    }

    /// Create a RequestId from an existing string
    pub fn from_string(id: String) -> Self {
        Self(id)
    }

    /// Get the request ID as a string slice
    pub fn as_str(&self) -> &str {
        &self.0
    }
}

impl Default for RequestId {
    fn default() -> Self {
        Self::new()
    }
}

impl std::fmt::Display for RequestId {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.0)
    }
}

/// Extractor for RequestId from request extensions
///
/// This extractor retrieves the request ID that was generated by `RequestIdLayer`.
/// Returns an error if the RequestIdLayer middleware was not applied.
///
/// # Example
///
/// ```rust,ignore
/// use rustapi_core::middleware::RequestId;
///
/// async fn handler(request_id: RequestId) -> impl IntoResponse {
///     format!("Request ID: {}", request_id)
/// }
/// ```
impl FromRequestParts for RequestId {
    fn from_request_parts(req: &Request) -> Result<Self> {
        req.extensions().get::<RequestId>().cloned().ok_or_else(|| {
            ApiError::internal(
                "RequestId not found. Did you forget to add RequestIdLayer middleware?",
            )
        })
    }
}

/// Middleware layer that generates a unique request ID for each request
#[derive(Clone, Default)]
pub struct RequestIdLayer;

impl RequestIdLayer {
    /// Create a new RequestIdLayer
    pub fn new() -> Self {
        Self
    }
}

impl MiddlewareLayer for RequestIdLayer {
    fn call(
        &self,
        mut req: Request,
        next: BoxedNext,
    ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
        Box::pin(async move {
            // Generate a unique request ID
            let request_id = RequestId::new();

            // Store in request extensions
            req.extensions_mut().insert(request_id.clone());

            // Call the next handler
            let mut response = next(req).await;

            // Add request ID to response headers
            if let Ok(header_value) = request_id.0.parse() {
                response.headers_mut().insert("x-request-id", header_value);
            }

            response
        })
    }

    fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
        Box::new(self.clone())
    }
}

/// Generate a UUID v4 string
///
/// This is a simple implementation that doesn't require external dependencies.
fn generate_uuid() -> String {
    use std::time::{SystemTime, UNIX_EPOCH};

    // Get current time for entropy
    let now = SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .unwrap_or_default();

    // Use time and a counter for uniqueness
    let time_part = now.as_nanos();

    // Generate random-ish bytes using time and thread ID
    let thread_id = std::thread::current().id();
    let thread_hash = format!("{:?}", thread_id);

    // Create a simple hash combining time and thread info
    let mut bytes = [0u8; 16];

    // Fill with time-based entropy
    let time_bytes = time_part.to_le_bytes();
    for (i, &b) in time_bytes.iter().enumerate().take(16) {
        bytes[i] = b;
    }

    // Mix in thread hash
    for (i, b) in thread_hash.bytes().enumerate() {
        bytes[i % 16] ^= b;
    }

    // Add some additional entropy from a simple counter
    static COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
    let count = COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
    let count_bytes = count.to_le_bytes();
    for (i, &b) in count_bytes.iter().enumerate() {
        bytes[(i + 8) % 16] ^= b;
    }

    // Set version (4) and variant bits for UUID v4 format
    bytes[6] = (bytes[6] & 0x0f) | 0x40; // Version 4
    bytes[8] = (bytes[8] & 0x3f) | 0x80; // Variant 1

    // Format as UUID string
    format!(
        "{:02x}{:02x}{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}",
        bytes[0], bytes[1], bytes[2], bytes[3],
        bytes[4], bytes[5],
        bytes[6], bytes[7],
        bytes[8], bytes[9],
        bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15]
    )
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::middleware::layer::{BoxedNext, LayerStack};
    use crate::path_params::PathParams;
    use bytes::Bytes;
    use http::{Extensions, Method, StatusCode};
    use proptest::prelude::*;
    use proptest::test_runner::TestCaseError;
    use std::collections::HashSet;
    use std::sync::Arc;

    /// Create a test request with the given method and path
    fn create_test_request(method: Method, path: &str) -> Request {
        let uri: http::Uri = path.parse().unwrap();
        let builder = http::Request::builder().method(method).uri(uri);

        let req = builder.body(()).unwrap();
        let (parts, _) = req.into_parts();

        Request::new(
            parts,
            crate::request::BodyVariant::Buffered(Bytes::new()),
            Arc::new(Extensions::new()),
            PathParams::new(),
        )
    }

    #[test]
    fn test_request_id_generation() {
        let id1 = RequestId::new();
        let id2 = RequestId::new();

        // IDs should be different
        assert_ne!(id1.0, id2.0);

        // IDs should be valid UUID format (36 chars with hyphens)
        assert_eq!(id1.0.len(), 36);
        assert_eq!(id2.0.len(), 36);
    }

    #[test]
    fn test_request_id_display() {
        let id = RequestId::from_string("test-id-123".to_string());
        assert_eq!(format!("{}", id), "test-id-123");
    }

    // **Feature: phase3-batteries-included, Property 3: Request ID uniqueness**
    //
    // For any set of N concurrent requests processed with RequestIdLayer enabled,
    // the System SHALL generate N distinct UUID values, each accessible via the
    // `RequestId` extractor.
    //
    // **Validates: Requirements 1.3**
    proptest! {
        #![proptest_config(ProptestConfig::with_cases(100))]

        #[test]
        fn prop_request_id_uniqueness(
            num_requests in 1usize..100usize,
        ) {
            let rt = tokio::runtime::Runtime::new().unwrap();
            let result: Result<(), TestCaseError> = rt.block_on(async {
                let mut stack = LayerStack::new();
                stack.push(Box::new(RequestIdLayer::new()));

                // Collect all generated request IDs
                let collected_ids = Arc::new(std::sync::Mutex::new(Vec::new()));

                // Process multiple requests through the middleware
                for _ in 0..num_requests {
                    let ids = collected_ids.clone();

                    // Create a handler that extracts and stores the request ID
                    let handler: BoxedNext = Arc::new(move |req: Request| {
                        let ids = ids.clone();
                        Box::pin(async move {
                            // Extract the request ID from extensions
                            if let Some(request_id) = req.extensions().get::<RequestId>() {
                                ids.lock().unwrap().push(request_id.0.clone());
                            }

                            http::Response::builder()
                                .status(StatusCode::OK)
                                .body(crate::response::Body::from("ok"))
                                .unwrap()
                        }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
                    });

                    let request = create_test_request(Method::GET, "/test");
                    let _response = stack.execute(request, handler).await;
                }

                // Verify all IDs are unique
                let ids = collected_ids.lock().unwrap();
                prop_assert_eq!(ids.len(), num_requests, "Should have collected {} IDs", num_requests);

                let unique_ids: HashSet<_> = ids.iter().collect();
                prop_assert_eq!(
                    unique_ids.len(),
                    num_requests,
                    "All {} request IDs should be unique, but found {} unique IDs",
                    num_requests,
                    unique_ids.len()
                );

                // Verify all IDs are valid UUID format (36 chars with hyphens)
                for id in ids.iter() {
                    prop_assert_eq!(id.len(), 36, "Request ID should be 36 characters (UUID format)");
                    // Check UUID format: 8-4-4-4-12
                    let parts: Vec<&str> = id.split('-').collect();
                    prop_assert_eq!(parts.len(), 5, "UUID should have 5 parts separated by hyphens");
                    prop_assert_eq!(parts[0].len(), 8);
                    prop_assert_eq!(parts[1].len(), 4);
                    prop_assert_eq!(parts[2].len(), 4);
                    prop_assert_eq!(parts[3].len(), 4);
                    prop_assert_eq!(parts[4].len(), 12);
                }

                Ok(())
            });
            result?;
        }
    }

    #[test]
    fn test_request_id_extractor() {
        let rt = tokio::runtime::Runtime::new().unwrap();
        rt.block_on(async {
            let mut stack = LayerStack::new();
            stack.push(Box::new(RequestIdLayer::new()));

            let extracted_id = Arc::new(std::sync::Mutex::new(None));
            let extracted_id_clone = extracted_id.clone();

            let handler: BoxedNext = Arc::new(move |req: Request| {
                let extracted_id = extracted_id_clone.clone();
                Box::pin(async move {
                    // Use the FromRequestParts implementation
                    if let Ok(request_id) = RequestId::from_request_parts(&req) {
                        *extracted_id.lock().unwrap() = Some(request_id.0.clone());
                    }

                    http::Response::builder()
                        .status(StatusCode::OK)
                        .body(crate::response::Body::from("ok"))
                        .unwrap()
                }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
            });

            let request = create_test_request(Method::GET, "/test");
            let _response = stack.execute(request, handler).await;

            // Verify the request ID was extracted
            let id = extracted_id.lock().unwrap();
            assert!(id.is_some(), "Request ID should have been extracted");
            assert_eq!(
                id.as_ref().unwrap().len(),
                36,
                "Request ID should be UUID format"
            );
        });
    }

    #[test]
    fn test_request_id_extractor_without_middleware() {
        // Test that extractor returns error when middleware is not applied
        let request = create_test_request(Method::GET, "/test");
        let result = RequestId::from_request_parts(&request);
        assert!(
            result.is_err(),
            "Should return error when RequestIdLayer is not applied"
        );
    }
}