Skip to main content

iec104/
server.rs

1use std::{collections::HashMap, net::SocketAddr, sync::Arc};
2
3use crate::{
4	Connection, asdu::Asdu, config::ServerConfig, error::Error,
5	receive_handler::ReceiveHandlerCallback, server::connection_handler::ConnectionHandler,
6};
7
8mod connection_handler;
9pub mod error;
10
11use snafu::{ResultExt as _, whatever};
12use tokio::{
13	net::TcpListener,
14	sync::{mpsc, oneshot},
15};
16use tokio_native_tls::{TlsAcceptor, native_tls::Identity};
17
18/// Identifies a connection so it can be removed from the list when it
19/// terminates.
20#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
21pub struct ConnectionId(u64);
22
23impl ConnectionId {
24	const fn new() -> Self {
25		Self(0)
26	}
27	const fn next(self) -> Self {
28		Self(self.0.wrapping_add(1))
29	}
30}
31
32#[async_trait::async_trait]
33pub trait ServerCallback {
34	async fn on_new_objects(&self, asdu: Asdu, id: ConnectionId, address: SocketAddr);
35	async fn on_new_connection(&self, address: SocketAddr) {
36		tracing::debug!("New connection from {address}");
37	}
38	async fn on_connection_started(&self, id: ConnectionId, address: SocketAddr) {
39		tracing::debug!("Connection {id:?} from {address} started");
40	}
41	async fn on_connection_stopped(&self, id: ConnectionId, address: SocketAddr) {
42		tracing::debug!("Connection {id:?} from {address} stopped");
43	}
44	async fn on_error(&self, error: &Error) {
45		tracing::debug!("Error: {error}");
46	}
47}
48
49struct InnerServerCallback<C: ServerCallback + Send + Sync> {
50	id: ConnectionId,
51	address: SocketAddr,
52	callback: Arc<C>,
53}
54
55#[async_trait::async_trait]
56impl<C: ServerCallback + Send + Sync> ReceiveHandlerCallback for InnerServerCallback<C> {
57	async fn on_new_objects(&self, asdu: Asdu) {
58		self.callback.on_new_objects(asdu, self.id, self.address).await;
59	}
60}
61
62impl<C: ServerCallback + Send + Sync> InnerServerCallback<C> {
63	async fn on_connection_started(&self, id: ConnectionId, address: SocketAddr) {
64		self.callback.on_connection_started(id, address).await;
65	}
66	async fn on_error(&self, error: &Error) {
67		self.callback.on_error(error).await;
68	}
69}
70
71/// Commands for the server to send to the inner server
72#[derive(Debug)]
73pub enum ServerCommand {
74	SingleConnection {
75		id: ConnectionId,
76		asdu: Asdu,
77		tx: oneshot::Sender<Result<(), error::ServerError>>,
78	},
79	Broadcast {
80		asdu: Asdu,
81		tx: oneshot::Sender<Result<(), error::ServerError>>,
82	},
83}
84
85#[derive(Debug, Clone)]
86pub struct Server {
87	tx: mpsc::Sender<ServerCommand>,
88}
89
90struct InnerServer<C: ServerCallback + Send + Sync + 'static> {
91	callback: Arc<C>,
92	config: ServerConfig,
93	connections: HashMap<ConnectionId, ConnectionHandler>,
94	rx: mpsc::Receiver<ServerCommand>,
95	acceptor: Option<TlsAcceptor>,
96	listener: TcpListener,
97}
98
99impl Server {
100	pub async fn start<C: ServerCallback + Send + Sync + 'static>(
101		config: ServerConfig,
102		callback: C,
103	) -> Result<Self, Error> {
104		let (tx, rx) = mpsc::channel(1024);
105		let inner_server = InnerServer::connect(config, Arc::new(callback), rx).await?;
106		tokio::spawn(async move {
107			let _ = inner_server
108				.start()
109				.await
110				.inspect_err(|e| tracing::error!("Error in running server: {e}"));
111		});
112		Ok(Self { tx })
113	}
114
115	/// Like [`Self::start`], but also returns the socket address the TCP
116	/// listener bound to (for example after setting
117	/// [`crate::config::ServerConfig::port`] to `0` for an ephemeral port).
118	pub async fn start_with_listen_addr<C: ServerCallback + Send + Sync + 'static>(
119		config: ServerConfig,
120		callback: C,
121	) -> Result<(Self, SocketAddr), Error> {
122		let (tx, rx) = mpsc::channel(1024);
123		let inner_server = InnerServer::connect(config, Arc::new(callback), rx).await?;
124		let listen_addr = inner_server
125			.listen_local_addr()
126			.with_whatever_context(|e| format!("Unable to read listener local address: {e}"))?;
127		tokio::spawn(async move {
128			let _ = inner_server
129				.start()
130				.await
131				.inspect_err(|e| tracing::error!("Error in running server: {e}"));
132		});
133		Ok((Self { tx }, listen_addr))
134	}
135
136	pub async fn send_asdu(&self, id: ConnectionId, asdu: Asdu) -> Result<(), error::ServerError> {
137		let (tx, rx) = oneshot::channel();
138		self.tx
139			.send(ServerCommand::SingleConnection { id, asdu, tx })
140			.await
141			.context(error::SendCommand)?;
142		rx.await.context(error::ReceiveResponse)?
143	}
144	pub async fn broadcast_asdu(&self, asdu: Asdu) -> Result<(), error::ServerError> {
145		let (tx, rx) = oneshot::channel();
146		self.tx.send(ServerCommand::Broadcast { asdu, tx }).await.context(error::SendCommand)?;
147		rx.await.context(error::ReceiveResponse)?
148	}
149}
150
151impl<C: ServerCallback + Send + Sync + 'static> InnerServer<C> {
152	async fn connect(
153		config: ServerConfig,
154		callback: Arc<C>,
155		rx: mpsc::Receiver<ServerCommand>,
156	) -> Result<Self, Error> {
157		let bind_addr = format!("{}:{}", config.address, config.port);
158		let listener = TcpListener::bind(&bind_addr)
159			.await
160			.with_whatever_context(|_| format!("Unable to bind to address '{bind_addr}'"))?;
161
162		let acceptor = config
163			.tls
164			.as_ref()
165			.map(|tls| {
166				let identity = Identity::from_pkcs8(
167					std::fs::read(tls.server_certificate.clone())
168						.whatever_context::<_, Error>("Failed to read server certificate")?
169						.as_slice(),
170					std::fs::read(tls.server_key.clone())
171						.whatever_context::<_, Error>("Failed to read server key")?
172						.as_slice(),
173				)
174				.whatever_context("Error creating server TLS identity")?;
175				tokio_native_tls::native_tls::TlsAcceptor::new(identity)
176					.map(TlsAcceptor::from)
177					.whatever_context("Error crating TLS acceptor")
178			})
179			.transpose()?;
180		Ok(Self { callback, config, connections: HashMap::new(), rx, acceptor, listener })
181	}
182
183	fn listen_local_addr(&self) -> std::io::Result<SocketAddr> {
184		self.listener.local_addr()
185	}
186
187	async fn start(mut self) -> Result<(), Error> {
188		let (closed_tx, mut closed_rx) = mpsc::unbounded_channel::<ConnectionId>();
189		let mut next_id = ConnectionId::new();
190
191		loop {
192			tokio::select! {
193				accept_result = self.listener.accept() => {
194					let (socket, address) =
195						accept_result.whatever_context("Error accepting a new connection")?;
196					let connection = if let Some(ref acceptor) = self.acceptor {
197						Connection::Tls(
198							acceptor
199								.accept(socket)
200								.await
201								.whatever_context("Error doing the TLS handshake")?,
202						)
203					} else {
204						Connection::Tcp(socket)
205					};
206					// TODO: Do we want to accept or decline the connection based on the callback response?
207					self.callback.on_new_connection(address).await;
208
209					// TODO: Do want to have a max number of connections?
210
211					// Gets the next available connection id
212					while self.connections.contains_key(&next_id) {
213						next_id = next_id.next();
214					}
215
216					let id = next_id;
217					next_id = next_id.next();
218					let callback = InnerServerCallback { id, address, callback: self.callback.clone() };
219					let connection_handler = ConnectionHandler::start(id, address, connection, self.config.clone(), callback, closed_tx.clone());
220					self.connections.insert(id, connection_handler);
221				}
222				closed_id = closed_rx.recv() => {
223					if let Some(id) = closed_id {
224						let connection_handler = self.connections.remove(&id);
225						match connection_handler {
226							Some(connection_handler) => {
227								self.callback.on_connection_stopped(id, connection_handler.address).await;
228							}
229							None => {
230								tracing::error!("Connection {id:?} not found");
231							}
232						}
233					}
234				}
235				command = self.rx.recv() => {
236					let Some(command) = command else {
237						tracing::error!("Error receiving command. Aborting...");
238						whatever!("");
239					};
240					match command {
241						ServerCommand::SingleConnection { id, asdu, tx } => {
242							let connection_handler = self.connections.get(&id);
243							let res = match connection_handler {
244								Some(connection_handler) => {
245									connection_handler.send_asdu(asdu).await
246								}
247								None => {
248									error::ConnectionNotFound { id }.fail()
249								}
250							};
251							let _ = tx.send(res).inspect_err(|e| tracing::error!("Error sending response to single connection command: {e:?}"));
252						},
253						ServerCommand::Broadcast { asdu, tx } => {
254							let mut res = Ok(());
255							for connection_handler in self.connections.values().filter(|c| c.is_started()) {
256								if let Err(e) = connection_handler.send_asdu(asdu.clone()).await {
257									res = Err(e);
258									break;
259								}
260							}
261							let _ = tx.send(res).inspect_err(|e| tracing::error!("Error sending response to broadcast command: {e:?}"));
262						},
263					}
264				}
265			}
266		}
267	}
268}