lychee_lib/chain/mod.rs
1//! [Chain of responsibility pattern][pattern] implementation.
2//!
3//! lychee is based on a chain of responsibility, where each handler can modify
4//! a request and decide if it should be passed to the next element or not.
5//!
6//! The chain is implemented as a vector of [`Handler`] handlers. It is
7//! traversed by calling [`Chain::traverse`], which will call
8//! [`Handler::handle`] on each handler in the chain consecutively.
9//!
10//! To add external handlers, you can implement the [`Handler`] trait and add
11//! the handler to the chain.
12//!
13//! [pattern]: https://github.com/lpxxn/rust-design-pattern/blob/master/behavioral/chain_of_responsibility.rs
14use crate::Status;
15use async_trait::async_trait;
16use core::fmt::Debug;
17use std::sync::Arc;
18use tokio::sync::Mutex;
19
20/// Result of a handler.
21///
22/// This is used to decide if the chain should continue to the next handler or
23/// stop and return the result:
24///
25/// - If the chain should continue, the handler should return
26/// [`ChainResult::Next`]. This will traverse the next handler in the chain.
27/// - If the chain should stop, the handler should return [`ChainResult::Done`].
28/// All subsequent chain elements are skipped and the result is returned.
29#[derive(Debug, PartialEq)]
30pub enum ChainResult<T, R> {
31 /// Continue to the next handler in the chain.
32 Next(T),
33 /// Stop the chain and return the result.
34 Done(R),
35}
36
37/// Request chain type
38///
39/// Lychee uses a chain of responsibility pattern to handle requests.
40/// Each handler in the chain can modify the request and decide if it should be
41/// passed to the next handler or not.
42///
43/// The chain is implemented as a vector of handlers. It is traversed by calling
44/// `traverse` on the [`Chain`], which in turn will call [`Handler::handle`] on
45/// each handler in the chain consecutively.
46///
47/// To add external handlers, you can implement the [`Handler`] trait and add your
48/// handler to the chain.
49///
50/// The entire request chain takes a request as input and returns a status.
51///
52/// # Example
53///
54/// ```rust
55/// use async_trait::async_trait;
56/// use lychee_lib::{chain::RequestChain, ChainResult, ClientBuilder, Handler, Result, Status};
57/// use reqwest::{Method, Request, Url};
58///
59/// // Define your own custom handler
60/// #[derive(Debug)]
61/// struct DummyHandler {}
62///
63/// #[async_trait]
64/// impl Handler<Request, Status> for DummyHandler {
65/// async fn handle(&mut self, mut request: Request) -> ChainResult<Request, Status> {
66/// // Modify the request here
67/// // After that, continue to the next handler
68/// ChainResult::Next(request)
69/// }
70/// }
71///
72/// #[tokio::main]
73/// async fn main() -> Result<()> {
74/// // Build a custom request chain with our dummy handler
75/// let chain = RequestChain::new(vec![Box::new(DummyHandler {})]);
76///
77/// let client = ClientBuilder::builder()
78/// .plugin_request_chain(chain)
79/// .build()
80/// .client()?;
81///
82/// let result = client.check("https://wikipedia.org").await;
83/// println!("{:?}", result);
84///
85/// Ok(())
86/// }
87/// ```
88pub type RequestChain = Chain<reqwest::Request, Status>;
89
90/// Inner chain type.
91///
92/// This holds all handlers, which were chained together.
93/// Handlers are traversed in order.
94///
95/// Each handler needs to implement the `Handler` trait and be `Send`, because
96/// the chain is traversed concurrently and the handlers can be sent between
97/// threads.
98pub(crate) type InnerChain<T, R> = Vec<Box<dyn Handler<T, R> + Send>>;
99
100/// The outer chain type.
101///
102/// This is a wrapper around the inner chain type and allows for
103/// concurrent access to the chain.
104#[derive(Debug)]
105pub struct Chain<T, R>(Arc<Mutex<InnerChain<T, R>>>);
106
107impl<T, R> Default for Chain<T, R> {
108 fn default() -> Self {
109 Self(Arc::new(Mutex::new(InnerChain::default())))
110 }
111}
112
113impl<T, R> Clone for Chain<T, R> {
114 fn clone(&self) -> Self {
115 // Cloning the chain is a cheap operation, because the inner chain is
116 // wrapped in an `Arc` and `Mutex`.
117 Self(self.0.clone())
118 }
119}
120
121impl<T, R> Chain<T, R> {
122 /// Create a new chain from a vector of chainable handlers
123 #[must_use]
124 pub fn new(values: InnerChain<T, R>) -> Self {
125 Self(Arc::new(Mutex::new(values)))
126 }
127
128 /// Traverse the chain with the given input.
129 ///
130 /// This will call `chain` on each handler in the chain and return
131 /// the result. If a handler returns `ChainResult::Done`, the chain
132 /// will stop and return.
133 ///
134 /// If no handler returns `ChainResult::Done`, the chain will return
135 /// `ChainResult::Next` with the input.
136 pub(crate) async fn traverse(&self, mut input: T) -> ChainResult<T, R> {
137 use ChainResult::{Done, Next};
138 for e in self.0.lock().await.iter_mut() {
139 match e.handle(input).await {
140 Next(r) => input = r,
141 Done(r) => {
142 return Done(r);
143 }
144 }
145 }
146
147 Next(input)
148 }
149}
150
151/// Handler trait for implementing request handlers
152///
153/// This trait needs to be implemented by all chainable handlers.
154/// It is the only requirement to handle requests in lychee.
155///
156/// It takes an input request and returns a [`ChainResult`], which can be either
157/// [`ChainResult::Next`] to continue to the next handler or
158/// [`ChainResult::Done`] to stop the chain.
159///
160/// The request can be modified by the handler before it is passed to the next
161/// handler. This allows for modifying the request, such as adding headers or
162/// changing the URL (e.g. for remapping or filtering).
163#[async_trait]
164pub trait Handler<T, R>: Debug {
165 /// Given an input request, return a [`ChainResult`] to continue or stop the
166 /// chain.
167 ///
168 /// The input request can be modified by the handler before it is passed to
169 /// the next handler.
170 ///
171 /// # Example
172 ///
173 /// ```
174 /// use lychee_lib::{Handler, ChainResult, Status};
175 /// use reqwest::Request;
176 /// use async_trait::async_trait;
177 ///
178 /// #[derive(Debug)]
179 /// struct AddHeader;
180 ///
181 /// #[async_trait]
182 /// impl Handler<Request, Status> for AddHeader {
183 /// async fn handle(&mut self, mut request: Request) -> ChainResult<Request, Status> {
184 /// // You can modify the request however you like here
185 /// request.headers_mut().append("X-Header", "value".parse().unwrap());
186 ///
187 /// // Pass the request to the next handler
188 /// ChainResult::Next(request)
189 /// }
190 /// }
191 /// ```
192 async fn handle(&mut self, input: T) -> ChainResult<T, R>;
193}
194
195/// Client request chains
196///
197/// This struct holds all request chains.
198///
199/// Usually, this is used to hold the default request chain and the external
200/// plugin request chain.
201#[derive(Debug)]
202pub(crate) struct ClientRequestChains<'a> {
203 chains: Vec<&'a RequestChain>,
204}
205
206impl<'a> ClientRequestChains<'a> {
207 /// Create a new chain of request chains.
208 pub(crate) const fn new(chains: Vec<&'a RequestChain>) -> Self {
209 Self { chains }
210 }
211
212 /// Traverse all request chains and resolve to a status.
213 pub(crate) async fn traverse(&self, mut input: reqwest::Request) -> Status {
214 use ChainResult::{Done, Next};
215
216 for e in &self.chains {
217 match e.traverse(input).await {
218 Next(r) => input = r,
219 Done(r) => {
220 return r;
221 }
222 }
223 }
224
225 // Consider the request to be excluded if no chain element has converted
226 // it to a `ChainResult::Done`
227 Status::Excluded
228 }
229}
230
231mod test {
232 use super::{
233 ChainResult,
234 ChainResult::{Done, Next},
235 Handler,
236 };
237 use async_trait::async_trait;
238
239 #[allow(dead_code)] // work-around
240 #[derive(Debug)]
241 struct Add(usize);
242
243 #[derive(Debug, PartialEq, Eq)]
244 struct Result(usize);
245
246 #[async_trait]
247 impl Handler<Result, Result> for Add {
248 async fn handle(&mut self, req: Result) -> ChainResult<Result, Result> {
249 let added = req.0 + self.0;
250 if added > 100 {
251 Done(Result(req.0))
252 } else {
253 Next(Result(added))
254 }
255 }
256 }
257
258 #[tokio::test]
259 async fn simple_chain() {
260 use super::Chain;
261 let chain: Chain<Result, Result> = Chain::new(vec![Box::new(Add(7)), Box::new(Add(3))]);
262 let result = chain.traverse(Result(0)).await;
263 assert_eq!(result, Next(Result(10)));
264 }
265
266 #[tokio::test]
267 async fn early_exit_chain() {
268 use super::Chain;
269 let chain: Chain<Result, Result> =
270 Chain::new(vec![Box::new(Add(80)), Box::new(Add(30)), Box::new(Add(1))]);
271 let result = chain.traverse(Result(0)).await;
272 assert_eq!(result, Done(Result(80)));
273 }
274}