1use std::collections::VecDeque;
8use std::path::PathBuf;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use tokio::sync::{oneshot, Mutex};
14use tracing::info;
15
16use apiary_core::config::NodeConfig;
17use apiary_core::error::ApiaryError;
18use apiary_core::types::{BeeId, TaskId};
19use apiary_core::Result;
20
21#[derive(Debug, Clone, PartialEq)]
23pub enum BeeState {
24 Idle,
26 Busy(TaskId),
28}
29
30pub struct MasonChamber {
32 pub bee_id: BeeId,
34 pub memory_budget: u64,
36 pub memory_used: AtomicU64,
38 pub scratch_dir: PathBuf,
40 pub task_timeout: Duration,
42}
43
44impl MasonChamber {
45 pub fn new(
47 bee_id: BeeId,
48 memory_budget: u64,
49 scratch_dir: PathBuf,
50 task_timeout: Duration,
51 ) -> Self {
52 Self {
53 bee_id,
54 memory_budget,
55 memory_used: AtomicU64::new(0),
56 scratch_dir,
57 task_timeout,
58 }
59 }
60
61 pub fn request_memory(&self, bytes: u64) -> Result<()> {
63 let current = self.memory_used.load(Ordering::Relaxed);
64 if current + bytes > self.memory_budget {
65 return Err(ApiaryError::MemoryExceeded {
66 bee_id: self.bee_id.clone(),
67 budget: self.memory_budget,
68 requested: bytes,
69 });
70 }
71 self.memory_used.fetch_add(bytes, Ordering::Relaxed);
72 Ok(())
73 }
74
75 pub fn release_memory(&self, bytes: u64) {
77 self.memory_used.fetch_sub(bytes, Ordering::Relaxed);
78 }
79
80 pub fn utilisation(&self) -> f64 {
82 self.memory_used.load(Ordering::Relaxed) as f64 / self.memory_budget as f64
83 }
84
85 pub fn reset(&self) {
87 self.memory_used.store(0, Ordering::Relaxed);
88 if self.scratch_dir.exists() {
89 let _ = std::fs::remove_dir_all(&self.scratch_dir);
90 }
91 let _ = std::fs::create_dir_all(&self.scratch_dir);
92 }
93}
94
95pub struct Bee {
97 pub id: BeeId,
99 pub state: Mutex<BeeState>,
101 pub chamber: MasonChamber,
103}
104
105#[derive(Debug, Clone)]
107pub struct BeeStatus {
108 pub bee_id: String,
109 pub state: String,
110 pub memory_used: u64,
111 pub memory_budget: u64,
112}
113
114type TaskFn = Box<
116 dyn FnOnce() -> std::result::Result<Vec<arrow::record_batch::RecordBatch>, ApiaryError>
117 + Send
118 + 'static,
119>;
120
121struct QueuedTask {
123 task_id: TaskId,
124 func: TaskFn,
125 tx: oneshot::Sender<std::result::Result<Vec<arrow::record_batch::RecordBatch>, ApiaryError>>,
126}
127
128pub struct BeePool {
130 bees: Vec<Arc<Bee>>,
131 queue: Arc<Mutex<VecDeque<QueuedTask>>>,
132 default_timeout: Duration,
133}
134
135const DEFAULT_TASK_TIMEOUT: Duration = Duration::from_secs(30);
137
138impl BeePool {
139 pub fn new(config: &NodeConfig) -> Self {
141 let mut bees = Vec::with_capacity(config.cores);
142 for i in 0..config.cores {
143 let bee_id = BeeId::new(format!("bee-{i}"));
144 let scratch_dir = config.cache_dir.join("scratch").join(format!("bee_{i}"));
145 let _ = std::fs::create_dir_all(&scratch_dir);
146 let chamber = MasonChamber::new(
147 bee_id.clone(),
148 config.memory_per_bee,
149 scratch_dir,
150 DEFAULT_TASK_TIMEOUT,
151 );
152 bees.push(Arc::new(Bee {
153 id: bee_id,
154 state: Mutex::new(BeeState::Idle),
155 chamber,
156 }));
157 }
158 info!(
159 bees = config.cores,
160 memory_per_bee = config.memory_per_bee,
161 "BeePool created"
162 );
163 Self {
164 bees,
165 queue: Arc::new(Mutex::new(VecDeque::new())),
166 default_timeout: DEFAULT_TASK_TIMEOUT,
167 }
168 }
169
170 pub fn bee_count(&self) -> usize {
172 self.bees.len()
173 }
174
175 pub async fn status(&self) -> Vec<BeeStatus> {
177 let mut result = Vec::with_capacity(self.bees.len());
178 for bee in &self.bees {
179 let state = bee.state.lock().await;
180 result.push(BeeStatus {
181 bee_id: bee.id.to_string(),
182 state: match &*state {
183 BeeState::Idle => "idle".to_string(),
184 BeeState::Busy(tid) => format!("busy({})", tid),
185 },
186 memory_used: bee.chamber.memory_used.load(Ordering::Relaxed),
187 memory_budget: bee.chamber.memory_budget,
188 });
189 }
190 result
191 }
192
193 pub async fn queue_size(&self) -> usize {
195 let q = self.queue.lock().await;
196 q.len()
197 }
198
199 pub async fn busy_count(&self) -> usize {
201 let mut count = 0;
202 for bee in &self.bees {
203 let state = bee.state.lock().await;
204 if *state != BeeState::Idle {
205 count += 1;
206 }
207 }
208 count
209 }
210
211 pub fn avg_memory_utilisation(&self) -> f64 {
213 if self.bees.is_empty() {
214 return 0.0;
215 }
216 let sum: f64 = self.bees.iter().map(|b| b.chamber.utilisation()).sum();
217 sum / self.bees.len() as f64
218 }
219
220 pub async fn submit<F>(
226 &self,
227 func: F,
228 ) -> tokio::task::JoinHandle<
229 std::result::Result<Vec<arrow::record_batch::RecordBatch>, ApiaryError>,
230 >
231 where
232 F: FnOnce() -> std::result::Result<Vec<arrow::record_batch::RecordBatch>, ApiaryError>
233 + Send
234 + 'static,
235 {
236 let task_id = TaskId::generate();
237
238 let total = self.bees.len() as f64;
240 if total > 0.0 {
241 let busy = self.busy_count().await as f64;
242 let cpu_util = busy / total;
243 let mem_pressure = self.avg_memory_utilisation();
244 let queue_size = {
245 let q = self.queue.lock().await;
246 q.len() as f64
247 };
248 let queue_pressure = (queue_size / (total * 2.0)).min(1.0);
250 let temperature = 0.4 * cpu_util + 0.4 * mem_pressure + 0.2 * queue_pressure;
251 if temperature > 0.95 {
252 return tokio::task::spawn(async {
253 Err(ApiaryError::Internal {
254 message: "Colony temperature critical (> 0.95). Task rejected to protect system stability.".to_string(),
255 })
256 });
257 }
258 }
259
260 for bee in &self.bees {
262 let mut state = bee.state.lock().await;
263 if *state == BeeState::Idle {
264 *state = BeeState::Busy(task_id.clone());
265 drop(state);
266 return self.spawn_on_bee(Arc::clone(bee), task_id, Box::new(func));
267 }
268 }
269
270 let (tx, rx) = oneshot::channel();
272 {
273 let mut q = self.queue.lock().await;
274 q.push_back(QueuedTask {
275 task_id,
276 func: Box::new(func),
277 tx,
278 });
279 }
280 info!("All bees busy, task queued");
281
282 tokio::task::spawn(async move {
284 rx.await.unwrap_or_else(|_| {
285 Err(ApiaryError::Internal {
286 message: "Task channel closed before result".to_string(),
287 })
288 })
289 })
290 }
291
292 fn spawn_on_bee(
294 &self,
295 bee: Arc<Bee>,
296 task_id: TaskId,
297 func: TaskFn,
298 ) -> tokio::task::JoinHandle<
299 std::result::Result<Vec<arrow::record_batch::RecordBatch>, ApiaryError>,
300 > {
301 let timeout = self.default_timeout;
302 let queue = Arc::clone(&self.queue);
303 let bees = self.bees.clone();
304 let default_timeout = self.default_timeout;
305
306 tokio::task::spawn(async move {
307 let result = tokio::time::timeout(timeout, tokio::task::spawn_blocking(func)).await;
309
310 bee.chamber.reset();
312
313 {
315 let mut state = bee.state.lock().await;
316 *state = BeeState::Idle;
317 }
318
319 drain_queue_once(queue, bees, default_timeout).await;
321
322 match result {
323 Ok(Ok(task_result)) => task_result,
324 Ok(Err(join_err)) => Err(ApiaryError::Internal {
325 message: format!("Task panicked: {join_err}"),
326 }),
327 Err(_) => Err(ApiaryError::TaskTimeout {
328 message: format!("Task {task_id} exceeded {timeout:?} timeout"),
329 }),
330 }
331 })
332 }
333}
334
335fn drain_queue_once(
338 queue: Arc<Mutex<VecDeque<QueuedTask>>>,
339 bees: Vec<Arc<Bee>>,
340 timeout: Duration,
341) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>> {
342 Box::pin(async move {
343 let mut idle_bee = None;
345 for bee in &bees {
346 let state = bee.state.lock().await;
347 if *state == BeeState::Idle {
348 idle_bee = Some(Arc::clone(bee));
349 break;
350 }
351 }
352
353 let Some(bee) = idle_bee else {
354 return;
355 };
356
357 let queued = {
359 let mut q = queue.lock().await;
360 q.pop_front()
361 };
362
363 let Some(queued) = queued else {
364 return;
365 };
366
367 {
369 let mut state = bee.state.lock().await;
370 *state = BeeState::Busy(queued.task_id.clone());
371 }
372
373 let task_id = queued.task_id;
374 let func = queued.func;
375 let tx = queued.tx;
376 let queue_clone = Arc::clone(&queue);
377 let bees_clone = bees.clone();
378
379 tokio::task::spawn(async move {
381 let result = tokio::time::timeout(timeout, tokio::task::spawn_blocking(func)).await;
382
383 bee.chamber.reset();
384 {
385 let mut state = bee.state.lock().await;
386 *state = BeeState::Idle;
387 }
388
389 let final_result = match result {
390 Ok(Ok(task_result)) => task_result,
391 Ok(Err(join_err)) => Err(ApiaryError::Internal {
392 message: format!("Task panicked: {join_err}"),
393 }),
394 Err(_) => Err(ApiaryError::TaskTimeout {
395 message: format!("Task {task_id} exceeded {timeout:?} timeout"),
396 }),
397 };
398
399 let _ = tx.send(final_result);
400
401 drain_queue_once(queue_clone, bees_clone, timeout).await;
403 });
404 })
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410 use std::sync::atomic::AtomicUsize;
411
412 fn test_config(cores: usize) -> (NodeConfig, tempfile::TempDir) {
413 let tmp = tempfile::TempDir::new().unwrap();
414 let mut config = NodeConfig::detect("local://test");
415 config.cores = cores;
416 config.memory_per_bee = 1024 * 1024; config.cache_dir = tmp.path().to_path_buf();
418 (config, tmp)
419 }
420
421 #[tokio::test]
422 async fn test_bee_pool_creates_correct_number_of_bees() {
423 let (config, _tmp) = test_config(4);
424 let pool = BeePool::new(&config);
425 assert_eq!(pool.bee_count(), 4);
426 let status = pool.status().await;
427 assert_eq!(status.len(), 4);
428 for s in &status {
429 assert_eq!(s.state, "idle");
430 assert_eq!(s.memory_budget, 1024 * 1024);
431 assert_eq!(s.memory_used, 0);
432 }
433 }
434
435 #[tokio::test]
436 async fn test_bee_pool_executes_task() {
437 let (config, _tmp) = test_config(2);
438 let pool = BeePool::new(&config);
439
440 let handle = pool.submit(|| Ok(vec![])).await;
441 let result = handle.await.unwrap();
442 assert!(result.is_ok());
443 }
444
445 #[tokio::test]
446 async fn test_bee_returns_to_idle_after_task() {
447 let (config, _tmp) = test_config(1);
448 let pool = BeePool::new(&config);
449
450 let handle = pool.submit(|| Ok(vec![])).await;
451 handle.await.unwrap().unwrap();
452
453 tokio::time::sleep(Duration::from_millis(50)).await;
455
456 let status = pool.status().await;
457 assert_eq!(status[0].state, "idle");
458 }
459
460 #[tokio::test]
461 async fn test_memory_enforcement() {
462 let bee_id = BeeId::new("test-bee");
463 let tmp = tempfile::TempDir::new().unwrap();
464 let chamber = MasonChamber::new(
465 bee_id.clone(),
466 1000, tmp.path().to_path_buf(),
468 Duration::from_secs(10),
469 );
470
471 assert!(chamber.request_memory(500).is_ok());
473 assert_eq!(chamber.memory_used.load(Ordering::Relaxed), 500);
474
475 assert!(chamber.request_memory(400).is_ok());
477 assert_eq!(chamber.memory_used.load(Ordering::Relaxed), 900);
478
479 let err = chamber.request_memory(200);
481 assert!(err.is_err());
482 match err.unwrap_err() {
483 ApiaryError::MemoryExceeded {
484 bee_id: id,
485 budget,
486 requested,
487 } => {
488 assert_eq!(id, bee_id);
489 assert_eq!(budget, 1000);
490 assert_eq!(requested, 200);
491 }
492 other => panic!("Expected MemoryExceeded, got: {:?}", other),
493 }
494
495 assert_eq!(chamber.memory_used.load(Ordering::Relaxed), 900);
497
498 chamber.release_memory(400);
500 assert_eq!(chamber.memory_used.load(Ordering::Relaxed), 500);
501 assert!(chamber.request_memory(200).is_ok());
502 }
503
504 #[tokio::test]
505 async fn test_memory_exceeded_does_not_affect_other_bees() {
506 let (config, _tmp) = test_config(2);
507 let pool = BeePool::new(&config);
508
509 let budget = config.memory_per_bee;
510
511 let handle = pool
513 .submit(move || {
514 Err(ApiaryError::MemoryExceeded {
515 bee_id: BeeId::new("bee-0"),
516 budget,
517 requested: budget + 1,
518 })
519 })
520 .await;
521
522 let result = handle.await.unwrap();
523 assert!(result.is_err());
524
525 tokio::time::sleep(Duration::from_millis(50)).await;
527
528 let handle2 = pool.submit(|| Ok(vec![])).await;
530 let result2 = handle2.await.unwrap();
531 assert!(result2.is_ok());
532 }
533
534 #[tokio::test]
535 async fn test_task_timeout() {
536 let (config, _tmp) = test_config(1);
537 let mut pool = BeePool::new(&config);
538 pool.default_timeout = Duration::from_millis(100);
539
540 let handle = pool
541 .submit(|| {
542 std::thread::sleep(Duration::from_secs(5));
543 Ok(vec![])
544 })
545 .await;
546
547 let result = handle.await.unwrap();
548 assert!(result.is_err());
549 match result.unwrap_err() {
550 ApiaryError::TaskTimeout { message } => {
551 assert!(message.contains("timeout"), "Got: {message}");
552 }
553 other => panic!("Expected TaskTimeout, got: {:?}", other),
554 }
555 }
556
557 #[tokio::test]
558 async fn test_scratch_directory_isolated_and_cleaned() {
559 let tmp = tempfile::TempDir::new().unwrap();
560 let mut config = NodeConfig::detect("local://test");
561 config.cores = 2;
562 config.memory_per_bee = 1024 * 1024;
563 config.cache_dir = tmp.path().to_path_buf();
564
565 let pool = BeePool::new(&config);
566
567 let scratch_0 = tmp.path().join("scratch").join("bee_0");
569 let scratch_1 = tmp.path().join("scratch").join("bee_1");
570 assert!(scratch_0.exists());
571 assert!(scratch_1.exists());
572
573 let scratch_0_clone = scratch_0.clone();
575 let handle = pool
576 .submit(move || {
577 std::fs::write(scratch_0_clone.join("test.tmp"), b"hello").unwrap();
578 Ok(vec![])
579 })
580 .await;
581 handle.await.unwrap().unwrap();
582
583 tokio::time::sleep(Duration::from_millis(50)).await;
585 assert!(scratch_0.exists());
587 let entries: Vec<_> = std::fs::read_dir(&scratch_0).unwrap().collect();
588 assert!(
589 entries.is_empty(),
590 "Scratch dir should be cleaned after task"
591 );
592 }
593
594 #[tokio::test]
595 async fn test_concurrent_tasks_on_separate_bees() {
596 let (config, _tmp) = test_config(3);
597 let pool = Arc::new(BeePool::new(&config));
598
599 let counter = Arc::new(AtomicUsize::new(0));
600
601 let mut handles = vec![];
602 for _ in 0..3 {
603 let c = counter.clone();
604 let h = pool
605 .submit(move || {
606 c.fetch_add(1, Ordering::Relaxed);
607 std::thread::sleep(Duration::from_millis(50));
608 Ok(vec![])
609 })
610 .await;
611 handles.push(h);
612 }
613
614 for h in handles {
615 h.await.unwrap().unwrap();
616 }
617
618 assert_eq!(counter.load(Ordering::Relaxed), 3);
619 }
620
621 #[tokio::test]
622 async fn test_tasks_queue_when_all_bees_busy() {
623 let (config, _tmp) = test_config(1); let pool = Arc::new(BeePool::new(&config));
625
626 let counter = Arc::new(AtomicUsize::new(0));
627
628 let mut handles = vec![];
630 for _ in 0..3 {
631 let c = counter.clone();
632 let h = pool
633 .submit(move || {
634 c.fetch_add(1, Ordering::Relaxed);
635 std::thread::sleep(Duration::from_millis(20));
636 Ok(vec![])
637 })
638 .await;
639 handles.push(h);
640 }
641
642 for h in handles {
643 h.await.unwrap().unwrap();
644 }
645
646 assert_eq!(counter.load(Ordering::Relaxed), 3);
648 }
649
650 #[tokio::test]
651 async fn test_chamber_utilisation() {
652 let bee_id = BeeId::new("test-bee");
653 let tmp = tempfile::TempDir::new().unwrap();
654 let chamber = MasonChamber::new(
655 bee_id,
656 1000,
657 tmp.path().to_path_buf(),
658 Duration::from_secs(10),
659 );
660
661 assert!((chamber.utilisation() - 0.0).abs() < f64::EPSILON);
662 chamber.request_memory(500).unwrap();
663 assert!((chamber.utilisation() - 0.5).abs() < f64::EPSILON);
664 chamber.request_memory(500).unwrap();
665 assert!((chamber.utilisation() - 1.0).abs() < f64::EPSILON);
666 }
667
668 #[tokio::test]
669 async fn test_chamber_reset() {
670 let bee_id = BeeId::new("test-bee");
671 let tmp = tempfile::TempDir::new().unwrap();
672 let scratch = tmp.path().join("scratch");
673 std::fs::create_dir_all(&scratch).unwrap();
674 std::fs::write(scratch.join("leftover.tmp"), b"data").unwrap();
675
676 let chamber = MasonChamber::new(bee_id, 1000, scratch.clone(), Duration::from_secs(10));
677 chamber.request_memory(800).unwrap();
678 assert_eq!(chamber.memory_used.load(Ordering::Relaxed), 800);
679
680 chamber.reset();
681 assert_eq!(chamber.memory_used.load(Ordering::Relaxed), 0);
682 assert!(scratch.exists());
683 let entries: Vec<_> = std::fs::read_dir(&scratch).unwrap().collect();
684 assert!(entries.is_empty());
685 }
686
687 #[tokio::test]
688 async fn test_busy_count_and_avg_memory_utilisation() {
689 let (config, _tmp) = test_config(3);
690 let pool = BeePool::new(&config);
691
692 assert_eq!(pool.busy_count().await, 0);
694 assert!((pool.avg_memory_utilisation() - 0.0).abs() < f64::EPSILON);
695
696 let (tx, rx) = tokio::sync::oneshot::channel::<()>();
698 let handle = pool
699 .submit(move || {
700 let _ = rx.blocking_recv();
702 Ok(vec![])
703 })
704 .await;
705
706 tokio::time::sleep(Duration::from_millis(50)).await;
708
709 assert_eq!(pool.busy_count().await, 1);
710
711 let _ = tx.send(());
713 handle.await.unwrap().unwrap();
714 tokio::time::sleep(Duration::from_millis(50)).await;
715
716 assert_eq!(pool.busy_count().await, 0);
717 }
718}