1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use tokio::sync::{RwLock, broadcast, mpsc};
5use uuid::Uuid;
6
7use forge_core::cluster::NodeId;
8use forge_core::realtime::{Change, ReadSet, SessionId, SubscriptionId};
9
10use super::invalidation::{InvalidationConfig, InvalidationEngine};
11use super::listener::{ChangeListener, ListenerConfig};
12use super::manager::SubscriptionManager;
13use super::message::{
14 JobData, RealtimeConfig, RealtimeMessage, SessionServer, WorkflowData, WorkflowStepData,
15};
16use crate::function::{FunctionEntry, FunctionRegistry};
17
18#[derive(Debug, Clone)]
19pub struct ReactorConfig {
20 pub listener: ListenerConfig,
21 pub invalidation: InvalidationConfig,
22 pub realtime: RealtimeConfig,
23 pub max_listener_restarts: u32,
24 pub listener_restart_delay_ms: u64,
26}
27
28impl Default for ReactorConfig {
29 fn default() -> Self {
30 Self {
31 listener: ListenerConfig::default(),
32 invalidation: InvalidationConfig::default(),
33 realtime: RealtimeConfig::default(),
34 max_listener_restarts: 5,
35 listener_restart_delay_ms: 1000,
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct ActiveSubscription {
43 #[allow(dead_code)]
44 pub subscription_id: SubscriptionId,
45 pub session_id: SessionId,
46 #[allow(dead_code)]
47 pub client_sub_id: String,
48 pub query_name: String,
49 pub args: serde_json::Value,
50 pub last_result_hash: Option<String>,
51 #[allow(dead_code)]
52 pub read_set: ReadSet,
53 pub auth_context: forge_core::function::AuthContext,
55}
56
57#[derive(Debug, Clone)]
59pub struct JobSubscription {
60 #[allow(dead_code)]
61 pub subscription_id: SubscriptionId,
62 pub session_id: SessionId,
63 pub client_sub_id: String,
64 #[allow(dead_code)]
65 pub job_id: Uuid, pub auth_context: forge_core::function::AuthContext,
67}
68
69#[derive(Debug, Clone)]
71pub struct WorkflowSubscription {
72 #[allow(dead_code)]
73 pub subscription_id: SubscriptionId,
74 pub session_id: SessionId,
75 pub client_sub_id: String,
76 #[allow(dead_code)]
77 pub workflow_id: Uuid, pub auth_context: forge_core::function::AuthContext,
79}
80
81pub struct Reactor {
83 node_id: NodeId,
84 db_pool: sqlx::PgPool,
85 registry: FunctionRegistry,
86 subscription_manager: Arc<SubscriptionManager>,
87 session_server: Arc<SessionServer>,
88 change_listener: Arc<ChangeListener>,
89 invalidation_engine: Arc<InvalidationEngine>,
90 active_subscriptions: Arc<RwLock<HashMap<SubscriptionId, ActiveSubscription>>>,
92 job_subscriptions: Arc<RwLock<HashMap<Uuid, Vec<JobSubscription>>>>,
94 workflow_subscriptions: Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
96 shutdown_tx: broadcast::Sender<()>,
98 max_listener_restarts: u32,
100 listener_restart_delay_ms: u64,
101}
102
103impl Reactor {
104 pub fn new(
106 node_id: NodeId,
107 db_pool: sqlx::PgPool,
108 registry: FunctionRegistry,
109 config: ReactorConfig,
110 ) -> Self {
111 let subscription_manager = Arc::new(SubscriptionManager::new(
112 config.realtime.max_subscriptions_per_session,
113 ));
114 let session_server = Arc::new(SessionServer::new(node_id, config.realtime.clone()));
115 let change_listener = Arc::new(ChangeListener::new(db_pool.clone(), config.listener));
116 let invalidation_engine = Arc::new(InvalidationEngine::new(
117 subscription_manager.clone(),
118 config.invalidation,
119 ));
120 let (shutdown_tx, _) = broadcast::channel(1);
121
122 Self {
123 node_id,
124 db_pool,
125 registry,
126 subscription_manager,
127 session_server,
128 change_listener,
129 invalidation_engine,
130 active_subscriptions: Arc::new(RwLock::new(HashMap::new())),
131 job_subscriptions: Arc::new(RwLock::new(HashMap::new())),
132 workflow_subscriptions: Arc::new(RwLock::new(HashMap::new())),
133 shutdown_tx,
134 max_listener_restarts: config.max_listener_restarts,
135 listener_restart_delay_ms: config.listener_restart_delay_ms,
136 }
137 }
138
139 pub fn node_id(&self) -> NodeId {
141 self.node_id
142 }
143
144 pub fn session_server(&self) -> Arc<SessionServer> {
146 self.session_server.clone()
147 }
148
149 pub fn subscription_manager(&self) -> Arc<SubscriptionManager> {
151 self.subscription_manager.clone()
152 }
153
154 pub fn shutdown_receiver(&self) -> broadcast::Receiver<()> {
156 self.shutdown_tx.subscribe()
157 }
158
159 pub async fn register_session(
161 &self,
162 session_id: SessionId,
163 sender: mpsc::Sender<RealtimeMessage>,
164 ) {
165 self.session_server
166 .register_connection(session_id, sender)
167 .await;
168 tracing::trace!(?session_id, "Session registered");
169 }
170
171 pub async fn remove_session(&self, session_id: SessionId) {
173 if let Some(subscription_ids) = self.session_server.remove_connection(session_id).await {
174 for sub_id in subscription_ids {
176 self.subscription_manager.remove_subscription(sub_id).await;
177 self.active_subscriptions.write().await.remove(&sub_id);
178 }
179 }
180
181 {
183 let mut job_subs = self.job_subscriptions.write().await;
184 for subscribers in job_subs.values_mut() {
185 subscribers.retain(|s| s.session_id != session_id);
186 }
187 job_subs.retain(|_, v| !v.is_empty());
189 }
190
191 {
193 let mut workflow_subs = self.workflow_subscriptions.write().await;
194 for subscribers in workflow_subs.values_mut() {
195 subscribers.retain(|s| s.session_id != session_id);
196 }
197 workflow_subs.retain(|_, v| !v.is_empty());
199 }
200
201 tracing::trace!(?session_id, "Session removed");
202 }
203
204 pub async fn subscribe(
206 &self,
207 session_id: SessionId,
208 client_sub_id: String,
209 query_name: String,
210 args: serde_json::Value,
211 auth_context: forge_core::function::AuthContext,
212 ) -> forge_core::Result<(SubscriptionId, serde_json::Value)> {
213 let sub_info = self
214 .subscription_manager
215 .create_subscription(session_id, &query_name, args.clone())
216 .await?;
217
218 let subscription_id = sub_info.id;
219
220 if let Err(error) = self
221 .session_server
222 .add_subscription(session_id, subscription_id)
223 .await
224 {
225 self.subscription_manager
226 .remove_subscription(subscription_id)
227 .await;
228 return Err(error);
229 }
230
231 let (data, read_set) = match self.execute_query(&query_name, &args, &auth_context).await {
232 Ok(result) => result,
233 Err(error) => {
234 self.unsubscribe(subscription_id).await;
236 return Err(error);
237 }
238 };
239
240 let result_hash = Self::compute_hash(&data);
241
242 tracing::trace!(
243 ?subscription_id,
244 query = %query_name,
245 tables = ?read_set.tables.iter().collect::<Vec<_>>(),
246 "Subscription read set"
247 );
248
249 self.subscription_manager
250 .update_subscription(subscription_id, read_set.clone(), result_hash.clone())
251 .await;
252
253 let active = ActiveSubscription {
254 subscription_id,
255 session_id,
256 client_sub_id,
257 query_name,
258 args,
259 last_result_hash: Some(result_hash),
260 read_set,
261 auth_context,
262 };
263 self.active_subscriptions
264 .write()
265 .await
266 .insert(subscription_id, active);
267
268 tracing::trace!(?subscription_id, "Subscription created");
269
270 Ok((subscription_id, data))
271 }
272
273 pub async fn unsubscribe(&self, subscription_id: SubscriptionId) {
275 self.session_server
276 .remove_subscription(subscription_id)
277 .await;
278 self.subscription_manager
279 .remove_subscription(subscription_id)
280 .await;
281 self.active_subscriptions
282 .write()
283 .await
284 .remove(&subscription_id);
285 tracing::trace!(?subscription_id, "Subscription removed");
286 }
287
288 pub async fn subscribe_job(
290 &self,
291 session_id: SessionId,
292 client_sub_id: String,
293 job_id: Uuid, auth_context: &forge_core::function::AuthContext,
295 ) -> forge_core::Result<JobData> {
296 let subscription_id = SubscriptionId::new();
297
298 Self::ensure_job_access(&self.db_pool, job_id, auth_context).await?;
299
300 let job_data = self.fetch_job_data(job_id).await?;
302
303 let subscription = JobSubscription {
305 subscription_id,
306 session_id,
307 client_sub_id: client_sub_id.clone(),
308 job_id,
309 auth_context: auth_context.clone(),
310 };
311
312 let mut subs = self.job_subscriptions.write().await;
313 subs.entry(job_id).or_default().push(subscription);
314
315 tracing::trace!(
316 ?subscription_id,
317 %job_id,
318 "Job subscription created"
319 );
320
321 Ok(job_data)
322 }
323
324 pub async fn unsubscribe_job(&self, session_id: SessionId, client_sub_id: &str) {
326 let mut subs = self.job_subscriptions.write().await;
327
328 for subscribers in subs.values_mut() {
330 subscribers
331 .retain(|s| !(s.session_id == session_id && s.client_sub_id == client_sub_id));
332 }
333
334 subs.retain(|_, v| !v.is_empty());
336
337 tracing::trace!(client_id = %client_sub_id, "Job subscription removed");
338 }
339
340 pub async fn subscribe_workflow(
342 &self,
343 session_id: SessionId,
344 client_sub_id: String,
345 workflow_id: Uuid, auth_context: &forge_core::function::AuthContext,
347 ) -> forge_core::Result<WorkflowData> {
348 let subscription_id = SubscriptionId::new();
349
350 Self::ensure_workflow_access(&self.db_pool, workflow_id, auth_context).await?;
351
352 let workflow_data = self.fetch_workflow_data(workflow_id).await?;
354
355 let subscription = WorkflowSubscription {
357 subscription_id,
358 session_id,
359 client_sub_id: client_sub_id.clone(),
360 workflow_id,
361 auth_context: auth_context.clone(),
362 };
363
364 let mut subs = self.workflow_subscriptions.write().await;
365 subs.entry(workflow_id).or_default().push(subscription);
366
367 tracing::trace!(
368 ?subscription_id,
369 %workflow_id,
370 "Workflow subscription created"
371 );
372
373 Ok(workflow_data)
374 }
375
376 pub async fn unsubscribe_workflow(&self, session_id: SessionId, client_sub_id: &str) {
378 let mut subs = self.workflow_subscriptions.write().await;
379
380 for subscribers in subs.values_mut() {
382 subscribers
383 .retain(|s| !(s.session_id == session_id && s.client_sub_id == client_sub_id));
384 }
385
386 subs.retain(|_, v| !v.is_empty());
388
389 tracing::trace!(client_id = %client_sub_id, "Workflow subscription removed");
390 }
391
392 #[allow(clippy::type_complexity)]
394 async fn fetch_job_data(&self, job_id: Uuid) -> forge_core::Result<JobData> {
395 let row: Option<(
396 String,
397 Option<i32>,
398 Option<String>,
399 Option<serde_json::Value>,
400 Option<String>,
401 )> = sqlx::query_as(
402 r#"
403 SELECT status, progress_percent, progress_message, output,
404 COALESCE(cancel_reason, last_error) as error
405 FROM forge_jobs WHERE id = $1
406 "#,
407 )
408 .bind(job_id)
409 .fetch_optional(&self.db_pool)
410 .await
411 .map_err(forge_core::ForgeError::Sql)?;
412
413 match row {
414 Some((status, progress_percent, progress_message, output, error)) => Ok(JobData {
415 job_id: job_id.to_string(),
416 status,
417 progress_percent,
418 progress_message,
419 output,
420 error,
421 }),
422 None => Err(forge_core::ForgeError::NotFound(format!(
423 "Job {} not found",
424 job_id
425 ))),
426 }
427 }
428
429 #[allow(clippy::type_complexity)]
431 async fn fetch_workflow_data(&self, workflow_id: Uuid) -> forge_core::Result<WorkflowData> {
432 let row: Option<(
434 String,
435 Option<String>,
436 Option<serde_json::Value>,
437 Option<String>,
438 )> = sqlx::query_as(
439 r#"
440 SELECT status, current_step, output, error
441 FROM forge_workflow_runs WHERE id = $1
442 "#,
443 )
444 .bind(workflow_id)
445 .fetch_optional(&self.db_pool)
446 .await
447 .map_err(forge_core::ForgeError::Sql)?;
448
449 let (status, current_step, output, error) = match row {
450 Some(r) => r,
451 None => {
452 return Err(forge_core::ForgeError::NotFound(format!(
453 "Workflow {} not found",
454 workflow_id
455 )));
456 }
457 };
458
459 let step_rows: Vec<(String, String, Option<String>)> = sqlx::query_as(
461 r#"
462 SELECT step_name, status, error
463 FROM forge_workflow_steps
464 WHERE workflow_run_id = $1
465 ORDER BY started_at ASC NULLS LAST
466 "#,
467 )
468 .bind(workflow_id)
469 .fetch_all(&self.db_pool)
470 .await
471 .map_err(forge_core::ForgeError::Sql)?;
472
473 let steps = step_rows
474 .into_iter()
475 .map(|(name, status, error)| WorkflowStepData {
476 name,
477 status,
478 error,
479 })
480 .collect();
481
482 Ok(WorkflowData {
483 workflow_id: workflow_id.to_string(),
484 status,
485 current_step,
486 steps,
487 output,
488 error,
489 })
490 }
491
492 async fn execute_query(
494 &self,
495 query_name: &str,
496 args: &serde_json::Value,
497 auth_context: &forge_core::function::AuthContext,
498 ) -> forge_core::Result<(serde_json::Value, ReadSet)> {
499 match self.registry.get(query_name) {
500 Some(FunctionEntry::Query { info, handler }) => {
501 Self::check_query_auth(info, auth_context)?;
502 Self::check_identity_args(query_name, args, auth_context, !info.is_public)?;
503
504 let ctx = forge_core::function::QueryContext::new(
505 self.db_pool.clone(),
506 auth_context.clone(),
507 forge_core::function::RequestMetadata::new(),
508 );
509
510 let normalized_args = match args {
512 v if v.as_object().is_some_and(|o| o.is_empty()) => serde_json::Value::Null,
513 v => v.clone(),
514 };
515
516 let data = handler(&ctx, normalized_args).await?;
517
518 let mut read_set = ReadSet::new();
520
521 if info.table_dependencies.is_empty() {
522 let table_name = Self::extract_table_name(query_name);
525 read_set.add_table(&table_name);
526 tracing::trace!(
527 query = %query_name,
528 fallback_table = %table_name,
529 "Using naming convention fallback for table dependency"
530 );
531 } else {
532 for table in info.table_dependencies {
534 read_set.add_table(*table);
535 }
536 }
537
538 Ok((data, read_set))
539 }
540 Some(_) => Err(forge_core::ForgeError::Validation(format!(
541 "'{}' is not a query",
542 query_name
543 ))),
544 None => Err(forge_core::ForgeError::Validation(format!(
545 "Query '{}' not found",
546 query_name
547 ))),
548 }
549 }
550
551 fn compute_hash(data: &serde_json::Value) -> String {
553 use std::collections::hash_map::DefaultHasher;
554 use std::hash::{Hash, Hasher};
555
556 let json = serde_json::to_string(data).unwrap_or_default();
557 let mut hasher = DefaultHasher::new();
558 json.hash(&mut hasher);
559 format!("{:x}", hasher.finish())
560 }
561
562 pub async fn start(&self) -> forge_core::Result<()> {
564 let listener = self.change_listener.clone();
565 let invalidation_engine = self.invalidation_engine.clone();
566 let active_subscriptions = self.active_subscriptions.clone();
567 let job_subscriptions = self.job_subscriptions.clone();
568 let workflow_subscriptions = self.workflow_subscriptions.clone();
569 let session_server = self.session_server.clone();
570 let registry = self.registry.clone();
571 let db_pool = self.db_pool.clone();
572 let mut shutdown_rx = self.shutdown_tx.subscribe();
573 let max_restarts = self.max_listener_restarts;
574 let base_delay_ms = self.listener_restart_delay_ms;
575
576 let mut change_rx = listener.subscribe();
578
579 tokio::spawn(async move {
581 tracing::debug!("Reactor listening for changes");
582
583 let mut restart_count: u32 = 0;
584 let (listener_error_tx, mut listener_error_rx) = mpsc::channel::<String>(1);
585
586 let listener_clone = listener.clone();
588 let error_tx = listener_error_tx.clone();
589 let mut listener_handle = Some(tokio::spawn(async move {
590 if let Err(e) = listener_clone.run().await {
591 let _ = error_tx.send(format!("Change listener error: {}", e)).await;
592 }
593 }));
594
595 loop {
596 tokio::select! {
597 result = change_rx.recv() => {
598 match result {
599 Ok(change) => {
600 Self::handle_change(
601 &change,
602 &invalidation_engine,
603 &active_subscriptions,
604 &job_subscriptions,
605 &workflow_subscriptions,
606 &session_server,
607 ®istry,
608 &db_pool,
609 ).await;
610 }
611 Err(broadcast::error::RecvError::Lagged(n)) => {
612 tracing::warn!("Reactor lagged by {} messages", n);
613 }
614 Err(broadcast::error::RecvError::Closed) => {
615 tracing::debug!("Change channel closed");
616 break;
617 }
618 }
619 }
620 Some(error_msg) = listener_error_rx.recv() => {
621 if restart_count >= max_restarts {
622 tracing::error!(
623 attempts = restart_count,
624 last_error = %error_msg,
625 "Change listener failed permanently, real-time updates disabled"
626 );
627 break;
628 }
629
630 restart_count += 1;
631 let delay = base_delay_ms * 2u64.saturating_pow(restart_count - 1);
632 tracing::warn!(
633 attempt = restart_count,
634 max = max_restarts,
635 delay_ms = delay,
636 error = %error_msg,
637 "Change listener restarting"
638 );
639
640 tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
641
642 let listener_clone = listener.clone();
644 let error_tx = listener_error_tx.clone();
645 if let Some(handle) = listener_handle.take() {
646 handle.abort();
647 }
648 change_rx = listener.subscribe();
649 listener_handle = Some(tokio::spawn(async move {
650 if let Err(e) = listener_clone.run().await {
651 let _ = error_tx.send(format!("Change listener error: {}", e)).await;
652 }
653 }));
654 }
655 _ = shutdown_rx.recv() => {
656 tracing::debug!("Reactor shutting down");
657 break;
658 }
659 }
660 }
661
662 if let Some(handle) = listener_handle {
663 handle.abort();
664 }
665 });
666
667 Ok(())
668 }
669
670 #[allow(clippy::too_many_arguments)]
672 async fn handle_change(
673 change: &Change,
674 invalidation_engine: &Arc<InvalidationEngine>,
675 active_subscriptions: &Arc<RwLock<HashMap<SubscriptionId, ActiveSubscription>>>,
676 job_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<JobSubscription>>>>,
677 workflow_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
678 session_server: &Arc<SessionServer>,
679 registry: &FunctionRegistry,
680 db_pool: &sqlx::PgPool,
681 ) {
682 tracing::trace!(table = %change.table, op = ?change.operation, row_id = ?change.row_id, "Processing change");
683
684 match change.table.as_str() {
686 "forge_jobs" => {
687 if let Some(job_id) = change.row_id {
688 Self::handle_job_change(job_id, job_subscriptions, session_server, db_pool)
689 .await;
690 }
691 return; }
693 "forge_workflow_runs" => {
694 if let Some(workflow_id) = change.row_id {
695 Self::handle_workflow_change(
696 workflow_id,
697 workflow_subscriptions,
698 session_server,
699 db_pool,
700 )
701 .await;
702 }
703 return; }
705 "forge_workflow_steps" => {
706 if let Some(step_id) = change.row_id {
708 Self::handle_workflow_step_change(
709 step_id,
710 workflow_subscriptions,
711 session_server,
712 db_pool,
713 )
714 .await;
715 }
716 return; }
718 _ => {}
719 }
720
721 invalidation_engine.process_change(change.clone()).await;
723
724 let invalidated = invalidation_engine.check_pending().await;
729
730 if invalidated.is_empty() {
731 return;
732 }
733
734 tracing::trace!(count = invalidated.len(), "Invalidating subscriptions");
735
736 let subs_to_process: Vec<_> = {
738 let subscriptions = active_subscriptions.read().await;
739 invalidated
740 .iter()
741 .filter_map(|sub_id| {
742 subscriptions.get(sub_id).map(|active| {
743 (
744 *sub_id,
745 active.session_id,
746 active.client_sub_id.clone(),
747 active.query_name.clone(),
748 active.args.clone(),
749 active.last_result_hash.clone(),
750 active.auth_context.clone(),
751 )
752 })
753 })
754 .collect()
755 };
756
757 let mut updates: Vec<(SubscriptionId, String)> = Vec::new();
759
760 for (sub_id, session_id, client_sub_id, query_name, args, last_hash, auth_context) in
762 subs_to_process
763 {
764 match Self::execute_query_static(registry, db_pool, &query_name, &args, &auth_context)
766 .await
767 {
768 Ok((new_data, _read_set)) => {
769 let new_hash = Self::compute_hash(&new_data);
770
771 if last_hash.as_ref() != Some(&new_hash) {
773 let message = RealtimeMessage::Data {
775 subscription_id: client_sub_id.clone(),
776 data: new_data,
777 };
778
779 if let Err(e) = session_server.send_to_session(session_id, message).await {
780 tracing::debug!(client_id = %client_sub_id, error = %e, "Failed to send update");
781 } else {
782 tracing::trace!(client_id = %client_sub_id, "Pushed update to client");
783 updates.push((sub_id, new_hash));
785 }
786 }
787 }
788 Err(e) => {
789 tracing::warn!(client_id = %client_sub_id, error = %e, "Failed to re-execute query");
790 }
791 }
792 }
793
794 if !updates.is_empty() {
796 let mut subscriptions = active_subscriptions.write().await;
797 for (sub_id, new_hash) in updates {
798 if let Some(active) = subscriptions.get_mut(&sub_id) {
799 active.last_result_hash = Some(new_hash);
800 }
801 }
802 }
803 }
804
805 async fn handle_job_change(
807 job_id: Uuid,
808 job_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<JobSubscription>>>>,
809 session_server: &Arc<SessionServer>,
810 db_pool: &sqlx::PgPool,
811 ) {
812 let subs = job_subscriptions.read().await;
813 let subscribers = match subs.get(&job_id) {
814 Some(s) if !s.is_empty() => s.clone(),
815 _ => return, };
817 drop(subs); let job_data = match Self::fetch_job_data_static(job_id, db_pool).await {
821 Ok(data) => data,
822 Err(e) => {
823 tracing::debug!(%job_id, error = %e, "Failed to fetch job data");
824 return;
825 }
826 };
827
828 let owner_subject = match Self::fetch_job_owner_subject_static(job_id, db_pool).await {
829 Ok(owner) => owner,
830 Err(e) => {
831 tracing::debug!(%job_id, error = %e, "Failed to fetch job owner");
832 return;
833 }
834 };
835
836 let mut unauthorized_subscribers: HashSet<(SessionId, String)> = HashSet::new();
837
838 for sub in subscribers {
840 if Self::check_owner_access(owner_subject.clone(), &sub.auth_context).is_err() {
841 unauthorized_subscribers.insert((sub.session_id, sub.client_sub_id.clone()));
842 continue;
843 }
844
845 let message = RealtimeMessage::JobUpdate {
846 client_sub_id: sub.client_sub_id.clone(),
847 job: job_data.clone(),
848 };
849
850 if let Err(e) = session_server
851 .send_to_session(sub.session_id, message)
852 .await
853 {
854 tracing::trace!(%job_id, error = %e, "Failed to send job update");
855 } else {
856 tracing::trace!(%job_id, "Job update sent");
857 }
858 }
859
860 if !unauthorized_subscribers.is_empty() {
861 let mut subs = job_subscriptions.write().await;
862 if let Some(entries) = subs.get_mut(&job_id) {
863 entries.retain(|entry| {
864 !unauthorized_subscribers
865 .contains(&(entry.session_id, entry.client_sub_id.clone()))
866 });
867 }
868 subs.retain(|_, v| !v.is_empty());
869 }
870 }
871
872 async fn handle_workflow_change(
874 workflow_id: Uuid,
875 workflow_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
876 session_server: &Arc<SessionServer>,
877 db_pool: &sqlx::PgPool,
878 ) {
879 let subs = workflow_subscriptions.read().await;
880 let subscribers = match subs.get(&workflow_id) {
881 Some(s) if !s.is_empty() => s.clone(),
882 _ => return, };
884 drop(subs); let workflow_data = match Self::fetch_workflow_data_static(workflow_id, db_pool).await {
888 Ok(data) => data,
889 Err(e) => {
890 tracing::debug!(%workflow_id, error = %e, "Failed to fetch workflow data");
891 return;
892 }
893 };
894
895 let owner_subject =
896 match Self::fetch_workflow_owner_subject_static(workflow_id, db_pool).await {
897 Ok(owner) => owner,
898 Err(e) => {
899 tracing::debug!(%workflow_id, error = %e, "Failed to fetch workflow owner");
900 return;
901 }
902 };
903
904 let mut unauthorized_subscribers: HashSet<(SessionId, String)> = HashSet::new();
905
906 for sub in subscribers {
908 if Self::check_owner_access(owner_subject.clone(), &sub.auth_context).is_err() {
909 unauthorized_subscribers.insert((sub.session_id, sub.client_sub_id.clone()));
910 continue;
911 }
912
913 let message = RealtimeMessage::WorkflowUpdate {
914 client_sub_id: sub.client_sub_id.clone(),
915 workflow: workflow_data.clone(),
916 };
917
918 if let Err(e) = session_server
919 .send_to_session(sub.session_id, message)
920 .await
921 {
922 tracing::trace!(%workflow_id, error = %e, "Failed to send workflow update");
923 } else {
924 tracing::trace!(%workflow_id, "Workflow update sent");
925 }
926 }
927
928 if !unauthorized_subscribers.is_empty() {
929 let mut subs = workflow_subscriptions.write().await;
930 if let Some(entries) = subs.get_mut(&workflow_id) {
931 entries.retain(|entry| {
932 !unauthorized_subscribers
933 .contains(&(entry.session_id, entry.client_sub_id.clone()))
934 });
935 }
936 subs.retain(|_, v| !v.is_empty());
937 }
938 }
939
940 async fn handle_workflow_step_change(
942 step_id: Uuid,
943 workflow_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
944 session_server: &Arc<SessionServer>,
945 db_pool: &sqlx::PgPool,
946 ) {
947 let workflow_id: Option<Uuid> = match sqlx::query_scalar(
949 "SELECT workflow_run_id FROM forge_workflow_steps WHERE id = $1",
950 )
951 .bind(step_id)
952 .fetch_optional(db_pool)
953 .await
954 {
955 Ok(id) => id,
956 Err(e) => {
957 tracing::debug!(%step_id, error = %e, "Failed to look up workflow for step");
958 return;
959 }
960 };
961
962 if let Some(wf_id) = workflow_id {
963 Self::handle_workflow_change(wf_id, workflow_subscriptions, session_server, db_pool)
965 .await;
966 }
967 }
968
969 #[allow(clippy::type_complexity)]
971 async fn fetch_job_data_static(
972 job_id: Uuid,
973 db_pool: &sqlx::PgPool,
974 ) -> forge_core::Result<JobData> {
975 let row: Option<(
976 String,
977 Option<i32>,
978 Option<String>,
979 Option<serde_json::Value>,
980 Option<String>,
981 )> = sqlx::query_as(
982 r#"
983 SELECT status, progress_percent, progress_message, output, last_error
984 FROM forge_jobs WHERE id = $1
985 "#,
986 )
987 .bind(job_id)
988 .fetch_optional(db_pool)
989 .await
990 .map_err(forge_core::ForgeError::Sql)?;
991
992 match row {
993 Some((status, progress_percent, progress_message, output, error)) => Ok(JobData {
994 job_id: job_id.to_string(),
995 status,
996 progress_percent,
997 progress_message,
998 output,
999 error,
1000 }),
1001 None => Err(forge_core::ForgeError::NotFound(format!(
1002 "Job {} not found",
1003 job_id
1004 ))),
1005 }
1006 }
1007
1008 async fn fetch_job_owner_subject_static(
1009 job_id: Uuid,
1010 db_pool: &sqlx::PgPool,
1011 ) -> forge_core::Result<Option<String>> {
1012 let owner_subject: Option<Option<String>> =
1013 sqlx::query_scalar("SELECT owner_subject FROM forge_jobs WHERE id = $1")
1014 .bind(job_id)
1015 .fetch_optional(db_pool)
1016 .await
1017 .map_err(forge_core::ForgeError::Sql)?;
1018
1019 owner_subject
1020 .ok_or_else(|| forge_core::ForgeError::NotFound(format!("Job {} not found", job_id)))
1021 }
1022
1023 #[allow(clippy::type_complexity)]
1025 async fn fetch_workflow_data_static(
1026 workflow_id: Uuid,
1027 db_pool: &sqlx::PgPool,
1028 ) -> forge_core::Result<WorkflowData> {
1029 let row: Option<(
1030 String,
1031 Option<String>,
1032 Option<serde_json::Value>,
1033 Option<String>,
1034 )> = sqlx::query_as(
1035 r#"
1036 SELECT status, current_step, output, error
1037 FROM forge_workflow_runs WHERE id = $1
1038 "#,
1039 )
1040 .bind(workflow_id)
1041 .fetch_optional(db_pool)
1042 .await
1043 .map_err(forge_core::ForgeError::Sql)?;
1044
1045 let (status, current_step, output, error) = match row {
1046 Some(r) => r,
1047 None => {
1048 return Err(forge_core::ForgeError::NotFound(format!(
1049 "Workflow {} not found",
1050 workflow_id
1051 )));
1052 }
1053 };
1054
1055 let step_rows: Vec<(String, String, Option<String>)> = sqlx::query_as(
1056 r#"
1057 SELECT step_name, status, error
1058 FROM forge_workflow_steps
1059 WHERE workflow_run_id = $1
1060 ORDER BY started_at ASC NULLS LAST
1061 "#,
1062 )
1063 .bind(workflow_id)
1064 .fetch_all(db_pool)
1065 .await
1066 .map_err(forge_core::ForgeError::Sql)?;
1067
1068 let steps = step_rows
1069 .into_iter()
1070 .map(|(name, status, error)| WorkflowStepData {
1071 name,
1072 status,
1073 error,
1074 })
1075 .collect();
1076
1077 Ok(WorkflowData {
1078 workflow_id: workflow_id.to_string(),
1079 status,
1080 current_step,
1081 steps,
1082 output,
1083 error,
1084 })
1085 }
1086
1087 async fn fetch_workflow_owner_subject_static(
1088 workflow_id: Uuid,
1089 db_pool: &sqlx::PgPool,
1090 ) -> forge_core::Result<Option<String>> {
1091 let owner_subject: Option<Option<String>> =
1092 sqlx::query_scalar("SELECT owner_subject FROM forge_workflow_runs WHERE id = $1")
1093 .bind(workflow_id)
1094 .fetch_optional(db_pool)
1095 .await
1096 .map_err(forge_core::ForgeError::Sql)?;
1097
1098 owner_subject.ok_or_else(|| {
1099 forge_core::ForgeError::NotFound(format!("Workflow {} not found", workflow_id))
1100 })
1101 }
1102
1103 async fn execute_query_static(
1105 registry: &FunctionRegistry,
1106 db_pool: &sqlx::PgPool,
1107 query_name: &str,
1108 args: &serde_json::Value,
1109 auth_context: &forge_core::function::AuthContext,
1110 ) -> forge_core::Result<(serde_json::Value, ReadSet)> {
1111 match registry.get(query_name) {
1112 Some(FunctionEntry::Query { info, handler }) => {
1113 Self::check_query_auth(info, auth_context)?;
1114 Self::check_identity_args(query_name, args, auth_context, !info.is_public)?;
1115
1116 let ctx = forge_core::function::QueryContext::new(
1117 db_pool.clone(),
1118 auth_context.clone(),
1119 forge_core::function::RequestMetadata::new(),
1120 );
1121
1122 let normalized_args = match args {
1123 v if v.as_object().is_some_and(|o| o.is_empty()) => serde_json::Value::Null,
1124 v => v.clone(),
1125 };
1126
1127 let data = handler(&ctx, normalized_args).await?;
1128
1129 let mut read_set = ReadSet::new();
1131
1132 if info.table_dependencies.is_empty() {
1133 let table_name = Self::extract_table_name(query_name);
1135 read_set.add_table(&table_name);
1136 tracing::trace!(
1137 query = %query_name,
1138 fallback_table = %table_name,
1139 "Using naming convention fallback for table dependency"
1140 );
1141 } else {
1142 for table in info.table_dependencies {
1143 read_set.add_table(*table);
1144 }
1145 }
1146
1147 Ok((data, read_set))
1148 }
1149 _ => Err(forge_core::ForgeError::Validation(format!(
1150 "Query '{}' not found or not a query",
1151 query_name
1152 ))),
1153 }
1154 }
1155
1156 fn extract_table_name(query_name: &str) -> String {
1158 if let Some(rest) = query_name.strip_prefix("get_") {
1159 rest.to_string()
1160 } else if let Some(rest) = query_name.strip_prefix("list_") {
1161 rest.to_string()
1162 } else if let Some(rest) = query_name.strip_prefix("find_") {
1163 rest.to_string()
1164 } else if let Some(rest) = query_name.strip_prefix("fetch_") {
1165 rest.to_string()
1166 } else {
1167 query_name.to_string()
1168 }
1169 }
1170
1171 fn check_query_auth(
1172 info: &forge_core::function::FunctionInfo,
1173 auth: &forge_core::function::AuthContext,
1174 ) -> forge_core::Result<()> {
1175 if info.is_public {
1176 return Ok(());
1177 }
1178
1179 if !auth.is_authenticated() {
1180 return Err(forge_core::ForgeError::Unauthorized(
1181 "Authentication required".into(),
1182 ));
1183 }
1184
1185 if let Some(role) = info.required_role
1186 && !auth.has_role(role)
1187 {
1188 return Err(forge_core::ForgeError::Forbidden(format!(
1189 "Role '{}' required",
1190 role
1191 )));
1192 }
1193
1194 Ok(())
1195 }
1196
1197 fn check_identity_args(
1198 function_name: &str,
1199 args: &serde_json::Value,
1200 auth: &forge_core::function::AuthContext,
1201 enforce_scope: bool,
1202 ) -> forge_core::Result<()> {
1203 if auth.is_admin() {
1204 return Ok(());
1205 }
1206
1207 let Some(obj) = args.as_object() else {
1208 if enforce_scope && auth.is_authenticated() {
1209 return Err(forge_core::ForgeError::Forbidden(format!(
1210 "Function '{function_name}' must include identity or tenant scope arguments"
1211 )));
1212 }
1213 return Ok(());
1214 };
1215
1216 let mut principal_values: Vec<String> = Vec::new();
1217 if let Some(user_id) = auth.user_id().map(|id| id.to_string()) {
1218 principal_values.push(user_id);
1219 }
1220 if let Some(subject) = auth.principal_id()
1221 && !principal_values.iter().any(|v| v == &subject)
1222 {
1223 principal_values.push(subject);
1224 }
1225
1226 let mut has_scope_key = false;
1227
1228 for key in [
1229 "user_id",
1230 "userId",
1231 "owner_id",
1232 "ownerId",
1233 "owner_subject",
1234 "ownerSubject",
1235 "subject",
1236 "sub",
1237 "principal_id",
1238 "principalId",
1239 ] {
1240 let Some(value) = obj.get(key) else {
1241 continue;
1242 };
1243 has_scope_key = true;
1244
1245 if !auth.is_authenticated() {
1246 return Err(forge_core::ForgeError::Unauthorized(format!(
1247 "Function '{function_name}' requires authentication for identity-scoped argument '{key}'"
1248 )));
1249 }
1250
1251 let serde_json::Value::String(actual) = value else {
1252 return Err(forge_core::ForgeError::InvalidArgument(format!(
1253 "Function '{function_name}' argument '{key}' must be a non-empty string"
1254 )));
1255 };
1256
1257 if actual.trim().is_empty() || !principal_values.iter().any(|v| v == actual) {
1258 return Err(forge_core::ForgeError::Forbidden(format!(
1259 "Function '{function_name}' argument '{key}' does not match authenticated principal"
1260 )));
1261 }
1262 }
1263
1264 for key in ["tenant_id", "tenantId"] {
1265 let Some(value) = obj.get(key) else {
1266 continue;
1267 };
1268 has_scope_key = true;
1269
1270 if !auth.is_authenticated() {
1271 return Err(forge_core::ForgeError::Unauthorized(format!(
1272 "Function '{function_name}' requires authentication for tenant-scoped argument '{key}'"
1273 )));
1274 }
1275
1276 let expected = auth
1277 .claim("tenant_id")
1278 .and_then(|v| v.as_str())
1279 .ok_or_else(|| {
1280 forge_core::ForgeError::Forbidden(format!(
1281 "Function '{function_name}' argument '{key}' is not allowed for this principal"
1282 ))
1283 })?;
1284
1285 let serde_json::Value::String(actual) = value else {
1286 return Err(forge_core::ForgeError::InvalidArgument(format!(
1287 "Function '{function_name}' argument '{key}' must be a non-empty string"
1288 )));
1289 };
1290
1291 if actual.trim().is_empty() || actual != expected {
1292 return Err(forge_core::ForgeError::Forbidden(format!(
1293 "Function '{function_name}' argument '{key}' does not match authenticated tenant"
1294 )));
1295 }
1296 }
1297
1298 if enforce_scope && auth.is_authenticated() && !has_scope_key {
1299 return Err(forge_core::ForgeError::Forbidden(format!(
1300 "Function '{function_name}' must include identity or tenant scope arguments"
1301 )));
1302 }
1303
1304 Ok(())
1305 }
1306
1307 async fn ensure_job_access(
1308 db_pool: &sqlx::PgPool,
1309 job_id: Uuid,
1310 auth: &forge_core::function::AuthContext,
1311 ) -> forge_core::Result<()> {
1312 let owner_subject_row: Option<(Option<String>,)> = sqlx::query_as(
1313 r#"
1314 SELECT owner_subject
1315 FROM forge_jobs
1316 WHERE id = $1
1317 "#,
1318 )
1319 .bind(job_id)
1320 .fetch_optional(db_pool)
1321 .await
1322 .map_err(forge_core::ForgeError::Sql)?;
1323
1324 let owner_subject = owner_subject_row
1325 .ok_or_else(|| forge_core::ForgeError::NotFound(format!("Job {} not found", job_id)))?
1326 .0;
1327
1328 Self::check_owner_access(owner_subject, auth)
1329 }
1330
1331 async fn ensure_workflow_access(
1332 db_pool: &sqlx::PgPool,
1333 workflow_id: Uuid,
1334 auth: &forge_core::function::AuthContext,
1335 ) -> forge_core::Result<()> {
1336 let owner_subject_row: Option<(Option<String>,)> = sqlx::query_as(
1337 r#"
1338 SELECT owner_subject
1339 FROM forge_workflow_runs
1340 WHERE id = $1
1341 "#,
1342 )
1343 .bind(workflow_id)
1344 .fetch_optional(db_pool)
1345 .await
1346 .map_err(forge_core::ForgeError::Sql)?;
1347
1348 let owner_subject = owner_subject_row
1349 .ok_or_else(|| {
1350 forge_core::ForgeError::NotFound(format!("Workflow {} not found", workflow_id))
1351 })?
1352 .0;
1353
1354 Self::check_owner_access(owner_subject, auth)
1355 }
1356
1357 fn check_owner_access(
1358 owner_subject: Option<String>,
1359 auth: &forge_core::function::AuthContext,
1360 ) -> forge_core::Result<()> {
1361 if auth.is_admin() {
1362 return Ok(());
1363 }
1364
1365 let principal = auth.principal_id().ok_or_else(|| {
1366 forge_core::ForgeError::Unauthorized("Authentication required".to_string())
1367 })?;
1368
1369 match owner_subject {
1370 Some(owner) if owner == principal => Ok(()),
1371 Some(_) => Err(forge_core::ForgeError::Forbidden(
1372 "Not authorized to access this resource".to_string(),
1373 )),
1374 None => Err(forge_core::ForgeError::Forbidden(
1375 "Resource has no owner; admin role required".to_string(),
1376 )),
1377 }
1378 }
1379
1380 pub fn stop(&self) {
1382 let _ = self.shutdown_tx.send(());
1383 self.change_listener.stop();
1384 }
1385
1386 pub async fn stats(&self) -> ReactorStats {
1388 let session_stats = self.session_server.stats().await;
1389 let inv_stats = self.invalidation_engine.stats().await;
1390
1391 ReactorStats {
1392 connections: session_stats.connections,
1393 subscriptions: session_stats.subscriptions,
1394 pending_invalidations: inv_stats.pending_subscriptions,
1395 listener_running: self.change_listener.is_running(),
1396 }
1397 }
1398}
1399
1400#[derive(Debug, Clone)]
1402pub struct ReactorStats {
1403 pub connections: usize,
1404 pub subscriptions: usize,
1405 pub pending_invalidations: usize,
1406 pub listener_running: bool,
1407}
1408
1409#[cfg(test)]
1410mod tests {
1411 use super::*;
1412 use std::collections::HashMap;
1413
1414 #[test]
1415 fn test_reactor_config_default() {
1416 let config = ReactorConfig::default();
1417 assert_eq!(config.listener.channel, "forge_changes");
1418 assert_eq!(config.invalidation.debounce_ms, 50);
1419 assert_eq!(config.max_listener_restarts, 5);
1420 assert_eq!(config.listener_restart_delay_ms, 1000);
1421 }
1422
1423 #[test]
1424 fn test_compute_hash() {
1425 let data1 = serde_json::json!({"name": "test"});
1426 let data2 = serde_json::json!({"name": "test"});
1427 let data3 = serde_json::json!({"name": "different"});
1428
1429 let hash1 = Reactor::compute_hash(&data1);
1430 let hash2 = Reactor::compute_hash(&data2);
1431 let hash3 = Reactor::compute_hash(&data3);
1432
1433 assert_eq!(hash1, hash2);
1434 assert_ne!(hash1, hash3);
1435 }
1436
1437 #[test]
1438 fn test_check_identity_args_rejects_cross_user() {
1439 let user_id = uuid::Uuid::new_v4();
1440 let auth = forge_core::function::AuthContext::authenticated(
1441 user_id,
1442 vec!["user".to_string()],
1443 HashMap::from([(
1444 "sub".to_string(),
1445 serde_json::Value::String(user_id.to_string()),
1446 )]),
1447 );
1448
1449 let result = Reactor::check_identity_args(
1450 "list_orders",
1451 &serde_json::json!({"user_id": uuid::Uuid::new_v4().to_string()}),
1452 &auth,
1453 true,
1454 );
1455 assert!(matches!(result, Err(forge_core::ForgeError::Forbidden(_))));
1456 }
1457
1458 #[test]
1459 fn test_check_identity_args_requires_scope_for_non_public_queries() {
1460 let user_id = uuid::Uuid::new_v4();
1461 let auth = forge_core::function::AuthContext::authenticated(
1462 user_id,
1463 vec!["user".to_string()],
1464 HashMap::from([(
1465 "sub".to_string(),
1466 serde_json::Value::String(user_id.to_string()),
1467 )]),
1468 );
1469
1470 let result =
1471 Reactor::check_identity_args("list_orders", &serde_json::json!({}), &auth, true);
1472 assert!(matches!(result, Err(forge_core::ForgeError::Forbidden(_))));
1473 }
1474
1475 #[test]
1476 fn test_check_owner_access_allows_admin() {
1477 let auth = forge_core::function::AuthContext::authenticated_without_uuid(
1478 vec!["admin".to_string()],
1479 HashMap::from([(
1480 "sub".to_string(),
1481 serde_json::Value::String("admin-1".to_string()),
1482 )]),
1483 );
1484
1485 let result = Reactor::check_owner_access(Some("other-user".to_string()), &auth);
1486 assert!(result.is_ok());
1487 }
1488}