motorx_core/
lib.rs

1//! A reverse-proxy written in pure rust, built on hyper, tokio, and rustls
2//! # Motorx
3//! ## Basic usage
4//!
5//! ```ignore
6//! #[tokio::main]
7//! async fn main() {
8//!     // Register a tracing subscriber for logging
9//!
10//!     let server = motorx_core::Server::new(motorx_core::Config { /* Your config here */ });
11//!
12//!     // Start the server
13//!     server.run().await.unwrap()
14//! }
15//! ```
16
17pub mod config;
18mod conn_pool;
19pub mod error;
20mod handle;
21#[macro_use]
22pub mod log;
23mod cache;
24#[cfg(test)]
25mod e2e;
26mod listener;
27#[cfg(feature = "tls")]
28pub mod tls;
29
30#[cfg_attr(feature = "logging", macro_use(info, error, debug, trace))]
31#[cfg(feature = "logging")]
32extern crate tracing;
33
34use std::net::SocketAddr;
35use std::sync::Arc;
36
37use cache::Cache;
38use config::Upstream;
39use conn_pool::ConnPool;
40use hyper::body::Incoming;
41use hyper::service::service_fn;
42use hyper::Request;
43use hyper_util::rt::{TokioExecutor, TokioIo};
44use listener::Listener;
45#[cfg(feature = "tls")]
46use tls::stream::TlsStream;
47use tokio::io::{AsyncRead, AsyncWrite};
48use tokio::sync::{OwnedSemaphorePermit, Semaphore};
49
50pub use config::{CacheSettings, Config, Rule};
51pub use error::Error;
52
53// TODO: Consider Boxing this (Or just ConnPool) to improve spacial locality
54type UpstreamAndConnPool = (Arc<Upstream>, ConnPool);
55type Upstreams = Vec<UpstreamAndConnPool>;
56
57/// Motorx proxy server
58///
59/// Usage:
60/// ```ignore
61/// #[tokio::main]
62/// async fn main() {
63///     // Register a tracing subscriber for logging
64///
65///     let server = motorx_core::Server::new(motorx_core::Config { /* Your config here */ });
66///
67///     // start polling and proxying requests
68///     server.run().await.unwrap()
69/// }
70/// ```
71pub struct Server {
72    config: Arc<Config>,
73    cache: Arc<Cache>,
74    upstreams: Arc<Upstreams>,
75    listener: Listener,
76    /// Used to enforce max num of connections to this server
77    semaphore: Arc<Semaphore>,
78}
79
80impl Server {
81    /// Do configuration shared between raw and tls servers
82    fn common_config(
83        mut config: Config,
84    ) -> Result<(Arc<Config>, Arc<Cache>, Arc<Upstreams>, Listener), Error> {
85        let upstreams = Arc::new(init_upstreams(&mut config));
86        let cache = Arc::new(Cache::from_config(&mut config));
87
88        config.rules.sort_by(|a, b| a.path.cmp(&b.path));
89        let config = Arc::new(config);
90
91        cfg_logging! {debug!("Starting with config: {:#?}", *config);}
92
93        Ok((
94            config.clone(),
95            cache,
96            upstreams,
97            Listener::from_config(&config)?,
98        ))
99    }
100
101    pub fn new(config: Config) -> Result<Self, Error> {
102        let (config, cache, conn_pools, listener) = Self::common_config(config)?;
103
104        cfg_logging! {
105            info!("Motorx proxy listening on http://{}", {
106                listener.local_addr().unwrap()
107            });
108        }
109
110        Ok(Self {
111            semaphore: Arc::new(Semaphore::new(config.max_connections)),
112            cache,
113            upstreams: conn_pools,
114            config,
115            listener,
116        })
117    }
118
119    pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
120        self.listener.local_addr()
121    }
122
123    pub async fn run(mut self) -> Result<(), hyper::Error> {
124        loop {
125            if let Ok(permit) = self.semaphore.clone().acquire_owned().await {
126                match self.listener.accept().await {
127                    Ok((stream, peer_addr)) => {
128                        cfg_logging! {
129                            trace!("Accepted connection from {}", peer_addr);
130                        }
131
132                        handle_connection(
133                            stream,
134                            peer_addr,
135                            Arc::clone(&self.config),
136                            Arc::clone(&self.cache),
137                            Arc::clone(&self.upstreams),
138                            permit,
139                        );
140                    }
141                    Err(e) => {
142                        cfg_logging! {
143                            error!("Error connecting, {:?}", e);
144                        }
145                    }
146                }
147            }
148        }
149    }
150}
151
152#[cfg_attr(
153    feature = "logging",
154    tracing::instrument(skip(stream, config, cache, permit))
155)]
156fn handle_connection<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
157    stream: S,
158    peer_addr: SocketAddr,
159    config: Arc<Config>,
160    cache: Arc<Cache>,
161    conn_pools: Arc<Upstreams>,
162    permit: OwnedSemaphorePermit,
163) {
164    let service = service_fn(move |req: Request<Incoming>| {
165        let config = config.clone();
166        let cache = cache.clone();
167        let conn_pools = conn_pools.clone();
168
169        async move {
170            let res = handle::handle_req(
171                req,
172                peer_addr,
173                Arc::clone(&config),
174                Arc::clone(&cache),
175                Arc::clone(&conn_pools),
176            )
177            .await;
178
179            cfg_logging! {
180                trace!("Responded to req from {}", peer_addr);
181            }
182
183            res
184        }
185    });
186
187    tokio::spawn(async move {
188        cfg_logging! {
189            trace!("Handling connection from {}", peer_addr);
190        }
191        let conn_build = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
192        if let Err(err) = conn_build
193            .serve_connection_with_upgrades(TokioIo::new(stream), service)
194            .await
195        {
196            cfg_logging! {trace!("Error handling connection: {err:?}");}
197        };
198
199        cfg_logging! {
200            trace!("Closing connection to {}", peer_addr);
201        }
202
203        drop(permit);
204    });
205}
206
207#[inline]
208fn tcp_listener(addr: SocketAddr) -> std::io::Result<tokio::net::TcpListener> {
209    let std_listener = std::net::TcpListener::bind(addr)?;
210    std_listener.set_nonblocking(true)?;
211    tokio::net::TcpListener::from_std(std_listener)
212}
213
214#[inline]
215async fn tcp_connect(
216    addr: impl tokio::net::ToSocketAddrs,
217) -> std::io::Result<tokio::net::TcpStream> {
218    tokio::net::TcpStream::connect(addr).await
219}
220
221fn init_upstreams(config: &mut Config) -> Upstreams {
222    let mut upstreams = Vec::with_capacity(config.upstreams.len());
223
224    let mut upstream_order = Vec::new();
225
226    for upstream_name in config.upstreams.keys() {
227        upstream_order.push(upstream_name.clone());
228    }
229
230    for (key, upstream_name) in upstream_order.iter().enumerate() {
231        // Find any authentication referencing this upstream and populate their key
232        for (_, upstream) in &mut config.upstreams {
233            if let Some(auth) = Arc::get_mut(upstream).unwrap().authentication.as_mut() {
234                match &mut auth.source {
235                    config::authentication::AuthenticationSource::Upstream {
236                        name: _,
237                        path: _,
238                        key: upstream_key,
239                    } => *upstream_key = key,
240                    config::authentication::AuthenticationSource::Path(_) => {}
241                }
242            }
243        }
244
245        // Find any rules referencing this upstream and populate them with the key
246        for rule in &mut config.rules {
247            if rule.upstream == *upstream_name {
248                rule.upstream_key = key;
249            }
250        }
251    }
252
253    // Now, add upstreams into Vec
254    for (key, upstream_name) in upstream_order.iter().enumerate() {
255        let upstream = config.upstreams.get_mut(upstream_name).unwrap();
256        Arc::get_mut(upstream).unwrap().key = key;
257        upstreams.push((
258            Arc::clone(upstream),
259            ConnPool::new(upstream.addr.clone(), upstream.max_connections),
260        ));
261    }
262
263    upstreams.shrink_to_fit();
264
265    upstreams
266}