Skip to main content

forge_core/testing/context/
job.rs

1//! Test context for job functions.
2
3#![allow(clippy::unwrap_used, clippy::indexing_slicing)]
4
5use std::collections::HashMap;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::{Arc, RwLock};
8
9use sqlx::PgPool;
10use uuid::Uuid;
11
12use serde::Serialize;
13
14use super::super::mock_http::{MockHttp, MockRequest, MockResponse};
15use super::build_test_auth;
16use crate::Result;
17use crate::env::{EnvAccess, EnvProvider, MockEnvProvider};
18use crate::function::AuthContext;
19
20/// Progress update recorded during testing.
21#[derive(Debug, Clone)]
22pub struct TestProgressUpdate {
23    pub percent: u8,
24    pub message: String,
25}
26
27/// Test context for job functions.
28pub struct TestJobContext {
29    pub job_id: Uuid,
30    pub job_type: String,
31    pub attempt: u32,
32    pub max_attempts: u32,
33    pub auth: AuthContext,
34    pool: Option<PgPool>,
35    http: Arc<MockHttp>,
36    progress_updates: Arc<RwLock<Vec<TestProgressUpdate>>>,
37    env_provider: Arc<MockEnvProvider>,
38    saved_data: Arc<RwLock<serde_json::Value>>,
39    cancel_requested: Arc<AtomicBool>,
40    dispatched_jobs: Arc<RwLock<Vec<(String, serde_json::Value, Uuid)>>>,
41    started_workflows: Arc<RwLock<Vec<(String, serde_json::Value, Uuid)>>>,
42}
43
44impl TestJobContext {
45    /// Create a new builder.
46    pub fn builder(job_type: impl Into<String>) -> TestJobContextBuilder {
47        TestJobContextBuilder::new(job_type)
48    }
49
50    /// Get the database pool (if available).
51    pub fn db(&self) -> Option<&PgPool> {
52        self.pool.as_ref()
53    }
54
55    /// Get the mock HTTP client.
56    pub fn http(&self) -> &MockHttp {
57        &self.http
58    }
59
60    /// Report job progress.
61    pub fn progress(&self, percent: u8, message: impl Into<String>) -> Result<()> {
62        let update = TestProgressUpdate {
63            percent: percent.min(100),
64            message: message.into(),
65        };
66        self.progress_updates.write().unwrap().push(update);
67        Ok(())
68    }
69
70    /// Get all progress updates.
71    pub fn progress_updates(&self) -> Vec<TestProgressUpdate> {
72        self.progress_updates.read().unwrap().clone()
73    }
74
75    /// Get all saved job data.
76    ///
77    /// Returns the in-memory data that was written via [`save()`](Self::save).
78    pub fn saved(&self) -> serde_json::Value {
79        self.saved_data.read().unwrap().clone()
80    }
81
82    /// Save a key-value pair to job data.
83    ///
84    /// Merges `key` into the saved data object. Use [`saved()`](Self::saved)
85    /// to read it back in assertions.
86    pub fn save(&self, key: &str, value: serde_json::Value) -> Result<()> {
87        let mut guard = self.saved_data.write().unwrap();
88        if let Some(map) = guard.as_object_mut() {
89            map.insert(key.to_string(), value);
90        } else {
91            let mut map = serde_json::Map::new();
92            map.insert(key.to_string(), value);
93            *guard = serde_json::Value::Object(map);
94        }
95        Ok(())
96    }
97
98    /// Check if this is a retry attempt.
99    pub fn is_retry(&self) -> bool {
100        self.attempt > 1
101    }
102
103    /// Check if this is the last attempt.
104    pub fn is_last_attempt(&self) -> bool {
105        self.attempt >= self.max_attempts
106    }
107
108    /// Simulate heartbeat (no-op in tests, but records the intent).
109    pub async fn heartbeat(&self) -> Result<()> {
110        Ok(())
111    }
112
113    /// Check if cancellation has been requested.
114    pub fn is_cancel_requested(&self) -> Result<bool> {
115        Ok(self.cancel_requested.load(Ordering::SeqCst))
116    }
117
118    /// Return an error if cancellation has been requested.
119    ///
120    /// Use this in job handlers to check for cancellation and exit early.
121    pub fn check_cancelled(&self) -> Result<()> {
122        if self.cancel_requested.load(Ordering::SeqCst) {
123            Err(crate::ForgeError::JobCancelled(
124                "Job cancellation requested".to_string(),
125            ))
126        } else {
127            Ok(())
128        }
129    }
130
131    /// Request cancellation (for testing cancellation flows).
132    ///
133    /// After calling this, `is_cancel_requested()` returns `true` and
134    /// `check_cancelled()` returns an error.
135    pub fn request_cancellation(&self) {
136        self.cancel_requested.store(true, Ordering::SeqCst);
137    }
138
139    /// Buffer a sub-job dispatch (mirrors `JobContext::dispatch_job`).
140    pub fn dispatch_job<T: Serialize>(&self, job_type: &str, args: &T) -> Result<Uuid> {
141        let id = Uuid::new_v4();
142        let json = serde_json::to_value(args)
143            .map_err(|e| crate::ForgeError::Serialization(e.to_string()))?;
144        self.dispatched_jobs
145            .write()
146            .unwrap()
147            .push((job_type.to_string(), json, id));
148        Ok(id)
149    }
150
151    /// Buffer a workflow start (mirrors `JobContext::start_workflow`).
152    pub fn start_workflow<T: Serialize>(&self, workflow_name: &str, args: &T) -> Result<Uuid> {
153        let id = Uuid::new_v4();
154        let json = serde_json::to_value(args)
155            .map_err(|e| crate::ForgeError::Serialization(e.to_string()))?;
156        self.started_workflows
157            .write()
158            .unwrap()
159            .push((workflow_name.to_string(), json, id));
160        Ok(id)
161    }
162
163    /// Get all dispatched sub-jobs for assertions.
164    pub fn dispatched_jobs(&self) -> Vec<(String, serde_json::Value, Uuid)> {
165        self.dispatched_jobs.read().unwrap().clone()
166    }
167
168    /// Get all started workflows for assertions.
169    pub fn started_workflows(&self) -> Vec<(String, serde_json::Value, Uuid)> {
170        self.started_workflows.read().unwrap().clone()
171    }
172
173    /// Get the mock env provider for verification.
174    pub fn env_mock(&self) -> &MockEnvProvider {
175        &self.env_provider
176    }
177}
178
179impl EnvAccess for TestJobContext {
180    fn env_provider(&self) -> &dyn EnvProvider {
181        self.env_provider.as_ref()
182    }
183}
184
185/// Builder for TestJobContext.
186pub struct TestJobContextBuilder {
187    job_id: Option<Uuid>,
188    job_type: String,
189    attempt: u32,
190    max_attempts: u32,
191    user_id: Option<Uuid>,
192    roles: Vec<String>,
193    claims: HashMap<String, serde_json::Value>,
194    pool: Option<PgPool>,
195    http: MockHttp,
196    env_vars: HashMap<String, String>,
197    cancel_requested: bool,
198}
199
200impl TestJobContextBuilder {
201    /// Create a new builder with job type.
202    pub fn new(job_type: impl Into<String>) -> Self {
203        Self {
204            job_id: None,
205            job_type: job_type.into(),
206            attempt: 1,
207            max_attempts: 1,
208            user_id: None,
209            roles: Vec::new(),
210            claims: HashMap::new(),
211            pool: None,
212            http: MockHttp::new(),
213            env_vars: HashMap::new(),
214            cancel_requested: false,
215        }
216    }
217
218    /// Set a specific job ID.
219    pub fn with_job_id(mut self, id: Uuid) -> Self {
220        self.job_id = Some(id);
221        self
222    }
223
224    /// Set as a retry (attempt > 1).
225    pub fn as_retry(mut self, attempt: u32) -> Self {
226        self.attempt = attempt.max(1);
227        self
228    }
229
230    /// Set the maximum attempts.
231    pub fn with_max_attempts(mut self, max: u32) -> Self {
232        self.max_attempts = max.max(1);
233        self
234    }
235
236    /// Set as the last attempt.
237    pub fn as_last_attempt(mut self) -> Self {
238        self.attempt = 3;
239        self.max_attempts = 3;
240        self
241    }
242
243    /// Set the authenticated user with a UUID.
244    pub fn as_user(mut self, id: Uuid) -> Self {
245        self.user_id = Some(id);
246        self
247    }
248
249    /// For non-UUID auth providers (Firebase, Clerk, etc.).
250    pub fn as_subject(mut self, subject: impl Into<String>) -> Self {
251        self.claims
252            .insert("sub".to_string(), serde_json::json!(subject.into()));
253        self
254    }
255
256    /// Add a role.
257    pub fn with_role(mut self, role: impl Into<String>) -> Self {
258        self.roles.push(role.into());
259        self
260    }
261
262    /// Add multiple roles.
263    pub fn with_roles(mut self, roles: Vec<String>) -> Self {
264        self.roles.extend(roles);
265        self
266    }
267
268    /// Add a custom claim.
269    pub fn with_claim(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
270        self.claims.insert(key.into(), value);
271        self
272    }
273
274    /// Set the database pool.
275    pub fn with_pool(mut self, pool: PgPool) -> Self {
276        self.pool = Some(pool);
277        self
278    }
279
280    /// Add an HTTP mock with a custom handler.
281    pub fn mock_http<F>(self, pattern: &str, handler: F) -> Self
282    where
283        F: Fn(&MockRequest) -> MockResponse + Send + Sync + 'static,
284    {
285        self.http.add_mock_sync(pattern, handler);
286        self
287    }
288
289    /// Add an HTTP mock that returns a JSON response.
290    pub fn mock_http_json<T: serde::Serialize>(self, pattern: &str, response: T) -> Self {
291        let json = serde_json::to_value(response).unwrap_or(serde_json::Value::Null);
292        self.mock_http(pattern, move |_| MockResponse::json(json.clone()))
293    }
294
295    /// Set a single environment variable.
296    pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
297        self.env_vars.insert(key.into(), value.into());
298        self
299    }
300
301    /// Set multiple environment variables.
302    pub fn with_envs(mut self, vars: HashMap<String, String>) -> Self {
303        self.env_vars.extend(vars);
304        self
305    }
306
307    /// Start with cancellation already requested.
308    ///
309    /// Use this to test how jobs handle cancellation from the start.
310    pub fn with_cancellation_requested(mut self) -> Self {
311        self.cancel_requested = true;
312        self
313    }
314
315    /// Build the test context.
316    pub fn build(self) -> TestJobContext {
317        TestJobContext {
318            job_id: self.job_id.unwrap_or_else(Uuid::new_v4),
319            job_type: self.job_type,
320            attempt: self.attempt,
321            max_attempts: self.max_attempts,
322            auth: build_test_auth(self.user_id, self.roles, self.claims),
323            pool: self.pool,
324            http: Arc::new(self.http),
325            progress_updates: Arc::new(RwLock::new(Vec::new())),
326            env_provider: Arc::new(MockEnvProvider::with_vars(self.env_vars)),
327            saved_data: Arc::new(RwLock::new(crate::job::empty_saved_data())),
328            cancel_requested: Arc::new(AtomicBool::new(self.cancel_requested)),
329            dispatched_jobs: Arc::new(RwLock::new(Vec::new())),
330            started_workflows: Arc::new(RwLock::new(Vec::new())),
331        }
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    #[test]
340    fn test_job_context_creation() {
341        let ctx = TestJobContext::builder("export_users").build();
342
343        assert_eq!(ctx.job_type, "export_users");
344        assert_eq!(ctx.attempt, 1);
345        assert!(!ctx.is_retry());
346        assert!(ctx.is_last_attempt()); // 1 of 1
347    }
348
349    #[test]
350    fn test_retry_detection() {
351        let ctx = TestJobContext::builder("test")
352            .as_retry(3)
353            .with_max_attempts(5)
354            .build();
355
356        assert!(ctx.is_retry());
357        assert!(!ctx.is_last_attempt());
358    }
359
360    #[test]
361    fn test_last_attempt() {
362        let ctx = TestJobContext::builder("test").as_last_attempt().build();
363
364        assert!(ctx.is_retry());
365        assert!(ctx.is_last_attempt());
366    }
367
368    #[test]
369    fn test_progress_tracking() {
370        let ctx = TestJobContext::builder("test").build();
371
372        ctx.progress(25, "Step 1 complete").unwrap();
373        ctx.progress(50, "Step 2 complete").unwrap();
374        ctx.progress(100, "Done").unwrap();
375
376        let updates = ctx.progress_updates();
377        assert_eq!(updates.len(), 3);
378        assert_eq!(updates[0].percent, 25);
379        assert_eq!(updates[2].percent, 100);
380    }
381
382    #[test]
383    fn test_save_and_saved() {
384        let ctx = TestJobContext::builder("test").build();
385        ctx.save("charge_id", serde_json::json!("ch_123")).unwrap();
386        ctx.save("amount", serde_json::json!(100)).unwrap();
387
388        let saved = ctx.saved();
389        assert_eq!(saved["charge_id"], "ch_123");
390        assert_eq!(saved["amount"], 100);
391    }
392
393    #[test]
394    fn test_cancellation_not_requested() {
395        let ctx = TestJobContext::builder("test").build();
396
397        assert!(!ctx.is_cancel_requested().unwrap());
398        assert!(ctx.check_cancelled().is_ok());
399    }
400
401    #[test]
402    fn test_cancellation_requested_at_build() {
403        let ctx = TestJobContext::builder("test")
404            .with_cancellation_requested()
405            .build();
406
407        assert!(ctx.is_cancel_requested().unwrap());
408        assert!(ctx.check_cancelled().is_err());
409    }
410
411    #[test]
412    fn test_request_cancellation_mid_test() {
413        let ctx = TestJobContext::builder("test").build();
414
415        assert!(!ctx.is_cancel_requested().unwrap());
416        ctx.request_cancellation();
417        assert!(ctx.is_cancel_requested().unwrap());
418        assert!(ctx.check_cancelled().is_err());
419    }
420}