1use crate::args::{LogArgs, UserArgs};
2use crate::quic::{load_root_store, read_certs, server_config, ConfigOptions};
3use crate::topic::config::TopicConfig;
4use crate::topic::{pubsub, reqrep, Sender, Socket};
5use anyhow::{anyhow, bail, Context, Result};
6use futures::{future::join_all, stream::FuturesUnordered, SinkExt, StreamExt};
7use log::{error, info};
8use quinn::{Connecting, Connection, Endpoint, IdleTimeout, VarInt};
9use selium_log::config::{FlushPolicy, LogConfig};
10use selium_log::MessageLog;
11use selium_protocol::error_codes::INVALID_TOPIC_NAME;
12use selium_protocol::{error_codes, BiStream, ErrorPayload, Frame, TopicName};
13use std::net::SocketAddr;
14use std::time::Duration;
15use std::{collections::HashMap, sync::Arc};
16use tokio::{sync::Mutex, task::JoinHandle};
17
18pub(crate) type SharedTopics = Arc<Mutex<HashMap<TopicName, Sender>>>;
19type SharedTopicHandles = Arc<Mutex<FuturesUnordered<JoinHandle<()>>>>;
20
21pub struct Server {
22 topics: SharedTopics,
23 topic_handles: SharedTopicHandles,
24 log_args: Arc<LogArgs>,
25 endpoint: Endpoint,
26}
27
28impl Server {
29 pub async fn listen(&self) -> Result<()> {
30 loop {
31 tokio::select! {
32 Some(conn) = self.endpoint.accept() => {
33 self.connect(conn).await?;
34 },
35 Ok(()) = tokio::signal::ctrl_c() => {
36 self.shutdown().await?;
37 break;
38 }
39 }
40 }
41
42 Ok(())
43 }
44
45 pub fn addr(&self) -> Result<SocketAddr> {
46 let addr = self.endpoint.local_addr()?;
47 Ok(addr)
48 }
49
50 async fn connect(&self, conn: Connecting) -> Result<()> {
51 info!("connection incoming");
52 let topics_clone = self.topics.clone();
53 let topic_handles = self.topic_handles.clone();
54 let log_args = self.log_args.clone();
55
56 tokio::spawn(async move {
57 if let Err(e) = handle_connection(topics_clone, topic_handles, conn, log_args).await {
58 error!("connection failed: {:?}", e);
59 }
60 });
61
62 Ok(())
63 }
64
65 async fn shutdown(&self) -> Result<()> {
66 info!("Shutdown signal received: preparing to gracefully shutdown.");
67 self.endpoint.reject_new_connections();
68
69 let mut topics = self.topics.lock().await;
70 let mut topic_handles = self.topic_handles.lock().await;
71
72 topics.values_mut().for_each(|t| t.close_channel());
73 join_all(topic_handles.iter_mut()).await;
74
75 self.endpoint.close(
76 VarInt::from_u32(error_codes::SHUTDOWN),
77 b"Scheduled shutdown.",
78 );
79 self.endpoint.wait_idle().await;
80
81 Ok(())
82 }
83}
84
85impl TryFrom<UserArgs> for Server {
86 type Error = anyhow::Error;
87
88 fn try_from(args: UserArgs) -> Result<Self, Self::Error> {
89 let root_store = load_root_store(args.cert.ca)?;
90 let (certs, key) = read_certs(args.cert.cert, args.cert.key)?;
91 let log_args = Arc::new(args.log);
92
93 let opts = ConfigOptions {
94 keylog: args.keylog,
95 stateless_retry: args.stateless_retry,
96 max_idle_timeout: IdleTimeout::from(VarInt::from_u32(args.max_idle_timeout)),
97 };
98
99 let config = server_config(root_store, certs, key, opts)?;
100 let endpoint = Endpoint::server(config, args.bind_addr)?;
101
102 let topics = Arc::new(Mutex::new(HashMap::new()));
104 let topic_handles = Arc::new(Mutex::new(FuturesUnordered::new()));
105
106 Ok(Self {
107 topics,
108 topic_handles,
109 log_args,
110 endpoint,
111 })
112 }
113}
114
115async fn handle_connection(
116 topics: SharedTopics,
117 topic_handles: SharedTopicHandles,
118 conn: quinn::Connecting,
119 log_args: Arc<LogArgs>,
120) -> Result<()> {
121 let connection = conn.await?;
122 info!(
123 "Connection {} - {}",
124 connection.remote_address(),
125 connection
126 .handshake_data()
127 .unwrap()
128 .downcast::<quinn::crypto::rustls::HandshakeData>()
129 .unwrap()
130 .protocol
131 .map_or_else(
132 || "<none>".into(),
133 |x| String::from_utf8_lossy(&x).into_owned()
134 )
135 );
136
137 loop {
138 let connection = connection.clone();
139 let stream = connection.accept_bi().await;
140 let stream = match stream {
141 Err(quinn::ConnectionError::ApplicationClosed { .. }) => {
142 info!("Connection closed ({})", connection.remote_address());
143 return Ok(());
144 }
145 Err(e) => {
146 bail!(e)
147 }
148 Ok(stream) => BiStream::from(stream),
149 };
150
151 let topics_clone = topics.clone();
152 let topic_handles_clone = topic_handles.clone();
153 let log_args = log_args.clone();
154
155 tokio::spawn(async move {
156 if let Err(e) = handle_stream(
157 topics_clone,
158 topic_handles_clone,
159 stream,
160 connection,
161 log_args,
162 )
163 .await
164 {
165 error!("Request failed: {:?}", e);
166 }
167 });
168 }
169}
170
171async fn handle_stream(
172 topics: SharedTopics,
173 topic_handles: SharedTopicHandles,
174 mut stream: BiStream,
175 _connection: Connection,
176 log_args: Arc<LogArgs>,
177) -> Result<()> {
178 if let Some(result) = stream.next().await {
180 let frame = result?;
181 let topic = frame.get_topic().ok_or(anyhow!("Expected header frame"))?;
182
183 #[cfg(feature = "__cloud")]
184 {
185 use crate::cloud::do_cloud_auth;
186 use log::debug;
187 use selium_protocol::error_codes::CLOUD_AUTH_FAILED;
188
189 match do_cloud_auth(&_connection, topic, &topics).await {
190 Ok(_) => stream.send(Frame::Ok).await?,
191 Err(e) => {
192 debug!("Cloud authentication error: {e:?}");
193
194 let payload = ErrorPayload {
195 code: CLOUD_AUTH_FAILED,
196 message: e.to_string().into(),
197 };
198
199 stream.send(Frame::Error(payload)).await?;
200
201 return Ok(());
202 }
203 }
204 }
205 #[cfg(not(feature = "__cloud"))]
206 {
207 if !topic.is_valid() {
209 let payload = ErrorPayload {
210 code: INVALID_TOPIC_NAME,
211 message: "Invalid topic name".into(),
212 };
213 stream.send(Frame::Error(payload)).await?;
214 return Ok(());
215 }
216 stream.send(Frame::Ok).await?;
217 }
218
219 let mut ts = topics.lock().await;
220
221 if !ts.contains_key(topic) {
223 match frame {
224 Frame::RegisterPublisher(_) | Frame::RegisterSubscriber(_) => {
225 let retention_period = frame.retention_policy().unwrap();
226 let topic_path = topic.to_string();
227 let segments_path = log_args
228 .log_segments_directory
229 .join(topic_path.trim_matches('/'));
230
231 let mut flush_policy = FlushPolicy::default()
232 .interval(Duration::from_millis(log_args.flush_policy_interval));
233
234 if let Some(num_writes) = log_args.flush_policy_num_writes {
235 flush_policy = flush_policy.number_of_writes(num_writes);
236 }
237
238 let topic_config = Arc::new(TopicConfig::new(Duration::from_millis(
239 log_args.subscriber_polling_interval,
240 )));
241
242 let log_config = Arc::new(
243 LogConfig::from_path(segments_path)
244 .max_index_entries(log_args.log_maximum_entries)
245 .retention_period(Duration::from_millis(retention_period))
246 .cleaner_interval(Duration::from_millis(log_args.log_cleaner_interval))
247 .flush_policy(flush_policy),
248 );
249
250 let log = MessageLog::open(log_config).await?;
251 let (mut fut, tx) = pubsub::Topic::pair(log, topic_config);
252
253 let handle = tokio::spawn(async move {
254 fut.run().await.unwrap();
255 });
256
257 topic_handles.lock().await.push(handle);
258 ts.insert(topic.clone(), Sender::Pubsub(tx));
259 }
260 Frame::RegisterReplier(_) | Frame::RegisterRequestor(_) => {
261 let (fut, tx) = reqrep::Topic::pair();
262 let handle = tokio::spawn(fut);
263
264 topic_handles.lock().await.push(handle);
265 ts.insert(topic.clone(), Sender::ReqRep(tx));
266 }
267 _ => unreachable!(), };
269 }
270
271 let tx = ts.get_mut(topic).unwrap();
272
273 match frame {
274 Frame::RegisterPublisher(_) => {
275 let (_, read) = stream.split();
276 tx.send(Socket::Pubsub(pubsub::Socket::Stream(Box::pin(read))))
277 .await
278 .context("Failed to add Publisher stream")?;
279 }
280 Frame::RegisterSubscriber(payload) => {
281 let (write, _) = stream.split();
282 tx.send(Socket::Pubsub(pubsub::Socket::Sink(
283 Box::pin(write),
284 payload.offset,
285 )))
286 .await
287 .context("Failed to add Subscriber sink")?;
288 }
289 Frame::RegisterReplier(_) => {
290 let (si, st) = stream.split();
291 tx.send(Socket::Reqrep(reqrep::Socket::Server((
292 Box::pin(si),
293 Box::pin(st),
294 ))))
295 .await
296 .context("Failed to add Replier")?;
297 }
298 Frame::RegisterRequestor(_) => {
299 let (si, st) = stream.split();
300 tx.send(Socket::Reqrep(reqrep::Socket::Client((
301 Box::pin(si),
302 Box::pin(st),
303 ))))
304 .await
305 .context("Failed to add Requestor")?;
306 }
307 _ => unreachable!(), }
309 } else {
310 info!("Stream closed");
311 }
312
313 Ok(())
314}