1use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::time::Duration;
12
13use async_trait::async_trait;
14use tokio::sync::RwLock;
15use tokio::time::sleep;
16use tracing::{info, warn};
17
18use aranet_types::{CurrentReading, DeviceInfo, DeviceType, HistoryRecord};
19
20use crate::device::Device;
21use crate::error::{Error, Result};
22use crate::events::{DeviceEvent, DeviceId, EventSender};
23use crate::history::{HistoryInfo, HistoryOptions};
24use crate::settings::{CalibrationData, MeasurementInterval};
25use crate::traits::AranetDevice;
26
27#[derive(Debug, Clone)]
29pub struct ReconnectOptions {
30 pub max_attempts: Option<u32>,
32 pub initial_delay: Duration,
34 pub max_delay: Duration,
36 pub backoff_multiplier: f64,
38 pub use_exponential_backoff: bool,
40}
41
42impl Default for ReconnectOptions {
43 fn default() -> Self {
44 Self {
45 max_attempts: Some(5),
46 initial_delay: Duration::from_secs(1),
47 max_delay: Duration::from_secs(60),
48 backoff_multiplier: 2.0,
49 use_exponential_backoff: true,
50 }
51 }
52}
53
54impl ReconnectOptions {
55 pub fn new() -> Self {
57 Self::default()
58 }
59
60 pub fn unlimited() -> Self {
62 Self {
63 max_attempts: None,
64 ..Default::default()
65 }
66 }
67
68 pub fn fixed_delay(delay: Duration) -> Self {
70 Self {
71 initial_delay: delay,
72 use_exponential_backoff: false,
73 ..Default::default()
74 }
75 }
76
77 pub fn max_attempts(mut self, attempts: u32) -> Self {
79 self.max_attempts = Some(attempts);
80 self
81 }
82
83 pub fn initial_delay(mut self, delay: Duration) -> Self {
85 self.initial_delay = delay;
86 self
87 }
88
89 pub fn max_delay(mut self, delay: Duration) -> Self {
91 self.max_delay = delay;
92 self
93 }
94
95 pub fn backoff_multiplier(mut self, multiplier: f64) -> Self {
97 self.backoff_multiplier = multiplier;
98 self
99 }
100
101 pub fn exponential_backoff(mut self, enabled: bool) -> Self {
103 self.use_exponential_backoff = enabled;
104 self
105 }
106
107 pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
109 if !self.use_exponential_backoff {
110 return self.initial_delay;
111 }
112
113 let delay_ms =
114 self.initial_delay.as_millis() as f64 * self.backoff_multiplier.powi(attempt as i32);
115 let delay = Duration::from_millis(delay_ms as u64);
116
117 delay.min(self.max_delay)
118 }
119
120 pub fn validate(&self) -> Result<()> {
127 if self.backoff_multiplier < 1.0 {
128 return Err(Error::InvalidConfig(
129 "backoff_multiplier must be >= 1.0".to_string(),
130 ));
131 }
132 if self.initial_delay.is_zero() {
133 return Err(Error::InvalidConfig(
134 "initial_delay must be > 0".to_string(),
135 ));
136 }
137 if self.max_delay < self.initial_delay {
138 return Err(Error::InvalidConfig(
139 "max_delay must be >= initial_delay".to_string(),
140 ));
141 }
142 Ok(())
143 }
144}
145
146#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub enum ConnectionState {
149 Connected,
151 Disconnected,
153 Reconnecting,
155 Failed,
157}
158
159pub struct ReconnectingDevice {
165 identifier: String,
166 device: RwLock<Option<Arc<Device>>>,
168 options: ReconnectOptions,
169 state: RwLock<ConnectionState>,
170 event_sender: Option<EventSender>,
171 attempt_count: RwLock<u32>,
172 cancelled: Arc<AtomicBool>,
174 cached_name: std::sync::OnceLock<String>,
176 cached_device_type: std::sync::OnceLock<DeviceType>,
178}
179
180impl ReconnectingDevice {
181 pub async fn connect(identifier: &str, options: ReconnectOptions) -> Result<Self> {
183 let device = Arc::new(Device::connect(identifier).await?);
184
185 let cached_name = std::sync::OnceLock::new();
187 if let Some(name) = device.name() {
188 let _ = cached_name.set(name.to_string());
189 }
190
191 let cached_device_type = std::sync::OnceLock::new();
192 if let Some(device_type) = device.device_type() {
193 let _ = cached_device_type.set(device_type);
194 }
195
196 Ok(Self {
197 identifier: identifier.to_string(),
198 device: RwLock::new(Some(device)),
199 options,
200 state: RwLock::new(ConnectionState::Connected),
201 event_sender: None,
202 attempt_count: RwLock::new(0),
203 cancelled: Arc::new(AtomicBool::new(false)),
204 cached_name,
205 cached_device_type,
206 })
207 }
208
209 pub async fn connect_with_events(
211 identifier: &str,
212 options: ReconnectOptions,
213 event_sender: EventSender,
214 ) -> Result<Self> {
215 let mut this = Self::connect(identifier, options).await?;
216 this.event_sender = Some(event_sender);
217 Ok(this)
218 }
219
220 pub fn cancel_reconnect(&self) {
224 self.cancelled.store(true, Ordering::SeqCst);
225 }
226
227 pub fn is_cancelled(&self) -> bool {
229 self.cancelled.load(Ordering::SeqCst)
230 }
231
232 fn reset_cancellation(&self) {
234 self.cancelled.store(false, Ordering::SeqCst);
235 }
236
237 pub async fn state(&self) -> ConnectionState {
239 *self.state.read().await
240 }
241
242 pub async fn is_connected(&self) -> bool {
244 let guard = self.device.read().await;
245 if let Some(device) = guard.as_ref() {
246 device.is_connected().await
247 } else {
248 false
249 }
250 }
251
252 pub fn identifier(&self) -> &str {
254 &self.identifier
255 }
256
257 pub async fn with_device<F, Fut, T>(&self, f: F) -> Result<T>
269 where
270 F: Fn(&Device) -> Fut,
271 Fut: std::future::Future<Output = Result<T>>,
272 {
273 {
275 let guard = self.device.read().await;
276 if let Some(device) = guard.as_ref()
277 && device.is_connected().await
278 {
279 match f(device).await {
280 Ok(result) => return Ok(result),
281 Err(e) => {
282 warn!("Operation failed: {}", e);
283 }
285 }
286 }
287 }
288
289 self.reconnect().await?;
291
292 let guard = self.device.read().await;
294 if let Some(device) = guard.as_ref() {
295 f(device).await
296 } else {
297 Err(Error::NotConnected)
298 }
299 }
300
301 async fn run_with_reconnect<'a, T, F>(&'a self, f: F) -> Result<T>
311 where
312 F: for<'b> Fn(
313 &'b Device,
314 ) -> std::pin::Pin<
315 Box<dyn std::future::Future<Output = Result<T>> + Send + 'b>,
316 > + Send
317 + Sync,
318 T: Send,
319 {
320 {
322 let guard = self.device.read().await;
323 if let Some(device) = guard.as_ref()
324 && device.is_connected().await
325 {
326 match f(device).await {
327 Ok(result) => return Ok(result),
328 Err(e) => {
329 warn!("Operation failed: {}", e);
330 }
332 }
333 }
334 }
335
336 self.reconnect().await?;
338
339 let guard = self.device.read().await;
341 if let Some(device) = guard.as_ref() {
342 f(device).await
343 } else {
344 Err(Error::NotConnected)
345 }
346 }
347
348 pub async fn reconnect(&self) -> Result<()> {
353 self.reset_cancellation();
355
356 *self.state.write().await = ConnectionState::Reconnecting;
357 *self.attempt_count.write().await = 0;
358
359 loop {
360 if self.is_cancelled() {
362 *self.state.write().await = ConnectionState::Disconnected;
363 info!("Reconnection cancelled for {}", self.identifier);
364 return Err(Error::Cancelled);
365 }
366
367 let attempt = {
368 let mut count = self.attempt_count.write().await;
369 *count += 1;
370 *count
371 };
372
373 if let Some(max) = self.options.max_attempts
375 && attempt > max
376 {
377 *self.state.write().await = ConnectionState::Failed;
378 return Err(Error::Timeout {
379 operation: format!("reconnect to '{}'", self.identifier),
380 duration: self.options.max_delay * max,
381 });
382 }
383
384 if let Some(sender) = &self.event_sender {
386 let _ = sender.send(DeviceEvent::ReconnectStarted {
387 device: DeviceId::new(&self.identifier),
388 attempt,
389 });
390 }
391
392 info!("Reconnection attempt {} for {}", attempt, self.identifier);
393
394 let delay = self.options.delay_for_attempt(attempt - 1);
396 sleep(delay).await;
397
398 if self.is_cancelled() {
400 *self.state.write().await = ConnectionState::Disconnected;
401 info!("Reconnection cancelled for {}", self.identifier);
402 return Err(Error::Cancelled);
403 }
404
405 match Device::connect(&self.identifier).await {
407 Ok(new_device) => {
408 *self.device.write().await = Some(Arc::new(new_device));
409 *self.state.write().await = ConnectionState::Connected;
410
411 if let Some(sender) = &self.event_sender {
413 let _ = sender.send(DeviceEvent::ReconnectSucceeded {
414 device: DeviceId::new(&self.identifier),
415 attempts: attempt,
416 });
417 }
418
419 info!("Reconnected successfully after {} attempts", attempt);
420 return Ok(());
421 }
422 Err(e) => {
423 warn!("Reconnection attempt {} failed: {}", attempt, e);
424 }
425 }
426 }
427 }
428
429 pub async fn disconnect(&self) -> Result<()> {
431 let mut guard = self.device.write().await;
432 if let Some(device) = guard.take() {
433 device.disconnect().await?;
434 }
435 *self.state.write().await = ConnectionState::Disconnected;
436 Ok(())
437 }
438
439 pub async fn attempt_count(&self) -> u32 {
441 *self.attempt_count.read().await
442 }
443
444 pub async fn name(&self) -> Option<String> {
446 let guard = self.device.read().await;
447 guard.as_ref().and_then(|d| d.name().map(|s| s.to_string()))
448 }
449
450 pub async fn address(&self) -> String {
452 let guard = self.device.read().await;
453 guard
454 .as_ref()
455 .map(|d| d.address().to_string())
456 .unwrap_or_else(|| self.identifier.clone())
457 }
458
459 pub async fn device_type(&self) -> Option<DeviceType> {
461 let guard = self.device.read().await;
462 guard.as_ref().and_then(|d| d.device_type())
463 }
464}
465
466#[async_trait]
468impl AranetDevice for ReconnectingDevice {
469 async fn is_connected(&self) -> bool {
470 ReconnectingDevice::is_connected(self).await
471 }
472
473 async fn connect(&self) -> Result<()> {
474 if self.is_connected().await {
476 return Ok(());
477 }
478 self.reconnect().await
480 }
481
482 async fn disconnect(&self) -> Result<()> {
483 ReconnectingDevice::disconnect(self).await
484 }
485
486 fn name(&self) -> Option<&str> {
487 self.cached_name.get().map(|s| s.as_str())
488 }
489
490 fn address(&self) -> &str {
491 &self.identifier
492 }
493
494 fn device_type(&self) -> Option<DeviceType> {
495 self.cached_device_type.get().copied()
496 }
497
498 async fn read_current(&self) -> Result<CurrentReading> {
499 self.run_with_reconnect(|d| Box::pin(d.read_current()))
500 .await
501 }
502
503 async fn read_device_info(&self) -> Result<DeviceInfo> {
504 self.run_with_reconnect(|d| Box::pin(d.read_device_info()))
505 .await
506 }
507
508 async fn read_rssi(&self) -> Result<i16> {
509 self.run_with_reconnect(|d| Box::pin(d.read_rssi())).await
510 }
511
512 async fn read_battery(&self) -> Result<u8> {
513 self.run_with_reconnect(|d| Box::pin(d.read_battery()))
514 .await
515 }
516
517 async fn get_history_info(&self) -> Result<HistoryInfo> {
518 self.run_with_reconnect(|d| Box::pin(d.get_history_info()))
519 .await
520 }
521
522 async fn download_history(&self) -> Result<Vec<HistoryRecord>> {
523 self.run_with_reconnect(|d| Box::pin(d.download_history()))
524 .await
525 }
526
527 async fn download_history_with_options(
528 &self,
529 options: HistoryOptions,
530 ) -> Result<Vec<HistoryRecord>> {
531 let opts = options.clone();
532 self.run_with_reconnect(move |d| {
533 let opts = opts.clone();
534 Box::pin(async move { d.download_history_with_options(opts).await })
535 })
536 .await
537 }
538
539 async fn get_interval(&self) -> Result<MeasurementInterval> {
540 self.run_with_reconnect(|d| Box::pin(d.get_interval()))
541 .await
542 }
543
544 async fn set_interval(&self, interval: MeasurementInterval) -> Result<()> {
545 self.run_with_reconnect(move |d| Box::pin(d.set_interval(interval)))
546 .await
547 }
548
549 async fn get_calibration(&self) -> Result<CalibrationData> {
550 self.run_with_reconnect(|d| Box::pin(d.get_calibration()))
551 .await
552 }
553}
554
555#[cfg(test)]
556mod tests {
557 use super::*;
558
559 #[test]
560 fn test_reconnect_options_default() {
561 let opts = ReconnectOptions::default();
562 assert_eq!(opts.max_attempts, Some(5));
563 assert!(opts.use_exponential_backoff);
564 }
565
566 #[test]
567 fn test_reconnect_options_unlimited() {
568 let opts = ReconnectOptions::unlimited();
569 assert!(opts.max_attempts.is_none());
570 }
571
572 #[test]
573 fn test_delay_calculation() {
574 let opts = ReconnectOptions {
575 initial_delay: Duration::from_secs(1),
576 max_delay: Duration::from_secs(60),
577 backoff_multiplier: 2.0,
578 use_exponential_backoff: true,
579 ..Default::default()
580 };
581
582 assert_eq!(opts.delay_for_attempt(0), Duration::from_secs(1));
583 assert_eq!(opts.delay_for_attempt(1), Duration::from_secs(2));
584 assert_eq!(opts.delay_for_attempt(2), Duration::from_secs(4));
585 assert_eq!(opts.delay_for_attempt(3), Duration::from_secs(8));
586 }
587
588 #[test]
589 fn test_delay_capped_at_max() {
590 let opts = ReconnectOptions {
591 initial_delay: Duration::from_secs(1),
592 max_delay: Duration::from_secs(10),
593 backoff_multiplier: 2.0,
594 use_exponential_backoff: true,
595 ..Default::default()
596 };
597
598 assert_eq!(opts.delay_for_attempt(10), Duration::from_secs(10));
600 }
601
602 #[test]
603 fn test_fixed_delay() {
604 let opts = ReconnectOptions::fixed_delay(Duration::from_secs(5));
605 assert_eq!(opts.delay_for_attempt(0), Duration::from_secs(5));
606 assert_eq!(opts.delay_for_attempt(5), Duration::from_secs(5));
607 }
608}