1#[macro_use]
2extern crate log;
3
4use futures_util::future::join_all;
5use tokio_util::codec::Framed;
6use tokio::net::TcpListener;
7use tokio::sync::mpsc::{channel, Sender, Receiver};
8use tokio::task;
9use tokio::time::{self, Elapsed};
10use tokio_rustls::rustls::internal::pemfile::{certs, rsa_private_keys};
11use tokio_rustls::rustls::TLSError;
12use tokio_rustls::rustls::{AllowAnyAuthenticatedClient, NoClientAuth, RootCertStore, ServerConfig};
13use tokio_rustls::TlsAcceptor;
14use futures_util::sink::Sink;
15use futures_util::stream::Stream;
16use rumq_core::mqtt4::{codec, Packet};
17
18use serde::Deserialize;
19
20use std::fs::File;
21use std::io::{self, BufReader};
22use std::path::Path;
23use std::sync::Arc;
24use std::time::Duration;
25use std::thread;
26
27mod connection;
28mod state;
29mod router;
30
31pub use rumq_core as core;
32pub use router::{RouterMessage, Connection};
33
34#[derive(Debug, thiserror::Error)]
35pub enum Error {
36 #[error("I/O")]
37 Io(#[from] io::Error),
38 #[error("MQTT protocol error")]
39 Mqtt(#[from] rumq_core::Error),
40 #[error("Timeout")]
41 Timeout(#[from] Elapsed),
42 #[error("Broker State")]
43 State(#[from] state::Error),
44 #[error("TLS")]
45 Tls(#[from] TLSError),
46 #[error("No server cert")]
47 NoServerCert,
48 #[error("No server private key")]
49 NoServerPrivateKey,
50 #[error("No ca file")]
51 NoCAFile,
52 #[error("No server cert file")]
53 NoServerCertFile,
54 #[error("No server key file")]
55 NoServerKeyFile,
56 #[error("Disconnected")]
57 Disconnected,
58}
59
60#[derive(Debug, Deserialize, Clone)]
61pub struct Config {
62 servers: Vec<ServerSettings>,
63}
64
65#[derive(Debug, Deserialize, Clone)]
66pub struct ServerSettings {
67 pub port: u16,
68 pub connection_timeout_ms: u16,
69 pub next_connection_delay_ms: u64,
70 pub max_client_id_len: usize,
71 pub max_connections: usize,
72 pub disk_persistence: bool,
73 pub throttle_delay_ms: u64,
74 pub disk_retention_size: usize,
75 pub disk_retention_time_sec: usize,
76 pub auto_save_interval_sec: u16,
77 pub max_payload_size: usize,
78 pub max_inflight_topic_size: usize,
79 pub ca_path: Option<String>,
80 pub cert_path: Option<String>,
81 pub key_path: Option<String>,
82 pub username: Option<String>,
83 pub password: Option<String>,
84}
85
86async fn tls_connection<P: AsRef<Path>>(ca_path: Option<P>, cert_path: P, key_path: P) -> Result<TlsAcceptor, Error> {
87 let mut server_config = if let Some(ca_path) = ca_path {
89 let mut root_cert_store = RootCertStore::empty();
90 root_cert_store.add_pem_file(&mut BufReader::new(File::open(ca_path)?)).map_err(|_| Error::NoCAFile)?;
91 ServerConfig::new(AllowAnyAuthenticatedClient::new(root_cert_store))
92 } else {
93 ServerConfig::new(NoClientAuth::new())
94 };
95
96 let certs = certs(&mut BufReader::new(File::open(cert_path)?)).map_err(|_| Error::NoServerCertFile)?;
97 let mut keys = rsa_private_keys(&mut BufReader::new(File::open(key_path)?)).map_err(|_| Error::NoServerKeyFile)?;
98
99 server_config.set_single_cert(certs, keys.remove(0))?;
100 let acceptor = TlsAcceptor::from(Arc::new(server_config));
101 Ok(acceptor)
102}
103
104async fn accept_loop(config: Arc<ServerSettings>, router_tx: Sender<(String, router::RouterMessage)>) -> Result<(), Error> {
105 let addr = format!("0.0.0.0:{}", config.port);
106 let connection_config = config.clone();
107
108 let acceptor = if let Some(cert_path) = config.cert_path.clone() {
109 let key_path = config.key_path.clone().ok_or(Error::NoServerPrivateKey)?;
110 Some(tls_connection(config.ca_path.clone(), cert_path, key_path).await?)
111 } else {
112 None
113 };
114
115 info!("Waiting for connections on {}", addr);
116 let mut listener = TcpListener::bind(addr).await?;
118 let accept_loop_delay = Duration::from_millis(config.next_connection_delay_ms);
119 loop {
120 let (stream, addr) = match listener.accept().await {
121 Ok(s) => s,
122 Err(e) => {
123 error!("Tcp connection error = {:?}", e);
124 continue;
125 }
126 };
127
128 info!("Accepting from: {}", addr);
129
130 let config = connection_config.clone();
131 let router_tx = router_tx.clone();
132
133 if let Some(acceptor) = &acceptor {
134 let stream = match acceptor.accept(stream).await {
135 Ok(s) => s,
136 Err(e) => {
137 error!("Tls connection error = {:?}", e);
138 continue;
139 }
140 };
141
142 let framed = Framed::new(stream, codec::MqttCodec::new(config.max_payload_size));
143 task::spawn( async {
144 match connection::eventloop(config, framed, router_tx).await {
145 Ok(id) => info!("Connection eventloop done!!. Id = {:?}", id),
146 Err(e) => error!("Connection eventloop error = {:?}", e),
147 }
148 });
149 } else {
150 let framed = Framed::new(stream, codec::MqttCodec::new(config.max_payload_size));
151 task::spawn( async {
152 match connection::eventloop(config, framed, router_tx).await {
153 Ok(id) => info!("Connection eventloop done!!. Id = {:?}", id),
154 Err(e) => error!("Connection eventloop error = {:?}", e),
155 }
156 });
157 };
158
159 time::delay_for(accept_loop_delay).await;
160 }
161}
162
163
164pub trait Network: Stream<Item = Result<Packet, rumq_core::Error>> + Sink<Packet, Error = io::Error> + Unpin + Send {}
165impl<T> Network for T where T: Stream<Item = Result<Packet, rumq_core::Error>> + Sink<Packet, Error = io::Error> + Unpin + Send {}
166
167#[tokio::main(core_threads = 1)]
168async fn router(rx: Receiver<(String, router::RouterMessage)>) {
169 let mut router = router::Router::new(rx);
170 if let Err(e) = router.start().await {
171 error!("Router stopped. Error = {:?}", e);
172 }
173}
174
175pub struct Broker {
176 config: Config,
177 router_handle: Sender<(String, router::RouterMessage)>,
178}
179
180pub fn new(config: Config) -> Broker {
181 let (router_tx, router_rx) = channel(100);
182
183 thread::spawn(move || {
184 router(router_rx)
185 });
186
187 Broker {
188 config,
189 router_handle: router_tx
190 }
191}
192
193impl Broker {
194 pub fn new_router_handle(&self) -> Sender<(String, router::RouterMessage)> {
195 self.router_handle.clone()
196 }
197
198 pub async fn start(&mut self) -> Vec<Result<(), task::JoinError>> {
199 let mut servers = Vec::new();
200 let server_configs = self.config.servers.split_off(0);
201
202 for server in server_configs.into_iter() {
203 let config = Arc::new(server);
204 let fut = accept_loop(config, self.router_handle.clone());
205 let o = task::spawn(async {
206 error!("Accept loop returned = {:?}", fut.await);
207 });
208
209 servers.push(o);
210 }
211
212 join_all(servers).await
213 }
214}
215
216#[cfg(test)]
217mod test {
218 #[test]
219 fn accept_loop_rate_limits_incoming_connections() {}
220
221 #[test]
222 fn accept_loop_should_not_allow_more_than_maximum_connections() {}
223
224 #[test]
225 fn accept_loop_should_accept_new_connection_when_a_client_disconnects_after_max_connections() {}
226
227 #[test]
228 fn client_loop_should_error_if_connect_packet_is_not_received_in_time() {}
229}