1use std::collections::HashMap;
10use std::sync::Arc;
11
12use rsigma_eval::pipeline::sources::{DynamicSource, RefreshPolicy, SourceType};
13use tokio::sync::{mpsc, watch};
14
15use super::{SourceResolver, resolve_all};
16
17#[derive(Debug, Clone)]
19pub enum RefreshTrigger {
20 All,
22 Single(String),
24 #[cfg(feature = "nats")]
26 NatsPush {
27 source_id: String,
28 data: serde_json::Value,
29 },
30}
31
32#[derive(Debug, Clone)]
34pub struct RefreshResult {
35 pub resolved: HashMap<String, serde_json::Value>,
37}
38
39pub struct RefreshScheduler {
44 trigger_tx: mpsc::Sender<RefreshTrigger>,
46 trigger_rx: Option<mpsc::Receiver<RefreshTrigger>>,
48 result_tx: watch::Sender<Option<RefreshResult>>,
50 result_rx: watch::Receiver<Option<RefreshResult>>,
52}
53
54impl RefreshScheduler {
55 pub fn new() -> Self {
57 let (trigger_tx, trigger_rx) = mpsc::channel(32);
58 let (result_tx, result_rx) = watch::channel(None);
59 Self {
60 trigger_tx,
61 trigger_rx: Some(trigger_rx),
62 result_tx,
63 result_rx,
64 }
65 }
66
67 pub fn trigger_sender(&self) -> mpsc::Sender<RefreshTrigger> {
69 self.trigger_tx.clone()
70 }
71
72 pub fn result_receiver(&self) -> watch::Receiver<Option<RefreshResult>> {
74 self.result_rx.clone()
75 }
76
77 pub fn run(
85 mut self,
86 sources: Vec<DynamicSource>,
87 resolver: Arc<dyn SourceResolver>,
88 ) -> tokio::task::JoinHandle<()> {
89 let trigger_rx = self
90 .trigger_rx
91 .take()
92 .expect("run() can only be called once");
93
94 tokio::spawn(async move {
95 Self::run_loop(
96 sources,
97 resolver,
98 trigger_rx,
99 self.trigger_tx,
100 self.result_tx,
101 )
102 .await;
103 })
104 }
105
106 async fn run_loop(
107 sources: Vec<DynamicSource>,
108 resolver: Arc<dyn SourceResolver>,
109 mut trigger_rx: mpsc::Receiver<RefreshTrigger>,
110 trigger_tx: mpsc::Sender<RefreshTrigger>,
111 result_tx: watch::Sender<Option<RefreshResult>>,
112 ) {
113 for source in &sources {
115 if let RefreshPolicy::Interval(duration) = &source.refresh {
116 let tx = trigger_tx.clone();
117 let id = source.id.clone();
118 let interval = if *duration < super::MIN_REFRESH_INTERVAL {
119 tracing::warn!(
120 source_id = %id,
121 configured = ?duration,
122 clamped_to = ?super::MIN_REFRESH_INTERVAL,
123 "Refresh interval below minimum, clamping to floor"
124 );
125 super::MIN_REFRESH_INTERVAL
126 } else {
127 *duration
128 };
129 tokio::spawn(async move {
130 let mut timer = tokio::time::interval(interval);
131 timer.tick().await; loop {
133 timer.tick().await;
134 if tx.send(RefreshTrigger::Single(id.clone())).await.is_err() {
135 break;
136 }
137 }
138 });
139 }
140 }
141
142 #[cfg(feature = "nats")]
144 for source in &sources {
145 if source.refresh == RefreshPolicy::Push
146 && let SourceType::Nats {
147 url,
148 subject,
149 format,
150 extract: extract_expr,
151 } = &source.source_type
152 {
153 let tx = trigger_tx.clone();
154 let id = source.id.clone();
155 let url = url.clone();
156 let subject = subject.clone();
157 let format = *format;
158 let extract_expr = extract_expr.clone();
159 tokio::spawn(async move {
160 if let Err(e) =
161 nats_push_loop(&url, &subject, format, extract_expr.as_ref(), &id, &tx)
162 .await
163 {
164 tracing::error!(
165 source_id = %id,
166 error = %e,
167 "NATS push subscription failed"
168 );
169 }
170 });
171 }
172 }
173
174 for source in &sources {
176 if source.refresh == RefreshPolicy::Watch
177 && let SourceType::File { path, .. } = &source.source_type
178 {
179 let tx = trigger_tx.clone();
180 let id = source.id.clone();
181 let path = path.clone();
182 tokio::spawn(async move {
183 file_watch_loop(&path, &id, &tx).await;
184 });
185 }
186 }
187
188 while let Some(trigger) = trigger_rx.recv().await {
190 #[cfg(feature = "nats")]
192 if let RefreshTrigger::NatsPush { source_id, data } = trigger {
193 let mut resolved = HashMap::new();
194 resolved.insert(source_id, data);
195 let _ = result_tx.send(Some(RefreshResult { resolved }));
196 continue;
197 }
198
199 let to_resolve: Vec<&DynamicSource> = match &trigger {
200 RefreshTrigger::All => sources.iter().collect(),
201 RefreshTrigger::Single(id) => sources.iter().filter(|s| s.id == *id).collect(),
202 #[cfg(feature = "nats")]
203 RefreshTrigger::NatsPush { .. } => unreachable!(),
204 };
205
206 if to_resolve.is_empty() {
207 continue;
208 }
209
210 let refresh_count = to_resolve.len();
211 let refresh_start = std::time::Instant::now();
212 match resolve_all(
213 resolver.as_ref(),
214 &to_resolve.iter().map(|s| (*s).clone()).collect::<Vec<_>>(),
215 )
216 .await
217 {
218 Ok(resolved) => {
219 tracing::debug!(
220 sources = refresh_count,
221 duration_ms = refresh_start.elapsed().as_millis() as u64,
222 "Scheduled refresh completed",
223 );
224 let _ = result_tx.send(Some(RefreshResult { resolved }));
225 }
226 Err(e) => {
227 tracing::warn!(
228 error = %e,
229 sources = refresh_count,
230 duration_ms = refresh_start.elapsed().as_millis() as u64,
231 "Background source refresh failed",
232 );
233 }
234 }
235 }
236 }
237}
238
239impl Default for RefreshScheduler {
240 fn default() -> Self {
241 Self::new()
242 }
243}
244
245#[cfg(feature = "nats")]
247async fn nats_push_loop(
248 url: &str,
249 subject: &str,
250 format: rsigma_eval::pipeline::sources::DataFormat,
251 extract_expr: Option<&rsigma_eval::pipeline::sources::ExtractExpr>,
252 source_id: &str,
253 trigger_tx: &mpsc::Sender<RefreshTrigger>,
254) -> Result<(), String> {
255 use futures::StreamExt;
256
257 let client = async_nats::connect(url)
258 .await
259 .map_err(|e| format!("NATS connect failed: {e}"))?;
260
261 let mut subscriber = client
262 .subscribe(subject.to_string())
263 .await
264 .map_err(|e| format!("NATS subscribe failed: {e}"))?;
265
266 tracing::info!(
267 source_id = %source_id,
268 subject = %subject,
269 "NATS push subscription active"
270 );
271
272 while let Some(msg) = subscriber.next().await {
273 match super::nats::parse_nats_message(&msg.payload, format, extract_expr) {
274 Ok(data) => {
275 let trigger = RefreshTrigger::NatsPush {
276 source_id: source_id.to_string(),
277 data,
278 };
279 if trigger_tx.send(trigger).await.is_err() {
280 break;
281 }
282 }
283 Err(e) => {
284 tracing::warn!(
285 source_id = %source_id,
286 error = %e,
287 "Failed to parse NATS push message"
288 );
289 }
290 }
291 }
292
293 Ok(())
294}
295
296pub const NATS_CONTROL_SUBJECT: &str = "rsigma.control.resolve";
298
299#[cfg(feature = "nats")]
304pub async fn nats_control_loop(
305 url: &str,
306 subject: &str,
307 trigger_tx: mpsc::Sender<RefreshTrigger>,
308) -> Result<(), String> {
309 use futures::StreamExt;
310
311 let client = async_nats::connect(url)
312 .await
313 .map_err(|e| format!("NATS control connect failed: {e}"))?;
314
315 let mut subscriber = client
316 .subscribe(subject.to_string())
317 .await
318 .map_err(|e| format!("NATS control subscribe failed: {e}"))?;
319
320 tracing::info!(
321 subject = %subject,
322 "NATS control subscription active for source re-resolution"
323 );
324
325 while let Some(msg) = subscriber.next().await {
326 let payload = String::from_utf8_lossy(&msg.payload);
327 let payload = payload.trim();
328
329 let trigger = if payload.is_empty() {
330 tracing::debug!("NATS control: triggering all sources");
331 RefreshTrigger::All
332 } else {
333 tracing::debug!(source_id = %payload, "NATS control: triggering single source");
334 RefreshTrigger::Single(payload.to_string())
335 };
336
337 if trigger_tx.send(trigger).await.is_err() {
338 tracing::debug!("NATS control loop: trigger channel closed, exiting");
339 break;
340 }
341 }
342
343 Ok(())
344}
345
346async fn file_watch_loop(
348 path: &std::path::Path,
349 source_id: &str,
350 trigger_tx: &mpsc::Sender<RefreshTrigger>,
351) {
352 use notify::{Event, EventKind, RecommendedWatcher, Watcher};
353 use tokio::sync::mpsc as tokio_mpsc;
354
355 let (notify_tx, mut notify_rx) = tokio_mpsc::channel::<()>(4);
356
357 let _watcher = {
358 let tx = notify_tx.clone();
359 match RecommendedWatcher::new(
360 move |res: Result<Event, notify::Error>| {
361 if let Ok(event) = res
362 && matches!(event.kind, EventKind::Create(_) | EventKind::Modify(_))
363 {
364 let _ = tx.try_send(());
365 }
366 },
367 notify::Config::default(),
368 ) {
369 Ok(mut w) => {
370 if let Err(e) = w.watch(path, notify::RecursiveMode::NonRecursive) {
371 tracing::warn!(
372 source_id = %source_id,
373 path = %path.display(),
374 error = %e,
375 "Could not watch source file"
376 );
377 return;
378 }
379 tracing::info!(
380 source_id = %source_id,
381 path = %path.display(),
382 "Watching source file for changes"
383 );
384 Some(w)
385 }
386 Err(e) => {
387 tracing::warn!(
388 source_id = %source_id,
389 error = %e,
390 "Could not create file watcher for source"
391 );
392 return;
393 }
394 }
395 };
396
397 while notify_rx.recv().await.is_some() {
398 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
400 while notify_rx.try_recv().is_ok() {}
402
403 if trigger_tx
404 .send(RefreshTrigger::Single(source_id.to_string()))
405 .await
406 .is_err()
407 {
408 break;
409 }
410 }
411}