1use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Duration;
11
12use arrow::record_batch::RecordBatch;
13use tokio::sync::RwLock;
14use tracing::info;
15
16use apiary_core::config::NodeConfig;
17use apiary_core::error::ApiaryError;
18use apiary_core::registry_manager::RegistryManager;
19use apiary_core::storage::StorageBackend;
20use apiary_core::{CellSizingPolicy, FrameSchema, LedgerAction, Result, WriteResult};
21use apiary_query::ApiaryQueryContext;
22use apiary_storage::cell_reader::CellReader;
23use apiary_storage::cell_writer::CellWriter;
24use apiary_storage::ledger::Ledger;
25use apiary_storage::local::LocalBackend;
26use apiary_storage::s3::S3Backend;
27
28use crate::bee::{BeePool, BeeStatus};
29use crate::behavioral::{AbandonmentTracker, ColonyThermometer};
30use crate::cache::CellCache;
31use crate::heartbeat::{HeartbeatWriter, NodeState, WorldView, WorldViewBuilder};
32
33pub struct ApiaryNode {
40 pub config: NodeConfig,
42
43 pub storage: Arc<dyn StorageBackend>,
45
46 pub registry: Arc<RegistryManager>,
48
49 pub query_ctx: Arc<tokio::sync::Mutex<ApiaryQueryContext>>,
51
52 pub bee_pool: Arc<BeePool>,
54
55 pub cell_cache: Arc<CellCache>,
57
58 pub thermometer: ColonyThermometer,
60
61 pub abandonment_tracker: Arc<AbandonmentTracker>,
63
64 heartbeat_writer: Arc<HeartbeatWriter>,
66
67 world_view: Arc<RwLock<WorldView>>,
69
70 #[allow(dead_code)]
72 world_view_builder: Arc<WorldViewBuilder>,
73
74 cancel_tx: tokio::sync::watch::Sender<bool>,
76}
77
78impl ApiaryNode {
79 pub async fn start(config: NodeConfig) -> Result<Self> {
84 let storage: Arc<dyn StorageBackend> = if config.storage_uri.starts_with("s3://") {
85 Arc::new(S3Backend::new(&config.storage_uri)?)
86 } else {
87 let path = config
89 .storage_uri
90 .strip_prefix("local://")
91 .unwrap_or(&config.storage_uri);
92
93 let expanded = if path.starts_with("~/") || path.starts_with("~\\") {
95 let home = home_dir().ok_or_else(|| ApiaryError::Config {
96 message: "Cannot determine home directory".to_string(),
97 })?;
98 home.join(&path[2..])
99 } else {
100 std::path::PathBuf::from(path)
101 };
102
103 Arc::new(LocalBackend::new(expanded).await?)
104 };
105
106 info!(
107 node_id = %config.node_id,
108 cores = config.cores,
109 memory_mb = config.memory_bytes / (1024 * 1024),
110 memory_per_bee_mb = config.memory_per_bee / (1024 * 1024),
111 target_cell_size_mb = config.target_cell_size / (1024 * 1024),
112 storage_uri = %config.storage_uri,
113 "Apiary node started"
114 );
115
116 let registry = Arc::new(RegistryManager::new(Arc::clone(&storage)));
118 let _ = registry.load_or_create().await?;
119 info!("Registry loaded");
120
121 let query_ctx = Arc::new(tokio::sync::Mutex::new(ApiaryQueryContext::with_node_id(
123 Arc::clone(&storage),
124 Arc::clone(®istry),
125 config.node_id.clone(),
126 )));
127
128 let bee_pool = Arc::new(BeePool::new(&config));
130 info!(bees = config.cores, "Bee pool initialized");
131
132 let cache_dir = config.cache_dir.join("cells");
134 let cell_cache =
135 Arc::new(CellCache::new(cache_dir, config.max_cache_size, Arc::clone(&storage)).await?);
136 info!(
137 max_cache_mb = config.max_cache_size / (1024 * 1024),
138 "Cell cache initialized"
139 );
140
141 let heartbeat_writer = Arc::new(HeartbeatWriter::new(
143 Arc::clone(&storage),
144 &config,
145 Arc::clone(&bee_pool),
146 Arc::clone(&cell_cache),
147 ));
148
149 let world_view_builder = Arc::new(WorldViewBuilder::new(
151 Arc::clone(&storage),
152 config.heartbeat_interval, config.dead_threshold,
154 ));
155 let world_view = world_view_builder.world_view();
156
157 heartbeat_writer.write_once().await?;
160 world_view_builder.poll_once().await?;
161 info!("Initial heartbeat written and world view built");
162
163 let (cancel_tx, cancel_rx) = tokio::sync::watch::channel(false);
165
166 {
168 let writer = Arc::clone(&heartbeat_writer);
169 let rx = cancel_rx.clone();
170 tokio::spawn(async move {
171 writer.run(rx).await;
172 });
173 }
174
175 {
177 let builder = Arc::clone(&world_view_builder);
178 let rx = cancel_rx.clone();
179 tokio::spawn(async move {
180 builder.run(rx).await;
181 });
182 }
183
184 {
186 let storage = Arc::clone(&storage);
187 let query_ctx = Arc::clone(&query_ctx);
188 let node_id = config.node_id.clone();
189 let rx = cancel_rx.clone();
190 tokio::spawn(async move {
191 run_query_worker_poller(storage, query_ctx, node_id, rx).await;
192 });
193 }
194
195 info!("Heartbeat, world view, and query worker background tasks started");
196
197 Ok(Self {
198 config,
199 storage,
200 registry,
201 query_ctx,
202 bee_pool,
203 cell_cache,
204 thermometer: ColonyThermometer::default(),
205 abandonment_tracker: Arc::new(AbandonmentTracker::default()),
206 heartbeat_writer,
207 world_view,
208 world_view_builder,
209 cancel_tx,
210 })
211 }
212
213 pub async fn shutdown(&self) {
218 info!(node_id = %self.config.node_id, "Apiary node shutting down");
219
220 let _ = self.cancel_tx.send(true);
222
223 tokio::time::sleep(Duration::from_millis(100)).await;
225
226 if let Err(e) = self.heartbeat_writer.delete_heartbeat().await {
228 tracing::warn!(error = %e, "Failed to delete heartbeat during shutdown");
229 } else {
230 info!(node_id = %self.config.node_id, "Heartbeat deleted (graceful departure)");
231 }
232 }
233
234 pub async fn write_to_frame(
242 &self,
243 hive: &str,
244 box_name: &str,
245 frame_name: &str,
246 batch: &RecordBatch,
247 ) -> Result<WriteResult> {
248 let start = std::time::Instant::now();
249
250 let frame = self.registry.get_frame(hive, box_name, frame_name).await?;
252 let schema = FrameSchema::from_json_value(&frame.schema)?;
253 let frame_path = format!("{}/{}/{}", hive, box_name, frame_name);
254
255 let mut ledger = match Ledger::open(Arc::clone(&self.storage), &frame_path).await {
257 Ok(l) => l,
258 Err(_) => {
259 Ledger::create(
260 Arc::clone(&self.storage),
261 &frame_path,
262 schema.clone(),
263 frame.partition_by.clone(),
264 &self.config.node_id,
265 )
266 .await?
267 }
268 };
269
270 let sizing = CellSizingPolicy::new(
272 self.config.target_cell_size,
273 self.config.max_cell_size,
274 self.config.min_cell_size,
275 );
276
277 let writer = CellWriter::new(
278 Arc::clone(&self.storage),
279 frame_path,
280 schema,
281 frame.partition_by.clone(),
282 sizing,
283 );
284
285 let cells = writer.write(batch).await?;
286
287 let cells_written = cells.len();
288 let rows_written: u64 = cells.iter().map(|c| c.rows).sum();
289 let bytes_written: u64 = cells.iter().map(|c| c.bytes).sum();
290
291 let version = ledger
293 .commit(LedgerAction::AddCells { cells }, &self.config.node_id)
294 .await?;
295
296 let duration_ms = start.elapsed().as_millis() as u64;
297 let temperature = self.thermometer.measure(&self.bee_pool).await;
298
299 Ok(WriteResult {
300 version,
301 cells_written,
302 rows_written,
303 bytes_written,
304 duration_ms,
305 temperature,
306 })
307 }
308
309 pub async fn read_from_frame(
312 &self,
313 hive: &str,
314 box_name: &str,
315 frame_name: &str,
316 partition_filter: Option<&HashMap<String, String>>,
317 ) -> Result<Option<RecordBatch>> {
318 let frame_path = format!("{}/{}/{}", hive, box_name, frame_name);
319
320 let ledger = match Ledger::open(Arc::clone(&self.storage), &frame_path).await {
321 Ok(l) => l,
322 Err(ApiaryError::NotFound { .. }) => return Ok(None),
323 Err(e) => return Err(e),
324 };
325
326 let cells = if let Some(filter) = partition_filter {
327 ledger.prune_cells(filter, &HashMap::new())
328 } else {
329 ledger.active_cells().iter().collect()
330 };
331
332 if cells.is_empty() {
333 return Ok(None);
334 }
335
336 let reader = CellReader::new(Arc::clone(&self.storage), frame_path);
337 reader.read_cells_merged(&cells, None).await
338 }
339
340 pub async fn overwrite_frame(
343 &self,
344 hive: &str,
345 box_name: &str,
346 frame_name: &str,
347 batch: &RecordBatch,
348 ) -> Result<WriteResult> {
349 let start = std::time::Instant::now();
350
351 let frame = self.registry.get_frame(hive, box_name, frame_name).await?;
352 let schema = FrameSchema::from_json_value(&frame.schema)?;
353 let frame_path = format!("{}/{}/{}", hive, box_name, frame_name);
354
355 let mut ledger = Ledger::open(Arc::clone(&self.storage), &frame_path).await?;
356
357 let sizing = CellSizingPolicy::new(
358 self.config.target_cell_size,
359 self.config.max_cell_size,
360 self.config.min_cell_size,
361 );
362
363 let writer = CellWriter::new(
364 Arc::clone(&self.storage),
365 frame_path,
366 schema,
367 frame.partition_by.clone(),
368 sizing,
369 );
370
371 let new_cells = writer.write(batch).await?;
372
373 let cells_written = new_cells.len();
374 let rows_written: u64 = new_cells.iter().map(|c| c.rows).sum();
375 let bytes_written: u64 = new_cells.iter().map(|c| c.bytes).sum();
376
377 let removed: Vec<_> = ledger.active_cells().iter().map(|c| c.id.clone()).collect();
379
380 let version = ledger
381 .commit(
382 LedgerAction::RewriteCells {
383 removed,
384 added: new_cells,
385 },
386 &self.config.node_id,
387 )
388 .await?;
389
390 let duration_ms = start.elapsed().as_millis() as u64;
391 let temperature = self.thermometer.measure(&self.bee_pool).await;
392
393 Ok(WriteResult {
394 version,
395 cells_written,
396 rows_written,
397 bytes_written,
398 duration_ms,
399 temperature,
400 })
401 }
402
403 pub async fn init_frame_ledger(
405 &self,
406 hive: &str,
407 box_name: &str,
408 frame_name: &str,
409 ) -> Result<()> {
410 let frame = self.registry.get_frame(hive, box_name, frame_name).await?;
411 let schema = FrameSchema::from_json_value(&frame.schema)?;
412 let frame_path = format!("{}/{}/{}", hive, box_name, frame_name);
413
414 Ledger::create(
415 Arc::clone(&self.storage),
416 &frame_path,
417 schema,
418 frame.partition_by.clone(),
419 &self.config.node_id,
420 )
421 .await?;
422
423 Ok(())
424 }
425
426 pub async fn sql(&self, query: &str) -> Result<Vec<RecordBatch>> {
438 let query_ctx = Arc::clone(&self.query_ctx);
439 let query_owned = query.to_string();
440 let rt_handle = tokio::runtime::Handle::current();
441
442 let handle = self
443 .bee_pool
444 .submit(move || {
445 rt_handle.block_on(async {
446 let mut ctx = query_ctx.lock().await;
447 ctx.sql(&query_owned).await
448 })
449 })
450 .await;
451
452 handle.await.map_err(|e| ApiaryError::Internal {
453 message: format!("Task join error: {e}"),
454 })?
455 }
456
457 pub async fn bee_status(&self) -> Vec<BeeStatus> {
459 self.bee_pool.status().await
460 }
461
462 pub async fn world_view(&self) -> WorldView {
464 self.world_view.read().await.clone()
465 }
466
467 pub async fn swarm_status(&self) -> SwarmStatus {
469 let view = self.world_view.read().await;
470 let mut nodes = Vec::new();
471
472 for status in view.nodes.values() {
473 nodes.push(SwarmNodeInfo {
474 node_id: status.node_id.as_str().to_string(),
475 state: match status.state {
476 NodeState::Alive => "alive".to_string(),
477 NodeState::Suspect => "suspect".to_string(),
478 NodeState::Dead => "dead".to_string(),
479 },
480 bees: status.heartbeat.load.bees_total,
481 idle_bees: status.heartbeat.load.bees_idle,
482 memory_pressure: status.heartbeat.load.memory_pressure,
483 colony_temperature: status.heartbeat.load.colony_temperature,
484 });
485 }
486
487 nodes.sort_by(|a, b| a.node_id.cmp(&b.node_id));
489
490 let total_bees: usize = nodes.iter().map(|n| n.bees).sum();
491 let total_idle_bees: usize = nodes.iter().map(|n| n.idle_bees).sum();
492
493 SwarmStatus {
494 nodes,
495 total_bees,
496 total_idle_bees,
497 }
498 }
499
500 pub async fn colony_status(&self) -> ColonyStatus {
502 let temperature = self.thermometer.measure(&self.bee_pool).await;
503 let regulation = self.thermometer.regulation(temperature);
504
505 ColonyStatus {
506 temperature,
507 regulation: regulation.as_str().to_string(),
508 setpoint: self.thermometer.setpoint(),
509 }
510 }
511
512 #[allow(dead_code)] pub async fn sql_distributed(&self, query: &str) -> Result<Vec<RecordBatch>> {
519 self.sql(query).await
522 }
523}
524
525#[derive(Debug, Clone)]
527pub struct SwarmStatus {
528 pub nodes: Vec<SwarmNodeInfo>,
530 pub total_bees: usize,
532 pub total_idle_bees: usize,
534}
535
536#[derive(Debug, Clone)]
538pub struct SwarmNodeInfo {
539 pub node_id: String,
540 pub state: String,
541 pub bees: usize,
542 pub idle_bees: usize,
543 pub memory_pressure: f64,
544 pub colony_temperature: f64,
545}
546
547#[derive(Debug, Clone)]
549pub struct ColonyStatus {
550 pub temperature: f64,
552 pub regulation: String,
554 pub setpoint: f64,
556}
557
558fn home_dir() -> Option<std::path::PathBuf> {
560 #[cfg(target_os = "windows")]
561 {
562 std::env::var("USERPROFILE")
563 .ok()
564 .map(std::path::PathBuf::from)
565 }
566 #[cfg(not(target_os = "windows"))]
567 {
568 std::env::var("HOME").ok().map(std::path::PathBuf::from)
569 }
570}
571
572async fn run_query_worker_poller(
574 storage: Arc<dyn StorageBackend>,
575 query_ctx: Arc<tokio::sync::Mutex<ApiaryQueryContext>>,
576 node_id: apiary_core::types::NodeId,
577 cancel: tokio::sync::watch::Receiver<bool>,
578) {
579 use apiary_query::distributed;
580
581 info!(node_id = %node_id, "Query worker poller started");
582
583 let poll_interval = Duration::from_millis(500);
584
585 loop {
586 tokio::select! {
587 _ = tokio::time::sleep(poll_interval) => {
588 match storage.list("_queries/").await {
590 Ok(keys) => {
591 for key in keys {
593 if !key.ends_with("/manifest.json") {
594 continue;
595 }
596
597 let parts: Vec<&str> = key.split('/').collect();
599 if parts.len() < 3 {
600 continue;
601 }
602 let query_id = parts[1];
603
604 match distributed::read_manifest(&storage, query_id).await {
606 Ok(manifest) => {
607 let my_tasks: Vec<_> = manifest.tasks.iter()
609 .filter(|t| t.node_id == node_id)
610 .collect();
611
612 if my_tasks.is_empty() {
613 continue;
614 }
615
616 let partial_path = distributed::partial_result_path(query_id, &node_id);
618 if storage.get(&partial_path).await.is_ok() {
619 continue;
621 }
622
623 info!(
625 query_id = %query_id,
626 tasks = my_tasks.len(),
627 "Executing distributed query tasks"
628 );
629
630 let mut results = Vec::new();
631 let ctx = query_ctx.lock().await;
632
633 for task in my_tasks {
634 match ctx.execute_task(&task.sql_fragment, &task.cells).await {
635 Ok(batches) => {
636 results.extend(batches);
637 }
638 Err(e) => {
639 tracing::warn!(
640 task_id = %task.task_id,
641 error = %e,
642 "Task execution failed"
643 );
644 }
645 }
646 }
647
648 if !results.is_empty() {
650 if let Err(e) = distributed::write_partial_result(
651 &storage,
652 query_id,
653 &node_id,
654 &results,
655 ).await {
656 tracing::warn!(
657 query_id = %query_id,
658 error = %e,
659 "Failed to write partial result"
660 );
661 }
662 }
663 }
664 Err(_) => {
665 continue;
667 }
668 }
669 }
670 }
671 Err(e) => {
672 tracing::warn!(error = %e, "Failed to list query manifests");
673 }
674 }
675 }
676 _ = wait_for_cancel(&cancel) => {
677 tracing::debug!(node_id = %node_id, "Query worker poller stopping");
678 break;
679 }
680 }
681 }
682}
683
684async fn wait_for_cancel(cancel: &tokio::sync::watch::Receiver<bool>) {
686 let mut rx = cancel.clone();
687 let _ = rx.wait_for(|&v| v).await;
688}
689
690#[cfg(test)]
691mod tests {
692 use super::*;
693
694 #[tokio::test]
695 async fn test_start_local_node() {
696 let tmp = tempfile::TempDir::new().unwrap();
697 let mut config = NodeConfig::detect("local://test");
698 config.storage_uri = format!("local://{}", tmp.path().display());
699 let node = ApiaryNode::start(config).await.unwrap();
700 assert!(node.config.cores > 0);
701 node.shutdown().await;
702 }
703
704 #[tokio::test]
705 async fn test_start_with_raw_path() {
706 let tmp = tempfile::TempDir::new().unwrap();
707 let mut config = NodeConfig::detect("test");
708 config.storage_uri = tmp.path().to_string_lossy().to_string();
709 let node = ApiaryNode::start(config).await.unwrap();
710 node.shutdown().await;
711 }
712}