Skip to main content

rsigma_runtime/sources/
refresh.rs

1//! Background refresh scheduler for dynamic pipeline sources.
2//!
3//! Manages per-source refresh loops based on `RefreshPolicy`:
4//! - `Interval(duration)`: re-fetches on a timer
5//! - `Watch`: uses file system notifications (via `notify`)
6//! - `Push`: receives updates from external triggers (NATS)
7//! - `OnDemand`: only refreshes when explicitly triggered via API/signal
8
9use std::collections::HashMap;
10use std::sync::Arc;
11
12use rsigma_eval::pipeline::sources::{DynamicSource, RefreshPolicy, SourceType};
13use tokio::sync::{mpsc, watch};
14
15use super::{SourceResolver, resolve_all};
16
17/// A message requesting source re-resolution.
18#[derive(Debug, Clone)]
19pub enum RefreshTrigger {
20    /// Re-resolve all sources.
21    All,
22    /// Re-resolve a specific source by ID.
23    Single(String),
24    /// A NATS push message arrived with pre-parsed data for a specific source.
25    #[cfg(feature = "nats")]
26    NatsPush {
27        source_id: String,
28        data: serde_json::Value,
29    },
30}
31
32/// Notification sent when sources have been refreshed.
33#[derive(Debug, Clone)]
34pub struct RefreshResult {
35    /// The newly resolved source data (source_id -> value).
36    pub resolved: HashMap<String, serde_json::Value>,
37}
38
39/// Manages background refresh tasks for dynamic sources.
40///
41/// The scheduler spawns per-source tasks based on their refresh policy and
42/// sends `RefreshResult` notifications whenever source data changes.
43pub struct RefreshScheduler {
44    /// Channel for on-demand refresh triggers (from API, SIGHUP, NATS control).
45    trigger_tx: mpsc::Sender<RefreshTrigger>,
46    /// Receiver for on-demand triggers (consumed by the run loop).
47    trigger_rx: Option<mpsc::Receiver<RefreshTrigger>>,
48    /// Watch channel sender for notifying consumers of updated source data.
49    result_tx: watch::Sender<Option<RefreshResult>>,
50    /// Watch channel receiver for consumers.
51    result_rx: watch::Receiver<Option<RefreshResult>>,
52}
53
54impl RefreshScheduler {
55    /// Create a new scheduler.
56    pub fn new() -> Self {
57        let (trigger_tx, trigger_rx) = mpsc::channel(32);
58        let (result_tx, result_rx) = watch::channel(None);
59        Self {
60            trigger_tx,
61            trigger_rx: Some(trigger_rx),
62            result_tx,
63            result_rx,
64        }
65    }
66
67    /// Get a sender for triggering on-demand resolution.
68    pub fn trigger_sender(&self) -> mpsc::Sender<RefreshTrigger> {
69        self.trigger_tx.clone()
70    }
71
72    /// Get a receiver that is notified when sources are refreshed.
73    pub fn result_receiver(&self) -> watch::Receiver<Option<RefreshResult>> {
74        self.result_rx.clone()
75    }
76
77    /// Start the scheduler background loop.
78    ///
79    /// Takes ownership of the trigger receiver and spawns per-source interval tasks.
80    /// Returns a `JoinHandle` for the main coordination task.
81    ///
82    /// When a refresh occurs (via interval timer or on-demand trigger), all sources
83    /// are re-resolved and the result is published on the watch channel.
84    pub fn run(
85        mut self,
86        sources: Vec<DynamicSource>,
87        resolver: Arc<dyn SourceResolver>,
88    ) -> tokio::task::JoinHandle<()> {
89        let trigger_rx = self
90            .trigger_rx
91            .take()
92            .expect("run() can only be called once");
93
94        tokio::spawn(async move {
95            Self::run_loop(
96                sources,
97                resolver,
98                trigger_rx,
99                self.trigger_tx,
100                self.result_tx,
101            )
102            .await;
103        })
104    }
105
106    async fn run_loop(
107        sources: Vec<DynamicSource>,
108        resolver: Arc<dyn SourceResolver>,
109        mut trigger_rx: mpsc::Receiver<RefreshTrigger>,
110        trigger_tx: mpsc::Sender<RefreshTrigger>,
111        result_tx: watch::Sender<Option<RefreshResult>>,
112    ) {
113        // Spawn interval timers
114        for source in &sources {
115            if let RefreshPolicy::Interval(duration) = &source.refresh {
116                let tx = trigger_tx.clone();
117                let id = source.id.clone();
118                let interval = *duration;
119                tokio::spawn(async move {
120                    let mut timer = tokio::time::interval(interval);
121                    timer.tick().await; // skip immediate first tick
122                    loop {
123                        timer.tick().await;
124                        if tx.send(RefreshTrigger::Single(id.clone())).await.is_err() {
125                            break;
126                        }
127                    }
128                });
129            }
130        }
131
132        // Spawn NATS push subscriptions
133        #[cfg(feature = "nats")]
134        for source in &sources {
135            if source.refresh == RefreshPolicy::Push
136                && let SourceType::Nats {
137                    url,
138                    subject,
139                    format,
140                    extract: extract_expr,
141                } = &source.source_type
142            {
143                let tx = trigger_tx.clone();
144                let id = source.id.clone();
145                let url = url.clone();
146                let subject = subject.clone();
147                let format = *format;
148                let extract_expr = extract_expr.clone();
149                tokio::spawn(async move {
150                    if let Err(e) =
151                        nats_push_loop(&url, &subject, format, extract_expr.as_ref(), &id, &tx)
152                            .await
153                    {
154                        tracing::error!(
155                            source_id = %id,
156                            error = %e,
157                            "NATS push subscription failed"
158                        );
159                    }
160                });
161            }
162        }
163
164        // Spawn file watchers for Watch policy sources
165        for source in &sources {
166            if source.refresh == RefreshPolicy::Watch
167                && let SourceType::File { path, .. } = &source.source_type
168            {
169                let tx = trigger_tx.clone();
170                let id = source.id.clone();
171                let path = path.clone();
172                tokio::spawn(async move {
173                    file_watch_loop(&path, &id, &tx).await;
174                });
175            }
176        }
177
178        // Main loop: wait for triggers and resolve
179        while let Some(trigger) = trigger_rx.recv().await {
180            // Handle NATS push with pre-parsed data (no re-resolution needed)
181            #[cfg(feature = "nats")]
182            if let RefreshTrigger::NatsPush { source_id, data } = trigger {
183                let mut resolved = HashMap::new();
184                resolved.insert(source_id, data);
185                let _ = result_tx.send(Some(RefreshResult { resolved }));
186                continue;
187            }
188
189            let to_resolve: Vec<&DynamicSource> = match &trigger {
190                RefreshTrigger::All => sources.iter().collect(),
191                RefreshTrigger::Single(id) => sources.iter().filter(|s| s.id == *id).collect(),
192                #[cfg(feature = "nats")]
193                RefreshTrigger::NatsPush { .. } => unreachable!(),
194            };
195
196            if to_resolve.is_empty() {
197                continue;
198            }
199
200            match resolve_all(
201                resolver.as_ref(),
202                &to_resolve.iter().map(|s| (*s).clone()).collect::<Vec<_>>(),
203            )
204            .await
205            {
206                Ok(resolved) => {
207                    let _ = result_tx.send(Some(RefreshResult { resolved }));
208                }
209                Err(e) => {
210                    tracing::warn!(error = %e, "Background source refresh failed");
211                }
212            }
213        }
214    }
215}
216
217impl Default for RefreshScheduler {
218    fn default() -> Self {
219        Self::new()
220    }
221}
222
223/// Subscribe to a NATS subject and forward parsed messages as triggers.
224#[cfg(feature = "nats")]
225async fn nats_push_loop(
226    url: &str,
227    subject: &str,
228    format: rsigma_eval::pipeline::sources::DataFormat,
229    extract_expr: Option<&rsigma_eval::pipeline::sources::ExtractExpr>,
230    source_id: &str,
231    trigger_tx: &mpsc::Sender<RefreshTrigger>,
232) -> Result<(), String> {
233    use futures::StreamExt;
234
235    let client = async_nats::connect(url)
236        .await
237        .map_err(|e| format!("NATS connect failed: {e}"))?;
238
239    let mut subscriber = client
240        .subscribe(subject.to_string())
241        .await
242        .map_err(|e| format!("NATS subscribe failed: {e}"))?;
243
244    tracing::info!(
245        source_id = %source_id,
246        subject = %subject,
247        "NATS push subscription active"
248    );
249
250    while let Some(msg) = subscriber.next().await {
251        match super::nats::parse_nats_message(&msg.payload, format, extract_expr) {
252            Ok(data) => {
253                let trigger = RefreshTrigger::NatsPush {
254                    source_id: source_id.to_string(),
255                    data,
256                };
257                if trigger_tx.send(trigger).await.is_err() {
258                    break;
259                }
260            }
261            Err(e) => {
262                tracing::warn!(
263                    source_id = %source_id,
264                    error = %e,
265                    "Failed to parse NATS push message"
266                );
267            }
268        }
269    }
270
271    Ok(())
272}
273
274/// The default NATS control subject for triggering source re-resolution.
275pub const NATS_CONTROL_SUBJECT: &str = "rsigma.control.resolve";
276
277/// Subscribe to the NATS control subject and forward re-resolution triggers.
278///
279/// Messages with an empty payload trigger re-resolution of all sources.
280/// Messages with a non-empty payload are treated as a source ID to re-resolve.
281#[cfg(feature = "nats")]
282pub async fn nats_control_loop(
283    url: &str,
284    subject: &str,
285    trigger_tx: mpsc::Sender<RefreshTrigger>,
286) -> Result<(), String> {
287    use futures::StreamExt;
288
289    let client = async_nats::connect(url)
290        .await
291        .map_err(|e| format!("NATS control connect failed: {e}"))?;
292
293    let mut subscriber = client
294        .subscribe(subject.to_string())
295        .await
296        .map_err(|e| format!("NATS control subscribe failed: {e}"))?;
297
298    tracing::info!(
299        subject = %subject,
300        "NATS control subscription active for source re-resolution"
301    );
302
303    while let Some(msg) = subscriber.next().await {
304        let payload = String::from_utf8_lossy(&msg.payload);
305        let payload = payload.trim();
306
307        let trigger = if payload.is_empty() {
308            tracing::debug!("NATS control: triggering all sources");
309            RefreshTrigger::All
310        } else {
311            tracing::debug!(source_id = %payload, "NATS control: triggering single source");
312            RefreshTrigger::Single(payload.to_string())
313        };
314
315        if trigger_tx.send(trigger).await.is_err() {
316            tracing::debug!("NATS control loop: trigger channel closed, exiting");
317            break;
318        }
319    }
320
321    Ok(())
322}
323
324/// Watch a file for changes and send refresh triggers.
325async fn file_watch_loop(
326    path: &std::path::Path,
327    source_id: &str,
328    trigger_tx: &mpsc::Sender<RefreshTrigger>,
329) {
330    use notify::{Event, EventKind, RecommendedWatcher, Watcher};
331    use tokio::sync::mpsc as tokio_mpsc;
332
333    let (notify_tx, mut notify_rx) = tokio_mpsc::channel::<()>(4);
334
335    let _watcher = {
336        let tx = notify_tx.clone();
337        match RecommendedWatcher::new(
338            move |res: Result<Event, notify::Error>| {
339                if let Ok(event) = res
340                    && matches!(event.kind, EventKind::Create(_) | EventKind::Modify(_))
341                {
342                    let _ = tx.try_send(());
343                }
344            },
345            notify::Config::default(),
346        ) {
347            Ok(mut w) => {
348                if let Err(e) = w.watch(path, notify::RecursiveMode::NonRecursive) {
349                    tracing::warn!(
350                        source_id = %source_id,
351                        path = %path.display(),
352                        error = %e,
353                        "Could not watch source file"
354                    );
355                    return;
356                }
357                tracing::info!(
358                    source_id = %source_id,
359                    path = %path.display(),
360                    "Watching source file for changes"
361                );
362                Some(w)
363            }
364            Err(e) => {
365                tracing::warn!(
366                    source_id = %source_id,
367                    error = %e,
368                    "Could not create file watcher for source"
369                );
370                return;
371            }
372        }
373    };
374
375    while notify_rx.recv().await.is_some() {
376        // Debounce: wait a short period for additional changes
377        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
378        // Drain any queued notifications
379        while notify_rx.try_recv().is_ok() {}
380
381        if trigger_tx
382            .send(RefreshTrigger::Single(source_id.to_string()))
383            .await
384            .is_err()
385        {
386            break;
387        }
388    }
389}