actix_request_hook/
lib.rs1use std::cell::RefCell;
40use std::collections::HashSet;
41use std::future::{ready, Future, Ready};
42use std::pin::Pin;
43use std::rc::Rc;
44use std::time::Instant;
45
46use actix_web::body::MessageBody;
47use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform};
48use actix_web::web::{Buf, BytesMut};
49use actix_web::{Error, HttpMessage};
50use futures_util::task::{Context, Poll};
51use futures_util::StreamExt;
52use regex::RegexSet;
53use uuid::Uuid;
54
55use crate::observer::{Observer, RequestEndData, RequestStartData};
56use crate::util::get_payload;
57
58pub mod observer;
59mod tests;
60mod util;
61
62pub struct RequestHook(Rc<Inner>);
64
65impl Default for RequestHook {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71impl RequestHook {
72 pub fn new() -> Self {
73 Self(Rc::new(Inner {
74 exclude: HashSet::new(),
75 exclude_regex: RegexSet::empty(),
76 observers: Vec::new(),
77 }))
78 }
79
80 pub fn exclude<T: Into<String>>(mut self, path: T) -> Self {
82 Rc::get_mut(&mut self.0)
83 .unwrap()
84 .exclude
85 .insert(path.into());
86 self
87 }
88
89 pub fn exclude_regex<T: Into<String>>(mut self, path: T) -> Self {
91 let inner = Rc::get_mut(&mut self.0).unwrap();
92 let mut patterns = inner.exclude_regex.patterns().to_vec();
93 patterns.push(path.into());
94 let regex_set = RegexSet::new(patterns).unwrap();
95 inner.exclude_regex = regex_set;
96 self
97 }
98
99 pub fn register<T: 'static + Observer>(mut self, observer: Rc<T>) -> Self {
101 Rc::get_mut(&mut self.0).unwrap().observers.push(observer);
102 self
103 }
104}
105
106#[derive(Clone)]
113struct Inner {
114 exclude: HashSet<String>,
115 exclude_regex: RegexSet,
116 observers: Vec<Rc<dyn Observer>>,
117}
118
119impl<S: 'static, B> Transform<S, ServiceRequest> for RequestHook
120where
121 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
122 B: MessageBody,
123{
124 type Response = S::Response;
125 type Error = Error;
126 type Transform = RequestHookMiddleware<S>;
127 type InitError = ();
128 type Future = Ready<Result<Self::Transform, Self::InitError>>;
129
130 fn new_transform(&self, service: S) -> Self::Future {
131 ready(Ok(RequestHookMiddleware {
132 service: Rc::new(RefCell::new(service)),
133 inner: self.0.clone(),
134 }))
135 }
136}
137
138pub struct RequestHookMiddleware<S> {
139 inner: Rc<Inner>,
140 service: Rc<RefCell<S>>,
141}
142
143impl<S: 'static, B> Service<ServiceRequest> for RequestHookMiddleware<S>
144where
145 B: MessageBody,
146 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
147{
148 type Response = ServiceResponse<B>;
149 type Error = Error;
150 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
151 fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
152 self.service.poll_ready(ctx)
153 }
154
155 fn call(&self, mut req: ServiceRequest) -> Self::Future {
156 let svc = self.service.clone();
157
158 let excluded = self.inner.exclude.contains(req.path())
159 || self.inner.exclude_regex.is_match(req.path());
160 if excluded {
161 return Box::pin(svc.call(req));
162 }
163
164 let observers = self.inner.observers.clone();
165
166 let start = Instant::now();
167 let request_id = Uuid::new_v4();
168 let uri = req.uri().to_string();
169 let method = req.method().to_string();
170
171 let future_response = async move {
172 let mut payload = req.take_payload();
173 let mut body = BytesMut::new();
174 while let Some(chunk) = payload.next().await {
175 body.extend_from_slice(chunk.unwrap().chunk())
176 }
177
178 let handler_body = body.clone();
179 let repacked_payload = get_payload(body.freeze());
180
181 for observer in &observers {
182 observer.on_request_started(RequestStartData {
183 req: &req,
184 request_id,
185 uri: uri.to_string(),
186 method: method.to_string(),
187 body: handler_body.clone(),
188 })
189 }
190
191 req.set_payload(repacked_payload);
192 let res: Result<ServiceResponse<B>, Error> = svc.call(req).await;
193
194 let elapsed = start.elapsed();
195
196 let (response, status) = match res {
197 Err(err) => {
198 let status = err.error_response().status();
199 (Err(err), status)
200 }
201 Ok(service_response) => {
202 let status = service_response.status();
203
204 (Ok(service_response), status)
205 }
206 };
207 for observer in &observers {
208 observer.on_request_ended(RequestEndData {
209 request_id,
210 elapsed,
211 uri: uri.to_string(),
212 method: method.to_string(),
213 status,
214 })
215 }
216
217 response
218 };
219
220 Box::pin(future_response)
221 }
222}