1use phago_core::types::{DocumentId, NodeData, NodeId, Tick};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use thiserror::Error;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
14pub struct ShardId(pub u32);
15
16impl ShardId {
17 pub fn new(id: u32) -> Self {
19 Self(id)
20 }
21
22 pub fn as_u32(&self) -> u32 {
24 self.0
25 }
26}
27
28impl std::fmt::Display for ShardId {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 write!(f, "shard-{}", self.0)
31 }
32}
33
34#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
36pub struct NodeAddress {
37 pub host: String,
39 pub port: u16,
41}
42
43impl NodeAddress {
44 pub fn new(host: impl Into<String>, port: u16) -> Self {
46 Self {
47 host: host.into(),
48 port,
49 }
50 }
51
52 pub fn to_socket_addr(&self) -> String {
54 format!("{}:{}", self.host, self.port)
55 }
56}
57
58impl std::fmt::Display for NodeAddress {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 write!(f, "{}:{}", self.host, self.port)
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct DistributedConfig {
67 pub num_shards: u32,
69 pub replication_factor: u32,
71 pub rpc_timeout_ms: u64,
73 pub virtual_nodes_per_shard: u32,
75}
76
77impl Default for DistributedConfig {
78 fn default() -> Self {
79 Self {
80 num_shards: 3,
81 replication_factor: 2,
82 rpc_timeout_ms: 5000,
83 virtual_nodes_per_shard: 150,
84 }
85 }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
90pub enum ShardStatus {
91 Online,
93 Offline,
95 Recovering,
97 Draining,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct ShardInfo {
104 pub id: ShardId,
106 pub address: String,
108 pub node_count: usize,
110 pub edge_count: usize,
112 pub document_count: usize,
114 pub last_heartbeat: u64,
116}
117
118impl ShardInfo {
119 pub fn new(id: ShardId, address: String) -> Self {
121 Self {
122 id,
123 address,
124 node_count: 0,
125 edge_count: 0,
126 document_count: 0,
127 last_heartbeat: 0,
128 }
129 }
130}
131
132#[derive(Error, Debug, Clone)]
134pub enum DistributedError {
135 #[error("Shard {0:?} not found")]
136 ShardNotFound(ShardId),
137
138 #[error("Coordinator unavailable")]
139 CoordinatorUnavailable,
140
141 #[error("RPC error: {0}")]
142 RpcError(String),
143
144 #[error("Timeout waiting for phase {0:?}")]
145 PhaseTimeout(TickPhase),
146
147 #[error("Document routing failed for {0:?}")]
148 RoutingFailed(DocumentId),
149
150 #[error("Cross-shard edge resolution failed")]
151 EdgeResolutionFailed,
152
153 #[error("Ghost node not found: {0:?}")]
154 GhostNodeNotFound(NodeId),
155
156 #[error("Barrier synchronization failed")]
157 BarrierFailed,
158}
159
160pub type DistributedResult<T> = Result<T, DistributedError>;
162
163#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
168pub enum TickPhase {
169 Sense,
171 Act,
173 Decay,
175 Advance,
177}
178
179impl std::fmt::Display for TickPhase {
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 match self {
182 TickPhase::Sense => write!(f, "Sense"),
183 TickPhase::Act => write!(f, "Act"),
184 TickPhase::Decay => write!(f, "Decay"),
185 TickPhase::Advance => write!(f, "Advance"),
186 }
187 }
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct PhaseResult {
193 pub shard_id: ShardId,
195 pub phase: TickPhase,
197 pub tick: Tick,
199 pub cross_shard_edges: Vec<CrossShardEdge>,
201 pub node_count: usize,
203 pub edge_count: usize,
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct CrossShardEdge {
214 pub from_node: NodeId,
216 pub to_node: NodeId,
218 pub to_shard: ShardId,
220 pub weight: f64,
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct LocalQueryRequest {
230 pub query_terms: Vec<String>,
232 pub max_results: usize,
234 pub global_df: HashMap<String, u64>,
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
241pub struct LocalQueryResult {
242 pub shard_id: ShardId,
244 pub results: Vec<ScoredNode>,
246 pub term_frequencies: HashMap<String, u64>,
249}
250
251#[derive(Debug, Clone, Serialize, Deserialize)]
253pub struct ScoredNode {
254 pub node_id: NodeId,
256 pub label: String,
258 pub score: f64,
260 pub shard_id: ShardId,
262}
263
264impl PartialEq for ScoredNode {
265 fn eq(&self, other: &Self) -> bool {
266 self.node_id == other.node_id && self.shard_id == other.shard_id
267 }
268}
269
270impl Eq for ScoredNode {}
271
272impl PartialOrd for ScoredNode {
273 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
274 Some(self.cmp(other))
275 }
276}
277
278impl Ord for ScoredNode {
279 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
280 other
282 .score
283 .partial_cmp(&self.score)
284 .unwrap_or(std::cmp::Ordering::Equal)
285 }
286}
287
288#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct ShardHealth {
291 pub shard_id: ShardId,
293 pub healthy: bool,
295 pub load: f64,
297 pub memory_usage_mb: u64,
299 pub pending_operations: usize,
301}
302
303impl ShardHealth {
304 pub fn healthy(shard_id: ShardId) -> Self {
306 Self {
307 shard_id,
308 healthy: true,
309 load: 0.0,
310 memory_usage_mb: 0,
311 pending_operations: 0,
312 }
313 }
314
315 pub fn unhealthy(shard_id: ShardId) -> Self {
317 Self {
318 shard_id,
319 healthy: false,
320 load: 0.0,
321 memory_usage_mb: 0,
322 pending_operations: 0,
323 }
324 }
325}
326
327#[derive(Debug, Clone, Serialize, Deserialize)]
333pub struct GhostNode {
334 pub node_id: NodeId,
336 pub shard_id: ShardId,
338 pub label: String,
340 pub full_data: Option<NodeData>,
342}
343
344impl GhostNode {
345 pub fn new(node_id: NodeId, shard_id: ShardId, label: String) -> Self {
347 Self {
348 node_id,
349 shard_id,
350 label,
351 full_data: None,
352 }
353 }
354
355 pub fn is_resolved(&self) -> bool {
357 self.full_data.is_some()
358 }
359
360 pub fn resolve(&mut self, data: NodeData) {
362 self.full_data = Some(data);
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use phago_core::types::Position;
370
371 #[test]
372 fn test_shard_id() {
373 let shard = ShardId::new(42);
374 assert_eq!(shard.0, 42);
375 assert_eq!(shard.as_u32(), 42);
376 assert_eq!(format!("{}", shard), "shard-42");
377 }
378
379 #[test]
380 fn test_node_address() {
381 let addr = NodeAddress::new("127.0.0.1", 8080);
382 assert_eq!(addr.host, "127.0.0.1");
383 assert_eq!(addr.port, 8080);
384 assert_eq!(addr.to_socket_addr(), "127.0.0.1:8080");
385 assert_eq!(format!("{}", addr), "127.0.0.1:8080");
386 }
387
388 #[test]
389 fn test_distributed_config_default() {
390 let config = DistributedConfig::default();
391 assert_eq!(config.num_shards, 3);
392 assert_eq!(config.replication_factor, 2);
393 assert_eq!(config.rpc_timeout_ms, 5000);
394 assert_eq!(config.virtual_nodes_per_shard, 150);
395 }
396
397 #[test]
398 fn test_tick_phase_display() {
399 assert_eq!(format!("{}", TickPhase::Sense), "Sense");
400 assert_eq!(format!("{}", TickPhase::Act), "Act");
401 assert_eq!(format!("{}", TickPhase::Decay), "Decay");
402 assert_eq!(format!("{}", TickPhase::Advance), "Advance");
403 }
404
405 #[test]
406 fn test_scored_node_ordering() {
407 let node1 = ScoredNode {
408 node_id: NodeId::from_seed(1),
409 label: "high".to_string(),
410 score: 0.9,
411 shard_id: ShardId::new(0),
412 };
413 let node2 = ScoredNode {
414 node_id: NodeId::from_seed(2),
415 label: "low".to_string(),
416 score: 0.1,
417 shard_id: ShardId::new(0),
418 };
419
420 assert!(node1 < node2);
422 }
423
424 #[test]
425 fn test_ghost_node_resolution() {
426 let mut ghost = GhostNode::new(NodeId::from_seed(1), ShardId::new(1), "test".to_string());
427 assert!(!ghost.is_resolved());
428
429 let data = NodeData {
430 id: NodeId::from_seed(1),
431 label: "test".to_string(),
432 node_type: phago_core::types::NodeType::Concept,
433 position: Position::new(0.0, 0.0),
434 access_count: 0,
435 created_tick: 0,
436 embedding: None,
437 };
438 ghost.resolve(data);
439 assert!(ghost.is_resolved());
440 }
441
442 #[test]
443 fn test_shard_health() {
444 let healthy = ShardHealth::healthy(ShardId::new(0));
445 assert!(healthy.healthy);
446 assert_eq!(healthy.load, 0.0);
447
448 let unhealthy = ShardHealth::unhealthy(ShardId::new(1));
449 assert!(!unhealthy.healthy);
450 }
451
452 #[test]
453 fn test_shard_info_new() {
454 let info = ShardInfo::new(ShardId::new(5), "127.0.0.1:8085".to_string());
455 assert_eq!(info.id, ShardId::new(5));
456 assert_eq!(info.address, "127.0.0.1:8085");
457 assert_eq!(info.node_count, 0);
458 assert_eq!(info.edge_count, 0);
459 assert_eq!(info.document_count, 0);
460 }
461
462 #[test]
463 fn test_phase_result() {
464 let result = PhaseResult {
465 shard_id: ShardId::new(0),
466 phase: TickPhase::Sense,
467 tick: 42,
468 cross_shard_edges: vec![],
469 node_count: 100,
470 edge_count: 250,
471 };
472 assert_eq!(result.tick, 42);
473 assert_eq!(result.node_count, 100);
474 }
475
476 #[test]
477 fn test_cross_shard_edge() {
478 let edge = CrossShardEdge {
479 from_node: NodeId::from_seed(1),
480 to_node: NodeId::from_seed(2),
481 to_shard: ShardId::new(1),
482 weight: 0.75,
483 };
484 assert_eq!(edge.to_shard, ShardId::new(1));
485 assert!((edge.weight - 0.75).abs() < f64::EPSILON);
486 }
487
488 #[test]
489 fn test_local_query_request() {
490 let mut global_df = HashMap::new();
491 global_df.insert("rust".to_string(), 100);
492 global_df.insert("programming".to_string(), 200);
493
494 let request = LocalQueryRequest {
495 query_terms: vec!["rust".to_string(), "programming".to_string()],
496 max_results: 10,
497 global_df,
498 };
499 assert_eq!(request.query_terms.len(), 2);
500 assert_eq!(request.max_results, 10);
501 }
502}