1use crate::{
92 errors::CatBridgeError,
93 net::{
94 errors::CommonNetAPIError,
95 models::{Request, Response},
96 server::request_handlers::{Handler, HandlerAsService},
97 },
98};
99use fnv::FnvHashMap;
100use std::{
101 convert::Infallible,
102 hash::BuildHasherDefault,
103 pin::Pin,
104 task::{Context, Poll},
105};
106use tower::{Layer, Service, util::BoxCloneService};
107use tracing::{debug, warn};
108use wide::u8x16;
109
110type RoutableService<State> = BoxCloneService<Request<State>, Response, CatBridgeError>;
111type RoutingInfo = Vec<(u8x16, Vec<u8>)>;
112
113#[derive(Clone, Debug)]
123pub struct Router<State: Clone + Send + Sync + 'static = ()> {
124 fallback: Option<RoutableService<State>>,
126 offset_at: Option<usize>,
128 route_table: FnvHashMap<u8, (RoutingInfo, Vec<RoutableService<State>>)>,
137}
138
139impl<State: Clone + Send + Sync + 'static> Router<State> {
140 #[must_use]
142 pub fn new() -> Self {
143 Self {
144 fallback: None,
145 offset_at: None,
146 route_table: FnvHashMap::with_capacity_and_hasher(0, BuildHasherDefault::default()),
147 }
148 }
149
150 #[must_use]
152 pub fn new_with_offset(offset_at: usize) -> Self {
153 Self {
154 fallback: None,
155 offset_at: Some(offset_at),
156 route_table: FnvHashMap::with_capacity_and_hasher(0, BuildHasherDefault::default()),
157 }
158 }
159
160 pub fn add_route<HandlerTy, HandlerParamsTy>(
172 &mut self,
173 packet_start: &[u8],
174 handle: HandlerTy,
175 ) -> Result<(), CommonNetAPIError>
176 where
177 HandlerTy: Handler<HandlerParamsTy, State> + Clone + Send + 'static,
178 HandlerParamsTy: Send + 'static,
179 {
180 let boxed = BoxCloneService::new(HandlerAsService::new(handle));
181 let widened = Self::to_wide_simd(packet_start)?;
182
183 if let Some(routes) = self.route_table.get_mut(&packet_start[0]) {
184 for (wide, _) in &routes.0 {
185 if *wide == widened {
186 return Err(CommonNetAPIError::DuplicateRoute(Vec::from(packet_start)));
187 }
188 }
189 routes.0.push((widened, Vec::from(packet_start)));
190 routes.1.push(boxed);
191 } else {
192 self.route_table.insert(
193 packet_start[0],
194 (vec![(widened, Vec::from(packet_start))], vec![boxed]),
195 );
196 }
197
198 Ok(())
199 }
200
201 pub fn add_route_service<ServiceTy>(
212 &mut self,
213 packet_start: &[u8],
214 service: ServiceTy,
215 ) -> Result<(), CommonNetAPIError>
216 where
217 ServiceTy: Service<Request<State>, Response = Response, Error = CatBridgeError>
218 + Clone
219 + Send
220 + 'static,
221 ServiceTy::Future: Send,
222 {
223 let boxed = BoxCloneService::new(service);
224 let widened = Self::to_wide_simd(packet_start)?;
225
226 if let Some(routes) = self.route_table.get_mut(&packet_start[0]) {
227 for (wide, _) in &routes.0 {
228 if *wide == widened {
229 return Err(CommonNetAPIError::DuplicateRoute(Vec::from(packet_start)));
230 }
231 }
232 routes.0.push((widened, Vec::from(packet_start)));
233 routes.1.push(boxed);
234 } else {
235 self.route_table.insert(
236 packet_start[0],
237 (vec![(widened, Vec::from(packet_start))], vec![boxed]),
238 );
239 }
240
241 Ok(())
242 }
243
244 pub fn layer<LayerTy, ServiceTy>(&mut self, layer: LayerTy)
250 where
251 LayerTy:
252 Layer<BoxCloneService<Request<State>, Response, CatBridgeError>, Service = ServiceTy>,
253 ServiceTy: Service<Request<State>, Response = Response, Error = CatBridgeError>
254 + Clone
255 + Send
256 + 'static,
257 <LayerTy::Service as Service<Request<State>>>::Future: Send + 'static,
258 {
259 for route_table in self.route_table.values_mut() {
260 for route in &mut route_table.1 {
261 *route = BoxCloneService::new(layer.layer(route.clone()));
262 }
263 }
264 }
265
266 pub fn fallback_handler<HandlerTy, HandlerParamsTy>(
272 &mut self,
273 handle: HandlerTy,
274 ) -> Result<(), CommonNetAPIError>
275 where
276 HandlerTy: Handler<HandlerParamsTy, State> + Clone + Send + 'static,
277 HandlerParamsTy: Send + 'static,
278 {
279 if self.fallback.is_some() {
280 return Err(CommonNetAPIError::DuplicateFallbackHandler);
281 }
282 self.fallback = Some(BoxCloneService::new(HandlerAsService::new(handle)));
283 Ok(())
284 }
285
286 pub fn fallback_handler_service<ServiceTy>(
292 &mut self,
293 service: ServiceTy,
294 ) -> Result<(), CommonNetAPIError>
295 where
296 ServiceTy: Service<Request<State>, Response = Response, Error = CatBridgeError>
297 + Clone
298 + Send
299 + 'static,
300 ServiceTy::Future: Send,
301 {
302 if self.fallback.is_some() {
303 return Err(CommonNetAPIError::DuplicateFallbackHandler);
304 }
305 self.fallback = Some(BoxCloneService::new(service));
306 Ok(())
307 }
308
309 fn to_wide_simd(pre: &[u8]) -> Result<u8x16, CommonNetAPIError> {
310 if pre.is_empty() {
311 return Err(CommonNetAPIError::RouterNeedsSomeBytesToMatchOn);
312 }
313 if pre.len() > 16 {
314 return Err(CommonNetAPIError::RouteTooLongToMatchOn(Vec::from(pre)));
315 }
316
317 let mut data = [0_u8; 16];
318 for (idx, byte) in pre.iter().enumerate() {
319 data[idx] = *byte;
320 }
321 Ok(u8x16::new(data))
322 }
323
324 fn to_wide_simd_offset(packet: &[u8], offset: usize) -> u8x16 {
325 let mut data = [0_u8; 16];
326 data[..std::cmp::min(packet.len() - offset, 16)]
327 .copy_from_slice(&packet[offset..offset + std::cmp::min(packet.len() - offset, 16)]);
328 u8x16::new(data)
329 }
330}
331
332impl<State: Clone + Send + Sync + 'static> Default for Router<State> {
333 fn default() -> Self {
334 Self::new()
335 }
336}
337
338impl<State: Clone + Send + Sync + 'static> Service<Request<State>> for Router<State> {
339 type Response = Response;
340 type Error = Infallible;
341 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
342
343 fn poll_ready(&mut self, _ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
344 Poll::Ready(Ok(()))
345 }
346
347 fn call(&mut self, req: Request<State>) -> Self::Future {
348 let mut handler = None;
349
350 let body = req.body();
352 let offset = self.offset_at.unwrap_or_default();
353 if offset >= body.len() {
354 handler.clone_from(&self.fallback);
355 } else if let Some(possible_routes) = self.route_table.get(&body[offset]) {
356 let padded = Self::to_wide_simd_offset(body, offset);
357
358 for (idx, (possible_route, array)) in possible_routes.0.iter().enumerate() {
359 let mut result = u8x16::new([0_u8; 16]);
360 result |= possible_route;
361 result |= padded;
362 if result == padded && body[offset..].starts_with(array) {
363 handler = Some(possible_routes.1[idx].clone());
364 break;
365 }
366 }
367
368 if handler.is_none() {
369 handler.clone_from(&self.fallback);
370 }
371 } else {
372 handler.clone_from(&self.fallback);
373 }
374
375 Box::pin(async move {
376 if handler.is_none() {
377 debug!(
378 request.body = format!("{:02x?}", req.body()),
379 "unknown handler called for router!",
380 );
381 return Ok(Response::new_empty());
382 }
383 let mut hndl = handler.unwrap();
384
385 match hndl.call(req).await {
386 Ok(resp) => Ok(resp),
387 Err(cause) => {
388 warn!(
389 ?cause,
390 lisa.force_combine_fields = true,
391 "handler for l4 router failed",
392 );
393 Ok(Response::empty_close())
394 }
395 }
396 })
397 }
398}
399
400#[cfg(test)]
401pub mod test_helpers {
402 #![allow(unused)]
403
404 use super::*;
405 use bytes::Bytes;
406 use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
407
408 pub async fn router_empty_no_close(router: &mut Router, packet: &'static [u8]) {
411 router_empty_no_close_with_state(router, packet, ()).await
412 }
413
414 pub async fn router_empty_no_close_with_state<State: Clone + Send + Sync + 'static>(
417 router: &mut Router<State>,
418 packet: &'static [u8],
419 state: State,
420 ) {
421 let resp = router
422 .call(Request::new_with_state(
423 Bytes::from_static(packet),
424 SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)),
425 state,
426 None,
427 ))
428 .await
429 .expect("Failed to call router!");
430
431 assert!(
432 !resp.request_connection_close(),
433 "Response indiciated that it should be closed.",
434 );
435
436 if let Some(body) = resp.body() {
437 assert!(
438 body.is_empty(),
439 "Response body was not empty! Was:\n hex: {:02x?}\n str: {}",
440 body,
441 String::from_utf8_lossy(&body),
442 );
443 }
444 }
445
446 pub async fn router_body_no_close(router: &mut Router, packet: &'static [u8]) -> Vec<u8> {
449 router_body_no_close_with_state(router, packet, ()).await
450 }
451
452 pub async fn router_body_no_close_with_state<State: Clone + Send + Sync + 'static>(
455 router: &mut Router<State>,
456 packet: &'static [u8],
457 state: State,
458 ) -> Vec<u8> {
459 let resp = router
460 .call(Request::new_with_state(
461 Bytes::from_static(packet),
462 SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)),
463 state,
464 None,
465 ))
466 .await
467 .expect("Failed to call router!");
468
469 assert!(
470 !resp.request_connection_close(),
471 "Response indiciated that it should be closed.",
472 );
473
474 resp.body()
475 .expect("Failed to find response body from called route!")
476 .iter()
477 .map(|i| *i)
478 .collect::<Vec<_>>()
479 }
480
481 pub async fn router_empty_close(router: &mut Router, packet: &'static [u8]) {
483 router_empty_close_with_state(router, packet, ()).await
484 }
485
486 pub async fn router_empty_close_with_state<State: Clone + Send + Sync + 'static>(
487 router: &mut Router<State>,
488 packet: &'static [u8],
489 state: State,
490 ) {
491 let resp = router
492 .call(Request::new_with_state(
493 Bytes::from_static(packet),
494 SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)),
495 state,
496 None,
497 ))
498 .await
499 .expect("Failed to call router!");
500
501 assert!(resp.request_connection_close());
502 assert!(resp.body().is_none());
503 }
504}
505
506#[cfg(test)]
507mod unit_tests {
508 use super::{test_helpers::*, *};
509 use crate::net::server::requestable::State;
510 use bytes::{Bytes, BytesMut};
511 use std::sync::{
512 Arc,
513 atomic::{AtomicU8, Ordering},
514 };
515
516 #[tokio::test]
517 pub async fn route_to_handler_full() {
518 let mut router = Router::new();
519 router
520 .add_route(&[0x3, 0x4, 0x5, 0x6, 0x7], || async {
521 Ok::<Response, CatBridgeError>(Response::from(Bytes::from_static(b"hello world")))
522 })
523 .expect("Failed to add route!");
524
525 assert_eq!(
526 router_body_no_close(
527 &mut router,
528 &[0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0x10, 0x11, 0x12]
529 )
530 .await,
531 b"hello world"
532 );
533 }
534
535 #[tokio::test]
536 pub async fn route_to_handler_into() {
537 let mut router = Router::new();
538
539 async fn static_byte() -> &'static [u8] {
540 b"hello world static"
541 }
542
543 router
544 .add_route(&[0x1], || async { Bytes::from_static(b"hello world") })
545 .expect("Failed to add route!");
546 router
547 .add_route(&[0x2], || async {
548 let mut bytes = BytesMut::with_capacity(0);
549 bytes.extend_from_slice(b"hello world bytes_mut");
550 bytes
551 })
552 .expect("Failed to add route!");
553 router
554 .add_route(&[0x3], || async { "hello world string".to_owned() })
555 .expect("Failed to add route!");
556 router
557 .add_route(&[0x4], || async {
558 b"hello world vec".iter().map(|i| *i).collect::<Vec<u8>>()
559 })
560 .expect("Failed to add route!");
561 router
562 .add_route(&[0x5], static_byte)
563 .expect("Failed to add route!");
564 router
565 .add_route(&[0x6], || async { "hello world static str" })
566 .expect("Failed to add route!");
567 router
568 .add_route(&[0x7], || async {
569 Err::<String, CatBridgeError>(CatBridgeError::UnsupportedBitsPerCore)
570 })
571 .expect("Failed to add route!");
572
573 assert_eq!(
574 router_body_no_close(&mut router, &[0x1, 0x2, 0x3]).await,
575 b"hello world",
576 );
577 assert_eq!(
578 router_body_no_close(&mut router, &[0x2, 0x2, 0x3]).await,
579 b"hello world bytes_mut",
580 );
581 assert_eq!(
582 router_body_no_close(&mut router, &[0x3, 0x2, 0x3]).await,
583 b"hello world string",
584 );
585 assert_eq!(
586 router_body_no_close(&mut router, &[0x4, 0x2, 0x3]).await,
587 b"hello world vec",
588 );
589 assert_eq!(
590 router_body_no_close(&mut router, &[0x5, 0x2, 0x3]).await,
591 b"hello world static",
592 );
593 assert_eq!(
594 router_body_no_close(&mut router, &[0x6, 0x2, 0x3]).await,
595 b"hello world static str",
596 );
597
598 router_empty_close(&mut router, &[0x7, 0x2, 0x3]).await;
599 }
600
601 #[tokio::test]
602 pub async fn route_to_fallback() {
603 async fn base_handler() -> Result<Bytes, CatBridgeError> {
604 Ok(Bytes::from_static(b"base"))
605 }
606
607 let mut router = Router::new();
608 router
609 .add_route(&[0x1, 0x2, 0x3], base_handler)
610 .expect("Failed to add!");
611 router
612 .add_route(&[0x2, 0x2, 0x3], base_handler)
613 .expect("Failed to add!");
614 router
615 .add_route(&[0x3, 0x2, 0x3], base_handler)
616 .expect("Failed to add!");
617 router
618 .add_route(&[0x4, 0x2, 0x3], base_handler)
619 .expect("Failed to add!");
620 router
621 .add_route(&[0x6, 0x2, 0x3], base_handler)
622 .expect("Failed to add!");
623 router
624 .add_route(&[0x7, 0x2, 0x3], base_handler)
625 .expect("Failed to add!");
626 router
627 .add_route(&[0x8, 0x2, 0x3], base_handler)
628 .expect("Failed to add!");
629 router
630 .add_route(&[0x9, 0x2, 0x3], base_handler)
631 .expect("Failed to add!");
632 router
633 .fallback_handler(|| async { Ok(Bytes::from_static(b"fallback")) })
634 .expect("Failed to register fallback_handler!");
635
636 assert_eq!(
637 router_body_no_close(&mut router, &[0x5, 0x6, 0x7]).await,
638 b"fallback"
639 );
640 assert_eq!(
641 router_body_no_close(&mut router, &[0x1, 0x6, 0x7]).await,
642 b"fallback"
643 );
644 assert_eq!(
645 router_body_no_close(&mut router, &[0x1, 0x2, 0x3]).await,
646 b"base"
647 );
648 }
649
650 #[tokio::test]
651 pub async fn route_at_offset() {
652 #[derive(Clone)]
653 struct TestState {
654 fallbacks: Arc<AtomicU8>,
655 requests: Arc<AtomicU8>,
656 }
657
658 let my_state = TestState {
659 fallbacks: Arc::new(AtomicU8::new(0)),
660 requests: Arc::new(AtomicU8::new(0)),
661 };
662
663 async fn hit(State(test_state): State<TestState>) -> Bytes {
664 test_state.requests.fetch_add(1, Ordering::SeqCst);
665 Bytes::from_static(b"hewwo mw pwizzwa mwan")
666 }
667 async fn fallback(State(test_state): State<TestState>) -> Bytes {
668 test_state.fallbacks.fetch_add(1, Ordering::SeqCst);
669 Bytes::from_static(b"bye pwizza man")
670 }
671
672 let mut router = Router::<TestState>::new_with_offset(4);
673 router
674 .add_route(&[0x1], hit)
675 .expect("Failed to add route to router!");
676 router
677 .add_route(&[0x2, 0x3, 0x4], hit)
678 .expect("Failed to add route to router!");
679 router
680 .add_route(&[0x5, 0x6, 0x7], hit)
681 .expect("Failed to add route to router!");
682 router
683 .fallback_handler(fallback)
684 .expect("Failed to set fallback handler!");
685
686 _ = router_body_no_close_with_state(&mut router, &[], my_state.clone()).await;
688 _ = router_body_no_close_with_state(&mut router, &[0x1; 4], my_state.clone()).await;
689 _ = router_body_no_close_with_state(&mut router, &[0x1; 5], my_state.clone()).await;
691 _ = router_body_no_close_with_state(
692 &mut router,
693 &[0x1, 0x1, 0x1, 0x1, 0x2, 0x3, 0x4],
694 my_state.clone(),
695 )
696 .await;
697 _ = router_body_no_close_with_state(
698 &mut router,
699 &[0x1, 0x1, 0x1, 0x1, 0x5, 0x6, 0x7],
700 my_state.clone(),
701 )
702 .await;
703 _ = router_body_no_close_with_state(
705 &mut router,
706 &[0x1, 0x1, 0x1, 0x1, 0x9],
707 my_state.clone(),
708 )
709 .await;
710
711 assert_eq!(
712 my_state.fallbacks.load(Ordering::SeqCst),
713 3,
714 "Fallback did not get hit required amount of times!",
715 );
716 assert_eq!(
717 my_state.requests.load(Ordering::SeqCst),
718 3,
719 "Request Handler did not get hit required amount of times!",
720 );
721 }
722}