Skip to main content

phago_distributed/
types.rs

1//! Core types for distributed colony coordination.
2//!
3//! This module defines the core data structures used across the distributed
4//! system including shard identifiers, tick phases, cross-shard edges,
5//! query requests/results, and ghost nodes for remote references.
6
7use phago_core::types::{DocumentId, NodeData, NodeId, Tick};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use thiserror::Error;
11
12/// Unique identifier for a shard.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
14pub struct ShardId(pub u32);
15
16impl ShardId {
17    /// Create a new shard identifier.
18    pub fn new(id: u32) -> Self {
19        Self(id)
20    }
21
22    /// Get the underlying shard number.
23    pub fn as_u32(&self) -> u32 {
24        self.0
25    }
26}
27
28impl std::fmt::Display for ShardId {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        write!(f, "shard-{}", self.0)
31    }
32}
33
34/// Address of a node in the distributed cluster.
35#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
36pub struct NodeAddress {
37    /// Host address (IP or hostname).
38    pub host: String,
39    /// Port number.
40    pub port: u16,
41}
42
43impl NodeAddress {
44    /// Create a new node address.
45    pub fn new(host: impl Into<String>, port: u16) -> Self {
46        Self {
47            host: host.into(),
48            port,
49        }
50    }
51
52    /// Format as a socket address string.
53    pub fn to_socket_addr(&self) -> String {
54        format!("{}:{}", self.host, self.port)
55    }
56}
57
58impl std::fmt::Display for NodeAddress {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        write!(f, "{}:{}", self.host, self.port)
61    }
62}
63
64/// Configuration for the distributed colony.
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct DistributedConfig {
67    /// Number of shards to distribute data across.
68    pub num_shards: u32,
69    /// Replication factor for fault tolerance.
70    pub replication_factor: u32,
71    /// Timeout for RPC calls in milliseconds.
72    pub rpc_timeout_ms: u64,
73    /// Number of virtual nodes per shard for consistent hashing.
74    pub virtual_nodes_per_shard: u32,
75}
76
77impl Default for DistributedConfig {
78    fn default() -> Self {
79        Self {
80            num_shards: 3,
81            replication_factor: 2,
82            rpc_timeout_ms: 5000,
83            virtual_nodes_per_shard: 150,
84        }
85    }
86}
87
88/// Status of a shard in the cluster.
89#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
90pub enum ShardStatus {
91    /// Shard is online and accepting requests.
92    Online,
93    /// Shard is offline or unreachable.
94    Offline,
95    /// Shard is recovering/rebalancing data.
96    Recovering,
97    /// Shard is draining (preparing to go offline).
98    Draining,
99}
100
101/// Information about a shard registered with the coordinator.
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct ShardInfo {
104    /// The shard's unique identifier.
105    pub id: ShardId,
106    /// The network address of this shard (e.g., "127.0.0.1:8081").
107    pub address: String,
108    /// Number of nodes on this shard.
109    pub node_count: usize,
110    /// Number of edges on this shard (including ghost edges).
111    pub edge_count: usize,
112    /// Number of documents assigned to this shard.
113    pub document_count: usize,
114    /// Unix timestamp of the last heartbeat from this shard.
115    pub last_heartbeat: u64,
116}
117
118impl ShardInfo {
119    /// Create a new shard info with the given ID and address.
120    pub fn new(id: ShardId, address: String) -> Self {
121        Self {
122            id,
123            address,
124            node_count: 0,
125            edge_count: 0,
126            document_count: 0,
127            last_heartbeat: 0,
128        }
129    }
130}
131
132/// Errors that can occur in distributed operations.
133#[derive(Error, Debug, Clone)]
134pub enum DistributedError {
135    #[error("Shard {0:?} not found")]
136    ShardNotFound(ShardId),
137
138    #[error("Coordinator unavailable")]
139    CoordinatorUnavailable,
140
141    #[error("RPC error: {0}")]
142    RpcError(String),
143
144    #[error("Timeout waiting for phase {0:?}")]
145    PhaseTimeout(TickPhase),
146
147    #[error("Document routing failed for {0:?}")]
148    RoutingFailed(DocumentId),
149
150    #[error("Cross-shard edge resolution failed")]
151    EdgeResolutionFailed,
152
153    #[error("Ghost node not found: {0:?}")]
154    GhostNodeNotFound(NodeId),
155
156    #[error("Barrier synchronization failed")]
157    BarrierFailed,
158}
159
160/// Result type for distributed operations.
161pub type DistributedResult<T> = Result<T, DistributedError>;
162
163/// Phases of a distributed tick.
164///
165/// Each tick is divided into phases that must be synchronized across all shards.
166/// The coordinator ensures all shards complete each phase before proceeding.
167#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
168pub enum TickPhase {
169    /// Agents sense the substrate (read-only phase).
170    Sense,
171    /// Process agent actions (write phase).
172    Act,
173    /// Decay signals, traces, and edges (maintenance phase).
174    Decay,
175    /// Advance tick counter (finalization phase).
176    Advance,
177}
178
179impl std::fmt::Display for TickPhase {
180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181        match self {
182            TickPhase::Sense => write!(f, "Sense"),
183            TickPhase::Act => write!(f, "Act"),
184            TickPhase::Decay => write!(f, "Decay"),
185            TickPhase::Advance => write!(f, "Advance"),
186        }
187    }
188}
189
190/// Result of completing a tick phase on a shard.
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct PhaseResult {
193    /// The shard that completed this phase.
194    pub shard_id: ShardId,
195    /// The phase that was completed.
196    pub phase: TickPhase,
197    /// The tick number this phase belongs to.
198    pub tick: Tick,
199    /// Cross-shard edges created this phase (need ghost resolution).
200    pub cross_shard_edges: Vec<CrossShardEdge>,
201    /// Local node count after this phase.
202    pub node_count: usize,
203    /// Local edge count after this phase.
204    pub edge_count: usize,
205}
206
207/// A cross-shard edge reference.
208///
209/// When an edge is created that spans two shards, the local shard stores
210/// the edge with a ghost node reference. This struct captures the information
211/// needed to resolve that ghost node on the remote shard.
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct CrossShardEdge {
214    /// The local node (source of the edge).
215    pub from_node: NodeId,
216    /// The remote node (target of the edge).
217    pub to_node: NodeId,
218    /// The shard that owns the target node.
219    pub to_shard: ShardId,
220    /// The edge weight.
221    pub weight: f64,
222}
223
224/// Request for a local query on a shard.
225///
226/// The coordinator sends this to each shard during distributed query execution.
227/// The shard uses the global document frequencies to compute proper TF-IDF scores.
228#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct LocalQueryRequest {
230    /// The query terms to search for.
231    pub query_terms: Vec<String>,
232    /// Maximum number of results to return from this shard.
233    pub max_results: usize,
234    /// Global document frequencies (from coordinator).
235    /// Used for proper TF-IDF scoring across shards.
236    pub global_df: HashMap<String, u64>,
237}
238
239/// Result of a local query on a shard.
240#[derive(Debug, Clone, Serialize, Deserialize)]
241pub struct LocalQueryResult {
242    /// The shard that produced these results.
243    pub shard_id: ShardId,
244    /// The scored nodes matching the query.
245    pub results: Vec<ScoredNode>,
246    /// Local term frequencies for global DF computation.
247    /// Sent back to coordinator for aggregation.
248    pub term_frequencies: HashMap<String, u64>,
249}
250
251/// A node with its relevance score.
252#[derive(Debug, Clone, Serialize, Deserialize)]
253pub struct ScoredNode {
254    /// The node identifier.
255    pub node_id: NodeId,
256    /// The node's label/content.
257    pub label: String,
258    /// The relevance score (higher is better).
259    pub score: f64,
260    /// The shard this node belongs to.
261    pub shard_id: ShardId,
262}
263
264impl PartialEq for ScoredNode {
265    fn eq(&self, other: &Self) -> bool {
266        self.node_id == other.node_id && self.shard_id == other.shard_id
267    }
268}
269
270impl Eq for ScoredNode {}
271
272impl PartialOrd for ScoredNode {
273    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
274        Some(self.cmp(other))
275    }
276}
277
278impl Ord for ScoredNode {
279    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
280        // Reverse ordering for max-heap behavior (higher scores first)
281        other
282            .score
283            .partial_cmp(&self.score)
284            .unwrap_or(std::cmp::Ordering::Equal)
285    }
286}
287
288/// Health status of a shard.
289#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct ShardHealth {
291    /// The shard being reported on.
292    pub shard_id: ShardId,
293    /// Whether the shard is healthy and responsive.
294    pub healthy: bool,
295    /// Current load factor (0.0 = idle, 1.0 = fully loaded).
296    pub load: f64,
297    /// Memory usage in megabytes.
298    pub memory_usage_mb: u64,
299    /// Number of pending operations in the queue.
300    pub pending_operations: usize,
301}
302
303impl ShardHealth {
304    /// Create a healthy shard status with default values.
305    pub fn healthy(shard_id: ShardId) -> Self {
306        Self {
307            shard_id,
308            healthy: true,
309            load: 0.0,
310            memory_usage_mb: 0,
311            pending_operations: 0,
312        }
313    }
314
315    /// Create an unhealthy shard status.
316    pub fn unhealthy(shard_id: ShardId) -> Self {
317        Self {
318            shard_id,
319            healthy: false,
320            load: 0.0,
321            memory_usage_mb: 0,
322            pending_operations: 0,
323        }
324    }
325}
326
327/// A ghost node - minimal reference to a node on another shard.
328///
329/// Ghost nodes are placeholders for nodes that exist on remote shards.
330/// They enable local graph traversal to continue even when edges cross
331/// shard boundaries. The full data can be fetched lazily when needed.
332#[derive(Debug, Clone, Serialize, Deserialize)]
333pub struct GhostNode {
334    /// The actual node ID (same as on the owning shard).
335    pub node_id: NodeId,
336    /// The shard that owns this node.
337    pub shard_id: ShardId,
338    /// The node's label (cached for display/search).
339    pub label: String,
340    /// Full data fetched lazily when needed for operations.
341    pub full_data: Option<NodeData>,
342}
343
344impl GhostNode {
345    /// Create a new ghost node reference.
346    pub fn new(node_id: NodeId, shard_id: ShardId, label: String) -> Self {
347        Self {
348            node_id,
349            shard_id,
350            label,
351            full_data: None,
352        }
353    }
354
355    /// Check if the full data has been fetched.
356    pub fn is_resolved(&self) -> bool {
357        self.full_data.is_some()
358    }
359
360    /// Resolve this ghost node with full data from the remote shard.
361    pub fn resolve(&mut self, data: NodeData) {
362        self.full_data = Some(data);
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use phago_core::types::Position;
370
371    #[test]
372    fn test_shard_id() {
373        let shard = ShardId::new(42);
374        assert_eq!(shard.0, 42);
375        assert_eq!(shard.as_u32(), 42);
376        assert_eq!(format!("{}", shard), "shard-42");
377    }
378
379    #[test]
380    fn test_node_address() {
381        let addr = NodeAddress::new("127.0.0.1", 8080);
382        assert_eq!(addr.host, "127.0.0.1");
383        assert_eq!(addr.port, 8080);
384        assert_eq!(addr.to_socket_addr(), "127.0.0.1:8080");
385        assert_eq!(format!("{}", addr), "127.0.0.1:8080");
386    }
387
388    #[test]
389    fn test_distributed_config_default() {
390        let config = DistributedConfig::default();
391        assert_eq!(config.num_shards, 3);
392        assert_eq!(config.replication_factor, 2);
393        assert_eq!(config.rpc_timeout_ms, 5000);
394        assert_eq!(config.virtual_nodes_per_shard, 150);
395    }
396
397    #[test]
398    fn test_tick_phase_display() {
399        assert_eq!(format!("{}", TickPhase::Sense), "Sense");
400        assert_eq!(format!("{}", TickPhase::Act), "Act");
401        assert_eq!(format!("{}", TickPhase::Decay), "Decay");
402        assert_eq!(format!("{}", TickPhase::Advance), "Advance");
403    }
404
405    #[test]
406    fn test_scored_node_ordering() {
407        let node1 = ScoredNode {
408            node_id: NodeId::from_seed(1),
409            label: "high".to_string(),
410            score: 0.9,
411            shard_id: ShardId::new(0),
412        };
413        let node2 = ScoredNode {
414            node_id: NodeId::from_seed(2),
415            label: "low".to_string(),
416            score: 0.1,
417            shard_id: ShardId::new(0),
418        };
419
420        // Higher score should come first (reverse ordering)
421        assert!(node1 < node2);
422    }
423
424    #[test]
425    fn test_ghost_node_resolution() {
426        let mut ghost = GhostNode::new(NodeId::from_seed(1), ShardId::new(1), "test".to_string());
427        assert!(!ghost.is_resolved());
428
429        let data = NodeData {
430            id: NodeId::from_seed(1),
431            label: "test".to_string(),
432            node_type: phago_core::types::NodeType::Concept,
433            position: Position::new(0.0, 0.0),
434            access_count: 0,
435            created_tick: 0,
436            embedding: None,
437        };
438        ghost.resolve(data);
439        assert!(ghost.is_resolved());
440    }
441
442    #[test]
443    fn test_shard_health() {
444        let healthy = ShardHealth::healthy(ShardId::new(0));
445        assert!(healthy.healthy);
446        assert_eq!(healthy.load, 0.0);
447
448        let unhealthy = ShardHealth::unhealthy(ShardId::new(1));
449        assert!(!unhealthy.healthy);
450    }
451
452    #[test]
453    fn test_shard_info_new() {
454        let info = ShardInfo::new(ShardId::new(5), "127.0.0.1:8085".to_string());
455        assert_eq!(info.id, ShardId::new(5));
456        assert_eq!(info.address, "127.0.0.1:8085");
457        assert_eq!(info.node_count, 0);
458        assert_eq!(info.edge_count, 0);
459        assert_eq!(info.document_count, 0);
460    }
461
462    #[test]
463    fn test_phase_result() {
464        let result = PhaseResult {
465            shard_id: ShardId::new(0),
466            phase: TickPhase::Sense,
467            tick: 42,
468            cross_shard_edges: vec![],
469            node_count: 100,
470            edge_count: 250,
471        };
472        assert_eq!(result.tick, 42);
473        assert_eq!(result.node_count, 100);
474    }
475
476    #[test]
477    fn test_cross_shard_edge() {
478        let edge = CrossShardEdge {
479            from_node: NodeId::from_seed(1),
480            to_node: NodeId::from_seed(2),
481            to_shard: ShardId::new(1),
482            weight: 0.75,
483        };
484        assert_eq!(edge.to_shard, ShardId::new(1));
485        assert!((edge.weight - 0.75).abs() < f64::EPSILON);
486    }
487
488    #[test]
489    fn test_local_query_request() {
490        let mut global_df = HashMap::new();
491        global_df.insert("rust".to_string(), 100);
492        global_df.insert("programming".to_string(), 200);
493
494        let request = LocalQueryRequest {
495            query_terms: vec!["rust".to_string(), "programming".to_string()],
496            max_results: 10,
497            global_df,
498        };
499        assert_eq!(request.query_terms.len(), 2);
500        assert_eq!(request.max_results, 10);
501    }
502}