1pub mod config;
18mod conn_pool;
19pub mod error;
20mod handle;
21#[macro_use]
22pub mod log;
23mod cache;
24#[cfg(test)]
25mod e2e;
26mod listener;
27#[cfg(feature = "tls")]
28pub mod tls;
29
30#[cfg_attr(feature = "logging", macro_use(info, error, debug, trace))]
31#[cfg(feature = "logging")]
32extern crate tracing;
33
34use std::net::SocketAddr;
35use std::sync::Arc;
36
37use cache::Cache;
38use config::Upstream;
39use conn_pool::ConnPool;
40use hyper::body::Incoming;
41use hyper::service::service_fn;
42use hyper::Request;
43use hyper_util::rt::{TokioExecutor, TokioIo};
44use listener::Listener;
45#[cfg(feature = "tls")]
46use tls::stream::TlsStream;
47use tokio::io::{AsyncRead, AsyncWrite};
48use tokio::sync::{OwnedSemaphorePermit, Semaphore};
49
50pub use config::{CacheSettings, Config, Rule};
51pub use error::Error;
52
53type UpstreamAndConnPool = (Arc<Upstream>, ConnPool);
55type Upstreams = Vec<UpstreamAndConnPool>;
56
57pub struct Server {
72 config: Arc<Config>,
73 cache: Arc<Cache>,
74 upstreams: Arc<Upstreams>,
75 listener: Listener,
76 semaphore: Arc<Semaphore>,
78}
79
80impl Server {
81 fn common_config(
83 mut config: Config,
84 ) -> Result<(Arc<Config>, Arc<Cache>, Arc<Upstreams>, Listener), Error> {
85 let upstreams = Arc::new(init_upstreams(&mut config));
86 let cache = Arc::new(Cache::from_config(&mut config));
87
88 config.rules.sort_by(|a, b| a.path.cmp(&b.path));
89 let config = Arc::new(config);
90
91 cfg_logging! {debug!("Starting with config: {:#?}", *config);}
92
93 Ok((
94 config.clone(),
95 cache,
96 upstreams,
97 Listener::from_config(&config)?,
98 ))
99 }
100
101 pub fn new(config: Config) -> Result<Self, Error> {
102 let (config, cache, conn_pools, listener) = Self::common_config(config)?;
103
104 cfg_logging! {
105 info!("Motorx proxy listening on http://{}", {
106 listener.local_addr().unwrap()
107 });
108 }
109
110 Ok(Self {
111 semaphore: Arc::new(Semaphore::new(config.max_connections)),
112 cache,
113 upstreams: conn_pools,
114 config,
115 listener,
116 })
117 }
118
119 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
120 self.listener.local_addr()
121 }
122
123 pub async fn run(mut self) -> Result<(), hyper::Error> {
124 loop {
125 if let Ok(permit) = self.semaphore.clone().acquire_owned().await {
126 match self.listener.accept().await {
127 Ok((stream, peer_addr)) => {
128 cfg_logging! {
129 trace!("Accepted connection from {}", peer_addr);
130 }
131
132 handle_connection(
133 stream,
134 peer_addr,
135 Arc::clone(&self.config),
136 Arc::clone(&self.cache),
137 Arc::clone(&self.upstreams),
138 permit,
139 );
140 }
141 Err(e) => {
142 cfg_logging! {
143 error!("Error connecting, {:?}", e);
144 }
145 }
146 }
147 }
148 }
149 }
150}
151
152#[cfg_attr(
153 feature = "logging",
154 tracing::instrument(skip(stream, config, cache, permit))
155)]
156fn handle_connection<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
157 stream: S,
158 peer_addr: SocketAddr,
159 config: Arc<Config>,
160 cache: Arc<Cache>,
161 conn_pools: Arc<Upstreams>,
162 permit: OwnedSemaphorePermit,
163) {
164 let service = service_fn(move |req: Request<Incoming>| {
165 let config = config.clone();
166 let cache = cache.clone();
167 let conn_pools = conn_pools.clone();
168
169 async move {
170 let res = handle::handle_req(
171 req,
172 peer_addr,
173 Arc::clone(&config),
174 Arc::clone(&cache),
175 Arc::clone(&conn_pools),
176 )
177 .await;
178
179 cfg_logging! {
180 trace!("Responded to req from {}", peer_addr);
181 }
182
183 res
184 }
185 });
186
187 tokio::spawn(async move {
188 cfg_logging! {
189 trace!("Handling connection from {}", peer_addr);
190 }
191 let conn_build = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
192 if let Err(err) = conn_build
193 .serve_connection_with_upgrades(TokioIo::new(stream), service)
194 .await
195 {
196 cfg_logging! {trace!("Error handling connection: {err:?}");}
197 };
198
199 cfg_logging! {
200 trace!("Closing connection to {}", peer_addr);
201 }
202
203 drop(permit);
204 });
205}
206
207#[inline]
208fn tcp_listener(addr: SocketAddr) -> std::io::Result<tokio::net::TcpListener> {
209 let std_listener = std::net::TcpListener::bind(addr)?;
210 std_listener.set_nonblocking(true)?;
211 tokio::net::TcpListener::from_std(std_listener)
212}
213
214#[inline]
215async fn tcp_connect(
216 addr: impl tokio::net::ToSocketAddrs,
217) -> std::io::Result<tokio::net::TcpStream> {
218 tokio::net::TcpStream::connect(addr).await
219}
220
221fn init_upstreams(config: &mut Config) -> Upstreams {
222 let mut upstreams = Vec::with_capacity(config.upstreams.len());
223
224 let mut upstream_order = Vec::new();
225
226 for upstream_name in config.upstreams.keys() {
227 upstream_order.push(upstream_name.clone());
228 }
229
230 for (key, upstream_name) in upstream_order.iter().enumerate() {
231 for (_, upstream) in &mut config.upstreams {
233 if let Some(auth) = Arc::get_mut(upstream).unwrap().authentication.as_mut() {
234 match &mut auth.source {
235 config::authentication::AuthenticationSource::Upstream {
236 name: _,
237 path: _,
238 key: upstream_key,
239 } => *upstream_key = key,
240 config::authentication::AuthenticationSource::Path(_) => {}
241 }
242 }
243 }
244
245 for rule in &mut config.rules {
247 if rule.upstream == *upstream_name {
248 rule.upstream_key = key;
249 }
250 }
251 }
252
253 for (key, upstream_name) in upstream_order.iter().enumerate() {
255 let upstream = config.upstreams.get_mut(upstream_name).unwrap();
256 Arc::get_mut(upstream).unwrap().key = key;
257 upstreams.push((
258 Arc::clone(upstream),
259 ConnPool::new(upstream.addr.clone(), upstream.max_connections),
260 ));
261 }
262
263 upstreams.shrink_to_fit();
264
265 upstreams
266}