1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
//! Middleware types
//!
//! # Examples
//! ```no_run
//! use crux_http::middleware::{Next, Middleware};
//! use crux_http::{client::Client, Request, ResponseAsync, Result};
//! use std::time;
//! use std::sync::Arc;
//!
//! /// Log each request's duration
//! #[derive(Debug)]
//! pub struct Logger;
//!
//! #[async_trait::async_trait]
//! impl Middleware for Logger {
//!     async fn handle(
//!         &self,
//!         req: Request,
//!         client: Client,
//!         next: Next<'_>,
//!     ) -> Result<ResponseAsync> {
//!         println!("sending request to {}", req.url());
//!         let now = time::Instant::now();
//!         let res = next.run(req, client).await?;
//!         println!("request completed ({:?})", now.elapsed());
//!         Ok(res)
//!     }
//! }
//! ```
//! `Middleware` can also be instantiated using a free function thanks to some convenient trait
//! implementations.
//!
//! ```no_run
//! use futures_util::future::BoxFuture;
//! use crux_http::middleware::{Next, Middleware};
//! use crux_http::{client::Client, Request, ResponseAsync, Result};
//! use std::time;
//! use std::sync::Arc;
//!
//! fn logger<'a>(req: Request, client: Client, next: Next<'a>) -> BoxFuture<'a, Result<ResponseAsync>> {
//!     Box::pin(async move {
//!         println!("sending request to {}", req.url());
//!         let now = time::Instant::now();
//!         let res = next.run(req, client).await?;
//!         println!("request completed ({:?})", now.elapsed());
//!         Ok(res)
//!     })
//! }
//! ```

use std::sync::Arc;

use crate::{Client, Request, ResponseAsync, Result};

mod redirect;

pub use redirect::Redirect;

use async_trait::async_trait;
use futures_util::future::BoxFuture;

/// Middleware that wraps around remaining middleware chain.
#[async_trait]
pub trait Middleware: 'static + Send + Sync {
    /// Asynchronously handle the request, and return a response.
    async fn handle(&self, req: Request, client: Client, next: Next<'_>) -> Result<ResponseAsync>;
}

// This allows functions to work as middleware too.
#[async_trait]
impl<F> Middleware for F
where
    F: Send
        + Sync
        + 'static
        + for<'a> Fn(Request, Client, Next<'a>) -> BoxFuture<'a, Result<ResponseAsync>>,
{
    async fn handle(&self, req: Request, client: Client, next: Next<'_>) -> Result<ResponseAsync> {
        (self)(req, client, next).await
    }
}

/// The remainder of a middleware chain, including the endpoint.
#[allow(missing_debug_implementations)]
pub struct Next<'a> {
    next_middleware: &'a [Arc<dyn Middleware>],
    endpoint: &'a (dyn (Fn(Request, Client) -> BoxFuture<'static, Result<ResponseAsync>>)
             + Send
             + Sync
             + 'static),
}

impl Clone for Next<'_> {
    fn clone(&self) -> Self {
        *self
    }
}

impl Copy for Next<'_> {}

impl<'a> Next<'a> {
    /// Create a new instance
    pub fn new(
        next: &'a [Arc<dyn Middleware>],
        endpoint: &'a (dyn (Fn(Request, Client) -> BoxFuture<'static, Result<ResponseAsync>>)
                 + Send
                 + Sync
                 + 'static),
    ) -> Self {
        Self {
            endpoint,
            next_middleware: next,
        }
    }

    /// Asynchronously execute the remaining middleware chain.
    pub fn run(mut self, req: Request, client: Client) -> BoxFuture<'a, Result<ResponseAsync>> {
        if let Some((current, next)) = self.next_middleware.split_first() {
            self.next_middleware = next;
            current.handle(req, client, self)
        } else {
            (self.endpoint)(req, client)
        }
    }
}