1use std::{marker::PhantomData, sync::Arc};
2
3use anyhow::Result;
4use async_graphql::{ObjectType, Schema, SubscriptionType};
5use axum::{http::HeaderMap, Router};
6use nestforge_core::{
7 initialize_module_runtime, AuthIdentity, Container, ContainerError, InitializedModule,
8 ModuleDefinition, RequestId,
9};
10use nestforge_graphql::{graphql_router, graphql_router_with_config, GraphQlConfig};
11use nestforge_grpc::GrpcContext;
12use nestforge_microservices::{
13 InProcessMicroserviceClient, MicroserviceContext, MicroserviceRegistry, TransportMetadata,
14};
15use nestforge_websockets::WebSocketContext;
16
17type OverrideFn = Box<dyn Fn(&Container) -> Result<()> + Send + Sync>;
18
19pub struct TestFactory<M: ModuleDefinition> {
20 overrides: Vec<OverrideFn>,
21 _marker: PhantomData<M>,
22}
23
24impl<M: ModuleDefinition> TestFactory<M> {
25 pub fn create() -> Self {
26 Self {
27 overrides: Vec::new(),
28 _marker: PhantomData,
29 }
30 }
31
32 pub fn override_provider<T>(mut self, value: T) -> Self
33 where
34 T: Send + Sync + Clone + 'static,
35 {
36 self.overrides.push(Box::new(move |container| {
37 container.override_value(value.clone())?;
38 Ok(())
39 }));
40 self
41 }
42
43 pub fn build(self) -> Result<TestingModule> {
44 let container = Container::new();
45
46 for override_fn in self.overrides {
47 override_fn(&container)?;
48 }
49
50 let runtime = initialize_module_runtime::<M>(&container)?;
51 runtime.run_module_init(&container)?;
52 runtime.run_application_bootstrap(&container)?;
53
54 Ok(TestingModule {
55 container,
56 runtime: Arc::new(runtime),
57 })
58 }
59}
60
61#[derive(Clone)]
62pub struct TestingModule {
63 container: Container,
64 runtime: Arc<InitializedModule>,
65}
66
67impl TestingModule {
68 pub fn container(&self) -> &Container {
69 &self.container
70 }
71
72 pub fn resolve<T>(&self) -> Result<Arc<T>, ContainerError>
73 where
74 T: Send + Sync + 'static,
75 {
76 self.container.resolve::<T>()
77 }
78
79 pub fn http_router(&self) -> Router {
80 let mut app: Router<Container> = Router::new();
81 for controller_router in &self.runtime.controllers {
82 app = app.merge(controller_router.clone());
83 }
84
85 app.with_state(self.container.clone())
86 }
87
88 pub fn graphql_router<Query, Mutation, Subscription>(
89 &self,
90 schema: Schema<Query, Mutation, Subscription>,
91 ) -> Router
92 where
93 Query: ObjectType + Send + Sync + 'static,
94 Mutation: ObjectType + Send + Sync + 'static,
95 Subscription: SubscriptionType + Send + Sync + 'static,
96 {
97 self.http_router()
98 .merge(graphql_router(schema).with_state(self.container.clone()))
99 }
100
101 pub fn graphql_router_with_paths<Query, Mutation, Subscription>(
102 &self,
103 schema: Schema<Query, Mutation, Subscription>,
104 endpoint: impl Into<String>,
105 graphiql_endpoint: Option<String>,
106 ) -> Router
107 where
108 Query: ObjectType + Send + Sync + 'static,
109 Mutation: ObjectType + Send + Sync + 'static,
110 Subscription: SubscriptionType + Send + Sync + 'static,
111 {
112 let config = if let Some(graphiql_endpoint) = graphiql_endpoint {
113 GraphQlConfig::new(endpoint).with_graphiql(graphiql_endpoint)
114 } else {
115 GraphQlConfig::new(endpoint).without_graphiql()
116 };
117
118 self.http_router()
119 .merge(graphql_router_with_config(schema, config).with_state(self.container.clone()))
120 }
121
122 pub fn grpc_context(&self) -> GrpcContext {
123 GrpcContext::new(self.container.clone())
124 }
125
126 pub fn websocket_context(&self) -> WebSocketContext {
127 WebSocketContext::new(self.container.clone(), None, None, HeaderMap::new())
128 }
129
130 pub fn websocket_context_with(
131 &self,
132 request_id: Option<RequestId>,
133 auth_identity: Option<AuthIdentity>,
134 headers: HeaderMap,
135 ) -> WebSocketContext {
136 WebSocketContext::new(self.container.clone(), request_id, auth_identity, headers)
137 }
138
139 pub fn microservice_context(
140 &self,
141 transport: impl Into<String>,
142 pattern: impl Into<String>,
143 ) -> MicroserviceContext {
144 MicroserviceContext::new(
145 self.container.clone(),
146 transport,
147 pattern,
148 TransportMetadata::default(),
149 )
150 }
151
152 pub fn microservice_context_with_metadata(
153 &self,
154 transport: impl Into<String>,
155 pattern: impl Into<String>,
156 metadata: TransportMetadata,
157 ) -> MicroserviceContext {
158 MicroserviceContext::new(self.container.clone(), transport, pattern, metadata)
159 }
160
161 pub fn microservice_client(
162 &self,
163 registry: MicroserviceRegistry,
164 ) -> InProcessMicroserviceClient {
165 InProcessMicroserviceClient::new(self.container.clone(), registry)
166 }
167
168 pub fn microservice_client_with_metadata(
169 &self,
170 registry: MicroserviceRegistry,
171 transport: impl Into<String>,
172 metadata: TransportMetadata,
173 ) -> InProcessMicroserviceClient {
174 InProcessMicroserviceClient::new(self.container.clone(), registry)
175 .with_transport(transport)
176 .with_metadata(metadata)
177 }
178
179 pub fn shutdown(&self) -> Result<()> {
180 self.runtime.run_module_destroy(&self.container)?;
181 self.runtime.run_application_shutdown(&self.container)?;
182 Ok(())
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use std::sync::{Arc as StdArc, Mutex};
190
191 use async_graphql::{EmptyMutation, EmptySubscription};
192 use nestforge_microservices::MicroserviceClient;
193 use nestforge_core::{register_provider, ControllerDefinition, LifecycleHook, Provider};
194 use tower::ServiceExt;
195
196 #[derive(Clone, Debug, PartialEq, Eq)]
197 struct AppConfig {
198 app_name: &'static str,
199 }
200
201 struct AppModule;
202 impl ModuleDefinition for AppModule {
203 fn register(container: &Container) -> Result<()> {
204 container.register(AppConfig {
205 app_name: "default",
206 })?;
207 Ok(())
208 }
209 }
210
211 #[test]
212 fn builds_testing_module_and_resolves_default_provider() {
213 let module = TestFactory::<AppModule>::create()
214 .build()
215 .expect("test module should build");
216
217 let config = module
218 .resolve::<AppConfig>()
219 .expect("config should resolve");
220 assert_eq!(
221 *config,
222 AppConfig {
223 app_name: "default"
224 }
225 );
226 }
227
228 #[test]
229 fn overrides_provider_value() {
230 let module = TestFactory::<AppModule>::create()
231 .override_provider(AppConfig { app_name: "test" })
232 .build()
233 .expect("test module should build with overrides");
234
235 let config = module
236 .resolve::<AppConfig>()
237 .expect("config should resolve");
238 assert_eq!(*config, AppConfig { app_name: "test" });
239 }
240
241 #[derive(Clone, Debug, PartialEq, Eq)]
242 struct GreetingService {
243 greeting: String,
244 }
245
246 struct FactoryModule;
247 impl ModuleDefinition for FactoryModule {
248 fn register(container: &Container) -> Result<()> {
249 register_provider(
250 container,
251 Provider::factory(|container| {
252 let config = container.resolve::<AppConfig>()?;
253 Ok(GreetingService {
254 greeting: format!("hello {}", config.app_name),
255 })
256 }),
257 )?;
258 Ok(())
259 }
260 }
261
262 #[test]
263 fn overrides_are_applied_before_factory_resolution() {
264 let module = TestFactory::<FactoryModule>::create()
265 .override_provider(AppConfig {
266 app_name: "override",
267 })
268 .build()
269 .expect("test module should build with transitive overrides");
270
271 let greeting = module
272 .resolve::<GreetingService>()
273 .expect("greeting service should resolve");
274 assert_eq!(greeting.greeting, "hello override");
275 }
276
277 struct HttpController;
278
279 impl ControllerDefinition for HttpController {
280 fn router() -> Router<Container> {
281 Router::new().route(
282 "/health",
283 axum::routing::get(|| async { axum::Json(serde_json::json!({ "ok": true })) }),
284 )
285 }
286 }
287
288 struct HttpModule;
289 impl ModuleDefinition for HttpModule {
290 fn register(_container: &Container) -> Result<()> {
291 Ok(())
292 }
293
294 fn controllers() -> Vec<Router<Container>> {
295 vec![HttpController::router()]
296 }
297 }
298
299 #[tokio::test]
300 async fn builds_http_router_from_testing_module_runtime() {
301 let module = TestFactory::<HttpModule>::create()
302 .build()
303 .expect("http testing module should build");
304
305 let response = module
306 .http_router()
307 .oneshot(
308 axum::http::Request::builder()
309 .uri("/health")
310 .body(axum::body::Body::empty())
311 .expect("request should build"),
312 )
313 .await
314 .expect("request should succeed");
315
316 assert_eq!(response.status(), axum::http::StatusCode::OK);
317 }
318
319 struct QueryRoot;
320
321 #[async_graphql::Object]
322 impl QueryRoot {
323 async fn app_name(&self, ctx: &async_graphql::Context<'_>) -> &str {
324 let config = ctx
325 .data::<Container>()
326 .expect("container should be present")
327 .resolve::<AppConfig>()
328 .expect("app config should resolve");
329
330 config.app_name
331 }
332 }
333
334 #[tokio::test]
335 async fn builds_graphql_router_from_testing_module_runtime() {
336 let module = TestFactory::<AppModule>::create()
337 .override_provider(AppConfig { app_name: "graphql" })
338 .build()
339 .expect("graphql testing module should build");
340 let schema = Schema::build(QueryRoot, EmptyMutation, EmptySubscription).finish();
341
342 let response = module
343 .graphql_router_with_paths(schema, "/graphql", None)
344 .oneshot(
345 axum::http::Request::builder()
346 .method("POST")
347 .uri("/graphql")
348 .header(axum::http::header::CONTENT_TYPE, "application/json")
349 .body(axum::body::Body::from(
350 serde_json::json!({ "query": "{ appName }" }).to_string(),
351 ))
352 .expect("request should build"),
353 )
354 .await
355 .expect("graphql request should succeed");
356
357 assert_eq!(response.status(), axum::http::StatusCode::OK);
358 }
359
360 #[derive(Clone)]
361 struct HookLog(StdArc<Mutex<Vec<&'static str>>>);
362
363 fn record_destroy(container: &Container) -> Result<()> {
364 let log = container.resolve::<HookLog>()?;
365 log.0.lock().expect("hook log should be writable").push("destroy");
366 Ok(())
367 }
368
369 fn record_shutdown(container: &Container) -> Result<()> {
370 let log = container.resolve::<HookLog>()?;
371 log.0
372 .lock()
373 .expect("hook log should be writable")
374 .push("shutdown");
375 Ok(())
376 }
377
378 struct HookModule;
379
380 impl ModuleDefinition for HookModule {
381 fn register(container: &Container) -> Result<()> {
382 container.register(HookLog(StdArc::new(Mutex::new(Vec::new()))))?;
383 Ok(())
384 }
385
386 fn on_module_destroy() -> Vec<LifecycleHook> {
387 vec![record_destroy]
388 }
389
390 fn on_application_shutdown() -> Vec<LifecycleHook> {
391 vec![record_shutdown]
392 }
393 }
394
395 #[test]
396 fn shutdown_runs_destroy_and_shutdown_hooks() {
397 let module = TestFactory::<HookModule>::create()
398 .build()
399 .expect("hook testing module should build");
400
401 module.shutdown().expect("testing module should shut down");
402
403 let log = module.resolve::<HookLog>().expect("hook log should resolve");
404 let entries = log.0.lock().expect("hook log should be readable").clone();
405 assert_eq!(entries, vec!["destroy", "shutdown"]);
406 }
407
408 #[test]
409 fn builds_grpc_context_from_testing_module_runtime() {
410 let module = TestFactory::<AppModule>::create()
411 .override_provider(AppConfig { app_name: "grpc" })
412 .build()
413 .expect("grpc testing module should build");
414
415 let ctx = module.grpc_context();
416 let config = ctx.resolve::<AppConfig>().expect("app config should resolve");
417
418 assert_eq!(config.app_name, "grpc");
419 }
420
421 #[test]
422 fn builds_transport_contexts_with_testing_container() {
423 let module = TestFactory::<AppModule>::create()
424 .override_provider(AppConfig {
425 app_name: "transport",
426 })
427 .build()
428 .expect("transport testing module should build");
429
430 let websocket = module.websocket_context();
431 let config = websocket
432 .resolve::<AppConfig>()
433 .expect("app config should resolve from websocket context");
434 assert_eq!(config.app_name, "transport");
435
436 let microservice = module.microservice_context("test", "users.count");
437 let config = microservice
438 .resolve::<AppConfig>()
439 .expect("app config should resolve from microservice context");
440 assert_eq!(config.app_name, "transport");
441 assert_eq!(microservice.transport(), "test");
442 assert_eq!(microservice.pattern(), "users.count");
443 }
444
445 #[tokio::test]
446 async fn builds_in_process_microservice_clients_from_testing_module() {
447 let module = TestFactory::<AppModule>::create()
448 .override_provider(AppConfig { app_name: "client" })
449 .build()
450 .expect("client testing module should build");
451 let registry = MicroserviceRegistry::builder()
452 .message("app.name", |_payload: (), ctx| async move {
453 let config = ctx.resolve::<AppConfig>()?;
454 Ok(config.app_name.to_string())
455 })
456 .build();
457 let client = module.microservice_client_with_metadata(
458 registry,
459 "test-client",
460 TransportMetadata::new().insert("suite", "testing"),
461 );
462
463 let response: String = client
464 .send("app.name", ())
465 .await
466 .expect("response should resolve");
467
468 assert_eq!(response, "client");
469 }
470}