actix_request_hook/
lib.rs

1//! Actix web middleware hook for request start and end. Subscribe to request start and end, request data, elapsed time, request id and response status.
2//!
3//! Setup:
4//! ```
5//! use std::rc::Rc;
6//! use actix_web::{App, Error, HttpServer, web};
7//! use actix_request_hook::observer::{Observer, RequestEndData, RequestStartData};
8//! use actix_request_hook::RequestHook;
9//! struct RequestLogger;
10//!
11//! impl Observer for RequestLogger {
12//!     fn on_request_started(&self, data: RequestStartData) {
13//!         println!("started {}", data.uri)
14//!     }
15//!
16//!     fn on_request_ended(&self, data: RequestEndData) {
17//!         println!("ended {} after {}ms", data.uri, data.elapsed.as_millis())
18//!     }
19//! }
20//!
21//! async fn index() -> Result<String, Error> {
22//!     Ok("Hi there!".to_string())
23//! }
24//!
25//!
26//! // You can register many different observers.
27//! // One could be for logging, other for notifying other parts of the system etc.
28//! let request_hook = RequestHook::new()
29//!            .exclude("/bye") // bye route shouldn't be logged
30//!            .exclude_regex("^/\\d$") // excludes any numbered route like "/123456"
31//!            .register(Rc::new(RequestLogger{}));
32//! App::new()
33//!     .wrap(request_hook)
34//!     .route("/bye", web::get().to(index))
35//!     .route("/hey", web::get().to(index));
36//!
37//! ```
38//!
39use std::cell::RefCell;
40use std::collections::HashSet;
41use std::future::{ready, Future, Ready};
42use std::pin::Pin;
43use std::rc::Rc;
44use std::time::Instant;
45
46use actix_web::body::MessageBody;
47use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform};
48use actix_web::web::{Buf, BytesMut};
49use actix_web::{Error, HttpMessage};
50use futures_util::task::{Context, Poll};
51use futures_util::StreamExt;
52use regex::RegexSet;
53use uuid::Uuid;
54
55use crate::observer::{Observer, RequestEndData, RequestStartData};
56use crate::util::get_payload;
57
58pub mod observer;
59mod tests;
60mod util;
61
62/// Middleware for subscribing to request start and end. Enables access to request data, id, status and request duration.
63pub struct RequestHook(Rc<Inner>);
64
65impl Default for RequestHook {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71impl RequestHook {
72    pub fn new() -> Self {
73        Self(Rc::new(Inner {
74            exclude: HashSet::new(),
75            exclude_regex: RegexSet::empty(),
76            observers: Vec::new(),
77        }))
78    }
79
80    /// Ignore and do not log access info for specified path.
81    pub fn exclude<T: Into<String>>(mut self, path: T) -> Self {
82        Rc::get_mut(&mut self.0)
83            .unwrap()
84            .exclude
85            .insert(path.into());
86        self
87    }
88
89    /// Ignore and do not log access info for paths that match regex.
90    pub fn exclude_regex<T: Into<String>>(mut self, path: T) -> Self {
91        let inner = Rc::get_mut(&mut self.0).unwrap();
92        let mut patterns = inner.exclude_regex.patterns().to_vec();
93        patterns.push(path.into());
94        let regex_set = RegexSet::new(patterns).unwrap();
95        inner.exclude_regex = regex_set;
96        self
97    }
98
99    /// Registers an [Observer].
100    pub fn register<T: 'static + Observer>(mut self, observer: Rc<T>) -> Self {
101        Rc::get_mut(&mut self.0).unwrap().observers.push(observer);
102        self
103    }
104}
105
106/// Contains configuration for [RequestHook]
107///
108/// # Properties
109/// * `exclude` - excluded path is ignored.
110/// * `exclude_regex` - same as `exclude`, just uses regex instead of exact match.
111/// * `observers` - a list of observers for actix request.
112#[derive(Clone)]
113struct Inner {
114    exclude: HashSet<String>,
115    exclude_regex: RegexSet,
116    observers: Vec<Rc<dyn Observer>>,
117}
118
119impl<S: 'static, B> Transform<S, ServiceRequest> for RequestHook
120where
121    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
122    B: MessageBody,
123{
124    type Response = S::Response;
125    type Error = Error;
126    type Transform = RequestHookMiddleware<S>;
127    type InitError = ();
128    type Future = Ready<Result<Self::Transform, Self::InitError>>;
129
130    fn new_transform(&self, service: S) -> Self::Future {
131        ready(Ok(RequestHookMiddleware {
132            service: Rc::new(RefCell::new(service)),
133            inner: self.0.clone(),
134        }))
135    }
136}
137
138pub struct RequestHookMiddleware<S> {
139    inner: Rc<Inner>,
140    service: Rc<RefCell<S>>,
141}
142
143impl<S: 'static, B> Service<ServiceRequest> for RequestHookMiddleware<S>
144where
145    B: MessageBody,
146    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
147{
148    type Response = ServiceResponse<B>;
149    type Error = Error;
150    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
151    fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
152        self.service.poll_ready(ctx)
153    }
154
155    fn call(&self, mut req: ServiceRequest) -> Self::Future {
156        let svc = self.service.clone();
157
158        let excluded = self.inner.exclude.contains(req.path())
159            || self.inner.exclude_regex.is_match(req.path());
160        if excluded {
161            return Box::pin(svc.call(req));
162        }
163
164        let observers = self.inner.observers.clone();
165
166        let start = Instant::now();
167        let request_id = Uuid::new_v4();
168        let uri = req.uri().to_string();
169        let method = req.method().to_string();
170
171        let future_response = async move {
172            let mut payload = req.take_payload();
173            let mut body = BytesMut::new();
174            while let Some(chunk) = payload.next().await {
175                body.extend_from_slice(chunk.unwrap().chunk())
176            }
177
178            let handler_body = body.clone();
179            let repacked_payload = get_payload(body.freeze());
180
181            for observer in &observers {
182                observer.on_request_started(RequestStartData {
183                    req: &req,
184                    request_id,
185                    uri: uri.to_string(),
186                    method: method.to_string(),
187                    body: handler_body.clone(),
188                })
189            }
190
191            req.set_payload(repacked_payload);
192            let res: Result<ServiceResponse<B>, Error> = svc.call(req).await;
193
194            let elapsed = start.elapsed();
195
196            let (response, status) = match res {
197                Err(err) => {
198                    let status = err.error_response().status();
199                    (Err(err), status)
200                }
201                Ok(service_response) => {
202                    let status = service_response.status();
203
204                    (Ok(service_response), status)
205                }
206            };
207            for observer in &observers {
208                observer.on_request_ended(RequestEndData {
209                    request_id,
210                    elapsed,
211                    uri: uri.to_string(),
212                    method: method.to_string(),
213                    status,
214                })
215            }
216
217            response
218        };
219
220        Box::pin(future_response)
221    }
222}