use phago_core::types::{DocumentId, NodeData, NodeId, Tick};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub struct ShardId(pub u32);
impl ShardId {
pub fn new(id: u32) -> Self {
Self(id)
}
pub fn as_u32(&self) -> u32 {
self.0
}
}
impl std::fmt::Display for ShardId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "shard-{}", self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct NodeAddress {
pub host: String,
pub port: u16,
}
impl NodeAddress {
pub fn new(host: impl Into<String>, port: u16) -> Self {
Self {
host: host.into(),
port,
}
}
pub fn to_socket_addr(&self) -> String {
format!("{}:{}", self.host, self.port)
}
}
impl std::fmt::Display for NodeAddress {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}:{}", self.host, self.port)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DistributedConfig {
pub num_shards: u32,
pub replication_factor: u32,
pub rpc_timeout_ms: u64,
pub virtual_nodes_per_shard: u32,
}
impl Default for DistributedConfig {
fn default() -> Self {
Self {
num_shards: 3,
replication_factor: 2,
rpc_timeout_ms: 5000,
virtual_nodes_per_shard: 150,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ShardStatus {
Online,
Offline,
Recovering,
Draining,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShardInfo {
pub id: ShardId,
pub address: String,
pub node_count: usize,
pub edge_count: usize,
pub document_count: usize,
pub last_heartbeat: u64,
}
impl ShardInfo {
pub fn new(id: ShardId, address: String) -> Self {
Self {
id,
address,
node_count: 0,
edge_count: 0,
document_count: 0,
last_heartbeat: 0,
}
}
}
#[derive(Error, Debug, Clone)]
pub enum DistributedError {
#[error("Shard {0:?} not found")]
ShardNotFound(ShardId),
#[error("Coordinator unavailable")]
CoordinatorUnavailable,
#[error("RPC error: {0}")]
RpcError(String),
#[error("Timeout waiting for phase {0:?}")]
PhaseTimeout(TickPhase),
#[error("Document routing failed for {0:?}")]
RoutingFailed(DocumentId),
#[error("Cross-shard edge resolution failed")]
EdgeResolutionFailed,
#[error("Ghost node not found: {0:?}")]
GhostNodeNotFound(NodeId),
#[error("Barrier synchronization failed")]
BarrierFailed,
}
pub type DistributedResult<T> = Result<T, DistributedError>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum TickPhase {
Sense,
Act,
Decay,
Advance,
}
impl std::fmt::Display for TickPhase {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TickPhase::Sense => write!(f, "Sense"),
TickPhase::Act => write!(f, "Act"),
TickPhase::Decay => write!(f, "Decay"),
TickPhase::Advance => write!(f, "Advance"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PhaseResult {
pub shard_id: ShardId,
pub phase: TickPhase,
pub tick: Tick,
pub cross_shard_edges: Vec<CrossShardEdge>,
pub node_count: usize,
pub edge_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrossShardEdge {
pub from_node: NodeId,
pub to_node: NodeId,
pub to_shard: ShardId,
pub weight: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LocalQueryRequest {
pub query_terms: Vec<String>,
pub max_results: usize,
pub global_df: HashMap<String, u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LocalQueryResult {
pub shard_id: ShardId,
pub results: Vec<ScoredNode>,
pub term_frequencies: HashMap<String, u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoredNode {
pub node_id: NodeId,
pub label: String,
pub score: f64,
pub shard_id: ShardId,
}
impl PartialEq for ScoredNode {
fn eq(&self, other: &Self) -> bool {
self.node_id == other.node_id && self.shard_id == other.shard_id
}
}
impl Eq for ScoredNode {}
impl PartialOrd for ScoredNode {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for ScoredNode {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other
.score
.partial_cmp(&self.score)
.unwrap_or(std::cmp::Ordering::Equal)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShardHealth {
pub shard_id: ShardId,
pub healthy: bool,
pub load: f64,
pub memory_usage_mb: u64,
pub pending_operations: usize,
}
impl ShardHealth {
pub fn healthy(shard_id: ShardId) -> Self {
Self {
shard_id,
healthy: true,
load: 0.0,
memory_usage_mb: 0,
pending_operations: 0,
}
}
pub fn unhealthy(shard_id: ShardId) -> Self {
Self {
shard_id,
healthy: false,
load: 0.0,
memory_usage_mb: 0,
pending_operations: 0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GhostNode {
pub node_id: NodeId,
pub shard_id: ShardId,
pub label: String,
pub full_data: Option<NodeData>,
}
impl GhostNode {
pub fn new(node_id: NodeId, shard_id: ShardId, label: String) -> Self {
Self {
node_id,
shard_id,
label,
full_data: None,
}
}
pub fn is_resolved(&self) -> bool {
self.full_data.is_some()
}
pub fn resolve(&mut self, data: NodeData) {
self.full_data = Some(data);
}
}
#[cfg(test)]
mod tests {
use super::*;
use phago_core::types::Position;
#[test]
fn test_shard_id() {
let shard = ShardId::new(42);
assert_eq!(shard.0, 42);
assert_eq!(shard.as_u32(), 42);
assert_eq!(format!("{}", shard), "shard-42");
}
#[test]
fn test_node_address() {
let addr = NodeAddress::new("127.0.0.1", 8080);
assert_eq!(addr.host, "127.0.0.1");
assert_eq!(addr.port, 8080);
assert_eq!(addr.to_socket_addr(), "127.0.0.1:8080");
assert_eq!(format!("{}", addr), "127.0.0.1:8080");
}
#[test]
fn test_distributed_config_default() {
let config = DistributedConfig::default();
assert_eq!(config.num_shards, 3);
assert_eq!(config.replication_factor, 2);
assert_eq!(config.rpc_timeout_ms, 5000);
assert_eq!(config.virtual_nodes_per_shard, 150);
}
#[test]
fn test_tick_phase_display() {
assert_eq!(format!("{}", TickPhase::Sense), "Sense");
assert_eq!(format!("{}", TickPhase::Act), "Act");
assert_eq!(format!("{}", TickPhase::Decay), "Decay");
assert_eq!(format!("{}", TickPhase::Advance), "Advance");
}
#[test]
fn test_scored_node_ordering() {
let node1 = ScoredNode {
node_id: NodeId::from_seed(1),
label: "high".to_string(),
score: 0.9,
shard_id: ShardId::new(0),
};
let node2 = ScoredNode {
node_id: NodeId::from_seed(2),
label: "low".to_string(),
score: 0.1,
shard_id: ShardId::new(0),
};
assert!(node1 < node2);
}
#[test]
fn test_ghost_node_resolution() {
let mut ghost = GhostNode::new(NodeId::from_seed(1), ShardId::new(1), "test".to_string());
assert!(!ghost.is_resolved());
let data = NodeData {
id: NodeId::from_seed(1),
label: "test".to_string(),
node_type: phago_core::types::NodeType::Concept,
position: Position::new(0.0, 0.0),
access_count: 0,
created_tick: 0,
embedding: None,
};
ghost.resolve(data);
assert!(ghost.is_resolved());
}
#[test]
fn test_shard_health() {
let healthy = ShardHealth::healthy(ShardId::new(0));
assert!(healthy.healthy);
assert_eq!(healthy.load, 0.0);
let unhealthy = ShardHealth::unhealthy(ShardId::new(1));
assert!(!unhealthy.healthy);
}
#[test]
fn test_shard_info_new() {
let info = ShardInfo::new(ShardId::new(5), "127.0.0.1:8085".to_string());
assert_eq!(info.id, ShardId::new(5));
assert_eq!(info.address, "127.0.0.1:8085");
assert_eq!(info.node_count, 0);
assert_eq!(info.edge_count, 0);
assert_eq!(info.document_count, 0);
}
#[test]
fn test_phase_result() {
let result = PhaseResult {
shard_id: ShardId::new(0),
phase: TickPhase::Sense,
tick: 42,
cross_shard_edges: vec![],
node_count: 100,
edge_count: 250,
};
assert_eq!(result.tick, 42);
assert_eq!(result.node_count, 100);
}
#[test]
fn test_cross_shard_edge() {
let edge = CrossShardEdge {
from_node: NodeId::from_seed(1),
to_node: NodeId::from_seed(2),
to_shard: ShardId::new(1),
weight: 0.75,
};
assert_eq!(edge.to_shard, ShardId::new(1));
assert!((edge.weight - 0.75).abs() < f64::EPSILON);
}
#[test]
fn test_local_query_request() {
let mut global_df = HashMap::new();
global_df.insert("rust".to_string(), 100);
global_df.insert("programming".to_string(), 200);
let request = LocalQueryRequest {
query_terms: vec!["rust".to_string(), "programming".to_string()],
max_results: 10,
global_df,
};
assert_eq!(request.query_terms.len(), 2);
assert_eq!(request.max_results, 10);
}
}