rama_haproxy/server/
layer.rs1use crate::protocol::{HeaderResult, PartialResult, v1, v2};
2use rama_core::{
3 Context, Layer, Service,
4 error::{BoxError, ErrorExt},
5};
6use rama_net::{
7 forwarded::{Forwarded, ForwardedElement},
8 stream::{ChainReader, HeapReader, Stream},
9};
10use std::{fmt, net::SocketAddr};
11use tokio::io::AsyncReadExt;
12
13#[derive(Debug, Default, Clone)]
15#[non_exhaustive]
16pub struct HaProxyLayer;
17
18impl HaProxyLayer {
19 pub const fn new() -> Self {
21 HaProxyLayer
22 }
23}
24
25impl<S> Layer<S> for HaProxyLayer {
26 type Service = HaProxyService<S>;
27
28 fn layer(&self, inner: S) -> Self::Service {
29 HaProxyService { inner }
30 }
31}
32
33pub struct HaProxyService<S> {
38 inner: S,
39}
40
41impl<S> HaProxyService<S> {
42 pub const fn new(inner: S) -> Self {
44 HaProxyService { inner }
45 }
46}
47
48impl<S: fmt::Debug> fmt::Debug for HaProxyService<S> {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 f.debug_struct("HaProxyService")
51 .field("inner", &self.inner)
52 .finish()
53 }
54}
55
56impl<S: Clone> Clone for HaProxyService<S> {
57 fn clone(&self) -> Self {
58 HaProxyService {
59 inner: self.inner.clone(),
60 }
61 }
62}
63
64impl<State, S, IO> Service<State, IO> for HaProxyService<S>
65where
66 State: Clone + Send + Sync + 'static,
67 S: Service<
68 State,
69 tokio::io::Join<
70 ChainReader<HeapReader, tokio::io::ReadHalf<IO>>,
71 tokio::io::WriteHalf<IO>,
72 >,
73 Error: Into<BoxError>,
74 >,
75 IO: Stream + Unpin,
76{
77 type Response = S::Response;
78 type Error = BoxError;
79
80 async fn serve(
81 &self,
82 mut ctx: Context<State>,
83 mut stream: IO,
84 ) -> Result<Self::Response, Self::Error> {
85 let mut buffer = [0; 512];
86 let mut read = 0;
87 let header = loop {
88 let n = stream.read(&mut buffer[read..]).await?;
89 read += n;
90
91 let header = HeaderResult::parse(&buffer[..read]);
92 if header.is_complete() {
93 break header;
94 }
95
96 if n == 0 {
97 return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof)
98 .context("HaProxy header incomplete")
99 .into_boxed());
100 }
101
102 tracing::debug!("Incomplete header. Read {} bytes so far.", read);
103 };
104
105 let consumed = match header {
106 HeaderResult::V1(Ok(header)) => {
107 match header.addresses {
108 v1::Addresses::Tcp4(info) => {
109 let peer_addr: SocketAddr = (info.source_address, info.source_port).into();
110 let el = ForwardedElement::forwarded_for(peer_addr);
111 match ctx.get_mut::<Forwarded>() {
112 Some(forwarded) => {
113 forwarded.append(el);
114 }
115 None => {
116 let forwarded = Forwarded::new(el);
117 ctx.insert(forwarded);
118 }
119 }
120 }
121 v1::Addresses::Tcp6(info) => {
122 let peer_addr: SocketAddr = (info.source_address, info.source_port).into();
123 let el = ForwardedElement::forwarded_for(peer_addr);
124 match ctx.get_mut::<Forwarded>() {
125 Some(forwarded) => {
126 forwarded.append(el);
127 }
128 None => {
129 let forwarded = Forwarded::new(el);
130 ctx.insert(forwarded);
131 }
132 }
133 }
134 v1::Addresses::Unknown => (),
135 };
136 header.header.len()
137 }
138 HeaderResult::V2(Ok(header)) => {
139 match header.addresses {
140 v2::Addresses::IPv4(info) => {
141 let peer_addr: SocketAddr = (info.source_address, info.source_port).into();
142 let el = ForwardedElement::forwarded_for(peer_addr);
143 match ctx.get_mut::<Forwarded>() {
144 Some(forwarded) => {
145 forwarded.append(el);
146 }
147 None => {
148 let forwarded = Forwarded::new(el);
149 ctx.insert(forwarded);
150 }
151 }
152 }
153 v2::Addresses::IPv6(info) => {
154 let peer_addr: SocketAddr = (info.source_address, info.source_port).into();
155 let el = ForwardedElement::forwarded_for(peer_addr);
156 match ctx.get_mut::<Forwarded>() {
157 Some(forwarded) => {
158 forwarded.append(el);
159 }
160 None => {
161 let forwarded = Forwarded::new(el);
162 ctx.insert(forwarded);
163 }
164 }
165 }
166 v2::Addresses::Unix(_) | v2::Addresses::Unspecified => (),
167 };
168 header.header.len()
169 }
170 HeaderResult::V1(Err(error)) => {
171 return Err(error.into());
172 }
173 HeaderResult::V2(Err(error)) => {
174 return Err(error.into());
175 }
176 };
177
178 let (r, w) = tokio::io::split(stream);
180 let mem: HeapReader = buffer[consumed..read].into();
181 let r = ChainReader::new(mem, r);
182 let stream = tokio::io::join(r, w);
183
184 match self.inner.serve(ctx, stream).await {
186 Ok(response) => Ok(response),
187 Err(error) => Err(error.into()),
188 }
189 }
190}