1use abol_core::packet::{MAX_PACKET_SIZE, Packet};
2use abol_core::{Cidr, Request, Response};
3use abol_rt::net::AsyncUdpSocket;
4use abol_rt::{Executor, Runtime, YieldNow};
5use anyhow::{Context, anyhow};
6use bytes::Bytes;
7use dashmap::DashSet;
8use moka::future::Cache;
9use std::error::Error;
10use std::future::Future;
11use std::net::{IpAddr, SocketAddr};
12use std::pin::Pin;
13use std::sync::Arc;
14use std::sync::atomic::{AtomicUsize, Ordering};
15use std::task::Poll;
16
17pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
18pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
19pub type SecretResult = Result<Vec<(Cidr, Vec<u8>)>, BoxError>;
20
21type HandlerResult<T> = Result<T, Box<dyn Error + Send + Sync>>;
22
23type SecretBytes = Arc<[u8]>;
24type SecretEntry = (Cidr, SecretBytes);
25type SecretList = Vec<SecretEntry>;
26type SharedSecrets = Arc<SecretList>;
27type SecretCache = Cache<(), SharedSecrets>;
28
29#[derive(PartialEq, Eq, Hash, Clone, Debug)]
30struct RequestKey {
31 addr: SocketAddr,
32 identifier: u8,
33}
34
35pub trait SecretSource: Send + Sync + 'static {
43 fn get_all_secrets(
53 &self,
54 ) -> impl Future<Output = Result<Vec<(Cidr, Vec<u8>)>, BoxError>> + Send;
55}
56
57pub trait SecretSourceExt: Send + Sync + 'static {
58 fn get_all_secrets_boxed(&self) -> BoxFuture<'_, SecretResult>;
59}
60
61impl<T: SecretSource> SecretSourceExt for T {
62 fn get_all_secrets_boxed(&self) -> BoxFuture<'_, SecretResult> {
63 Box::pin(self.get_all_secrets())
64 }
65}
66
67pub trait SecretProvider: Send + Sync + 'static {
68 fn get_secret(&self, client_ip: IpAddr) -> impl Future<Output = Option<Arc<[u8]>>> + Send;
69}
70
71pub trait SecretProviderExt: Send + Sync + 'static {
72 fn get_secret_boxed(&self, client_ip: IpAddr) -> BoxFuture<'_, Option<Arc<[u8]>>>;
73}
74
75impl<T: SecretProvider> SecretProviderExt for T {
76 fn get_secret_boxed(&self, client_ip: IpAddr) -> BoxFuture<'_, Option<Arc<[u8]>>> {
77 Box::pin(self.get_secret(client_ip))
78 }
79}
80
81pub struct SecretManager {
87 cache: SecretCache,
89 source: Arc<dyn SecretSourceExt>,
91}
92
93impl SecretManager {
94 pub fn new(source: Arc<dyn SecretSourceExt>, cache_ttl_secs: u64) -> Self {
106 let cache = Cache::builder()
107 .max_capacity(1000) .time_to_live(std::time::Duration::from_secs(cache_ttl_secs))
109 .build();
110 Self { cache, source }
111 }
112 async fn get_secret_internal(&self, client_ip: IpAddr) -> Option<Arc<[u8]>> {
113 let table = self
114 .cache
115 .get_with((), async { self.reload().await.unwrap_or_default() })
116 .await;
117
118 table
119 .iter()
120 .find(|(cidr, _)| cidr.contains(&client_ip))
121 .map(|(_, secret)| secret.clone())
122 }
123
124 async fn reload(&self) -> Result<Arc<Vec<(Cidr, Arc<[u8]>)>>, BoxError> {
125 let entries = self.source.get_all_secrets_boxed().await?;
126 let arc_entries = entries
127 .into_iter()
128 .map(|(cidr, secret)| (cidr, Arc::from(secret)))
129 .collect();
130 Ok(Arc::new(arc_entries))
131 }
132}
133
134impl SecretProvider for SecretManager {
135 async fn get_secret(&self, client_ip: IpAddr) -> Option<Arc<[u8]>> {
136 self.get_secret_internal(client_ip).await
137 }
138}
139
140pub trait Handler: Send + Sync + 'static {
149 fn handle(
157 &self,
158 request: Request,
159 ) -> impl std::future::Future<Output = HandlerResult<Response>> + Send;
160}
161
162pub struct HandlerFn<F>(pub F);
165
166impl<F, Fut> Handler for HandlerFn<F>
167where
168 F: Fn(Request) -> Fut + Send + Sync + 'static,
169 Fut: std::future::Future<Output = HandlerResult<Response>> + Send + 'static,
170{
171 async fn handle(&self, request: Request) -> HandlerResult<Response> {
173 (self.0)(request).await
174 }
175}
176
177pub(crate) struct ServerContext<S, H>
178where
179 S: SecretProvider,
180 H: Handler,
181{
182 pub secret_provider: Arc<S>,
183 pub handler: Arc<H>,
184 pub undergoing_requests: Arc<DashSet<RequestKey>>,
185 pub active_tasks: Arc<AtomicUsize>,
186}
187
188pub struct Server<S, H, R>
200where
201 S: SecretProvider,
202 H: Handler,
203 R: Runtime,
204{
205 runtime: Arc<R>,
207 socket: Arc<R::Socket>,
209 secret_provider: Arc<S>,
211 handler: Arc<H>,
213 undergoing_requests: Arc<DashSet<RequestKey>>,
216 active_tasks: Arc<AtomicUsize>,
218 shutdown_signal: Option<BoxFuture<'static, ()>>,
220}
221
222impl<S, H, R> Server<S, H, R>
223where
224 S: SecretProvider + 'static,
225 H: Handler + 'static,
226 R: Runtime + 'static,
227{
228 pub fn new(runtime: R, socket: R::Socket, secret_provider: S, handler: H) -> Self {
229 Self {
230 runtime: Arc::new(runtime),
231 socket: Arc::new(socket),
232 secret_provider: Arc::new(secret_provider),
233 handler: Arc::new(handler),
234 undergoing_requests: Arc::new(DashSet::new()),
235 active_tasks: Arc::new(AtomicUsize::new(0)),
236 shutdown_signal: None,
237 }
238 }
239
240 pub fn local_addr(&self) -> anyhow::Result<SocketAddr> {
241 self.socket.local_addr().map_err(|e| anyhow!(e))
242 }
243
244 pub fn with_graceful_shutdown<F>(mut self, shutdown: F) -> Self
245 where
246 F: std::future::Future<Output = ()> + Send + 'static,
247 {
248 self.shutdown_signal = Some(Box::pin(shutdown));
249 self
250 }
251
252 pub async fn listen_and_serve(self) -> anyhow::Result<()> {
268 let local_addr_str = self.socket.local_addr()?.to_string();
269 let context = Arc::new(ServerContext {
270 secret_provider: Arc::clone(&self.secret_provider),
271 handler: Arc::clone(&self.handler),
272 undergoing_requests: Arc::clone(&self.undergoing_requests),
273 active_tasks: Arc::clone(&self.active_tasks),
274 });
275
276 let mut shutdown = self
277 .shutdown_signal
278 .unwrap_or_else(|| Box::pin(std::future::pending()));
279
280 let mut run_loop_fut = Box::pin(Self::run_loop(
281 Arc::clone(&context),
282 Arc::clone(&self.socket),
283 Arc::clone(&self.runtime),
284 local_addr_str,
285 ));
286
287 let result = std::future::poll_fn(|cx| {
288 if shutdown.as_mut().poll(cx).is_ready() {
289 return Poll::Ready(Ok(()));
290 }
291 if let Poll::Ready(res) = run_loop_fut.as_mut().poll(cx) {
292 return Poll::Ready(res);
293 }
294 Poll::Pending
295 })
296 .await;
297
298 while context.active_tasks.load(Ordering::SeqCst) > 0 {
299 YieldNow::new().await;
300 }
301 result
302 }
303
304 async fn run_loop(
305 context: Arc<ServerContext<S, H>>,
306 socket: Arc<R::Socket>,
307 runtime: Arc<R>,
308 local_addr: String,
309 ) -> anyhow::Result<()> {
310 loop {
311 let mut buf = [0u8; MAX_PACKET_SIZE];
312 let (len, peer_addr) = socket
313 .recv_from(&mut buf)
314 .await
315 .context("Failed to receive")?;
316
317 let data = Bytes::copy_from_slice(&buf[..len]);
318 let ctx = Arc::clone(&context);
319 let sock = Arc::clone(&socket);
320 let l_addr = local_addr.clone();
321 let rt = Arc::clone(&runtime);
322
323 ctx.active_tasks.fetch_add(1, Ordering::SeqCst);
324
325 rt.executor().execute(Box::pin(async move {
326 let _guard = TaskGuard::new(Arc::clone(&ctx));
327
328 let secret = match ctx.secret_provider.get_secret(peer_addr.ip()).await {
329 Some(s) => s,
330 None => return,
331 };
332
333 let packet = match Packet::parse_packet(data, Arc::clone(&secret)) {
334 Ok(p) => p,
335 Err(_) => return,
336 };
337
338 let key = RequestKey {
339 addr: peer_addr,
340 identifier: packet.identifier,
341 };
342 if !ctx.undergoing_requests.insert(key.clone()) {
343 return;
344 }
345
346 let _ = Self::process(&ctx, packet, l_addr, peer_addr, sock).await;
347 ctx.undergoing_requests.remove(&key);
348 }));
349 }
350 }
351
352 async fn process(
353 ctx: &ServerContext<S, H>,
354 packet: Packet,
355 local_addr: String,
356 peer_addr: SocketAddr,
357 socket: Arc<R::Socket>,
358 ) -> anyhow::Result<()> {
359 if !packet.verify_request() {
360 return Err(anyhow!("Invalid authenticator from {}", peer_addr));
361 }
362
363 let request = Request {
364 local_addr,
365 remote_addr: peer_addr.to_string(),
366 packet,
367 };
368
369 let response = ctx
370 .handler
371 .handle(request)
372 .await
373 .map_err(|e| anyhow!("Handler error: {:?}", e))?;
374
375 let encoded = response.packet.encode().context("Encoding failed")?;
376
377 socket
378 .send_to(&encoded, peer_addr)
379 .await
380 .context("UDP send failed")?;
381
382 Ok(())
383 }
384}
385struct TaskGuard<S, H>
386where
387 S: SecretProvider,
388 H: Handler,
389{
390 context: Arc<ServerContext<S, H>>,
391}
392
393impl<S, H> TaskGuard<S, H>
394where
395 S: SecretProvider,
396 H: Handler,
397{
398 fn new(context: Arc<ServerContext<S, H>>) -> Self {
399 Self { context }
400 }
401}
402
403impl<S, H> Drop for TaskGuard<S, H>
404where
405 S: SecretProvider,
406 H: Handler,
407{
408 fn drop(&mut self) {
409 self.context.active_tasks.fetch_sub(1, Ordering::SeqCst);
410 }
411}