1use crate::{
92 errors::CatBridgeError,
93 net::{
94 errors::CommonNetAPIError,
95 models::{Request, Response},
96 server::request_handlers::{Handler, HandlerAsService},
97 },
98};
99use fnv::FnvHashMap;
100use std::{
101 convert::Infallible,
102 hash::BuildHasherDefault,
103 pin::Pin,
104 task::{Context, Poll},
105};
106use tower::{Layer, Service, util::BoxCloneService};
107use tracing::{debug, warn};
108use wide::u8x16;
109
110type RoutableService<State> = BoxCloneService<Request<State>, Response, CatBridgeError>;
111type RoutingInfo = Vec<(u8x16, Vec<u8>)>;
112
113#[derive(Clone, Debug)]
123pub struct Router<State: Clone + Send + Sync + 'static = ()> {
124 fallback: Option<RoutableService<State>>,
126 offset_at: Option<usize>,
128 route_table: FnvHashMap<u8, (RoutingInfo, Vec<RoutableService<State>>)>,
137}
138
139impl<State: Clone + Send + Sync + 'static> Router<State> {
140 #[must_use]
142 pub fn new() -> Self {
143 Self {
144 fallback: None,
145 offset_at: None,
146 route_table: FnvHashMap::with_capacity_and_hasher(0, BuildHasherDefault::default()),
147 }
148 }
149
150 #[must_use]
152 pub fn new_with_offset(offset_at: usize) -> Self {
153 Self {
154 fallback: None,
155 offset_at: Some(offset_at),
156 route_table: FnvHashMap::with_capacity_and_hasher(0, BuildHasherDefault::default()),
157 }
158 }
159
160 pub fn add_route<HandlerTy, HandlerParamsTy>(
172 &mut self,
173 packet_start: &[u8],
174 handle: HandlerTy,
175 ) -> Result<(), CommonNetAPIError>
176 where
177 HandlerTy: Handler<HandlerParamsTy, State> + Clone + Send + 'static,
178 HandlerParamsTy: Send + 'static,
179 {
180 let boxed = BoxCloneService::new(HandlerAsService::new(handle));
181 let widened = Self::to_wide_simd(packet_start)?;
182
183 if let Some(routes) = self.route_table.get_mut(&packet_start[0]) {
184 for (wide, _) in &routes.0 {
185 if *wide == widened {
186 return Err(CommonNetAPIError::DuplicateRoute(Vec::from(packet_start)));
187 }
188 }
189 routes.0.push((widened, Vec::from(packet_start)));
190 routes.1.push(boxed);
191 } else {
192 self.route_table.insert(
193 packet_start[0],
194 (vec![(widened, Vec::from(packet_start))], vec![boxed]),
195 );
196 }
197
198 Ok(())
199 }
200
201 pub fn add_route_service<ServiceTy>(
212 &mut self,
213 packet_start: &[u8],
214 service: ServiceTy,
215 ) -> Result<(), CommonNetAPIError>
216 where
217 ServiceTy: Service<Request<State>, Response = Response, Error = CatBridgeError>
218 + Clone
219 + Send
220 + 'static,
221 ServiceTy::Future: Send,
222 {
223 let boxed = BoxCloneService::new(service);
224 let widened = Self::to_wide_simd(packet_start)?;
225
226 if let Some(routes) = self.route_table.get_mut(&packet_start[0]) {
227 for (wide, _) in &routes.0 {
228 if *wide == widened {
229 return Err(CommonNetAPIError::DuplicateRoute(Vec::from(packet_start)));
230 }
231 }
232 routes.0.push((widened, Vec::from(packet_start)));
233 routes.1.push(boxed);
234 } else {
235 self.route_table.insert(
236 packet_start[0],
237 (vec![(widened, Vec::from(packet_start))], vec![boxed]),
238 );
239 }
240
241 Ok(())
242 }
243
244 pub fn layer<LayerTy, ServiceTy>(&mut self, layer: LayerTy)
250 where
251 LayerTy:
252 Layer<BoxCloneService<Request<State>, Response, CatBridgeError>, Service = ServiceTy>,
253 ServiceTy: Service<Request<State>, Response = Response, Error = CatBridgeError>
254 + Clone
255 + Send
256 + 'static,
257 <LayerTy::Service as Service<Request<State>>>::Future: Send + 'static,
258 {
259 for route_table in self.route_table.values_mut() {
260 for route in &mut route_table.1 {
261 *route = BoxCloneService::new(layer.layer(route.clone()));
262 }
263 }
264 }
265
266 pub fn fallback_handler<HandlerTy, HandlerParamsTy>(
272 &mut self,
273 handle: HandlerTy,
274 ) -> Result<(), CommonNetAPIError>
275 where
276 HandlerTy: Handler<HandlerParamsTy, State> + Clone + Send + 'static,
277 HandlerParamsTy: Send + 'static,
278 {
279 if self.fallback.is_some() {
280 return Err(CommonNetAPIError::DuplicateFallbackHandler);
281 }
282 self.fallback = Some(BoxCloneService::new(HandlerAsService::new(handle)));
283 Ok(())
284 }
285
286 pub fn fallback_handler_service<ServiceTy>(
292 &mut self,
293 service: ServiceTy,
294 ) -> Result<(), CommonNetAPIError>
295 where
296 ServiceTy: Service<Request<State>, Response = Response, Error = CatBridgeError>
297 + Clone
298 + Send
299 + 'static,
300 ServiceTy::Future: Send,
301 {
302 if self.fallback.is_some() {
303 return Err(CommonNetAPIError::DuplicateFallbackHandler);
304 }
305 self.fallback = Some(BoxCloneService::new(service));
306 Ok(())
307 }
308
309 fn to_wide_simd(pre: &[u8]) -> Result<u8x16, CommonNetAPIError> {
310 if pre.is_empty() {
311 return Err(CommonNetAPIError::RouterNeedsSomeBytesToMatchOn);
312 }
313 if pre.len() > 16 {
314 return Err(CommonNetAPIError::RouteTooLongToMatchOn(Vec::from(pre)));
315 }
316
317 let mut data = [0_u8; 16];
318 for (idx, byte) in pre.iter().enumerate() {
319 data[idx] = *byte;
320 }
321 Ok(u8x16::new(data))
322 }
323
324 fn to_wide_simd_offset(packet: &[u8], offset: usize) -> u8x16 {
325 let mut data = [0_u8; 16];
326 data[..std::cmp::min(packet.len() - offset, 16)]
327 .copy_from_slice(&packet[offset..offset + std::cmp::min(packet.len() - offset, 16)]);
328 u8x16::new(data)
329 }
330}
331
332impl<State: Clone + Send + Sync + 'static> Default for Router<State> {
333 fn default() -> Self {
334 Self::new()
335 }
336}
337
338impl<State: Clone + Send + Sync + 'static> Service<Request<State>> for Router<State> {
339 type Response = Response;
340 type Error = Infallible;
341 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
342
343 fn poll_ready(&mut self, _ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
344 Poll::Ready(Ok(()))
345 }
346
347 fn call(&mut self, req: Request<State>) -> Self::Future {
348 let mut handler = None;
349
350 let body = req.body();
352 let offset = self.offset_at.unwrap_or_default();
353 if offset >= body.len() {
354 handler.clone_from(&self.fallback);
355 } else if let Some(possible_routes) = self.route_table.get(&body[offset]) {
356 let padded = Self::to_wide_simd_offset(body, offset);
357
358 for (idx, (possible_route, array)) in possible_routes.0.iter().enumerate() {
359 let mut result = u8x16::new([0_u8; 16]);
360 result |= possible_route;
361 result |= padded;
362 if result == padded && body[offset..].starts_with(array) {
363 handler = Some(possible_routes.1[idx].clone());
364 break;
365 }
366 }
367
368 if handler.is_none() {
369 handler.clone_from(&self.fallback);
370 }
371 } else {
372 handler.clone_from(&self.fallback);
373 }
374
375 Box::pin(async move {
376 if handler.is_none() {
377 debug!(
378 request.body = format!("{:02x?}", req.body()),
379 "unknown handler called for router!",
380 );
381 return Ok(Response::new_empty());
382 }
383 let mut hndl = handler.unwrap();
384
385 match hndl.call(req).await {
386 Ok(resp) => Ok(resp),
387 Err(cause) => {
388 warn!(?cause, "handler for l4 router failed");
389 Ok(Response::empty_close())
390 }
391 }
392 })
393 }
394}
395
396#[cfg(test)]
397pub mod test_helpers {
398 #![allow(unused)]
399
400 use super::*;
401 use bytes::Bytes;
402 use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
403
404 pub async fn router_empty_no_close(router: &mut Router, packet: &'static [u8]) {
407 router_empty_no_close_with_state(router, packet, ()).await
408 }
409
410 pub async fn router_empty_no_close_with_state<State: Clone + Send + Sync + 'static>(
413 router: &mut Router<State>,
414 packet: &'static [u8],
415 state: State,
416 ) {
417 let resp = router
418 .call(Request::new_with_state(
419 Bytes::from_static(packet),
420 SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)),
421 state,
422 None,
423 ))
424 .await
425 .expect("Failed to call router!");
426
427 assert!(
428 !resp.request_connection_close(),
429 "Response indiciated that it should be closed.",
430 );
431
432 if let Some(body) = resp.body() {
433 assert!(
434 body.is_empty(),
435 "Response body was not empty! Was:\n hex: {:02x?}\n str: {}",
436 body,
437 String::from_utf8_lossy(&body),
438 );
439 }
440 }
441
442 pub async fn router_body_no_close(router: &mut Router, packet: &'static [u8]) -> Vec<u8> {
445 router_body_no_close_with_state(router, packet, ()).await
446 }
447
448 pub async fn router_body_no_close_with_state<State: Clone + Send + Sync + 'static>(
451 router: &mut Router<State>,
452 packet: &'static [u8],
453 state: State,
454 ) -> Vec<u8> {
455 let resp = router
456 .call(Request::new_with_state(
457 Bytes::from_static(packet),
458 SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)),
459 state,
460 None,
461 ))
462 .await
463 .expect("Failed to call router!");
464
465 assert!(
466 !resp.request_connection_close(),
467 "Response indiciated that it should be closed.",
468 );
469
470 resp.body()
471 .expect("Failed to find response body from called route!")
472 .iter()
473 .map(|i| *i)
474 .collect::<Vec<_>>()
475 }
476
477 pub async fn router_empty_close(router: &mut Router, packet: &'static [u8]) {
479 router_empty_close_with_state(router, packet, ()).await
480 }
481
482 pub async fn router_empty_close_with_state<State: Clone + Send + Sync + 'static>(
483 router: &mut Router<State>,
484 packet: &'static [u8],
485 state: State,
486 ) {
487 let resp = router
488 .call(Request::new_with_state(
489 Bytes::from_static(packet),
490 SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)),
491 state,
492 None,
493 ))
494 .await
495 .expect("Failed to call router!");
496
497 assert!(resp.request_connection_close());
498 assert!(resp.body().is_none());
499 }
500}
501
502#[cfg(test)]
503mod unit_tests {
504 use super::{test_helpers::*, *};
505 use crate::net::server::requestable::State;
506 use bytes::{Bytes, BytesMut};
507 use std::sync::{
508 Arc,
509 atomic::{AtomicU8, Ordering},
510 };
511
512 #[tokio::test]
513 pub async fn route_to_handler_full() {
514 let mut router = Router::new();
515 router
516 .add_route(&[0x3, 0x4, 0x5, 0x6, 0x7], || async {
517 Ok::<Response, CatBridgeError>(Response::from(Bytes::from_static(b"hello world")))
518 })
519 .expect("Failed to add route!");
520
521 assert_eq!(
522 router_body_no_close(
523 &mut router,
524 &[0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0x10, 0x11, 0x12]
525 )
526 .await,
527 b"hello world"
528 );
529 }
530
531 #[tokio::test]
532 pub async fn route_to_handler_into() {
533 let mut router = Router::new();
534
535 async fn static_byte() -> &'static [u8] {
536 b"hello world static"
537 }
538
539 router
540 .add_route(&[0x1], || async { Bytes::from_static(b"hello world") })
541 .expect("Failed to add route!");
542 router
543 .add_route(&[0x2], || async {
544 let mut bytes = BytesMut::with_capacity(0);
545 bytes.extend_from_slice(b"hello world bytes_mut");
546 bytes
547 })
548 .expect("Failed to add route!");
549 router
550 .add_route(&[0x3], || async { "hello world string".to_owned() })
551 .expect("Failed to add route!");
552 router
553 .add_route(&[0x4], || async {
554 b"hello world vec".iter().map(|i| *i).collect::<Vec<u8>>()
555 })
556 .expect("Failed to add route!");
557 router
558 .add_route(&[0x5], static_byte)
559 .expect("Failed to add route!");
560 router
561 .add_route(&[0x6], || async { "hello world static str" })
562 .expect("Failed to add route!");
563 router
564 .add_route(&[0x7], || async {
565 Err::<String, CatBridgeError>(CatBridgeError::UnsupportedBitsPerCore)
566 })
567 .expect("Failed to add route!");
568
569 assert_eq!(
570 router_body_no_close(&mut router, &[0x1, 0x2, 0x3]).await,
571 b"hello world",
572 );
573 assert_eq!(
574 router_body_no_close(&mut router, &[0x2, 0x2, 0x3]).await,
575 b"hello world bytes_mut",
576 );
577 assert_eq!(
578 router_body_no_close(&mut router, &[0x3, 0x2, 0x3]).await,
579 b"hello world string",
580 );
581 assert_eq!(
582 router_body_no_close(&mut router, &[0x4, 0x2, 0x3]).await,
583 b"hello world vec",
584 );
585 assert_eq!(
586 router_body_no_close(&mut router, &[0x5, 0x2, 0x3]).await,
587 b"hello world static",
588 );
589 assert_eq!(
590 router_body_no_close(&mut router, &[0x6, 0x2, 0x3]).await,
591 b"hello world static str",
592 );
593
594 router_empty_close(&mut router, &[0x7, 0x2, 0x3]).await;
595 }
596
597 #[tokio::test]
598 pub async fn route_to_fallback() {
599 async fn base_handler() -> Result<Bytes, CatBridgeError> {
600 Ok(Bytes::from_static(b"base"))
601 }
602
603 let mut router = Router::new();
604 router
605 .add_route(&[0x1, 0x2, 0x3], base_handler)
606 .expect("Failed to add!");
607 router
608 .add_route(&[0x2, 0x2, 0x3], base_handler)
609 .expect("Failed to add!");
610 router
611 .add_route(&[0x3, 0x2, 0x3], base_handler)
612 .expect("Failed to add!");
613 router
614 .add_route(&[0x4, 0x2, 0x3], base_handler)
615 .expect("Failed to add!");
616 router
617 .add_route(&[0x6, 0x2, 0x3], base_handler)
618 .expect("Failed to add!");
619 router
620 .add_route(&[0x7, 0x2, 0x3], base_handler)
621 .expect("Failed to add!");
622 router
623 .add_route(&[0x8, 0x2, 0x3], base_handler)
624 .expect("Failed to add!");
625 router
626 .add_route(&[0x9, 0x2, 0x3], base_handler)
627 .expect("Failed to add!");
628 router
629 .fallback_handler(|| async { Ok(Bytes::from_static(b"fallback")) })
630 .expect("Failed to register fallback_handler!");
631
632 assert_eq!(
633 router_body_no_close(&mut router, &[0x5, 0x6, 0x7]).await,
634 b"fallback"
635 );
636 assert_eq!(
637 router_body_no_close(&mut router, &[0x1, 0x6, 0x7]).await,
638 b"fallback"
639 );
640 assert_eq!(
641 router_body_no_close(&mut router, &[0x1, 0x2, 0x3]).await,
642 b"base"
643 );
644 }
645
646 #[tokio::test]
647 pub async fn test_route_at_offset() {
648 #[derive(Clone)]
649 struct TestState {
650 fallbacks: Arc<AtomicU8>,
651 requests: Arc<AtomicU8>,
652 }
653
654 let my_state = TestState {
655 fallbacks: Arc::new(AtomicU8::new(0)),
656 requests: Arc::new(AtomicU8::new(0)),
657 };
658
659 async fn hit(State(test_state): State<TestState>) -> Bytes {
660 test_state.requests.fetch_add(1, Ordering::SeqCst);
661 Bytes::from_static(b"hewwo mw pwizzwa mwan")
662 }
663 async fn fallback(State(test_state): State<TestState>) -> Bytes {
664 test_state.fallbacks.fetch_add(1, Ordering::SeqCst);
665 Bytes::from_static(b"bye pwizza man")
666 }
667
668 let mut router = Router::<TestState>::new_with_offset(4);
669 router
670 .add_route(&[0x1], hit)
671 .expect("Failed to add route to router!");
672 router
673 .add_route(&[0x2, 0x3, 0x4], hit)
674 .expect("Failed to add route to router!");
675 router
676 .add_route(&[0x5, 0x6, 0x7], hit)
677 .expect("Failed to add route to router!");
678 router
679 .fallback_handler(fallback)
680 .expect("Failed to set fallback handler!");
681
682 _ = router_body_no_close_with_state(&mut router, &[], my_state.clone()).await;
684 _ = router_body_no_close_with_state(&mut router, &[0x1; 4], my_state.clone()).await;
685 _ = router_body_no_close_with_state(&mut router, &[0x1; 5], my_state.clone()).await;
687 _ = router_body_no_close_with_state(
688 &mut router,
689 &[0x1, 0x1, 0x1, 0x1, 0x2, 0x3, 0x4],
690 my_state.clone(),
691 )
692 .await;
693 _ = router_body_no_close_with_state(
694 &mut router,
695 &[0x1, 0x1, 0x1, 0x1, 0x5, 0x6, 0x7],
696 my_state.clone(),
697 )
698 .await;
699 _ = router_body_no_close_with_state(
701 &mut router,
702 &[0x1, 0x1, 0x1, 0x1, 0x9],
703 my_state.clone(),
704 )
705 .await;
706
707 assert_eq!(
708 my_state.fallbacks.load(Ordering::SeqCst),
709 3,
710 "Fallback did not get hit required amount of times!",
711 );
712 assert_eq!(
713 my_state.requests.load(Ordering::SeqCst),
714 3,
715 "Request Handler did not get hit required amount of times!",
716 );
717 }
718}