Skip to main content

rsigma_runtime/sources/
refresh.rs

1//! Background refresh scheduler for dynamic pipeline sources.
2//!
3//! Manages per-source refresh loops based on `RefreshPolicy`:
4//! - `Interval(duration)`: re-fetches on a timer
5//! - `Watch`: uses file system notifications (via `notify`)
6//! - `Push`: receives updates from external triggers (NATS)
7//! - `OnDemand`: only refreshes when explicitly triggered via API/signal
8
9use std::collections::HashMap;
10use std::sync::Arc;
11
12use rsigma_eval::pipeline::sources::{DynamicSource, RefreshPolicy, SourceType};
13use tokio::sync::{mpsc, watch};
14
15use super::{SourceResolver, resolve_all};
16
17/// A message requesting source re-resolution.
18#[derive(Debug, Clone)]
19pub enum RefreshTrigger {
20    /// Re-resolve all sources.
21    All,
22    /// Re-resolve a specific source by ID.
23    Single(String),
24    /// A NATS push message arrived with pre-parsed data for a specific source.
25    #[cfg(feature = "nats")]
26    NatsPush {
27        source_id: String,
28        data: serde_json::Value,
29    },
30}
31
32/// Notification sent when sources have been refreshed.
33#[derive(Debug, Clone)]
34pub struct RefreshResult {
35    /// The newly resolved source data (source_id -> value).
36    pub resolved: HashMap<String, serde_json::Value>,
37}
38
39/// Manages background refresh tasks for dynamic sources.
40///
41/// The scheduler spawns per-source tasks based on their refresh policy and
42/// sends `RefreshResult` notifications whenever source data changes.
43pub struct RefreshScheduler {
44    /// Channel for on-demand refresh triggers (from API, SIGHUP, NATS control).
45    trigger_tx: mpsc::Sender<RefreshTrigger>,
46    /// Receiver for on-demand triggers (consumed by the run loop).
47    trigger_rx: Option<mpsc::Receiver<RefreshTrigger>>,
48    /// Watch channel sender for notifying consumers of updated source data.
49    result_tx: watch::Sender<Option<RefreshResult>>,
50    /// Watch channel receiver for consumers.
51    result_rx: watch::Receiver<Option<RefreshResult>>,
52}
53
54impl RefreshScheduler {
55    /// Create a new scheduler.
56    pub fn new() -> Self {
57        let (trigger_tx, trigger_rx) = mpsc::channel(32);
58        let (result_tx, result_rx) = watch::channel(None);
59        Self {
60            trigger_tx,
61            trigger_rx: Some(trigger_rx),
62            result_tx,
63            result_rx,
64        }
65    }
66
67    /// Get a sender for triggering on-demand resolution.
68    pub fn trigger_sender(&self) -> mpsc::Sender<RefreshTrigger> {
69        self.trigger_tx.clone()
70    }
71
72    /// Get a receiver that is notified when sources are refreshed.
73    pub fn result_receiver(&self) -> watch::Receiver<Option<RefreshResult>> {
74        self.result_rx.clone()
75    }
76
77    /// Start the scheduler background loop.
78    ///
79    /// Takes ownership of the trigger receiver and spawns per-source interval tasks.
80    /// Returns a `JoinHandle` for the main coordination task.
81    ///
82    /// When a refresh occurs (via interval timer or on-demand trigger), all sources
83    /// are re-resolved and the result is published on the watch channel.
84    pub fn run(
85        mut self,
86        sources: Vec<DynamicSource>,
87        resolver: Arc<dyn SourceResolver>,
88    ) -> tokio::task::JoinHandle<()> {
89        let trigger_rx = self
90            .trigger_rx
91            .take()
92            .expect("run() can only be called once");
93
94        tokio::spawn(async move {
95            Self::run_loop(
96                sources,
97                resolver,
98                trigger_rx,
99                self.trigger_tx,
100                self.result_tx,
101            )
102            .await;
103        })
104    }
105
106    async fn run_loop(
107        sources: Vec<DynamicSource>,
108        resolver: Arc<dyn SourceResolver>,
109        mut trigger_rx: mpsc::Receiver<RefreshTrigger>,
110        trigger_tx: mpsc::Sender<RefreshTrigger>,
111        result_tx: watch::Sender<Option<RefreshResult>>,
112    ) {
113        // Spawn interval timers
114        for source in &sources {
115            if let RefreshPolicy::Interval(duration) = &source.refresh {
116                let tx = trigger_tx.clone();
117                let id = source.id.clone();
118                let interval = if *duration < super::MIN_REFRESH_INTERVAL {
119                    tracing::warn!(
120                        source_id = %id,
121                        configured = ?duration,
122                        clamped_to = ?super::MIN_REFRESH_INTERVAL,
123                        "Refresh interval below minimum, clamping to floor"
124                    );
125                    super::MIN_REFRESH_INTERVAL
126                } else {
127                    *duration
128                };
129                tokio::spawn(async move {
130                    let mut timer = tokio::time::interval(interval);
131                    timer.tick().await; // skip immediate first tick
132                    loop {
133                        timer.tick().await;
134                        if tx.send(RefreshTrigger::Single(id.clone())).await.is_err() {
135                            break;
136                        }
137                    }
138                });
139            }
140        }
141
142        // Spawn NATS push subscriptions
143        #[cfg(feature = "nats")]
144        for source in &sources {
145            if source.refresh == RefreshPolicy::Push
146                && let SourceType::Nats {
147                    url,
148                    subject,
149                    format,
150                    extract: extract_expr,
151                } = &source.source_type
152            {
153                let tx = trigger_tx.clone();
154                let id = source.id.clone();
155                let url = url.clone();
156                let subject = subject.clone();
157                let format = *format;
158                let extract_expr = extract_expr.clone();
159                tokio::spawn(async move {
160                    if let Err(e) =
161                        nats_push_loop(&url, &subject, format, extract_expr.as_ref(), &id, &tx)
162                            .await
163                    {
164                        tracing::error!(
165                            source_id = %id,
166                            error = %e,
167                            "NATS push subscription failed"
168                        );
169                    }
170                });
171            }
172        }
173
174        // Spawn file watchers for Watch policy sources
175        for source in &sources {
176            if source.refresh == RefreshPolicy::Watch
177                && let SourceType::File { path, .. } = &source.source_type
178            {
179                let tx = trigger_tx.clone();
180                let id = source.id.clone();
181                let path = path.clone();
182                tokio::spawn(async move {
183                    file_watch_loop(&path, &id, &tx).await;
184                });
185            }
186        }
187
188        // Main loop: wait for triggers and resolve
189        while let Some(trigger) = trigger_rx.recv().await {
190            // Handle NATS push with pre-parsed data (no re-resolution needed)
191            #[cfg(feature = "nats")]
192            if let RefreshTrigger::NatsPush { source_id, data } = trigger {
193                let mut resolved = HashMap::new();
194                resolved.insert(source_id, data);
195                let _ = result_tx.send(Some(RefreshResult { resolved }));
196                continue;
197            }
198
199            let to_resolve: Vec<&DynamicSource> = match &trigger {
200                RefreshTrigger::All => sources.iter().collect(),
201                RefreshTrigger::Single(id) => sources.iter().filter(|s| s.id == *id).collect(),
202                #[cfg(feature = "nats")]
203                RefreshTrigger::NatsPush { .. } => unreachable!(),
204            };
205
206            if to_resolve.is_empty() {
207                continue;
208            }
209
210            match resolve_all(
211                resolver.as_ref(),
212                &to_resolve.iter().map(|s| (*s).clone()).collect::<Vec<_>>(),
213            )
214            .await
215            {
216                Ok(resolved) => {
217                    let _ = result_tx.send(Some(RefreshResult { resolved }));
218                }
219                Err(e) => {
220                    tracing::warn!(error = %e, "Background source refresh failed");
221                }
222            }
223        }
224    }
225}
226
227impl Default for RefreshScheduler {
228    fn default() -> Self {
229        Self::new()
230    }
231}
232
233/// Subscribe to a NATS subject and forward parsed messages as triggers.
234#[cfg(feature = "nats")]
235async fn nats_push_loop(
236    url: &str,
237    subject: &str,
238    format: rsigma_eval::pipeline::sources::DataFormat,
239    extract_expr: Option<&rsigma_eval::pipeline::sources::ExtractExpr>,
240    source_id: &str,
241    trigger_tx: &mpsc::Sender<RefreshTrigger>,
242) -> Result<(), String> {
243    use futures::StreamExt;
244
245    let client = async_nats::connect(url)
246        .await
247        .map_err(|e| format!("NATS connect failed: {e}"))?;
248
249    let mut subscriber = client
250        .subscribe(subject.to_string())
251        .await
252        .map_err(|e| format!("NATS subscribe failed: {e}"))?;
253
254    tracing::info!(
255        source_id = %source_id,
256        subject = %subject,
257        "NATS push subscription active"
258    );
259
260    while let Some(msg) = subscriber.next().await {
261        match super::nats::parse_nats_message(&msg.payload, format, extract_expr) {
262            Ok(data) => {
263                let trigger = RefreshTrigger::NatsPush {
264                    source_id: source_id.to_string(),
265                    data,
266                };
267                if trigger_tx.send(trigger).await.is_err() {
268                    break;
269                }
270            }
271            Err(e) => {
272                tracing::warn!(
273                    source_id = %source_id,
274                    error = %e,
275                    "Failed to parse NATS push message"
276                );
277            }
278        }
279    }
280
281    Ok(())
282}
283
284/// The default NATS control subject for triggering source re-resolution.
285pub const NATS_CONTROL_SUBJECT: &str = "rsigma.control.resolve";
286
287/// Subscribe to the NATS control subject and forward re-resolution triggers.
288///
289/// Messages with an empty payload trigger re-resolution of all sources.
290/// Messages with a non-empty payload are treated as a source ID to re-resolve.
291#[cfg(feature = "nats")]
292pub async fn nats_control_loop(
293    url: &str,
294    subject: &str,
295    trigger_tx: mpsc::Sender<RefreshTrigger>,
296) -> Result<(), String> {
297    use futures::StreamExt;
298
299    let client = async_nats::connect(url)
300        .await
301        .map_err(|e| format!("NATS control connect failed: {e}"))?;
302
303    let mut subscriber = client
304        .subscribe(subject.to_string())
305        .await
306        .map_err(|e| format!("NATS control subscribe failed: {e}"))?;
307
308    tracing::info!(
309        subject = %subject,
310        "NATS control subscription active for source re-resolution"
311    );
312
313    while let Some(msg) = subscriber.next().await {
314        let payload = String::from_utf8_lossy(&msg.payload);
315        let payload = payload.trim();
316
317        let trigger = if payload.is_empty() {
318            tracing::debug!("NATS control: triggering all sources");
319            RefreshTrigger::All
320        } else {
321            tracing::debug!(source_id = %payload, "NATS control: triggering single source");
322            RefreshTrigger::Single(payload.to_string())
323        };
324
325        if trigger_tx.send(trigger).await.is_err() {
326            tracing::debug!("NATS control loop: trigger channel closed, exiting");
327            break;
328        }
329    }
330
331    Ok(())
332}
333
334/// Watch a file for changes and send refresh triggers.
335async fn file_watch_loop(
336    path: &std::path::Path,
337    source_id: &str,
338    trigger_tx: &mpsc::Sender<RefreshTrigger>,
339) {
340    use notify::{Event, EventKind, RecommendedWatcher, Watcher};
341    use tokio::sync::mpsc as tokio_mpsc;
342
343    let (notify_tx, mut notify_rx) = tokio_mpsc::channel::<()>(4);
344
345    let _watcher = {
346        let tx = notify_tx.clone();
347        match RecommendedWatcher::new(
348            move |res: Result<Event, notify::Error>| {
349                if let Ok(event) = res
350                    && matches!(event.kind, EventKind::Create(_) | EventKind::Modify(_))
351                {
352                    let _ = tx.try_send(());
353                }
354            },
355            notify::Config::default(),
356        ) {
357            Ok(mut w) => {
358                if let Err(e) = w.watch(path, notify::RecursiveMode::NonRecursive) {
359                    tracing::warn!(
360                        source_id = %source_id,
361                        path = %path.display(),
362                        error = %e,
363                        "Could not watch source file"
364                    );
365                    return;
366                }
367                tracing::info!(
368                    source_id = %source_id,
369                    path = %path.display(),
370                    "Watching source file for changes"
371                );
372                Some(w)
373            }
374            Err(e) => {
375                tracing::warn!(
376                    source_id = %source_id,
377                    error = %e,
378                    "Could not create file watcher for source"
379                );
380                return;
381            }
382        }
383    };
384
385    while notify_rx.recv().await.is_some() {
386        // Debounce: wait a short period for additional changes
387        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
388        // Drain any queued notifications
389        while notify_rx.try_recv().is_ok() {}
390
391        if trigger_tx
392            .send(RefreshTrigger::Single(source_id.to_string()))
393            .await
394            .is_err()
395        {
396            break;
397        }
398    }
399}