1use std::{collections::HashMap, net::SocketAddr, sync::Arc};
2
3use crate::{
4 Connection, asdu::Asdu, config::ServerConfig, error::Error,
5 receive_handler::ReceiveHandlerCallback, server::connection_handler::ConnectionHandler,
6};
7
8mod connection_handler;
9pub mod error;
10
11use snafu::{ResultExt as _, whatever};
12use tokio::{
13 net::TcpListener,
14 sync::{mpsc, oneshot},
15};
16use tokio_native_tls::{TlsAcceptor, native_tls::Identity};
17
18#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
21pub struct ConnectionId(u64);
22
23impl ConnectionId {
24 const fn new() -> Self {
25 Self(0)
26 }
27 const fn next(self) -> Self {
28 Self(self.0.wrapping_add(1))
29 }
30}
31
32#[async_trait::async_trait]
33pub trait ServerCallback {
34 async fn on_new_objects(&self, asdu: Asdu, id: ConnectionId, address: SocketAddr);
35 async fn on_new_connection(&self, address: SocketAddr) {
36 tracing::debug!("New connection from {address}");
37 }
38 async fn on_connection_started(&self, id: ConnectionId, address: SocketAddr) {
39 tracing::debug!("Connection {id:?} from {address} started");
40 }
41 async fn on_connection_stopped(&self, id: ConnectionId, address: SocketAddr) {
42 tracing::debug!("Connection {id:?} from {address} stopped");
43 }
44 async fn on_error(&self, error: &Error) {
45 tracing::debug!("Error: {error}");
46 }
47}
48
49struct InnerServerCallback<C: ServerCallback + Send + Sync> {
50 id: ConnectionId,
51 address: SocketAddr,
52 callback: Arc<C>,
53}
54
55#[async_trait::async_trait]
56impl<C: ServerCallback + Send + Sync> ReceiveHandlerCallback for InnerServerCallback<C> {
57 async fn on_new_objects(&self, asdu: Asdu) {
58 self.callback.on_new_objects(asdu, self.id, self.address).await;
59 }
60}
61
62impl<C: ServerCallback + Send + Sync> InnerServerCallback<C> {
63 async fn on_connection_started(&self, id: ConnectionId, address: SocketAddr) {
64 self.callback.on_connection_started(id, address).await;
65 }
66 async fn on_error(&self, error: &Error) {
67 self.callback.on_error(error).await;
68 }
69}
70
71#[derive(Debug)]
73pub enum ServerCommand {
74 SingleConnection {
75 id: ConnectionId,
76 asdu: Asdu,
77 tx: oneshot::Sender<Result<(), error::ServerError>>,
78 },
79 Broadcast {
80 asdu: Asdu,
81 tx: oneshot::Sender<Result<(), error::ServerError>>,
82 },
83}
84
85#[derive(Debug, Clone)]
86pub struct Server {
87 tx: mpsc::Sender<ServerCommand>,
88}
89
90struct InnerServer<C: ServerCallback + Send + Sync + 'static> {
91 callback: Arc<C>,
92 config: ServerConfig,
93 connections: HashMap<ConnectionId, ConnectionHandler>,
94 rx: mpsc::Receiver<ServerCommand>,
95 acceptor: Option<TlsAcceptor>,
96 listener: TcpListener,
97}
98
99impl Server {
100 pub async fn start<C: ServerCallback + Send + Sync + 'static>(
101 config: ServerConfig,
102 callback: C,
103 ) -> Result<Self, Error> {
104 let (tx, rx) = mpsc::channel(1024);
105 let inner_server = InnerServer::connect(config, Arc::new(callback), rx).await?;
106 tokio::spawn(async move {
107 let _ = inner_server
108 .start()
109 .await
110 .inspect_err(|e| tracing::error!("Error in running server: {e}"));
111 });
112 Ok(Self { tx })
113 }
114
115 pub async fn start_with_listen_addr<C: ServerCallback + Send + Sync + 'static>(
119 config: ServerConfig,
120 callback: C,
121 ) -> Result<(Self, SocketAddr), Error> {
122 let (tx, rx) = mpsc::channel(1024);
123 let inner_server = InnerServer::connect(config, Arc::new(callback), rx).await?;
124 let listen_addr = inner_server
125 .listen_local_addr()
126 .with_whatever_context(|e| format!("Unable to read listener local address: {e}"))?;
127 tokio::spawn(async move {
128 let _ = inner_server
129 .start()
130 .await
131 .inspect_err(|e| tracing::error!("Error in running server: {e}"));
132 });
133 Ok((Self { tx }, listen_addr))
134 }
135
136 pub async fn send_asdu(&self, id: ConnectionId, asdu: Asdu) -> Result<(), error::ServerError> {
137 let (tx, rx) = oneshot::channel();
138 self.tx
139 .send(ServerCommand::SingleConnection { id, asdu, tx })
140 .await
141 .context(error::SendCommand)?;
142 rx.await.context(error::ReceiveResponse)?
143 }
144 pub async fn broadcast_asdu(&self, asdu: Asdu) -> Result<(), error::ServerError> {
145 let (tx, rx) = oneshot::channel();
146 self.tx.send(ServerCommand::Broadcast { asdu, tx }).await.context(error::SendCommand)?;
147 rx.await.context(error::ReceiveResponse)?
148 }
149}
150
151impl<C: ServerCallback + Send + Sync + 'static> InnerServer<C> {
152 async fn connect(
153 config: ServerConfig,
154 callback: Arc<C>,
155 rx: mpsc::Receiver<ServerCommand>,
156 ) -> Result<Self, Error> {
157 let bind_addr = format!("{}:{}", config.address, config.port);
158 let listener = TcpListener::bind(&bind_addr)
159 .await
160 .with_whatever_context(|_| format!("Unable to bind to address '{bind_addr}'"))?;
161
162 let acceptor = config
163 .tls
164 .as_ref()
165 .map(|tls| {
166 let identity = Identity::from_pkcs8(
167 std::fs::read(tls.server_certificate.clone())
168 .whatever_context::<_, Error>("Failed to read server certificate")?
169 .as_slice(),
170 std::fs::read(tls.server_key.clone())
171 .whatever_context::<_, Error>("Failed to read server key")?
172 .as_slice(),
173 )
174 .whatever_context("Error creating server TLS identity")?;
175 tokio_native_tls::native_tls::TlsAcceptor::new(identity)
176 .map(TlsAcceptor::from)
177 .whatever_context("Error crating TLS acceptor")
178 })
179 .transpose()?;
180 Ok(Self { callback, config, connections: HashMap::new(), rx, acceptor, listener })
181 }
182
183 fn listen_local_addr(&self) -> std::io::Result<SocketAddr> {
184 self.listener.local_addr()
185 }
186
187 async fn start(mut self) -> Result<(), Error> {
188 let (closed_tx, mut closed_rx) = mpsc::unbounded_channel::<ConnectionId>();
189 let mut next_id = ConnectionId::new();
190
191 loop {
192 tokio::select! {
193 accept_result = self.listener.accept() => {
194 let (socket, address) =
195 accept_result.whatever_context("Error accepting a new connection")?;
196 let connection = if let Some(ref acceptor) = self.acceptor {
197 Connection::Tls(
198 acceptor
199 .accept(socket)
200 .await
201 .whatever_context("Error doing the TLS handshake")?,
202 )
203 } else {
204 Connection::Tcp(socket)
205 };
206 self.callback.on_new_connection(address).await;
208
209 while self.connections.contains_key(&next_id) {
213 next_id = next_id.next();
214 }
215
216 let id = next_id;
217 next_id = next_id.next();
218 let callback = InnerServerCallback { id, address, callback: self.callback.clone() };
219 let connection_handler = ConnectionHandler::start(id, address, connection, self.config.clone(), callback, closed_tx.clone());
220 self.connections.insert(id, connection_handler);
221 }
222 closed_id = closed_rx.recv() => {
223 if let Some(id) = closed_id {
224 let connection_handler = self.connections.remove(&id);
225 match connection_handler {
226 Some(connection_handler) => {
227 self.callback.on_connection_stopped(id, connection_handler.address).await;
228 }
229 None => {
230 tracing::error!("Connection {id:?} not found");
231 }
232 }
233 }
234 }
235 command = self.rx.recv() => {
236 let Some(command) = command else {
237 tracing::error!("Error receiving command. Aborting...");
238 whatever!("");
239 };
240 match command {
241 ServerCommand::SingleConnection { id, asdu, tx } => {
242 let connection_handler = self.connections.get(&id);
243 let res = match connection_handler {
244 Some(connection_handler) => {
245 connection_handler.send_asdu(asdu).await
246 }
247 None => {
248 error::ConnectionNotFound { id }.fail()
249 }
250 };
251 let _ = tx.send(res).inspect_err(|e| tracing::error!("Error sending response to single connection command: {e:?}"));
252 },
253 ServerCommand::Broadcast { asdu, tx } => {
254 let mut res = Ok(());
255 for connection_handler in self.connections.values().filter(|c| c.is_started()) {
256 if let Err(e) = connection_handler.send_asdu(asdu.clone()).await {
257 res = Err(e);
258 break;
259 }
260 }
261 let _ = tx.send(res).inspect_err(|e| tracing::error!("Error sending response to broadcast command: {e:?}"));
262 },
263 }
264 }
265 }
266 }
267 }
268}