selium_server/
server.rs

1use crate::args::{LogArgs, UserArgs};
2use crate::quic::{load_root_store, read_certs, server_config, ConfigOptions};
3use crate::topic::config::TopicConfig;
4use crate::topic::{pubsub, reqrep, Sender, Socket};
5use anyhow::{anyhow, bail, Context, Result};
6use futures::{future::join_all, stream::FuturesUnordered, SinkExt, StreamExt};
7use log::{error, info};
8use quinn::{Connecting, Connection, Endpoint, IdleTimeout, VarInt};
9use selium_log::config::{FlushPolicy, LogConfig};
10use selium_log::MessageLog;
11use selium_protocol::error_codes::INVALID_TOPIC_NAME;
12use selium_protocol::{error_codes, BiStream, ErrorPayload, Frame, TopicName};
13use std::net::SocketAddr;
14use std::time::Duration;
15use std::{collections::HashMap, sync::Arc};
16use tokio::{sync::Mutex, task::JoinHandle};
17
18pub(crate) type SharedTopics = Arc<Mutex<HashMap<TopicName, Sender>>>;
19type SharedTopicHandles = Arc<Mutex<FuturesUnordered<JoinHandle<()>>>>;
20
21pub struct Server {
22    topics: SharedTopics,
23    topic_handles: SharedTopicHandles,
24    log_args: Arc<LogArgs>,
25    endpoint: Endpoint,
26}
27
28impl Server {
29    pub async fn listen(&self) -> Result<()> {
30        loop {
31            tokio::select! {
32                Some(conn) = self.endpoint.accept() => {
33                    self.connect(conn).await?;
34                },
35                Ok(()) = tokio::signal::ctrl_c() => {
36                    self.shutdown().await?;
37                    break;
38                }
39            }
40        }
41
42        Ok(())
43    }
44
45    pub fn addr(&self) -> Result<SocketAddr> {
46        let addr = self.endpoint.local_addr()?;
47        Ok(addr)
48    }
49
50    async fn connect(&self, conn: Connecting) -> Result<()> {
51        info!("connection incoming");
52        let topics_clone = self.topics.clone();
53        let topic_handles = self.topic_handles.clone();
54        let log_args = self.log_args.clone();
55
56        tokio::spawn(async move {
57            if let Err(e) = handle_connection(topics_clone, topic_handles, conn, log_args).await {
58                error!("connection failed: {:?}", e);
59            }
60        });
61
62        Ok(())
63    }
64
65    async fn shutdown(&self) -> Result<()> {
66        info!("Shutdown signal received: preparing to gracefully shutdown.");
67        self.endpoint.reject_new_connections();
68
69        let mut topics = self.topics.lock().await;
70        let mut topic_handles = self.topic_handles.lock().await;
71
72        topics.values_mut().for_each(|t| t.close_channel());
73        join_all(topic_handles.iter_mut()).await;
74
75        self.endpoint.close(
76            VarInt::from_u32(error_codes::SHUTDOWN),
77            b"Scheduled shutdown.",
78        );
79        self.endpoint.wait_idle().await;
80
81        Ok(())
82    }
83}
84
85impl TryFrom<UserArgs> for Server {
86    type Error = anyhow::Error;
87
88    fn try_from(args: UserArgs) -> Result<Self, Self::Error> {
89        let root_store = load_root_store(args.cert.ca)?;
90        let (certs, key) = read_certs(args.cert.cert, args.cert.key)?;
91        let log_args = Arc::new(args.log);
92
93        let opts = ConfigOptions {
94            keylog: args.keylog,
95            stateless_retry: args.stateless_retry,
96            max_idle_timeout: IdleTimeout::from(VarInt::from_u32(args.max_idle_timeout)),
97        };
98
99        let config = server_config(root_store, certs, key, opts)?;
100        let endpoint = Endpoint::server(config, args.bind_addr)?;
101
102        // Create hash to store message ordering data
103        let topics = Arc::new(Mutex::new(HashMap::new()));
104        let topic_handles = Arc::new(Mutex::new(FuturesUnordered::new()));
105
106        Ok(Self {
107            topics,
108            topic_handles,
109            log_args,
110            endpoint,
111        })
112    }
113}
114
115async fn handle_connection(
116    topics: SharedTopics,
117    topic_handles: SharedTopicHandles,
118    conn: quinn::Connecting,
119    log_args: Arc<LogArgs>,
120) -> Result<()> {
121    let connection = conn.await?;
122    info!(
123        "Connection {} - {}",
124        connection.remote_address(),
125        connection
126            .handshake_data()
127            .unwrap()
128            .downcast::<quinn::crypto::rustls::HandshakeData>()
129            .unwrap()
130            .protocol
131            .map_or_else(
132                || "<none>".into(),
133                |x| String::from_utf8_lossy(&x).into_owned()
134            )
135    );
136
137    loop {
138        let connection = connection.clone();
139        let stream = connection.accept_bi().await;
140        let stream = match stream {
141            Err(quinn::ConnectionError::ApplicationClosed { .. }) => {
142                info!("Connection closed ({})", connection.remote_address());
143                return Ok(());
144            }
145            Err(e) => {
146                bail!(e)
147            }
148            Ok(stream) => BiStream::from(stream),
149        };
150
151        let topics_clone = topics.clone();
152        let topic_handles_clone = topic_handles.clone();
153        let log_args = log_args.clone();
154
155        tokio::spawn(async move {
156            if let Err(e) = handle_stream(
157                topics_clone,
158                topic_handles_clone,
159                stream,
160                connection,
161                log_args,
162            )
163            .await
164            {
165                error!("Request failed: {:?}", e);
166            }
167        });
168    }
169}
170
171async fn handle_stream(
172    topics: SharedTopics,
173    topic_handles: SharedTopicHandles,
174    mut stream: BiStream,
175    _connection: Connection,
176    log_args: Arc<LogArgs>,
177) -> Result<()> {
178    // Receive header
179    if let Some(result) = stream.next().await {
180        let frame = result?;
181        let topic = frame.get_topic().ok_or(anyhow!("Expected header frame"))?;
182
183        #[cfg(feature = "__cloud")]
184        {
185            use crate::cloud::do_cloud_auth;
186            use log::debug;
187            use selium_protocol::error_codes::CLOUD_AUTH_FAILED;
188
189            match do_cloud_auth(&_connection, topic, &topics).await {
190                Ok(_) => stream.send(Frame::Ok).await?,
191                Err(e) => {
192                    debug!("Cloud authentication error: {e:?}");
193
194                    let payload = ErrorPayload {
195                        code: CLOUD_AUTH_FAILED,
196                        message: e.to_string().into(),
197                    };
198
199                    stream.send(Frame::Error(payload)).await?;
200
201                    return Ok(());
202                }
203            }
204        }
205        #[cfg(not(feature = "__cloud"))]
206        {
207            // Note this can only occur if someone circumvents the client lib
208            if !topic.is_valid() {
209                let payload = ErrorPayload {
210                    code: INVALID_TOPIC_NAME,
211                    message: "Invalid topic name".into(),
212                };
213                stream.send(Frame::Error(payload)).await?;
214                return Ok(());
215            }
216            stream.send(Frame::Ok).await?;
217        }
218
219        let mut ts = topics.lock().await;
220
221        // Spawn new topic if it doesn't exist yet
222        if !ts.contains_key(topic) {
223            match frame {
224                Frame::RegisterPublisher(_) | Frame::RegisterSubscriber(_) => {
225                    let retention_period = frame.retention_policy().unwrap();
226                    let topic_path = topic.to_string();
227                    let segments_path = log_args
228                        .log_segments_directory
229                        .join(topic_path.trim_matches('/'));
230
231                    let mut flush_policy = FlushPolicy::default()
232                        .interval(Duration::from_millis(log_args.flush_policy_interval));
233
234                    if let Some(num_writes) = log_args.flush_policy_num_writes {
235                        flush_policy = flush_policy.number_of_writes(num_writes);
236                    }
237
238                    let topic_config = Arc::new(TopicConfig::new(Duration::from_millis(
239                        log_args.subscriber_polling_interval,
240                    )));
241
242                    let log_config = Arc::new(
243                        LogConfig::from_path(segments_path)
244                            .max_index_entries(log_args.log_maximum_entries)
245                            .retention_period(Duration::from_millis(retention_period))
246                            .cleaner_interval(Duration::from_millis(log_args.log_cleaner_interval))
247                            .flush_policy(flush_policy),
248                    );
249
250                    let log = MessageLog::open(log_config).await?;
251                    let (mut fut, tx) = pubsub::Topic::pair(log, topic_config);
252
253                    let handle = tokio::spawn(async move {
254                        fut.run().await.unwrap();
255                    });
256
257                    topic_handles.lock().await.push(handle);
258                    ts.insert(topic.clone(), Sender::Pubsub(tx));
259                }
260                Frame::RegisterReplier(_) | Frame::RegisterRequestor(_) => {
261                    let (fut, tx) = reqrep::Topic::pair();
262                    let handle = tokio::spawn(fut);
263
264                    topic_handles.lock().await.push(handle);
265                    ts.insert(topic.clone(), Sender::ReqRep(tx));
266                }
267                _ => unreachable!(), // because of `topic` instantiation
268            };
269        }
270
271        let tx = ts.get_mut(topic).unwrap();
272
273        match frame {
274            Frame::RegisterPublisher(_) => {
275                let (_, read) = stream.split();
276                tx.send(Socket::Pubsub(pubsub::Socket::Stream(Box::pin(read))))
277                    .await
278                    .context("Failed to add Publisher stream")?;
279            }
280            Frame::RegisterSubscriber(payload) => {
281                let (write, _) = stream.split();
282                tx.send(Socket::Pubsub(pubsub::Socket::Sink(
283                    Box::pin(write),
284                    payload.offset,
285                )))
286                .await
287                .context("Failed to add Subscriber sink")?;
288            }
289            Frame::RegisterReplier(_) => {
290                let (si, st) = stream.split();
291                tx.send(Socket::Reqrep(reqrep::Socket::Server((
292                    Box::pin(si),
293                    Box::pin(st),
294                ))))
295                .await
296                .context("Failed to add Replier")?;
297            }
298            Frame::RegisterRequestor(_) => {
299                let (si, st) = stream.split();
300                tx.send(Socket::Reqrep(reqrep::Socket::Client((
301                    Box::pin(si),
302                    Box::pin(st),
303                ))))
304                .await
305                .context("Failed to add Requestor")?;
306            }
307            _ => unreachable!(), // because of `topic` instantiation
308        }
309    } else {
310        info!("Stream closed");
311    }
312
313    Ok(())
314}