Skip to main content

rsigma_runtime/sources/
refresh.rs

1//! Background refresh scheduler for dynamic pipeline sources.
2//!
3//! Manages per-source refresh loops based on `RefreshPolicy`:
4//! - `Interval(duration)`: re-fetches on a timer
5//! - `Watch`: uses file system notifications (via `notify`)
6//! - `Push`: receives updates from external triggers (NATS)
7//! - `OnDemand`: only refreshes when explicitly triggered via API/signal
8
9use std::collections::HashMap;
10use std::sync::Arc;
11
12use rsigma_eval::pipeline::sources::{DynamicSource, RefreshPolicy, SourceType};
13use tokio::sync::{mpsc, watch};
14
15use super::{SourceResolver, resolve_all};
16
17/// A message requesting source re-resolution.
18#[derive(Debug, Clone)]
19pub enum RefreshTrigger {
20    /// Re-resolve all sources.
21    All,
22    /// Re-resolve a specific source by ID.
23    Single(String),
24    /// A NATS push message arrived with pre-parsed data for a specific source.
25    #[cfg(feature = "nats")]
26    NatsPush {
27        source_id: String,
28        data: serde_json::Value,
29    },
30}
31
32/// Notification sent when sources have been refreshed.
33#[derive(Debug, Clone)]
34pub struct RefreshResult {
35    /// The newly resolved source data (source_id -> value).
36    pub resolved: HashMap<String, serde_json::Value>,
37}
38
39/// Manages background refresh tasks for dynamic sources.
40///
41/// The scheduler spawns per-source tasks based on their refresh policy and
42/// sends `RefreshResult` notifications whenever source data changes.
43pub struct RefreshScheduler {
44    /// Channel for on-demand refresh triggers (from API, SIGHUP, NATS control).
45    trigger_tx: mpsc::Sender<RefreshTrigger>,
46    /// Receiver for on-demand triggers (consumed by the run loop).
47    trigger_rx: Option<mpsc::Receiver<RefreshTrigger>>,
48    /// Watch channel sender for notifying consumers of updated source data.
49    result_tx: watch::Sender<Option<RefreshResult>>,
50    /// Watch channel receiver for consumers.
51    result_rx: watch::Receiver<Option<RefreshResult>>,
52}
53
54impl RefreshScheduler {
55    /// Create a new scheduler.
56    pub fn new() -> Self {
57        let (trigger_tx, trigger_rx) = mpsc::channel(32);
58        let (result_tx, result_rx) = watch::channel(None);
59        Self {
60            trigger_tx,
61            trigger_rx: Some(trigger_rx),
62            result_tx,
63            result_rx,
64        }
65    }
66
67    /// Get a sender for triggering on-demand resolution.
68    pub fn trigger_sender(&self) -> mpsc::Sender<RefreshTrigger> {
69        self.trigger_tx.clone()
70    }
71
72    /// Get a receiver that is notified when sources are refreshed.
73    pub fn result_receiver(&self) -> watch::Receiver<Option<RefreshResult>> {
74        self.result_rx.clone()
75    }
76
77    /// Start the scheduler background loop.
78    ///
79    /// Takes ownership of the trigger receiver and spawns per-source interval tasks.
80    /// Returns a `JoinHandle` for the main coordination task.
81    ///
82    /// When a refresh occurs (via interval timer or on-demand trigger), all sources
83    /// are re-resolved and the result is published on the watch channel.
84    pub fn run(
85        mut self,
86        sources: Vec<DynamicSource>,
87        resolver: Arc<dyn SourceResolver>,
88    ) -> tokio::task::JoinHandle<()> {
89        let trigger_rx = self
90            .trigger_rx
91            .take()
92            .expect("run() can only be called once");
93
94        tokio::spawn(async move {
95            Self::run_loop(
96                sources,
97                resolver,
98                trigger_rx,
99                self.trigger_tx,
100                self.result_tx,
101            )
102            .await;
103        })
104    }
105
106    async fn run_loop(
107        sources: Vec<DynamicSource>,
108        resolver: Arc<dyn SourceResolver>,
109        mut trigger_rx: mpsc::Receiver<RefreshTrigger>,
110        trigger_tx: mpsc::Sender<RefreshTrigger>,
111        result_tx: watch::Sender<Option<RefreshResult>>,
112    ) {
113        // Spawn interval timers
114        for source in &sources {
115            if let RefreshPolicy::Interval(duration) = &source.refresh {
116                let tx = trigger_tx.clone();
117                let id = source.id.clone();
118                let interval = if *duration < super::MIN_REFRESH_INTERVAL {
119                    tracing::warn!(
120                        source_id = %id,
121                        configured = ?duration,
122                        clamped_to = ?super::MIN_REFRESH_INTERVAL,
123                        "Refresh interval below minimum, clamping to floor"
124                    );
125                    super::MIN_REFRESH_INTERVAL
126                } else {
127                    *duration
128                };
129                tokio::spawn(async move {
130                    let mut timer = tokio::time::interval(interval);
131                    timer.tick().await; // skip immediate first tick
132                    loop {
133                        timer.tick().await;
134                        if tx.send(RefreshTrigger::Single(id.clone())).await.is_err() {
135                            break;
136                        }
137                    }
138                });
139            }
140        }
141
142        // Spawn NATS push subscriptions
143        #[cfg(feature = "nats")]
144        for source in &sources {
145            if source.refresh == RefreshPolicy::Push
146                && let SourceType::Nats {
147                    url,
148                    subject,
149                    format,
150                    extract: extract_expr,
151                } = &source.source_type
152            {
153                let tx = trigger_tx.clone();
154                let id = source.id.clone();
155                let url = url.clone();
156                let subject = subject.clone();
157                let format = *format;
158                let extract_expr = extract_expr.clone();
159                tokio::spawn(async move {
160                    if let Err(e) =
161                        nats_push_loop(&url, &subject, format, extract_expr.as_ref(), &id, &tx)
162                            .await
163                    {
164                        tracing::error!(
165                            source_id = %id,
166                            error = %e,
167                            "NATS push subscription failed"
168                        );
169                    }
170                });
171            }
172        }
173
174        // Spawn file watchers for Watch policy sources
175        for source in &sources {
176            if source.refresh == RefreshPolicy::Watch
177                && let SourceType::File { path, .. } = &source.source_type
178            {
179                let tx = trigger_tx.clone();
180                let id = source.id.clone();
181                let path = path.clone();
182                tokio::spawn(async move {
183                    file_watch_loop(&path, &id, &tx).await;
184                });
185            }
186        }
187
188        // Main loop: wait for triggers and resolve
189        while let Some(trigger) = trigger_rx.recv().await {
190            // Handle NATS push with pre-parsed data (no re-resolution needed)
191            #[cfg(feature = "nats")]
192            if let RefreshTrigger::NatsPush { source_id, data } = trigger {
193                let mut resolved = HashMap::new();
194                resolved.insert(source_id, data);
195                let _ = result_tx.send(Some(RefreshResult { resolved }));
196                continue;
197            }
198
199            let to_resolve: Vec<&DynamicSource> = match &trigger {
200                RefreshTrigger::All => sources.iter().collect(),
201                RefreshTrigger::Single(id) => sources.iter().filter(|s| s.id == *id).collect(),
202                #[cfg(feature = "nats")]
203                RefreshTrigger::NatsPush { .. } => unreachable!(),
204            };
205
206            if to_resolve.is_empty() {
207                continue;
208            }
209
210            let refresh_count = to_resolve.len();
211            let refresh_start = std::time::Instant::now();
212            match resolve_all(
213                resolver.as_ref(),
214                &to_resolve.iter().map(|s| (*s).clone()).collect::<Vec<_>>(),
215            )
216            .await
217            {
218                Ok(resolved) => {
219                    tracing::debug!(
220                        sources = refresh_count,
221                        duration_ms = refresh_start.elapsed().as_millis() as u64,
222                        "Scheduled refresh completed",
223                    );
224                    let _ = result_tx.send(Some(RefreshResult { resolved }));
225                }
226                Err(e) => {
227                    tracing::warn!(
228                        error = %e,
229                        sources = refresh_count,
230                        duration_ms = refresh_start.elapsed().as_millis() as u64,
231                        "Background source refresh failed",
232                    );
233                }
234            }
235        }
236    }
237}
238
239impl Default for RefreshScheduler {
240    fn default() -> Self {
241        Self::new()
242    }
243}
244
245/// Subscribe to a NATS subject and forward parsed messages as triggers.
246#[cfg(feature = "nats")]
247async fn nats_push_loop(
248    url: &str,
249    subject: &str,
250    format: rsigma_eval::pipeline::sources::DataFormat,
251    extract_expr: Option<&rsigma_eval::pipeline::sources::ExtractExpr>,
252    source_id: &str,
253    trigger_tx: &mpsc::Sender<RefreshTrigger>,
254) -> Result<(), String> {
255    use futures::StreamExt;
256
257    let client = async_nats::connect(url)
258        .await
259        .map_err(|e| format!("NATS connect failed: {e}"))?;
260
261    let mut subscriber = client
262        .subscribe(subject.to_string())
263        .await
264        .map_err(|e| format!("NATS subscribe failed: {e}"))?;
265
266    tracing::info!(
267        source_id = %source_id,
268        subject = %subject,
269        "NATS push subscription active"
270    );
271
272    while let Some(msg) = subscriber.next().await {
273        match super::nats::parse_nats_message(&msg.payload, format, extract_expr) {
274            Ok(data) => {
275                let trigger = RefreshTrigger::NatsPush {
276                    source_id: source_id.to_string(),
277                    data,
278                };
279                if trigger_tx.send(trigger).await.is_err() {
280                    break;
281                }
282            }
283            Err(e) => {
284                tracing::warn!(
285                    source_id = %source_id,
286                    error = %e,
287                    "Failed to parse NATS push message"
288                );
289            }
290        }
291    }
292
293    Ok(())
294}
295
296/// The default NATS control subject for triggering source re-resolution.
297pub const NATS_CONTROL_SUBJECT: &str = "rsigma.control.resolve";
298
299/// Subscribe to the NATS control subject and forward re-resolution triggers.
300///
301/// Messages with an empty payload trigger re-resolution of all sources.
302/// Messages with a non-empty payload are treated as a source ID to re-resolve.
303#[cfg(feature = "nats")]
304pub async fn nats_control_loop(
305    url: &str,
306    subject: &str,
307    trigger_tx: mpsc::Sender<RefreshTrigger>,
308) -> Result<(), String> {
309    use futures::StreamExt;
310
311    let client = async_nats::connect(url)
312        .await
313        .map_err(|e| format!("NATS control connect failed: {e}"))?;
314
315    let mut subscriber = client
316        .subscribe(subject.to_string())
317        .await
318        .map_err(|e| format!("NATS control subscribe failed: {e}"))?;
319
320    tracing::info!(
321        subject = %subject,
322        "NATS control subscription active for source re-resolution"
323    );
324
325    while let Some(msg) = subscriber.next().await {
326        let payload = String::from_utf8_lossy(&msg.payload);
327        let payload = payload.trim();
328
329        let trigger = if payload.is_empty() {
330            tracing::debug!("NATS control: triggering all sources");
331            RefreshTrigger::All
332        } else {
333            tracing::debug!(source_id = %payload, "NATS control: triggering single source");
334            RefreshTrigger::Single(payload.to_string())
335        };
336
337        if trigger_tx.send(trigger).await.is_err() {
338            tracing::debug!("NATS control loop: trigger channel closed, exiting");
339            break;
340        }
341    }
342
343    Ok(())
344}
345
346/// Watch a file for changes and send refresh triggers.
347async fn file_watch_loop(
348    path: &std::path::Path,
349    source_id: &str,
350    trigger_tx: &mpsc::Sender<RefreshTrigger>,
351) {
352    use notify::{Event, EventKind, RecommendedWatcher, Watcher};
353    use tokio::sync::mpsc as tokio_mpsc;
354
355    let (notify_tx, mut notify_rx) = tokio_mpsc::channel::<()>(4);
356
357    let _watcher = {
358        let tx = notify_tx.clone();
359        match RecommendedWatcher::new(
360            move |res: Result<Event, notify::Error>| {
361                if let Ok(event) = res
362                    && matches!(event.kind, EventKind::Create(_) | EventKind::Modify(_))
363                {
364                    let _ = tx.try_send(());
365                }
366            },
367            notify::Config::default(),
368        ) {
369            Ok(mut w) => {
370                if let Err(e) = w.watch(path, notify::RecursiveMode::NonRecursive) {
371                    tracing::warn!(
372                        source_id = %source_id,
373                        path = %path.display(),
374                        error = %e,
375                        "Could not watch source file"
376                    );
377                    return;
378                }
379                tracing::info!(
380                    source_id = %source_id,
381                    path = %path.display(),
382                    "Watching source file for changes"
383                );
384                Some(w)
385            }
386            Err(e) => {
387                tracing::warn!(
388                    source_id = %source_id,
389                    error = %e,
390                    "Could not create file watcher for source"
391                );
392                return;
393            }
394        }
395    };
396
397    while notify_rx.recv().await.is_some() {
398        // Debounce: wait a short period for additional changes
399        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
400        // Drain any queued notifications
401        while notify_rx.try_recv().is_ok() {}
402
403        if trigger_tx
404            .send(RefreshTrigger::Single(source_id.to_string()))
405            .await
406            .is_err()
407        {
408            break;
409        }
410    }
411}