1use crate::coordinator::Coordinator;
8use crate::rpc::messages::CrossShardSignal;
9use crate::rpc::protocol::{CoordinatorService, RpcError, RpcResult, ShardService, TickStatus};
10use crate::shard::ShardedColony;
11use crate::types::*;
12use futures::StreamExt;
13use phago_core::substrate::Substrate;
14use phago_core::topology::TopologyGraph;
15use phago_core::types::{Document, DocumentId, NodeData, NodeId};
16use std::collections::HashMap;
17use std::net::SocketAddr;
18use std::sync::Arc;
19use tarpc::context::Context;
20use tarpc::server::{self, Channel};
21use tokio_serde::formats::Bincode;
22use tokio::sync::RwLock;
23use tracing::{debug, error, info, instrument};
24
25#[derive(Clone)]
46pub struct ShardServer {
47 shard: Arc<RwLock<ShardedColony>>,
48}
49
50impl ShardServer {
51 pub fn new(shard: Arc<RwLock<ShardedColony>>) -> Self {
57 Self { shard }
58 }
59
60 pub async fn start(self, addr: SocketAddr) -> Result<(), std::io::Error> {
73 use crate::rpc::protocol::ShardService;
74 let listener = tarpc::serde_transport::tcp::listen(&addr, Bincode::default).await?;
75 info!("Shard server listening on {}", addr);
76
77 listener
78 .filter_map(|r| futures::future::ready(r.ok()))
79 .map(server::BaseChannel::with_defaults)
80 .for_each_concurrent(10, |channel| {
81 let server = self.clone();
82 async move {
83 channel.execute(server.serve()).for_each(|_| async {}).await
84 }
85 })
86 .await;
87
88 Ok(())
89 }
90}
91
92impl ShardService for ShardServer {
93 #[instrument(skip(self, _ctx), fields(doc_title = %doc.title))]
94 async fn ingest_document(self, _ctx: Context, doc: Document) -> RpcResult<DocumentId> {
95 debug!("Ingesting document: {}", doc.title);
96 let mut shard = self.shard.write().await;
97 let id = shard.ingest_document_direct(&doc.title, &doc.content, doc.position);
98 debug!("Document ingested with ID: {:?}", id);
99 Ok(id)
100 }
101
102 #[instrument(skip(self, _ctx), fields(phase = %phase, tick = tick))]
103 async fn tick_phase(
104 self,
105 _ctx: Context,
106 phase: TickPhase,
107 tick: u64,
108 ) -> RpcResult<PhaseResult> {
109 debug!("Executing tick phase {:?} for tick {}", phase, tick);
110 let mut shard = self.shard.write().await;
111 let result = shard.tick_phase(phase);
112 debug!(
113 "Phase complete: {} nodes, {} edges, {} cross-shard edges",
114 result.node_count,
115 result.edge_count,
116 result.cross_shard_edges.len()
117 );
118 Ok(result)
119 }
120
121 #[instrument(skip(self, _ctx, req), fields(terms = ?req.query_terms, max_results = req.max_results))]
122 async fn local_query(
123 self,
124 _ctx: Context,
125 req: LocalQueryRequest,
126 ) -> RpcResult<LocalQueryResult> {
127 debug!("Executing local query with {} terms", req.query_terms.len());
128 let shard = self.shard.read().await;
129 let result = shard.execute_local_query(&req);
130 debug!("Query returned {} results", result.results.len());
131 Ok(result)
132 }
133
134 #[instrument(skip(self, _ctx), fields(term_count = terms.len()))]
135 async fn get_term_frequencies(
136 self,
137 _ctx: Context,
138 terms: Vec<String>,
139 ) -> RpcResult<HashMap<String, u64>> {
140 debug!("Getting term frequencies for {} terms", terms.len());
141 let shard = self.shard.read().await;
142 let freqs = shard.get_term_frequencies(&terms);
143 debug!("Returned {} term frequencies", freqs.len());
144 Ok(freqs)
145 }
146
147 #[instrument(skip(self, _ctx), fields(node_id = ?id))]
148 async fn get_node(self, _ctx: Context, id: NodeId) -> RpcResult<Option<NodeData>> {
149 debug!("Getting node {:?}", id);
150 let shard = self.shard.read().await;
151 let node = shard.get_node(&id);
152 debug!("Node found: {}", node.is_some());
153 Ok(node)
154 }
155
156 #[instrument(skip(self, _ctx))]
157 async fn health_check(self, _ctx: Context) -> RpcResult<ShardHealth> {
158 debug!("Health check requested");
159 let shard = self.shard.read().await;
160 let health = shard.health();
161 debug!(
162 "Health: healthy={}, load={:.2}",
163 health.healthy, health.load
164 );
165 Ok(health)
166 }
167
168 #[instrument(skip(self, _ctx), fields(node_count = node_ids.len()))]
169 async fn resolve_ghost_nodes(
170 self,
171 _ctx: Context,
172 node_ids: Vec<NodeId>,
173 ) -> RpcResult<Vec<GhostNode>> {
174 debug!("Resolving {} ghost nodes", node_ids.len());
175 let shard = self.shard.read().await;
176 let shard_id = shard.shard_id();
177
178 let mut ghosts = Vec::with_capacity(node_ids.len());
179 for id in node_ids {
180 if let Some(node) = shard.get_node(&id) {
181 let mut ghost = GhostNode::new(id, shard_id, node.label.clone());
182 ghost.resolve(node);
183 ghosts.push(ghost);
184 }
185 }
186
187 debug!("Resolved {} ghost nodes", ghosts.len());
188 Ok(ghosts)
189 }
190
191 #[instrument(skip(self, _ctx), fields(node_id = ?node_id))]
192 async fn get_neighbors(self, _ctx: Context, node_id: NodeId) -> RpcResult<Vec<NodeId>> {
193 debug!("Getting neighbors for node {:?}", node_id);
194 let shard = self.shard.read().await;
195 let graph = shard.local().substrate().graph();
196 let neighbors: Vec<NodeId> = graph
197 .neighbors(&node_id)
198 .into_iter()
199 .map(|(id, _)| id)
200 .collect();
201 debug!("Found {} neighbors", neighbors.len());
202 Ok(neighbors)
203 }
204
205 #[instrument(skip(self, _ctx, signals), fields(signal_count = signals.len()))]
206 async fn receive_signals(self, _ctx: Context, signals: Vec<CrossShardSignal>) -> RpcResult<()> {
207 debug!("Receiving {} cross-shard signals", signals.len());
208
209 let mut shard = self.shard.write().await;
210 for sig in &signals {
211 let local_signal = phago_core::types::Signal {
212 signal_type: sig.signal_type.clone(),
213 intensity: sig.intensity,
214 position: sig.position.clone(),
215 emitter: sig.emitter,
216 tick: sig.tick,
217 };
218 shard.local_mut().substrate_mut().emit_signal(local_signal);
219 }
220
221 debug!("Applied {} signals to substrate", signals.len());
222 Ok(())
223 }
224}
225
226#[derive(Clone)]
247pub struct CoordinatorServer {
248 coordinator: Arc<Coordinator>,
249}
250
251impl CoordinatorServer {
252 pub fn new(coordinator: Arc<Coordinator>) -> Self {
258 Self { coordinator }
259 }
260
261 pub async fn start(self, addr: SocketAddr) -> Result<(), std::io::Error> {
271 use crate::rpc::protocol::CoordinatorService;
272 let listener = tarpc::serde_transport::tcp::listen(&addr, Bincode::default).await?;
273 info!("Coordinator server listening on {}", addr);
274
275 listener
276 .filter_map(|r| futures::future::ready(r.ok()))
277 .map(server::BaseChannel::with_defaults)
278 .for_each_concurrent(10, |channel| {
279 let server = self.clone();
280 async move {
281 channel.execute(server.serve()).for_each(|_| async {}).await
282 }
283 })
284 .await;
285
286 Ok(())
287 }
288}
289
290impl CoordinatorService for CoordinatorServer {
291 #[instrument(skip(self, _ctx), fields(shard_id = ?info.id, address = %info.address))]
292 async fn register(self, _ctx: Context, info: ShardInfo) -> RpcResult<ShardId> {
293 info!("Registering shard at {}", info.address);
294 match self.coordinator.register_shard(info).await {
295 Ok(id) => {
296 info!("Shard registered with ID {:?}", id);
297 Ok(id)
298 }
299 Err(e) => {
300 error!("Failed to register shard: {}", e);
301 Err(RpcError::Internal(e.to_string()))
302 }
303 }
304 }
305
306 #[instrument(skip(self, _ctx), fields(shard_id = ?shard_id))]
307 async fn unregister(self, _ctx: Context, shard_id: ShardId) -> RpcResult<()> {
308 info!("Unregistering shard {:?}", shard_id);
309 match self.coordinator.deregister_shard(shard_id).await {
310 Ok(()) => {
311 info!("Shard {:?} unregistered", shard_id);
312 Ok(())
313 }
314 Err(DistributedError::ShardNotFound(_)) => {
315 Err(RpcError::ShardNotFound(shard_id.as_u32()))
316 }
317 Err(e) => {
318 error!("Failed to unregister shard: {}", e);
319 Err(RpcError::Internal(e.to_string()))
320 }
321 }
322 }
323
324 #[instrument(skip(self, _ctx), fields(shard_id = ?shard_id, phase = %phase, tick = tick))]
325 async fn phase_complete(
326 self,
327 _ctx: Context,
328 shard_id: ShardId,
329 phase: TickPhase,
330 tick: u64,
331 ) -> RpcResult<()> {
332 debug!(
333 "Shard {:?} completed phase {:?} for tick {}",
334 shard_id, phase, tick
335 );
336 match self.coordinator.phase_complete(shard_id, phase, tick).await {
337 Ok(()) => Ok(()),
338 Err(DistributedError::BarrierFailed) => Err(RpcError::BarrierFailed),
339 Err(DistributedError::PhaseTimeout(p)) => Err(RpcError::PhaseTimeout(p.to_string())),
340 Err(e) => Err(RpcError::Internal(e.to_string())),
341 }
342 }
343
344 #[instrument(skip(self, _ctx), fields(doc_id = ?doc_id))]
345 async fn route_document(self, _ctx: Context, doc_id: DocumentId) -> ShardId {
346 let shard = self.coordinator.route_document(&doc_id).await;
347 debug!("Document {:?} routed to shard {:?}", doc_id, shard);
348 shard
349 }
350
351 #[instrument(skip(self, _ctx), fields(node_id = ?node_id))]
352 async fn route_node(self, _ctx: Context, node_id: NodeId) -> ShardId {
353 let doc_id = DocumentId(node_id.0);
356 let shard = self.coordinator.route_document(&doc_id).await;
357 debug!("Node {:?} routed to shard {:?}", node_id, shard);
358 shard
359 }
360
361 #[instrument(skip(self, _ctx), fields(term_count = terms.len()))]
362 async fn get_global_df(
363 self,
364 _ctx: Context,
365 terms: Vec<String>,
366 ) -> RpcResult<HashMap<String, u64>> {
367 debug!("Getting global DF for {} terms", terms.len());
368
369 let global_df = HashMap::new();
375 debug!("Returning {} global DF entries (scatter-gather handled by query engine)", global_df.len());
376 Ok(global_df)
377 }
378
379 #[instrument(skip(self, _ctx), fields(shard_id = ?shard_id, phase = %phase, tick = tick))]
380 async fn barrier_ready(
381 self,
382 _ctx: Context,
383 shard_id: ShardId,
384 phase: TickPhase,
385 tick: u64,
386 ) -> RpcResult<bool> {
387 debug!(
388 "Shard {:?} checking barrier for phase {:?}, tick {}",
389 shard_id, phase, tick
390 );
391
392 match self.coordinator.phase_complete(shard_id, phase, tick).await {
394 Ok(()) => {
395 debug!("Barrier released for phase {:?}", phase);
397 Ok(true)
398 }
399 Err(DistributedError::BarrierFailed) => {
400 debug!("Barrier not ready yet");
401 Ok(false)
402 }
403 Err(e) => Err(RpcError::Internal(e.to_string())),
404 }
405 }
406
407 #[instrument(skip(self, _ctx))]
408 async fn current_tick(self, _ctx: Context) -> u64 {
409 let tick = self.coordinator.current_tick();
410 debug!("Current tick: {}", tick);
411 tick
412 }
413
414 #[instrument(skip(self, _ctx))]
415 async fn list_shards(self, _ctx: Context) -> Vec<ShardInfo> {
416 let shards = self.coordinator.all_shards().await;
417 debug!("Listed {} shards", shards.len());
418 shards
419 }
420
421 #[instrument(skip(self, _ctx))]
422 async fn start_tick(self, _ctx: Context) -> RpcResult<u64> {
423 info!("Starting new tick");
424 let new_tick = self.coordinator.advance_tick().await;
425 info!("Started tick {}", new_tick);
426 Ok(new_tick)
427 }
428
429 #[instrument(skip(self, _ctx))]
430 async fn tick_status(self, _ctx: Context) -> RpcResult<TickStatus> {
431 debug!("Getting tick status");
432 let tick = self.coordinator.current_tick();
433 let all_shards = self.coordinator.all_shards().await;
434 let shard_ids: Vec<ShardId> = all_shards.iter().map(|s| s.id).collect();
435
436 let tick_complete = tick > 0 && shard_ids.is_empty();
443
444 let status = TickStatus {
445 tick,
446 phase: TickPhase::Sense,
447 completed_shards: vec![],
448 pending_shards: shard_ids,
449 tick_complete,
450 };
451
452 debug!(
453 "Tick status: tick={}, shards={}, complete={}",
454 status.tick,
455 status.pending_shards.len(),
456 status.tick_complete,
457 );
458 Ok(status)
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465 use crate::hashing::ConsistentHashRing;
466 use phago_core::types::Position;
467 use phago_runtime::colony::ColonyConfig;
468
469 fn create_test_shard() -> Arc<RwLock<ShardedColony>> {
470 let hash_ring = Arc::new(RwLock::new(ConsistentHashRing::new(3)));
471 Arc::new(RwLock::new(ShardedColony::new(
472 ShardId::new(0),
473 ColonyConfig::default(),
474 hash_ring,
475 )))
476 }
477
478 fn create_test_coordinator() -> Arc<Coordinator> {
479 Arc::new(Coordinator::new(3))
480 }
481
482 #[tokio::test]
483 async fn test_shard_server_health_check() {
484 let shard = create_test_shard();
485 let server = ShardServer::new(shard);
486
487 let ctx = tarpc::context::current();
488 let health = server.health_check(ctx).await.unwrap();
489
490 assert!(health.healthy);
491 assert_eq!(health.shard_id, ShardId::new(0));
492 }
493
494 #[tokio::test]
495 async fn test_shard_server_ingest_document() {
496 let shard = create_test_shard();
497 let server = ShardServer::new(shard.clone());
498
499 let doc = Document {
500 id: DocumentId::new(),
501 title: "Test".to_string(),
502 content: "Test content".to_string(),
503 position: Position::new(0.0, 0.0),
504 digested: false,
505 };
506
507 let ctx = tarpc::context::current();
508 let doc_id = server.ingest_document(ctx, doc).await.unwrap();
509
510 assert!(!doc_id.0.is_nil());
511
512 let shard_guard = shard.read().await;
514 assert_eq!(shard_guard.document_count(), 1);
515 }
516
517 #[tokio::test]
518 async fn test_shard_server_tick_phase() {
519 let shard = create_test_shard();
520 let server = ShardServer::new(shard);
521
522 let ctx = tarpc::context::current();
523 let result = server.tick_phase(ctx, TickPhase::Sense, 0).await.unwrap();
524
525 assert_eq!(result.shard_id, ShardId::new(0));
526 assert_eq!(result.phase, TickPhase::Sense);
527 }
528
529 #[tokio::test]
530 async fn test_shard_server_local_query() {
531 let shard = create_test_shard();
532 let server = ShardServer::new(shard);
533
534 let req = LocalQueryRequest {
535 query_terms: vec!["test".to_string()],
536 max_results: 10,
537 global_df: HashMap::new(),
538 };
539
540 let ctx = tarpc::context::current();
541 let result = server.local_query(ctx, req).await.unwrap();
542
543 assert_eq!(result.shard_id, ShardId::new(0));
544 assert!(result.results.is_empty()); }
546
547 #[tokio::test]
548 async fn test_shard_server_get_term_frequencies() {
549 let shard = create_test_shard();
550 let server = ShardServer::new(shard);
551
552 let ctx = tarpc::context::current();
553 let freqs = server
554 .get_term_frequencies(ctx, vec!["test".to_string()])
555 .await
556 .unwrap();
557
558 assert!(freqs.is_empty()); }
560
561 #[tokio::test]
562 async fn test_shard_server_get_node_not_found() {
563 let shard = create_test_shard();
564 let server = ShardServer::new(shard);
565
566 let ctx = tarpc::context::current();
567 let node = server.get_node(ctx, NodeId::from_seed(999)).await.unwrap();
568
569 assert!(node.is_none());
570 }
571
572 #[tokio::test]
573 async fn test_shard_server_resolve_ghost_nodes_empty() {
574 let shard = create_test_shard();
575 let server = ShardServer::new(shard);
576
577 let ctx = tarpc::context::current();
578 let ghosts = server
579 .resolve_ghost_nodes(ctx, vec![NodeId::from_seed(1), NodeId::from_seed(2)])
580 .await
581 .unwrap();
582
583 assert!(ghosts.is_empty()); }
585
586 #[tokio::test]
587 async fn test_shard_server_get_neighbors_empty() {
588 let shard = create_test_shard();
589 let server = ShardServer::new(shard);
590
591 let ctx = tarpc::context::current();
592 let neighbors = server
593 .get_neighbors(ctx, NodeId::from_seed(1))
594 .await
595 .unwrap();
596
597 assert!(neighbors.is_empty());
598 }
599
600 #[tokio::test]
601 async fn test_shard_server_receive_signals() {
602 let shard = create_test_shard();
603 let server = ShardServer::new(shard);
604
605 let signals = vec![CrossShardSignal {
606 signal_type: phago_core::types::SignalType::Input,
607 intensity: 0.5,
608 position: Position::new(0.0, 0.0),
609 emitter: phago_core::types::AgentId::from_seed(1),
610 tick: 0,
611 source_shard: ShardId::new(1),
612 }];
613
614 let ctx = tarpc::context::current();
615 let result = server.receive_signals(ctx, signals).await;
616
617 assert!(result.is_ok());
618 }
619
620 #[tokio::test]
621 async fn test_coordinator_server_register() {
622 let coordinator = create_test_coordinator();
623 let server = CoordinatorServer::new(coordinator.clone());
624
625 let info = ShardInfo::new(ShardId::new(0), "127.0.0.1:8080".to_string());
626
627 let ctx = tarpc::context::current();
628 let shard_id = server.register(ctx, info).await.unwrap();
629
630 assert_eq!(shard_id, ShardId::new(0));
631
632 let shards = coordinator.all_shards().await;
634 assert_eq!(shards.len(), 1);
635 }
636
637 #[tokio::test]
638 async fn test_coordinator_server_unregister() {
639 let coordinator = create_test_coordinator();
640 let server = CoordinatorServer::new(coordinator.clone());
641
642 let info = ShardInfo::new(ShardId::new(0), "127.0.0.1:8080".to_string());
644 let ctx = tarpc::context::current();
645 let shard_id = server.clone().register(ctx, info).await.unwrap();
646
647 let ctx = tarpc::context::current();
649 let result = server.unregister(ctx, shard_id).await;
650
651 assert!(result.is_ok());
652
653 let shards = coordinator.all_shards().await;
655 assert!(shards.is_empty());
656 }
657
658 #[tokio::test]
659 async fn test_coordinator_server_unregister_not_found() {
660 let coordinator = create_test_coordinator();
661 let server = CoordinatorServer::new(coordinator);
662
663 let ctx = tarpc::context::current();
664 let result = server.unregister(ctx, ShardId::new(999)).await;
665
666 assert!(matches!(result, Err(RpcError::ShardNotFound(999))));
667 }
668
669 #[tokio::test]
670 async fn test_coordinator_server_route_document() {
671 let coordinator = create_test_coordinator();
672 let server = CoordinatorServer::new(coordinator);
673
674 let doc_id = DocumentId::from_seed(42);
675
676 let ctx = tarpc::context::current();
677 let shard1 = server.clone().route_document(ctx, doc_id).await;
678
679 let ctx = tarpc::context::current();
680 let shard2 = server.route_document(ctx, doc_id).await;
681
682 assert_eq!(shard1, shard2);
684 }
685
686 #[tokio::test]
687 async fn test_coordinator_server_current_tick() {
688 let coordinator = create_test_coordinator();
689 let server = CoordinatorServer::new(coordinator);
690
691 let ctx = tarpc::context::current();
692 let tick = server.current_tick(ctx).await;
693
694 assert_eq!(tick, 0);
695 }
696
697 #[tokio::test]
698 async fn test_coordinator_server_start_tick() {
699 let coordinator = create_test_coordinator();
700 let server = CoordinatorServer::new(coordinator.clone());
701
702 let ctx = tarpc::context::current();
703 let tick1 = server.clone().start_tick(ctx).await.unwrap();
704 assert_eq!(tick1, 1);
705
706 let ctx = tarpc::context::current();
707 let tick2 = server.start_tick(ctx).await.unwrap();
708 assert_eq!(tick2, 2);
709
710 assert_eq!(coordinator.current_tick(), 2);
711 }
712
713 #[tokio::test]
714 async fn test_coordinator_server_list_shards() {
715 let coordinator = create_test_coordinator();
716 let server = CoordinatorServer::new(coordinator);
717
718 let ctx = tarpc::context::current();
720 server
721 .clone()
722 .register(
723 ctx,
724 ShardInfo::new(ShardId::new(0), "127.0.0.1:8080".to_string()),
725 )
726 .await
727 .unwrap();
728
729 let ctx = tarpc::context::current();
730 server
731 .clone()
732 .register(
733 ctx,
734 ShardInfo::new(ShardId::new(0), "127.0.0.1:8081".to_string()),
735 )
736 .await
737 .unwrap();
738
739 let ctx = tarpc::context::current();
740 let shards = server.list_shards(ctx).await;
741
742 assert_eq!(shards.len(), 2);
743 }
744
745 #[tokio::test]
746 async fn test_coordinator_server_tick_status() {
747 let coordinator = create_test_coordinator();
748 let server = CoordinatorServer::new(coordinator);
749
750 let ctx = tarpc::context::current();
751 let status = server.tick_status(ctx).await.unwrap();
752
753 assert_eq!(status.tick, 0);
754 assert!(!status.tick_complete);
755 }
756
757 #[tokio::test]
758 async fn test_coordinator_server_get_global_df() {
759 let coordinator = create_test_coordinator();
760 let server = CoordinatorServer::new(coordinator);
761
762 let ctx = tarpc::context::current();
763 let df = server
764 .get_global_df(ctx, vec!["test".to_string()])
765 .await
766 .unwrap();
767
768 assert!(df.is_empty());
770 }
771}