Skip to main content

oxibonsai_runtime/
distributed.rs

1//! Distributed serving infrastructure for OxiBonsai.
2//!
3//! Provides a consistent hash ring for request routing across multiple
4//! inference nodes, a node registry with health tracking, and a multi-node
5//! coordinator that manages cluster topology.
6//!
7//! # Architecture
8//!
9//! ```text
10//!  ┌────────────────────────────────────────────┐
11//!  │  DistributedCoordinator                    │
12//!  │  ┌──────────────────────────────────────┐  │
13//!  │  │  NodeRegistry                        │  │
14//!  │  │  ┌────────────────────────────────┐  │  │
15//!  │  │  │  ConsistentHashRing            │  │  │
16//!  │  │  │  [VNode, VNode, ..., VNode]    │  │  │
17//!  │  │  └────────────────────────────────┘  │  │
18//!  │  │  HashMap<node_id, NodeInfo>          │  │
19//!  │  └──────────────────────────────────────┘  │
20//!  └────────────────────────────────────────────┘
21//! ```
22//!
23//! The hash ring uses FNV-1a (64-bit) with virtual nodes for even distribution.
24//! All state is in-memory — no actual TCP connections are made.
25
26use std::collections::HashMap;
27
28// ─── FNV-1a hash ──────────────────────────────────────────────────────────────
29
30/// FNV-1a 64-bit hash — fast, good distribution, no external deps.
31///
32/// Reference: <http://www.isthe.com/chongo/tech/comp/fnv/>
33pub fn fnv1a_hash(input: &str) -> u64 {
34    const OFFSET_BASIS: u64 = 14_695_981_039_346_656_037;
35    const PRIME: u64 = 1_099_511_628_211;
36
37    let mut hash: u64 = OFFSET_BASIS;
38    for byte in input.bytes() {
39        hash ^= byte as u64;
40        hash = hash.wrapping_mul(PRIME);
41    }
42    hash
43}
44
45// ─── Consistent Hash Ring ─────────────────────────────────────────────────────
46
47/// A virtual node on the ring — maps a hash position to a backend node.
48#[derive(Debug, Clone)]
49pub struct VNode {
50    /// Position on the ring (FNV-1a hash of `"<node_id>#<replica_index>"`).
51    pub hash: u64,
52    /// The real node this virtual node represents.
53    pub node_id: String,
54}
55
56/// Consistent hash ring with virtual nodes for even load distribution.
57///
58/// Virtual nodes (`replicas` per real node) help achieve more uniform key
59/// distribution even with small cluster sizes.
60///
61/// # Example
62/// ```rust
63/// use oxibonsai_runtime::distributed::ConsistentHashRing;
64///
65/// let mut ring = ConsistentHashRing::new(150);
66/// ring.add_node("node-a");
67/// ring.add_node("node-b");
68/// let target = ring.get_node("my-request-key");
69/// assert!(target.is_some());
70/// ```
71pub struct ConsistentHashRing {
72    /// Virtual nodes sorted by hash — the ring's backbone.
73    vnodes: Vec<VNode>,
74    /// Number of virtual nodes created per real node.
75    replicas: usize,
76}
77
78impl ConsistentHashRing {
79    /// Create a new empty ring.
80    ///
81    /// `replicas` controls how many virtual nodes are placed on the ring per
82    /// real node. Higher values give better distribution at the cost of memory.
83    /// A value of 100–200 is typical.
84    pub fn new(replicas: usize) -> Self {
85        Self {
86            vnodes: Vec::new(),
87            replicas: replicas.max(1),
88        }
89    }
90
91    /// Add a node to the ring by inserting `replicas` virtual nodes.
92    ///
93    /// Virtual node keys are `"<node_id>#<i>"` for `i` in `0..replicas`.
94    /// After insertion the internal slice is re-sorted.
95    pub fn add_node(&mut self, node_id: &str) {
96        for i in 0..self.replicas {
97            let key = format!("{}#{}", node_id, i);
98            let hash = fnv1a_hash(&key);
99            self.vnodes.push(VNode {
100                hash,
101                node_id: node_id.to_string(),
102            });
103        }
104        self.vnodes.sort_unstable_by_key(|v| v.hash);
105    }
106
107    /// Remove all virtual nodes belonging to `node_id` from the ring.
108    pub fn remove_node(&mut self, node_id: &str) {
109        self.vnodes.retain(|v| v.node_id != node_id);
110    }
111
112    /// Route a key to the first virtual node at or after its hash position.
113    ///
114    /// Returns `None` if the ring is empty.
115    pub fn get_node(&self, key: &str) -> Option<&str> {
116        if self.vnodes.is_empty() {
117            return None;
118        }
119        let hash = fnv1a_hash(key);
120        // Binary search for the first vnode with hash >= key hash.
121        let idx = self.vnodes.partition_point(|v| v.hash < hash) % self.vnodes.len();
122        Some(&self.vnodes[idx].node_id)
123    }
124
125    /// Route a key and return up to `count` distinct real nodes in ring order.
126    ///
127    /// Useful for replication — returns the first `count` unique node IDs
128    /// encountered walking clockwise from the key's position.
129    pub fn get_nodes(&self, key: &str, count: usize) -> Vec<&str> {
130        if self.vnodes.is_empty() || count == 0 {
131            return Vec::new();
132        }
133        let hash = fnv1a_hash(key);
134        let start = self.vnodes.partition_point(|v| v.hash < hash) % self.vnodes.len();
135
136        let mut result: Vec<&str> = Vec::with_capacity(count);
137        let total = self.vnodes.len();
138
139        for offset in 0..total {
140            let idx = (start + offset) % total;
141            let node_id = self.vnodes[idx].node_id.as_str();
142            if !result.contains(&node_id) {
143                result.push(node_id);
144            }
145            if result.len() >= count {
146                break;
147            }
148        }
149        result
150    }
151
152    /// Number of distinct real nodes currently on the ring.
153    pub fn node_count(&self) -> usize {
154        let mut seen: Vec<&str> = Vec::new();
155        for v in &self.vnodes {
156            let s = v.node_id.as_str();
157            if !seen.contains(&s) {
158                seen.push(s);
159            }
160        }
161        seen.len()
162    }
163
164    /// Total number of virtual nodes on the ring (`replicas × node_count`).
165    pub fn vnode_count(&self) -> usize {
166        self.vnodes.len()
167    }
168}
169
170// ─── Node Registry ────────────────────────────────────────────────────────────
171
172/// Runtime information about a single serving node.
173#[derive(Clone, Debug)]
174pub struct NodeInfo {
175    /// Unique node identifier (e.g. `"node-0"`, `"gpu-west-1"`).
176    pub id: String,
177    /// Network address, `"host:port"` format (informational only).
178    pub addr: String,
179    /// Whether the node passed its most recent health check.
180    pub healthy: bool,
181    /// Normalised load factor in `[0.0, 1.0]` — 0 is idle, 1 is saturated.
182    pub load: f32,
183    /// Epoch-milliseconds timestamp of the last heartbeat/update.
184    pub last_seen_ms: u64,
185}
186
187impl NodeInfo {
188    /// Construct a healthy node with zero load.
189    pub fn new(id: impl Into<String>, addr: impl Into<String>) -> Self {
190        Self {
191            id: id.into(),
192            addr: addr.into(),
193            healthy: true,
194            load: 0.0,
195            last_seen_ms: current_time_ms(),
196        }
197    }
198}
199
200/// Cluster membership store backed by a consistent hash ring.
201///
202/// Maintains a live map of [`NodeInfo`] and mirrors add/remove operations
203/// into a [`ConsistentHashRing`] so routing decisions stay in sync.
204pub struct NodeRegistry {
205    nodes: HashMap<String, NodeInfo>,
206    ring: ConsistentHashRing,
207}
208
209impl NodeRegistry {
210    /// Create an empty registry with 150 virtual nodes per real node.
211    pub fn new() -> Self {
212        Self {
213            nodes: HashMap::new(),
214            ring: ConsistentHashRing::new(150),
215        }
216    }
217
218    /// Register a node (or overwrite an existing entry with the same ID).
219    pub fn register(&mut self, info: NodeInfo) {
220        let id = info.id.clone();
221        // If already present, remove its old vnodes before re-adding.
222        if self.nodes.contains_key(&id) {
223            self.ring.remove_node(&id);
224        }
225        self.ring.add_node(&id);
226        self.nodes.insert(id, info);
227    }
228
229    /// Remove a node from the registry and the hash ring.
230    pub fn deregister(&mut self, node_id: &str) {
231        self.ring.remove_node(node_id);
232        self.nodes.remove(node_id);
233    }
234
235    /// Update the health status of a node.
236    ///
237    /// If `healthy` is `false` the node is kept in the registry but excluded
238    /// from routing via `route_request` and `healthy_nodes`.
239    pub fn mark_healthy(&mut self, node_id: &str, healthy: bool) {
240        if let Some(node) = self.nodes.get_mut(node_id) {
241            node.healthy = healthy;
242            node.last_seen_ms = current_time_ms();
243        }
244    }
245
246    /// Update the load factor of a node.  Clamped to `[0.0, 1.0]`.
247    pub fn update_load(&mut self, node_id: &str, load: f32) {
248        if let Some(node) = self.nodes.get_mut(node_id) {
249            node.load = load.clamp(0.0, 1.0);
250            node.last_seen_ms = current_time_ms();
251        }
252    }
253
254    /// Route a request to a healthy node using consistent hashing.
255    ///
256    /// Walks the ring starting at `request_key`'s hash position and returns
257    /// the first node that exists in the registry **and** is healthy.
258    /// Returns `None` if there are no healthy nodes.
259    pub fn route_request(&self, request_key: &str) -> Option<&NodeInfo> {
260        // Ask the ring for up to all nodes in order, then pick the first
261        // healthy one.
262        let candidates = self.ring.get_nodes(request_key, self.nodes.len().max(1));
263        for node_id in candidates {
264            if let Some(info) = self.nodes.get(node_id) {
265                if info.healthy {
266                    return Some(info);
267                }
268            }
269        }
270        None
271    }
272
273    /// Returns references to all healthy nodes (arbitrary order).
274    pub fn healthy_nodes(&self) -> Vec<&NodeInfo> {
275        self.nodes.values().filter(|n| n.healthy).collect()
276    }
277
278    /// Returns references to all registered nodes (arbitrary order).
279    pub fn all_nodes(&self) -> Vec<&NodeInfo> {
280        self.nodes.values().collect()
281    }
282
283    /// Expose a reference to the underlying ring (read-only).
284    pub fn ring(&self) -> &ConsistentHashRing {
285        &self.ring
286    }
287}
288
289impl Default for NodeRegistry {
290    fn default() -> Self {
291        Self::new()
292    }
293}
294
295// ─── Multi-node Coordinator ───────────────────────────────────────────────────
296
297/// Configuration for a [`DistributedCoordinator`] instance.
298#[derive(Debug, Clone)]
299pub struct CoordinatorConfig {
300    /// Identifier for *this* node.
301    pub node_id: String,
302    /// Address this node listens on (`"host:port"`).
303    pub bind_addr: String,
304    /// Peer addresses to seed the cluster with.
305    pub peers: Vec<String>,
306    /// How often to send heartbeats (milliseconds).
307    pub heartbeat_interval_ms: u64,
308    /// Age after which a node is considered unhealthy (milliseconds).
309    pub health_timeout_ms: u64,
310}
311
312impl CoordinatorConfig {
313    /// Sensible defaults for a single-node development setup.
314    pub fn local_default(node_id: impl Into<String>) -> Self {
315        Self {
316            node_id: node_id.into(),
317            bind_addr: "127.0.0.1:8080".to_string(),
318            peers: Vec::new(),
319            heartbeat_interval_ms: 1_000,
320            health_timeout_ms: 5_000,
321        }
322    }
323}
324
325/// In-memory multi-node coordinator.
326///
327/// Manages cluster topology, routes incoming requests to healthy nodes via
328/// consistent hashing, and exposes cluster health information.
329///
330/// **Note:** This implementation is intentionally in-memory only — no actual
331/// TCP connections are established. It is designed for unit testing and
332/// single-process simulation. Production deployments would wrap this with a
333/// gRPC/HTTP gossip layer.
334pub struct DistributedCoordinator {
335    config: CoordinatorConfig,
336    registry: NodeRegistry,
337}
338
339impl DistributedCoordinator {
340    /// Create a new coordinator with the given configuration.
341    ///
342    /// Does not automatically register `self` — call `register_self` to
343    /// add this node to the ring.
344    pub fn new(config: CoordinatorConfig) -> Self {
345        Self {
346            config,
347            registry: NodeRegistry::new(),
348        }
349    }
350
351    /// Register this node in the local registry so it participates in routing.
352    pub fn register_self(&mut self) {
353        let info = NodeInfo::new(self.config.node_id.clone(), self.config.bind_addr.clone());
354        self.registry.register(info);
355    }
356
357    /// Add a peer to the registry as a healthy node with zero load.
358    ///
359    /// `addr` is the peer's `"host:port"` bind address.
360    /// `node_id` is the peer's unique identifier.
361    pub fn add_peer(&mut self, addr: &str, node_id: &str) {
362        let info = NodeInfo::new(node_id, addr);
363        self.registry.register(info);
364    }
365
366    /// Route `request_key` to a healthy node and return its address.
367    ///
368    /// Returns `None` if no healthy nodes are available.
369    pub fn route(&self, request_key: &str) -> Option<String> {
370        self.registry
371            .route_request(request_key)
372            .map(|n| n.addr.clone())
373    }
374
375    /// Total number of nodes registered in the cluster (healthy + unhealthy).
376    pub fn cluster_size(&self) -> usize {
377        self.registry.all_nodes().len()
378    }
379
380    /// Number of nodes currently marked as healthy.
381    pub fn healthy_count(&self) -> usize {
382        self.registry.healthy_nodes().len()
383    }
384
385    /// A human-readable summary of current cluster topology.
386    ///
387    /// Format (not stable across versions):
388    /// ```text
389    /// cluster[nodes=3 healthy=2 vnodes=450 self=node-0]
390    /// ```
391    pub fn topology_summary(&self) -> String {
392        let total = self.cluster_size();
393        let healthy = self.healthy_count();
394        let vnodes = self.registry.ring().vnode_count();
395        let self_id = &self.config.node_id;
396        format!("cluster[nodes={total} healthy={healthy} vnodes={vnodes} self={self_id}]")
397    }
398
399    /// Access the underlying registry (read-only).
400    pub fn registry(&self) -> &NodeRegistry {
401        &self.registry
402    }
403
404    /// Access the coordinator config.
405    pub fn config(&self) -> &CoordinatorConfig {
406        &self.config
407    }
408
409    /// Mark a peer node as healthy or unhealthy.
410    pub fn set_peer_health(&mut self, node_id: &str, healthy: bool) {
411        self.registry.mark_healthy(node_id, healthy);
412    }
413
414    /// Update a peer's reported load factor.
415    pub fn update_peer_load(&mut self, node_id: &str, load: f32) {
416        self.registry.update_load(node_id, load);
417    }
418}
419
420// ─── Helpers ──────────────────────────────────────────────────────────────────
421
422/// Return current wall-clock time in milliseconds since the Unix epoch.
423///
424/// Falls back to `0` if the system clock is before the epoch (unlikely).
425fn current_time_ms() -> u64 {
426    std::time::SystemTime::now()
427        .duration_since(std::time::UNIX_EPOCH)
428        .map(|d| d.as_millis() as u64)
429        .unwrap_or(0)
430}
431
432// ─── Unit tests ───────────────────────────────────────────────────────────────
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn fnv1a_deterministic() {
440        assert_eq!(fnv1a_hash("hello"), fnv1a_hash("hello"));
441        assert_eq!(fnv1a_hash(""), fnv1a_hash(""));
442    }
443
444    #[test]
445    fn fnv1a_different_inputs() {
446        assert_ne!(fnv1a_hash("foo"), fnv1a_hash("bar"));
447        assert_ne!(fnv1a_hash("node-0"), fnv1a_hash("node-1"));
448    }
449
450    #[test]
451    fn hash_ring_empty_returns_none() {
452        let ring = ConsistentHashRing::new(10);
453        assert!(ring.get_node("any-key").is_none());
454    }
455
456    #[test]
457    fn hash_ring_single_node_always_routes_there() {
458        let mut ring = ConsistentHashRing::new(10);
459        ring.add_node("solo");
460        for key in &["a", "b", "c", "hello", "world", "12345"] {
461            assert_eq!(ring.get_node(key), Some("solo"));
462        }
463    }
464
465    #[test]
466    fn hash_ring_vnode_count_equals_replicas_times_nodes() {
467        let mut ring = ConsistentHashRing::new(50);
468        ring.add_node("n1");
469        assert_eq!(ring.vnode_count(), 50);
470        ring.add_node("n2");
471        assert_eq!(ring.vnode_count(), 100);
472        ring.add_node("n3");
473        assert_eq!(ring.vnode_count(), 150);
474    }
475
476    #[test]
477    fn hash_ring_node_count() {
478        let mut ring = ConsistentHashRing::new(10);
479        assert_eq!(ring.node_count(), 0);
480        ring.add_node("a");
481        assert_eq!(ring.node_count(), 1);
482        ring.add_node("b");
483        assert_eq!(ring.node_count(), 2);
484    }
485}