1use std::collections::HashMap;
10use std::sync::Arc;
11
12use rsigma_eval::pipeline::sources::{DynamicSource, RefreshPolicy, SourceType};
13use tokio::sync::{mpsc, watch};
14
15use super::{SourceResolver, resolve_all};
16
17#[derive(Debug, Clone)]
19pub enum RefreshTrigger {
20 All,
22 Single(String),
24 #[cfg(feature = "nats")]
26 NatsPush {
27 source_id: String,
28 data: serde_json::Value,
29 },
30}
31
32#[derive(Debug, Clone)]
34pub struct RefreshResult {
35 pub resolved: HashMap<String, serde_json::Value>,
37}
38
39pub struct RefreshScheduler {
44 trigger_tx: mpsc::Sender<RefreshTrigger>,
46 trigger_rx: Option<mpsc::Receiver<RefreshTrigger>>,
48 result_tx: watch::Sender<Option<RefreshResult>>,
50 result_rx: watch::Receiver<Option<RefreshResult>>,
52}
53
54impl RefreshScheduler {
55 pub fn new() -> Self {
57 let (trigger_tx, trigger_rx) = mpsc::channel(32);
58 let (result_tx, result_rx) = watch::channel(None);
59 Self {
60 trigger_tx,
61 trigger_rx: Some(trigger_rx),
62 result_tx,
63 result_rx,
64 }
65 }
66
67 pub fn trigger_sender(&self) -> mpsc::Sender<RefreshTrigger> {
69 self.trigger_tx.clone()
70 }
71
72 pub fn result_receiver(&self) -> watch::Receiver<Option<RefreshResult>> {
74 self.result_rx.clone()
75 }
76
77 pub fn run(
85 mut self,
86 sources: Vec<DynamicSource>,
87 resolver: Arc<dyn SourceResolver>,
88 ) -> tokio::task::JoinHandle<()> {
89 let trigger_rx = self
90 .trigger_rx
91 .take()
92 .expect("run() can only be called once");
93
94 tokio::spawn(async move {
95 Self::run_loop(
96 sources,
97 resolver,
98 trigger_rx,
99 self.trigger_tx,
100 self.result_tx,
101 )
102 .await;
103 })
104 }
105
106 async fn run_loop(
107 sources: Vec<DynamicSource>,
108 resolver: Arc<dyn SourceResolver>,
109 mut trigger_rx: mpsc::Receiver<RefreshTrigger>,
110 trigger_tx: mpsc::Sender<RefreshTrigger>,
111 result_tx: watch::Sender<Option<RefreshResult>>,
112 ) {
113 for source in &sources {
115 if let RefreshPolicy::Interval(duration) = &source.refresh {
116 let tx = trigger_tx.clone();
117 let id = source.id.clone();
118 let interval = if *duration < super::MIN_REFRESH_INTERVAL {
119 tracing::warn!(
120 source_id = %id,
121 configured = ?duration,
122 clamped_to = ?super::MIN_REFRESH_INTERVAL,
123 "Refresh interval below minimum, clamping to floor"
124 );
125 super::MIN_REFRESH_INTERVAL
126 } else {
127 *duration
128 };
129 tokio::spawn(async move {
130 let mut timer = tokio::time::interval(interval);
131 timer.tick().await; loop {
133 timer.tick().await;
134 if tx.send(RefreshTrigger::Single(id.clone())).await.is_err() {
135 break;
136 }
137 }
138 });
139 }
140 }
141
142 #[cfg(feature = "nats")]
144 for source in &sources {
145 if source.refresh == RefreshPolicy::Push
146 && let SourceType::Nats {
147 url,
148 subject,
149 format,
150 extract: extract_expr,
151 } = &source.source_type
152 {
153 let tx = trigger_tx.clone();
154 let id = source.id.clone();
155 let url = url.clone();
156 let subject = subject.clone();
157 let format = *format;
158 let extract_expr = extract_expr.clone();
159 tokio::spawn(async move {
160 if let Err(e) =
161 nats_push_loop(&url, &subject, format, extract_expr.as_ref(), &id, &tx)
162 .await
163 {
164 tracing::error!(
165 source_id = %id,
166 error = %e,
167 "NATS push subscription failed"
168 );
169 }
170 });
171 }
172 }
173
174 for source in &sources {
176 if source.refresh == RefreshPolicy::Watch
177 && let SourceType::File { path, .. } = &source.source_type
178 {
179 let tx = trigger_tx.clone();
180 let id = source.id.clone();
181 let path = path.clone();
182 tokio::spawn(async move {
183 file_watch_loop(&path, &id, &tx).await;
184 });
185 }
186 }
187
188 while let Some(trigger) = trigger_rx.recv().await {
190 #[cfg(feature = "nats")]
192 if let RefreshTrigger::NatsPush { source_id, data } = trigger {
193 let mut resolved = HashMap::new();
194 resolved.insert(source_id, data);
195 let _ = result_tx.send(Some(RefreshResult { resolved }));
196 continue;
197 }
198
199 let to_resolve: Vec<&DynamicSource> = match &trigger {
200 RefreshTrigger::All => sources.iter().collect(),
201 RefreshTrigger::Single(id) => sources.iter().filter(|s| s.id == *id).collect(),
202 #[cfg(feature = "nats")]
203 RefreshTrigger::NatsPush { .. } => unreachable!(),
204 };
205
206 if to_resolve.is_empty() {
207 continue;
208 }
209
210 match resolve_all(
211 resolver.as_ref(),
212 &to_resolve.iter().map(|s| (*s).clone()).collect::<Vec<_>>(),
213 )
214 .await
215 {
216 Ok(resolved) => {
217 let _ = result_tx.send(Some(RefreshResult { resolved }));
218 }
219 Err(e) => {
220 tracing::warn!(error = %e, "Background source refresh failed");
221 }
222 }
223 }
224 }
225}
226
227impl Default for RefreshScheduler {
228 fn default() -> Self {
229 Self::new()
230 }
231}
232
233#[cfg(feature = "nats")]
235async fn nats_push_loop(
236 url: &str,
237 subject: &str,
238 format: rsigma_eval::pipeline::sources::DataFormat,
239 extract_expr: Option<&rsigma_eval::pipeline::sources::ExtractExpr>,
240 source_id: &str,
241 trigger_tx: &mpsc::Sender<RefreshTrigger>,
242) -> Result<(), String> {
243 use futures::StreamExt;
244
245 let client = async_nats::connect(url)
246 .await
247 .map_err(|e| format!("NATS connect failed: {e}"))?;
248
249 let mut subscriber = client
250 .subscribe(subject.to_string())
251 .await
252 .map_err(|e| format!("NATS subscribe failed: {e}"))?;
253
254 tracing::info!(
255 source_id = %source_id,
256 subject = %subject,
257 "NATS push subscription active"
258 );
259
260 while let Some(msg) = subscriber.next().await {
261 match super::nats::parse_nats_message(&msg.payload, format, extract_expr) {
262 Ok(data) => {
263 let trigger = RefreshTrigger::NatsPush {
264 source_id: source_id.to_string(),
265 data,
266 };
267 if trigger_tx.send(trigger).await.is_err() {
268 break;
269 }
270 }
271 Err(e) => {
272 tracing::warn!(
273 source_id = %source_id,
274 error = %e,
275 "Failed to parse NATS push message"
276 );
277 }
278 }
279 }
280
281 Ok(())
282}
283
284pub const NATS_CONTROL_SUBJECT: &str = "rsigma.control.resolve";
286
287#[cfg(feature = "nats")]
292pub async fn nats_control_loop(
293 url: &str,
294 subject: &str,
295 trigger_tx: mpsc::Sender<RefreshTrigger>,
296) -> Result<(), String> {
297 use futures::StreamExt;
298
299 let client = async_nats::connect(url)
300 .await
301 .map_err(|e| format!("NATS control connect failed: {e}"))?;
302
303 let mut subscriber = client
304 .subscribe(subject.to_string())
305 .await
306 .map_err(|e| format!("NATS control subscribe failed: {e}"))?;
307
308 tracing::info!(
309 subject = %subject,
310 "NATS control subscription active for source re-resolution"
311 );
312
313 while let Some(msg) = subscriber.next().await {
314 let payload = String::from_utf8_lossy(&msg.payload);
315 let payload = payload.trim();
316
317 let trigger = if payload.is_empty() {
318 tracing::debug!("NATS control: triggering all sources");
319 RefreshTrigger::All
320 } else {
321 tracing::debug!(source_id = %payload, "NATS control: triggering single source");
322 RefreshTrigger::Single(payload.to_string())
323 };
324
325 if trigger_tx.send(trigger).await.is_err() {
326 tracing::debug!("NATS control loop: trigger channel closed, exiting");
327 break;
328 }
329 }
330
331 Ok(())
332}
333
334async fn file_watch_loop(
336 path: &std::path::Path,
337 source_id: &str,
338 trigger_tx: &mpsc::Sender<RefreshTrigger>,
339) {
340 use notify::{Event, EventKind, RecommendedWatcher, Watcher};
341 use tokio::sync::mpsc as tokio_mpsc;
342
343 let (notify_tx, mut notify_rx) = tokio_mpsc::channel::<()>(4);
344
345 let _watcher = {
346 let tx = notify_tx.clone();
347 match RecommendedWatcher::new(
348 move |res: Result<Event, notify::Error>| {
349 if let Ok(event) = res
350 && matches!(event.kind, EventKind::Create(_) | EventKind::Modify(_))
351 {
352 let _ = tx.try_send(());
353 }
354 },
355 notify::Config::default(),
356 ) {
357 Ok(mut w) => {
358 if let Err(e) = w.watch(path, notify::RecursiveMode::NonRecursive) {
359 tracing::warn!(
360 source_id = %source_id,
361 path = %path.display(),
362 error = %e,
363 "Could not watch source file"
364 );
365 return;
366 }
367 tracing::info!(
368 source_id = %source_id,
369 path = %path.display(),
370 "Watching source file for changes"
371 );
372 Some(w)
373 }
374 Err(e) => {
375 tracing::warn!(
376 source_id = %source_id,
377 error = %e,
378 "Could not create file watcher for source"
379 );
380 return;
381 }
382 }
383 };
384
385 while notify_rx.recv().await.is_some() {
386 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
388 while notify_rx.try_recv().is_ok() {}
390
391 if trigger_tx
392 .send(RefreshTrigger::Single(source_id.to_string()))
393 .await
394 .is_err()
395 {
396 break;
397 }
398 }
399}