1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use tokio::sync::{RwLock, broadcast, mpsc};
5use uuid::Uuid;
6
7use forge_core::cluster::NodeId;
8use forge_core::realtime::{Change, ReadSet, SessionId, SubscriptionId};
9
10use super::invalidation::{InvalidationConfig, InvalidationEngine};
11use super::listener::{ChangeListener, ListenerConfig};
12use super::manager::SubscriptionManager;
13use super::message::{
14 JobData, RealtimeConfig, RealtimeMessage, SessionServer, WorkflowData, WorkflowStepData,
15};
16use crate::function::{FunctionEntry, FunctionRegistry};
17
18#[derive(Debug, Clone)]
19pub struct ReactorConfig {
20 pub listener: ListenerConfig,
21 pub invalidation: InvalidationConfig,
22 pub realtime: RealtimeConfig,
23 pub max_listener_restarts: u32,
24 pub listener_restart_delay_ms: u64,
26}
27
28impl Default for ReactorConfig {
29 fn default() -> Self {
30 Self {
31 listener: ListenerConfig::default(),
32 invalidation: InvalidationConfig::default(),
33 realtime: RealtimeConfig::default(),
34 max_listener_restarts: 5,
35 listener_restart_delay_ms: 1000,
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct ActiveSubscription {
43 #[allow(dead_code)]
44 pub subscription_id: SubscriptionId,
45 pub session_id: SessionId,
46 #[allow(dead_code)]
47 pub client_sub_id: String,
48 pub query_name: String,
49 pub args: serde_json::Value,
50 pub last_result_hash: Option<String>,
51 #[allow(dead_code)]
52 pub read_set: ReadSet,
53 pub auth_context: forge_core::function::AuthContext,
55}
56
57#[derive(Debug, Clone)]
59pub struct JobSubscription {
60 #[allow(dead_code)]
61 pub subscription_id: SubscriptionId,
62 pub session_id: SessionId,
63 pub client_sub_id: String,
64 #[allow(dead_code)]
65 pub job_id: Uuid, pub auth_context: forge_core::function::AuthContext,
67}
68
69#[derive(Debug, Clone)]
71pub struct WorkflowSubscription {
72 #[allow(dead_code)]
73 pub subscription_id: SubscriptionId,
74 pub session_id: SessionId,
75 pub client_sub_id: String,
76 #[allow(dead_code)]
77 pub workflow_id: Uuid, pub auth_context: forge_core::function::AuthContext,
79}
80
81pub struct Reactor {
83 node_id: NodeId,
84 db_pool: sqlx::PgPool,
85 registry: FunctionRegistry,
86 subscription_manager: Arc<SubscriptionManager>,
87 session_server: Arc<SessionServer>,
88 change_listener: Arc<ChangeListener>,
89 invalidation_engine: Arc<InvalidationEngine>,
90 active_subscriptions: Arc<RwLock<HashMap<SubscriptionId, ActiveSubscription>>>,
92 job_subscriptions: Arc<RwLock<HashMap<Uuid, Vec<JobSubscription>>>>,
94 workflow_subscriptions: Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
96 shutdown_tx: broadcast::Sender<()>,
98 max_listener_restarts: u32,
100 listener_restart_delay_ms: u64,
101}
102
103impl Reactor {
104 pub fn new(
106 node_id: NodeId,
107 db_pool: sqlx::PgPool,
108 registry: FunctionRegistry,
109 config: ReactorConfig,
110 ) -> Self {
111 let subscription_manager = Arc::new(SubscriptionManager::new(
112 config.realtime.max_subscriptions_per_session,
113 ));
114 let session_server = Arc::new(SessionServer::new(node_id, config.realtime.clone()));
115 let change_listener = Arc::new(ChangeListener::new(db_pool.clone(), config.listener));
116 let invalidation_engine = Arc::new(InvalidationEngine::new(
117 subscription_manager.clone(),
118 config.invalidation,
119 ));
120 let (shutdown_tx, _) = broadcast::channel(1);
121
122 Self {
123 node_id,
124 db_pool,
125 registry,
126 subscription_manager,
127 session_server,
128 change_listener,
129 invalidation_engine,
130 active_subscriptions: Arc::new(RwLock::new(HashMap::new())),
131 job_subscriptions: Arc::new(RwLock::new(HashMap::new())),
132 workflow_subscriptions: Arc::new(RwLock::new(HashMap::new())),
133 shutdown_tx,
134 max_listener_restarts: config.max_listener_restarts,
135 listener_restart_delay_ms: config.listener_restart_delay_ms,
136 }
137 }
138
139 pub fn node_id(&self) -> NodeId {
141 self.node_id
142 }
143
144 pub fn session_server(&self) -> Arc<SessionServer> {
146 self.session_server.clone()
147 }
148
149 pub fn subscription_manager(&self) -> Arc<SubscriptionManager> {
151 self.subscription_manager.clone()
152 }
153
154 pub fn shutdown_receiver(&self) -> broadcast::Receiver<()> {
156 self.shutdown_tx.subscribe()
157 }
158
159 pub async fn register_session(
161 &self,
162 session_id: SessionId,
163 sender: mpsc::Sender<RealtimeMessage>,
164 ) {
165 self.session_server
166 .register_connection(session_id, sender)
167 .await;
168 tracing::trace!(?session_id, "Session registered");
169 }
170
171 pub async fn remove_session(&self, session_id: SessionId) {
173 if let Some(subscription_ids) = self.session_server.remove_connection(session_id).await {
174 for sub_id in subscription_ids {
176 self.subscription_manager.remove_subscription(sub_id).await;
177 self.active_subscriptions.write().await.remove(&sub_id);
178 }
179 }
180
181 {
183 let mut job_subs = self.job_subscriptions.write().await;
184 for subscribers in job_subs.values_mut() {
185 subscribers.retain(|s| s.session_id != session_id);
186 }
187 job_subs.retain(|_, v| !v.is_empty());
189 }
190
191 {
193 let mut workflow_subs = self.workflow_subscriptions.write().await;
194 for subscribers in workflow_subs.values_mut() {
195 subscribers.retain(|s| s.session_id != session_id);
196 }
197 workflow_subs.retain(|_, v| !v.is_empty());
199 }
200
201 tracing::trace!(?session_id, "Session removed");
202 }
203
204 pub async fn subscribe(
206 &self,
207 session_id: SessionId,
208 client_sub_id: String,
209 query_name: String,
210 args: serde_json::Value,
211 auth_context: forge_core::function::AuthContext,
212 ) -> forge_core::Result<(SubscriptionId, serde_json::Value)> {
213 let sub_info = self
214 .subscription_manager
215 .create_subscription(session_id, &query_name, args.clone())
216 .await?;
217
218 let subscription_id = sub_info.id;
219
220 if let Err(error) = self
221 .session_server
222 .add_subscription(session_id, subscription_id)
223 .await
224 {
225 self.subscription_manager
226 .remove_subscription(subscription_id)
227 .await;
228 return Err(error);
229 }
230
231 let (data, read_set) = match self.execute_query(&query_name, &args, &auth_context).await {
232 Ok(result) => result,
233 Err(error) => {
234 self.unsubscribe(subscription_id).await;
236 return Err(error);
237 }
238 };
239
240 let result_hash = Self::compute_hash(&data);
241
242 tracing::trace!(
243 ?subscription_id,
244 query = %query_name,
245 tables = ?read_set.tables.iter().collect::<Vec<_>>(),
246 "Subscription read set"
247 );
248
249 self.subscription_manager
250 .update_subscription(subscription_id, read_set.clone(), result_hash.clone())
251 .await;
252
253 let active = ActiveSubscription {
254 subscription_id,
255 session_id,
256 client_sub_id,
257 query_name,
258 args,
259 last_result_hash: Some(result_hash),
260 read_set,
261 auth_context,
262 };
263 self.active_subscriptions
264 .write()
265 .await
266 .insert(subscription_id, active);
267
268 tracing::trace!(?subscription_id, "Subscription created");
269
270 Ok((subscription_id, data))
271 }
272
273 pub async fn unsubscribe(&self, subscription_id: SubscriptionId) {
275 self.session_server
276 .remove_subscription(subscription_id)
277 .await;
278 self.subscription_manager
279 .remove_subscription(subscription_id)
280 .await;
281 self.active_subscriptions
282 .write()
283 .await
284 .remove(&subscription_id);
285 tracing::trace!(?subscription_id, "Subscription removed");
286 }
287
288 pub async fn subscribe_job(
290 &self,
291 session_id: SessionId,
292 client_sub_id: String,
293 job_id: Uuid, auth_context: &forge_core::function::AuthContext,
295 ) -> forge_core::Result<JobData> {
296 let subscription_id = SubscriptionId::new();
297
298 Self::ensure_job_access(&self.db_pool, job_id, auth_context).await?;
299
300 let job_data = self.fetch_job_data(job_id).await?;
302
303 let subscription = JobSubscription {
305 subscription_id,
306 session_id,
307 client_sub_id: client_sub_id.clone(),
308 job_id,
309 auth_context: auth_context.clone(),
310 };
311
312 let mut subs = self.job_subscriptions.write().await;
313 subs.entry(job_id).or_default().push(subscription);
314
315 tracing::trace!(
316 ?subscription_id,
317 %job_id,
318 "Job subscription created"
319 );
320
321 Ok(job_data)
322 }
323
324 pub async fn unsubscribe_job(&self, session_id: SessionId, client_sub_id: &str) {
326 let mut subs = self.job_subscriptions.write().await;
327
328 for subscribers in subs.values_mut() {
330 subscribers
331 .retain(|s| !(s.session_id == session_id && s.client_sub_id == client_sub_id));
332 }
333
334 subs.retain(|_, v| !v.is_empty());
336
337 tracing::trace!(client_id = %client_sub_id, "Job subscription removed");
338 }
339
340 pub async fn subscribe_workflow(
342 &self,
343 session_id: SessionId,
344 client_sub_id: String,
345 workflow_id: Uuid, auth_context: &forge_core::function::AuthContext,
347 ) -> forge_core::Result<WorkflowData> {
348 let subscription_id = SubscriptionId::new();
349
350 Self::ensure_workflow_access(&self.db_pool, workflow_id, auth_context).await?;
351
352 let workflow_data = self.fetch_workflow_data(workflow_id).await?;
354
355 let subscription = WorkflowSubscription {
357 subscription_id,
358 session_id,
359 client_sub_id: client_sub_id.clone(),
360 workflow_id,
361 auth_context: auth_context.clone(),
362 };
363
364 let mut subs = self.workflow_subscriptions.write().await;
365 subs.entry(workflow_id).or_default().push(subscription);
366
367 tracing::trace!(
368 ?subscription_id,
369 %workflow_id,
370 "Workflow subscription created"
371 );
372
373 Ok(workflow_data)
374 }
375
376 pub async fn unsubscribe_workflow(&self, session_id: SessionId, client_sub_id: &str) {
378 let mut subs = self.workflow_subscriptions.write().await;
379
380 for subscribers in subs.values_mut() {
382 subscribers
383 .retain(|s| !(s.session_id == session_id && s.client_sub_id == client_sub_id));
384 }
385
386 subs.retain(|_, v| !v.is_empty());
388
389 tracing::trace!(client_id = %client_sub_id, "Workflow subscription removed");
390 }
391
392 #[allow(clippy::type_complexity)]
394 async fn fetch_job_data(&self, job_id: Uuid) -> forge_core::Result<JobData> {
395 let row: Option<(
396 String,
397 Option<i32>,
398 Option<String>,
399 Option<serde_json::Value>,
400 Option<String>,
401 )> = sqlx::query_as(
402 r#"
403 SELECT status, progress_percent, progress_message, output,
404 COALESCE(cancel_reason, last_error) as error
405 FROM forge_jobs WHERE id = $1
406 "#,
407 )
408 .bind(job_id)
409 .fetch_optional(&self.db_pool)
410 .await
411 .map_err(forge_core::ForgeError::Sql)?;
412
413 match row {
414 Some((status, progress_percent, progress_message, output, error)) => Ok(JobData {
415 job_id: job_id.to_string(),
416 status,
417 progress_percent,
418 progress_message,
419 output,
420 error,
421 }),
422 None => Err(forge_core::ForgeError::NotFound(format!(
423 "Job {} not found",
424 job_id
425 ))),
426 }
427 }
428
429 #[allow(clippy::type_complexity)]
431 async fn fetch_workflow_data(&self, workflow_id: Uuid) -> forge_core::Result<WorkflowData> {
432 let row: Option<(
434 String,
435 Option<String>,
436 Option<serde_json::Value>,
437 Option<String>,
438 )> = sqlx::query_as(
439 r#"
440 SELECT status, current_step, output, error
441 FROM forge_workflow_runs WHERE id = $1
442 "#,
443 )
444 .bind(workflow_id)
445 .fetch_optional(&self.db_pool)
446 .await
447 .map_err(forge_core::ForgeError::Sql)?;
448
449 let (status, current_step, output, error) = match row {
450 Some(r) => r,
451 None => {
452 return Err(forge_core::ForgeError::NotFound(format!(
453 "Workflow {} not found",
454 workflow_id
455 )));
456 }
457 };
458
459 let step_rows: Vec<(String, String, Option<String>)> = sqlx::query_as(
461 r#"
462 SELECT step_name, status, error
463 FROM forge_workflow_steps
464 WHERE workflow_run_id = $1
465 ORDER BY started_at ASC NULLS LAST
466 "#,
467 )
468 .bind(workflow_id)
469 .fetch_all(&self.db_pool)
470 .await
471 .map_err(forge_core::ForgeError::Sql)?;
472
473 let steps = step_rows
474 .into_iter()
475 .map(|(name, status, error)| WorkflowStepData {
476 name,
477 status,
478 error,
479 })
480 .collect();
481
482 Ok(WorkflowData {
483 workflow_id: workflow_id.to_string(),
484 status,
485 current_step,
486 steps,
487 output,
488 error,
489 })
490 }
491
492 async fn execute_query(
494 &self,
495 query_name: &str,
496 args: &serde_json::Value,
497 auth_context: &forge_core::function::AuthContext,
498 ) -> forge_core::Result<(serde_json::Value, ReadSet)> {
499 match self.registry.get(query_name) {
500 Some(FunctionEntry::Query { info, handler }) => {
501 Self::check_query_auth(info, auth_context)?;
502 Self::check_identity_args(query_name, args, auth_context, !info.is_public)?;
503
504 let ctx = forge_core::function::QueryContext::new(
505 self.db_pool.clone(),
506 auth_context.clone(),
507 forge_core::function::RequestMetadata::new(),
508 );
509
510 let normalized_args = match args {
512 v if v.as_object().is_some_and(|o| o.is_empty()) => serde_json::Value::Null,
513 v => v.clone(),
514 };
515
516 let data = handler(&ctx, normalized_args).await?;
517
518 let mut read_set = ReadSet::new();
520
521 if info.table_dependencies.is_empty() {
522 let table_name = Self::extract_table_name(query_name);
525 read_set.add_table(&table_name);
526 tracing::trace!(
527 query = %query_name,
528 fallback_table = %table_name,
529 "Using naming convention fallback for table dependency"
530 );
531 } else {
532 for table in info.table_dependencies {
534 read_set.add_table(*table);
535 }
536 }
537
538 Ok((data, read_set))
539 }
540 Some(_) => Err(forge_core::ForgeError::Validation(format!(
541 "'{}' is not a query",
542 query_name
543 ))),
544 None => Err(forge_core::ForgeError::Validation(format!(
545 "Query '{}' not found",
546 query_name
547 ))),
548 }
549 }
550
551 fn compute_hash(data: &serde_json::Value) -> String {
553 use std::collections::hash_map::DefaultHasher;
554 use std::hash::{Hash, Hasher};
555
556 let json = serde_json::to_string(data).unwrap_or_default();
557 let mut hasher = DefaultHasher::new();
558 json.hash(&mut hasher);
559 format!("{:x}", hasher.finish())
560 }
561
562 async fn flush_invalidations(
564 invalidation_engine: &Arc<InvalidationEngine>,
565 active_subscriptions: &Arc<RwLock<HashMap<SubscriptionId, ActiveSubscription>>>,
566 session_server: &Arc<SessionServer>,
567 registry: &FunctionRegistry,
568 db_pool: &sqlx::PgPool,
569 ) {
570 let invalidated = invalidation_engine.check_pending().await;
571 if invalidated.is_empty() {
572 return;
573 }
574
575 tracing::trace!(count = invalidated.len(), "Invalidating subscriptions");
576
577 let subs_to_process: Vec<_> = {
578 let subscriptions = active_subscriptions.read().await;
579 invalidated
580 .iter()
581 .filter_map(|sub_id| {
582 subscriptions.get(sub_id).map(|active| {
583 (
584 *sub_id,
585 active.session_id,
586 active.client_sub_id.clone(),
587 active.query_name.clone(),
588 active.args.clone(),
589 active.last_result_hash.clone(),
590 active.auth_context.clone(),
591 )
592 })
593 })
594 .collect()
595 };
596
597 let mut updates: Vec<(SubscriptionId, String)> = Vec::new();
598
599 for (sub_id, session_id, client_sub_id, query_name, args, last_hash, auth_context) in
600 subs_to_process
601 {
602 match Self::execute_query_static(registry, db_pool, &query_name, &args, &auth_context)
603 .await
604 {
605 Ok((new_data, _read_set)) => {
606 let new_hash = Self::compute_hash(&new_data);
607
608 if last_hash.as_ref() != Some(&new_hash) {
609 let message = RealtimeMessage::Data {
610 subscription_id: client_sub_id.clone(),
611 data: new_data,
612 };
613
614 if let Err(e) = session_server.send_to_session(session_id, message).await {
615 tracing::debug!(client_id = %client_sub_id, error = %e, "Failed to send update");
616 } else {
617 tracing::trace!(client_id = %client_sub_id, "Pushed update to client");
618 updates.push((sub_id, new_hash));
619 }
620 }
621 }
622 Err(e) => {
623 tracing::warn!(client_id = %client_sub_id, error = %e, "Failed to re-execute query");
624 }
625 }
626 }
627
628 if !updates.is_empty() {
629 let mut subscriptions = active_subscriptions.write().await;
630 for (sub_id, new_hash) in updates {
631 if let Some(active) = subscriptions.get_mut(&sub_id) {
632 active.last_result_hash = Some(new_hash);
633 }
634 }
635 }
636 }
637
638 pub async fn start(&self) -> forge_core::Result<()> {
640 let listener = self.change_listener.clone();
641 let invalidation_engine = self.invalidation_engine.clone();
642 let active_subscriptions = self.active_subscriptions.clone();
643 let job_subscriptions = self.job_subscriptions.clone();
644 let workflow_subscriptions = self.workflow_subscriptions.clone();
645 let session_server = self.session_server.clone();
646 let registry = self.registry.clone();
647 let db_pool = self.db_pool.clone();
648 let mut shutdown_rx = self.shutdown_tx.subscribe();
649 let max_restarts = self.max_listener_restarts;
650 let base_delay_ms = self.listener_restart_delay_ms;
651
652 let mut change_rx = listener.subscribe();
654
655 tokio::spawn(async move {
657 tracing::debug!("Reactor listening for changes");
658
659 let mut restart_count: u32 = 0;
660 let (listener_error_tx, mut listener_error_rx) = mpsc::channel::<String>(1);
661
662 let listener_clone = listener.clone();
664 let error_tx = listener_error_tx.clone();
665 let mut listener_handle = Some(tokio::spawn(async move {
666 if let Err(e) = listener_clone.run().await {
667 let _ = error_tx.send(format!("Change listener error: {}", e)).await;
668 }
669 }));
670
671 let mut flush_interval = tokio::time::interval(std::time::Duration::from_millis(25));
673 flush_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
674
675 loop {
676 tokio::select! {
677 result = change_rx.recv() => {
678 match result {
679 Ok(change) => {
680 Self::handle_change(
681 &change,
682 &invalidation_engine,
683 &job_subscriptions,
684 &workflow_subscriptions,
685 &session_server,
686 &db_pool,
687 ).await;
688 }
689 Err(broadcast::error::RecvError::Lagged(n)) => {
690 tracing::warn!("Reactor lagged by {} messages", n);
691 }
692 Err(broadcast::error::RecvError::Closed) => {
693 tracing::debug!("Change channel closed");
694 break;
695 }
696 }
697 }
698 _ = flush_interval.tick() => {
699 Self::flush_invalidations(
700 &invalidation_engine,
701 &active_subscriptions,
702 &session_server,
703 ®istry,
704 &db_pool,
705 ).await;
706 }
707 Some(error_msg) = listener_error_rx.recv() => {
708 if restart_count >= max_restarts {
709 tracing::error!(
710 attempts = restart_count,
711 last_error = %error_msg,
712 "Change listener failed permanently, real-time updates disabled"
713 );
714 break;
715 }
716
717 restart_count += 1;
718 let delay = base_delay_ms * 2u64.saturating_pow(restart_count - 1);
719 tracing::warn!(
720 attempt = restart_count,
721 max = max_restarts,
722 delay_ms = delay,
723 error = %error_msg,
724 "Change listener restarting"
725 );
726
727 tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
728
729 let listener_clone = listener.clone();
731 let error_tx = listener_error_tx.clone();
732 if let Some(handle) = listener_handle.take() {
733 handle.abort();
734 }
735 change_rx = listener.subscribe();
736 listener_handle = Some(tokio::spawn(async move {
737 if let Err(e) = listener_clone.run().await {
738 let _ = error_tx.send(format!("Change listener error: {}", e)).await;
739 }
740 }));
741 }
742 _ = shutdown_rx.recv() => {
743 tracing::debug!("Reactor shutting down");
744 break;
745 }
746 }
747 }
748
749 if let Some(handle) = listener_handle {
750 handle.abort();
751 }
752 });
753
754 Ok(())
755 }
756
757 #[allow(clippy::too_many_arguments)]
759 async fn handle_change(
760 change: &Change,
761 invalidation_engine: &Arc<InvalidationEngine>,
762 job_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<JobSubscription>>>>,
763 workflow_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
764 session_server: &Arc<SessionServer>,
765 db_pool: &sqlx::PgPool,
766 ) {
767 tracing::trace!(table = %change.table, op = ?change.operation, row_id = ?change.row_id, "Processing change");
768
769 match change.table.as_str() {
771 "forge_jobs" => {
772 if let Some(job_id) = change.row_id {
773 Self::handle_job_change(job_id, job_subscriptions, session_server, db_pool)
774 .await;
775 }
776 return; }
778 "forge_workflow_runs" => {
779 if let Some(workflow_id) = change.row_id {
780 Self::handle_workflow_change(
781 workflow_id,
782 workflow_subscriptions,
783 session_server,
784 db_pool,
785 )
786 .await;
787 }
788 return; }
790 "forge_workflow_steps" => {
791 if let Some(step_id) = change.row_id {
793 Self::handle_workflow_step_change(
794 step_id,
795 workflow_subscriptions,
796 session_server,
797 db_pool,
798 )
799 .await;
800 }
801 return; }
803 _ => {}
804 }
805
806 invalidation_engine.process_change(change.clone()).await;
808 }
809
810 async fn handle_job_change(
812 job_id: Uuid,
813 job_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<JobSubscription>>>>,
814 session_server: &Arc<SessionServer>,
815 db_pool: &sqlx::PgPool,
816 ) {
817 let subs = job_subscriptions.read().await;
818 let subscribers = match subs.get(&job_id) {
819 Some(s) if !s.is_empty() => s.clone(),
820 _ => return, };
822 drop(subs); let job_data = match Self::fetch_job_data_static(job_id, db_pool).await {
826 Ok(data) => data,
827 Err(e) => {
828 tracing::debug!(%job_id, error = %e, "Failed to fetch job data");
829 return;
830 }
831 };
832
833 let owner_subject = match Self::fetch_job_owner_subject_static(job_id, db_pool).await {
834 Ok(owner) => owner,
835 Err(e) => {
836 tracing::debug!(%job_id, error = %e, "Failed to fetch job owner");
837 return;
838 }
839 };
840
841 let mut unauthorized_subscribers: HashSet<(SessionId, String)> = HashSet::new();
842
843 for sub in subscribers {
845 if Self::check_owner_access(owner_subject.clone(), &sub.auth_context).is_err() {
846 unauthorized_subscribers.insert((sub.session_id, sub.client_sub_id.clone()));
847 continue;
848 }
849
850 let message = RealtimeMessage::JobUpdate {
851 client_sub_id: sub.client_sub_id.clone(),
852 job: job_data.clone(),
853 };
854
855 if let Err(e) = session_server
856 .send_to_session(sub.session_id, message)
857 .await
858 {
859 tracing::trace!(%job_id, error = %e, "Failed to send job update");
860 } else {
861 tracing::trace!(%job_id, "Job update sent");
862 }
863 }
864
865 if !unauthorized_subscribers.is_empty() {
866 let mut subs = job_subscriptions.write().await;
867 if let Some(entries) = subs.get_mut(&job_id) {
868 entries.retain(|entry| {
869 !unauthorized_subscribers
870 .contains(&(entry.session_id, entry.client_sub_id.clone()))
871 });
872 }
873 subs.retain(|_, v| !v.is_empty());
874 }
875 }
876
877 async fn handle_workflow_change(
879 workflow_id: Uuid,
880 workflow_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
881 session_server: &Arc<SessionServer>,
882 db_pool: &sqlx::PgPool,
883 ) {
884 let subs = workflow_subscriptions.read().await;
885 let subscribers = match subs.get(&workflow_id) {
886 Some(s) if !s.is_empty() => s.clone(),
887 _ => return, };
889 drop(subs); let workflow_data = match Self::fetch_workflow_data_static(workflow_id, db_pool).await {
893 Ok(data) => data,
894 Err(e) => {
895 tracing::debug!(%workflow_id, error = %e, "Failed to fetch workflow data");
896 return;
897 }
898 };
899
900 let owner_subject =
901 match Self::fetch_workflow_owner_subject_static(workflow_id, db_pool).await {
902 Ok(owner) => owner,
903 Err(e) => {
904 tracing::debug!(%workflow_id, error = %e, "Failed to fetch workflow owner");
905 return;
906 }
907 };
908
909 let mut unauthorized_subscribers: HashSet<(SessionId, String)> = HashSet::new();
910
911 for sub in subscribers {
913 if Self::check_owner_access(owner_subject.clone(), &sub.auth_context).is_err() {
914 unauthorized_subscribers.insert((sub.session_id, sub.client_sub_id.clone()));
915 continue;
916 }
917
918 let message = RealtimeMessage::WorkflowUpdate {
919 client_sub_id: sub.client_sub_id.clone(),
920 workflow: workflow_data.clone(),
921 };
922
923 if let Err(e) = session_server
924 .send_to_session(sub.session_id, message)
925 .await
926 {
927 tracing::trace!(%workflow_id, error = %e, "Failed to send workflow update");
928 } else {
929 tracing::trace!(%workflow_id, "Workflow update sent");
930 }
931 }
932
933 if !unauthorized_subscribers.is_empty() {
934 let mut subs = workflow_subscriptions.write().await;
935 if let Some(entries) = subs.get_mut(&workflow_id) {
936 entries.retain(|entry| {
937 !unauthorized_subscribers
938 .contains(&(entry.session_id, entry.client_sub_id.clone()))
939 });
940 }
941 subs.retain(|_, v| !v.is_empty());
942 }
943 }
944
945 async fn handle_workflow_step_change(
947 step_id: Uuid,
948 workflow_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
949 session_server: &Arc<SessionServer>,
950 db_pool: &sqlx::PgPool,
951 ) {
952 let workflow_id: Option<Uuid> = match sqlx::query_scalar(
954 "SELECT workflow_run_id FROM forge_workflow_steps WHERE id = $1",
955 )
956 .bind(step_id)
957 .fetch_optional(db_pool)
958 .await
959 {
960 Ok(id) => id,
961 Err(e) => {
962 tracing::debug!(%step_id, error = %e, "Failed to look up workflow for step");
963 return;
964 }
965 };
966
967 if let Some(wf_id) = workflow_id {
968 Self::handle_workflow_change(wf_id, workflow_subscriptions, session_server, db_pool)
970 .await;
971 }
972 }
973
974 #[allow(clippy::type_complexity)]
976 async fn fetch_job_data_static(
977 job_id: Uuid,
978 db_pool: &sqlx::PgPool,
979 ) -> forge_core::Result<JobData> {
980 let row: Option<(
981 String,
982 Option<i32>,
983 Option<String>,
984 Option<serde_json::Value>,
985 Option<String>,
986 )> = sqlx::query_as(
987 r#"
988 SELECT status, progress_percent, progress_message, output, last_error
989 FROM forge_jobs WHERE id = $1
990 "#,
991 )
992 .bind(job_id)
993 .fetch_optional(db_pool)
994 .await
995 .map_err(forge_core::ForgeError::Sql)?;
996
997 match row {
998 Some((status, progress_percent, progress_message, output, error)) => Ok(JobData {
999 job_id: job_id.to_string(),
1000 status,
1001 progress_percent,
1002 progress_message,
1003 output,
1004 error,
1005 }),
1006 None => Err(forge_core::ForgeError::NotFound(format!(
1007 "Job {} not found",
1008 job_id
1009 ))),
1010 }
1011 }
1012
1013 async fn fetch_job_owner_subject_static(
1014 job_id: Uuid,
1015 db_pool: &sqlx::PgPool,
1016 ) -> forge_core::Result<Option<String>> {
1017 let owner_subject: Option<Option<String>> =
1018 sqlx::query_scalar("SELECT owner_subject FROM forge_jobs WHERE id = $1")
1019 .bind(job_id)
1020 .fetch_optional(db_pool)
1021 .await
1022 .map_err(forge_core::ForgeError::Sql)?;
1023
1024 owner_subject
1025 .ok_or_else(|| forge_core::ForgeError::NotFound(format!("Job {} not found", job_id)))
1026 }
1027
1028 #[allow(clippy::type_complexity)]
1030 async fn fetch_workflow_data_static(
1031 workflow_id: Uuid,
1032 db_pool: &sqlx::PgPool,
1033 ) -> forge_core::Result<WorkflowData> {
1034 let row: Option<(
1035 String,
1036 Option<String>,
1037 Option<serde_json::Value>,
1038 Option<String>,
1039 )> = sqlx::query_as(
1040 r#"
1041 SELECT status, current_step, output, error
1042 FROM forge_workflow_runs WHERE id = $1
1043 "#,
1044 )
1045 .bind(workflow_id)
1046 .fetch_optional(db_pool)
1047 .await
1048 .map_err(forge_core::ForgeError::Sql)?;
1049
1050 let (status, current_step, output, error) = match row {
1051 Some(r) => r,
1052 None => {
1053 return Err(forge_core::ForgeError::NotFound(format!(
1054 "Workflow {} not found",
1055 workflow_id
1056 )));
1057 }
1058 };
1059
1060 let step_rows: Vec<(String, String, Option<String>)> = sqlx::query_as(
1061 r#"
1062 SELECT step_name, status, error
1063 FROM forge_workflow_steps
1064 WHERE workflow_run_id = $1
1065 ORDER BY started_at ASC NULLS LAST
1066 "#,
1067 )
1068 .bind(workflow_id)
1069 .fetch_all(db_pool)
1070 .await
1071 .map_err(forge_core::ForgeError::Sql)?;
1072
1073 let steps = step_rows
1074 .into_iter()
1075 .map(|(name, status, error)| WorkflowStepData {
1076 name,
1077 status,
1078 error,
1079 })
1080 .collect();
1081
1082 Ok(WorkflowData {
1083 workflow_id: workflow_id.to_string(),
1084 status,
1085 current_step,
1086 steps,
1087 output,
1088 error,
1089 })
1090 }
1091
1092 async fn fetch_workflow_owner_subject_static(
1093 workflow_id: Uuid,
1094 db_pool: &sqlx::PgPool,
1095 ) -> forge_core::Result<Option<String>> {
1096 let owner_subject: Option<Option<String>> =
1097 sqlx::query_scalar("SELECT owner_subject FROM forge_workflow_runs WHERE id = $1")
1098 .bind(workflow_id)
1099 .fetch_optional(db_pool)
1100 .await
1101 .map_err(forge_core::ForgeError::Sql)?;
1102
1103 owner_subject.ok_or_else(|| {
1104 forge_core::ForgeError::NotFound(format!("Workflow {} not found", workflow_id))
1105 })
1106 }
1107
1108 async fn execute_query_static(
1110 registry: &FunctionRegistry,
1111 db_pool: &sqlx::PgPool,
1112 query_name: &str,
1113 args: &serde_json::Value,
1114 auth_context: &forge_core::function::AuthContext,
1115 ) -> forge_core::Result<(serde_json::Value, ReadSet)> {
1116 match registry.get(query_name) {
1117 Some(FunctionEntry::Query { info, handler }) => {
1118 Self::check_query_auth(info, auth_context)?;
1119 Self::check_identity_args(query_name, args, auth_context, !info.is_public)?;
1120
1121 let ctx = forge_core::function::QueryContext::new(
1122 db_pool.clone(),
1123 auth_context.clone(),
1124 forge_core::function::RequestMetadata::new(),
1125 );
1126
1127 let normalized_args = match args {
1128 v if v.as_object().is_some_and(|o| o.is_empty()) => serde_json::Value::Null,
1129 v => v.clone(),
1130 };
1131
1132 let data = handler(&ctx, normalized_args).await?;
1133
1134 let mut read_set = ReadSet::new();
1136
1137 if info.table_dependencies.is_empty() {
1138 let table_name = Self::extract_table_name(query_name);
1140 read_set.add_table(&table_name);
1141 tracing::trace!(
1142 query = %query_name,
1143 fallback_table = %table_name,
1144 "Using naming convention fallback for table dependency"
1145 );
1146 } else {
1147 for table in info.table_dependencies {
1148 read_set.add_table(*table);
1149 }
1150 }
1151
1152 Ok((data, read_set))
1153 }
1154 _ => Err(forge_core::ForgeError::Validation(format!(
1155 "Query '{}' not found or not a query",
1156 query_name
1157 ))),
1158 }
1159 }
1160
1161 fn extract_table_name(query_name: &str) -> String {
1163 if let Some(rest) = query_name.strip_prefix("get_") {
1164 rest.to_string()
1165 } else if let Some(rest) = query_name.strip_prefix("list_") {
1166 rest.to_string()
1167 } else if let Some(rest) = query_name.strip_prefix("find_") {
1168 rest.to_string()
1169 } else if let Some(rest) = query_name.strip_prefix("fetch_") {
1170 rest.to_string()
1171 } else {
1172 query_name.to_string()
1173 }
1174 }
1175
1176 fn check_query_auth(
1177 info: &forge_core::function::FunctionInfo,
1178 auth: &forge_core::function::AuthContext,
1179 ) -> forge_core::Result<()> {
1180 if info.is_public {
1181 return Ok(());
1182 }
1183
1184 if !auth.is_authenticated() {
1185 return Err(forge_core::ForgeError::Unauthorized(
1186 "Authentication required".into(),
1187 ));
1188 }
1189
1190 if let Some(role) = info.required_role
1191 && !auth.has_role(role)
1192 {
1193 return Err(forge_core::ForgeError::Forbidden(format!(
1194 "Role '{}' required",
1195 role
1196 )));
1197 }
1198
1199 Ok(())
1200 }
1201
1202 fn check_identity_args(
1203 function_name: &str,
1204 args: &serde_json::Value,
1205 auth: &forge_core::function::AuthContext,
1206 enforce_scope: bool,
1207 ) -> forge_core::Result<()> {
1208 if auth.is_admin() {
1209 return Ok(());
1210 }
1211
1212 let Some(obj) = args.as_object() else {
1213 if enforce_scope && auth.is_authenticated() {
1214 return Err(forge_core::ForgeError::Forbidden(format!(
1215 "Function '{function_name}' must include identity or tenant scope arguments"
1216 )));
1217 }
1218 return Ok(());
1219 };
1220
1221 let mut principal_values: Vec<String> = Vec::new();
1222 if let Some(user_id) = auth.user_id().map(|id| id.to_string()) {
1223 principal_values.push(user_id);
1224 }
1225 if let Some(subject) = auth.principal_id()
1226 && !principal_values.iter().any(|v| v == &subject)
1227 {
1228 principal_values.push(subject);
1229 }
1230
1231 let mut has_scope_key = false;
1232
1233 for key in [
1234 "user_id",
1235 "userId",
1236 "owner_id",
1237 "ownerId",
1238 "owner_subject",
1239 "ownerSubject",
1240 "subject",
1241 "sub",
1242 "principal_id",
1243 "principalId",
1244 ] {
1245 let Some(value) = obj.get(key) else {
1246 continue;
1247 };
1248 has_scope_key = true;
1249
1250 if !auth.is_authenticated() {
1251 return Err(forge_core::ForgeError::Unauthorized(format!(
1252 "Function '{function_name}' requires authentication for identity-scoped argument '{key}'"
1253 )));
1254 }
1255
1256 let serde_json::Value::String(actual) = value else {
1257 return Err(forge_core::ForgeError::InvalidArgument(format!(
1258 "Function '{function_name}' argument '{key}' must be a non-empty string"
1259 )));
1260 };
1261
1262 if actual.trim().is_empty() || !principal_values.iter().any(|v| v == actual) {
1263 return Err(forge_core::ForgeError::Forbidden(format!(
1264 "Function '{function_name}' argument '{key}' does not match authenticated principal"
1265 )));
1266 }
1267 }
1268
1269 for key in ["tenant_id", "tenantId"] {
1270 let Some(value) = obj.get(key) else {
1271 continue;
1272 };
1273 has_scope_key = true;
1274
1275 if !auth.is_authenticated() {
1276 return Err(forge_core::ForgeError::Unauthorized(format!(
1277 "Function '{function_name}' requires authentication for tenant-scoped argument '{key}'"
1278 )));
1279 }
1280
1281 let expected = auth
1282 .claim("tenant_id")
1283 .and_then(|v| v.as_str())
1284 .ok_or_else(|| {
1285 forge_core::ForgeError::Forbidden(format!(
1286 "Function '{function_name}' argument '{key}' is not allowed for this principal"
1287 ))
1288 })?;
1289
1290 let serde_json::Value::String(actual) = value else {
1291 return Err(forge_core::ForgeError::InvalidArgument(format!(
1292 "Function '{function_name}' argument '{key}' must be a non-empty string"
1293 )));
1294 };
1295
1296 if actual.trim().is_empty() || actual != expected {
1297 return Err(forge_core::ForgeError::Forbidden(format!(
1298 "Function '{function_name}' argument '{key}' does not match authenticated tenant"
1299 )));
1300 }
1301 }
1302
1303 if enforce_scope && auth.is_authenticated() && !has_scope_key {
1304 return Err(forge_core::ForgeError::Forbidden(format!(
1305 "Function '{function_name}' must include identity or tenant scope arguments"
1306 )));
1307 }
1308
1309 Ok(())
1310 }
1311
1312 async fn ensure_job_access(
1313 db_pool: &sqlx::PgPool,
1314 job_id: Uuid,
1315 auth: &forge_core::function::AuthContext,
1316 ) -> forge_core::Result<()> {
1317 let owner_subject_row: Option<(Option<String>,)> = sqlx::query_as(
1318 r#"
1319 SELECT owner_subject
1320 FROM forge_jobs
1321 WHERE id = $1
1322 "#,
1323 )
1324 .bind(job_id)
1325 .fetch_optional(db_pool)
1326 .await
1327 .map_err(forge_core::ForgeError::Sql)?;
1328
1329 let owner_subject = owner_subject_row
1330 .ok_or_else(|| forge_core::ForgeError::NotFound(format!("Job {} not found", job_id)))?
1331 .0;
1332
1333 Self::check_owner_access(owner_subject, auth)
1334 }
1335
1336 async fn ensure_workflow_access(
1337 db_pool: &sqlx::PgPool,
1338 workflow_id: Uuid,
1339 auth: &forge_core::function::AuthContext,
1340 ) -> forge_core::Result<()> {
1341 let owner_subject_row: Option<(Option<String>,)> = sqlx::query_as(
1342 r#"
1343 SELECT owner_subject
1344 FROM forge_workflow_runs
1345 WHERE id = $1
1346 "#,
1347 )
1348 .bind(workflow_id)
1349 .fetch_optional(db_pool)
1350 .await
1351 .map_err(forge_core::ForgeError::Sql)?;
1352
1353 let owner_subject = owner_subject_row
1354 .ok_or_else(|| {
1355 forge_core::ForgeError::NotFound(format!("Workflow {} not found", workflow_id))
1356 })?
1357 .0;
1358
1359 Self::check_owner_access(owner_subject, auth)
1360 }
1361
1362 fn check_owner_access(
1363 owner_subject: Option<String>,
1364 auth: &forge_core::function::AuthContext,
1365 ) -> forge_core::Result<()> {
1366 if auth.is_admin() {
1367 return Ok(());
1368 }
1369
1370 let principal = auth.principal_id().ok_or_else(|| {
1371 forge_core::ForgeError::Unauthorized("Authentication required".to_string())
1372 })?;
1373
1374 match owner_subject {
1375 Some(owner) if owner == principal => Ok(()),
1376 Some(_) => Err(forge_core::ForgeError::Forbidden(
1377 "Not authorized to access this resource".to_string(),
1378 )),
1379 None => Err(forge_core::ForgeError::Forbidden(
1380 "Resource has no owner; admin role required".to_string(),
1381 )),
1382 }
1383 }
1384
1385 pub fn stop(&self) {
1387 let _ = self.shutdown_tx.send(());
1388 self.change_listener.stop();
1389 }
1390
1391 pub async fn stats(&self) -> ReactorStats {
1393 let session_stats = self.session_server.stats().await;
1394 let inv_stats = self.invalidation_engine.stats().await;
1395
1396 ReactorStats {
1397 connections: session_stats.connections,
1398 subscriptions: session_stats.subscriptions,
1399 pending_invalidations: inv_stats.pending_subscriptions,
1400 listener_running: self.change_listener.is_running(),
1401 }
1402 }
1403}
1404
1405#[derive(Debug, Clone)]
1407pub struct ReactorStats {
1408 pub connections: usize,
1409 pub subscriptions: usize,
1410 pub pending_invalidations: usize,
1411 pub listener_running: bool,
1412}
1413
1414#[cfg(test)]
1415mod tests {
1416 use super::*;
1417 use std::collections::HashMap;
1418
1419 #[test]
1420 fn test_reactor_config_default() {
1421 let config = ReactorConfig::default();
1422 assert_eq!(config.listener.channel, "forge_changes");
1423 assert_eq!(config.invalidation.debounce_ms, 50);
1424 assert_eq!(config.max_listener_restarts, 5);
1425 assert_eq!(config.listener_restart_delay_ms, 1000);
1426 }
1427
1428 #[test]
1429 fn test_compute_hash() {
1430 let data1 = serde_json::json!({"name": "test"});
1431 let data2 = serde_json::json!({"name": "test"});
1432 let data3 = serde_json::json!({"name": "different"});
1433
1434 let hash1 = Reactor::compute_hash(&data1);
1435 let hash2 = Reactor::compute_hash(&data2);
1436 let hash3 = Reactor::compute_hash(&data3);
1437
1438 assert_eq!(hash1, hash2);
1439 assert_ne!(hash1, hash3);
1440 }
1441
1442 #[test]
1443 fn test_check_identity_args_rejects_cross_user() {
1444 let user_id = uuid::Uuid::new_v4();
1445 let auth = forge_core::function::AuthContext::authenticated(
1446 user_id,
1447 vec!["user".to_string()],
1448 HashMap::from([(
1449 "sub".to_string(),
1450 serde_json::Value::String(user_id.to_string()),
1451 )]),
1452 );
1453
1454 let result = Reactor::check_identity_args(
1455 "list_orders",
1456 &serde_json::json!({"user_id": uuid::Uuid::new_v4().to_string()}),
1457 &auth,
1458 true,
1459 );
1460 assert!(matches!(result, Err(forge_core::ForgeError::Forbidden(_))));
1461 }
1462
1463 #[test]
1464 fn test_check_identity_args_requires_scope_for_non_public_queries() {
1465 let user_id = uuid::Uuid::new_v4();
1466 let auth = forge_core::function::AuthContext::authenticated(
1467 user_id,
1468 vec!["user".to_string()],
1469 HashMap::from([(
1470 "sub".to_string(),
1471 serde_json::Value::String(user_id.to_string()),
1472 )]),
1473 );
1474
1475 let result =
1476 Reactor::check_identity_args("list_orders", &serde_json::json!({}), &auth, true);
1477 assert!(matches!(result, Err(forge_core::ForgeError::Forbidden(_))));
1478 }
1479
1480 #[test]
1481 fn test_check_owner_access_allows_admin() {
1482 let auth = forge_core::function::AuthContext::authenticated_without_uuid(
1483 vec!["admin".to_string()],
1484 HashMap::from([(
1485 "sub".to_string(),
1486 serde_json::Value::String("admin-1".to_string()),
1487 )]),
1488 );
1489
1490 let result = Reactor::check_owner_access(Some("other-user".to_string()), &auth);
1491 assert!(result.is_ok());
1492 }
1493}