use crate::types::{
GhostNode, LocalQueryRequest, LocalQueryResult, PhaseResult, ShardHealth, ShardId, ShardInfo,
TickPhase,
};
use phago_core::types::{Document, DocumentId, NodeData, NodeId};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub type RpcResult<T> = Result<T, RpcError>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RpcError {
ShardNotFound(u32),
CoordinatorUnavailable,
RpcFailed(String),
PhaseTimeout(String),
RoutingFailed,
EdgeResolutionFailed,
GhostNodeNotFound,
BarrierFailed,
Internal(String),
}
impl std::fmt::Display for RpcError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RpcError::ShardNotFound(id) => write!(f, "Shard {} not found", id),
RpcError::CoordinatorUnavailable => write!(f, "Coordinator unavailable"),
RpcError::RpcFailed(msg) => write!(f, "RPC failed: {}", msg),
RpcError::PhaseTimeout(phase) => write!(f, "Phase timeout: {}", phase),
RpcError::RoutingFailed => write!(f, "Document routing failed"),
RpcError::EdgeResolutionFailed => write!(f, "Edge resolution failed"),
RpcError::GhostNodeNotFound => write!(f, "Ghost node not found"),
RpcError::BarrierFailed => write!(f, "Barrier synchronization failed"),
RpcError::Internal(msg) => write!(f, "Internal error: {}", msg),
}
}
}
impl std::error::Error for RpcError {}
#[tarpc::service]
pub trait ShardService {
async fn ingest_document(doc: Document) -> RpcResult<DocumentId>;
async fn tick_phase(phase: TickPhase, tick: u64) -> RpcResult<PhaseResult>;
async fn local_query(req: LocalQueryRequest) -> RpcResult<LocalQueryResult>;
async fn get_term_frequencies(terms: Vec<String>) -> RpcResult<HashMap<String, u64>>;
async fn get_node(id: NodeId) -> RpcResult<Option<NodeData>>;
async fn health_check() -> RpcResult<ShardHealth>;
async fn resolve_ghost_nodes(node_ids: Vec<NodeId>) -> RpcResult<Vec<GhostNode>>;
async fn get_neighbors(node_id: NodeId) -> RpcResult<Vec<NodeId>>;
async fn receive_signals(signals: Vec<crate::rpc::messages::CrossShardSignal>)
-> RpcResult<()>;
}
#[tarpc::service]
pub trait CoordinatorService {
async fn register(info: ShardInfo) -> RpcResult<ShardId>;
async fn unregister(shard_id: ShardId) -> RpcResult<()>;
async fn phase_complete(shard_id: ShardId, phase: TickPhase, tick: u64) -> RpcResult<()>;
async fn route_document(doc_id: DocumentId) -> ShardId;
async fn route_node(node_id: NodeId) -> ShardId;
async fn get_global_df(terms: Vec<String>) -> RpcResult<HashMap<String, u64>>;
async fn barrier_ready(shard_id: ShardId, phase: TickPhase, tick: u64) -> RpcResult<bool>;
async fn current_tick() -> u64;
async fn list_shards() -> Vec<ShardInfo>;
async fn start_tick() -> RpcResult<u64>;
async fn tick_status() -> RpcResult<TickStatus>;
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TickStatus {
pub tick: u64,
pub phase: TickPhase,
pub completed_shards: Vec<ShardId>,
pub pending_shards: Vec<ShardId>,
pub tick_complete: bool,
}