Skip to main content

nestforge_testing/
lib.rs

1use std::{marker::PhantomData, sync::Arc};
2
3use anyhow::Result;
4use async_graphql::{ObjectType, Schema, SubscriptionType};
5use axum::{http::HeaderMap, Router};
6use nestforge_core::{
7    initialize_module_runtime, AuthIdentity, Container, ContainerError, InitializedModule,
8    ModuleDefinition, RequestId,
9};
10use nestforge_graphql::{graphql_router, graphql_router_with_config, GraphQlConfig};
11use nestforge_grpc::GrpcContext;
12use nestforge_microservices::{
13    InProcessMicroserviceClient, MicroserviceContext, MicroserviceRegistry, TransportMetadata,
14};
15use nestforge_websockets::WebSocketContext;
16
17type OverrideFn = Box<dyn Fn(&Container) -> Result<()> + Send + Sync>;
18
19/**
20 * TestFactory
21 *
22 * A factory for creating test instances of NestForge modules.
23 * Provides a convenient way to set up modules for testing with
24 * provider overrides.
25 *
26 * # Type Parameters
27 * - `M`: The module type to create a test instance for
28 *
29 * # Example
30 * ```rust
31 * let module = TestFactory::<AppModule>::create()
32 *     .override_provider(MyService::mock())
33 *     .build()
34 *     .unwrap();
35 * ```
36 */
37pub struct TestFactory<M: ModuleDefinition> {
38    overrides: Vec<OverrideFn>,
39    _marker: PhantomData<M>,
40}
41
42impl<M: ModuleDefinition> TestFactory<M> {
43    /**
44     * Creates a new TestFactory for the given module.
45     */
46    pub fn create() -> Self {
47        Self {
48            overrides: Vec::new(),
49            _marker: PhantomData,
50        }
51    }
52
53    /**
54     * Overrides a provider with a test value.
55     *
56     * # Type Parameters
57     * - `T`: The type to override
58     *
59     * # Arguments
60     * - `value`: The test value to use instead of the real provider
61     */
62    pub fn override_provider<T>(mut self, value: T) -> Self
63    where
64        T: Send + Sync + Clone + 'static,
65    {
66        self.overrides.push(Box::new(move |container| {
67            container.override_value(value.clone())?;
68            Ok(())
69        }));
70        self
71    }
72
73    /**
74     * Builds the TestingModule with all configured overrides.
75     */
76    pub fn build(self) -> Result<TestingModule> {
77        let container = Container::new();
78
79        for override_fn in self.overrides {
80            override_fn(&container)?;
81        }
82
83        let runtime = initialize_module_runtime::<M>(&container)?;
84        runtime.run_module_init(&container)?;
85        runtime.run_application_bootstrap(&container)?;
86
87        Ok(TestingModule {
88            container,
89            runtime: Arc::new(runtime),
90        })
91    }
92}
93
94/**
95 * TestingModule
96 *
97 * A fully initialized module for testing purposes.
98 * Provides access to the container, routers, and test utilities.
99 */
100#[derive(Clone)]
101pub struct TestingModule {
102    container: Container,
103    runtime: Arc<InitializedModule>,
104}
105
106impl TestingModule {
107    /**
108     * Returns a reference to the DI container.
109     */
110    pub fn container(&self) -> &Container {
111        &self.container
112    }
113
114    /**
115     * Resolves a service from the container.
116     *
117     * # Type Parameters
118     * - `T`: The type to resolve
119     */
120    pub fn resolve<T>(&self) -> Result<Arc<T>, ContainerError>
121    where
122        T: Send + Sync + 'static,
123    {
124        self.container.resolve::<T>()
125    }
126
127    /**
128     * Returns an HTTP router with all controllers merged.
129     */
130    pub fn http_router(&self) -> Router {
131        let mut app: Router<Container> = Router::new();
132        for controller_router in &self.runtime.controllers {
133            app = app.merge(controller_router.clone());
134        }
135
136        app.with_state(self.container.clone())
137    }
138
139    pub fn graphql_router<Query, Mutation, Subscription>(
140        &self,
141        schema: Schema<Query, Mutation, Subscription>,
142    ) -> Router
143    where
144        Query: ObjectType + Send + Sync + 'static,
145        Mutation: ObjectType + Send + Sync + 'static,
146        Subscription: SubscriptionType + Send + Sync + 'static,
147    {
148        self.http_router()
149            .merge(graphql_router(schema).with_state(self.container.clone()))
150    }
151
152    pub fn graphql_router_with_paths<Query, Mutation, Subscription>(
153        &self,
154        schema: Schema<Query, Mutation, Subscription>,
155        endpoint: impl Into<String>,
156        graphiql_endpoint: Option<String>,
157    ) -> Router
158    where
159        Query: ObjectType + Send + Sync + 'static,
160        Mutation: ObjectType + Send + Sync + 'static,
161        Subscription: SubscriptionType + Send + Sync + 'static,
162    {
163        let config = if let Some(graphiql_endpoint) = graphiql_endpoint {
164            GraphQlConfig::new(endpoint).with_graphiql(graphiql_endpoint)
165        } else {
166            GraphQlConfig::new(endpoint).without_graphiql()
167        };
168
169        self.http_router()
170            .merge(graphql_router_with_config(schema, config).with_state(self.container.clone()))
171    }
172
173    pub fn grpc_context(&self) -> GrpcContext {
174        GrpcContext::new(self.container.clone())
175    }
176
177    pub fn websocket_context(&self) -> WebSocketContext {
178        WebSocketContext::new(self.container.clone(), None, None, HeaderMap::new())
179    }
180
181    pub fn websocket_context_with(
182        &self,
183        request_id: Option<RequestId>,
184        auth_identity: Option<AuthIdentity>,
185        headers: HeaderMap,
186    ) -> WebSocketContext {
187        WebSocketContext::new(self.container.clone(), request_id, auth_identity, headers)
188    }
189
190    pub fn microservice_context(
191        &self,
192        transport: impl Into<String>,
193        pattern: impl Into<String>,
194    ) -> MicroserviceContext {
195        MicroserviceContext::new(
196            self.container.clone(),
197            transport,
198            pattern,
199            TransportMetadata::default(),
200        )
201    }
202
203    pub fn microservice_context_with_metadata(
204        &self,
205        transport: impl Into<String>,
206        pattern: impl Into<String>,
207        metadata: TransportMetadata,
208    ) -> MicroserviceContext {
209        MicroserviceContext::new(self.container.clone(), transport, pattern, metadata)
210    }
211
212    pub fn microservice_client(
213        &self,
214        registry: MicroserviceRegistry,
215    ) -> InProcessMicroserviceClient {
216        InProcessMicroserviceClient::new(self.container.clone(), registry)
217    }
218
219    pub fn microservice_client_with_metadata(
220        &self,
221        registry: MicroserviceRegistry,
222        transport: impl Into<String>,
223        metadata: TransportMetadata,
224    ) -> InProcessMicroserviceClient {
225        InProcessMicroserviceClient::new(self.container.clone(), registry)
226            .with_transport(transport)
227            .with_metadata(metadata)
228    }
229
230    pub fn shutdown(&self) -> Result<()> {
231        self.runtime.run_module_destroy(&self.container)?;
232        self.runtime.run_application_shutdown(&self.container)?;
233        Ok(())
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use std::sync::{Arc as StdArc, Mutex};
241
242    use async_graphql::{EmptyMutation, EmptySubscription};
243    use nestforge_core::{register_provider, ControllerDefinition, LifecycleHook, Provider};
244    use nestforge_microservices::MicroserviceClient;
245    use tower::ServiceExt;
246
247    #[derive(Clone, Debug, PartialEq, Eq)]
248    struct AppConfig {
249        app_name: &'static str,
250    }
251
252    struct AppModule;
253    impl ModuleDefinition for AppModule {
254        fn register(container: &Container) -> Result<()> {
255            container.register(AppConfig {
256                app_name: "default",
257            })?;
258            Ok(())
259        }
260    }
261
262    #[test]
263    fn builds_testing_module_and_resolves_default_provider() {
264        let module = TestFactory::<AppModule>::create()
265            .build()
266            .expect("test module should build");
267
268        let config = module
269            .resolve::<AppConfig>()
270            .expect("config should resolve");
271        assert_eq!(
272            *config,
273            AppConfig {
274                app_name: "default"
275            }
276        );
277    }
278
279    #[test]
280    fn overrides_provider_value() {
281        let module = TestFactory::<AppModule>::create()
282            .override_provider(AppConfig { app_name: "test" })
283            .build()
284            .expect("test module should build with overrides");
285
286        let config = module
287            .resolve::<AppConfig>()
288            .expect("config should resolve");
289        assert_eq!(*config, AppConfig { app_name: "test" });
290    }
291
292    #[derive(Clone, Debug, PartialEq, Eq)]
293    struct GreetingService {
294        greeting: String,
295    }
296
297    struct FactoryModule;
298    impl ModuleDefinition for FactoryModule {
299        fn register(container: &Container) -> Result<()> {
300            register_provider(
301                container,
302                Provider::factory(|container| {
303                    let config = container.resolve::<AppConfig>()?;
304                    Ok(GreetingService {
305                        greeting: format!("hello {}", config.app_name),
306                    })
307                }),
308            )?;
309            Ok(())
310        }
311    }
312
313    #[test]
314    fn overrides_are_applied_before_factory_resolution() {
315        let module = TestFactory::<FactoryModule>::create()
316            .override_provider(AppConfig {
317                app_name: "override",
318            })
319            .build()
320            .expect("test module should build with transitive overrides");
321
322        let greeting = module
323            .resolve::<GreetingService>()
324            .expect("greeting service should resolve");
325        assert_eq!(greeting.greeting, "hello override");
326    }
327
328    struct HttpController;
329
330    impl ControllerDefinition for HttpController {
331        fn router() -> Router<Container> {
332            Router::new().route(
333                "/health",
334                axum::routing::get(|| async { axum::Json(serde_json::json!({ "ok": true })) }),
335            )
336        }
337    }
338
339    struct HttpModule;
340    impl ModuleDefinition for HttpModule {
341        fn register(_container: &Container) -> Result<()> {
342            Ok(())
343        }
344
345        fn controllers() -> Vec<Router<Container>> {
346            vec![HttpController::router()]
347        }
348    }
349
350    #[tokio::test]
351    async fn builds_http_router_from_testing_module_runtime() {
352        let module = TestFactory::<HttpModule>::create()
353            .build()
354            .expect("http testing module should build");
355
356        let response = module
357            .http_router()
358            .oneshot(
359                axum::http::Request::builder()
360                    .uri("/health")
361                    .body(axum::body::Body::empty())
362                    .expect("request should build"),
363            )
364            .await
365            .expect("request should succeed");
366
367        assert_eq!(response.status(), axum::http::StatusCode::OK);
368    }
369
370    struct QueryRoot;
371
372    #[async_graphql::Object]
373    impl QueryRoot {
374        async fn app_name(&self, ctx: &async_graphql::Context<'_>) -> &str {
375            let config = ctx
376                .data::<Container>()
377                .expect("container should be present")
378                .resolve::<AppConfig>()
379                .expect("app config should resolve");
380
381            config.app_name
382        }
383    }
384
385    #[tokio::test]
386    async fn builds_graphql_router_from_testing_module_runtime() {
387        let module = TestFactory::<AppModule>::create()
388            .override_provider(AppConfig {
389                app_name: "graphql",
390            })
391            .build()
392            .expect("graphql testing module should build");
393        let schema = Schema::build(QueryRoot, EmptyMutation, EmptySubscription).finish();
394
395        let response = module
396            .graphql_router_with_paths(schema, "/graphql", None)
397            .oneshot(
398                axum::http::Request::builder()
399                    .method("POST")
400                    .uri("/graphql")
401                    .header(axum::http::header::CONTENT_TYPE, "application/json")
402                    .body(axum::body::Body::from(
403                        serde_json::json!({ "query": "{ appName }" }).to_string(),
404                    ))
405                    .expect("request should build"),
406            )
407            .await
408            .expect("graphql request should succeed");
409
410        assert_eq!(response.status(), axum::http::StatusCode::OK);
411    }
412
413    #[derive(Clone)]
414    struct HookLog(StdArc<Mutex<Vec<&'static str>>>);
415
416    fn record_destroy(container: &Container) -> Result<()> {
417        let log = container.resolve::<HookLog>()?;
418        log.0
419            .lock()
420            .expect("hook log should be writable")
421            .push("destroy");
422        Ok(())
423    }
424
425    fn record_shutdown(container: &Container) -> Result<()> {
426        let log = container.resolve::<HookLog>()?;
427        log.0
428            .lock()
429            .expect("hook log should be writable")
430            .push("shutdown");
431        Ok(())
432    }
433
434    struct HookModule;
435
436    impl ModuleDefinition for HookModule {
437        fn register(container: &Container) -> Result<()> {
438            container.register(HookLog(StdArc::new(Mutex::new(Vec::new()))))?;
439            Ok(())
440        }
441
442        fn on_module_destroy() -> Vec<LifecycleHook> {
443            vec![record_destroy]
444        }
445
446        fn on_application_shutdown() -> Vec<LifecycleHook> {
447            vec![record_shutdown]
448        }
449    }
450
451    #[test]
452    fn shutdown_runs_destroy_and_shutdown_hooks() {
453        let module = TestFactory::<HookModule>::create()
454            .build()
455            .expect("hook testing module should build");
456
457        module.shutdown().expect("testing module should shut down");
458
459        let log = module
460            .resolve::<HookLog>()
461            .expect("hook log should resolve");
462        let entries = log.0.lock().expect("hook log should be readable").clone();
463        assert_eq!(entries, vec!["destroy", "shutdown"]);
464    }
465
466    #[test]
467    fn builds_grpc_context_from_testing_module_runtime() {
468        let module = TestFactory::<AppModule>::create()
469            .override_provider(AppConfig { app_name: "grpc" })
470            .build()
471            .expect("grpc testing module should build");
472
473        let ctx = module.grpc_context();
474        let config = ctx
475            .resolve::<AppConfig>()
476            .expect("app config should resolve");
477
478        assert_eq!(config.app_name, "grpc");
479    }
480
481    #[test]
482    fn builds_transport_contexts_with_testing_container() {
483        let module = TestFactory::<AppModule>::create()
484            .override_provider(AppConfig {
485                app_name: "transport",
486            })
487            .build()
488            .expect("transport testing module should build");
489
490        let websocket = module.websocket_context();
491        let config = websocket
492            .resolve::<AppConfig>()
493            .expect("app config should resolve from websocket context");
494        assert_eq!(config.app_name, "transport");
495
496        let microservice = module.microservice_context("test", "users.count");
497        let config = microservice
498            .resolve::<AppConfig>()
499            .expect("app config should resolve from microservice context");
500        assert_eq!(config.app_name, "transport");
501        assert_eq!(microservice.transport(), "test");
502        assert_eq!(microservice.pattern(), "users.count");
503    }
504
505    #[tokio::test]
506    async fn builds_in_process_microservice_clients_from_testing_module() {
507        let module = TestFactory::<AppModule>::create()
508            .override_provider(AppConfig { app_name: "client" })
509            .build()
510            .expect("client testing module should build");
511        let registry = MicroserviceRegistry::builder()
512            .message("app.name", |_payload: (), ctx| async move {
513                let config = ctx.resolve::<AppConfig>()?;
514                Ok(config.app_name.to_string())
515            })
516            .build();
517        let client = module.microservice_client_with_metadata(
518            registry,
519            "test-client",
520            TransportMetadata::new().insert("suite", "testing"),
521        );
522
523        let response: String = client
524            .send("app.name", ())
525            .await
526            .expect("response should resolve");
527
528        assert_eq!(response, "client");
529    }
530}