rama_haproxy/server/
layer.rs

1use crate::protocol::{HeaderResult, PartialResult, v1, v2};
2use rama_core::{
3    Context, Layer, Service,
4    error::{BoxError, ErrorExt},
5};
6use rama_net::{
7    forwarded::{Forwarded, ForwardedElement},
8    stream::{ChainReader, HeapReader, Stream},
9};
10use std::{fmt, net::SocketAddr};
11use tokio::io::AsyncReadExt;
12
13/// Layer to decode the HaProxy Protocol
14#[derive(Debug, Default, Clone)]
15#[non_exhaustive]
16pub struct HaProxyLayer;
17
18impl HaProxyLayer {
19    /// Create a new [`HaProxyLayer`].
20    pub const fn new() -> Self {
21        HaProxyLayer
22    }
23}
24
25impl<S> Layer<S> for HaProxyLayer {
26    type Service = HaProxyService<S>;
27
28    fn layer(&self, inner: S) -> Self::Service {
29        HaProxyService { inner }
30    }
31}
32
33/// Service to decode the HaProxy Protocol
34///
35/// This service will decode the HaProxy Protocol header and pass the decoded
36/// information to the inner service.
37pub struct HaProxyService<S> {
38    inner: S,
39}
40
41impl<S> HaProxyService<S> {
42    /// Create a new [`HaProxyService`] with the given inner service.
43    pub const fn new(inner: S) -> Self {
44        HaProxyService { inner }
45    }
46}
47
48impl<S: fmt::Debug> fmt::Debug for HaProxyService<S> {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        f.debug_struct("HaProxyService")
51            .field("inner", &self.inner)
52            .finish()
53    }
54}
55
56impl<S: Clone> Clone for HaProxyService<S> {
57    fn clone(&self) -> Self {
58        HaProxyService {
59            inner: self.inner.clone(),
60        }
61    }
62}
63
64impl<State, S, IO> Service<State, IO> for HaProxyService<S>
65where
66    State: Clone + Send + Sync + 'static,
67    S: Service<
68            State,
69            tokio::io::Join<
70                ChainReader<HeapReader, tokio::io::ReadHalf<IO>>,
71                tokio::io::WriteHalf<IO>,
72            >,
73            Error: Into<BoxError>,
74        >,
75    IO: Stream + Unpin,
76{
77    type Response = S::Response;
78    type Error = BoxError;
79
80    async fn serve(
81        &self,
82        mut ctx: Context<State>,
83        mut stream: IO,
84    ) -> Result<Self::Response, Self::Error> {
85        let mut buffer = [0; 512];
86        let mut read = 0;
87        let header = loop {
88            let n = stream.read(&mut buffer[read..]).await?;
89            read += n;
90
91            let header = HeaderResult::parse(&buffer[..read]);
92            if header.is_complete() {
93                break header;
94            }
95
96            if n == 0 {
97                return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof)
98                    .context("HaProxy header incomplete")
99                    .into_boxed());
100            }
101
102            tracing::debug!("Incomplete header. Read {} bytes so far.", read);
103        };
104
105        let consumed = match header {
106            HeaderResult::V1(Ok(header)) => {
107                match header.addresses {
108                    v1::Addresses::Tcp4(info) => {
109                        let peer_addr: SocketAddr = (info.source_address, info.source_port).into();
110                        let el = ForwardedElement::forwarded_for(peer_addr);
111                        match ctx.get_mut::<Forwarded>() {
112                            Some(forwarded) => {
113                                forwarded.append(el);
114                            }
115                            None => {
116                                let forwarded = Forwarded::new(el);
117                                ctx.insert(forwarded);
118                            }
119                        }
120                    }
121                    v1::Addresses::Tcp6(info) => {
122                        let peer_addr: SocketAddr = (info.source_address, info.source_port).into();
123                        let el = ForwardedElement::forwarded_for(peer_addr);
124                        match ctx.get_mut::<Forwarded>() {
125                            Some(forwarded) => {
126                                forwarded.append(el);
127                            }
128                            None => {
129                                let forwarded = Forwarded::new(el);
130                                ctx.insert(forwarded);
131                            }
132                        }
133                    }
134                    v1::Addresses::Unknown => (),
135                };
136                header.header.len()
137            }
138            HeaderResult::V2(Ok(header)) => {
139                match header.addresses {
140                    v2::Addresses::IPv4(info) => {
141                        let peer_addr: SocketAddr = (info.source_address, info.source_port).into();
142                        let el = ForwardedElement::forwarded_for(peer_addr);
143                        match ctx.get_mut::<Forwarded>() {
144                            Some(forwarded) => {
145                                forwarded.append(el);
146                            }
147                            None => {
148                                let forwarded = Forwarded::new(el);
149                                ctx.insert(forwarded);
150                            }
151                        }
152                    }
153                    v2::Addresses::IPv6(info) => {
154                        let peer_addr: SocketAddr = (info.source_address, info.source_port).into();
155                        let el = ForwardedElement::forwarded_for(peer_addr);
156                        match ctx.get_mut::<Forwarded>() {
157                            Some(forwarded) => {
158                                forwarded.append(el);
159                            }
160                            None => {
161                                let forwarded = Forwarded::new(el);
162                                ctx.insert(forwarded);
163                            }
164                        }
165                    }
166                    v2::Addresses::Unix(_) | v2::Addresses::Unspecified => (),
167                };
168                header.header.len()
169            }
170            HeaderResult::V1(Err(error)) => {
171                return Err(error.into());
172            }
173            HeaderResult::V2(Err(error)) => {
174                return Err(error.into());
175            }
176        };
177
178        // put back the data that is read too much
179        let (r, w) = tokio::io::split(stream);
180        let mem: HeapReader = buffer[consumed..read].into();
181        let r = ChainReader::new(mem, r);
182        let stream = tokio::io::join(r, w);
183
184        // read the rest of the data
185        match self.inner.serve(ctx, stream).await {
186            Ok(response) => Ok(response),
187            Err(error) => Err(error.into()),
188        }
189    }
190}