Skip to main content

forge_core/testing/
mock_dispatch.rs

1//! Mock dispatchers for testing job and workflow dispatch.
2
3#![allow(clippy::unwrap_used, clippy::indexing_slicing)]
4
5use std::sync::RwLock;
6
7use chrono::{DateTime, Utc};
8use uuid::Uuid;
9
10use crate::error::{ForgeError, Result};
11use crate::job::JobStatus;
12use crate::workflow::WorkflowStatus;
13
14/// Record of a dispatched job.
15#[derive(Debug, Clone)]
16pub struct DispatchedJob {
17    pub id: Uuid,
18    pub job_type: String,
19    pub args: serde_json::Value,
20    pub owner_subject: Option<String>,
21    /// `true` when dispatched via `dispatch_in_conn`.
22    pub in_connection: bool,
23    pub dispatched_at: DateTime<Utc>,
24    /// `None` for immediately-runnable jobs.
25    pub scheduled_at: Option<DateTime<Utc>>,
26    pub status: JobStatus,
27    pub cancel_reason: Option<String>,
28}
29
30/// Record of a started workflow.
31#[derive(Debug, Clone)]
32pub struct StartedWorkflow {
33    pub run_id: Uuid,
34    pub workflow_name: String,
35    pub input: serde_json::Value,
36    pub started_at: DateTime<Utc>,
37    pub status: WorkflowStatus,
38}
39
40/// Records dispatched jobs for later verification.
41pub struct MockJobDispatch {
42    jobs: RwLock<Vec<DispatchedJob>>,
43}
44
45impl MockJobDispatch {
46    pub fn new() -> Self {
47        Self {
48            jobs: RwLock::new(Vec::new()),
49        }
50    }
51
52    pub async fn dispatch<T: serde::Serialize>(&self, job_type: &str, args: T) -> Result<Uuid> {
53        self.dispatch_inner(job_type, args, None, false, None).await
54    }
55
56    pub async fn dispatch_at<T: serde::Serialize>(
57        &self,
58        job_type: &str,
59        args: T,
60        scheduled_at: DateTime<Utc>,
61    ) -> Result<Uuid> {
62        self.dispatch_inner(job_type, args, None, false, Some(scheduled_at))
63            .await
64    }
65
66    async fn dispatch_inner<T: serde::Serialize>(
67        &self,
68        job_type: &str,
69        args: T,
70        owner_subject: Option<String>,
71        in_connection: bool,
72        scheduled_at: Option<DateTime<Utc>>,
73    ) -> Result<Uuid> {
74        let id = Uuid::new_v4();
75        let args_json =
76            serde_json::to_value(args).map_err(|e| ForgeError::Serialization(e.to_string()))?;
77
78        let job = DispatchedJob {
79            id,
80            job_type: job_type.to_string(),
81            args: args_json,
82            owner_subject,
83            in_connection,
84            dispatched_at: Utc::now(),
85            scheduled_at,
86            status: JobStatus::Pending,
87            cancel_reason: None,
88        };
89
90        self.jobs.write().expect("jobs lock poisoned").push(job);
91        Ok(id)
92    }
93
94    pub fn dispatched_jobs(&self) -> Vec<DispatchedJob> {
95        self.jobs.read().expect("jobs lock poisoned").clone()
96    }
97
98    pub fn jobs_of_type(&self, job_type: &str) -> Vec<DispatchedJob> {
99        self.jobs
100            .read()
101            .expect("jobs lock poisoned")
102            .iter()
103            .filter(|j| j.job_type == job_type)
104            .cloned()
105            .collect()
106    }
107
108    pub fn assert_dispatched(&self, job_type: &str) {
109        let jobs = self.jobs.read().expect("jobs lock poisoned");
110        let found = jobs.iter().any(|j| j.job_type == job_type);
111        assert!(
112            found,
113            "Expected job '{}' to be dispatched, but it wasn't. Dispatched jobs: {:?}",
114            job_type,
115            jobs.iter().map(|j| &j.job_type).collect::<Vec<_>>()
116        );
117    }
118
119    pub fn assert_dispatched_with<F>(&self, job_type: &str, predicate: F)
120    where
121        F: Fn(&serde_json::Value) -> bool,
122    {
123        let jobs = self.jobs.read().expect("jobs lock poisoned");
124        let found = jobs
125            .iter()
126            .any(|j| j.job_type == job_type && predicate(&j.args));
127        assert!(
128            found,
129            "Expected job '{}' with matching args to be dispatched",
130            job_type
131        );
132    }
133
134    pub fn assert_not_dispatched(&self, job_type: &str) {
135        let jobs = self.jobs.read().expect("jobs lock poisoned");
136        let found = jobs.iter().any(|j| j.job_type == job_type);
137        assert!(
138            !found,
139            "Expected job '{}' NOT to be dispatched, but it was",
140            job_type
141        );
142    }
143
144    pub fn assert_dispatch_count(&self, job_type: &str, expected: usize) {
145        let jobs = self.jobs.read().expect("jobs lock poisoned");
146        let count = jobs.iter().filter(|j| j.job_type == job_type).count();
147        assert_eq!(
148            count, expected,
149            "Expected {} dispatches of '{}', but found {}",
150            expected, job_type, count
151        );
152    }
153
154    pub fn clear(&self) {
155        self.jobs.write().expect("jobs lock poisoned").clear();
156    }
157
158    pub fn complete_job(&self, job_id: Uuid) {
159        let mut jobs = self.jobs.write().expect("jobs lock poisoned");
160        if let Some(job) = jobs.iter_mut().find(|j| j.id == job_id) {
161            job.status = JobStatus::Completed;
162        }
163    }
164
165    pub fn fail_job(&self, job_id: Uuid) {
166        let mut jobs = self.jobs.write().expect("jobs lock poisoned");
167        if let Some(job) = jobs.iter_mut().find(|j| j.id == job_id) {
168            job.status = JobStatus::Failed;
169        }
170    }
171
172    pub fn cancel_job(&self, job_id: Uuid, reason: Option<String>) {
173        let mut jobs = self.jobs.write().expect("jobs lock poisoned");
174        if let Some(job) = jobs.iter_mut().find(|j| j.id == job_id) {
175            job.status = JobStatus::Cancelled;
176            job.cancel_reason = reason;
177        }
178    }
179}
180
181impl Default for MockJobDispatch {
182    fn default() -> Self {
183        Self::new()
184    }
185}
186
187impl crate::function::JobDispatch for MockJobDispatch {
188    fn get_info(&self, _job_type: &str) -> Option<crate::job::JobInfo> {
189        None
190    }
191
192    fn dispatch_by_name(
193        &self,
194        job_type: &str,
195        args: serde_json::Value,
196        owner_subject: Option<String>,
197        _tenant_id: Option<Uuid>,
198    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Uuid>> + Send + '_>> {
199        let job_type = job_type.to_string();
200        Box::pin(async move {
201            self.dispatch_inner(&job_type, args, owner_subject, false, None)
202                .await
203        })
204    }
205
206    fn dispatch_by_name_at(
207        &self,
208        job_type: &str,
209        args: serde_json::Value,
210        scheduled_at: DateTime<Utc>,
211        owner_subject: Option<String>,
212        _tenant_id: Option<Uuid>,
213    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Uuid>> + Send + '_>> {
214        let job_type = job_type.to_string();
215        Box::pin(async move {
216            self.dispatch_inner(&job_type, args, owner_subject, false, Some(scheduled_at))
217                .await
218        })
219    }
220
221    fn dispatch_in_conn<'a>(
222        &'a self,
223        _conn: &'a mut sqlx::PgConnection,
224        job_type: &'a str,
225        args: serde_json::Value,
226        owner_subject: Option<String>,
227        _tenant_id: Option<Uuid>,
228    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Uuid>> + Send + 'a>> {
229        Box::pin(async move {
230            self.dispatch_inner(job_type, args, owner_subject, true, None)
231                .await
232        })
233    }
234
235    fn dispatch_in_conn_at<'a>(
236        &'a self,
237        _conn: &'a mut sqlx::PgConnection,
238        job_type: &'a str,
239        args: serde_json::Value,
240        scheduled_at: DateTime<Utc>,
241        owner_subject: Option<String>,
242        _tenant_id: Option<Uuid>,
243    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Uuid>> + Send + 'a>> {
244        Box::pin(async move {
245            self.dispatch_inner(job_type, args, owner_subject, true, Some(scheduled_at))
246                .await
247        })
248    }
249
250    fn cancel(
251        &self,
252        job_id: Uuid,
253        reason: Option<String>,
254    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<bool>> + Send + '_>> {
255        Box::pin(async move {
256            self.cancel_job(job_id, reason);
257            Ok(true)
258        })
259    }
260}
261
262/// Records started workflows for later verification.
263pub struct MockWorkflowDispatch {
264    workflows: RwLock<Vec<StartedWorkflow>>,
265}
266
267impl MockWorkflowDispatch {
268    pub fn new() -> Self {
269        Self {
270            workflows: RwLock::new(Vec::new()),
271        }
272    }
273
274    pub async fn start<T: serde::Serialize>(&self, workflow_name: &str, input: T) -> Result<Uuid> {
275        let run_id = Uuid::new_v4();
276        let input_json =
277            serde_json::to_value(input).map_err(|e| ForgeError::Serialization(e.to_string()))?;
278
279        let workflow = StartedWorkflow {
280            run_id,
281            workflow_name: workflow_name.to_string(),
282            input: input_json,
283            started_at: Utc::now(),
284            status: WorkflowStatus::Pending,
285        };
286
287        self.workflows
288            .write()
289            .expect("workflows lock poisoned")
290            .push(workflow);
291        Ok(run_id)
292    }
293
294    pub fn started_workflows(&self) -> Vec<StartedWorkflow> {
295        self.workflows
296            .read()
297            .expect("workflows lock poisoned")
298            .clone()
299    }
300
301    pub fn workflows_named(&self, name: &str) -> Vec<StartedWorkflow> {
302        self.workflows
303            .read()
304            .expect("workflows lock poisoned")
305            .iter()
306            .filter(|w| w.workflow_name == name)
307            .cloned()
308            .collect()
309    }
310
311    pub fn assert_started(&self, workflow_name: &str) {
312        let workflows = self.workflows.read().expect("workflows lock poisoned");
313        let found = workflows.iter().any(|w| w.workflow_name == workflow_name);
314        assert!(
315            found,
316            "Expected workflow '{}' to be started, but it wasn't. Started workflows: {:?}",
317            workflow_name,
318            workflows
319                .iter()
320                .map(|w| &w.workflow_name)
321                .collect::<Vec<_>>()
322        );
323    }
324
325    pub fn assert_started_with<F>(&self, workflow_name: &str, predicate: F)
326    where
327        F: Fn(&serde_json::Value) -> bool,
328    {
329        let workflows = self.workflows.read().expect("workflows lock poisoned");
330        let found = workflows
331            .iter()
332            .any(|w| w.workflow_name == workflow_name && predicate(&w.input));
333        assert!(
334            found,
335            "Expected workflow '{}' with matching input to be started",
336            workflow_name
337        );
338    }
339
340    pub fn assert_not_started(&self, workflow_name: &str) {
341        let workflows = self.workflows.read().expect("workflows lock poisoned");
342        let found = workflows.iter().any(|w| w.workflow_name == workflow_name);
343        assert!(
344            !found,
345            "Expected workflow '{}' NOT to be started, but it was",
346            workflow_name
347        );
348    }
349
350    pub fn assert_start_count(&self, workflow_name: &str, expected: usize) {
351        let workflows = self.workflows.read().expect("workflows lock poisoned");
352        let count = workflows
353            .iter()
354            .filter(|w| w.workflow_name == workflow_name)
355            .count();
356        assert_eq!(
357            count, expected,
358            "Expected {} starts of '{}', but found {}",
359            expected, workflow_name, count
360        );
361    }
362
363    pub fn clear(&self) {
364        self.workflows
365            .write()
366            .expect("workflows lock poisoned")
367            .clear();
368    }
369
370    pub fn complete_workflow(&self, run_id: Uuid) {
371        let mut workflows = self.workflows.write().expect("workflows lock poisoned");
372        if let Some(workflow) = workflows.iter_mut().find(|w| w.run_id == run_id) {
373            workflow.status = WorkflowStatus::Completed;
374        }
375    }
376
377    pub fn fail_workflow(&self, run_id: Uuid) {
378        let mut workflows = self.workflows.write().expect("workflows lock poisoned");
379        if let Some(workflow) = workflows.iter_mut().find(|w| w.run_id == run_id) {
380            workflow.status = WorkflowStatus::Failed;
381        }
382    }
383}
384
385impl Default for MockWorkflowDispatch {
386    fn default() -> Self {
387        Self::new()
388    }
389}
390
391impl crate::function::WorkflowDispatch for MockWorkflowDispatch {
392    fn get_info(&self, _workflow_name: &str) -> Option<crate::workflow::WorkflowInfo> {
393        None
394    }
395
396    fn start_by_name(
397        &self,
398        workflow_name: &str,
399        input: serde_json::Value,
400        _owner_subject: Option<String>,
401        _trace_id: Option<String>,
402    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Uuid>> + Send + '_>> {
403        let name = workflow_name.to_string();
404        Box::pin(async move { self.start(&name, input).await })
405    }
406
407    fn start_in_conn<'a>(
408        &'a self,
409        _conn: &'a mut sqlx::PgConnection,
410        workflow_name: &'a str,
411        input: serde_json::Value,
412        _owner_subject: Option<String>,
413        _trace_id: Option<String>,
414    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Uuid>> + Send + 'a>> {
415        Box::pin(async move { self.start(workflow_name, input).await })
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422
423    #[tokio::test]
424    async fn test_mock_job_dispatch() {
425        let dispatch = MockJobDispatch::new();
426
427        let job_id = dispatch
428            .dispatch("send_email", serde_json::json!({"to": "test@example.com"}))
429            .await
430            .unwrap();
431
432        assert!(!job_id.is_nil());
433        dispatch.assert_dispatched("send_email");
434        dispatch.assert_not_dispatched("other_job");
435    }
436
437    #[tokio::test]
438    async fn test_job_dispatch_with_args() {
439        let dispatch = MockJobDispatch::new();
440
441        dispatch
442            .dispatch("send_email", serde_json::json!({"to": "test@example.com"}))
443            .await
444            .unwrap();
445
446        dispatch.assert_dispatched_with("send_email", |args| args["to"] == "test@example.com");
447    }
448
449    #[tokio::test]
450    async fn test_job_dispatch_count() {
451        let dispatch = MockJobDispatch::new();
452
453        dispatch
454            .dispatch("job_a", serde_json::json!({}))
455            .await
456            .unwrap();
457        dispatch
458            .dispatch("job_b", serde_json::json!({}))
459            .await
460            .unwrap();
461        dispatch
462            .dispatch("job_a", serde_json::json!({}))
463            .await
464            .unwrap();
465
466        dispatch.assert_dispatch_count("job_a", 2);
467        dispatch.assert_dispatch_count("job_b", 1);
468    }
469
470    #[tokio::test]
471    async fn test_mock_workflow_dispatch() {
472        let dispatch = MockWorkflowDispatch::new();
473
474        let run_id = dispatch
475            .start("onboarding", serde_json::json!({"user_id": "123"}))
476            .await
477            .unwrap();
478
479        assert!(!run_id.is_nil());
480        dispatch.assert_started("onboarding");
481        dispatch.assert_not_started("other_workflow");
482    }
483
484    #[tokio::test]
485    async fn test_workflow_dispatch_with_input() {
486        let dispatch = MockWorkflowDispatch::new();
487
488        dispatch
489            .start("onboarding", serde_json::json!({"user_id": "123"}))
490            .await
491            .unwrap();
492
493        dispatch.assert_started_with("onboarding", |input| input["user_id"] == "123");
494    }
495
496    #[tokio::test]
497    async fn test_clear() {
498        let dispatch = MockJobDispatch::new();
499        dispatch
500            .dispatch("test", serde_json::json!({}))
501            .await
502            .unwrap();
503
504        assert_eq!(dispatch.dispatched_jobs().len(), 1);
505        dispatch.clear();
506        assert_eq!(dispatch.dispatched_jobs().len(), 0);
507    }
508
509    #[tokio::test]
510    async fn test_job_status_simulation() {
511        let dispatch = MockJobDispatch::new();
512        let job_id = dispatch
513            .dispatch("test", serde_json::json!({}))
514            .await
515            .unwrap();
516
517        let jobs = dispatch.dispatched_jobs();
518        assert_eq!(jobs[0].status, JobStatus::Pending);
519
520        dispatch.complete_job(job_id);
521
522        let jobs = dispatch.dispatched_jobs();
523        assert_eq!(jobs[0].status, JobStatus::Completed);
524    }
525
526    #[tokio::test]
527    async fn test_job_cancel_simulation() {
528        let dispatch = MockJobDispatch::new();
529        let job_id = dispatch
530            .dispatch("test", serde_json::json!({}))
531            .await
532            .unwrap();
533
534        dispatch.cancel_job(job_id, Some("user request".to_string()));
535
536        let jobs = dispatch.dispatched_jobs();
537        assert_eq!(jobs[0].status, JobStatus::Cancelled);
538        assert_eq!(jobs[0].cancel_reason.as_deref(), Some("user request"));
539    }
540}