Skip to main content

rumq_broker/
lib.rs

1#[macro_use]
2extern crate log;
3
4use futures_util::future::join_all;
5use tokio_util::codec::Framed;
6use tokio::net::TcpListener;
7use tokio::sync::mpsc::{channel, Sender, Receiver};
8use tokio::task;
9use tokio::time::{self, Elapsed};
10use tokio_rustls::rustls::internal::pemfile::{certs, rsa_private_keys};
11use tokio_rustls::rustls::TLSError;
12use tokio_rustls::rustls::{AllowAnyAuthenticatedClient, NoClientAuth, RootCertStore, ServerConfig};
13use tokio_rustls::TlsAcceptor;
14use futures_util::sink::Sink;
15use futures_util::stream::Stream;
16use rumq_core::mqtt4::{codec, Packet};
17
18use serde::Deserialize;
19
20use std::fs::File;
21use std::io::{self, BufReader};
22use std::path::Path;
23use std::sync::Arc;
24use std::time::Duration;
25use std::thread;
26
27mod connection;
28mod state;
29mod router;
30
31pub use rumq_core as core;
32pub use router::{RouterMessage, Connection};
33
34#[derive(Debug, thiserror::Error)]
35pub enum Error {
36    #[error("I/O")]
37    Io(#[from] io::Error),
38    #[error("MQTT protocol error")]
39    Mqtt(#[from] rumq_core::Error),
40    #[error("Timeout")]
41    Timeout(#[from] Elapsed),
42    #[error("Broker State")]
43    State(#[from] state::Error),
44    #[error("TLS")]
45    Tls(#[from] TLSError),
46    #[error("No server cert")]
47    NoServerCert,
48    #[error("No server private key")]
49    NoServerPrivateKey,
50    #[error("No ca file")]
51    NoCAFile,
52    #[error("No server cert file")]
53    NoServerCertFile,
54    #[error("No server key file")]
55    NoServerKeyFile,
56    #[error("Disconnected")]
57    Disconnected,
58}
59
60#[derive(Debug, Deserialize, Clone)]
61pub struct Config {
62    servers:    Vec<ServerSettings>,
63}
64
65#[derive(Debug, Deserialize, Clone)]
66pub struct ServerSettings {
67    pub port: u16,
68    pub connection_timeout_ms: u16,
69    pub next_connection_delay_ms: u64,
70    pub max_client_id_len: usize,
71    pub max_connections: usize,
72    pub disk_persistence: bool,
73    pub throttle_delay_ms: u64,
74    pub disk_retention_size: usize,
75    pub disk_retention_time_sec: usize,
76    pub auto_save_interval_sec: u16,
77    pub max_payload_size: usize,
78    pub max_inflight_topic_size: usize,
79    pub ca_path: Option<String>,
80    pub cert_path: Option<String>,
81    pub key_path: Option<String>,
82    pub username: Option<String>,
83    pub password: Option<String>,
84}
85
86async fn tls_connection<P: AsRef<Path>>(ca_path: Option<P>, cert_path: P, key_path: P) -> Result<TlsAcceptor, Error> {
87    // client authentication with a CA. CA isn't required otherwise
88    let mut server_config = if let Some(ca_path) = ca_path {
89        let mut root_cert_store = RootCertStore::empty();
90        root_cert_store.add_pem_file(&mut BufReader::new(File::open(ca_path)?)).map_err(|_| Error::NoCAFile)?;
91        ServerConfig::new(AllowAnyAuthenticatedClient::new(root_cert_store))
92    } else {
93        ServerConfig::new(NoClientAuth::new())
94    };
95
96    let certs = certs(&mut BufReader::new(File::open(cert_path)?)).map_err(|_| Error::NoServerCertFile)?;
97    let mut keys = rsa_private_keys(&mut BufReader::new(File::open(key_path)?)).map_err(|_| Error::NoServerKeyFile)?;
98
99    server_config.set_single_cert(certs, keys.remove(0))?;
100    let acceptor = TlsAcceptor::from(Arc::new(server_config));
101    Ok(acceptor)
102}
103
104async fn accept_loop(config: Arc<ServerSettings>, router_tx: Sender<(String, router::RouterMessage)>) -> Result<(), Error> {
105    let addr = format!("0.0.0.0:{}", config.port);
106    let connection_config = config.clone();
107
108    let acceptor = if let Some(cert_path) = config.cert_path.clone() {
109        let key_path = config.key_path.clone().ok_or(Error::NoServerPrivateKey)?;
110        Some(tls_connection(config.ca_path.clone(), cert_path, key_path).await?)
111    } else {
112        None
113    };
114
115    info!("Waiting for connections on {}", addr);
116    // eventloop which accepts connections
117    let mut listener = TcpListener::bind(addr).await?;
118    let accept_loop_delay = Duration::from_millis(config.next_connection_delay_ms);
119    loop {
120        let (stream, addr) = match listener.accept().await {
121            Ok(s) => s,
122            Err(e) => {
123                error!("Tcp connection error = {:?}", e);
124                continue;
125            }
126        };
127
128        info!("Accepting from: {}", addr);
129
130        let config = connection_config.clone();
131        let router_tx = router_tx.clone();
132
133        if let Some(acceptor) = &acceptor {
134            let stream = match acceptor.accept(stream).await {
135                Ok(s) => s,
136                Err(e) => {
137                    error!("Tls connection error = {:?}", e);
138                    continue;
139                }
140            };
141
142            let framed = Framed::new(stream, codec::MqttCodec::new(config.max_payload_size));
143            task::spawn( async {
144                match connection::eventloop(config, framed, router_tx).await {
145                    Ok(id) => info!("Connection eventloop done!!. Id = {:?}", id),
146                    Err(e) => error!("Connection eventloop error = {:?}", e),
147                }
148            });
149        } else {
150            let framed = Framed::new(stream, codec::MqttCodec::new(config.max_payload_size));
151            task::spawn( async {
152                match connection::eventloop(config, framed, router_tx).await {
153                    Ok(id) => info!("Connection eventloop done!!. Id = {:?}", id),
154                    Err(e) => error!("Connection eventloop error = {:?}", e),
155                }
156            });
157        };
158
159        time::delay_for(accept_loop_delay).await;
160    }
161}
162
163
164pub trait Network: Stream<Item = Result<Packet, rumq_core::Error>> + Sink<Packet, Error = io::Error> + Unpin + Send {}
165impl<T> Network for T where T: Stream<Item = Result<Packet, rumq_core::Error>> + Sink<Packet, Error = io::Error> + Unpin + Send {}
166
167#[tokio::main(core_threads = 1)]
168async fn router(rx: Receiver<(String, router::RouterMessage)>) {
169    let mut router = router::Router::new(rx);
170    if let Err(e) = router.start().await {
171        error!("Router stopped. Error = {:?}", e);
172    }
173}
174
175pub struct Broker {
176    config: Config,
177    router_handle: Sender<(String, router::RouterMessage)>,
178}
179
180pub fn new(config: Config) -> Broker {
181    let (router_tx, router_rx) = channel(100);
182
183    thread::spawn(move || {
184        router(router_rx)
185    });
186
187    Broker {
188        config,
189        router_handle: router_tx
190    }
191}
192
193impl Broker {
194    pub fn new_router_handle(&self) -> Sender<(String, router::RouterMessage)> {
195        self.router_handle.clone()
196    }
197
198    pub async fn start(&mut self) -> Vec<Result<(), task::JoinError>> {
199        let mut servers = Vec::new();
200        let server_configs = self.config.servers.split_off(0);
201
202        for server in server_configs.into_iter() {
203            let config = Arc::new(server);
204            let fut = accept_loop(config, self.router_handle.clone());
205            let o = task::spawn(async {
206                error!("Accept loop returned = {:?}", fut.await);
207            });
208
209            servers.push(o);
210        }
211
212        join_all(servers).await
213    }
214}
215
216#[cfg(test)]
217mod test {
218    #[test]
219    fn accept_loop_rate_limits_incoming_connections() {}
220
221    #[test]
222    fn accept_loop_should_not_allow_more_than_maximum_connections() {}
223
224    #[test]
225    fn accept_loop_should_accept_new_connection_when_a_client_disconnects_after_max_connections() {}
226
227    #[test]
228    fn client_loop_should_error_if_connect_packet_is_not_received_in_time() {}
229}