aranet_core/
reconnect.rs

1//! Automatic reconnection handling for Aranet devices.
2//!
3//! This module provides a wrapper around Device that automatically
4//! handles reconnection when the connection is lost.
5//!
6//! [`ReconnectingDevice`] implements the [`AranetDevice`] trait,
7//! allowing it to be used interchangeably with regular devices in generic code.
8
9use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::time::Duration;
12
13use async_trait::async_trait;
14use tokio::sync::RwLock;
15use tokio::time::sleep;
16use tracing::{info, warn};
17
18use aranet_types::{CurrentReading, DeviceInfo, DeviceType, HistoryRecord};
19
20use crate::device::Device;
21use crate::error::{Error, Result};
22use crate::events::{DeviceEvent, DeviceId, EventSender};
23use crate::history::{HistoryInfo, HistoryOptions};
24use crate::settings::{CalibrationData, MeasurementInterval};
25use crate::traits::AranetDevice;
26
27/// Options for automatic reconnection.
28#[derive(Debug, Clone)]
29pub struct ReconnectOptions {
30    /// Maximum number of reconnection attempts (None = unlimited).
31    pub max_attempts: Option<u32>,
32    /// Initial delay before first reconnection attempt.
33    pub initial_delay: Duration,
34    /// Maximum delay between attempts (for exponential backoff).
35    pub max_delay: Duration,
36    /// Multiplier for exponential backoff.
37    pub backoff_multiplier: f64,
38    /// Whether to use exponential backoff.
39    pub use_exponential_backoff: bool,
40}
41
42impl Default for ReconnectOptions {
43    fn default() -> Self {
44        Self {
45            max_attempts: Some(5),
46            initial_delay: Duration::from_secs(1),
47            max_delay: Duration::from_secs(60),
48            backoff_multiplier: 2.0,
49            use_exponential_backoff: true,
50        }
51    }
52}
53
54impl ReconnectOptions {
55    /// Create new reconnect options with defaults.
56    pub fn new() -> Self {
57        Self::default()
58    }
59
60    /// Create options with unlimited retry attempts.
61    pub fn unlimited() -> Self {
62        Self {
63            max_attempts: None,
64            ..Default::default()
65        }
66    }
67
68    /// Create options with a fixed delay (no backoff).
69    pub fn fixed_delay(delay: Duration) -> Self {
70        Self {
71            initial_delay: delay,
72            use_exponential_backoff: false,
73            ..Default::default()
74        }
75    }
76
77    /// Set maximum number of reconnection attempts.
78    pub fn max_attempts(mut self, attempts: u32) -> Self {
79        self.max_attempts = Some(attempts);
80        self
81    }
82
83    /// Set initial delay before first reconnection attempt.
84    pub fn initial_delay(mut self, delay: Duration) -> Self {
85        self.initial_delay = delay;
86        self
87    }
88
89    /// Set maximum delay between attempts.
90    pub fn max_delay(mut self, delay: Duration) -> Self {
91        self.max_delay = delay;
92        self
93    }
94
95    /// Set backoff multiplier for exponential backoff.
96    pub fn backoff_multiplier(mut self, multiplier: f64) -> Self {
97        self.backoff_multiplier = multiplier;
98        self
99    }
100
101    /// Enable or disable exponential backoff.
102    pub fn exponential_backoff(mut self, enabled: bool) -> Self {
103        self.use_exponential_backoff = enabled;
104        self
105    }
106
107    /// Calculate delay for a given attempt number.
108    pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
109        if !self.use_exponential_backoff {
110            return self.initial_delay;
111        }
112
113        let delay_ms =
114            self.initial_delay.as_millis() as f64 * self.backoff_multiplier.powi(attempt as i32);
115        let delay = Duration::from_millis(delay_ms as u64);
116
117        delay.min(self.max_delay)
118    }
119
120    /// Validate the options and return an error if invalid.
121    ///
122    /// Checks that:
123    /// - `backoff_multiplier` is >= 1.0
124    /// - `initial_delay` is > 0
125    /// - `max_delay` >= `initial_delay`
126    pub fn validate(&self) -> Result<()> {
127        if self.backoff_multiplier < 1.0 {
128            return Err(Error::InvalidConfig(
129                "backoff_multiplier must be >= 1.0".to_string(),
130            ));
131        }
132        if self.initial_delay.is_zero() {
133            return Err(Error::InvalidConfig(
134                "initial_delay must be > 0".to_string(),
135            ));
136        }
137        if self.max_delay < self.initial_delay {
138            return Err(Error::InvalidConfig(
139                "max_delay must be >= initial_delay".to_string(),
140            ));
141        }
142        Ok(())
143    }
144}
145
146/// State of the reconnecting device.
147#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub enum ConnectionState {
149    /// Device is connected and operational.
150    Connected,
151    /// Device is disconnected.
152    Disconnected,
153    /// Attempting to reconnect.
154    Reconnecting,
155    /// Reconnection failed after max attempts.
156    Failed,
157}
158
159/// A device wrapper that automatically handles reconnection.
160///
161/// This wrapper caches the device name and type upon initial connection so they
162/// can be accessed synchronously via the [`AranetDevice`] trait, even while
163/// reconnecting.
164pub struct ReconnectingDevice {
165    identifier: String,
166    /// The connected device, wrapped in Arc to allow concurrent access.
167    device: RwLock<Option<Arc<Device>>>,
168    options: ReconnectOptions,
169    state: RwLock<ConnectionState>,
170    event_sender: Option<EventSender>,
171    attempt_count: RwLock<u32>,
172    /// Cancellation flag for stopping reconnection attempts.
173    cancelled: Arc<AtomicBool>,
174    /// Cached device name (populated on first connection).
175    cached_name: std::sync::OnceLock<String>,
176    /// Cached device type (populated on first connection).
177    cached_device_type: std::sync::OnceLock<DeviceType>,
178}
179
180impl ReconnectingDevice {
181    /// Create a new reconnecting device wrapper.
182    pub async fn connect(identifier: &str, options: ReconnectOptions) -> Result<Self> {
183        let device = Arc::new(Device::connect(identifier).await?);
184
185        // Cache the name and device type for synchronous access
186        let cached_name = std::sync::OnceLock::new();
187        if let Some(name) = device.name() {
188            let _ = cached_name.set(name.to_string());
189        }
190
191        let cached_device_type = std::sync::OnceLock::new();
192        if let Some(device_type) = device.device_type() {
193            let _ = cached_device_type.set(device_type);
194        }
195
196        Ok(Self {
197            identifier: identifier.to_string(),
198            device: RwLock::new(Some(device)),
199            options,
200            state: RwLock::new(ConnectionState::Connected),
201            event_sender: None,
202            attempt_count: RwLock::new(0),
203            cancelled: Arc::new(AtomicBool::new(false)),
204            cached_name,
205            cached_device_type,
206        })
207    }
208
209    /// Create with an event sender for notifications.
210    pub async fn connect_with_events(
211        identifier: &str,
212        options: ReconnectOptions,
213        event_sender: EventSender,
214    ) -> Result<Self> {
215        let mut this = Self::connect(identifier, options).await?;
216        this.event_sender = Some(event_sender);
217        Ok(this)
218    }
219
220    /// Cancel any ongoing reconnection attempts.
221    ///
222    /// This will cause the reconnect loop to exit on its next iteration.
223    pub fn cancel_reconnect(&self) {
224        self.cancelled.store(true, Ordering::SeqCst);
225    }
226
227    /// Check if reconnection has been cancelled.
228    pub fn is_cancelled(&self) -> bool {
229        self.cancelled.load(Ordering::SeqCst)
230    }
231
232    /// Reset the cancellation flag (call before starting a new reconnection).
233    fn reset_cancellation(&self) {
234        self.cancelled.store(false, Ordering::SeqCst);
235    }
236
237    /// Get the current connection state.
238    pub async fn state(&self) -> ConnectionState {
239        *self.state.read().await
240    }
241
242    /// Check if currently connected.
243    pub async fn is_connected(&self) -> bool {
244        let guard = self.device.read().await;
245        if let Some(device) = guard.as_ref() {
246            device.is_connected().await
247        } else {
248            false
249        }
250    }
251
252    /// Get the identifier.
253    pub fn identifier(&self) -> &str {
254        &self.identifier
255    }
256
257    /// Execute an operation, reconnecting if necessary.
258    ///
259    /// The closure is called with a reference to the device. If the operation
260    /// fails due to a connection issue, the device will attempt to reconnect
261    /// and retry the operation.
262    ///
263    /// # Example
264    ///
265    /// ```ignore
266    /// let reading = device.with_device(|d| async { d.read_current().await }).await?;
267    /// ```
268    pub async fn with_device<F, Fut, T>(&self, f: F) -> Result<T>
269    where
270        F: Fn(&Device) -> Fut,
271        Fut: std::future::Future<Output = Result<T>>,
272    {
273        // Try the operation if already connected
274        {
275            let guard = self.device.read().await;
276            if let Some(device) = guard.as_ref()
277                && device.is_connected().await
278            {
279                match f(device).await {
280                    Ok(result) => return Ok(result),
281                    Err(e) => {
282                        warn!("Operation failed: {}", e);
283                        // Fall through to reconnect
284                    }
285                }
286            }
287        }
288
289        // Need to reconnect
290        self.reconnect().await?;
291
292        // Retry the operation after reconnection
293        let guard = self.device.read().await;
294        if let Some(device) = guard.as_ref() {
295            f(device).await
296        } else {
297            Err(Error::NotConnected)
298        }
299    }
300
301    /// Internal helper that executes an operation with automatic reconnection using boxed futures.
302    ///
303    /// This method uses explicit HRTB (Higher-Rank Trait Bounds) to handle the complex
304    /// lifetime requirements when returning futures from closures. It's used internally
305    /// by the `AranetDevice` trait implementation.
306    ///
307    /// Note: We cannot consolidate this with `with_device` due to Rust's async closure
308    /// lifetime limitations. The `with_device` method provides a more ergonomic API for
309    /// callers, while this method handles the trait implementation requirements.
310    async fn run_with_reconnect<'a, T, F>(&'a self, f: F) -> Result<T>
311    where
312        F: for<'b> Fn(
313                &'b Device,
314            ) -> std::pin::Pin<
315                Box<dyn std::future::Future<Output = Result<T>> + Send + 'b>,
316            > + Send
317            + Sync,
318        T: Send,
319    {
320        // Try the operation if already connected
321        {
322            let guard = self.device.read().await;
323            if let Some(device) = guard.as_ref()
324                && device.is_connected().await
325            {
326                match f(device).await {
327                    Ok(result) => return Ok(result),
328                    Err(e) => {
329                        warn!("Operation failed: {}", e);
330                        // Fall through to reconnect
331                    }
332                }
333            }
334        }
335
336        // Need to reconnect
337        self.reconnect().await?;
338
339        // Retry the operation after reconnection
340        let guard = self.device.read().await;
341        if let Some(device) = guard.as_ref() {
342            f(device).await
343        } else {
344            Err(Error::NotConnected)
345        }
346    }
347
348    /// Attempt to reconnect to the device.
349    ///
350    /// This loop can be cancelled by calling `cancel_reconnect()` from another task.
351    /// When cancelled, returns `Error::Cancelled`.
352    pub async fn reconnect(&self) -> Result<()> {
353        // Reset cancellation flag at the start
354        self.reset_cancellation();
355
356        *self.state.write().await = ConnectionState::Reconnecting;
357        *self.attempt_count.write().await = 0;
358
359        loop {
360            // Check for cancellation at the start of each iteration
361            if self.is_cancelled() {
362                *self.state.write().await = ConnectionState::Disconnected;
363                info!("Reconnection cancelled for {}", self.identifier);
364                return Err(Error::Cancelled);
365            }
366
367            let attempt = {
368                let mut count = self.attempt_count.write().await;
369                *count += 1;
370                *count
371            };
372
373            // Check if we've exceeded max attempts
374            if let Some(max) = self.options.max_attempts
375                && attempt > max
376            {
377                *self.state.write().await = ConnectionState::Failed;
378                return Err(Error::Timeout {
379                    operation: format!("reconnect to '{}'", self.identifier),
380                    duration: self.options.max_delay * max,
381                });
382            }
383
384            // Send reconnect started event
385            if let Some(sender) = &self.event_sender {
386                let _ = sender.send(DeviceEvent::ReconnectStarted {
387                    device: DeviceId::new(&self.identifier),
388                    attempt,
389                });
390            }
391
392            info!("Reconnection attempt {} for {}", attempt, self.identifier);
393
394            // Wait before attempting (check cancellation during sleep)
395            let delay = self.options.delay_for_attempt(attempt - 1);
396            sleep(delay).await;
397
398            // Check for cancellation after sleep
399            if self.is_cancelled() {
400                *self.state.write().await = ConnectionState::Disconnected;
401                info!("Reconnection cancelled for {}", self.identifier);
402                return Err(Error::Cancelled);
403            }
404
405            // Try to connect
406            match Device::connect(&self.identifier).await {
407                Ok(new_device) => {
408                    *self.device.write().await = Some(Arc::new(new_device));
409                    *self.state.write().await = ConnectionState::Connected;
410
411                    // Send reconnect succeeded event
412                    if let Some(sender) = &self.event_sender {
413                        let _ = sender.send(DeviceEvent::ReconnectSucceeded {
414                            device: DeviceId::new(&self.identifier),
415                            attempts: attempt,
416                        });
417                    }
418
419                    info!("Reconnected successfully after {} attempts", attempt);
420                    return Ok(());
421                }
422                Err(e) => {
423                    warn!("Reconnection attempt {} failed: {}", attempt, e);
424                }
425            }
426        }
427    }
428
429    /// Disconnect from the device.
430    pub async fn disconnect(&self) -> Result<()> {
431        let mut guard = self.device.write().await;
432        if let Some(device) = guard.take() {
433            device.disconnect().await?;
434        }
435        *self.state.write().await = ConnectionState::Disconnected;
436        Ok(())
437    }
438
439    /// Get the number of reconnection attempts made.
440    pub async fn attempt_count(&self) -> u32 {
441        *self.attempt_count.read().await
442    }
443
444    /// Get the device name, if available and connected.
445    pub async fn name(&self) -> Option<String> {
446        let guard = self.device.read().await;
447        guard.as_ref().and_then(|d| d.name().map(|s| s.to_string()))
448    }
449
450    /// Get the device address (returns identifier if not connected).
451    pub async fn address(&self) -> String {
452        let guard = self.device.read().await;
453        guard
454            .as_ref()
455            .map(|d| d.address().to_string())
456            .unwrap_or_else(|| self.identifier.clone())
457    }
458
459    /// Get the detected device type, if available.
460    pub async fn device_type(&self) -> Option<DeviceType> {
461        let guard = self.device.read().await;
462        guard.as_ref().and_then(|d| d.device_type())
463    }
464}
465
466// Implement the AranetDevice trait for ReconnectingDevice
467#[async_trait]
468impl AranetDevice for ReconnectingDevice {
469    async fn is_connected(&self) -> bool {
470        ReconnectingDevice::is_connected(self).await
471    }
472
473    async fn connect(&self) -> Result<()> {
474        // If already connected, this is a no-op
475        if self.is_connected().await {
476            return Ok(());
477        }
478        // Otherwise, attempt to reconnect
479        self.reconnect().await
480    }
481
482    async fn disconnect(&self) -> Result<()> {
483        ReconnectingDevice::disconnect(self).await
484    }
485
486    fn name(&self) -> Option<&str> {
487        self.cached_name.get().map(|s| s.as_str())
488    }
489
490    fn address(&self) -> &str {
491        &self.identifier
492    }
493
494    fn device_type(&self) -> Option<DeviceType> {
495        self.cached_device_type.get().copied()
496    }
497
498    async fn read_current(&self) -> Result<CurrentReading> {
499        self.run_with_reconnect(|d| Box::pin(d.read_current()))
500            .await
501    }
502
503    async fn read_device_info(&self) -> Result<DeviceInfo> {
504        self.run_with_reconnect(|d| Box::pin(d.read_device_info()))
505            .await
506    }
507
508    async fn read_rssi(&self) -> Result<i16> {
509        self.run_with_reconnect(|d| Box::pin(d.read_rssi())).await
510    }
511
512    async fn read_battery(&self) -> Result<u8> {
513        self.run_with_reconnect(|d| Box::pin(d.read_battery()))
514            .await
515    }
516
517    async fn get_history_info(&self) -> Result<HistoryInfo> {
518        self.run_with_reconnect(|d| Box::pin(d.get_history_info()))
519            .await
520    }
521
522    async fn download_history(&self) -> Result<Vec<HistoryRecord>> {
523        self.run_with_reconnect(|d| Box::pin(d.download_history()))
524            .await
525    }
526
527    async fn download_history_with_options(
528        &self,
529        options: HistoryOptions,
530    ) -> Result<Vec<HistoryRecord>> {
531        let opts = options.clone();
532        self.run_with_reconnect(move |d| {
533            let opts = opts.clone();
534            Box::pin(async move { d.download_history_with_options(opts).await })
535        })
536        .await
537    }
538
539    async fn get_interval(&self) -> Result<MeasurementInterval> {
540        self.run_with_reconnect(|d| Box::pin(d.get_interval()))
541            .await
542    }
543
544    async fn set_interval(&self, interval: MeasurementInterval) -> Result<()> {
545        self.run_with_reconnect(move |d| Box::pin(d.set_interval(interval)))
546            .await
547    }
548
549    async fn get_calibration(&self) -> Result<CalibrationData> {
550        self.run_with_reconnect(|d| Box::pin(d.get_calibration()))
551            .await
552    }
553}
554
555#[cfg(test)]
556mod tests {
557    use super::*;
558
559    #[test]
560    fn test_reconnect_options_default() {
561        let opts = ReconnectOptions::default();
562        assert_eq!(opts.max_attempts, Some(5));
563        assert!(opts.use_exponential_backoff);
564    }
565
566    #[test]
567    fn test_reconnect_options_unlimited() {
568        let opts = ReconnectOptions::unlimited();
569        assert!(opts.max_attempts.is_none());
570    }
571
572    #[test]
573    fn test_delay_calculation() {
574        let opts = ReconnectOptions {
575            initial_delay: Duration::from_secs(1),
576            max_delay: Duration::from_secs(60),
577            backoff_multiplier: 2.0,
578            use_exponential_backoff: true,
579            ..Default::default()
580        };
581
582        assert_eq!(opts.delay_for_attempt(0), Duration::from_secs(1));
583        assert_eq!(opts.delay_for_attempt(1), Duration::from_secs(2));
584        assert_eq!(opts.delay_for_attempt(2), Duration::from_secs(4));
585        assert_eq!(opts.delay_for_attempt(3), Duration::from_secs(8));
586    }
587
588    #[test]
589    fn test_delay_capped_at_max() {
590        let opts = ReconnectOptions {
591            initial_delay: Duration::from_secs(1),
592            max_delay: Duration::from_secs(10),
593            backoff_multiplier: 2.0,
594            use_exponential_backoff: true,
595            ..Default::default()
596        };
597
598        // 2^10 = 1024 seconds, but capped at 10
599        assert_eq!(opts.delay_for_attempt(10), Duration::from_secs(10));
600    }
601
602    #[test]
603    fn test_fixed_delay() {
604        let opts = ReconnectOptions::fixed_delay(Duration::from_secs(5));
605        assert_eq!(opts.delay_for_attempt(0), Duration::from_secs(5));
606        assert_eq!(opts.delay_for_attempt(5), Duration::from_secs(5));
607    }
608}