Skip to main content

dbx_core/grid/
manager.rs

1use crate::error::{DbxError, DbxResult};
2use crate::grid::protocol::{GridMessage, QueryMessage, StorageMessage};
3use crate::grid::quic::{GridMessageWrapper, QuicChannel};
4use crate::sql::executor::local_executor::LocalExecutor;
5use crate::sql::planner::types::PhysicalPlan;
6use crate::storage::erasure_coding::distributed_store::DistributedErasureCodingStore;
7use dashmap::DashMap;
8use std::sync::Arc;
9use tokio::sync::mpsc;
10use tracing::{error, info, warn};
11
12type QuerySender = mpsc::Sender<DbxResult<Option<Vec<u8>>>>;
13type QueryStreamMap = Arc<DashMap<(String, usize), QuerySender>>;
14
15/// 그리드 중앙 제어기 (Thin Dispatcher)
16///
17/// 네트워크 채널로부터 오는 메시지를 분류하여 도메인별 핸들러로 배달합니다.
18pub struct GridManager {
19    quic_channel: Arc<QuicChannel>,
20    ec_store: Arc<DistributedErasureCodingStore>,
21    receiver: mpsc::Receiver<GridMessageWrapper>,
22    query_streams: QueryStreamMap,
23    stage_barriers: Arc<DashMap<(String, usize, std::net::SocketAddr), mpsc::Sender<()>>>,
24    /// 워커 측 로컬 실행 엔진 (워커 노드일 때 사용)
25    local_executor: Option<Arc<LocalExecutor>>,
26    /// 이 노드의 식별자
27    node_id: u32,
28}
29
30impl GridManager {
31    pub fn new(
32        quic_channel: Arc<QuicChannel>,
33        ec_store: Arc<DistributedErasureCodingStore>,
34        receiver: mpsc::Receiver<GridMessageWrapper>,
35    ) -> Self {
36        Self::with_node_id(quic_channel, ec_store, receiver, 0)
37    }
38
39    pub fn with_node_id(
40        quic_channel: Arc<QuicChannel>,
41        ec_store: Arc<DistributedErasureCodingStore>,
42        receiver: mpsc::Receiver<GridMessageWrapper>,
43        node_id: u32,
44    ) -> Self {
45        Self {
46            quic_channel,
47            ec_store,
48            receiver,
49            query_streams: Arc::new(DashMap::new()),
50            stage_barriers: Arc::new(DashMap::new()),
51            local_executor: None,
52            node_id,
53        }
54    }
55
56    /// 워커 노드 로컬 실행기 설정
57    pub fn with_local_executor(mut self, executor: Arc<LocalExecutor>) -> Self {
58        self.local_executor = Some(executor);
59        self
60    }
61
62    pub fn get_query_streams(&self) -> QueryStreamMap {
63        Arc::clone(&self.query_streams)
64    }
65
66    pub fn get_stage_barriers(
67        &self,
68    ) -> Arc<DashMap<(String, usize, std::net::SocketAddr), mpsc::Sender<()>>> {
69        Arc::clone(&self.stage_barriers)
70    }
71
72    /// 수신 루프 시작
73    pub async fn run(mut self) {
74        info!(
75            "GridManager receiver loop started on {}",
76            self.quic_channel.local_addr
77        );
78
79        while let Some(wrapper) = self.receiver.recv().await {
80            let ec_store = Arc::clone(&self.ec_store);
81            let query_streams = Arc::clone(&self.query_streams);
82            let stage_barriers = Arc::clone(&self.stage_barriers);
83            let quic_channel = Arc::clone(&self.quic_channel);
84            let local_executor = self.local_executor.clone();
85            let node_id = self.node_id;
86
87            // 각 메시지를 비동기적으로 처리하여 병목 방지
88            tokio::spawn(async move {
89                if let Err(e) = Self::handle_message(
90                    ec_store,
91                    query_streams,
92                    stage_barriers,
93                    quic_channel,
94                    local_executor,
95                    node_id,
96                    wrapper,
97                )
98                .await
99                {
100                    error!("Error handling GridMessage: {:?}", e);
101                }
102            });
103        }
104
105        info!("GridManager receiver loop terminated");
106    }
107
108    /// 메시지 종류별 분기 처리
109    async fn handle_message(
110        ec_store: Arc<DistributedErasureCodingStore>,
111        query_streams: QueryStreamMap,
112        stage_barriers: Arc<DashMap<(String, usize, std::net::SocketAddr), mpsc::Sender<()>>>,
113        quic_channel: Arc<QuicChannel>,
114        local_executor: Option<Arc<LocalExecutor>>,
115        node_id: u32,
116        wrapper: GridMessageWrapper,
117    ) -> DbxResult<()> {
118        let GridMessageWrapper { msg, mut stream } = wrapper;
119
120        match msg {
121            GridMessage::Storage(storage_msg) => {
122                Self::handle_storage_message(ec_store, storage_msg, &mut stream).await
123            }
124            GridMessage::Query(query_msg) => {
125                Self::handle_query_message(
126                    query_streams,
127                    stage_barriers,
128                    quic_channel,
129                    local_executor,
130                    node_id,
131                    query_msg,
132                )
133                .await
134            }
135            GridMessage::Lock(_) => {
136                warn!("LockMessage received but not implemented yet");
137                Ok(())
138            }
139            GridMessage::Replication(_) => {
140                warn!("ReplicationMessage received but not implemented yet");
141                Ok(())
142            }
143        }
144    }
145
146    /// 쿼리(스트리밍) 메시지 처리
147    async fn handle_query_message(
148        query_streams: QueryStreamMap,
149        stage_barriers: Arc<DashMap<(String, usize, std::net::SocketAddr), mpsc::Sender<()>>>,
150        quic_channel: Arc<QuicChannel>,
151        local_executor: Option<Arc<LocalExecutor>>,
152        node_id: u32,
153        msg: QueryMessage,
154    ) -> DbxResult<()> {
155        match msg {
156            QueryMessage::ExecuteFragment {
157                execution_id,
158                stage_id,
159                plans_bytes,
160                coordinator_addr,
161            } => {
162                info!(
163                    "ExecuteFragment received for ID: {}, Stage: {} from coordinator: {}",
164                    execution_id, stage_id, coordinator_addr
165                );
166
167                let executor = match local_executor {
168                    Some(e) => e,
169                    None => {
170                        warn!(
171                            "ExecuteFragment received but no LocalExecutor configured — ignoring"
172                        );
173                        return Ok(());
174                    }
175                };
176
177                // 코디네이터 주소 파싱
178                let coord_addr: std::net::SocketAddr = match coordinator_addr.parse() {
179                    Ok(a) => a,
180                    Err(e) => {
181                        return Err(DbxError::Network(format!(
182                            "Invalid coordinator addr: {}",
183                            e
184                        )));
185                    }
186                };
187
188                let exec_id = execution_id.clone();
189                let query_streams = Arc::clone(&query_streams);
190
191                // 역직렬화
192                let mut plans: Vec<PhysicalPlan> = Vec::new();
193                for bytes in plans_bytes {
194                    let plan = bincode::deserialize(&bytes)
195                        .map_err(|e| DbxError::Serialization(e.to_string()))?;
196                    plans.push(plan);
197                }
198
199                let quic_master = Arc::clone(&quic_channel);
200                tokio::spawn(async move {
201                    info!(
202                        "Worker spawning execution for exec_id: {}, stage_id: {}",
203                        exec_id, stage_id
204                    );
205
206                    let mut join_set = tokio::task::JoinSet::new();
207
208                    for worker_plan in plans {
209                        let executor_ref = Arc::clone(&executor);
210                        let quic_ref = Arc::clone(&quic_master);
211                        let exec_id_ref = exec_id.clone();
212                        let q_streams = Arc::clone(&query_streams);
213
214                        join_set.spawn(async move {
215                            // 1. CPU-bound 쿼리 실행을 spawn_blocking 기반으로 넘김
216                            let (batches, channels) = match tokio::task::spawn_blocking(move || {
217                                let mut chs = crate::sql::executor::local_executor::DistributedChannels::default();
218                                let b = executor_ref.execute_collect_distributed(&worker_plan, &mut chs)?;
219                                Ok::<(Vec<arrow::array::RecordBatch>, _), DbxError>((b, chs))
220                            }).await {
221                                Ok(Ok(res)) => res,
222                                Ok(Err(e)) => {
223                                    error!("Worker execution error for {}: {:?}", exec_id_ref, e);
224                                    let eof_msg = GridMessage::Query(QueryMessage::ExchangeData {
225                                        execution_id: exec_id_ref.clone(),
226                                        exchange_id: 0,
227                                        node_id,
228                                        is_eof: true,
229                                        batch_data: vec![],
230                                    });
231                                    let _ = quic_ref.send_message(coord_addr, eof_msg).await;
232                                    return;
233                                }
234                                Err(e) => {
235                                    error!("Worker spawn_blocking panic: {:?}", e);
236                                    return;
237                                }
238                            };
239
240                            // 2. 수동생성된 수신 채널(tx)들을 DashMap에 등록 (GridExchange 용도)
241                            for (e_id, tx) in channels.exchanges {
242                                q_streams.insert((exec_id_ref.clone(), e_id), tx);
243                            }
244
245                            // 3. ShuffleWriter 발신 채널(rx)들을 타겟별로 묶어 송신 태스크 스폰
246                            let mut shuffle_join_set = tokio::task::JoinSet::new();
247                            for (e_id, receivers) in channels.shuffles {
248                                for (target_addr, mut rx) in receivers {
249                                    let quic_sub = Arc::clone(&quic_ref);
250                                    let exec_sub = exec_id_ref.clone();
251                                    shuffle_join_set.spawn(async move {
252                                        while let Some(Ok(Some(batch_bytes))) = rx.recv().await {
253                                            let msg = GridMessage::Query(QueryMessage::ExchangeData {
254                                                execution_id: exec_sub.clone(),
255                                                exchange_id: e_id,
256                                                node_id,
257                                                is_eof: false,
258                                                batch_data: batch_bytes,
259                                            });
260                                            let _ = quic_sub.send_message(target_addr, msg).await;
261                                        }
262                                        let eof_msg = GridMessage::Query(QueryMessage::ExchangeData {
263                                            execution_id: exec_sub,
264                                            exchange_id: e_id,
265                                            node_id,
266                                            is_eof: true,
267                                            batch_data: vec![],
268                                        });
269                                        let _ = quic_sub.send_message(target_addr, eof_msg).await;
270                                    });
271                                }
272                            }
273
274                            // 4. (분산 Agg 등) 최상위 Return RecordBatch들을 스트리밍 송신
275                            for batch in batches {
276                                let ipc_bytes = match crate::grid::protocol::serialize_batch_to_ipc(&batch) {
277                                    Ok(b) => b,
278                                    Err(_) => continue,
279                                };
280                                let msg = GridMessage::Query(QueryMessage::ExchangeData {
281                                    execution_id: exec_id_ref.clone(),
282                                    exchange_id: 0,
283                                    node_id,
284                                    is_eof: false,
285                                    batch_data: ipc_bytes,
286                                });
287                                let _ = quic_ref.send_message(coord_addr, msg).await;
288                            }
289
290                            // EOF
291                            let _ = quic_ref.send_message(coord_addr, GridMessage::Query(QueryMessage::ExchangeData {
292                                execution_id: exec_id_ref.clone(),
293                                exchange_id: 0,
294                                node_id,
295                                is_eof: true,
296                                batch_data: vec![],
297                            })).await;
298
299                            // 셔플 백그라운드 송출 태스크들이 다 이빨이 맞을 때까지 대기
300                            while shuffle_join_set.join_next().await.is_some() {}
301                        });
302                    }
303
304                    // 모든 Plan 태스크들 처리가 끝날 때까지 대기
305                    while join_set.join_next().await.is_some() {}
306
307                    // 코디네이터에게 해당 Stage의 모든 수행이 종료되었음을 알림
308                    let complete_msg = GridMessage::Query(QueryMessage::FragmentCompleted {
309                        execution_id: exec_id.clone(),
310                        stage_id,
311                    });
312                    let _ = quic_master.send_message(coord_addr, complete_msg).await;
313                    info!(
314                        "Worker completed all plans for exec_id: {}, stage_id: {}",
315                        exec_id, stage_id
316                    );
317                });
318
319                Ok(())
320            }
321            QueryMessage::FragmentCompleted {
322                execution_id,
323                stage_id,
324            } => {
325                let barriers = stage_barriers;
326                let matched_keys: Vec<_> = barriers
327                    .iter()
328                    .filter(|entry| entry.key().0 == execution_id && entry.key().1 == stage_id)
329                    .map(|entry| entry.key().clone())
330                    .collect();
331
332                for key in matched_keys {
333                    if let Some(sender) = barriers.get_mut(&key) {
334                        let _ = sender.try_send(());
335                    }
336                }
337                Ok(())
338            }
339            QueryMessage::ExchangeData {
340                execution_id,
341                exchange_id,
342                node_id: _,
343                is_eof,
344                batch_data,
345            } => {
346                // 코디네이터가 큐를 통해 Operator로 데이터 밀어넣기
347                // DashMap lock을 await 전에 해제하기 위해 Sender를 복제합니다.
348                let sender_opt = query_streams
349                    .get(&(execution_id.clone(), exchange_id))
350                    .map(|kv| kv.value().clone());
351
352                if let Some(sender) = sender_opt {
353                    if is_eof {
354                        let _ = sender.send(Ok(None)).await;
355                    } else {
356                        let _ = sender.send(Ok(Some(batch_data))).await;
357                    }
358                } else {
359                    warn!("ExchangeData for unknown execution_id: {}", execution_id);
360                }
361                Ok(())
362            }
363        }
364    }
365
366    /// 스토리지(EC 샤드) 메시지 처리
367    async fn handle_storage_message(
368        ec_store: Arc<DistributedErasureCodingStore>,
369        msg: StorageMessage,
370        stream: &mut Option<s2n_quic::stream::BidirectionalStream>,
371    ) -> DbxResult<()> {
372        match msg {
373            StorageMessage::StoreShard {
374                key,
375                shard_id,
376                data,
377            } => {
378                info!("Storing shard {}:{} locally", key, shard_id);
379                ec_store.local_store_shard(&key, shard_id, &data)?;
380                Ok(())
381            }
382            StorageMessage::FetchShard { key, shard_id } => {
383                info!("Fetching shard {}:{} for remote request", key, shard_id);
384                let shard_data = ec_store.local_fetch_shard(&key, shard_id)?;
385
386                // 응답 전송
387                if let Some(s) = stream {
388                    let reply = GridMessage::Storage(StorageMessage::ShardResponse {
389                        key: key.clone(),
390                        shard_id,
391                        data: shard_data,
392                    });
393                    ::tracing::debug!(
394                        "Sending ShardResponse for {}:{} on stream...",
395                        key,
396                        shard_id
397                    );
398                    if let Err(e) = QuicChannel::send_response(s, reply).await {
399                        ::tracing::error!(
400                            "Failed to send ShardResponse for {}:{}: {:?}",
401                            key,
402                            shard_id,
403                            e
404                        );
405                    } else {
406                        ::tracing::debug!(
407                            "Successfully sent ShardResponse for {}:{}",
408                            key,
409                            shard_id
410                        );
411                    }
412                }
413                Ok(())
414            }
415            StorageMessage::ShardResponse { .. } => {
416                // ShardResponse는 보통 send_request_and_wait에서 직접 받으므로
417                // 메인 루프에 도달했다면 무시하거나 에러 처리
418                warn!("Received unexpected ShardResponse in main handler loop");
419                Ok(())
420            }
421        }
422    }
423}