Skip to main content

nestforge_testing/
lib.rs

1use std::{marker::PhantomData, sync::Arc};
2
3use anyhow::Result;
4use async_graphql::{ObjectType, Schema, SubscriptionType};
5use axum::{http::HeaderMap, Router};
6use nestforge_core::{
7    initialize_module_runtime, AuthIdentity, Container, ContainerError, InitializedModule,
8    ModuleDefinition, RequestId,
9};
10use nestforge_graphql::{graphql_router, graphql_router_with_config, GraphQlConfig};
11use nestforge_grpc::GrpcContext;
12use nestforge_microservices::{
13    InProcessMicroserviceClient, MicroserviceContext, MicroserviceRegistry, TransportMetadata,
14};
15use nestforge_websockets::WebSocketContext;
16
17type OverrideFn = Box<dyn Fn(&Container) -> Result<()> + Send + Sync>;
18
19pub struct TestFactory<M: ModuleDefinition> {
20    overrides: Vec<OverrideFn>,
21    _marker: PhantomData<M>,
22}
23
24impl<M: ModuleDefinition> TestFactory<M> {
25    pub fn create() -> Self {
26        Self {
27            overrides: Vec::new(),
28            _marker: PhantomData,
29        }
30    }
31
32    pub fn override_provider<T>(mut self, value: T) -> Self
33    where
34        T: Send + Sync + Clone + 'static,
35    {
36        self.overrides.push(Box::new(move |container| {
37            container.override_value(value.clone())?;
38            Ok(())
39        }));
40        self
41    }
42
43    pub fn build(self) -> Result<TestingModule> {
44        let container = Container::new();
45
46        for override_fn in self.overrides {
47            override_fn(&container)?;
48        }
49
50        let runtime = initialize_module_runtime::<M>(&container)?;
51        runtime.run_module_init(&container)?;
52        runtime.run_application_bootstrap(&container)?;
53
54        Ok(TestingModule {
55            container,
56            runtime: Arc::new(runtime),
57        })
58    }
59}
60
61#[derive(Clone)]
62pub struct TestingModule {
63    container: Container,
64    runtime: Arc<InitializedModule>,
65}
66
67impl TestingModule {
68    pub fn container(&self) -> &Container {
69        &self.container
70    }
71
72    pub fn resolve<T>(&self) -> Result<Arc<T>, ContainerError>
73    where
74        T: Send + Sync + 'static,
75    {
76        self.container.resolve::<T>()
77    }
78
79    pub fn http_router(&self) -> Router {
80        let mut app: Router<Container> = Router::new();
81        for controller_router in &self.runtime.controllers {
82            app = app.merge(controller_router.clone());
83        }
84
85        app.with_state(self.container.clone())
86    }
87
88    pub fn graphql_router<Query, Mutation, Subscription>(
89        &self,
90        schema: Schema<Query, Mutation, Subscription>,
91    ) -> Router
92    where
93        Query: ObjectType + Send + Sync + 'static,
94        Mutation: ObjectType + Send + Sync + 'static,
95        Subscription: SubscriptionType + Send + Sync + 'static,
96    {
97        self.http_router()
98            .merge(graphql_router(schema).with_state(self.container.clone()))
99    }
100
101    pub fn graphql_router_with_paths<Query, Mutation, Subscription>(
102        &self,
103        schema: Schema<Query, Mutation, Subscription>,
104        endpoint: impl Into<String>,
105        graphiql_endpoint: Option<String>,
106    ) -> Router
107    where
108        Query: ObjectType + Send + Sync + 'static,
109        Mutation: ObjectType + Send + Sync + 'static,
110        Subscription: SubscriptionType + Send + Sync + 'static,
111    {
112        let config = if let Some(graphiql_endpoint) = graphiql_endpoint {
113            GraphQlConfig::new(endpoint).with_graphiql(graphiql_endpoint)
114        } else {
115            GraphQlConfig::new(endpoint).without_graphiql()
116        };
117
118        self.http_router()
119            .merge(graphql_router_with_config(schema, config).with_state(self.container.clone()))
120    }
121
122    pub fn grpc_context(&self) -> GrpcContext {
123        GrpcContext::new(self.container.clone())
124    }
125
126    pub fn websocket_context(&self) -> WebSocketContext {
127        WebSocketContext::new(self.container.clone(), None, None, HeaderMap::new())
128    }
129
130    pub fn websocket_context_with(
131        &self,
132        request_id: Option<RequestId>,
133        auth_identity: Option<AuthIdentity>,
134        headers: HeaderMap,
135    ) -> WebSocketContext {
136        WebSocketContext::new(self.container.clone(), request_id, auth_identity, headers)
137    }
138
139    pub fn microservice_context(
140        &self,
141        transport: impl Into<String>,
142        pattern: impl Into<String>,
143    ) -> MicroserviceContext {
144        MicroserviceContext::new(
145            self.container.clone(),
146            transport,
147            pattern,
148            TransportMetadata::default(),
149        )
150    }
151
152    pub fn microservice_context_with_metadata(
153        &self,
154        transport: impl Into<String>,
155        pattern: impl Into<String>,
156        metadata: TransportMetadata,
157    ) -> MicroserviceContext {
158        MicroserviceContext::new(self.container.clone(), transport, pattern, metadata)
159    }
160
161    pub fn microservice_client(
162        &self,
163        registry: MicroserviceRegistry,
164    ) -> InProcessMicroserviceClient {
165        InProcessMicroserviceClient::new(self.container.clone(), registry)
166    }
167
168    pub fn microservice_client_with_metadata(
169        &self,
170        registry: MicroserviceRegistry,
171        transport: impl Into<String>,
172        metadata: TransportMetadata,
173    ) -> InProcessMicroserviceClient {
174        InProcessMicroserviceClient::new(self.container.clone(), registry)
175            .with_transport(transport)
176            .with_metadata(metadata)
177    }
178
179    pub fn shutdown(&self) -> Result<()> {
180        self.runtime.run_module_destroy(&self.container)?;
181        self.runtime.run_application_shutdown(&self.container)?;
182        Ok(())
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use std::sync::{Arc as StdArc, Mutex};
190
191    use async_graphql::{EmptyMutation, EmptySubscription};
192    use nestforge_microservices::MicroserviceClient;
193    use nestforge_core::{register_provider, ControllerDefinition, LifecycleHook, Provider};
194    use tower::ServiceExt;
195
196    #[derive(Clone, Debug, PartialEq, Eq)]
197    struct AppConfig {
198        app_name: &'static str,
199    }
200
201    struct AppModule;
202    impl ModuleDefinition for AppModule {
203        fn register(container: &Container) -> Result<()> {
204            container.register(AppConfig {
205                app_name: "default",
206            })?;
207            Ok(())
208        }
209    }
210
211    #[test]
212    fn builds_testing_module_and_resolves_default_provider() {
213        let module = TestFactory::<AppModule>::create()
214            .build()
215            .expect("test module should build");
216
217        let config = module
218            .resolve::<AppConfig>()
219            .expect("config should resolve");
220        assert_eq!(
221            *config,
222            AppConfig {
223                app_name: "default"
224            }
225        );
226    }
227
228    #[test]
229    fn overrides_provider_value() {
230        let module = TestFactory::<AppModule>::create()
231            .override_provider(AppConfig { app_name: "test" })
232            .build()
233            .expect("test module should build with overrides");
234
235        let config = module
236            .resolve::<AppConfig>()
237            .expect("config should resolve");
238        assert_eq!(*config, AppConfig { app_name: "test" });
239    }
240
241    #[derive(Clone, Debug, PartialEq, Eq)]
242    struct GreetingService {
243        greeting: String,
244    }
245
246    struct FactoryModule;
247    impl ModuleDefinition for FactoryModule {
248        fn register(container: &Container) -> Result<()> {
249            register_provider(
250                container,
251                Provider::factory(|container| {
252                    let config = container.resolve::<AppConfig>()?;
253                    Ok(GreetingService {
254                        greeting: format!("hello {}", config.app_name),
255                    })
256                }),
257            )?;
258            Ok(())
259        }
260    }
261
262    #[test]
263    fn overrides_are_applied_before_factory_resolution() {
264        let module = TestFactory::<FactoryModule>::create()
265            .override_provider(AppConfig {
266                app_name: "override",
267            })
268            .build()
269            .expect("test module should build with transitive overrides");
270
271        let greeting = module
272            .resolve::<GreetingService>()
273            .expect("greeting service should resolve");
274        assert_eq!(greeting.greeting, "hello override");
275    }
276
277    struct HttpController;
278
279    impl ControllerDefinition for HttpController {
280        fn router() -> Router<Container> {
281            Router::new().route(
282                "/health",
283                axum::routing::get(|| async { axum::Json(serde_json::json!({ "ok": true })) }),
284            )
285        }
286    }
287
288    struct HttpModule;
289    impl ModuleDefinition for HttpModule {
290        fn register(_container: &Container) -> Result<()> {
291            Ok(())
292        }
293
294        fn controllers() -> Vec<Router<Container>> {
295            vec![HttpController::router()]
296        }
297    }
298
299    #[tokio::test]
300    async fn builds_http_router_from_testing_module_runtime() {
301        let module = TestFactory::<HttpModule>::create()
302            .build()
303            .expect("http testing module should build");
304
305        let response = module
306            .http_router()
307            .oneshot(
308                axum::http::Request::builder()
309                    .uri("/health")
310                    .body(axum::body::Body::empty())
311                    .expect("request should build"),
312            )
313            .await
314            .expect("request should succeed");
315
316        assert_eq!(response.status(), axum::http::StatusCode::OK);
317    }
318
319    struct QueryRoot;
320
321    #[async_graphql::Object]
322    impl QueryRoot {
323        async fn app_name(&self, ctx: &async_graphql::Context<'_>) -> &str {
324            let config = ctx
325                .data::<Container>()
326                .expect("container should be present")
327                .resolve::<AppConfig>()
328                .expect("app config should resolve");
329
330            config.app_name
331        }
332    }
333
334    #[tokio::test]
335    async fn builds_graphql_router_from_testing_module_runtime() {
336        let module = TestFactory::<AppModule>::create()
337            .override_provider(AppConfig { app_name: "graphql" })
338            .build()
339            .expect("graphql testing module should build");
340        let schema = Schema::build(QueryRoot, EmptyMutation, EmptySubscription).finish();
341
342        let response = module
343            .graphql_router_with_paths(schema, "/graphql", None)
344            .oneshot(
345                axum::http::Request::builder()
346                    .method("POST")
347                    .uri("/graphql")
348                    .header(axum::http::header::CONTENT_TYPE, "application/json")
349                    .body(axum::body::Body::from(
350                        serde_json::json!({ "query": "{ appName }" }).to_string(),
351                    ))
352                    .expect("request should build"),
353            )
354            .await
355            .expect("graphql request should succeed");
356
357        assert_eq!(response.status(), axum::http::StatusCode::OK);
358    }
359
360    #[derive(Clone)]
361    struct HookLog(StdArc<Mutex<Vec<&'static str>>>);
362
363    fn record_destroy(container: &Container) -> Result<()> {
364        let log = container.resolve::<HookLog>()?;
365        log.0.lock().expect("hook log should be writable").push("destroy");
366        Ok(())
367    }
368
369    fn record_shutdown(container: &Container) -> Result<()> {
370        let log = container.resolve::<HookLog>()?;
371        log.0
372            .lock()
373            .expect("hook log should be writable")
374            .push("shutdown");
375        Ok(())
376    }
377
378    struct HookModule;
379
380    impl ModuleDefinition for HookModule {
381        fn register(container: &Container) -> Result<()> {
382            container.register(HookLog(StdArc::new(Mutex::new(Vec::new()))))?;
383            Ok(())
384        }
385
386        fn on_module_destroy() -> Vec<LifecycleHook> {
387            vec![record_destroy]
388        }
389
390        fn on_application_shutdown() -> Vec<LifecycleHook> {
391            vec![record_shutdown]
392        }
393    }
394
395    #[test]
396    fn shutdown_runs_destroy_and_shutdown_hooks() {
397        let module = TestFactory::<HookModule>::create()
398            .build()
399            .expect("hook testing module should build");
400
401        module.shutdown().expect("testing module should shut down");
402
403        let log = module.resolve::<HookLog>().expect("hook log should resolve");
404        let entries = log.0.lock().expect("hook log should be readable").clone();
405        assert_eq!(entries, vec!["destroy", "shutdown"]);
406    }
407
408    #[test]
409    fn builds_grpc_context_from_testing_module_runtime() {
410        let module = TestFactory::<AppModule>::create()
411            .override_provider(AppConfig { app_name: "grpc" })
412            .build()
413            .expect("grpc testing module should build");
414
415        let ctx = module.grpc_context();
416        let config = ctx.resolve::<AppConfig>().expect("app config should resolve");
417
418        assert_eq!(config.app_name, "grpc");
419    }
420
421    #[test]
422    fn builds_transport_contexts_with_testing_container() {
423        let module = TestFactory::<AppModule>::create()
424            .override_provider(AppConfig {
425                app_name: "transport",
426            })
427            .build()
428            .expect("transport testing module should build");
429
430        let websocket = module.websocket_context();
431        let config = websocket
432            .resolve::<AppConfig>()
433            .expect("app config should resolve from websocket context");
434        assert_eq!(config.app_name, "transport");
435
436        let microservice = module.microservice_context("test", "users.count");
437        let config = microservice
438            .resolve::<AppConfig>()
439            .expect("app config should resolve from microservice context");
440        assert_eq!(config.app_name, "transport");
441        assert_eq!(microservice.transport(), "test");
442        assert_eq!(microservice.pattern(), "users.count");
443    }
444
445    #[tokio::test]
446    async fn builds_in_process_microservice_clients_from_testing_module() {
447        let module = TestFactory::<AppModule>::create()
448            .override_provider(AppConfig { app_name: "client" })
449            .build()
450            .expect("client testing module should build");
451        let registry = MicroserviceRegistry::builder()
452            .message("app.name", |_payload: (), ctx| async move {
453                let config = ctx.resolve::<AppConfig>()?;
454                Ok(config.app_name.to_string())
455            })
456            .build();
457        let client = module.microservice_client_with_metadata(
458            registry,
459            "test-client",
460            TransportMetadata::new().insert("suite", "testing"),
461        );
462
463        let response: String = client
464            .send("app.name", ())
465            .await
466            .expect("response should resolve");
467
468        assert_eq!(response, "client");
469    }
470}