1use std::collections::HashMap;
10use std::sync::Arc;
11
12use rsigma_eval::pipeline::sources::{DynamicSource, RefreshPolicy, SourceType};
13use tokio::sync::{mpsc, watch};
14
15use super::{SourceResolver, resolve_all};
16
17#[derive(Debug, Clone)]
19pub enum RefreshTrigger {
20 All,
22 Single(String),
24 #[cfg(feature = "nats")]
26 NatsPush {
27 source_id: String,
28 data: serde_json::Value,
29 },
30}
31
32#[derive(Debug, Clone)]
34pub struct RefreshResult {
35 pub resolved: HashMap<String, serde_json::Value>,
37}
38
39pub struct RefreshScheduler {
44 trigger_tx: mpsc::Sender<RefreshTrigger>,
46 trigger_rx: Option<mpsc::Receiver<RefreshTrigger>>,
48 result_tx: watch::Sender<Option<RefreshResult>>,
50 result_rx: watch::Receiver<Option<RefreshResult>>,
52}
53
54impl RefreshScheduler {
55 pub fn new() -> Self {
57 let (trigger_tx, trigger_rx) = mpsc::channel(32);
58 let (result_tx, result_rx) = watch::channel(None);
59 Self {
60 trigger_tx,
61 trigger_rx: Some(trigger_rx),
62 result_tx,
63 result_rx,
64 }
65 }
66
67 pub fn trigger_sender(&self) -> mpsc::Sender<RefreshTrigger> {
69 self.trigger_tx.clone()
70 }
71
72 pub fn result_receiver(&self) -> watch::Receiver<Option<RefreshResult>> {
74 self.result_rx.clone()
75 }
76
77 pub fn run(
85 mut self,
86 sources: Vec<DynamicSource>,
87 resolver: Arc<dyn SourceResolver>,
88 ) -> tokio::task::JoinHandle<()> {
89 let trigger_rx = self
90 .trigger_rx
91 .take()
92 .expect("run() can only be called once");
93
94 tokio::spawn(async move {
95 Self::run_loop(
96 sources,
97 resolver,
98 trigger_rx,
99 self.trigger_tx,
100 self.result_tx,
101 )
102 .await;
103 })
104 }
105
106 async fn run_loop(
107 sources: Vec<DynamicSource>,
108 resolver: Arc<dyn SourceResolver>,
109 mut trigger_rx: mpsc::Receiver<RefreshTrigger>,
110 trigger_tx: mpsc::Sender<RefreshTrigger>,
111 result_tx: watch::Sender<Option<RefreshResult>>,
112 ) {
113 for source in &sources {
115 if let RefreshPolicy::Interval(duration) = &source.refresh {
116 let tx = trigger_tx.clone();
117 let id = source.id.clone();
118 let interval = *duration;
119 tokio::spawn(async move {
120 let mut timer = tokio::time::interval(interval);
121 timer.tick().await; loop {
123 timer.tick().await;
124 if tx.send(RefreshTrigger::Single(id.clone())).await.is_err() {
125 break;
126 }
127 }
128 });
129 }
130 }
131
132 #[cfg(feature = "nats")]
134 for source in &sources {
135 if source.refresh == RefreshPolicy::Push
136 && let SourceType::Nats {
137 url,
138 subject,
139 format,
140 extract: extract_expr,
141 } = &source.source_type
142 {
143 let tx = trigger_tx.clone();
144 let id = source.id.clone();
145 let url = url.clone();
146 let subject = subject.clone();
147 let format = *format;
148 let extract_expr = extract_expr.clone();
149 tokio::spawn(async move {
150 if let Err(e) =
151 nats_push_loop(&url, &subject, format, extract_expr.as_ref(), &id, &tx)
152 .await
153 {
154 tracing::error!(
155 source_id = %id,
156 error = %e,
157 "NATS push subscription failed"
158 );
159 }
160 });
161 }
162 }
163
164 for source in &sources {
166 if source.refresh == RefreshPolicy::Watch
167 && let SourceType::File { path, .. } = &source.source_type
168 {
169 let tx = trigger_tx.clone();
170 let id = source.id.clone();
171 let path = path.clone();
172 tokio::spawn(async move {
173 file_watch_loop(&path, &id, &tx).await;
174 });
175 }
176 }
177
178 while let Some(trigger) = trigger_rx.recv().await {
180 #[cfg(feature = "nats")]
182 if let RefreshTrigger::NatsPush { source_id, data } = trigger {
183 let mut resolved = HashMap::new();
184 resolved.insert(source_id, data);
185 let _ = result_tx.send(Some(RefreshResult { resolved }));
186 continue;
187 }
188
189 let to_resolve: Vec<&DynamicSource> = match &trigger {
190 RefreshTrigger::All => sources.iter().collect(),
191 RefreshTrigger::Single(id) => sources.iter().filter(|s| s.id == *id).collect(),
192 #[cfg(feature = "nats")]
193 RefreshTrigger::NatsPush { .. } => unreachable!(),
194 };
195
196 if to_resolve.is_empty() {
197 continue;
198 }
199
200 match resolve_all(
201 resolver.as_ref(),
202 &to_resolve.iter().map(|s| (*s).clone()).collect::<Vec<_>>(),
203 )
204 .await
205 {
206 Ok(resolved) => {
207 let _ = result_tx.send(Some(RefreshResult { resolved }));
208 }
209 Err(e) => {
210 tracing::warn!(error = %e, "Background source refresh failed");
211 }
212 }
213 }
214 }
215}
216
217impl Default for RefreshScheduler {
218 fn default() -> Self {
219 Self::new()
220 }
221}
222
223#[cfg(feature = "nats")]
225async fn nats_push_loop(
226 url: &str,
227 subject: &str,
228 format: rsigma_eval::pipeline::sources::DataFormat,
229 extract_expr: Option<&rsigma_eval::pipeline::sources::ExtractExpr>,
230 source_id: &str,
231 trigger_tx: &mpsc::Sender<RefreshTrigger>,
232) -> Result<(), String> {
233 use futures::StreamExt;
234
235 let client = async_nats::connect(url)
236 .await
237 .map_err(|e| format!("NATS connect failed: {e}"))?;
238
239 let mut subscriber = client
240 .subscribe(subject.to_string())
241 .await
242 .map_err(|e| format!("NATS subscribe failed: {e}"))?;
243
244 tracing::info!(
245 source_id = %source_id,
246 subject = %subject,
247 "NATS push subscription active"
248 );
249
250 while let Some(msg) = subscriber.next().await {
251 match super::nats::parse_nats_message(&msg.payload, format, extract_expr) {
252 Ok(data) => {
253 let trigger = RefreshTrigger::NatsPush {
254 source_id: source_id.to_string(),
255 data,
256 };
257 if trigger_tx.send(trigger).await.is_err() {
258 break;
259 }
260 }
261 Err(e) => {
262 tracing::warn!(
263 source_id = %source_id,
264 error = %e,
265 "Failed to parse NATS push message"
266 );
267 }
268 }
269 }
270
271 Ok(())
272}
273
274pub const NATS_CONTROL_SUBJECT: &str = "rsigma.control.resolve";
276
277#[cfg(feature = "nats")]
282pub async fn nats_control_loop(
283 url: &str,
284 subject: &str,
285 trigger_tx: mpsc::Sender<RefreshTrigger>,
286) -> Result<(), String> {
287 use futures::StreamExt;
288
289 let client = async_nats::connect(url)
290 .await
291 .map_err(|e| format!("NATS control connect failed: {e}"))?;
292
293 let mut subscriber = client
294 .subscribe(subject.to_string())
295 .await
296 .map_err(|e| format!("NATS control subscribe failed: {e}"))?;
297
298 tracing::info!(
299 subject = %subject,
300 "NATS control subscription active for source re-resolution"
301 );
302
303 while let Some(msg) = subscriber.next().await {
304 let payload = String::from_utf8_lossy(&msg.payload);
305 let payload = payload.trim();
306
307 let trigger = if payload.is_empty() {
308 tracing::debug!("NATS control: triggering all sources");
309 RefreshTrigger::All
310 } else {
311 tracing::debug!(source_id = %payload, "NATS control: triggering single source");
312 RefreshTrigger::Single(payload.to_string())
313 };
314
315 if trigger_tx.send(trigger).await.is_err() {
316 tracing::debug!("NATS control loop: trigger channel closed, exiting");
317 break;
318 }
319 }
320
321 Ok(())
322}
323
324async fn file_watch_loop(
326 path: &std::path::Path,
327 source_id: &str,
328 trigger_tx: &mpsc::Sender<RefreshTrigger>,
329) {
330 use notify::{Event, EventKind, RecommendedWatcher, Watcher};
331 use tokio::sync::mpsc as tokio_mpsc;
332
333 let (notify_tx, mut notify_rx) = tokio_mpsc::channel::<()>(4);
334
335 let _watcher = {
336 let tx = notify_tx.clone();
337 match RecommendedWatcher::new(
338 move |res: Result<Event, notify::Error>| {
339 if let Ok(event) = res
340 && matches!(event.kind, EventKind::Create(_) | EventKind::Modify(_))
341 {
342 let _ = tx.try_send(());
343 }
344 },
345 notify::Config::default(),
346 ) {
347 Ok(mut w) => {
348 if let Err(e) = w.watch(path, notify::RecursiveMode::NonRecursive) {
349 tracing::warn!(
350 source_id = %source_id,
351 path = %path.display(),
352 error = %e,
353 "Could not watch source file"
354 );
355 return;
356 }
357 tracing::info!(
358 source_id = %source_id,
359 path = %path.display(),
360 "Watching source file for changes"
361 );
362 Some(w)
363 }
364 Err(e) => {
365 tracing::warn!(
366 source_id = %source_id,
367 error = %e,
368 "Could not create file watcher for source"
369 );
370 return;
371 }
372 }
373 };
374
375 while notify_rx.recv().await.is_some() {
376 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
378 while notify_rx.try_recv().is_ok() {}
380
381 if trigger_tx
382 .send(RefreshTrigger::Single(source_id.to_string()))
383 .await
384 .is_err()
385 {
386 break;
387 }
388 }
389}