mwc_api/
router.rs

1// Copyright 2019 The Grin Developers
2// Copyright 2024 The MWC Developers
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use futures::future::{self, Future};
17use hyper::service::Service;
18use hyper::{Body, Method, Request, Response, StatusCode};
19use std::collections::hash_map::DefaultHasher;
20use std::hash::{Hash, Hasher};
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24
25lazy_static! {
26	static ref WILDCARD_HASH: u64 = calculate_hash(&"*");
27	static ref WILDCARD_STOP_HASH: u64 = calculate_hash(&"**");
28}
29
30pub type ResponseFuture =
31	Pin<Box<dyn Future<Output = Result<Response<Body>, hyper::Error>> + Send>>;
32
33pub trait Handler {
34	fn get(&self, _req: Request<Body>) -> ResponseFuture {
35		not_found()
36	}
37
38	fn post(&self, _req: Request<Body>) -> ResponseFuture {
39		not_found()
40	}
41
42	fn put(&self, _req: Request<Body>) -> ResponseFuture {
43		not_found()
44	}
45
46	fn patch(&self, _req: Request<Body>) -> ResponseFuture {
47		not_found()
48	}
49
50	fn delete(&self, _req: Request<Body>) -> ResponseFuture {
51		not_found()
52	}
53
54	fn head(&self, _req: Request<Body>) -> ResponseFuture {
55		not_found()
56	}
57
58	fn options(&self, _req: Request<Body>) -> ResponseFuture {
59		not_found()
60	}
61
62	fn trace(&self, _req: Request<Body>) -> ResponseFuture {
63		not_found()
64	}
65
66	fn connect(&self, _req: Request<Body>) -> ResponseFuture {
67		not_found()
68	}
69
70	fn call(
71		&self,
72		req: Request<Body>,
73		mut _handlers: Box<dyn Iterator<Item = HandlerObj>>,
74	) -> ResponseFuture {
75		match *req.method() {
76			Method::GET => self.get(req),
77			Method::POST => self.post(req),
78			Method::PUT => self.put(req),
79			Method::DELETE => self.delete(req),
80			Method::PATCH => self.patch(req),
81			Method::OPTIONS => self.options(req),
82			Method::CONNECT => self.connect(req),
83			Method::TRACE => self.trace(req),
84			Method::HEAD => self.head(req),
85			_ => not_found(),
86		}
87	}
88}
89
90#[derive(Clone, thiserror::Error, Eq, Debug, PartialEq, Serialize, Deserialize)]
91pub enum RouterError {
92	#[error("Route {0} already exists")]
93	RouteAlreadyExists(String),
94	#[error("Route {0} not found")]
95	RouteNotFound(String),
96	#[error("Route value not found for {0}")]
97	NoValue(String),
98}
99
100#[derive(Clone)]
101pub struct Router {
102	nodes: Vec<Node>,
103}
104
105#[derive(Debug, Clone, Copy)]
106struct NodeId(usize);
107
108const MAX_CHILDREN: usize = 16;
109
110pub type HandlerObj = Arc<dyn Handler + Send + Sync>;
111
112#[derive(Clone)]
113pub struct Node {
114	key: u64,
115	value: Option<HandlerObj>,
116	children: [NodeId; MAX_CHILDREN],
117	children_count: usize,
118	mws: Option<Vec<HandlerObj>>,
119}
120
121impl Router {
122	pub fn new() -> Router {
123		let root = Node::new(calculate_hash(&""), None);
124		let mut nodes = vec![];
125		nodes.push(root);
126		Router { nodes }
127	}
128
129	pub fn add_middleware(&mut self, mw: HandlerObj) {
130		self.node_mut(NodeId(0)).add_middleware(mw);
131	}
132
133	fn root(&self) -> NodeId {
134		NodeId(0)
135	}
136
137	fn node(&self, id: NodeId) -> &Node {
138		&self.nodes[id.0]
139	}
140
141	fn node_mut(&mut self, id: NodeId) -> &mut Node {
142		&mut self.nodes[id.0]
143	}
144
145	fn find(&self, parent: NodeId, key: u64) -> Option<NodeId> {
146		let node = self.node(parent);
147		node.children
148			.iter()
149			.find(|&id| {
150				let node_key = self.node(*id).key;
151				node_key == key || node_key == *WILDCARD_HASH || node_key == *WILDCARD_STOP_HASH
152			})
153			.cloned()
154	}
155
156	fn add_empty_node(&mut self, parent: NodeId, key: u64) -> NodeId {
157		let id = NodeId(self.nodes.len());
158		self.nodes.push(Node::new(key, None));
159		self.node_mut(parent).add_child(id);
160		id
161	}
162
163	pub fn add_route(
164		&mut self,
165		route: &'static str,
166		value: HandlerObj,
167	) -> Result<&mut Node, RouterError> {
168		let keys = generate_path(route);
169		let mut node_id = self.root();
170		for key in keys {
171			node_id = self
172				.find(node_id, key)
173				.unwrap_or_else(|| self.add_empty_node(node_id, key));
174		}
175		match self.node(node_id).value() {
176			None => {
177				let node = self.node_mut(node_id);
178				node.set_value(value);
179				Ok(node)
180			}
181			Some(_) => Err(RouterError::RouteAlreadyExists(route.to_string())),
182		}
183	}
184
185	pub fn get(&self, path: &str) -> Result<impl Iterator<Item = HandlerObj>, RouterError> {
186		let keys = generate_path(path);
187		let mut handlers = vec![];
188		let mut node_id = self.root();
189		collect_node_middleware(&mut handlers, self.node(node_id));
190		for key in keys {
191			node_id = self
192				.find(node_id, key)
193				.ok_or(RouterError::RouteNotFound(path.to_string()))?;
194			let node = self.node(node_id);
195			collect_node_middleware(&mut handlers, self.node(node_id));
196			if node.key == *WILDCARD_STOP_HASH {
197				break;
198			}
199		}
200
201		if let Some(h) = self.node(node_id).value() {
202			handlers.push(h);
203			Ok(handlers.into_iter())
204		} else {
205			Err(RouterError::NoValue(path.to_string()))
206		}
207	}
208}
209
210impl Service<Request<Body>> for Router {
211	type Response = Response<Body>;
212	type Error = hyper::Error;
213	type Future = ResponseFuture;
214
215	fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
216		Poll::Ready(Ok(()))
217	}
218
219	fn call(&mut self, req: Request<Body>) -> Self::Future {
220		match self.get(req.uri().path()) {
221			Err(_) => not_found(),
222			Ok(mut handlers) => match handlers.next() {
223				None => not_found(),
224				Some(h) => h.call(req, Box::new(handlers)),
225			},
226		}
227	}
228}
229
230impl Node {
231	fn new(key: u64, value: Option<HandlerObj>) -> Node {
232		Node {
233			key,
234			value,
235			children: [NodeId(0); MAX_CHILDREN],
236			children_count: 0,
237			mws: None,
238		}
239	}
240
241	pub fn add_middleware(&mut self, mw: HandlerObj) -> &mut Node {
242		if self.mws.is_none() {
243			self.mws = Some(vec![]);
244		}
245		if let Some(ref mut mws) = self.mws {
246			mws.push(mw.clone());
247		}
248		self
249	}
250
251	fn value(&self) -> Option<HandlerObj> {
252		match &self.value {
253			None => None,
254			Some(v) => Some(v.clone()),
255		}
256	}
257
258	fn set_value(&mut self, value: HandlerObj) {
259		self.value = Some(value);
260	}
261
262	fn add_child(&mut self, child_id: NodeId) {
263		if self.children_count == MAX_CHILDREN {
264			panic!("Can't add a route, children limit exceeded");
265		}
266		self.children[self.children_count] = child_id;
267		self.children_count += 1;
268	}
269}
270
271pub fn not_found() -> ResponseFuture {
272	let mut response = Response::new(Body::empty());
273	*response.status_mut() = StatusCode::NOT_FOUND;
274	Box::pin(future::ok(response))
275}
276
277fn calculate_hash<T: Hash>(t: &T) -> u64 {
278	let mut s = DefaultHasher::new();
279	t.hash(&mut s);
280	s.finish()
281}
282
283fn generate_path(route: &str) -> Vec<u64> {
284	route
285		.split('/')
286		.skip(1)
287		.map(|path| calculate_hash(&path))
288		.collect()
289}
290
291fn collect_node_middleware(handlers: &mut Vec<HandlerObj>, node: &Node) {
292	if let Some(ref mws) = node.mws {
293		for mw in mws {
294			handlers.push(mw.clone());
295		}
296	}
297}
298
299#[cfg(test)]
300mod tests {
301
302	use super::*;
303	use futures::executor::block_on;
304
305	struct HandlerImpl(u16);
306
307	impl Handler for HandlerImpl {
308		fn get(&self, _req: Request<Body>) -> ResponseFuture {
309			let code = self.0;
310			Box::pin(async move {
311				let res = Response::builder()
312					.status(code)
313					.body(Body::default())
314					.unwrap();
315				Ok(res)
316			})
317		}
318	}
319
320	#[test]
321	fn test_add_route() {
322		let mut routes = Router::new();
323		let h1 = Arc::new(HandlerImpl(1));
324		let h2 = Arc::new(HandlerImpl(2));
325		let h3 = Arc::new(HandlerImpl(3));
326		routes.add_route("/v1/users", h1.clone()).unwrap();
327		assert!(routes.add_route("/v1/users", h2.clone()).is_err());
328		routes.add_route("/v1/users/xxx", h3.clone()).unwrap();
329		routes.add_route("/v1/users/xxx/yyy", h3.clone()).unwrap();
330		routes.add_route("/v1/zzz/*", h3.clone()).unwrap();
331		assert!(routes.add_route("/v1/zzz/ccc", h2.clone()).is_err());
332		routes
333			.add_route("/v1/zzz/*/zzz", Arc::new(HandlerImpl(6)))
334			.unwrap();
335	}
336
337	#[test]
338	fn test_get() {
339		let mut routes = Router::new();
340		routes
341			.add_route("/v1/users", Arc::new(HandlerImpl(101)))
342			.unwrap();
343		routes
344			.add_route("/v1/users/xxx", Arc::new(HandlerImpl(103)))
345			.unwrap();
346		routes
347			.add_route("/v1/users/xxx/yyy", Arc::new(HandlerImpl(103)))
348			.unwrap();
349		routes
350			.add_route("/v1/zzz/*", Arc::new(HandlerImpl(103)))
351			.unwrap();
352		routes
353			.add_route("/v1/zzz/*/zzz", Arc::new(HandlerImpl(106)))
354			.unwrap();
355
356		let call_handler = |url| {
357			let task = async {
358				let resp = routes
359					.get(url)
360					.unwrap()
361					.next()
362					.unwrap()
363					.get(Request::new(Body::default()))
364					.await
365					.unwrap();
366				resp.status().as_u16()
367			};
368			block_on(task)
369		};
370
371		assert_eq!(call_handler("/v1/users"), 101);
372		assert_eq!(call_handler("/v1/users/xxx"), 103);
373		assert!(routes.get("/v1/users/yyy").is_err());
374		assert_eq!(call_handler("/v1/users/xxx/yyy"), 103);
375		assert!(routes.get("/v1/zzz").is_err());
376		assert_eq!(call_handler("/v1/zzz/1"), 103);
377		assert_eq!(call_handler("/v1/zzz/2"), 103);
378		assert_eq!(call_handler("/v1/zzz/2/zzz"), 106);
379	}
380}