Skip to main content

hdds_persistence/
subscriber.rs

1// SPDX-License-Identifier: Apache-2.0 OR MIT
2// Copyright (c) 2025-2026 naskel.com
3
4//! Durability subscriber
5//!
6//! Subscribes to TRANSIENT/PERSISTENT topics and stores samples.
7//!
8//! # Operation
9//!
10//! 1. Discover writers matching topic filter
11//! 2. Create DataReaders for each matching topic
12//! 3. Poll readers and store received samples
13//! 4. Apply retention policy periodically
14
15use crate::config::Config;
16use crate::dds_interface::{
17    DataReader, DdsInterface, DiscoveredReader, DiscoveredWriter, DiscoveryCallback,
18};
19use crate::store::{PersistenceStore, RetentionPolicy, Sample};
20use anyhow::Result;
21use std::collections::{HashMap, HashSet};
22use std::sync::Arc;
23use std::time::Duration;
24use tokio::sync::mpsc;
25use tokio::sync::RwLock;
26use tokio::time::interval;
27
28/// Durability subscriber
29///
30/// Subscribes to topics matching the filter and persists all received samples.
31pub struct DurabilitySubscriber<S: PersistenceStore, D: DdsInterface> {
32    config: Config,
33    store: Arc<RwLock<S>>,
34    dds: Arc<D>,
35    /// Active readers by topic
36    readers: HashMap<String, Box<dyn DataReader>>,
37    /// Known writers (to detect new ones)
38    known_writers: HashSet<[u8; 16]>,
39    /// Topics we've created readers for
40    subscribed_topics: HashSet<String>,
41    /// Per-topic retention hint derived from writer durability settings.
42    retention_hints: HashMap<String, usize>,
43    /// Statistics
44    stats: SubscriberStats,
45}
46
47/// Subscriber statistics
48#[derive(Debug, Default, Clone)]
49pub struct SubscriberStats {
50    /// Total samples received
51    pub samples_received: u64,
52    /// Total samples stored
53    pub samples_stored: u64,
54    /// Storage errors
55    pub storage_errors: u64,
56    /// Writers discovered
57    pub writers_discovered: u64,
58    /// Topics subscribed
59    pub topics_subscribed: u64,
60}
61
62#[allow(dead_code)]
63enum DiscoveryEvent {
64    Writer(DiscoveredWriter),
65    Reader(DiscoveredReader),
66}
67
68struct DiscoveryBridge {
69    tx: mpsc::Sender<DiscoveryEvent>,
70}
71
72impl DiscoveryCallback for DiscoveryBridge {
73    fn on_reader_discovered(&self, reader: DiscoveredReader) {
74        if self.tx.try_send(DiscoveryEvent::Reader(reader)).is_err() {
75            tracing::debug!("Dropping reader discovery event (channel full)");
76        }
77    }
78
79    fn on_reader_removed(&self, _guid: [u8; 16]) {}
80
81    fn on_writer_discovered(&self, writer: DiscoveredWriter) {
82        if self.tx.try_send(DiscoveryEvent::Writer(writer)).is_err() {
83            tracing::debug!("Dropping writer discovery event (channel full)");
84        }
85    }
86
87    fn on_writer_removed(&self, _guid: [u8; 16]) {}
88}
89
90impl<S: PersistenceStore + Send + Sync, D: DdsInterface> DurabilitySubscriber<S, D> {
91    /// Create a new durability subscriber
92    pub fn new(config: Config, store: Arc<RwLock<S>>, dds: Arc<D>) -> Self {
93        Self {
94            config,
95            store,
96            dds,
97            readers: HashMap::new(),
98            known_writers: HashSet::new(),
99            subscribed_topics: HashSet::new(),
100            retention_hints: HashMap::new(),
101            stats: SubscriberStats::default(),
102        }
103    }
104
105    /// Get subscriber statistics
106    pub fn stats(&self) -> &SubscriberStats {
107        &self.stats
108    }
109
110    /// Run the subscriber
111    pub async fn run(mut self) -> Result<()> {
112        tracing::info!(
113            "DurabilitySubscriber started for topics: {}",
114            self.config.topic_filter
115        );
116
117        let (event_tx, mut event_rx) = mpsc::channel(128);
118        let bridge = Arc::new(DiscoveryBridge { tx: event_tx });
119        self.dds.register_discovery_callback(bridge)?;
120
121        // Initial snapshot
122        self.discover_and_subscribe().await?;
123
124        // Polling intervals
125        let mut sample_interval = interval(Duration::from_millis(100));
126        let mut retention_interval = interval(Duration::from_secs(60));
127
128        loop {
129            tokio::select! {
130                // Poll for samples (high frequency)
131                _ = sample_interval.tick() => {
132                    self.poll_samples().await?;
133                }
134                Some(event) = event_rx.recv() => {
135                    if let DiscoveryEvent::Writer(writer) = event {
136                        self.handle_writer_discovered(writer)?;
137                    }
138                }
139
140                // Apply retention policy (low frequency)
141                _ = retention_interval.tick() => {
142                    self.apply_retention().await?;
143                }
144            }
145        }
146    }
147
148    /// Discover writers and create readers for matching topics
149    async fn discover_and_subscribe(&mut self) -> Result<()> {
150        let writers = self.dds.discovered_writers(&self.config.topic_filter)?;
151
152        for writer in writers {
153            self.handle_writer_discovered(writer)?;
154        }
155
156        Ok(())
157    }
158
159    fn handle_writer_discovered(&mut self, writer: DiscoveredWriter) -> Result<()> {
160        // Skip if we already know this writer
161        if self.known_writers.contains(&writer.guid) {
162            return Ok(());
163        }
164
165        // Only subscribe to durable writers
166        if !writer.durability.is_durable() && !self.config.subscribe_volatile {
167            tracing::debug!("Skipping volatile writer for topic {}", writer.topic);
168            return Ok(());
169        }
170
171        self.known_writers.insert(writer.guid);
172        self.stats.writers_discovered += 1;
173
174        tracing::info!(
175            "Discovered writer for topic: {} (type: {})",
176            writer.topic,
177            writer.type_name
178        );
179
180        if let Some(hint) = writer.retention_hint {
181            self.update_retention_hint(&writer.topic, hint);
182        }
183
184        // Create reader if we haven't already
185        if !self.subscribed_topics.contains(&writer.topic) {
186            self.create_reader_for_topic(&writer)?;
187        }
188
189        Ok(())
190    }
191
192    /// Create a DataReader for a topic
193    fn create_reader_for_topic(&mut self, writer: &DiscoveredWriter) -> Result<()> {
194        let reader = self
195            .dds
196            .create_reader(&writer.topic, &writer.type_name, writer.durability)?;
197
198        tracing::info!(
199            "Created reader for topic: {} (type: {})",
200            writer.topic,
201            writer.type_name
202        );
203
204        self.readers.insert(writer.topic.clone(), reader);
205        self.subscribed_topics.insert(writer.topic.clone());
206        self.stats.topics_subscribed += 1;
207
208        Ok(())
209    }
210
211    /// Poll all readers for new samples
212    async fn poll_samples(&mut self) -> Result<()> {
213        let store = self.store.read().await;
214
215        for (topic, reader) in &self.readers {
216            // Take samples (removes from reader cache)
217            match reader.take() {
218                Ok(samples) => {
219                    for received in samples {
220                        self.stats.samples_received += 1;
221
222                        let sample = Sample {
223                            topic: received.topic,
224                            type_name: received.type_name,
225                            payload: received.payload,
226                            timestamp_ns: received.timestamp_ns,
227                            sequence: received.sequence,
228                            source_guid: received.writer_guid,
229                        };
230
231                        match store.save(&sample) {
232                            Ok(()) => {
233                                self.stats.samples_stored += 1;
234                                tracing::trace!(
235                                    "Stored sample: topic={}, seq={}",
236                                    sample.topic,
237                                    sample.sequence
238                                );
239                            }
240                            Err(e) => {
241                                self.stats.storage_errors += 1;
242                                tracing::error!("Failed to store sample for {}: {}", topic, e);
243                            }
244                        }
245                    }
246                }
247                Err(e) => {
248                    tracing::warn!("Failed to take samples from {}: {}", topic, e);
249                }
250            }
251        }
252
253        Ok(())
254    }
255
256    /// Apply retention policy to all topics
257    async fn apply_retention(&self) -> Result<()> {
258        let store = self.store.read().await;
259
260        for topic in &self.subscribed_topics {
261            let policy = self.retention_policy_for_topic(topic);
262            if policy.is_noop() {
263                continue;
264            }
265            if let Err(e) = store.apply_retention_policy(topic, &policy) {
266                tracing::warn!("Failed to apply retention to {}: {}", topic, e);
267            }
268        }
269
270        tracing::debug!(
271            "Applied retention policy ({} samples) to {} topics",
272            self.config.retention_count,
273            self.subscribed_topics.len()
274        );
275
276        Ok(())
277    }
278
279    fn retention_policy_for_topic(&self, topic: &str) -> RetentionPolicy {
280        let mut policy = RetentionPolicy {
281            keep_count: self.config.retention_count,
282            max_age_ns: if self.config.retention_time_secs > 0 {
283                Some(
284                    self.config
285                        .retention_time_secs
286                        .saturating_mul(1_000_000_000),
287                )
288            } else {
289                None
290            },
291            max_bytes: if self.config.retention_size_bytes > 0 {
292                Some(self.config.retention_size_bytes)
293            } else {
294                None
295            },
296        };
297
298        if let Some(keep_hint) = self.retention_hints.get(topic).copied() {
299            if keep_hint > 0 {
300                if policy.keep_count == 0 {
301                    policy.keep_count = keep_hint;
302                } else {
303                    policy.keep_count = policy.keep_count.min(keep_hint);
304                }
305            }
306        }
307
308        policy
309    }
310
311    fn update_retention_hint(&mut self, topic: &str, hint: RetentionPolicy) {
312        if hint.keep_count == 0 {
313            return;
314        }
315
316        let entry = self
317            .retention_hints
318            .entry(topic.to_string())
319            .or_insert(hint.keep_count);
320        *entry = (*entry).max(hint.keep_count);
321    }
322}
323
324// ============================================================================
325// Standalone mode (without DDS - for testing and CLI)
326// ============================================================================
327
328/// Standalone durability subscriber that accepts samples via channel
329pub struct StandaloneSubscriber<S: PersistenceStore> {
330    config: Config,
331    store: Arc<RwLock<S>>,
332    rx: tokio::sync::mpsc::Receiver<Sample>,
333    stats: SubscriberStats,
334}
335
336impl<S: PersistenceStore + Send + Sync> StandaloneSubscriber<S> {
337    /// Create a new standalone subscriber
338    ///
339    /// Returns the subscriber and a sender for pushing samples.
340    pub fn new(config: Config, store: Arc<RwLock<S>>) -> (Self, tokio::sync::mpsc::Sender<Sample>) {
341        let (tx, rx) = tokio::sync::mpsc::channel(1000);
342
343        let subscriber = Self {
344            config,
345            store,
346            rx,
347            stats: SubscriberStats::default(),
348        };
349
350        (subscriber, tx)
351    }
352
353    /// Get statistics
354    pub fn stats(&self) -> &SubscriberStats {
355        &self.stats
356    }
357
358    /// Run the standalone subscriber
359    pub async fn run(mut self) -> Result<()> {
360        tracing::info!(
361            "StandaloneSubscriber started for topics: {}",
362            self.config.topic_filter
363        );
364
365        let mut retention_interval = interval(Duration::from_secs(60));
366
367        loop {
368            tokio::select! {
369                // Receive sample from channel
370                Some(sample) = self.rx.recv() => {
371                    self.stats.samples_received += 1;
372
373                    // Check if topic matches filter
374                    if !topic_matches(&self.config.topic_filter, &sample.topic) {
375                        continue;
376                    }
377
378                    let store = self.store.read().await;
379                    match store.save(&sample) {
380                        Ok(()) => {
381                            self.stats.samples_stored += 1;
382                            tracing::trace!(
383                                "Stored sample: topic={}, seq={}",
384                                sample.topic,
385                                sample.sequence
386                            );
387                        }
388                        Err(e) => {
389                            self.stats.storage_errors += 1;
390                            tracing::error!("Failed to store sample: {}", e);
391                        }
392                    }
393                }
394
395                // Apply retention
396                _ = retention_interval.tick() => {
397                    self.apply_retention().await?;
398                }
399
400                else => break,
401            }
402        }
403
404        Ok(())
405    }
406
407    /// Apply retention policy
408    async fn apply_retention(&self) -> Result<()> {
409        let policy = RetentionPolicy {
410            keep_count: self.config.retention_count,
411            max_age_ns: if self.config.retention_time_secs > 0 {
412                Some(
413                    self.config
414                        .retention_time_secs
415                        .saturating_mul(1_000_000_000),
416                )
417            } else {
418                None
419            },
420            max_bytes: if self.config.retention_size_bytes > 0 {
421                Some(self.config.retention_size_bytes)
422            } else {
423                None
424            },
425        };
426        if policy.is_noop() {
427            return Ok(());
428        }
429
430        let store = self.store.read().await;
431
432        // Get unique topics from store
433        let all_samples = store.query_range("*", 0, u64::MAX)?;
434        let topics: HashSet<_> = all_samples.iter().map(|s| s.topic.clone()).collect();
435
436        for topic in &topics {
437            if let Err(e) = store.apply_retention_policy(topic, &policy) {
438                tracing::warn!("Failed to apply retention to {}: {}", topic, e);
439            }
440        }
441
442        Ok(())
443    }
444}
445
446/// Check if a topic matches a pattern
447fn topic_matches(pattern: &str, topic: &str) -> bool {
448    if pattern == "*" {
449        return true;
450    }
451    if let Some(prefix) = pattern.strip_suffix("/*") {
452        return topic.starts_with(prefix) && topic.len() > prefix.len();
453    }
454    pattern == topic
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460    use crate::dds_interface::MockDdsInterface;
461    use crate::sqlite::SqliteStore;
462
463    #[test]
464    fn test_topic_matches() {
465        assert!(topic_matches("*", "any/topic"));
466        assert!(topic_matches("State/*", "State/Temperature"));
467        assert!(!topic_matches("State/*", "Command/Set"));
468        assert!(topic_matches("exact", "exact"));
469        assert!(!topic_matches("exact", "other"));
470    }
471
472    #[tokio::test]
473    async fn test_standalone_subscriber() {
474        let config = Config::builder()
475            .topic_filter("State/*")
476            .retention_count(100)
477            .build();
478
479        let store = SqliteStore::new_in_memory().unwrap();
480        let store = Arc::new(RwLock::new(store));
481
482        let (subscriber, tx) = StandaloneSubscriber::new(config, Arc::clone(&store));
483
484        // Send a sample
485        let sample = Sample {
486            topic: "State/Temperature".to_string(),
487            type_name: "Temperature".to_string(),
488            payload: vec![1, 2, 3],
489            timestamp_ns: 1000,
490            sequence: 1,
491            source_guid: [0xAA; 16],
492        };
493
494        tx.send(sample).await.unwrap();
495
496        // Run subscriber briefly
497        let handle = tokio::spawn(async move {
498            tokio::time::timeout(Duration::from_millis(200), subscriber.run()).await
499        });
500
501        // Wait a bit then drop sender to stop subscriber
502        tokio::time::sleep(Duration::from_millis(100)).await;
503        drop(tx);
504
505        let _ = handle.await;
506
507        // Verify sample was stored
508        let store = store.read().await;
509        let samples = store.load("State/Temperature").unwrap();
510        assert_eq!(samples.len(), 1);
511        assert_eq!(samples[0].sequence, 1);
512    }
513
514    #[test]
515    fn test_durability_subscriber_creation() {
516        let config = Config::builder()
517            .topic_filter("State/*")
518            .retention_count(100)
519            .build();
520
521        let store = SqliteStore::new_in_memory().unwrap();
522        let store = Arc::new(RwLock::new(store));
523
524        let dds = Arc::new(MockDdsInterface::new());
525
526        let subscriber = DurabilitySubscriber::new(config, store, dds);
527        assert_eq!(subscriber.stats.samples_received, 0);
528    }
529
530    #[tokio::test]
531    async fn test_discover_and_subscribe() {
532        use crate::dds_interface::{DiscoveredWriter, DurabilityKind};
533
534        let config = Config::builder()
535            .topic_filter("State/*")
536            .subscribe_volatile(true)
537            .build();
538
539        let store = SqliteStore::new_in_memory().unwrap();
540        let store = Arc::new(RwLock::new(store));
541
542        let dds = Arc::new(MockDdsInterface::new());
543
544        // Add a mock writer
545        dds.add_writer(DiscoveredWriter {
546            guid: [0x01; 16],
547            topic: "State/Temperature".to_string(),
548            type_name: "Temperature".to_string(),
549            durability: DurabilityKind::TransientLocal,
550            retention_hint: None,
551        });
552
553        let mut subscriber = DurabilitySubscriber::new(config, store, dds);
554
555        // Run discovery
556        subscriber.discover_and_subscribe().await.unwrap();
557
558        assert_eq!(subscriber.stats.writers_discovered, 1);
559        assert_eq!(subscriber.stats.topics_subscribed, 1);
560        assert!(subscriber.subscribed_topics.contains("State/Temperature"));
561    }
562}