1use crate::config::Config;
16use crate::dds_interface::{
17 DataReader, DdsInterface, DiscoveredReader, DiscoveredWriter, DiscoveryCallback,
18};
19use crate::store::{PersistenceStore, RetentionPolicy, Sample};
20use anyhow::Result;
21use std::collections::{HashMap, HashSet};
22use std::sync::Arc;
23use std::time::Duration;
24use tokio::sync::mpsc;
25use tokio::sync::RwLock;
26use tokio::time::interval;
27
28pub struct DurabilitySubscriber<S: PersistenceStore, D: DdsInterface> {
32 config: Config,
33 store: Arc<RwLock<S>>,
34 dds: Arc<D>,
35 readers: HashMap<String, Box<dyn DataReader>>,
37 known_writers: HashSet<[u8; 16]>,
39 subscribed_topics: HashSet<String>,
41 retention_hints: HashMap<String, usize>,
43 stats: SubscriberStats,
45}
46
47#[derive(Debug, Default, Clone)]
49pub struct SubscriberStats {
50 pub samples_received: u64,
52 pub samples_stored: u64,
54 pub storage_errors: u64,
56 pub writers_discovered: u64,
58 pub topics_subscribed: u64,
60}
61
62#[allow(dead_code)]
63enum DiscoveryEvent {
64 Writer(DiscoveredWriter),
65 Reader(DiscoveredReader),
66}
67
68struct DiscoveryBridge {
69 tx: mpsc::Sender<DiscoveryEvent>,
70}
71
72impl DiscoveryCallback for DiscoveryBridge {
73 fn on_reader_discovered(&self, reader: DiscoveredReader) {
74 if self.tx.try_send(DiscoveryEvent::Reader(reader)).is_err() {
75 tracing::debug!("Dropping reader discovery event (channel full)");
76 }
77 }
78
79 fn on_reader_removed(&self, _guid: [u8; 16]) {}
80
81 fn on_writer_discovered(&self, writer: DiscoveredWriter) {
82 if self.tx.try_send(DiscoveryEvent::Writer(writer)).is_err() {
83 tracing::debug!("Dropping writer discovery event (channel full)");
84 }
85 }
86
87 fn on_writer_removed(&self, _guid: [u8; 16]) {}
88}
89
90impl<S: PersistenceStore + Send + Sync, D: DdsInterface> DurabilitySubscriber<S, D> {
91 pub fn new(config: Config, store: Arc<RwLock<S>>, dds: Arc<D>) -> Self {
93 Self {
94 config,
95 store,
96 dds,
97 readers: HashMap::new(),
98 known_writers: HashSet::new(),
99 subscribed_topics: HashSet::new(),
100 retention_hints: HashMap::new(),
101 stats: SubscriberStats::default(),
102 }
103 }
104
105 pub fn stats(&self) -> &SubscriberStats {
107 &self.stats
108 }
109
110 pub async fn run(mut self) -> Result<()> {
112 tracing::info!(
113 "DurabilitySubscriber started for topics: {}",
114 self.config.topic_filter
115 );
116
117 let (event_tx, mut event_rx) = mpsc::channel(128);
118 let bridge = Arc::new(DiscoveryBridge { tx: event_tx });
119 self.dds.register_discovery_callback(bridge)?;
120
121 self.discover_and_subscribe().await?;
123
124 let mut sample_interval = interval(Duration::from_millis(100));
126 let mut retention_interval = interval(Duration::from_secs(60));
127
128 loop {
129 tokio::select! {
130 _ = sample_interval.tick() => {
132 self.poll_samples().await?;
133 }
134 Some(event) = event_rx.recv() => {
135 if let DiscoveryEvent::Writer(writer) = event {
136 self.handle_writer_discovered(writer)?;
137 }
138 }
139
140 _ = retention_interval.tick() => {
142 self.apply_retention().await?;
143 }
144 }
145 }
146 }
147
148 async fn discover_and_subscribe(&mut self) -> Result<()> {
150 let writers = self.dds.discovered_writers(&self.config.topic_filter)?;
151
152 for writer in writers {
153 self.handle_writer_discovered(writer)?;
154 }
155
156 Ok(())
157 }
158
159 fn handle_writer_discovered(&mut self, writer: DiscoveredWriter) -> Result<()> {
160 if self.known_writers.contains(&writer.guid) {
162 return Ok(());
163 }
164
165 if !writer.durability.is_durable() && !self.config.subscribe_volatile {
167 tracing::debug!("Skipping volatile writer for topic {}", writer.topic);
168 return Ok(());
169 }
170
171 self.known_writers.insert(writer.guid);
172 self.stats.writers_discovered += 1;
173
174 tracing::info!(
175 "Discovered writer for topic: {} (type: {})",
176 writer.topic,
177 writer.type_name
178 );
179
180 if let Some(hint) = writer.retention_hint {
181 self.update_retention_hint(&writer.topic, hint);
182 }
183
184 if !self.subscribed_topics.contains(&writer.topic) {
186 self.create_reader_for_topic(&writer)?;
187 }
188
189 Ok(())
190 }
191
192 fn create_reader_for_topic(&mut self, writer: &DiscoveredWriter) -> Result<()> {
194 let reader = self
195 .dds
196 .create_reader(&writer.topic, &writer.type_name, writer.durability)?;
197
198 tracing::info!(
199 "Created reader for topic: {} (type: {})",
200 writer.topic,
201 writer.type_name
202 );
203
204 self.readers.insert(writer.topic.clone(), reader);
205 self.subscribed_topics.insert(writer.topic.clone());
206 self.stats.topics_subscribed += 1;
207
208 Ok(())
209 }
210
211 async fn poll_samples(&mut self) -> Result<()> {
213 let store = self.store.read().await;
214
215 for (topic, reader) in &self.readers {
216 match reader.take() {
218 Ok(samples) => {
219 for received in samples {
220 self.stats.samples_received += 1;
221
222 let sample = Sample {
223 topic: received.topic,
224 type_name: received.type_name,
225 payload: received.payload,
226 timestamp_ns: received.timestamp_ns,
227 sequence: received.sequence,
228 source_guid: received.writer_guid,
229 };
230
231 match store.save(&sample) {
232 Ok(()) => {
233 self.stats.samples_stored += 1;
234 tracing::trace!(
235 "Stored sample: topic={}, seq={}",
236 sample.topic,
237 sample.sequence
238 );
239 }
240 Err(e) => {
241 self.stats.storage_errors += 1;
242 tracing::error!("Failed to store sample for {}: {}", topic, e);
243 }
244 }
245 }
246 }
247 Err(e) => {
248 tracing::warn!("Failed to take samples from {}: {}", topic, e);
249 }
250 }
251 }
252
253 Ok(())
254 }
255
256 async fn apply_retention(&self) -> Result<()> {
258 let store = self.store.read().await;
259
260 for topic in &self.subscribed_topics {
261 let policy = self.retention_policy_for_topic(topic);
262 if policy.is_noop() {
263 continue;
264 }
265 if let Err(e) = store.apply_retention_policy(topic, &policy) {
266 tracing::warn!("Failed to apply retention to {}: {}", topic, e);
267 }
268 }
269
270 tracing::debug!(
271 "Applied retention policy ({} samples) to {} topics",
272 self.config.retention_count,
273 self.subscribed_topics.len()
274 );
275
276 Ok(())
277 }
278
279 fn retention_policy_for_topic(&self, topic: &str) -> RetentionPolicy {
280 let mut policy = RetentionPolicy {
281 keep_count: self.config.retention_count,
282 max_age_ns: if self.config.retention_time_secs > 0 {
283 Some(
284 self.config
285 .retention_time_secs
286 .saturating_mul(1_000_000_000),
287 )
288 } else {
289 None
290 },
291 max_bytes: if self.config.retention_size_bytes > 0 {
292 Some(self.config.retention_size_bytes)
293 } else {
294 None
295 },
296 };
297
298 if let Some(keep_hint) = self.retention_hints.get(topic).copied() {
299 if keep_hint > 0 {
300 if policy.keep_count == 0 {
301 policy.keep_count = keep_hint;
302 } else {
303 policy.keep_count = policy.keep_count.min(keep_hint);
304 }
305 }
306 }
307
308 policy
309 }
310
311 fn update_retention_hint(&mut self, topic: &str, hint: RetentionPolicy) {
312 if hint.keep_count == 0 {
313 return;
314 }
315
316 let entry = self
317 .retention_hints
318 .entry(topic.to_string())
319 .or_insert(hint.keep_count);
320 *entry = (*entry).max(hint.keep_count);
321 }
322}
323
324pub struct StandaloneSubscriber<S: PersistenceStore> {
330 config: Config,
331 store: Arc<RwLock<S>>,
332 rx: tokio::sync::mpsc::Receiver<Sample>,
333 stats: SubscriberStats,
334}
335
336impl<S: PersistenceStore + Send + Sync> StandaloneSubscriber<S> {
337 pub fn new(config: Config, store: Arc<RwLock<S>>) -> (Self, tokio::sync::mpsc::Sender<Sample>) {
341 let (tx, rx) = tokio::sync::mpsc::channel(1000);
342
343 let subscriber = Self {
344 config,
345 store,
346 rx,
347 stats: SubscriberStats::default(),
348 };
349
350 (subscriber, tx)
351 }
352
353 pub fn stats(&self) -> &SubscriberStats {
355 &self.stats
356 }
357
358 pub async fn run(mut self) -> Result<()> {
360 tracing::info!(
361 "StandaloneSubscriber started for topics: {}",
362 self.config.topic_filter
363 );
364
365 let mut retention_interval = interval(Duration::from_secs(60));
366
367 loop {
368 tokio::select! {
369 Some(sample) = self.rx.recv() => {
371 self.stats.samples_received += 1;
372
373 if !topic_matches(&self.config.topic_filter, &sample.topic) {
375 continue;
376 }
377
378 let store = self.store.read().await;
379 match store.save(&sample) {
380 Ok(()) => {
381 self.stats.samples_stored += 1;
382 tracing::trace!(
383 "Stored sample: topic={}, seq={}",
384 sample.topic,
385 sample.sequence
386 );
387 }
388 Err(e) => {
389 self.stats.storage_errors += 1;
390 tracing::error!("Failed to store sample: {}", e);
391 }
392 }
393 }
394
395 _ = retention_interval.tick() => {
397 self.apply_retention().await?;
398 }
399
400 else => break,
401 }
402 }
403
404 Ok(())
405 }
406
407 async fn apply_retention(&self) -> Result<()> {
409 let policy = RetentionPolicy {
410 keep_count: self.config.retention_count,
411 max_age_ns: if self.config.retention_time_secs > 0 {
412 Some(
413 self.config
414 .retention_time_secs
415 .saturating_mul(1_000_000_000),
416 )
417 } else {
418 None
419 },
420 max_bytes: if self.config.retention_size_bytes > 0 {
421 Some(self.config.retention_size_bytes)
422 } else {
423 None
424 },
425 };
426 if policy.is_noop() {
427 return Ok(());
428 }
429
430 let store = self.store.read().await;
431
432 let all_samples = store.query_range("*", 0, u64::MAX)?;
434 let topics: HashSet<_> = all_samples.iter().map(|s| s.topic.clone()).collect();
435
436 for topic in &topics {
437 if let Err(e) = store.apply_retention_policy(topic, &policy) {
438 tracing::warn!("Failed to apply retention to {}: {}", topic, e);
439 }
440 }
441
442 Ok(())
443 }
444}
445
446fn topic_matches(pattern: &str, topic: &str) -> bool {
448 if pattern == "*" {
449 return true;
450 }
451 if let Some(prefix) = pattern.strip_suffix("/*") {
452 return topic.starts_with(prefix) && topic.len() > prefix.len();
453 }
454 pattern == topic
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460 use crate::dds_interface::MockDdsInterface;
461 use crate::sqlite::SqliteStore;
462
463 #[test]
464 fn test_topic_matches() {
465 assert!(topic_matches("*", "any/topic"));
466 assert!(topic_matches("State/*", "State/Temperature"));
467 assert!(!topic_matches("State/*", "Command/Set"));
468 assert!(topic_matches("exact", "exact"));
469 assert!(!topic_matches("exact", "other"));
470 }
471
472 #[tokio::test]
473 async fn test_standalone_subscriber() {
474 let config = Config::builder()
475 .topic_filter("State/*")
476 .retention_count(100)
477 .build();
478
479 let store = SqliteStore::new_in_memory().unwrap();
480 let store = Arc::new(RwLock::new(store));
481
482 let (subscriber, tx) = StandaloneSubscriber::new(config, Arc::clone(&store));
483
484 let sample = Sample {
486 topic: "State/Temperature".to_string(),
487 type_name: "Temperature".to_string(),
488 payload: vec![1, 2, 3],
489 timestamp_ns: 1000,
490 sequence: 1,
491 source_guid: [0xAA; 16],
492 };
493
494 tx.send(sample).await.unwrap();
495
496 let handle = tokio::spawn(async move {
498 tokio::time::timeout(Duration::from_millis(200), subscriber.run()).await
499 });
500
501 tokio::time::sleep(Duration::from_millis(100)).await;
503 drop(tx);
504
505 let _ = handle.await;
506
507 let store = store.read().await;
509 let samples = store.load("State/Temperature").unwrap();
510 assert_eq!(samples.len(), 1);
511 assert_eq!(samples[0].sequence, 1);
512 }
513
514 #[test]
515 fn test_durability_subscriber_creation() {
516 let config = Config::builder()
517 .topic_filter("State/*")
518 .retention_count(100)
519 .build();
520
521 let store = SqliteStore::new_in_memory().unwrap();
522 let store = Arc::new(RwLock::new(store));
523
524 let dds = Arc::new(MockDdsInterface::new());
525
526 let subscriber = DurabilitySubscriber::new(config, store, dds);
527 assert_eq!(subscriber.stats.samples_received, 0);
528 }
529
530 #[tokio::test]
531 async fn test_discover_and_subscribe() {
532 use crate::dds_interface::{DiscoveredWriter, DurabilityKind};
533
534 let config = Config::builder()
535 .topic_filter("State/*")
536 .subscribe_volatile(true)
537 .build();
538
539 let store = SqliteStore::new_in_memory().unwrap();
540 let store = Arc::new(RwLock::new(store));
541
542 let dds = Arc::new(MockDdsInterface::new());
543
544 dds.add_writer(DiscoveredWriter {
546 guid: [0x01; 16],
547 topic: "State/Temperature".to_string(),
548 type_name: "Temperature".to_string(),
549 durability: DurabilityKind::TransientLocal,
550 retention_hint: None,
551 });
552
553 let mut subscriber = DurabilitySubscriber::new(config, store, dds);
554
555 subscriber.discover_and_subscribe().await.unwrap();
557
558 assert_eq!(subscriber.stats.writers_discovered, 1);
559 assert_eq!(subscriber.stats.topics_subscribed, 1);
560 assert!(subscriber.subscribed_topics.contains("State/Temperature"));
561 }
562}