graphql_operation_server_harness/adapters/gateways/async_graphql/
server.rs

1use async_trait::async_trait;
2use axum::{
3    extract::State,
4    http::StatusCode,
5    response::IntoResponse,
6    routing::post,
7    Router,
8};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::net::SocketAddr;
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::sync::Arc;
15use tokio::sync::{oneshot, Mutex as TokioMutex};
16
17use crate::entities::{CollectedRequest, Handler, Operation, OperationType, RequestContext};
18use crate::error::HarnessError;
19use crate::use_cases::ports::{Collector, Server};
20
21/// AsyncGraphQL-compatible server implementation
22#[derive(Clone)]
23pub struct AsyncGraphQL {
24    addr: SocketAddr,
25}
26
27impl AsyncGraphQL {
28    pub fn new(addr: SocketAddr) -> Self {
29        Self { addr }
30    }
31
32    pub fn bind(addr: impl Into<SocketAddr>) -> Self {
33        Self::new(addr.into())
34    }
35}
36
37impl Default for AsyncGraphQL {
38    fn default() -> Self {
39        Self::new(([127, 0, 0, 1], 0).into())
40    }
41}
42
43#[derive(Debug, Deserialize)]
44struct GraphQLRequest {
45    query: String,
46    #[serde(rename = "operationName")]
47    operation_name: Option<String>,
48    variables: Option<Value>,
49}
50
51#[derive(Debug, Serialize)]
52struct GraphQLResponse {
53    data: Option<Value>,
54    #[serde(skip_serializing_if = "Option::is_none")]
55    errors: Option<Vec<Value>>,
56}
57
58/// Shared state for tracking completion
59#[derive(Clone)]
60struct CompletionTracker {
61    total_handlers: usize,
62    handlers_called: Arc<AtomicUsize>,
63    shutdown_tx: Arc<TokioMutex<Option<oneshot::Sender<()>>>>,
64}
65
66impl CompletionTracker {
67    fn new(total_handlers: usize, shutdown_tx: oneshot::Sender<()>) -> Self {
68        Self {
69            total_handlers,
70            handlers_called: Arc::new(AtomicUsize::new(0)),
71            shutdown_tx: Arc::new(TokioMutex::new(Some(shutdown_tx))),
72        }
73    }
74
75    async fn handler_called(&self) {
76        let called = self.handlers_called.fetch_add(1, Ordering::SeqCst) + 1;
77        if called >= self.total_handlers {
78            if let Some(tx) = self.shutdown_tx.lock().await.take() {
79                let _ = tx.send(());
80            }
81        }
82    }
83}
84
85/// Internal trait for collecting requests (without Output type)
86trait InternalCollector: Send + Sync {
87    fn collect(&self, request: CollectedRequest);
88}
89
90impl<C: Collector> InternalCollector for C {
91    fn collect(&self, request: CollectedRequest) {
92        Collector::collect(self, request);
93    }
94}
95
96/// State shared with handlers
97#[derive(Clone)]
98struct ServerState {
99    /// Map from field name to handlers
100    query_handlers: Arc<HashMap<String, FieldState>>,
101    mutation_handlers: Arc<HashMap<String, FieldState>>,
102    collector: Arc<dyn InternalCollector>,
103    completion_tracker: CompletionTracker,
104}
105
106#[derive(Clone)]
107struct FieldState {
108    handlers: Vec<Handler>,
109    call_count: Arc<AtomicUsize>,
110}
111
112async fn handle_graphql(
113    State(state): State<ServerState>,
114    body: String,
115) -> impl IntoResponse {
116    let request: GraphQLRequest = match serde_json::from_str(&body) {
117        Ok(req) => req,
118        Err(e) => {
119            let response = GraphQLResponse {
120                data: None,
121                errors: Some(vec![serde_json::json!({"message": e.to_string()})]),
122            };
123            return (StatusCode::OK, axum::Json(response));
124        }
125    };
126
127    // Collect the request
128    let mut collected = CollectedRequest::new(&request.query);
129    if let Some(op_name) = &request.operation_name {
130        collected = collected.with_operation_name(op_name);
131    }
132    if let Some(vars) = &request.variables {
133        collected = collected.with_variables(vars.clone());
134    }
135    state.collector.collect(collected);
136
137    // Parse the query to find the operation type and field
138    let query = request.query.trim();
139    let (handlers_map, _op_type) = if query.starts_with("mutation") {
140        (&state.mutation_handlers, "mutation")
141    } else {
142        (&state.query_handlers, "query")
143    };
144
145    // Simple field extraction - find field names in the query
146    let mut response_data = serde_json::Map::new();
147    let mut errors: Vec<Value> = Vec::new();
148
149    for (field_name, field_state) in handlers_map.iter() {
150        if query.contains(field_name) {
151            let call_index = field_state.call_count.fetch_add(1, Ordering::SeqCst);
152            let handler_count = field_state.handlers.len();
153            let handler_index = call_index.min(handler_count.saturating_sub(1));
154
155            // Check if this is a new handler being called for the first time
156            if call_index < handler_count {
157                state.completion_tracker.handler_called().await;
158            }
159
160            if let Some(handler) = field_state.handlers.get(handler_index) {
161                let mut ctx = RequestContext::new(field_name).with_query(&request.query);
162                if let Some(op_name) = &request.operation_name {
163                    ctx = ctx.with_operation_name(op_name);
164                }
165                if let Some(vars) = &request.variables {
166                    ctx = ctx.with_variables(vars.clone());
167                }
168
169                let handler_response = handler.respond(&ctx);
170                if let Some(obj) = handler_response.data.as_object() {
171                    for (k, v) in obj {
172                        response_data.insert(k.clone(), v.clone());
173                    }
174                } else {
175                    response_data.insert(field_name.clone(), handler_response.data.clone());
176                }
177                if let Some(errs) = &handler_response.errors {
178                    for err in errs {
179                        let mut err_val = serde_json::json!({"message": err.message});
180                        if let Some(path) = &err.path {
181                            err_val["path"] = serde_json::json!(path);
182                        }
183                        errors.push(err_val);
184                    }
185                }
186            }
187        }
188    }
189
190    let response = GraphQLResponse {
191        data: Some(Value::Object(response_data)),
192        errors: if errors.is_empty() { None } else { Some(errors) },
193    };
194
195    (StatusCode::OK, axum::Json(response))
196}
197
198#[async_trait]
199impl Server for AsyncGraphQL {
200    async fn run<C, F>(
201        &self,
202        operations: Vec<Operation>,
203        collector: C,
204        on_ready: Option<F>,
205    ) -> Result<C::Output, HarnessError>
206    where
207        C: Collector + 'static,
208        F: FnOnce(SocketAddr) + Send + 'static,
209    {
210        let collector_arc: Arc<C> = Arc::new(collector);
211
212        // Count total handlers
213        let total_handlers: usize = operations
214            .iter()
215            .map(|op| {
216                op.fields
217                    .iter()
218                    .map(|f| f.handlers.len().max(1))
219                    .sum::<usize>()
220            })
221            .sum();
222
223        // Create shutdown channel for auto-shutdown
224        let (auto_shutdown_tx, auto_shutdown_rx) = oneshot::channel();
225        let completion_tracker = CompletionTracker::new(total_handlers, auto_shutdown_tx);
226
227        let mut query_handlers = HashMap::new();
228        let mut mutation_handlers = HashMap::new();
229
230        for operation in operations {
231            let handlers_map = match operation.operation_type {
232                OperationType::Query => &mut query_handlers,
233                OperationType::Mutation => &mut mutation_handlers,
234                OperationType::Subscription => continue, // Skip subscriptions for now
235            };
236
237            for field in operation.fields {
238                handlers_map.insert(
239                    field.name,
240                    FieldState {
241                        handlers: field.handlers,
242                        call_count: Arc::new(AtomicUsize::new(0)),
243                    },
244                );
245            }
246        }
247
248        let state = ServerState {
249            query_handlers: Arc::new(query_handlers),
250            mutation_handlers: Arc::new(mutation_handlers),
251            collector: collector_arc.clone(),
252            completion_tracker,
253        };
254
255        let router = Router::new()
256            .route("/graphql", post(handle_graphql))
257            .with_state(state);
258
259        let listener = tokio::net::TcpListener::bind(self.addr)
260            .await
261            .map_err(|e| HarnessError::ServerError(e.to_string()))?;
262
263        let addr = listener
264            .local_addr()
265            .map_err(|e| HarnessError::ServerError(e.to_string()))?;
266
267        // Call on_ready callback if provided
268        if let Some(callback) = on_ready {
269            callback(addr);
270        }
271
272        // Run the server until auto-shutdown
273        axum::serve(listener, router)
274            .with_graceful_shutdown(async {
275                let _ = auto_shutdown_rx.await;
276            })
277            .await
278            .map_err(|e| HarnessError::ServerError(e.to_string()))?;
279
280        // Extract the collector from the Arc and return its output
281        let collector = Arc::try_unwrap(collector_arc)
282            .map_err(|_| HarnessError::ServerError("Failed to unwrap collector".to_string()))?;
283        Ok(collector.into_output())
284    }
285}