1#![allow(clippy::unwrap_used, clippy::indexing_slicing)]
4
5use std::sync::RwLock;
6
7use chrono::{DateTime, Utc};
8use uuid::Uuid;
9
10use crate::error::{ForgeError, Result};
11use crate::job::JobStatus;
12use crate::workflow::WorkflowStatus;
13
14#[derive(Debug, Clone)]
16pub struct DispatchedJob {
17 pub id: Uuid,
18 pub job_type: String,
19 pub args: serde_json::Value,
20 pub owner_subject: Option<String>,
21 pub in_connection: bool,
23 pub dispatched_at: DateTime<Utc>,
24 pub scheduled_at: Option<DateTime<Utc>>,
26 pub status: JobStatus,
27 pub cancel_reason: Option<String>,
28}
29
30#[derive(Debug, Clone)]
32pub struct StartedWorkflow {
33 pub run_id: Uuid,
34 pub workflow_name: String,
35 pub input: serde_json::Value,
36 pub started_at: DateTime<Utc>,
37 pub status: WorkflowStatus,
38}
39
40pub struct MockJobDispatch {
42 jobs: RwLock<Vec<DispatchedJob>>,
43}
44
45impl MockJobDispatch {
46 pub fn new() -> Self {
47 Self {
48 jobs: RwLock::new(Vec::new()),
49 }
50 }
51
52 pub async fn dispatch<T: serde::Serialize>(&self, job_type: &str, args: T) -> Result<Uuid> {
53 self.dispatch_inner(job_type, args, None, false, None).await
54 }
55
56 pub async fn dispatch_at<T: serde::Serialize>(
57 &self,
58 job_type: &str,
59 args: T,
60 scheduled_at: DateTime<Utc>,
61 ) -> Result<Uuid> {
62 self.dispatch_inner(job_type, args, None, false, Some(scheduled_at))
63 .await
64 }
65
66 async fn dispatch_inner<T: serde::Serialize>(
67 &self,
68 job_type: &str,
69 args: T,
70 owner_subject: Option<String>,
71 in_connection: bool,
72 scheduled_at: Option<DateTime<Utc>>,
73 ) -> Result<Uuid> {
74 let id = Uuid::new_v4();
75 let args_json =
76 serde_json::to_value(args).map_err(|e| ForgeError::Serialization(e.to_string()))?;
77
78 let job = DispatchedJob {
79 id,
80 job_type: job_type.to_string(),
81 args: args_json,
82 owner_subject,
83 in_connection,
84 dispatched_at: Utc::now(),
85 scheduled_at,
86 status: JobStatus::Pending,
87 cancel_reason: None,
88 };
89
90 self.jobs.write().expect("jobs lock poisoned").push(job);
91 Ok(id)
92 }
93
94 pub fn dispatched_jobs(&self) -> Vec<DispatchedJob> {
95 self.jobs.read().expect("jobs lock poisoned").clone()
96 }
97
98 pub fn jobs_of_type(&self, job_type: &str) -> Vec<DispatchedJob> {
99 self.jobs
100 .read()
101 .expect("jobs lock poisoned")
102 .iter()
103 .filter(|j| j.job_type == job_type)
104 .cloned()
105 .collect()
106 }
107
108 pub fn assert_dispatched(&self, job_type: &str) {
109 let jobs = self.jobs.read().expect("jobs lock poisoned");
110 let found = jobs.iter().any(|j| j.job_type == job_type);
111 assert!(
112 found,
113 "Expected job '{}' to be dispatched, but it wasn't. Dispatched jobs: {:?}",
114 job_type,
115 jobs.iter().map(|j| &j.job_type).collect::<Vec<_>>()
116 );
117 }
118
119 pub fn assert_dispatched_with<F>(&self, job_type: &str, predicate: F)
120 where
121 F: Fn(&serde_json::Value) -> bool,
122 {
123 let jobs = self.jobs.read().expect("jobs lock poisoned");
124 let found = jobs
125 .iter()
126 .any(|j| j.job_type == job_type && predicate(&j.args));
127 assert!(
128 found,
129 "Expected job '{}' with matching args to be dispatched",
130 job_type
131 );
132 }
133
134 pub fn assert_not_dispatched(&self, job_type: &str) {
135 let jobs = self.jobs.read().expect("jobs lock poisoned");
136 let found = jobs.iter().any(|j| j.job_type == job_type);
137 assert!(
138 !found,
139 "Expected job '{}' NOT to be dispatched, but it was",
140 job_type
141 );
142 }
143
144 pub fn assert_dispatch_count(&self, job_type: &str, expected: usize) {
145 let jobs = self.jobs.read().expect("jobs lock poisoned");
146 let count = jobs.iter().filter(|j| j.job_type == job_type).count();
147 assert_eq!(
148 count, expected,
149 "Expected {} dispatches of '{}', but found {}",
150 expected, job_type, count
151 );
152 }
153
154 pub fn clear(&self) {
155 self.jobs.write().expect("jobs lock poisoned").clear();
156 }
157
158 pub fn complete_job(&self, job_id: Uuid) {
159 let mut jobs = self.jobs.write().expect("jobs lock poisoned");
160 if let Some(job) = jobs.iter_mut().find(|j| j.id == job_id) {
161 job.status = JobStatus::Completed;
162 }
163 }
164
165 pub fn fail_job(&self, job_id: Uuid) {
166 let mut jobs = self.jobs.write().expect("jobs lock poisoned");
167 if let Some(job) = jobs.iter_mut().find(|j| j.id == job_id) {
168 job.status = JobStatus::Failed;
169 }
170 }
171
172 pub fn cancel_job(&self, job_id: Uuid, reason: Option<String>) {
173 let mut jobs = self.jobs.write().expect("jobs lock poisoned");
174 if let Some(job) = jobs.iter_mut().find(|j| j.id == job_id) {
175 job.status = JobStatus::Cancelled;
176 job.cancel_reason = reason;
177 }
178 }
179}
180
181impl Default for MockJobDispatch {
182 fn default() -> Self {
183 Self::new()
184 }
185}
186
187impl crate::function::JobDispatch for MockJobDispatch {
188 fn get_info(&self, _job_type: &str) -> Option<crate::job::JobInfo> {
189 None
190 }
191
192 fn dispatch_by_name(
193 &self,
194 job_type: &str,
195 args: serde_json::Value,
196 owner_subject: Option<String>,
197 _tenant_id: Option<Uuid>,
198 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Uuid>> + Send + '_>> {
199 let job_type = job_type.to_string();
200 Box::pin(async move {
201 self.dispatch_inner(&job_type, args, owner_subject, false, None)
202 .await
203 })
204 }
205
206 fn dispatch_by_name_at(
207 &self,
208 job_type: &str,
209 args: serde_json::Value,
210 scheduled_at: DateTime<Utc>,
211 owner_subject: Option<String>,
212 _tenant_id: Option<Uuid>,
213 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Uuid>> + Send + '_>> {
214 let job_type = job_type.to_string();
215 Box::pin(async move {
216 self.dispatch_inner(&job_type, args, owner_subject, false, Some(scheduled_at))
217 .await
218 })
219 }
220
221 fn dispatch_in_conn<'a>(
222 &'a self,
223 _conn: &'a mut sqlx::PgConnection,
224 job_type: &'a str,
225 args: serde_json::Value,
226 owner_subject: Option<String>,
227 _tenant_id: Option<Uuid>,
228 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Uuid>> + Send + 'a>> {
229 Box::pin(async move {
230 self.dispatch_inner(job_type, args, owner_subject, true, None)
231 .await
232 })
233 }
234
235 fn dispatch_in_conn_at<'a>(
236 &'a self,
237 _conn: &'a mut sqlx::PgConnection,
238 job_type: &'a str,
239 args: serde_json::Value,
240 scheduled_at: DateTime<Utc>,
241 owner_subject: Option<String>,
242 _tenant_id: Option<Uuid>,
243 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Uuid>> + Send + 'a>> {
244 Box::pin(async move {
245 self.dispatch_inner(job_type, args, owner_subject, true, Some(scheduled_at))
246 .await
247 })
248 }
249
250 fn cancel(
251 &self,
252 job_id: Uuid,
253 reason: Option<String>,
254 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<bool>> + Send + '_>> {
255 Box::pin(async move {
256 self.cancel_job(job_id, reason);
257 Ok(true)
258 })
259 }
260}
261
262pub struct MockWorkflowDispatch {
264 workflows: RwLock<Vec<StartedWorkflow>>,
265}
266
267impl MockWorkflowDispatch {
268 pub fn new() -> Self {
269 Self {
270 workflows: RwLock::new(Vec::new()),
271 }
272 }
273
274 pub async fn start<T: serde::Serialize>(&self, workflow_name: &str, input: T) -> Result<Uuid> {
275 let run_id = Uuid::new_v4();
276 let input_json =
277 serde_json::to_value(input).map_err(|e| ForgeError::Serialization(e.to_string()))?;
278
279 let workflow = StartedWorkflow {
280 run_id,
281 workflow_name: workflow_name.to_string(),
282 input: input_json,
283 started_at: Utc::now(),
284 status: WorkflowStatus::Pending,
285 };
286
287 self.workflows
288 .write()
289 .expect("workflows lock poisoned")
290 .push(workflow);
291 Ok(run_id)
292 }
293
294 pub fn started_workflows(&self) -> Vec<StartedWorkflow> {
295 self.workflows
296 .read()
297 .expect("workflows lock poisoned")
298 .clone()
299 }
300
301 pub fn workflows_named(&self, name: &str) -> Vec<StartedWorkflow> {
302 self.workflows
303 .read()
304 .expect("workflows lock poisoned")
305 .iter()
306 .filter(|w| w.workflow_name == name)
307 .cloned()
308 .collect()
309 }
310
311 pub fn assert_started(&self, workflow_name: &str) {
312 let workflows = self.workflows.read().expect("workflows lock poisoned");
313 let found = workflows.iter().any(|w| w.workflow_name == workflow_name);
314 assert!(
315 found,
316 "Expected workflow '{}' to be started, but it wasn't. Started workflows: {:?}",
317 workflow_name,
318 workflows
319 .iter()
320 .map(|w| &w.workflow_name)
321 .collect::<Vec<_>>()
322 );
323 }
324
325 pub fn assert_started_with<F>(&self, workflow_name: &str, predicate: F)
326 where
327 F: Fn(&serde_json::Value) -> bool,
328 {
329 let workflows = self.workflows.read().expect("workflows lock poisoned");
330 let found = workflows
331 .iter()
332 .any(|w| w.workflow_name == workflow_name && predicate(&w.input));
333 assert!(
334 found,
335 "Expected workflow '{}' with matching input to be started",
336 workflow_name
337 );
338 }
339
340 pub fn assert_not_started(&self, workflow_name: &str) {
341 let workflows = self.workflows.read().expect("workflows lock poisoned");
342 let found = workflows.iter().any(|w| w.workflow_name == workflow_name);
343 assert!(
344 !found,
345 "Expected workflow '{}' NOT to be started, but it was",
346 workflow_name
347 );
348 }
349
350 pub fn assert_start_count(&self, workflow_name: &str, expected: usize) {
351 let workflows = self.workflows.read().expect("workflows lock poisoned");
352 let count = workflows
353 .iter()
354 .filter(|w| w.workflow_name == workflow_name)
355 .count();
356 assert_eq!(
357 count, expected,
358 "Expected {} starts of '{}', but found {}",
359 expected, workflow_name, count
360 );
361 }
362
363 pub fn clear(&self) {
364 self.workflows
365 .write()
366 .expect("workflows lock poisoned")
367 .clear();
368 }
369
370 pub fn complete_workflow(&self, run_id: Uuid) {
371 let mut workflows = self.workflows.write().expect("workflows lock poisoned");
372 if let Some(workflow) = workflows.iter_mut().find(|w| w.run_id == run_id) {
373 workflow.status = WorkflowStatus::Completed;
374 }
375 }
376
377 pub fn fail_workflow(&self, run_id: Uuid) {
378 let mut workflows = self.workflows.write().expect("workflows lock poisoned");
379 if let Some(workflow) = workflows.iter_mut().find(|w| w.run_id == run_id) {
380 workflow.status = WorkflowStatus::Failed;
381 }
382 }
383}
384
385impl Default for MockWorkflowDispatch {
386 fn default() -> Self {
387 Self::new()
388 }
389}
390
391impl crate::function::WorkflowDispatch for MockWorkflowDispatch {
392 fn get_info(&self, _workflow_name: &str) -> Option<crate::workflow::WorkflowInfo> {
393 None
394 }
395
396 fn start_by_name(
397 &self,
398 workflow_name: &str,
399 input: serde_json::Value,
400 _owner_subject: Option<String>,
401 _trace_id: Option<String>,
402 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Uuid>> + Send + '_>> {
403 let name = workflow_name.to_string();
404 Box::pin(async move { self.start(&name, input).await })
405 }
406
407 fn start_in_conn<'a>(
408 &'a self,
409 _conn: &'a mut sqlx::PgConnection,
410 workflow_name: &'a str,
411 input: serde_json::Value,
412 _owner_subject: Option<String>,
413 _trace_id: Option<String>,
414 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Uuid>> + Send + 'a>> {
415 Box::pin(async move { self.start(workflow_name, input).await })
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422
423 #[tokio::test]
424 async fn test_mock_job_dispatch() {
425 let dispatch = MockJobDispatch::new();
426
427 let job_id = dispatch
428 .dispatch("send_email", serde_json::json!({"to": "test@example.com"}))
429 .await
430 .unwrap();
431
432 assert!(!job_id.is_nil());
433 dispatch.assert_dispatched("send_email");
434 dispatch.assert_not_dispatched("other_job");
435 }
436
437 #[tokio::test]
438 async fn test_job_dispatch_with_args() {
439 let dispatch = MockJobDispatch::new();
440
441 dispatch
442 .dispatch("send_email", serde_json::json!({"to": "test@example.com"}))
443 .await
444 .unwrap();
445
446 dispatch.assert_dispatched_with("send_email", |args| args["to"] == "test@example.com");
447 }
448
449 #[tokio::test]
450 async fn test_job_dispatch_count() {
451 let dispatch = MockJobDispatch::new();
452
453 dispatch
454 .dispatch("job_a", serde_json::json!({}))
455 .await
456 .unwrap();
457 dispatch
458 .dispatch("job_b", serde_json::json!({}))
459 .await
460 .unwrap();
461 dispatch
462 .dispatch("job_a", serde_json::json!({}))
463 .await
464 .unwrap();
465
466 dispatch.assert_dispatch_count("job_a", 2);
467 dispatch.assert_dispatch_count("job_b", 1);
468 }
469
470 #[tokio::test]
471 async fn test_mock_workflow_dispatch() {
472 let dispatch = MockWorkflowDispatch::new();
473
474 let run_id = dispatch
475 .start("onboarding", serde_json::json!({"user_id": "123"}))
476 .await
477 .unwrap();
478
479 assert!(!run_id.is_nil());
480 dispatch.assert_started("onboarding");
481 dispatch.assert_not_started("other_workflow");
482 }
483
484 #[tokio::test]
485 async fn test_workflow_dispatch_with_input() {
486 let dispatch = MockWorkflowDispatch::new();
487
488 dispatch
489 .start("onboarding", serde_json::json!({"user_id": "123"}))
490 .await
491 .unwrap();
492
493 dispatch.assert_started_with("onboarding", |input| input["user_id"] == "123");
494 }
495
496 #[tokio::test]
497 async fn test_clear() {
498 let dispatch = MockJobDispatch::new();
499 dispatch
500 .dispatch("test", serde_json::json!({}))
501 .await
502 .unwrap();
503
504 assert_eq!(dispatch.dispatched_jobs().len(), 1);
505 dispatch.clear();
506 assert_eq!(dispatch.dispatched_jobs().len(), 0);
507 }
508
509 #[tokio::test]
510 async fn test_job_status_simulation() {
511 let dispatch = MockJobDispatch::new();
512 let job_id = dispatch
513 .dispatch("test", serde_json::json!({}))
514 .await
515 .unwrap();
516
517 let jobs = dispatch.dispatched_jobs();
518 assert_eq!(jobs[0].status, JobStatus::Pending);
519
520 dispatch.complete_job(job_id);
521
522 let jobs = dispatch.dispatched_jobs();
523 assert_eq!(jobs[0].status, JobStatus::Completed);
524 }
525
526 #[tokio::test]
527 async fn test_job_cancel_simulation() {
528 let dispatch = MockJobDispatch::new();
529 let job_id = dispatch
530 .dispatch("test", serde_json::json!({}))
531 .await
532 .unwrap();
533
534 dispatch.cancel_job(job_id, Some("user request".to_string()));
535
536 let jobs = dispatch.dispatched_jobs();
537 assert_eq!(jobs[0].status, JobStatus::Cancelled);
538 assert_eq!(jobs[0].cancel_reason.as_deref(), Some("user request"));
539 }
540}