Skip to main content

rsigma_runtime/sources/
refresh.rs

1//! Background refresh scheduler for dynamic pipeline sources.
2//!
3//! Manages per-source refresh loops based on `RefreshPolicy`:
4//! - `Interval(duration)`: re-fetches on a timer
5//! - `Watch`: uses file system notifications (via `notify`)
6//! - `Push`: receives updates from external triggers (NATS)
7//! - `OnDemand`: only refreshes when explicitly triggered via API/signal
8
9use std::collections::HashMap;
10use std::sync::Arc;
11
12use rsigma_eval::pipeline::sources::{DynamicSource, RefreshPolicy, SourceType};
13use tokio::sync::{mpsc, watch};
14
15use super::{SourceResolver, resolve_all};
16
17/// A message requesting source re-resolution.
18#[derive(Debug, Clone)]
19pub enum RefreshTrigger {
20    /// Re-resolve all sources.
21    All,
22    /// Re-resolve a specific source by ID.
23    Single(String),
24    /// A NATS push message arrived with pre-parsed data for a specific source.
25    #[cfg(feature = "nats")]
26    NatsPush {
27        source_id: String,
28        data: serde_json::Value,
29    },
30}
31
32/// Notification sent when sources have been refreshed.
33#[derive(Debug, Clone)]
34pub struct RefreshResult {
35    /// The newly resolved source data (source_id -> value).
36    pub resolved: HashMap<String, serde_json::Value>,
37}
38
39/// Manages background refresh tasks for dynamic sources.
40///
41/// The scheduler spawns per-source tasks based on their refresh policy and
42/// sends `RefreshResult` notifications whenever source data changes.
43pub struct RefreshScheduler {
44    /// Channel for on-demand refresh triggers (from API, SIGHUP, NATS control).
45    trigger_tx: mpsc::Sender<RefreshTrigger>,
46    /// Receiver for on-demand triggers (consumed by the run loop).
47    trigger_rx: Option<mpsc::Receiver<RefreshTrigger>>,
48    /// Watch channel sender for notifying consumers of updated source data.
49    result_tx: watch::Sender<Option<RefreshResult>>,
50    /// Watch channel receiver for consumers.
51    result_rx: watch::Receiver<Option<RefreshResult>>,
52}
53
54impl RefreshScheduler {
55    /// Create a new scheduler.
56    pub fn new() -> Self {
57        let (trigger_tx, trigger_rx) = mpsc::channel(32);
58        let (result_tx, result_rx) = watch::channel(None);
59        Self {
60            trigger_tx,
61            trigger_rx: Some(trigger_rx),
62            result_tx,
63            result_rx,
64        }
65    }
66
67    /// Get a sender for triggering on-demand resolution.
68    pub fn trigger_sender(&self) -> mpsc::Sender<RefreshTrigger> {
69        self.trigger_tx.clone()
70    }
71
72    /// Get a receiver that is notified when sources are refreshed.
73    pub fn result_receiver(&self) -> watch::Receiver<Option<RefreshResult>> {
74        self.result_rx.clone()
75    }
76
77    /// Start the scheduler for a detached, non-pipeline consumer.
78    ///
79    /// Returns a [`SourceSubscription`] bundling the spawned coordination task,
80    /// a [`watch`] receiver of decoded source payloads, and a trigger sender for
81    /// on-demand re-resolution. This is the seam used by consumers that want a
82    /// source's decoded payload directly (rather than bound into a pipeline
83    /// `${source.*}` namespace): it reuses the same per-source file, HTTP, and
84    /// NATS fetch and refresh machinery as the pipeline binder via [`Self::run`],
85    /// so that logic is never duplicated.
86    pub fn subscribe(
87        self,
88        sources: Vec<DynamicSource>,
89        resolver: Arc<dyn SourceResolver>,
90    ) -> SourceSubscription {
91        let results = self.result_receiver();
92        let trigger = self.trigger_sender();
93        let handle = self.run(sources, resolver);
94        SourceSubscription {
95            handle,
96            results,
97            trigger,
98        }
99    }
100
101    /// Start the scheduler background loop.
102    ///
103    /// Takes ownership of the trigger receiver and spawns per-source interval tasks.
104    /// Returns a `JoinHandle` for the main coordination task.
105    ///
106    /// When a refresh occurs (via interval timer or on-demand trigger), all sources
107    /// are re-resolved and the result is published on the watch channel.
108    pub fn run(
109        mut self,
110        sources: Vec<DynamicSource>,
111        resolver: Arc<dyn SourceResolver>,
112    ) -> tokio::task::JoinHandle<()> {
113        let trigger_rx = self
114            .trigger_rx
115            .take()
116            .expect("run() can only be called once");
117
118        tokio::spawn(async move {
119            Self::run_loop(
120                sources,
121                resolver,
122                trigger_rx,
123                self.trigger_tx,
124                self.result_tx,
125            )
126            .await;
127        })
128    }
129
130    async fn run_loop(
131        sources: Vec<DynamicSource>,
132        resolver: Arc<dyn SourceResolver>,
133        mut trigger_rx: mpsc::Receiver<RefreshTrigger>,
134        trigger_tx: mpsc::Sender<RefreshTrigger>,
135        result_tx: watch::Sender<Option<RefreshResult>>,
136    ) {
137        // Spawn interval timers
138        for source in &sources {
139            if let RefreshPolicy::Interval(duration) = &source.refresh {
140                let tx = trigger_tx.clone();
141                let id = source.id.clone();
142                let interval = if *duration < super::MIN_REFRESH_INTERVAL {
143                    tracing::warn!(
144                        source_id = %id,
145                        configured = ?duration,
146                        clamped_to = ?super::MIN_REFRESH_INTERVAL,
147                        "Refresh interval below minimum, clamping to floor"
148                    );
149                    super::MIN_REFRESH_INTERVAL
150                } else {
151                    *duration
152                };
153                tokio::spawn(async move {
154                    let mut timer = tokio::time::interval(interval);
155                    timer.tick().await; // skip immediate first tick
156                    loop {
157                        timer.tick().await;
158                        if tx.send(RefreshTrigger::Single(id.clone())).await.is_err() {
159                            break;
160                        }
161                    }
162                });
163            }
164        }
165
166        // Spawn NATS push subscriptions
167        #[cfg(feature = "nats")]
168        for source in &sources {
169            if source.refresh == RefreshPolicy::Push
170                && let SourceType::Nats {
171                    url,
172                    subject,
173                    format,
174                    extract: extract_expr,
175                } = &source.source_type
176            {
177                let tx = trigger_tx.clone();
178                let id = source.id.clone();
179                let url = url.clone();
180                let subject = subject.clone();
181                let format = *format;
182                let extract_expr = extract_expr.clone();
183                tokio::spawn(async move {
184                    if let Err(e) =
185                        nats_push_loop(&url, &subject, format, extract_expr.as_ref(), &id, &tx)
186                            .await
187                    {
188                        tracing::error!(
189                            source_id = %id,
190                            error = %e,
191                            "NATS push subscription failed"
192                        );
193                    }
194                });
195            }
196        }
197
198        // Spawn file watchers for Watch policy sources
199        for source in &sources {
200            if source.refresh == RefreshPolicy::Watch
201                && let SourceType::File { path, .. } = &source.source_type
202            {
203                let tx = trigger_tx.clone();
204                let id = source.id.clone();
205                let path = path.clone();
206                tokio::spawn(async move {
207                    file_watch_loop(&path, &id, &tx).await;
208                });
209            }
210        }
211
212        // Main loop: wait for triggers and resolve
213        while let Some(trigger) = trigger_rx.recv().await {
214            // Handle NATS push with pre-parsed data (no re-resolution needed)
215            #[cfg(feature = "nats")]
216            if let RefreshTrigger::NatsPush { source_id, data } = trigger {
217                let mut resolved = HashMap::new();
218                resolved.insert(source_id, data);
219                let _ = result_tx.send(Some(RefreshResult { resolved }));
220                continue;
221            }
222
223            let to_resolve: Vec<&DynamicSource> = match &trigger {
224                RefreshTrigger::All => sources.iter().collect(),
225                RefreshTrigger::Single(id) => sources.iter().filter(|s| s.id == *id).collect(),
226                #[cfg(feature = "nats")]
227                RefreshTrigger::NatsPush { .. } => unreachable!(),
228            };
229
230            if to_resolve.is_empty() {
231                continue;
232            }
233
234            let refresh_count = to_resolve.len();
235            let refresh_start = std::time::Instant::now();
236            match resolve_all(
237                resolver.as_ref(),
238                &to_resolve.iter().map(|s| (*s).clone()).collect::<Vec<_>>(),
239            )
240            .await
241            {
242                Ok(resolved) => {
243                    tracing::debug!(
244                        sources = refresh_count,
245                        duration_ms = refresh_start.elapsed().as_millis() as u64,
246                        "Scheduled refresh completed",
247                    );
248                    let _ = result_tx.send(Some(RefreshResult { resolved }));
249                }
250                Err(e) => {
251                    tracing::warn!(
252                        error = %e,
253                        sources = refresh_count,
254                        duration_ms = refresh_start.elapsed().as_millis() as u64,
255                        "Background source refresh failed",
256                    );
257                }
258            }
259        }
260    }
261}
262
263impl Default for RefreshScheduler {
264    fn default() -> Self {
265        Self::new()
266    }
267}
268
269/// A detached source subscription returned by [`RefreshScheduler::subscribe`].
270///
271/// Bundles the spawned coordination task, a receiver of decoded source payloads
272/// (the latest [`RefreshResult`] per refresh), and a trigger sender for
273/// on-demand re-resolution and hot-reload. Dropping `handle` does not stop the
274/// loop; hold it (or detach it) for the lifetime of the consumer.
275pub struct SourceSubscription {
276    /// The scheduler coordination task.
277    pub handle: tokio::task::JoinHandle<()>,
278    /// Latest decoded source payloads, updated on every refresh.
279    pub results: watch::Receiver<Option<RefreshResult>>,
280    /// Trigger channel for on-demand re-resolution (`All` / `Single`).
281    pub trigger: mpsc::Sender<RefreshTrigger>,
282}
283
284/// Subscribe to a NATS subject and forward parsed messages as triggers.
285#[cfg(feature = "nats")]
286async fn nats_push_loop(
287    url: &str,
288    subject: &str,
289    format: rsigma_eval::pipeline::sources::DataFormat,
290    extract_expr: Option<&rsigma_eval::pipeline::sources::ExtractExpr>,
291    source_id: &str,
292    trigger_tx: &mpsc::Sender<RefreshTrigger>,
293) -> Result<(), String> {
294    use futures::StreamExt;
295
296    let client = async_nats::connect(url)
297        .await
298        .map_err(|e| format!("NATS connect failed: {e}"))?;
299
300    let mut subscriber = client
301        .subscribe(subject.to_string())
302        .await
303        .map_err(|e| format!("NATS subscribe failed: {e}"))?;
304
305    tracing::info!(
306        source_id = %source_id,
307        subject = %subject,
308        "NATS push subscription active"
309    );
310
311    while let Some(msg) = subscriber.next().await {
312        match super::nats::parse_nats_message(&msg.payload, format, extract_expr) {
313            Ok(data) => {
314                let trigger = RefreshTrigger::NatsPush {
315                    source_id: source_id.to_string(),
316                    data,
317                };
318                if trigger_tx.send(trigger).await.is_err() {
319                    break;
320                }
321            }
322            Err(e) => {
323                tracing::warn!(
324                    source_id = %source_id,
325                    error = %e,
326                    "Failed to parse NATS push message"
327                );
328            }
329        }
330    }
331
332    Ok(())
333}
334
335/// The default NATS control subject for triggering source re-resolution.
336pub const NATS_CONTROL_SUBJECT: &str = "rsigma.control.resolve";
337
338/// Subscribe to the NATS control subject and forward re-resolution triggers.
339///
340/// Messages with an empty payload trigger re-resolution of all sources.
341/// Messages with a non-empty payload are treated as a source ID to re-resolve.
342#[cfg(feature = "nats")]
343pub async fn nats_control_loop(
344    url: &str,
345    subject: &str,
346    trigger_tx: mpsc::Sender<RefreshTrigger>,
347) -> Result<(), String> {
348    use futures::StreamExt;
349
350    let client = async_nats::connect(url)
351        .await
352        .map_err(|e| format!("NATS control connect failed: {e}"))?;
353
354    let mut subscriber = client
355        .subscribe(subject.to_string())
356        .await
357        .map_err(|e| format!("NATS control subscribe failed: {e}"))?;
358
359    tracing::info!(
360        subject = %subject,
361        "NATS control subscription active for source re-resolution"
362    );
363
364    while let Some(msg) = subscriber.next().await {
365        let payload = String::from_utf8_lossy(&msg.payload);
366        let payload = payload.trim();
367
368        let trigger = if payload.is_empty() {
369            tracing::debug!("NATS control: triggering all sources");
370            RefreshTrigger::All
371        } else {
372            tracing::debug!(source_id = %payload, "NATS control: triggering single source");
373            RefreshTrigger::Single(payload.to_string())
374        };
375
376        if trigger_tx.send(trigger).await.is_err() {
377            tracing::debug!("NATS control loop: trigger channel closed, exiting");
378            break;
379        }
380    }
381
382    Ok(())
383}
384
385/// Watch a file for changes and send refresh triggers.
386async fn file_watch_loop(
387    path: &std::path::Path,
388    source_id: &str,
389    trigger_tx: &mpsc::Sender<RefreshTrigger>,
390) {
391    use notify::{Event, EventKind, RecommendedWatcher, Watcher};
392    use tokio::sync::mpsc as tokio_mpsc;
393
394    let (notify_tx, mut notify_rx) = tokio_mpsc::channel::<()>(4);
395
396    let _watcher = {
397        let tx = notify_tx.clone();
398        match RecommendedWatcher::new(
399            move |res: Result<Event, notify::Error>| {
400                if let Ok(event) = res
401                    && matches!(event.kind, EventKind::Create(_) | EventKind::Modify(_))
402                {
403                    let _ = tx.try_send(());
404                }
405            },
406            notify::Config::default(),
407        ) {
408            Ok(mut w) => {
409                if let Err(e) = w.watch(path, notify::RecursiveMode::NonRecursive) {
410                    tracing::warn!(
411                        source_id = %source_id,
412                        path = %path.display(),
413                        error = %e,
414                        "Could not watch source file"
415                    );
416                    return;
417                }
418                tracing::info!(
419                    source_id = %source_id,
420                    path = %path.display(),
421                    "Watching source file for changes"
422                );
423                Some(w)
424            }
425            Err(e) => {
426                tracing::warn!(
427                    source_id = %source_id,
428                    error = %e,
429                    "Could not create file watcher for source"
430                );
431                return;
432            }
433        }
434    };
435
436    while notify_rx.recv().await.is_some() {
437        // Debounce: wait a short period for additional changes
438        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
439        // Drain any queued notifications
440        while notify_rx.try_recv().is_ok() {}
441
442        if trigger_tx
443            .send(RefreshTrigger::Single(source_id.to_string()))
444            .await
445            .is_err()
446        {
447            break;
448        }
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455    use crate::sources::DefaultSourceResolver;
456    use rsigma_eval::pipeline::sources::{DataFormat, ErrorPolicy, SourceType};
457
458    fn file_source(id: &str, path: std::path::PathBuf) -> DynamicSource {
459        DynamicSource {
460            id: id.to_string(),
461            source_type: SourceType::File {
462                path,
463                format: DataFormat::Json,
464                extract: None,
465            },
466            refresh: RefreshPolicy::OnDemand,
467            timeout: None,
468            on_error: ErrorPolicy::Fail,
469            required: true,
470            default: None,
471        }
472    }
473
474    // A detached consumer subscribes to a single file source and receives its
475    // decoded payload on the watch channel after an on-demand trigger, reusing
476    // the same fetch machinery as the pipeline binder.
477    #[tokio::test]
478    async fn subscribe_delivers_decoded_payload() {
479        let dir = tempfile::tempdir().unwrap();
480        let path = dir.path().join("dispositions.json");
481        std::fs::write(
482            &path,
483            r#"{"rules": [{"rule_id": "r1", "verdict": "false_positive"}]}"#,
484        )
485        .unwrap();
486
487        let scheduler = RefreshScheduler::new();
488        let sub = scheduler.subscribe(
489            vec![file_source("d", path)],
490            Arc::new(DefaultSourceResolver::new()),
491        );
492        let mut results = sub.results;
493
494        sub.trigger.send(RefreshTrigger::All).await.unwrap();
495        results.changed().await.unwrap();
496
497        let payload = results.borrow().clone().expect("a refresh result");
498        let data = payload.resolved.get("d").expect("source d resolved");
499        assert_eq!(data["rules"][0]["rule_id"], "r1");
500        assert_eq!(data["rules"][0]["verdict"], "false_positive");
501    }
502}