Skip to main content

laser_dac/
session.rs

1//! Reconnecting session wrapper for automatic reconnection.
2
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::{Arc, Mutex};
5use std::time::Duration;
6
7use crate::backend::{Error, Result};
8use crate::discovery::DacDiscovery;
9use crate::stream::{Dac, StreamControl};
10use crate::types::{ChunkRequest, DacInfo, LaserPoint, RunExit, StreamConfig};
11
12type DisconnectCallback = Box<dyn FnMut(&Error) + Send + 'static>;
13type ReconnectCallback = Box<dyn FnMut(&DacInfo) + Send + 'static>;
14type DiscoveryFactory = Box<dyn Fn() -> DacDiscovery + Send + 'static>;
15
16// =============================================================================
17// Session Control
18// =============================================================================
19
20/// Control handle for a [`ReconnectingSession`].
21///
22/// This mirrors `StreamControl`, but survives reconnections by attaching
23/// to each new stream as it is created.
24#[derive(Clone)]
25pub struct SessionControl {
26    inner: Arc<SessionControlInner>,
27}
28
29struct SessionControlInner {
30    armed: AtomicBool,
31    stop_requested: AtomicBool,
32    current: Mutex<Option<StreamControl>>,
33}
34
35impl SessionControl {
36    fn new() -> Self {
37        Self {
38            inner: Arc::new(SessionControlInner {
39                armed: AtomicBool::new(false),
40                stop_requested: AtomicBool::new(false),
41                current: Mutex::new(None),
42            }),
43        }
44    }
45
46    fn attach(&self, control: StreamControl) {
47        *self.inner.current.lock().unwrap() = Some(control.clone());
48
49        if self.inner.stop_requested.load(Ordering::SeqCst) {
50            let _ = control.stop();
51            return;
52        }
53
54        if self.inner.armed.load(Ordering::SeqCst) {
55            let _ = control.arm();
56        } else {
57            let _ = control.disarm();
58        }
59    }
60
61    fn detach(&self) {
62        *self.inner.current.lock().unwrap() = None;
63    }
64
65    /// Arm the output (allow laser to fire).
66    pub fn arm(&self) -> Result<()> {
67        self.inner.armed.store(true, Ordering::SeqCst);
68        if let Some(control) = self.inner.current.lock().unwrap().as_ref() {
69            let _ = control.arm();
70        }
71        Ok(())
72    }
73
74    /// Disarm the output (force laser off).
75    pub fn disarm(&self) -> Result<()> {
76        self.inner.armed.store(false, Ordering::SeqCst);
77        if let Some(control) = self.inner.current.lock().unwrap().as_ref() {
78            let _ = control.disarm();
79        }
80        Ok(())
81    }
82
83    /// Check if the output is armed.
84    pub fn is_armed(&self) -> bool {
85        self.inner.armed.load(Ordering::SeqCst)
86    }
87
88    /// Request the session to stop.
89    pub fn stop(&self) -> Result<()> {
90        self.inner.stop_requested.store(true, Ordering::SeqCst);
91        if let Some(control) = self.inner.current.lock().unwrap().as_ref() {
92            let _ = control.stop();
93        }
94        Ok(())
95    }
96
97    /// Check if a stop has been requested.
98    pub fn is_stop_requested(&self) -> bool {
99        self.inner.stop_requested.load(Ordering::SeqCst)
100    }
101}
102
103// =============================================================================
104// Reconnecting Session
105// =============================================================================
106
107/// A reconnecting wrapper around the streaming API.
108///
109/// This helper reconnects to a device by ID and restarts streaming
110/// automatically when a disconnection occurs.
111///
112/// By default this uses `open_device()` internally. To use custom DAC
113/// backends, call [`with_discovery`](Self::with_discovery) with a factory
114/// function that creates a configured [`DacDiscovery`].
115///
116/// # Example
117///
118/// ```no_run
119/// use laser_dac::{ReconnectingSession, StreamConfig};
120/// use std::time::Duration;
121///
122/// let mut session = ReconnectingSession::new("my-device", StreamConfig::new(30_000))
123///     .with_max_retries(5)
124///     .with_backoff(Duration::from_secs(1))
125///     .on_disconnect(|err| eprintln!("Lost connection: {}", err))
126///     .on_reconnect(|info| println!("Reconnected to {}", info.name));
127///
128/// session.control().arm()?;
129///
130/// session.run(
131///     |req| Some(vec![laser_dac::LaserPoint::blanked(0.0, 0.0); req.n_points]),
132///     |err| eprintln!("Stream error: {}", err),
133/// )?;
134/// # Ok::<(), laser_dac::Error>(())
135/// ```
136pub struct ReconnectingSession {
137    device_id: String,
138    config: StreamConfig,
139    max_retries: Option<u32>,
140    backoff: Duration,
141    on_disconnect: Arc<Mutex<Option<DisconnectCallback>>>,
142    on_reconnect: Option<ReconnectCallback>,
143    control: SessionControl,
144    discovery_factory: Option<DiscoveryFactory>,
145}
146
147impl ReconnectingSession {
148    /// Create a new reconnecting session for a device ID.
149    pub fn new(device_id: impl Into<String>, config: StreamConfig) -> Self {
150        Self {
151            device_id: device_id.into(),
152            config,
153            max_retries: None,
154            backoff: Duration::from_secs(1),
155            on_disconnect: Arc::new(Mutex::new(None)),
156            on_reconnect: None,
157            control: SessionControl::new(),
158            discovery_factory: None,
159        }
160    }
161
162    /// Set the maximum number of reconnect attempts.
163    ///
164    /// `None` (default) retries forever. `Some(0)` disables retries.
165    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
166        self.max_retries = Some(max_retries);
167        self
168    }
169
170    /// Set a fixed backoff duration between reconnect attempts.
171    pub fn with_backoff(mut self, backoff: Duration) -> Self {
172        self.backoff = backoff;
173        self
174    }
175
176    /// Register a callback invoked when a disconnect is detected.
177    pub fn on_disconnect<F>(self, f: F) -> Self
178    where
179        F: FnMut(&Error) + Send + 'static,
180    {
181        *self.on_disconnect.lock().unwrap() = Some(Box::new(f));
182        self
183    }
184
185    /// Register a callback invoked after a successful reconnect.
186    pub fn on_reconnect<F>(mut self, f: F) -> Self
187    where
188        F: FnMut(&DacInfo) + Send + 'static,
189    {
190        self.on_reconnect = Some(Box::new(f));
191        self
192    }
193
194    /// Use a custom discovery factory for opening devices.
195    ///
196    /// This allows using custom DAC backends by providing a factory function
197    /// that creates a [`DacDiscovery`] with external discoverers registered.
198    ///
199    /// # Example
200    ///
201    /// ```no_run
202    /// use laser_dac::{DacDiscovery, EnabledDacTypes, ReconnectingSession, StreamConfig};
203    ///
204    /// let session = ReconnectingSession::new("custom:my-device", StreamConfig::new(30_000))
205    ///     .with_discovery(|| {
206    ///         let mut discovery = DacDiscovery::new(EnabledDacTypes::all());
207    ///         // discovery.register(my_custom_discoverer);
208    ///         discovery
209    ///     });
210    /// ```
211    pub fn with_discovery<F>(mut self, factory: F) -> Self
212    where
213        F: Fn() -> DacDiscovery + Send + 'static,
214    {
215        self.discovery_factory = Some(Box::new(factory));
216        self
217    }
218
219    /// Returns a control handle for arm/disarm/stop.
220    pub fn control(&self) -> SessionControl {
221        self.control.clone()
222    }
223
224    /// Run the stream, automatically reconnecting on disconnection.
225    pub fn run<F, E>(&mut self, producer: F, on_error: E) -> Result<RunExit>
226    where
227        F: FnMut(ChunkRequest) -> Option<Vec<LaserPoint>> + Send + 'static,
228        E: FnMut(Error) + Send + 'static,
229    {
230        let producer = Arc::new(Mutex::new(producer));
231        let on_error = Arc::new(Mutex::new(on_error));
232        let on_disconnect = Arc::clone(&self.on_disconnect);
233        let mut connected_once = false;
234        let mut retries = 0u32;
235
236        loop {
237            if self.control.is_stop_requested() {
238                return Ok(RunExit::Stopped);
239            }
240
241            if let Some(max) = self.max_retries {
242                if retries >= max {
243                    return Ok(RunExit::Disconnected);
244                }
245            }
246
247            let device = match self.open_device() {
248                Ok(device) => device,
249                Err(err) => {
250                    if !Self::is_retriable_connect_error(&err) {
251                        return Err(err);
252                    }
253                    {
254                        let mut handler = on_error.lock().unwrap();
255                        handler(err);
256                    }
257                    retries = retries.saturating_add(1);
258                    if let Some(max) = self.max_retries {
259                        if retries >= max {
260                            return Ok(RunExit::Disconnected);
261                        }
262                    }
263                    if self.sleep_with_stop(self.backoff) {
264                        return Ok(RunExit::Stopped);
265                    }
266                    continue;
267                }
268            };
269
270            let (stream, info) = match device.start_stream(self.config.clone()) {
271                Ok(result) => result,
272                Err(err) => {
273                    if !Self::is_retriable_connect_error(&err) {
274                        return Err(err);
275                    }
276                    {
277                        let mut handler = on_error.lock().unwrap();
278                        handler(err);
279                    }
280                    retries = retries.saturating_add(1);
281                    if let Some(max) = self.max_retries {
282                        if retries >= max {
283                            return Ok(RunExit::Disconnected);
284                        }
285                    }
286                    if self.sleep_with_stop(self.backoff) {
287                        return Ok(RunExit::Stopped);
288                    }
289                    continue;
290                }
291            };
292
293            if connected_once {
294                if let Some(cb) = self.on_reconnect.as_mut() {
295                    cb(&info);
296                }
297            }
298            connected_once = true;
299            retries = 0;
300
301            self.control.attach(stream.control());
302
303            let producer_handle = Arc::clone(&producer);
304            let on_error_handle = Arc::clone(&on_error);
305            let on_disconnect_handle = Arc::clone(&on_disconnect);
306            let exit = match stream.run(
307                move |req| {
308                    let mut handler = producer_handle.lock().unwrap();
309                    handler(req)
310                },
311                move |err| {
312                    if err.is_disconnected() {
313                        if let Some(cb) = on_disconnect_handle.lock().unwrap().as_mut() {
314                            cb(&err);
315                        }
316                    }
317                    let mut handler = on_error_handle.lock().unwrap();
318                    handler(err)
319                },
320            ) {
321                Ok(exit) => exit,
322                Err(err) => {
323                    self.control.detach();
324                    return Err(err);
325                }
326            };
327
328            self.control.detach();
329
330            match exit {
331                RunExit::Disconnected => {
332                    if let Some(max) = self.max_retries {
333                        if retries >= max {
334                            return Ok(RunExit::Disconnected);
335                        }
336                    }
337                    if self.sleep_with_stop(self.backoff) {
338                        return Ok(RunExit::Stopped);
339                    }
340                    continue;
341                }
342                other => return Ok(other),
343            }
344        }
345    }
346
347    fn open_device(&mut self) -> Result<Dac> {
348        if let Some(factory) = &self.discovery_factory {
349            let mut discovery = factory();
350            let discovered = discovery.scan();
351
352            let device = discovered
353                .into_iter()
354                .find(|d| d.info().stable_id() == self.device_id)
355                .ok_or_else(|| {
356                    Error::disconnected(format!("device not found: {}", self.device_id))
357                })?;
358
359            let info = device.info();
360            let name = info.name();
361            let dac_type = device.dac_type();
362            let stream_backend = discovery.connect(device)?;
363
364            let device_info = crate::types::DacInfo {
365                id: self.device_id.clone(),
366                name,
367                kind: dac_type,
368                caps: stream_backend.caps().clone(),
369            };
370
371            Ok(Dac::new(device_info, stream_backend))
372        } else {
373            crate::open_device(&self.device_id)
374        }
375    }
376
377    fn is_retriable_connect_error(err: &Error) -> bool {
378        !matches!(err, Error::InvalidConfig(_) | Error::Stopped)
379    }
380
381    fn sleep_with_stop(&self, duration: Duration) -> bool {
382        const SLICE: Duration = Duration::from_millis(50);
383        let mut remaining = duration;
384        while remaining > Duration::ZERO {
385            if self.control.is_stop_requested() {
386                return true;
387            }
388            let slice = remaining.min(SLICE);
389            std::thread::sleep(slice);
390            remaining = remaining.saturating_sub(slice);
391        }
392        self.control.is_stop_requested()
393    }
394}