1use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::{SystemTime, UNIX_EPOCH};
11
12use arrow::ipc::reader::FileReader;
13use arrow::ipc::writer::FileWriter;
14use arrow::record_batch::RecordBatch;
15use serde::{Deserialize, Serialize};
16use tracing::{info, warn};
17use uuid::Uuid;
18
19use apiary_core::error::ApiaryError;
20use apiary_core::storage::StorageBackend;
21use apiary_core::types::NodeId;
22use apiary_core::Result;
23
24#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
26pub enum NodeState {
27 Alive,
29 Suspect,
31 Dead,
33}
34
35#[derive(Clone, Debug)]
37pub struct CellInfo {
38 pub storage_key: String,
40 pub bytes: u64,
42 pub partition: Vec<(String, String)>,
44}
45
46#[derive(Clone, Debug)]
48pub struct NodeInfo {
49 pub node_id: NodeId,
50 pub state: NodeState,
51 pub cores: usize,
52 pub memory_bytes: u64,
53 pub memory_per_bee: u64,
54 pub target_cell_size: u64,
55 pub bees_total: usize,
56 pub bees_busy: usize,
57 pub idle_bees: usize,
58 pub cached_cells: HashMap<String, u64>,
60}
61
62#[derive(Clone, Debug)]
64pub enum QueryPlan {
65 Local { cells: Vec<CellInfo> },
67 Distributed {
69 assignments: HashMap<NodeId, Vec<CellInfo>>,
70 },
71}
72
73#[derive(Clone, Debug, Serialize, Deserialize)]
75pub struct PlannedTask {
76 pub task_id: String,
78 pub node_id: NodeId,
80 pub cells: Vec<String>,
82 pub sql_fragment: String,
84}
85
86#[derive(Clone, Debug, Serialize, Deserialize)]
88pub struct QueryManifest {
89 pub query_id: String,
91 pub original_sql: String,
93 pub tasks: Vec<PlannedTask>,
95 pub merge_sql: Option<String>,
97 pub timeout_secs: u64,
99 pub created_at: u64,
101}
102
103pub fn plan_query(
110 cells: Vec<CellInfo>,
111 nodes: Vec<NodeInfo>,
112 local_node_id: &NodeId,
113) -> Result<QueryPlan> {
114 if nodes.is_empty() {
115 return Err(ApiaryError::Internal {
116 message: "No alive nodes in world view".into(),
117 });
118 }
119
120 let local_node = nodes
122 .iter()
123 .find(|n| &n.node_id == local_node_id)
124 .ok_or_else(|| ApiaryError::Internal {
125 message: format!("Local node {} not found in world view", local_node_id),
126 })?;
127
128 let total_size: u64 = cells.iter().map(|c| c.bytes).sum();
130
131 if total_size < local_node.memory_per_bee && local_node.idle_bees > 0 {
133 info!(
134 total_size_mb = total_size / (1024 * 1024),
135 bee_budget_mb = local_node.memory_per_bee / (1024 * 1024),
136 "Query fits in one bee, executing locally"
137 );
138 return Ok(QueryPlan::Local { cells });
139 }
140
141 if nodes.len() == 1 {
143 info!("Only one node available, executing locally");
144 return Ok(QueryPlan::Local { cells });
145 }
146
147 info!(
149 total_cells = cells.len(),
150 total_size_mb = total_size / (1024 * 1024),
151 alive_nodes = nodes.len(),
152 "Distributing query across swarm"
153 );
154
155 let assignments = assign_cells_to_nodes(cells, &nodes);
156
157 if assignments.is_empty() {
158 return Err(ApiaryError::Internal {
159 message: "Failed to assign cells to any node".into(),
160 });
161 }
162
163 Ok(QueryPlan::Distributed { assignments })
164}
165
166pub fn extract_alive_nodes<T>(
169 world_view_nodes: &HashMap<NodeId, T>,
170 node_extractor: impl Fn(&T) -> Option<NodeInfo>,
171) -> Vec<NodeInfo> {
172 world_view_nodes
173 .values()
174 .filter_map(node_extractor)
175 .collect()
176}
177
178fn assign_cells_to_nodes(
180 cells: Vec<CellInfo>,
181 nodes: &[NodeInfo],
182) -> HashMap<NodeId, Vec<CellInfo>> {
183 let mut assignments: HashMap<NodeId, Vec<CellInfo>> = HashMap::new();
184
185 for cell in cells {
186 let caching_node = nodes
188 .iter()
189 .filter(|n| n.idle_bees > 0)
190 .find(|n| n.cached_cells.contains_key(&cell.storage_key));
191
192 if let Some(node) = caching_node {
193 assignments
195 .entry(node.node_id.clone())
196 .or_default()
197 .push(cell);
198 continue;
199 }
200
201 if let Some(best_node) = nodes
203 .iter()
204 .filter(|n| n.idle_bees > 0)
205 .max_by_key(|n| n.idle_bees)
206 {
207 assignments
208 .entry(best_node.node_id.clone())
209 .or_default()
210 .push(cell);
211 }
212 }
213
214 leafcutter_split_assignments(&mut assignments, nodes);
216
217 assignments
218}
219
220fn leafcutter_split_assignments(
223 assignments: &mut HashMap<NodeId, Vec<CellInfo>>,
224 nodes: &[NodeInfo],
225) {
226 let mut overflow = Vec::new();
227
228 for (node_id, cells) in assignments.iter_mut() {
230 if let Some(node) = nodes.iter().find(|n| &n.node_id == node_id) {
231 let total: u64 = cells.iter().map(|c| c.bytes).sum();
232 if total > node.memory_per_bee && cells.len() > 1 {
233 let mut kept_size: u64 = 0;
235 let mut keep = Vec::new();
236 for cell in cells.drain(..) {
237 if kept_size + cell.bytes <= node.memory_per_bee || keep.is_empty() {
240 kept_size += cell.bytes;
241 keep.push(cell);
242 } else {
243 overflow.push(cell);
244 }
245 }
246 *cells = keep;
247 }
248 }
249 }
250
251 for cell in overflow {
253 let best = nodes
254 .iter()
255 .filter(|n| n.idle_bees > 0)
256 .filter(|n| {
257 let current: u64 = assignments
258 .get(&n.node_id)
259 .map(|c| c.iter().map(|ci| ci.bytes).sum())
260 .unwrap_or(0);
261 current + cell.bytes <= n.memory_per_bee
262 })
263 .max_by_key(|n| n.idle_bees);
264
265 if let Some(node) = best {
266 assignments
267 .entry(node.node_id.clone())
268 .or_default()
269 .push(cell);
270 } else {
271 if let Some(node) = nodes
273 .iter()
274 .filter(|n| n.idle_bees > 0)
275 .max_by_key(|n| n.idle_bees)
276 {
277 assignments
278 .entry(node.node_id.clone())
279 .or_default()
280 .push(cell);
281 }
282 }
283 }
284}
285
286pub fn generate_sql_fragment(
291 original_sql: &str,
292 _is_aggregation: bool,
293) -> (String, Option<String>) {
294 (original_sql.to_string(), None)
297}
298
299pub fn manifest_path(query_id: &str) -> String {
301 format!("_queries/{}/manifest.json", query_id)
302}
303
304pub fn partial_result_path(query_id: &str, node_id: &NodeId) -> String {
306 format!("_queries/{}/partial_{}.arrow", query_id, node_id)
307}
308
309pub fn create_manifest(
311 original_sql: &str,
312 tasks: Vec<PlannedTask>,
313 merge_sql: Option<String>,
314 timeout_secs: u64,
315) -> QueryManifest {
316 let query_id = Uuid::new_v4().to_string();
317 let created_at = SystemTime::now()
318 .duration_since(UNIX_EPOCH)
319 .unwrap()
320 .as_secs();
321
322 QueryManifest {
323 query_id,
324 original_sql: original_sql.to_string(),
325 tasks,
326 merge_sql,
327 timeout_secs,
328 created_at,
329 }
330}
331
332pub async fn write_manifest(
334 storage: &Arc<dyn StorageBackend>,
335 manifest: &QueryManifest,
336) -> Result<()> {
337 let path = manifest_path(&manifest.query_id);
338 let json = serde_json::to_vec(manifest).map_err(|e| ApiaryError::Internal {
339 message: format!("Failed to serialize manifest: {}", e),
340 })?;
341
342 storage.put(&path, json.into()).await?;
343 info!(query_id = %manifest.query_id, "Query manifest written");
344 Ok(())
345}
346
347pub async fn read_manifest(
349 storage: &Arc<dyn StorageBackend>,
350 query_id: &str,
351) -> Result<QueryManifest> {
352 let path = manifest_path(query_id);
353 let bytes = storage.get(&path).await?;
354 let manifest = serde_json::from_slice(&bytes).map_err(|e| ApiaryError::Internal {
355 message: format!("Failed to deserialize manifest: {}", e),
356 })?;
357 Ok(manifest)
358}
359
360pub async fn write_partial_result(
362 storage: &Arc<dyn StorageBackend>,
363 query_id: &str,
364 node_id: &NodeId,
365 batches: &[RecordBatch],
366) -> Result<()> {
367 if batches.is_empty() {
368 return Err(ApiaryError::Internal {
369 message: "Cannot write empty partial result".into(),
370 });
371 }
372
373 let path = partial_result_path(query_id, node_id);
374
375 let mut buf = Vec::new();
377 {
378 let mut writer = FileWriter::try_new(&mut buf, &batches[0].schema()).map_err(|e| {
379 ApiaryError::Internal {
380 message: format!("Failed to create Arrow writer: {}", e),
381 }
382 })?;
383
384 for batch in batches {
385 writer.write(batch).map_err(|e| ApiaryError::Internal {
386 message: format!("Failed to write batch: {}", e),
387 })?;
388 }
389
390 writer.finish().map_err(|e| ApiaryError::Internal {
391 message: format!("Failed to finish Arrow writer: {}", e),
392 })?;
393 }
394
395 storage.put(&path, buf.into()).await?;
396 info!(query_id = %query_id, node_id = %node_id, "Partial result written");
397 Ok(())
398}
399
400pub async fn read_partial_result(
402 storage: &Arc<dyn StorageBackend>,
403 query_id: &str,
404 node_id: &NodeId,
405) -> Result<Vec<RecordBatch>> {
406 let path = partial_result_path(query_id, node_id);
407 let bytes = storage.get(&path).await?;
408
409 let cursor = std::io::Cursor::new(bytes.to_vec());
410 let reader = FileReader::try_new(cursor, None).map_err(|e| ApiaryError::Internal {
411 message: format!("Failed to create Arrow reader: {}", e),
412 })?;
413
414 let batches: Result<Vec<_>> = reader
415 .map(|result| {
416 result.map_err(|e| ApiaryError::Internal {
417 message: format!("Failed to read batch: {}", e),
418 })
419 })
420 .collect();
421
422 batches
423}
424
425pub async fn cleanup_query(storage: &Arc<dyn StorageBackend>, query_id: &str) -> Result<()> {
427 let prefix = format!("_queries/{}/", query_id);
428 let keys = storage.list(&prefix).await?;
429
430 for key in keys {
431 if let Err(e) = storage.delete(&key).await {
432 warn!(key = %key, error = %e, "Failed to delete query file");
433 }
434 }
435
436 info!(query_id = %query_id, "Query files cleaned up");
437 Ok(())
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443
444 fn mock_cell_info(key: &str, bytes: u64) -> CellInfo {
445 CellInfo {
446 storage_key: key.to_string(),
447 bytes,
448 partition: vec![],
449 }
450 }
451
452 #[test]
453 fn test_assign_cells_prefers_cache_locality() {
454 let cells = vec![
455 mock_cell_info("cell1", 100_000_000),
456 mock_cell_info("cell2", 100_000_000),
457 ];
458
459 let mut cached = HashMap::new();
460 cached.insert("cell1".to_string(), 100_000_000);
461
462 let nodes = vec![
463 NodeInfo {
464 node_id: NodeId::from("node1"),
465 state: NodeState::Alive,
466 cores: 4,
467 memory_bytes: 4_000_000_000,
468 memory_per_bee: 1_000_000_000,
469 target_cell_size: 256_000_000,
470 bees_total: 4,
471 bees_busy: 0,
472 idle_bees: 4,
473 cached_cells: cached,
474 },
475 NodeInfo {
476 node_id: NodeId::from("node2"),
477 state: NodeState::Alive,
478 cores: 4,
479 memory_bytes: 4_000_000_000,
480 memory_per_bee: 1_000_000_000,
481 target_cell_size: 256_000_000,
482 bees_total: 4,
483 bees_busy: 0,
484 idle_bees: 4,
485 cached_cells: HashMap::new(),
486 },
487 ];
488
489 let assignments = assign_cells_to_nodes(cells, &nodes);
490
491 assert!(assignments.contains_key(&NodeId::from("node1")));
493 let node1_cells = assignments.get(&NodeId::from("node1")).unwrap();
494 assert_eq!(node1_cells.len(), 1);
495 assert_eq!(node1_cells[0].storage_key, "cell1");
496 }
497
498 #[test]
499 fn test_assign_cells_distributes_to_idle_nodes() {
500 let cells = vec![
501 mock_cell_info("cell1", 100_000_000),
502 mock_cell_info("cell2", 100_000_000),
503 ];
504
505 let nodes = vec![
506 NodeInfo {
507 node_id: NodeId::from("node1"),
508 state: NodeState::Alive,
509 cores: 4,
510 memory_bytes: 4_000_000_000,
511 memory_per_bee: 1_000_000_000,
512 target_cell_size: 256_000_000,
513 bees_total: 4,
514 bees_busy: 3,
515 idle_bees: 1,
516 cached_cells: HashMap::new(),
517 },
518 NodeInfo {
519 node_id: NodeId::from("node2"),
520 state: NodeState::Alive,
521 cores: 4,
522 memory_bytes: 4_000_000_000,
523 memory_per_bee: 1_000_000_000,
524 target_cell_size: 256_000_000,
525 bees_total: 4,
526 bees_busy: 0,
527 idle_bees: 4,
528 cached_cells: HashMap::new(),
529 },
530 ];
531
532 let assignments = assign_cells_to_nodes(cells, &nodes);
533
534 assert!(assignments.contains_key(&NodeId::from("node2")));
536 let node2_cells = assignments.get(&NodeId::from("node2")).unwrap();
537 assert_eq!(node2_cells.len(), 2);
538 }
539
540 #[test]
541 fn test_leafcutter_split_redistributes_excess() {
542 let node1_id = NodeId::from("node1");
543 let node2_id = NodeId::from("node2");
544
545 let mut assignments = HashMap::new();
547 assignments.insert(
548 node1_id.clone(),
549 vec![
550 mock_cell_info("c1", 100_000_000),
551 mock_cell_info("c2", 100_000_000),
552 mock_cell_info("c3", 100_000_000),
553 ],
554 );
555
556 let nodes = vec![
557 NodeInfo {
558 node_id: node1_id.clone(),
559 state: NodeState::Alive,
560 cores: 4,
561 memory_bytes: 4_000_000_000,
562 memory_per_bee: 200_000_000, target_cell_size: 256_000_000,
564 bees_total: 4,
565 bees_busy: 0,
566 idle_bees: 4,
567 cached_cells: HashMap::new(),
568 },
569 NodeInfo {
570 node_id: node2_id.clone(),
571 state: NodeState::Alive,
572 cores: 4,
573 memory_bytes: 4_000_000_000,
574 memory_per_bee: 200_000_000, target_cell_size: 256_000_000,
576 bees_total: 4,
577 bees_busy: 0,
578 idle_bees: 4,
579 cached_cells: HashMap::new(),
580 },
581 ];
582
583 leafcutter_split_assignments(&mut assignments, &nodes);
584
585 let node1_total: u64 = assignments
586 .get(&node1_id)
587 .map(|c| c.iter().map(|ci| ci.bytes).sum())
588 .unwrap_or(0);
589 assert!(
590 node1_total <= 200_000_000,
591 "node1 should not exceed its budget"
592 );
593
594 let node2_cells = assignments.get(&node2_id).unwrap();
596 assert!(
597 !node2_cells.is_empty(),
598 "node2 should receive overflow cells"
599 );
600
601 let total_cells: usize = assignments.values().map(|c| c.len()).sum();
603 assert_eq!(total_cells, 3);
604 }
605}