abol_server/
lib.rs

1use abol_core::packet::{MAX_PACKET_SIZE, Packet};
2use abol_core::{Cidr, Request, Response};
3use abol_rt::net::AsyncUdpSocket;
4use abol_rt::{Executor, Runtime, YieldNow};
5use anyhow::{Context, anyhow};
6use bytes::Bytes;
7use dashmap::DashSet;
8use moka::future::Cache;
9use std::error::Error;
10use std::future::Future;
11use std::net::{IpAddr, SocketAddr};
12use std::pin::Pin;
13use std::sync::Arc;
14use std::sync::atomic::{AtomicUsize, Ordering};
15use std::task::Poll;
16
17pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
18pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
19pub type SecretResult = Result<Vec<(Cidr, Vec<u8>)>, BoxError>;
20
21type HandlerResult<T> = Result<T, Box<dyn Error + Send + Sync>>;
22
23type SecretBytes = Arc<[u8]>;
24type SecretEntry = (Cidr, SecretBytes);
25type SecretList = Vec<SecretEntry>;
26type SharedSecrets = Arc<SecretList>;
27type SecretCache = Cache<(), SharedSecrets>;
28
29#[derive(PartialEq, Eq, Hash, Clone, Debug)]
30struct RequestKey {
31    addr: SocketAddr,
32    identifier: u8,
33}
34
35/// A provider that maps incoming client IP ranges to their shared secrets.
36///
37/// In the RADIUS protocol, the server identifies which password (shared secret)
38/// to use based on the source IP address of the Network Access Server (NAS).
39///
40/// Implement this trait to define how your server looks up these secrets—whether
41/// from a static list, a configuration file, or a database.
42pub trait SecretSource: Send + Sync + 'static {
43    /// Retrieves a complete list of IP networks and their associated secrets.
44    ///
45    /// The server's secret manager will call this periodically to refresh its
46    /// internal cache.
47    ///
48    /// ### Returns
49    /// A list of tuples where:
50    /// * `Cidr` - The IP range (e.g., 192.168.1.0/24) allowed to connect.
51    /// * `Vec<u8>` - The shared secret used to sign and encrypt packets for that range.
52    fn get_all_secrets(
53        &self,
54    ) -> impl Future<Output = Result<Vec<(Cidr, Vec<u8>)>, BoxError>> + Send;
55}
56
57pub trait SecretSourceExt: Send + Sync + 'static {
58    fn get_all_secrets_boxed(&self) -> BoxFuture<'_, SecretResult>;
59}
60
61impl<T: SecretSource> SecretSourceExt for T {
62    fn get_all_secrets_boxed(&self) -> BoxFuture<'_, SecretResult> {
63        Box::pin(self.get_all_secrets())
64    }
65}
66
67pub trait SecretProvider: Send + Sync + 'static {
68    fn get_secret(&self, client_ip: IpAddr) -> impl Future<Output = Option<Arc<[u8]>>> + Send;
69}
70
71pub trait SecretProviderExt: Send + Sync + 'static {
72    fn get_secret_boxed(&self, client_ip: IpAddr) -> BoxFuture<'_, Option<Arc<[u8]>>>;
73}
74
75impl<T: SecretProvider> SecretProviderExt for T {
76    fn get_secret_boxed(&self, client_ip: IpAddr) -> BoxFuture<'_, Option<Arc<[u8]>>> {
77        Box::pin(self.get_secret(client_ip))
78    }
79}
80
81/// High-level manager responsible for coordinating secret retrieval and caching.
82///
83/// The `SecretManager` wraps a [SecretProvider] (stored as `source`) and adds a
84/// caching layer to reduce the overhead of frequent lookups, such as
85/// database queries or remote API calls.
86pub struct SecretManager {
87    /// In-memory cache for frequently used secrets.
88    cache: SecretCache,
89    /// The underlying source of secrets, stored as an object-safe trait.
90    source: Arc<dyn SecretSourceExt>,
91}
92
93impl SecretManager {
94    /// Creates a new `SecretManager` with the specified secret source and cache expiration.
95    ///
96    /// This constructor initializes an internal [SecretCache] that helps avoid
97    /// redundant lookups for the same client IP, improving performance during
98    /// high-frequency RADIUS requests.
99    ///
100    /// # Arguments
101    ///
102    /// * `source` - An [Arc]-wrapped trait object that provides the actual shared secrets.
103    /// * `cache_ttl_secs` - The duration (in seconds) that a secret remains valid in the
104    ///   cache before being refreshed from the source.
105    pub fn new(source: Arc<dyn SecretSourceExt>, cache_ttl_secs: u64) -> Self {
106        let cache = Cache::builder()
107            .max_capacity(1000) // Note: Usually larger than 1 for production
108            .time_to_live(std::time::Duration::from_secs(cache_ttl_secs))
109            .build();
110        Self { cache, source }
111    }
112    async fn get_secret_internal(&self, client_ip: IpAddr) -> Option<Arc<[u8]>> {
113        let table = self
114            .cache
115            .get_with((), async { self.reload().await.unwrap_or_default() })
116            .await;
117
118        table
119            .iter()
120            .find(|(cidr, _)| cidr.contains(&client_ip))
121            .map(|(_, secret)| secret.clone())
122    }
123
124    async fn reload(&self) -> Result<Arc<Vec<(Cidr, Arc<[u8]>)>>, BoxError> {
125        let entries = self.source.get_all_secrets_boxed().await?;
126        let arc_entries = entries
127            .into_iter()
128            .map(|(cidr, secret)| (cidr, Arc::from(secret)))
129            .collect();
130        Ok(Arc::new(arc_entries))
131    }
132}
133
134impl SecretProvider for SecretManager {
135    async fn get_secret(&self, client_ip: IpAddr) -> Option<Arc<[u8]>> {
136        self.get_secret_internal(client_ip).await
137    }
138}
139
140/// A trait for processing incoming RADIUS requests and producing responses.
141///
142/// `Handler` is the primary entry point for your application logic. It receives
143/// a [Request], which contains the decrypted RADIUS packet and client metadata,
144/// and must return a [HandlerResult] containing the [Response].
145///
146/// Since the trait is `Send + Sync + 'static`, it is safe to share across
147/// multiple threads and asynchronous tasks.
148pub trait Handler: Send + Sync + 'static {
149    /// Processes a RADIUS request asynchronously.
150    ///
151    /// # Parameters
152    /// - `request`: The incoming RADIUS request and client information.
153    ///
154    /// # Returns
155    /// A future that resolves to a `HandlerResult<Response>`.
156    fn handle(
157        &self,
158        request: Request,
159    ) -> impl std::future::Future<Output = HandlerResult<Response>> + Send;
160}
161
162/// A wrapper that allows a simple function or closure to be used as a [Handler].
163///
164pub struct HandlerFn<F>(pub F);
165
166impl<F, Fut> Handler for HandlerFn<F>
167where
168    F: Fn(Request) -> Fut + Send + Sync + 'static,
169    Fut: std::future::Future<Output = HandlerResult<Response>> + Send + 'static,
170{
171    /// Executes the wrapped closure to process the request.
172    async fn handle(&self, request: Request) -> HandlerResult<Response> {
173        (self.0)(request).await
174    }
175}
176
177pub(crate) struct ServerContext<S, H>
178where
179    S: SecretProvider,
180    H: Handler,
181{
182    pub secret_provider: Arc<S>,
183    pub handler: Arc<H>,
184    pub undergoing_requests: Arc<DashSet<RequestKey>>,
185    pub active_tasks: Arc<AtomicUsize>,
186}
187
188/// A high-performance RADIUS server instance.
189///
190/// The `Server` acts as the central orchestrator of the RADIUS stack, managing
191/// the network lifecycle, request deduplication, and task execution. It is
192/// generic over the secret retrieval logic, the request handler, and the
193/// underlying async runtime.
194///
195/// ### Type Parameters
196/// * `S`: Implements [SecretProvider] to look up shared secrets for incoming packets.
197/// * `H`: Implements [Handler] to define how the server responds to RADIUS requests.
198/// * `R`: Implements [Runtime] to abstract over different async executors (e.g., Tokio or Smol).
199pub struct Server<S, H, R>
200where
201    S: SecretProvider,
202    H: Handler,
203    R: Runtime,
204{
205    /// The async runtime executor used for spawning tasks and timers.
206    runtime: Arc<R>,
207    /// The abstract socket (UDP) provided by the runtime.
208    socket: Arc<R::Socket>,
209    /// The provider used to authenticate incoming packets via shared secrets.
210    secret_provider: Arc<S>,
211    /// The user-defined logic for processing RADIUS attributes and codes.
212    handler: Arc<H>,
213    /// A set used for request deduplication to prevent processing the same
214    /// Packet ID/Source IP combination multiple times concurrently.
215    undergoing_requests: Arc<DashSet<RequestKey>>,
216    /// A counter tracking the number of currently executing asynchronous tasks.
217    active_tasks: Arc<AtomicUsize>,
218    /// An optional future that, when resolved, triggers a graceful shutdown.
219    shutdown_signal: Option<BoxFuture<'static, ()>>,
220}
221
222impl<S, H, R> Server<S, H, R>
223where
224    S: SecretProvider + 'static,
225    H: Handler + 'static,
226    R: Runtime + 'static,
227{
228    pub fn new(runtime: R, socket: R::Socket, secret_provider: S, handler: H) -> Self {
229        Self {
230            runtime: Arc::new(runtime),
231            socket: Arc::new(socket),
232            secret_provider: Arc::new(secret_provider),
233            handler: Arc::new(handler),
234            undergoing_requests: Arc::new(DashSet::new()),
235            active_tasks: Arc::new(AtomicUsize::new(0)),
236            shutdown_signal: None,
237        }
238    }
239
240    pub fn local_addr(&self) -> anyhow::Result<SocketAddr> {
241        self.socket.local_addr().map_err(|e| anyhow!(e))
242    }
243
244    pub fn with_graceful_shutdown<F>(mut self, shutdown: F) -> Self
245    where
246        F: std::future::Future<Output = ()> + Send + 'static,
247    {
248        self.shutdown_signal = Some(Box::pin(shutdown));
249        self
250    }
251
252    /// Starts the RADIUS server and begins listening for incoming packets.
253    ///
254    /// This is the main entry point of the server. It will run indefinitely until:
255    /// 1. An unrecoverable network error occurs.
256    /// 2. The `shutdown_signal` (if provided) is triggered.
257    ///
258    /// ### Graceful Shutdown
259    /// When a shutdown signal is received, the server stops accepting new packets
260    /// immediately but waits for all currently processing requests (active tasks)
261    /// to finish before returning. This ensures no client requests are "dropped"
262    /// mid-processing.
263    ///
264    /// # Errors
265    /// Returns an error if the server fails to retrieve the local address or if
266    /// the internal run loop encounters a fatal exception.
267    pub async fn listen_and_serve(self) -> anyhow::Result<()> {
268        let local_addr_str = self.socket.local_addr()?.to_string();
269        let context = Arc::new(ServerContext {
270            secret_provider: Arc::clone(&self.secret_provider),
271            handler: Arc::clone(&self.handler),
272            undergoing_requests: Arc::clone(&self.undergoing_requests),
273            active_tasks: Arc::clone(&self.active_tasks),
274        });
275
276        let mut shutdown = self
277            .shutdown_signal
278            .unwrap_or_else(|| Box::pin(std::future::pending()));
279
280        let mut run_loop_fut = Box::pin(Self::run_loop(
281            Arc::clone(&context),
282            Arc::clone(&self.socket),
283            Arc::clone(&self.runtime),
284            local_addr_str,
285        ));
286
287        let result = std::future::poll_fn(|cx| {
288            if shutdown.as_mut().poll(cx).is_ready() {
289                return Poll::Ready(Ok(()));
290            }
291            if let Poll::Ready(res) = run_loop_fut.as_mut().poll(cx) {
292                return Poll::Ready(res);
293            }
294            Poll::Pending
295        })
296        .await;
297
298        while context.active_tasks.load(Ordering::SeqCst) > 0 {
299            YieldNow::new().await;
300        }
301        result
302    }
303
304    async fn run_loop(
305        context: Arc<ServerContext<S, H>>,
306        socket: Arc<R::Socket>,
307        runtime: Arc<R>,
308        local_addr: String,
309    ) -> anyhow::Result<()> {
310        loop {
311            let mut buf = [0u8; MAX_PACKET_SIZE];
312            let (len, peer_addr) = socket
313                .recv_from(&mut buf)
314                .await
315                .context("Failed to receive")?;
316
317            let data = Bytes::copy_from_slice(&buf[..len]);
318            let ctx = Arc::clone(&context);
319            let sock = Arc::clone(&socket);
320            let l_addr = local_addr.clone();
321            let rt = Arc::clone(&runtime);
322
323            ctx.active_tasks.fetch_add(1, Ordering::SeqCst);
324
325            rt.executor().execute(Box::pin(async move {
326                let _guard = TaskGuard::new(Arc::clone(&ctx));
327
328                let secret = match ctx.secret_provider.get_secret(peer_addr.ip()).await {
329                    Some(s) => s,
330                    None => return,
331                };
332
333                let packet = match Packet::parse_packet(data, Arc::clone(&secret)) {
334                    Ok(p) => p,
335                    Err(_) => return,
336                };
337
338                let key = RequestKey {
339                    addr: peer_addr,
340                    identifier: packet.identifier,
341                };
342                if !ctx.undergoing_requests.insert(key.clone()) {
343                    return;
344                }
345
346                let _ = Self::process(&ctx, packet, l_addr, peer_addr, sock).await;
347                ctx.undergoing_requests.remove(&key);
348            }));
349        }
350    }
351
352    async fn process(
353        ctx: &ServerContext<S, H>,
354        packet: Packet,
355        local_addr: String,
356        peer_addr: SocketAddr,
357        socket: Arc<R::Socket>,
358    ) -> anyhow::Result<()> {
359        if !packet.verify_request() {
360            return Err(anyhow!("Invalid authenticator from {}", peer_addr));
361        }
362
363        let request = Request {
364            local_addr,
365            remote_addr: peer_addr.to_string(),
366            packet,
367        };
368
369        let response = ctx
370            .handler
371            .handle(request)
372            .await
373            .map_err(|e| anyhow!("Handler error: {:?}", e))?;
374
375        let encoded = response.packet.encode().context("Encoding failed")?;
376
377        socket
378            .send_to(&encoded, peer_addr)
379            .await
380            .context("UDP send failed")?;
381
382        Ok(())
383    }
384}
385struct TaskGuard<S, H>
386where
387    S: SecretProvider,
388    H: Handler,
389{
390    context: Arc<ServerContext<S, H>>,
391}
392
393impl<S, H> TaskGuard<S, H>
394where
395    S: SecretProvider,
396    H: Handler,
397{
398    fn new(context: Arc<ServerContext<S, H>>) -> Self {
399        Self { context }
400    }
401}
402
403impl<S, H> Drop for TaskGuard<S, H>
404where
405    S: SecretProvider,
406    H: Handler,
407{
408    fn drop(&mut self) {
409        self.context.active_tasks.fetch_sub(1, Ordering::SeqCst);
410    }
411}