1use crate::error::{DbxError, DbxResult};
2use crate::grid::protocol::{GridMessage, QueryMessage, StorageMessage};
3use crate::grid::quic::{GridMessageWrapper, QuicChannel};
4use crate::sql::executor::local_executor::LocalExecutor;
5use crate::sql::planner::types::PhysicalPlan;
6use crate::storage::erasure_coding::distributed_store::DistributedErasureCodingStore;
7use dashmap::DashMap;
8use std::sync::Arc;
9use tokio::sync::mpsc;
10use tracing::{error, info, warn};
11
12type QuerySender = mpsc::Sender<DbxResult<Option<Vec<u8>>>>;
13type QueryStreamMap = Arc<DashMap<(String, usize), QuerySender>>;
14
15pub struct GridManager {
19 quic_channel: Arc<QuicChannel>,
20 ec_store: Arc<DistributedErasureCodingStore>,
21 receiver: mpsc::Receiver<GridMessageWrapper>,
22 query_streams: QueryStreamMap,
23 stage_barriers: Arc<DashMap<(String, usize, std::net::SocketAddr), mpsc::Sender<()>>>,
24 local_executor: Option<Arc<LocalExecutor>>,
26 node_id: u32,
28}
29
30impl GridManager {
31 pub fn new(
32 quic_channel: Arc<QuicChannel>,
33 ec_store: Arc<DistributedErasureCodingStore>,
34 receiver: mpsc::Receiver<GridMessageWrapper>,
35 ) -> Self {
36 Self::with_node_id(quic_channel, ec_store, receiver, 0)
37 }
38
39 pub fn with_node_id(
40 quic_channel: Arc<QuicChannel>,
41 ec_store: Arc<DistributedErasureCodingStore>,
42 receiver: mpsc::Receiver<GridMessageWrapper>,
43 node_id: u32,
44 ) -> Self {
45 Self {
46 quic_channel,
47 ec_store,
48 receiver,
49 query_streams: Arc::new(DashMap::new()),
50 stage_barriers: Arc::new(DashMap::new()),
51 local_executor: None,
52 node_id,
53 }
54 }
55
56 pub fn with_local_executor(mut self, executor: Arc<LocalExecutor>) -> Self {
58 self.local_executor = Some(executor);
59 self
60 }
61
62 pub fn get_query_streams(&self) -> QueryStreamMap {
63 Arc::clone(&self.query_streams)
64 }
65
66 pub fn get_stage_barriers(
67 &self,
68 ) -> Arc<DashMap<(String, usize, std::net::SocketAddr), mpsc::Sender<()>>> {
69 Arc::clone(&self.stage_barriers)
70 }
71
72 pub async fn run(mut self) {
74 info!(
75 "GridManager receiver loop started on {}",
76 self.quic_channel.local_addr
77 );
78
79 while let Some(wrapper) = self.receiver.recv().await {
80 let ec_store = Arc::clone(&self.ec_store);
81 let query_streams = Arc::clone(&self.query_streams);
82 let stage_barriers = Arc::clone(&self.stage_barriers);
83 let quic_channel = Arc::clone(&self.quic_channel);
84 let local_executor = self.local_executor.clone();
85 let node_id = self.node_id;
86
87 tokio::spawn(async move {
89 if let Err(e) = Self::handle_message(
90 ec_store,
91 query_streams,
92 stage_barriers,
93 quic_channel,
94 local_executor,
95 node_id,
96 wrapper,
97 )
98 .await
99 {
100 error!("Error handling GridMessage: {:?}", e);
101 }
102 });
103 }
104
105 info!("GridManager receiver loop terminated");
106 }
107
108 async fn handle_message(
110 ec_store: Arc<DistributedErasureCodingStore>,
111 query_streams: QueryStreamMap,
112 stage_barriers: Arc<DashMap<(String, usize, std::net::SocketAddr), mpsc::Sender<()>>>,
113 quic_channel: Arc<QuicChannel>,
114 local_executor: Option<Arc<LocalExecutor>>,
115 node_id: u32,
116 wrapper: GridMessageWrapper,
117 ) -> DbxResult<()> {
118 let GridMessageWrapper { msg, mut stream } = wrapper;
119
120 match msg {
121 GridMessage::Storage(storage_msg) => {
122 Self::handle_storage_message(ec_store, storage_msg, &mut stream).await
123 }
124 GridMessage::Query(query_msg) => {
125 Self::handle_query_message(
126 query_streams,
127 stage_barriers,
128 quic_channel,
129 local_executor,
130 node_id,
131 query_msg,
132 )
133 .await
134 }
135 GridMessage::Lock(_) => {
136 warn!("LockMessage received but not implemented yet");
137 Ok(())
138 }
139 GridMessage::Replication(_) => {
140 warn!("ReplicationMessage received but not implemented yet");
141 Ok(())
142 }
143 }
144 }
145
146 async fn handle_query_message(
148 query_streams: QueryStreamMap,
149 stage_barriers: Arc<DashMap<(String, usize, std::net::SocketAddr), mpsc::Sender<()>>>,
150 quic_channel: Arc<QuicChannel>,
151 local_executor: Option<Arc<LocalExecutor>>,
152 node_id: u32,
153 msg: QueryMessage,
154 ) -> DbxResult<()> {
155 match msg {
156 QueryMessage::ExecuteFragment {
157 execution_id,
158 stage_id,
159 plans_bytes,
160 coordinator_addr,
161 } => {
162 info!(
163 "ExecuteFragment received for ID: {}, Stage: {} from coordinator: {}",
164 execution_id, stage_id, coordinator_addr
165 );
166
167 let executor = match local_executor {
168 Some(e) => e,
169 None => {
170 warn!(
171 "ExecuteFragment received but no LocalExecutor configured — ignoring"
172 );
173 return Ok(());
174 }
175 };
176
177 let coord_addr: std::net::SocketAddr = match coordinator_addr.parse() {
179 Ok(a) => a,
180 Err(e) => {
181 return Err(DbxError::Network(format!(
182 "Invalid coordinator addr: {}",
183 e
184 )));
185 }
186 };
187
188 let exec_id = execution_id.clone();
189 let query_streams = Arc::clone(&query_streams);
190
191 let mut plans: Vec<PhysicalPlan> = Vec::new();
193 for bytes in plans_bytes {
194 let plan = bincode::deserialize(&bytes)
195 .map_err(|e| DbxError::Serialization(e.to_string()))?;
196 plans.push(plan);
197 }
198
199 let quic_master = Arc::clone(&quic_channel);
200 tokio::spawn(async move {
201 info!(
202 "Worker spawning execution for exec_id: {}, stage_id: {}",
203 exec_id, stage_id
204 );
205
206 let mut join_set = tokio::task::JoinSet::new();
207
208 for worker_plan in plans {
209 let executor_ref = Arc::clone(&executor);
210 let quic_ref = Arc::clone(&quic_master);
211 let exec_id_ref = exec_id.clone();
212 let q_streams = Arc::clone(&query_streams);
213
214 join_set.spawn(async move {
215 let (batches, channels) = match tokio::task::spawn_blocking(move || {
217 let mut chs = crate::sql::executor::local_executor::DistributedChannels::default();
218 let b = executor_ref.execute_collect_distributed(&worker_plan, &mut chs)?;
219 Ok::<(Vec<arrow::array::RecordBatch>, _), DbxError>((b, chs))
220 }).await {
221 Ok(Ok(res)) => res,
222 Ok(Err(e)) => {
223 error!("Worker execution error for {}: {:?}", exec_id_ref, e);
224 let eof_msg = GridMessage::Query(QueryMessage::ExchangeData {
225 execution_id: exec_id_ref.clone(),
226 exchange_id: 0,
227 node_id,
228 is_eof: true,
229 batch_data: vec![],
230 });
231 let _ = quic_ref.send_message(coord_addr, eof_msg).await;
232 return;
233 }
234 Err(e) => {
235 error!("Worker spawn_blocking panic: {:?}", e);
236 return;
237 }
238 };
239
240 for (e_id, tx) in channels.exchanges {
242 q_streams.insert((exec_id_ref.clone(), e_id), tx);
243 }
244
245 let mut shuffle_join_set = tokio::task::JoinSet::new();
247 for (e_id, receivers) in channels.shuffles {
248 for (target_addr, mut rx) in receivers {
249 let quic_sub = Arc::clone(&quic_ref);
250 let exec_sub = exec_id_ref.clone();
251 shuffle_join_set.spawn(async move {
252 while let Some(Ok(Some(batch_bytes))) = rx.recv().await {
253 let msg = GridMessage::Query(QueryMessage::ExchangeData {
254 execution_id: exec_sub.clone(),
255 exchange_id: e_id,
256 node_id,
257 is_eof: false,
258 batch_data: batch_bytes,
259 });
260 let _ = quic_sub.send_message(target_addr, msg).await;
261 }
262 let eof_msg = GridMessage::Query(QueryMessage::ExchangeData {
263 execution_id: exec_sub,
264 exchange_id: e_id,
265 node_id,
266 is_eof: true,
267 batch_data: vec![],
268 });
269 let _ = quic_sub.send_message(target_addr, eof_msg).await;
270 });
271 }
272 }
273
274 for batch in batches {
276 let ipc_bytes = match crate::grid::protocol::serialize_batch_to_ipc(&batch) {
277 Ok(b) => b,
278 Err(_) => continue,
279 };
280 let msg = GridMessage::Query(QueryMessage::ExchangeData {
281 execution_id: exec_id_ref.clone(),
282 exchange_id: 0,
283 node_id,
284 is_eof: false,
285 batch_data: ipc_bytes,
286 });
287 let _ = quic_ref.send_message(coord_addr, msg).await;
288 }
289
290 let _ = quic_ref.send_message(coord_addr, GridMessage::Query(QueryMessage::ExchangeData {
292 execution_id: exec_id_ref.clone(),
293 exchange_id: 0,
294 node_id,
295 is_eof: true,
296 batch_data: vec![],
297 })).await;
298
299 while shuffle_join_set.join_next().await.is_some() {}
301 });
302 }
303
304 while join_set.join_next().await.is_some() {}
306
307 let complete_msg = GridMessage::Query(QueryMessage::FragmentCompleted {
309 execution_id: exec_id.clone(),
310 stage_id,
311 });
312 let _ = quic_master.send_message(coord_addr, complete_msg).await;
313 info!(
314 "Worker completed all plans for exec_id: {}, stage_id: {}",
315 exec_id, stage_id
316 );
317 });
318
319 Ok(())
320 }
321 QueryMessage::FragmentCompleted {
322 execution_id,
323 stage_id,
324 } => {
325 let barriers = stage_barriers;
326 let matched_keys: Vec<_> = barriers
327 .iter()
328 .filter(|entry| entry.key().0 == execution_id && entry.key().1 == stage_id)
329 .map(|entry| entry.key().clone())
330 .collect();
331
332 for key in matched_keys {
333 if let Some(sender) = barriers.get_mut(&key) {
334 let _ = sender.try_send(());
335 }
336 }
337 Ok(())
338 }
339 QueryMessage::ExchangeData {
340 execution_id,
341 exchange_id,
342 node_id: _,
343 is_eof,
344 batch_data,
345 } => {
346 let sender_opt = query_streams
349 .get(&(execution_id.clone(), exchange_id))
350 .map(|kv| kv.value().clone());
351
352 if let Some(sender) = sender_opt {
353 if is_eof {
354 let _ = sender.send(Ok(None)).await;
355 } else {
356 let _ = sender.send(Ok(Some(batch_data))).await;
357 }
358 } else {
359 warn!("ExchangeData for unknown execution_id: {}", execution_id);
360 }
361 Ok(())
362 }
363 }
364 }
365
366 async fn handle_storage_message(
368 ec_store: Arc<DistributedErasureCodingStore>,
369 msg: StorageMessage,
370 stream: &mut Option<s2n_quic::stream::BidirectionalStream>,
371 ) -> DbxResult<()> {
372 match msg {
373 StorageMessage::StoreShard {
374 key,
375 shard_id,
376 data,
377 } => {
378 info!("Storing shard {}:{} locally", key, shard_id);
379 ec_store.local_store_shard(&key, shard_id, &data)?;
380 Ok(())
381 }
382 StorageMessage::FetchShard { key, shard_id } => {
383 info!("Fetching shard {}:{} for remote request", key, shard_id);
384 let shard_data = ec_store.local_fetch_shard(&key, shard_id)?;
385
386 if let Some(s) = stream {
388 let reply = GridMessage::Storage(StorageMessage::ShardResponse {
389 key: key.clone(),
390 shard_id,
391 data: shard_data,
392 });
393 ::tracing::debug!(
394 "Sending ShardResponse for {}:{} on stream...",
395 key,
396 shard_id
397 );
398 if let Err(e) = QuicChannel::send_response(s, reply).await {
399 ::tracing::error!(
400 "Failed to send ShardResponse for {}:{}: {:?}",
401 key,
402 shard_id,
403 e
404 );
405 } else {
406 ::tracing::debug!(
407 "Successfully sent ShardResponse for {}:{}",
408 key,
409 shard_id
410 );
411 }
412 }
413 Ok(())
414 }
415 StorageMessage::ShardResponse { .. } => {
416 warn!("Received unexpected ShardResponse in main handler loop");
419 Ok(())
420 }
421 }
422 }
423}