1use crate::{
92	errors::CatBridgeError,
93	net::{
94		errors::CommonNetAPIError,
95		models::{Request, Response},
96		server::request_handlers::{Handler, HandlerAsService},
97	},
98};
99use fnv::FnvHashMap;
100use std::{
101	convert::Infallible,
102	hash::BuildHasherDefault,
103	pin::Pin,
104	task::{Context, Poll},
105};
106use tower::{Layer, Service, util::BoxCloneService};
107use tracing::{debug, warn};
108use wide::u8x16;
109
110type RoutableService<State> = BoxCloneService<Request<State>, Response, CatBridgeError>;
111type RoutingInfo = Vec<(u8x16, Vec<u8>)>;
112
113#[derive(Clone, Debug)]
123pub struct Router<State: Clone + Send + Sync + 'static = ()> {
124	fallback: Option<RoutableService<State>>,
126	offset_at: Option<usize>,
128	route_table: FnvHashMap<u8, (RoutingInfo, Vec<RoutableService<State>>)>,
137}
138
139impl<State: Clone + Send + Sync + 'static> Router<State> {
140	#[must_use]
142	pub fn new() -> Self {
143		Self {
144			fallback: None,
145			offset_at: None,
146			route_table: FnvHashMap::with_capacity_and_hasher(0, BuildHasherDefault::default()),
147		}
148	}
149
150	#[must_use]
152	pub fn new_with_offset(offset_at: usize) -> Self {
153		Self {
154			fallback: None,
155			offset_at: Some(offset_at),
156			route_table: FnvHashMap::with_capacity_and_hasher(0, BuildHasherDefault::default()),
157		}
158	}
159
160	pub fn add_route<HandlerTy, HandlerParamsTy>(
172		&mut self,
173		packet_start: &[u8],
174		handle: HandlerTy,
175	) -> Result<(), CommonNetAPIError>
176	where
177		HandlerTy: Handler<HandlerParamsTy, State> + Clone + Send + 'static,
178		HandlerParamsTy: Send + 'static,
179	{
180		let boxed = BoxCloneService::new(HandlerAsService::new(handle));
181		let widened = Self::to_wide_simd(packet_start)?;
182
183		if let Some(routes) = self.route_table.get_mut(&packet_start[0]) {
184			for (wide, _) in &routes.0 {
185				if *wide == widened {
186					return Err(CommonNetAPIError::DuplicateRoute(Vec::from(packet_start)));
187				}
188			}
189			routes.0.push((widened, Vec::from(packet_start)));
190			routes.1.push(boxed);
191		} else {
192			self.route_table.insert(
193				packet_start[0],
194				(vec![(widened, Vec::from(packet_start))], vec![boxed]),
195			);
196		}
197
198		Ok(())
199	}
200
201	pub fn add_route_service<ServiceTy>(
212		&mut self,
213		packet_start: &[u8],
214		service: ServiceTy,
215	) -> Result<(), CommonNetAPIError>
216	where
217		ServiceTy: Service<Request<State>, Response = Response, Error = CatBridgeError>
218			+ Clone
219			+ Send
220			+ 'static,
221		ServiceTy::Future: Send,
222	{
223		let boxed = BoxCloneService::new(service);
224		let widened = Self::to_wide_simd(packet_start)?;
225
226		if let Some(routes) = self.route_table.get_mut(&packet_start[0]) {
227			for (wide, _) in &routes.0 {
228				if *wide == widened {
229					return Err(CommonNetAPIError::DuplicateRoute(Vec::from(packet_start)));
230				}
231			}
232			routes.0.push((widened, Vec::from(packet_start)));
233			routes.1.push(boxed);
234		} else {
235			self.route_table.insert(
236				packet_start[0],
237				(vec![(widened, Vec::from(packet_start))], vec![boxed]),
238			);
239		}
240
241		Ok(())
242	}
243
244	pub fn layer<LayerTy, ServiceTy>(&mut self, layer: LayerTy)
250	where
251		LayerTy:
252			Layer<BoxCloneService<Request<State>, Response, CatBridgeError>, Service = ServiceTy>,
253		ServiceTy: Service<Request<State>, Response = Response, Error = CatBridgeError>
254			+ Clone
255			+ Send
256			+ 'static,
257		<LayerTy::Service as Service<Request<State>>>::Future: Send + 'static,
258	{
259		for route_table in self.route_table.values_mut() {
260			for route in &mut route_table.1 {
261				*route = BoxCloneService::new(layer.layer(route.clone()));
262			}
263		}
264	}
265
266	pub fn fallback_handler<HandlerTy, HandlerParamsTy>(
272		&mut self,
273		handle: HandlerTy,
274	) -> Result<(), CommonNetAPIError>
275	where
276		HandlerTy: Handler<HandlerParamsTy, State> + Clone + Send + 'static,
277		HandlerParamsTy: Send + 'static,
278	{
279		if self.fallback.is_some() {
280			return Err(CommonNetAPIError::DuplicateFallbackHandler);
281		}
282		self.fallback = Some(BoxCloneService::new(HandlerAsService::new(handle)));
283		Ok(())
284	}
285
286	pub fn fallback_handler_service<ServiceTy>(
292		&mut self,
293		service: ServiceTy,
294	) -> Result<(), CommonNetAPIError>
295	where
296		ServiceTy: Service<Request<State>, Response = Response, Error = CatBridgeError>
297			+ Clone
298			+ Send
299			+ 'static,
300		ServiceTy::Future: Send,
301	{
302		if self.fallback.is_some() {
303			return Err(CommonNetAPIError::DuplicateFallbackHandler);
304		}
305		self.fallback = Some(BoxCloneService::new(service));
306		Ok(())
307	}
308
309	fn to_wide_simd(pre: &[u8]) -> Result<u8x16, CommonNetAPIError> {
310		if pre.is_empty() {
311			return Err(CommonNetAPIError::RouterNeedsSomeBytesToMatchOn);
312		}
313		if pre.len() > 16 {
314			return Err(CommonNetAPIError::RouteTooLongToMatchOn(Vec::from(pre)));
315		}
316
317		let mut data = [0_u8; 16];
318		for (idx, byte) in pre.iter().enumerate() {
319			data[idx] = *byte;
320		}
321		Ok(u8x16::new(data))
322	}
323
324	fn to_wide_simd_offset(packet: &[u8], offset: usize) -> u8x16 {
325		let mut data = [0_u8; 16];
326		data[..std::cmp::min(packet.len() - offset, 16)]
327			.copy_from_slice(&packet[offset..offset + std::cmp::min(packet.len() - offset, 16)]);
328		u8x16::new(data)
329	}
330}
331
332impl<State: Clone + Send + Sync + 'static> Default for Router<State> {
333	fn default() -> Self {
334		Self::new()
335	}
336}
337
338impl<State: Clone + Send + Sync + 'static> Service<Request<State>> for Router<State> {
339	type Response = Response;
340	type Error = Infallible;
341	type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
342
343	fn poll_ready(&mut self, _ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
344		Poll::Ready(Ok(()))
345	}
346
347	fn call(&mut self, req: Request<State>) -> Self::Future {
348		let mut handler = None;
349
350		let body = req.body();
352		let offset = self.offset_at.unwrap_or_default();
353		if offset >= body.len() {
354			handler.clone_from(&self.fallback);
355		} else if let Some(possible_routes) = self.route_table.get(&body[offset]) {
356			let padded = Self::to_wide_simd_offset(body, offset);
357
358			for (idx, (possible_route, array)) in possible_routes.0.iter().enumerate() {
359				let mut result = u8x16::new([0_u8; 16]);
360				result |= possible_route;
361				result |= padded;
362				if result == padded && body[offset..].starts_with(array) {
363					handler = Some(possible_routes.1[idx].clone());
364					break;
365				}
366			}
367
368			if handler.is_none() {
369				handler.clone_from(&self.fallback);
370			}
371		} else {
372			handler.clone_from(&self.fallback);
373		}
374
375		Box::pin(async move {
376			if handler.is_none() {
377				debug!(
378					request.body = format!("{:02x?}", req.body()),
379					"unknown handler called for router!",
380				);
381				return Ok(Response::new_empty());
382			}
383			let mut hndl = handler.unwrap();
384
385			match hndl.call(req).await {
386				Ok(resp) => Ok(resp),
387				Err(cause) => {
388					warn!(?cause, "handler for l4 router failed");
389					Ok(Response::empty_close())
390				}
391			}
392		})
393	}
394}
395
396#[cfg(test)]
397pub mod test_helpers {
398	#![allow(unused)]
399
400	use super::*;
401	use bytes::Bytes;
402	use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
403
404	pub async fn router_empty_no_close(router: &mut Router, packet: &'static [u8]) {
407		router_empty_no_close_with_state(router, packet, ()).await
408	}
409
410	pub async fn router_empty_no_close_with_state<State: Clone + Send + Sync + 'static>(
413		router: &mut Router<State>,
414		packet: &'static [u8],
415		state: State,
416	) {
417		let resp = router
418			.call(Request::new_with_state(
419				Bytes::from_static(packet),
420				SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)),
421				state,
422				None,
423			))
424			.await
425			.expect("Failed to call router!");
426
427		assert!(
428			!resp.request_connection_close(),
429			"Response indiciated that it should be closed.",
430		);
431
432		if let Some(body) = resp.body() {
433			assert!(
434				body.is_empty(),
435				"Response body was not empty! Was:\n  hex: {:02x?}\n  str: {}",
436				body,
437				String::from_utf8_lossy(&body),
438			);
439		}
440	}
441
442	pub async fn router_body_no_close(router: &mut Router, packet: &'static [u8]) -> Vec<u8> {
445		router_body_no_close_with_state(router, packet, ()).await
446	}
447
448	pub async fn router_body_no_close_with_state<State: Clone + Send + Sync + 'static>(
451		router: &mut Router<State>,
452		packet: &'static [u8],
453		state: State,
454	) -> Vec<u8> {
455		let resp = router
456			.call(Request::new_with_state(
457				Bytes::from_static(packet),
458				SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)),
459				state,
460				None,
461			))
462			.await
463			.expect("Failed to call router!");
464
465		assert!(
466			!resp.request_connection_close(),
467			"Response indiciated that it should be closed.",
468		);
469
470		resp.body()
471			.expect("Failed to find response body from called route!")
472			.iter()
473			.map(|i| *i)
474			.collect::<Vec<_>>()
475	}
476
477	pub async fn router_empty_close(router: &mut Router, packet: &'static [u8]) {
479		router_empty_close_with_state(router, packet, ()).await
480	}
481
482	pub async fn router_empty_close_with_state<State: Clone + Send + Sync + 'static>(
483		router: &mut Router<State>,
484		packet: &'static [u8],
485		state: State,
486	) {
487		let resp = router
488			.call(Request::new_with_state(
489				Bytes::from_static(packet),
490				SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)),
491				state,
492				None,
493			))
494			.await
495			.expect("Failed to call router!");
496
497		assert!(resp.request_connection_close());
498		assert!(resp.body().is_none());
499	}
500}
501
502#[cfg(test)]
503mod unit_tests {
504	use super::{test_helpers::*, *};
505	use crate::net::server::requestable::State;
506	use bytes::{Bytes, BytesMut};
507	use std::sync::{
508		Arc,
509		atomic::{AtomicU8, Ordering},
510	};
511
512	#[tokio::test]
513	pub async fn route_to_handler_full() {
514		let mut router = Router::new();
515		router
516			.add_route(&[0x3, 0x4, 0x5, 0x6, 0x7], || async {
517				Ok::<Response, CatBridgeError>(Response::from(Bytes::from_static(b"hello world")))
518			})
519			.expect("Failed to add route!");
520
521		assert_eq!(
522			router_body_no_close(
523				&mut router,
524				&[0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0x10, 0x11, 0x12]
525			)
526			.await,
527			b"hello world"
528		);
529	}
530
531	#[tokio::test]
532	pub async fn route_to_handler_into() {
533		let mut router = Router::new();
534
535		async fn static_byte() -> &'static [u8] {
536			b"hello world static"
537		}
538
539		router
540			.add_route(&[0x1], || async { Bytes::from_static(b"hello world") })
541			.expect("Failed to add route!");
542		router
543			.add_route(&[0x2], || async {
544				let mut bytes = BytesMut::with_capacity(0);
545				bytes.extend_from_slice(b"hello world bytes_mut");
546				bytes
547			})
548			.expect("Failed to add route!");
549		router
550			.add_route(&[0x3], || async { "hello world string".to_owned() })
551			.expect("Failed to add route!");
552		router
553			.add_route(&[0x4], || async {
554				b"hello world vec".iter().map(|i| *i).collect::<Vec<u8>>()
555			})
556			.expect("Failed to add route!");
557		router
558			.add_route(&[0x5], static_byte)
559			.expect("Failed to add route!");
560		router
561			.add_route(&[0x6], || async { "hello world static str" })
562			.expect("Failed to add route!");
563		router
564			.add_route(&[0x7], || async {
565				Err::<String, CatBridgeError>(CatBridgeError::UnsupportedBitsPerCore)
566			})
567			.expect("Failed to add route!");
568
569		assert_eq!(
570			router_body_no_close(&mut router, &[0x1, 0x2, 0x3]).await,
571			b"hello world",
572		);
573		assert_eq!(
574			router_body_no_close(&mut router, &[0x2, 0x2, 0x3]).await,
575			b"hello world bytes_mut",
576		);
577		assert_eq!(
578			router_body_no_close(&mut router, &[0x3, 0x2, 0x3]).await,
579			b"hello world string",
580		);
581		assert_eq!(
582			router_body_no_close(&mut router, &[0x4, 0x2, 0x3]).await,
583			b"hello world vec",
584		);
585		assert_eq!(
586			router_body_no_close(&mut router, &[0x5, 0x2, 0x3]).await,
587			b"hello world static",
588		);
589		assert_eq!(
590			router_body_no_close(&mut router, &[0x6, 0x2, 0x3]).await,
591			b"hello world static str",
592		);
593
594		router_empty_close(&mut router, &[0x7, 0x2, 0x3]).await;
595	}
596
597	#[tokio::test]
598	pub async fn route_to_fallback() {
599		async fn base_handler() -> Result<Bytes, CatBridgeError> {
600			Ok(Bytes::from_static(b"base"))
601		}
602
603		let mut router = Router::new();
604		router
605			.add_route(&[0x1, 0x2, 0x3], base_handler)
606			.expect("Failed to add!");
607		router
608			.add_route(&[0x2, 0x2, 0x3], base_handler)
609			.expect("Failed to add!");
610		router
611			.add_route(&[0x3, 0x2, 0x3], base_handler)
612			.expect("Failed to add!");
613		router
614			.add_route(&[0x4, 0x2, 0x3], base_handler)
615			.expect("Failed to add!");
616		router
617			.add_route(&[0x6, 0x2, 0x3], base_handler)
618			.expect("Failed to add!");
619		router
620			.add_route(&[0x7, 0x2, 0x3], base_handler)
621			.expect("Failed to add!");
622		router
623			.add_route(&[0x8, 0x2, 0x3], base_handler)
624			.expect("Failed to add!");
625		router
626			.add_route(&[0x9, 0x2, 0x3], base_handler)
627			.expect("Failed to add!");
628		router
629			.fallback_handler(|| async { Ok(Bytes::from_static(b"fallback")) })
630			.expect("Failed to register fallback_handler!");
631
632		assert_eq!(
633			router_body_no_close(&mut router, &[0x5, 0x6, 0x7]).await,
634			b"fallback"
635		);
636		assert_eq!(
637			router_body_no_close(&mut router, &[0x1, 0x6, 0x7]).await,
638			b"fallback"
639		);
640		assert_eq!(
641			router_body_no_close(&mut router, &[0x1, 0x2, 0x3]).await,
642			b"base"
643		);
644	}
645
646	#[tokio::test]
647	pub async fn test_route_at_offset() {
648		#[derive(Clone)]
649		struct TestState {
650			fallbacks: Arc<AtomicU8>,
651			requests: Arc<AtomicU8>,
652		}
653
654		let my_state = TestState {
655			fallbacks: Arc::new(AtomicU8::new(0)),
656			requests: Arc::new(AtomicU8::new(0)),
657		};
658
659		async fn hit(State(test_state): State<TestState>) -> Bytes {
660			test_state.requests.fetch_add(1, Ordering::SeqCst);
661			Bytes::from_static(b"hewwo mw pwizzwa mwan")
662		}
663		async fn fallback(State(test_state): State<TestState>) -> Bytes {
664			test_state.fallbacks.fetch_add(1, Ordering::SeqCst);
665			Bytes::from_static(b"bye pwizza man")
666		}
667
668		let mut router = Router::<TestState>::new_with_offset(4);
669		router
670			.add_route(&[0x1], hit)
671			.expect("Failed to add route to router!");
672		router
673			.add_route(&[0x2, 0x3, 0x4], hit)
674			.expect("Failed to add route to router!");
675		router
676			.add_route(&[0x5, 0x6, 0x7], hit)
677			.expect("Failed to add route to router!");
678		router
679			.fallback_handler(fallback)
680			.expect("Failed to set fallback handler!");
681
682		_ = router_body_no_close_with_state(&mut router, &[], my_state.clone()).await;
684		_ = router_body_no_close_with_state(&mut router, &[0x1; 4], my_state.clone()).await;
685		_ = router_body_no_close_with_state(&mut router, &[0x1; 5], my_state.clone()).await;
687		_ = router_body_no_close_with_state(
688			&mut router,
689			&[0x1, 0x1, 0x1, 0x1, 0x2, 0x3, 0x4],
690			my_state.clone(),
691		)
692		.await;
693		_ = router_body_no_close_with_state(
694			&mut router,
695			&[0x1, 0x1, 0x1, 0x1, 0x5, 0x6, 0x7],
696			my_state.clone(),
697		)
698		.await;
699		_ = router_body_no_close_with_state(
701			&mut router,
702			&[0x1, 0x1, 0x1, 0x1, 0x9],
703			my_state.clone(),
704		)
705		.await;
706
707		assert_eq!(
708			my_state.fallbacks.load(Ordering::SeqCst),
709			3,
710			"Fallback did not get hit required amount of times!",
711		);
712		assert_eq!(
713			my_state.requests.load(Ordering::SeqCst),
714			3,
715			"Request Handler did not get hit required amount of times!",
716		);
717	}
718}