1use std::{marker::PhantomData, sync::Arc};
2
3use anyhow::Result;
4use async_graphql::{ObjectType, Schema, SubscriptionType};
5use axum::{http::HeaderMap, Router};
6use nestforge_core::{
7 initialize_module_runtime, AuthIdentity, Container, ContainerError, InitializedModule,
8 ModuleDefinition, RequestId,
9};
10use nestforge_graphql::{graphql_router, graphql_router_with_config, GraphQlConfig};
11use nestforge_grpc::GrpcContext;
12use nestforge_microservices::{
13 InProcessMicroserviceClient, MicroserviceContext, MicroserviceRegistry, TransportMetadata,
14};
15use nestforge_websockets::WebSocketContext;
16
17type OverrideFn = Box<dyn Fn(&Container) -> Result<()> + Send + Sync>;
18
19pub struct TestFactory<M: ModuleDefinition> {
38 overrides: Vec<OverrideFn>,
39 _marker: PhantomData<M>,
40}
41
42impl<M: ModuleDefinition> TestFactory<M> {
43 pub fn create() -> Self {
47 Self {
48 overrides: Vec::new(),
49 _marker: PhantomData,
50 }
51 }
52
53 pub fn override_provider<T>(mut self, value: T) -> Self
63 where
64 T: Send + Sync + Clone + 'static,
65 {
66 self.overrides.push(Box::new(move |container| {
67 container.override_value(value.clone())?;
68 Ok(())
69 }));
70 self
71 }
72
73 pub fn build(self) -> Result<TestingModule> {
77 let container = Container::new();
78
79 for override_fn in self.overrides {
80 override_fn(&container)?;
81 }
82
83 let runtime = initialize_module_runtime::<M>(&container)?;
84 runtime.run_module_init(&container)?;
85 runtime.run_application_bootstrap(&container)?;
86
87 Ok(TestingModule {
88 container,
89 runtime: Arc::new(runtime),
90 })
91 }
92}
93
94#[derive(Clone)]
101pub struct TestingModule {
102 container: Container,
103 runtime: Arc<InitializedModule>,
104}
105
106impl TestingModule {
107 pub fn container(&self) -> &Container {
111 &self.container
112 }
113
114 pub fn resolve<T>(&self) -> Result<Arc<T>, ContainerError>
121 where
122 T: Send + Sync + 'static,
123 {
124 self.container.resolve::<T>()
125 }
126
127 pub fn http_router(&self) -> Router {
131 let mut app: Router<Container> = Router::new();
132 for controller_router in &self.runtime.controllers {
133 app = app.merge(controller_router.clone());
134 }
135
136 app.with_state(self.container.clone())
137 }
138
139 pub fn graphql_router<Query, Mutation, Subscription>(
140 &self,
141 schema: Schema<Query, Mutation, Subscription>,
142 ) -> Router
143 where
144 Query: ObjectType + Send + Sync + 'static,
145 Mutation: ObjectType + Send + Sync + 'static,
146 Subscription: SubscriptionType + Send + Sync + 'static,
147 {
148 self.http_router()
149 .merge(graphql_router(schema).with_state(self.container.clone()))
150 }
151
152 pub fn graphql_router_with_paths<Query, Mutation, Subscription>(
153 &self,
154 schema: Schema<Query, Mutation, Subscription>,
155 endpoint: impl Into<String>,
156 graphiql_endpoint: Option<String>,
157 ) -> Router
158 where
159 Query: ObjectType + Send + Sync + 'static,
160 Mutation: ObjectType + Send + Sync + 'static,
161 Subscription: SubscriptionType + Send + Sync + 'static,
162 {
163 let config = if let Some(graphiql_endpoint) = graphiql_endpoint {
164 GraphQlConfig::new(endpoint).with_graphiql(graphiql_endpoint)
165 } else {
166 GraphQlConfig::new(endpoint).without_graphiql()
167 };
168
169 self.http_router()
170 .merge(graphql_router_with_config(schema, config).with_state(self.container.clone()))
171 }
172
173 pub fn grpc_context(&self) -> GrpcContext {
174 GrpcContext::new(self.container.clone())
175 }
176
177 pub fn websocket_context(&self) -> WebSocketContext {
178 WebSocketContext::new(self.container.clone(), None, None, HeaderMap::new())
179 }
180
181 pub fn websocket_context_with(
182 &self,
183 request_id: Option<RequestId>,
184 auth_identity: Option<AuthIdentity>,
185 headers: HeaderMap,
186 ) -> WebSocketContext {
187 WebSocketContext::new(self.container.clone(), request_id, auth_identity, headers)
188 }
189
190 pub fn microservice_context(
191 &self,
192 transport: impl Into<String>,
193 pattern: impl Into<String>,
194 ) -> MicroserviceContext {
195 MicroserviceContext::new(
196 self.container.clone(),
197 transport,
198 pattern,
199 TransportMetadata::default(),
200 )
201 }
202
203 pub fn microservice_context_with_metadata(
204 &self,
205 transport: impl Into<String>,
206 pattern: impl Into<String>,
207 metadata: TransportMetadata,
208 ) -> MicroserviceContext {
209 MicroserviceContext::new(self.container.clone(), transport, pattern, metadata)
210 }
211
212 pub fn microservice_client(
213 &self,
214 registry: MicroserviceRegistry,
215 ) -> InProcessMicroserviceClient {
216 InProcessMicroserviceClient::new(self.container.clone(), registry)
217 }
218
219 pub fn microservice_client_with_metadata(
220 &self,
221 registry: MicroserviceRegistry,
222 transport: impl Into<String>,
223 metadata: TransportMetadata,
224 ) -> InProcessMicroserviceClient {
225 InProcessMicroserviceClient::new(self.container.clone(), registry)
226 .with_transport(transport)
227 .with_metadata(metadata)
228 }
229
230 pub fn shutdown(&self) -> Result<()> {
231 self.runtime.run_module_destroy(&self.container)?;
232 self.runtime.run_application_shutdown(&self.container)?;
233 Ok(())
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use std::sync::{Arc as StdArc, Mutex};
241
242 use async_graphql::{EmptyMutation, EmptySubscription};
243 use nestforge_core::{register_provider, ControllerDefinition, LifecycleHook, Provider};
244 use nestforge_microservices::MicroserviceClient;
245 use tower::ServiceExt;
246
247 #[derive(Clone, Debug, PartialEq, Eq)]
248 struct AppConfig {
249 app_name: &'static str,
250 }
251
252 struct AppModule;
253 impl ModuleDefinition for AppModule {
254 fn register(container: &Container) -> Result<()> {
255 container.register(AppConfig {
256 app_name: "default",
257 })?;
258 Ok(())
259 }
260 }
261
262 #[test]
263 fn builds_testing_module_and_resolves_default_provider() {
264 let module = TestFactory::<AppModule>::create()
265 .build()
266 .expect("test module should build");
267
268 let config = module
269 .resolve::<AppConfig>()
270 .expect("config should resolve");
271 assert_eq!(
272 *config,
273 AppConfig {
274 app_name: "default"
275 }
276 );
277 }
278
279 #[test]
280 fn overrides_provider_value() {
281 let module = TestFactory::<AppModule>::create()
282 .override_provider(AppConfig { app_name: "test" })
283 .build()
284 .expect("test module should build with overrides");
285
286 let config = module
287 .resolve::<AppConfig>()
288 .expect("config should resolve");
289 assert_eq!(*config, AppConfig { app_name: "test" });
290 }
291
292 #[derive(Clone, Debug, PartialEq, Eq)]
293 struct GreetingService {
294 greeting: String,
295 }
296
297 struct FactoryModule;
298 impl ModuleDefinition for FactoryModule {
299 fn register(container: &Container) -> Result<()> {
300 register_provider(
301 container,
302 Provider::factory(|container| {
303 let config = container.resolve::<AppConfig>()?;
304 Ok(GreetingService {
305 greeting: format!("hello {}", config.app_name),
306 })
307 }),
308 )?;
309 Ok(())
310 }
311 }
312
313 #[test]
314 fn overrides_are_applied_before_factory_resolution() {
315 let module = TestFactory::<FactoryModule>::create()
316 .override_provider(AppConfig {
317 app_name: "override",
318 })
319 .build()
320 .expect("test module should build with transitive overrides");
321
322 let greeting = module
323 .resolve::<GreetingService>()
324 .expect("greeting service should resolve");
325 assert_eq!(greeting.greeting, "hello override");
326 }
327
328 struct HttpController;
329
330 impl ControllerDefinition for HttpController {
331 fn router() -> Router<Container> {
332 Router::new().route(
333 "/health",
334 axum::routing::get(|| async { axum::Json(serde_json::json!({ "ok": true })) }),
335 )
336 }
337 }
338
339 struct HttpModule;
340 impl ModuleDefinition for HttpModule {
341 fn register(_container: &Container) -> Result<()> {
342 Ok(())
343 }
344
345 fn controllers() -> Vec<Router<Container>> {
346 vec![HttpController::router()]
347 }
348 }
349
350 #[tokio::test]
351 async fn builds_http_router_from_testing_module_runtime() {
352 let module = TestFactory::<HttpModule>::create()
353 .build()
354 .expect("http testing module should build");
355
356 let response = module
357 .http_router()
358 .oneshot(
359 axum::http::Request::builder()
360 .uri("/health")
361 .body(axum::body::Body::empty())
362 .expect("request should build"),
363 )
364 .await
365 .expect("request should succeed");
366
367 assert_eq!(response.status(), axum::http::StatusCode::OK);
368 }
369
370 struct QueryRoot;
371
372 #[async_graphql::Object]
373 impl QueryRoot {
374 async fn app_name(&self, ctx: &async_graphql::Context<'_>) -> &str {
375 let config = ctx
376 .data::<Container>()
377 .expect("container should be present")
378 .resolve::<AppConfig>()
379 .expect("app config should resolve");
380
381 config.app_name
382 }
383 }
384
385 #[tokio::test]
386 async fn builds_graphql_router_from_testing_module_runtime() {
387 let module = TestFactory::<AppModule>::create()
388 .override_provider(AppConfig {
389 app_name: "graphql",
390 })
391 .build()
392 .expect("graphql testing module should build");
393 let schema = Schema::build(QueryRoot, EmptyMutation, EmptySubscription).finish();
394
395 let response = module
396 .graphql_router_with_paths(schema, "/graphql", None)
397 .oneshot(
398 axum::http::Request::builder()
399 .method("POST")
400 .uri("/graphql")
401 .header(axum::http::header::CONTENT_TYPE, "application/json")
402 .body(axum::body::Body::from(
403 serde_json::json!({ "query": "{ appName }" }).to_string(),
404 ))
405 .expect("request should build"),
406 )
407 .await
408 .expect("graphql request should succeed");
409
410 assert_eq!(response.status(), axum::http::StatusCode::OK);
411 }
412
413 #[derive(Clone)]
414 struct HookLog(StdArc<Mutex<Vec<&'static str>>>);
415
416 fn record_destroy(container: &Container) -> Result<()> {
417 let log = container.resolve::<HookLog>()?;
418 log.0
419 .lock()
420 .expect("hook log should be writable")
421 .push("destroy");
422 Ok(())
423 }
424
425 fn record_shutdown(container: &Container) -> Result<()> {
426 let log = container.resolve::<HookLog>()?;
427 log.0
428 .lock()
429 .expect("hook log should be writable")
430 .push("shutdown");
431 Ok(())
432 }
433
434 struct HookModule;
435
436 impl ModuleDefinition for HookModule {
437 fn register(container: &Container) -> Result<()> {
438 container.register(HookLog(StdArc::new(Mutex::new(Vec::new()))))?;
439 Ok(())
440 }
441
442 fn on_module_destroy() -> Vec<LifecycleHook> {
443 vec![record_destroy]
444 }
445
446 fn on_application_shutdown() -> Vec<LifecycleHook> {
447 vec![record_shutdown]
448 }
449 }
450
451 #[test]
452 fn shutdown_runs_destroy_and_shutdown_hooks() {
453 let module = TestFactory::<HookModule>::create()
454 .build()
455 .expect("hook testing module should build");
456
457 module.shutdown().expect("testing module should shut down");
458
459 let log = module
460 .resolve::<HookLog>()
461 .expect("hook log should resolve");
462 let entries = log.0.lock().expect("hook log should be readable").clone();
463 assert_eq!(entries, vec!["destroy", "shutdown"]);
464 }
465
466 #[test]
467 fn builds_grpc_context_from_testing_module_runtime() {
468 let module = TestFactory::<AppModule>::create()
469 .override_provider(AppConfig { app_name: "grpc" })
470 .build()
471 .expect("grpc testing module should build");
472
473 let ctx = module.grpc_context();
474 let config = ctx
475 .resolve::<AppConfig>()
476 .expect("app config should resolve");
477
478 assert_eq!(config.app_name, "grpc");
479 }
480
481 #[test]
482 fn builds_transport_contexts_with_testing_container() {
483 let module = TestFactory::<AppModule>::create()
484 .override_provider(AppConfig {
485 app_name: "transport",
486 })
487 .build()
488 .expect("transport testing module should build");
489
490 let websocket = module.websocket_context();
491 let config = websocket
492 .resolve::<AppConfig>()
493 .expect("app config should resolve from websocket context");
494 assert_eq!(config.app_name, "transport");
495
496 let microservice = module.microservice_context("test", "users.count");
497 let config = microservice
498 .resolve::<AppConfig>()
499 .expect("app config should resolve from microservice context");
500 assert_eq!(config.app_name, "transport");
501 assert_eq!(microservice.transport(), "test");
502 assert_eq!(microservice.pattern(), "users.count");
503 }
504
505 #[tokio::test]
506 async fn builds_in_process_microservice_clients_from_testing_module() {
507 let module = TestFactory::<AppModule>::create()
508 .override_provider(AppConfig { app_name: "client" })
509 .build()
510 .expect("client testing module should build");
511 let registry = MicroserviceRegistry::builder()
512 .message("app.name", |_payload: (), ctx| async move {
513 let config = ctx.resolve::<AppConfig>()?;
514 Ok(config.app_name.to_string())
515 })
516 .build();
517 let client = module.microservice_client_with_metadata(
518 registry,
519 "test-client",
520 TransportMetadata::new().insert("suite", "testing"),
521 );
522
523 let response: String = client
524 .send("app.name", ())
525 .await
526 .expect("response should resolve");
527
528 assert_eq!(response, "client");
529 }
530}