graphql_operation_server_harness/adapters/gateways/async_graphql/
server.rs1use async_trait::async_trait;
2use axum::{
3 extract::State,
4 http::StatusCode,
5 response::IntoResponse,
6 routing::post,
7 Router,
8};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::net::SocketAddr;
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::sync::Arc;
15use tokio::sync::{oneshot, Mutex as TokioMutex};
16
17use crate::entities::{CollectedRequest, Handler, Operation, OperationType, RequestContext};
18use crate::error::HarnessError;
19use crate::use_cases::ports::{Collector, Server};
20
21#[derive(Clone)]
23pub struct AsyncGraphQL {
24 addr: SocketAddr,
25}
26
27impl AsyncGraphQL {
28 pub fn new(addr: SocketAddr) -> Self {
29 Self { addr }
30 }
31
32 pub fn bind(addr: impl Into<SocketAddr>) -> Self {
33 Self::new(addr.into())
34 }
35}
36
37impl Default for AsyncGraphQL {
38 fn default() -> Self {
39 Self::new(([127, 0, 0, 1], 0).into())
40 }
41}
42
43#[derive(Debug, Deserialize)]
44struct GraphQLRequest {
45 query: String,
46 #[serde(rename = "operationName")]
47 operation_name: Option<String>,
48 variables: Option<Value>,
49}
50
51#[derive(Debug, Serialize)]
52struct GraphQLResponse {
53 data: Option<Value>,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 errors: Option<Vec<Value>>,
56}
57
58#[derive(Clone)]
60struct CompletionTracker {
61 total_handlers: usize,
62 handlers_called: Arc<AtomicUsize>,
63 shutdown_tx: Arc<TokioMutex<Option<oneshot::Sender<()>>>>,
64}
65
66impl CompletionTracker {
67 fn new(total_handlers: usize, shutdown_tx: oneshot::Sender<()>) -> Self {
68 Self {
69 total_handlers,
70 handlers_called: Arc::new(AtomicUsize::new(0)),
71 shutdown_tx: Arc::new(TokioMutex::new(Some(shutdown_tx))),
72 }
73 }
74
75 async fn handler_called(&self) {
76 let called = self.handlers_called.fetch_add(1, Ordering::SeqCst) + 1;
77 if called >= self.total_handlers {
78 if let Some(tx) = self.shutdown_tx.lock().await.take() {
79 let _ = tx.send(());
80 }
81 }
82 }
83}
84
85trait InternalCollector: Send + Sync {
87 fn collect(&self, request: CollectedRequest);
88}
89
90impl<C: Collector> InternalCollector for C {
91 fn collect(&self, request: CollectedRequest) {
92 Collector::collect(self, request);
93 }
94}
95
96#[derive(Clone)]
98struct ServerState {
99 query_handlers: Arc<HashMap<String, FieldState>>,
101 mutation_handlers: Arc<HashMap<String, FieldState>>,
102 collector: Arc<dyn InternalCollector>,
103 completion_tracker: CompletionTracker,
104}
105
106#[derive(Clone)]
107struct FieldState {
108 handlers: Vec<Handler>,
109 call_count: Arc<AtomicUsize>,
110}
111
112async fn handle_graphql(
113 State(state): State<ServerState>,
114 body: String,
115) -> impl IntoResponse {
116 let request: GraphQLRequest = match serde_json::from_str(&body) {
117 Ok(req) => req,
118 Err(e) => {
119 let response = GraphQLResponse {
120 data: None,
121 errors: Some(vec![serde_json::json!({"message": e.to_string()})]),
122 };
123 return (StatusCode::OK, axum::Json(response));
124 }
125 };
126
127 let mut collected = CollectedRequest::new(&request.query);
129 if let Some(op_name) = &request.operation_name {
130 collected = collected.with_operation_name(op_name);
131 }
132 if let Some(vars) = &request.variables {
133 collected = collected.with_variables(vars.clone());
134 }
135 state.collector.collect(collected);
136
137 let query = request.query.trim();
139 let (handlers_map, _op_type) = if query.starts_with("mutation") {
140 (&state.mutation_handlers, "mutation")
141 } else {
142 (&state.query_handlers, "query")
143 };
144
145 let mut response_data = serde_json::Map::new();
147 let mut errors: Vec<Value> = Vec::new();
148
149 for (field_name, field_state) in handlers_map.iter() {
150 if query.contains(field_name) {
151 let call_index = field_state.call_count.fetch_add(1, Ordering::SeqCst);
152 let handler_count = field_state.handlers.len();
153 let handler_index = call_index.min(handler_count.saturating_sub(1));
154
155 if call_index < handler_count {
157 state.completion_tracker.handler_called().await;
158 }
159
160 if let Some(handler) = field_state.handlers.get(handler_index) {
161 let mut ctx = RequestContext::new(field_name).with_query(&request.query);
162 if let Some(op_name) = &request.operation_name {
163 ctx = ctx.with_operation_name(op_name);
164 }
165 if let Some(vars) = &request.variables {
166 ctx = ctx.with_variables(vars.clone());
167 }
168
169 let handler_response = handler.respond(&ctx);
170 if let Some(obj) = handler_response.data.as_object() {
171 for (k, v) in obj {
172 response_data.insert(k.clone(), v.clone());
173 }
174 } else {
175 response_data.insert(field_name.clone(), handler_response.data.clone());
176 }
177 if let Some(errs) = &handler_response.errors {
178 for err in errs {
179 let mut err_val = serde_json::json!({"message": err.message});
180 if let Some(path) = &err.path {
181 err_val["path"] = serde_json::json!(path);
182 }
183 errors.push(err_val);
184 }
185 }
186 }
187 }
188 }
189
190 let response = GraphQLResponse {
191 data: Some(Value::Object(response_data)),
192 errors: if errors.is_empty() { None } else { Some(errors) },
193 };
194
195 (StatusCode::OK, axum::Json(response))
196}
197
198#[async_trait]
199impl Server for AsyncGraphQL {
200 async fn run<C, F>(
201 &self,
202 operations: Vec<Operation>,
203 collector: C,
204 on_ready: Option<F>,
205 ) -> Result<C::Output, HarnessError>
206 where
207 C: Collector + 'static,
208 F: FnOnce(SocketAddr) + Send + 'static,
209 {
210 let collector_arc: Arc<C> = Arc::new(collector);
211
212 let total_handlers: usize = operations
214 .iter()
215 .map(|op| {
216 op.fields
217 .iter()
218 .map(|f| f.handlers.len().max(1))
219 .sum::<usize>()
220 })
221 .sum();
222
223 let (auto_shutdown_tx, auto_shutdown_rx) = oneshot::channel();
225 let completion_tracker = CompletionTracker::new(total_handlers, auto_shutdown_tx);
226
227 let mut query_handlers = HashMap::new();
228 let mut mutation_handlers = HashMap::new();
229
230 for operation in operations {
231 let handlers_map = match operation.operation_type {
232 OperationType::Query => &mut query_handlers,
233 OperationType::Mutation => &mut mutation_handlers,
234 OperationType::Subscription => continue, };
236
237 for field in operation.fields {
238 handlers_map.insert(
239 field.name,
240 FieldState {
241 handlers: field.handlers,
242 call_count: Arc::new(AtomicUsize::new(0)),
243 },
244 );
245 }
246 }
247
248 let state = ServerState {
249 query_handlers: Arc::new(query_handlers),
250 mutation_handlers: Arc::new(mutation_handlers),
251 collector: collector_arc.clone(),
252 completion_tracker,
253 };
254
255 let router = Router::new()
256 .route("/graphql", post(handle_graphql))
257 .with_state(state);
258
259 let listener = tokio::net::TcpListener::bind(self.addr)
260 .await
261 .map_err(|e| HarnessError::ServerError(e.to_string()))?;
262
263 let addr = listener
264 .local_addr()
265 .map_err(|e| HarnessError::ServerError(e.to_string()))?;
266
267 if let Some(callback) = on_ready {
269 callback(addr);
270 }
271
272 axum::serve(listener, router)
274 .with_graceful_shutdown(async {
275 let _ = auto_shutdown_rx.await;
276 })
277 .await
278 .map_err(|e| HarnessError::ServerError(e.to_string()))?;
279
280 let collector = Arc::try_unwrap(collector_arc)
282 .map_err(|_| HarnessError::ServerError("Failed to unwrap collector".to_string()))?;
283 Ok(collector.into_output())
284 }
285}