1use std::collections::HashMap;
10use std::sync::Arc;
11
12use rsigma_eval::pipeline::sources::{DynamicSource, RefreshPolicy, SourceType};
13use tokio::sync::{mpsc, watch};
14
15use super::{SourceResolver, resolve_all};
16
17#[derive(Debug, Clone)]
19pub enum RefreshTrigger {
20 All,
22 Single(String),
24 #[cfg(feature = "nats")]
26 NatsPush {
27 source_id: String,
28 data: serde_json::Value,
29 },
30}
31
32#[derive(Debug, Clone)]
34pub struct RefreshResult {
35 pub resolved: HashMap<String, serde_json::Value>,
37}
38
39pub struct RefreshScheduler {
44 trigger_tx: mpsc::Sender<RefreshTrigger>,
46 trigger_rx: Option<mpsc::Receiver<RefreshTrigger>>,
48 result_tx: watch::Sender<Option<RefreshResult>>,
50 result_rx: watch::Receiver<Option<RefreshResult>>,
52}
53
54impl RefreshScheduler {
55 pub fn new() -> Self {
57 let (trigger_tx, trigger_rx) = mpsc::channel(32);
58 let (result_tx, result_rx) = watch::channel(None);
59 Self {
60 trigger_tx,
61 trigger_rx: Some(trigger_rx),
62 result_tx,
63 result_rx,
64 }
65 }
66
67 pub fn trigger_sender(&self) -> mpsc::Sender<RefreshTrigger> {
69 self.trigger_tx.clone()
70 }
71
72 pub fn result_receiver(&self) -> watch::Receiver<Option<RefreshResult>> {
74 self.result_rx.clone()
75 }
76
77 pub fn subscribe(
87 self,
88 sources: Vec<DynamicSource>,
89 resolver: Arc<dyn SourceResolver>,
90 ) -> SourceSubscription {
91 let results = self.result_receiver();
92 let trigger = self.trigger_sender();
93 let handle = self.run(sources, resolver);
94 SourceSubscription {
95 handle,
96 results,
97 trigger,
98 }
99 }
100
101 pub fn run(
109 mut self,
110 sources: Vec<DynamicSource>,
111 resolver: Arc<dyn SourceResolver>,
112 ) -> tokio::task::JoinHandle<()> {
113 let trigger_rx = self
114 .trigger_rx
115 .take()
116 .expect("run() can only be called once");
117
118 tokio::spawn(async move {
119 Self::run_loop(
120 sources,
121 resolver,
122 trigger_rx,
123 self.trigger_tx,
124 self.result_tx,
125 )
126 .await;
127 })
128 }
129
130 async fn run_loop(
131 sources: Vec<DynamicSource>,
132 resolver: Arc<dyn SourceResolver>,
133 mut trigger_rx: mpsc::Receiver<RefreshTrigger>,
134 trigger_tx: mpsc::Sender<RefreshTrigger>,
135 result_tx: watch::Sender<Option<RefreshResult>>,
136 ) {
137 for source in &sources {
139 if let RefreshPolicy::Interval(duration) = &source.refresh {
140 let tx = trigger_tx.clone();
141 let id = source.id.clone();
142 let interval = if *duration < super::MIN_REFRESH_INTERVAL {
143 tracing::warn!(
144 source_id = %id,
145 configured = ?duration,
146 clamped_to = ?super::MIN_REFRESH_INTERVAL,
147 "Refresh interval below minimum, clamping to floor"
148 );
149 super::MIN_REFRESH_INTERVAL
150 } else {
151 *duration
152 };
153 tokio::spawn(async move {
154 let mut timer = tokio::time::interval(interval);
155 timer.tick().await; loop {
157 timer.tick().await;
158 if tx.send(RefreshTrigger::Single(id.clone())).await.is_err() {
159 break;
160 }
161 }
162 });
163 }
164 }
165
166 #[cfg(feature = "nats")]
168 for source in &sources {
169 if source.refresh == RefreshPolicy::Push
170 && let SourceType::Nats {
171 url,
172 subject,
173 format,
174 extract: extract_expr,
175 } = &source.source_type
176 {
177 let tx = trigger_tx.clone();
178 let id = source.id.clone();
179 let url = url.clone();
180 let subject = subject.clone();
181 let format = *format;
182 let extract_expr = extract_expr.clone();
183 tokio::spawn(async move {
184 if let Err(e) =
185 nats_push_loop(&url, &subject, format, extract_expr.as_ref(), &id, &tx)
186 .await
187 {
188 tracing::error!(
189 source_id = %id,
190 error = %e,
191 "NATS push subscription failed"
192 );
193 }
194 });
195 }
196 }
197
198 for source in &sources {
200 if source.refresh == RefreshPolicy::Watch
201 && let SourceType::File { path, .. } = &source.source_type
202 {
203 let tx = trigger_tx.clone();
204 let id = source.id.clone();
205 let path = path.clone();
206 tokio::spawn(async move {
207 file_watch_loop(&path, &id, &tx).await;
208 });
209 }
210 }
211
212 while let Some(trigger) = trigger_rx.recv().await {
214 #[cfg(feature = "nats")]
216 if let RefreshTrigger::NatsPush { source_id, data } = trigger {
217 let mut resolved = HashMap::new();
218 resolved.insert(source_id, data);
219 let _ = result_tx.send(Some(RefreshResult { resolved }));
220 continue;
221 }
222
223 let to_resolve: Vec<&DynamicSource> = match &trigger {
224 RefreshTrigger::All => sources.iter().collect(),
225 RefreshTrigger::Single(id) => sources.iter().filter(|s| s.id == *id).collect(),
226 #[cfg(feature = "nats")]
227 RefreshTrigger::NatsPush { .. } => unreachable!(),
228 };
229
230 if to_resolve.is_empty() {
231 continue;
232 }
233
234 let refresh_count = to_resolve.len();
235 let refresh_start = std::time::Instant::now();
236 match resolve_all(
237 resolver.as_ref(),
238 &to_resolve.iter().map(|s| (*s).clone()).collect::<Vec<_>>(),
239 )
240 .await
241 {
242 Ok(resolved) => {
243 tracing::debug!(
244 sources = refresh_count,
245 duration_ms = refresh_start.elapsed().as_millis() as u64,
246 "Scheduled refresh completed",
247 );
248 let _ = result_tx.send(Some(RefreshResult { resolved }));
249 }
250 Err(e) => {
251 tracing::warn!(
252 error = %e,
253 sources = refresh_count,
254 duration_ms = refresh_start.elapsed().as_millis() as u64,
255 "Background source refresh failed",
256 );
257 }
258 }
259 }
260 }
261}
262
263impl Default for RefreshScheduler {
264 fn default() -> Self {
265 Self::new()
266 }
267}
268
269pub struct SourceSubscription {
276 pub handle: tokio::task::JoinHandle<()>,
278 pub results: watch::Receiver<Option<RefreshResult>>,
280 pub trigger: mpsc::Sender<RefreshTrigger>,
282}
283
284#[cfg(feature = "nats")]
286async fn nats_push_loop(
287 url: &str,
288 subject: &str,
289 format: rsigma_eval::pipeline::sources::DataFormat,
290 extract_expr: Option<&rsigma_eval::pipeline::sources::ExtractExpr>,
291 source_id: &str,
292 trigger_tx: &mpsc::Sender<RefreshTrigger>,
293) -> Result<(), String> {
294 use futures::StreamExt;
295
296 let client = async_nats::connect(url)
297 .await
298 .map_err(|e| format!("NATS connect failed: {e}"))?;
299
300 let mut subscriber = client
301 .subscribe(subject.to_string())
302 .await
303 .map_err(|e| format!("NATS subscribe failed: {e}"))?;
304
305 tracing::info!(
306 source_id = %source_id,
307 subject = %subject,
308 "NATS push subscription active"
309 );
310
311 while let Some(msg) = subscriber.next().await {
312 match super::nats::parse_nats_message(&msg.payload, format, extract_expr) {
313 Ok(data) => {
314 let trigger = RefreshTrigger::NatsPush {
315 source_id: source_id.to_string(),
316 data,
317 };
318 if trigger_tx.send(trigger).await.is_err() {
319 break;
320 }
321 }
322 Err(e) => {
323 tracing::warn!(
324 source_id = %source_id,
325 error = %e,
326 "Failed to parse NATS push message"
327 );
328 }
329 }
330 }
331
332 Ok(())
333}
334
335pub const NATS_CONTROL_SUBJECT: &str = "rsigma.control.resolve";
337
338#[cfg(feature = "nats")]
343pub async fn nats_control_loop(
344 url: &str,
345 subject: &str,
346 trigger_tx: mpsc::Sender<RefreshTrigger>,
347) -> Result<(), String> {
348 use futures::StreamExt;
349
350 let client = async_nats::connect(url)
351 .await
352 .map_err(|e| format!("NATS control connect failed: {e}"))?;
353
354 let mut subscriber = client
355 .subscribe(subject.to_string())
356 .await
357 .map_err(|e| format!("NATS control subscribe failed: {e}"))?;
358
359 tracing::info!(
360 subject = %subject,
361 "NATS control subscription active for source re-resolution"
362 );
363
364 while let Some(msg) = subscriber.next().await {
365 let payload = String::from_utf8_lossy(&msg.payload);
366 let payload = payload.trim();
367
368 let trigger = if payload.is_empty() {
369 tracing::debug!("NATS control: triggering all sources");
370 RefreshTrigger::All
371 } else {
372 tracing::debug!(source_id = %payload, "NATS control: triggering single source");
373 RefreshTrigger::Single(payload.to_string())
374 };
375
376 if trigger_tx.send(trigger).await.is_err() {
377 tracing::debug!("NATS control loop: trigger channel closed, exiting");
378 break;
379 }
380 }
381
382 Ok(())
383}
384
385async fn file_watch_loop(
387 path: &std::path::Path,
388 source_id: &str,
389 trigger_tx: &mpsc::Sender<RefreshTrigger>,
390) {
391 use notify::{Event, EventKind, RecommendedWatcher, Watcher};
392 use tokio::sync::mpsc as tokio_mpsc;
393
394 let (notify_tx, mut notify_rx) = tokio_mpsc::channel::<()>(4);
395
396 let _watcher = {
397 let tx = notify_tx.clone();
398 match RecommendedWatcher::new(
399 move |res: Result<Event, notify::Error>| {
400 if let Ok(event) = res
401 && matches!(event.kind, EventKind::Create(_) | EventKind::Modify(_))
402 {
403 let _ = tx.try_send(());
404 }
405 },
406 notify::Config::default(),
407 ) {
408 Ok(mut w) => {
409 if let Err(e) = w.watch(path, notify::RecursiveMode::NonRecursive) {
410 tracing::warn!(
411 source_id = %source_id,
412 path = %path.display(),
413 error = %e,
414 "Could not watch source file"
415 );
416 return;
417 }
418 tracing::info!(
419 source_id = %source_id,
420 path = %path.display(),
421 "Watching source file for changes"
422 );
423 Some(w)
424 }
425 Err(e) => {
426 tracing::warn!(
427 source_id = %source_id,
428 error = %e,
429 "Could not create file watcher for source"
430 );
431 return;
432 }
433 }
434 };
435
436 while notify_rx.recv().await.is_some() {
437 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
439 while notify_rx.try_recv().is_ok() {}
441
442 if trigger_tx
443 .send(RefreshTrigger::Single(source_id.to_string()))
444 .await
445 .is_err()
446 {
447 break;
448 }
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455 use crate::sources::DefaultSourceResolver;
456 use rsigma_eval::pipeline::sources::{DataFormat, ErrorPolicy, SourceType};
457
458 fn file_source(id: &str, path: std::path::PathBuf) -> DynamicSource {
459 DynamicSource {
460 id: id.to_string(),
461 source_type: SourceType::File {
462 path,
463 format: DataFormat::Json,
464 extract: None,
465 },
466 refresh: RefreshPolicy::OnDemand,
467 timeout: None,
468 on_error: ErrorPolicy::Fail,
469 required: true,
470 default: None,
471 }
472 }
473
474 #[tokio::test]
478 async fn subscribe_delivers_decoded_payload() {
479 let dir = tempfile::tempdir().unwrap();
480 let path = dir.path().join("dispositions.json");
481 std::fs::write(
482 &path,
483 r#"{"rules": [{"rule_id": "r1", "verdict": "false_positive"}]}"#,
484 )
485 .unwrap();
486
487 let scheduler = RefreshScheduler::new();
488 let sub = scheduler.subscribe(
489 vec![file_source("d", path)],
490 Arc::new(DefaultSourceResolver::new()),
491 );
492 let mut results = sub.results;
493
494 sub.trigger.send(RefreshTrigger::All).await.unwrap();
495 results.changed().await.unwrap();
496
497 let payload = results.borrow().clone().expect("a refresh result");
498 let data = payload.resolved.get("d").expect("source d resolved");
499 assert_eq!(data["rules"][0]["rule_id"], "r1");
500 assert_eq!(data["rules"][0]["verdict"], "false_positive");
501 }
502}