monolake_services/common/
detect.rs

1use std::{future::Future, io, io::Cursor};
2
3use monoio::{
4    buf::IoBufMut,
5    io::{AsyncReadRent, AsyncReadRentExt, PrefixedReadIo},
6};
7use service_async::Service;
8
9/// Detect is a trait for detecting a certain pattern in the input stream.
10///
11/// It accepts an input stream and returns a tuple of the detected pattern and the wrapped input
12/// stream which is usually a `PrefixedReadIo`. The implementation can choose to whether add the
13/// prefix data.
14/// If it fails to detect the pattern, it should represent the error inside the `DetOut`.
15pub trait Detect<IO> {
16    type DetOut;
17    type IOOut;
18
19    fn detect(&self, io: IO) -> impl Future<Output = io::Result<(Self::DetOut, Self::IOOut)>>;
20}
21
22/// DetectService is a service that detects a certain pattern in the input stream and forwards the
23/// detected pattern and the wrapped input stream to the inner service.
24pub struct DetectService<D, S> {
25    pub detector: D,
26    pub inner: S,
27}
28
29#[derive(thiserror::Error, Debug)]
30pub enum DetectError<E> {
31    #[error("service error: {0:?}")]
32    Svc(E),
33    #[error("io error: {0:?}")]
34    Io(std::io::Error),
35}
36
37impl<R, S, D, CX> Service<(R, CX)> for DetectService<D, S>
38where
39    D: Detect<R>,
40    S: Service<(D::DetOut, D::IOOut, CX)>,
41{
42    type Response = S::Response;
43    type Error = DetectError<S::Error>;
44
45    async fn call(&self, (io, cx): (R, CX)) -> Result<Self::Response, Self::Error> {
46        let (det, io) = self.detector.detect(io).await.map_err(DetectError::Io)?;
47        self.inner
48            .call((det, io, cx))
49            .await
50            .map_err(DetectError::Svc)
51    }
52}
53
54/// FixedLengthDetector detects a fixed length of bytes from the input stream.
55pub struct FixedLengthDetector<const N: usize, F>(pub F);
56
57impl<const N: usize, F, IO, DetOut> Detect<IO> for FixedLengthDetector<N, F>
58where
59    F: Fn(&mut [u8]) -> DetOut,
60    IO: AsyncReadRent,
61{
62    type DetOut = DetOut;
63    type IOOut = PrefixedReadIo<IO, Cursor<Vec<u8>>>;
64
65    async fn detect(&self, mut io: IO) -> io::Result<(Self::DetOut, Self::IOOut)> {
66        let buf = Vec::with_capacity(N).slice_mut(..N);
67        let (r, buf) = io.read_exact(buf).await;
68        r?;
69
70        let mut buf = buf.into_inner();
71        let r = (self.0)(&mut buf);
72        Ok((r, PrefixedReadIo::new(io, Cursor::new(buf))))
73    }
74}
75
76/// PrefixDetector detects a certain prefix from the input stream.
77///
78/// If the prefix matches, it returns true and the wrapped input stream with the prefix data.
79/// Otherwise, it returns false and the input stream with the prefix data(the prefix maybe less than
80/// the static str's length).
81pub struct PrefixDetector(pub &'static [u8]);
82
83impl<IO> Detect<IO> for PrefixDetector
84where
85    IO: AsyncReadRent,
86{
87    type DetOut = bool;
88    type IOOut = PrefixedReadIo<IO, Cursor<Vec<u8>>>;
89
90    async fn detect(&self, mut io: IO) -> io::Result<(Self::DetOut, Self::IOOut)> {
91        let l = self.0.len();
92        let mut written = 0;
93        let mut buf: Vec<u8> = Vec::with_capacity(l);
94        let mut eq = true;
95        loop {
96            // # Safety
97            // The buf must have enough capacity to write the data.
98            let buf_slice = unsafe { buf.slice_mut_unchecked(written..l) };
99            let (result, buf_slice) = io.read(buf_slice).await;
100            buf = buf_slice.into_inner();
101            match result? {
102                0 => {
103                    break;
104                }
105                n => {
106                    let curr = written;
107                    written += n;
108                    if self.0[curr..written] != buf[curr..written] {
109                        eq = false;
110                        break;
111                    }
112                }
113            }
114        }
115        let io = PrefixedReadIo::new(io, Cursor::new(buf));
116        Ok((eq && written == l, io))
117    }
118}