pingora_ketama/
lib.rs

1// Copyright 2024 Cloudflare, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! # pingora-ketama
16//! A Rust port of the nginx consistent hashing algorithm.
17//!
18//! This crate provides a consistent hashing algorithm which is identical in
19//! behavior to [nginx consistent hashing](https://www.nginx.com/resources/wiki/modules/consistent_hash/).
20//!
21//! Using a consistent hash strategy like this is useful when one wants to
22//! minimize the amount of requests that need to be rehashed to different nodes
23//! when a node is added or removed.
24//!
25//! Here's a simple example of how one might use it:
26//!
27//! ```
28//! use pingora_ketama::{Bucket, Continuum};
29//!
30//! # #[allow(clippy::needless_doctest_main)]
31//! fn main() {
32//!     // Set up a continuum with a few nodes of various weight.
33//!     let mut buckets = vec![];
34//!     buckets.push(Bucket::new("127.0.0.1:12345".parse().unwrap(), 1));
35//!     buckets.push(Bucket::new("127.0.0.2:12345".parse().unwrap(), 2));
36//!     buckets.push(Bucket::new("127.0.0.3:12345".parse().unwrap(), 3));
37//!     let ring = Continuum::new(&buckets);
38//!
39//!     // Let's see what the result is for a few keys:
40//!     for key in &["some_key", "another_key", "last_key"] {
41//!         let node = ring.node(key.as_bytes()).unwrap();
42//!         println!("{}: {}:{}", key, node.ip(), node.port());
43//!     }
44//! }
45//! ```
46//!
47//! ```bash
48//! # Output:
49//! some_key: 127.0.0.3:12345
50//! another_key: 127.0.0.3:12345
51//! last_key: 127.0.0.2:12345
52//! ```
53//!
54//! We've provided a health-aware example in
55//! `pingora-ketama/examples/health_aware_selector.rs`.
56//!
57//! For a carefully crafted real-world example, see the [`pingora-load-balancing`](https://docs.rs/pingora-load-balancing)
58//! crate.
59
60use std::cmp::Ordering;
61use std::io::Write;
62use std::net::SocketAddr;
63
64use crc32fast::Hasher;
65
66/// A [Bucket] represents a server for consistent hashing
67///
68/// A [Bucket] contains a [SocketAddr] to the server and a weight associated with it.
69#[derive(Clone, Debug, Eq, PartialEq, PartialOrd)]
70pub struct Bucket {
71    // The node name.
72    // TODO: UDS
73    node: SocketAddr,
74
75    // The weight associated with a node. A higher weight indicates that this node should
76    // receive more requests.
77    weight: u32,
78}
79
80impl Bucket {
81    /// Return a new bucket with the given node and weight.
82    ///
83    /// The chance that a [Bucket] is selected is proportional to the relative weight of all [Bucket]s.
84    ///
85    /// # Panics
86    ///
87    /// This will panic if the weight is zero.
88    pub fn new(node: SocketAddr, weight: u32) -> Self {
89        assert!(weight != 0, "weight must be at least one");
90
91        Bucket { node, weight }
92    }
93}
94
95// A point on the continuum.
96#[derive(Clone, Debug, Eq, PartialEq)]
97struct Point {
98    // the index to the actual address
99    node: u32,
100    hash: u32,
101}
102
103// We only want to compare the hash when sorting, so we implement these traits by hand.
104impl Ord for Point {
105    fn cmp(&self, other: &Self) -> Ordering {
106        self.hash.cmp(&other.hash)
107    }
108}
109
110impl PartialOrd for Point {
111    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
112        Some(self.cmp(other))
113    }
114}
115
116impl Point {
117    fn new(node: u32, hash: u32) -> Self {
118        Point { node, hash }
119    }
120}
121
122/// The consistent hashing ring
123///
124/// A [Continuum] represents a ring of buckets where a node is associated with various points on
125/// the ring.
126pub struct Continuum {
127    ring: Box<[Point]>,
128    addrs: Box<[SocketAddr]>,
129}
130
131impl Continuum {
132    /// Create a new [Continuum] with the given list of buckets.
133    pub fn new(buckets: &[Bucket]) -> Self {
134        // This constant is copied from nginx. It will create 160 points per weight unit. For
135        // example, a weight of 2 will create 320 points on the ring.
136        const POINT_MULTIPLE: u32 = 160;
137
138        if buckets.is_empty() {
139            return Continuum {
140                ring: Box::new([]),
141                addrs: Box::new([]),
142            };
143        }
144
145        // The total weight is multiplied by the factor of points to create many points per node.
146        let total_weight: u32 = buckets.iter().fold(0, |sum, b| sum + b.weight);
147        let mut ring = Vec::with_capacity((total_weight * POINT_MULTIPLE) as usize);
148        let mut addrs = Vec::with_capacity(buckets.len());
149
150        for bucket in buckets {
151            let mut hasher = Hasher::new();
152
153            // We only do the following for backwards compatibility with nginx/memcache:
154            // - Convert SocketAddr to string
155            // - The hash input is as follows "HOST EMPTY PORT PREVIOUS_HASH". Spaces are only added
156            //   for readability.
157            // TODO: remove this logic and hash the literal SocketAddr once we no longer
158            // need backwards compatibility
159
160            // with_capacity = max_len(ipv6)(39) + len(null)(1) + max_len(port)(5)
161            let mut hash_bytes = Vec::with_capacity(39 + 1 + 5);
162            write!(&mut hash_bytes, "{}", bucket.node.ip()).unwrap();
163            write!(&mut hash_bytes, "\0").unwrap();
164            write!(&mut hash_bytes, "{}", bucket.node.port()).unwrap();
165            hasher.update(hash_bytes.as_ref());
166
167            // A higher weight will add more points for this node.
168            let num_points = bucket.weight * POINT_MULTIPLE;
169
170            // This is appended to the crc32 hash for each point.
171            let mut prev_hash: u32 = 0;
172            addrs.push(bucket.node);
173            let node = addrs.len() - 1;
174            for _ in 0..num_points {
175                let mut hasher = hasher.clone();
176                hasher.update(&prev_hash.to_le_bytes());
177
178                let hash = hasher.finalize();
179                ring.push(Point::new(node as u32, hash));
180                prev_hash = hash;
181            }
182        }
183
184        // Sort and remove any duplicates.
185        ring.sort_unstable();
186        ring.dedup_by(|a, b| a.hash == b.hash);
187
188        Continuum {
189            ring: ring.into_boxed_slice(),
190            addrs: addrs.into_boxed_slice(),
191        }
192    }
193
194    /// Find the associated index for the given input.
195    pub fn node_idx(&self, input: &[u8]) -> usize {
196        let hash = crc32fast::hash(input);
197
198        // The `Result` returned here is either a match or the error variant returns where the
199        // value would be inserted.
200        match self.ring.binary_search_by(|p| p.hash.cmp(&hash)) {
201            Ok(i) => i,
202            Err(i) => {
203                // We wrap around to the front if this value would be inserted at the end.
204                if i == self.ring.len() {
205                    0
206                } else {
207                    i
208                }
209            }
210        }
211    }
212
213    /// Hash the given `hash_key` to the server address.
214    pub fn node(&self, hash_key: &[u8]) -> Option<SocketAddr> {
215        self.ring
216            .get(self.node_idx(hash_key)) // should we unwrap here?
217            .map(|p| self.addrs[p.node as usize])
218    }
219
220    /// Get an iterator of nodes starting at the original hashed node of the `hash_key`.
221    ///
222    /// This function is useful to find failover servers if the original ones are offline, which is
223    /// cheaper than rebuilding the entire hash ring.
224    pub fn node_iter(&self, hash_key: &[u8]) -> NodeIterator {
225        NodeIterator {
226            idx: self.node_idx(hash_key),
227            continuum: self,
228        }
229    }
230
231    pub fn get_addr(&self, idx: &mut usize) -> Option<&SocketAddr> {
232        let point = self.ring.get(*idx);
233        if point.is_some() {
234            // only update idx for non-empty ring otherwise we will panic on modulo 0
235            *idx = (*idx + 1) % self.ring.len();
236        }
237        point.map(|p| &self.addrs[p.node as usize])
238    }
239}
240
241/// Iterator over a Continuum
242pub struct NodeIterator<'a> {
243    idx: usize,
244    continuum: &'a Continuum,
245}
246
247impl<'a> Iterator for NodeIterator<'a> {
248    type Item = &'a SocketAddr;
249
250    fn next(&mut self) -> Option<Self::Item> {
251        self.continuum.get_addr(&mut self.idx)
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use std::net::SocketAddr;
258    use std::path::Path;
259
260    use super::{Bucket, Continuum};
261
262    fn get_sockaddr(ip: &str) -> SocketAddr {
263        ip.parse().unwrap()
264    }
265
266    #[test]
267    fn consistency_after_adding_host() {
268        fn assert_hosts(c: &Continuum) {
269            assert_eq!(c.node(b"a"), Some(get_sockaddr("127.0.0.10:6443")));
270            assert_eq!(c.node(b"b"), Some(get_sockaddr("127.0.0.5:6443")));
271        }
272
273        let buckets: Vec<_> = (1..11)
274            .map(|u| Bucket::new(get_sockaddr(&format!("127.0.0.{u}:6443")), 1))
275            .collect();
276        let c = Continuum::new(&buckets);
277        assert_hosts(&c);
278
279        // Now add a new host and ensure that the hosts don't get shuffled.
280        let buckets: Vec<_> = (1..12)
281            .map(|u| Bucket::new(get_sockaddr(&format!("127.0.0.{u}:6443")), 1))
282            .collect();
283
284        let c = Continuum::new(&buckets);
285        assert_hosts(&c);
286    }
287
288    #[test]
289    fn matches_nginx_sample() {
290        let upstream_hosts = ["127.0.0.1:7777", "127.0.0.1:7778"];
291        let upstream_hosts = upstream_hosts.iter().map(|i| get_sockaddr(i));
292
293        let mut buckets = Vec::new();
294        for upstream in upstream_hosts {
295            buckets.push(Bucket::new(upstream, 1));
296        }
297
298        let c = Continuum::new(&buckets);
299
300        assert_eq!(c.node(b"/some/path"), Some(get_sockaddr("127.0.0.1:7778")));
301        assert_eq!(
302            c.node(b"/some/longer/path"),
303            Some(get_sockaddr("127.0.0.1:7777"))
304        );
305        assert_eq!(
306            c.node(b"/sad/zaidoon"),
307            Some(get_sockaddr("127.0.0.1:7778"))
308        );
309        assert_eq!(c.node(b"/g"), Some(get_sockaddr("127.0.0.1:7777")));
310        assert_eq!(
311            c.node(b"/pingora/team/is/cool/and/this/is/a/long/uri"),
312            Some(get_sockaddr("127.0.0.1:7778"))
313        );
314        assert_eq!(
315            c.node(b"/i/am/not/confident/in/this/code"),
316            Some(get_sockaddr("127.0.0.1:7777"))
317        );
318    }
319
320    #[test]
321    fn matches_nginx_sample_data() {
322        let upstream_hosts = [
323            "10.0.0.1:443",
324            "10.0.0.2:443",
325            "10.0.0.3:443",
326            "10.0.0.4:443",
327            "10.0.0.5:443",
328            "10.0.0.6:443",
329            "10.0.0.7:443",
330            "10.0.0.8:443",
331            "10.0.0.9:443",
332        ];
333        let upstream_hosts = upstream_hosts.iter().map(|i| get_sockaddr(i));
334
335        let mut buckets = Vec::new();
336        for upstream in upstream_hosts {
337            buckets.push(Bucket::new(upstream, 100));
338        }
339
340        let c = Continuum::new(&buckets);
341
342        let path = Path::new(env!("CARGO_MANIFEST_DIR"))
343            .join("test-data")
344            .join("sample-nginx-upstream.csv");
345
346        let mut rdr = csv::ReaderBuilder::new()
347            .has_headers(false)
348            .from_path(path)
349            .unwrap();
350
351        for pair in rdr.records() {
352            let pair = pair.unwrap();
353            let uri = pair.get(0).unwrap();
354            let upstream = pair.get(1).unwrap();
355
356            let got = c.node(uri.as_bytes()).unwrap();
357            assert_eq!(got, get_sockaddr(upstream));
358        }
359    }
360
361    #[test]
362    fn node_iter() {
363        let upstream_hosts = ["127.0.0.1:7777", "127.0.0.1:7778", "127.0.0.1:7779"];
364        let upstream_hosts = upstream_hosts.iter().map(|i| get_sockaddr(i));
365
366        let mut buckets = Vec::new();
367        for upstream in upstream_hosts {
368            buckets.push(Bucket::new(upstream, 1));
369        }
370
371        let c = Continuum::new(&buckets);
372        let mut iter = c.node_iter(b"doghash");
373        assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7778")));
374        assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7779")));
375        assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7779")));
376        assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7777")));
377        assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7777")));
378        assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7778")));
379        assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7778")));
380        assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7779")));
381
382        // drop 127.0.0.1:7777
383        let upstream_hosts = ["127.0.0.1:7777", "127.0.0.1:7779"];
384        let upstream_hosts = upstream_hosts.iter().map(|i| get_sockaddr(i));
385
386        let mut buckets = Vec::new();
387        for upstream in upstream_hosts {
388            buckets.push(Bucket::new(upstream, 1));
389        }
390
391        let c = Continuum::new(&buckets);
392        let mut iter = c.node_iter(b"doghash");
393        // 127.0.0.1:7778 nodes are gone now
394        // assert_eq!(iter.next(), Some("127.0.0.1:7778"));
395        assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7779")));
396        assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7779")));
397        assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7777")));
398        assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7777")));
399        // assert_eq!(iter.next(), Some("127.0.0.1:7778"));
400        // assert_eq!(iter.next(), Some("127.0.0.1:7778"));
401        assert_eq!(iter.next(), Some(&get_sockaddr("127.0.0.1:7779")));
402
403        // assert infinite cycle
404        let c = Continuum::new(&[Bucket::new(get_sockaddr("127.0.0.1:7777"), 1)]);
405        let mut iter = c.node_iter(b"doghash");
406
407        let start_idx = iter.idx;
408        for _ in 0..c.ring.len() {
409            assert!(iter.next().is_some());
410        }
411        // assert wrap around
412        assert_eq!(start_idx, iter.idx);
413    }
414
415    #[test]
416    fn test_empty() {
417        let c = Continuum::new(&[]);
418        assert!(c.node(b"doghash").is_none());
419
420        let mut iter = c.node_iter(b"doghash");
421        assert!(iter.next().is_none());
422        assert!(iter.next().is_none());
423        assert!(iter.next().is_none());
424    }
425
426    #[test]
427    fn test_ipv6_ring() {
428        let upstream_hosts = ["[::1]:7777", "[::1]:7778", "[::1]:7779"];
429        let upstream_hosts = upstream_hosts.iter().map(|i| get_sockaddr(i));
430
431        let mut buckets = Vec::new();
432        for upstream in upstream_hosts {
433            buckets.push(Bucket::new(upstream, 1));
434        }
435
436        let c = Continuum::new(&buckets);
437        let mut iter = c.node_iter(b"doghash");
438        assert_eq!(iter.next(), Some(&get_sockaddr("[::1]:7777")));
439        assert_eq!(iter.next(), Some(&get_sockaddr("[::1]:7778")));
440        assert_eq!(iter.next(), Some(&get_sockaddr("[::1]:7777")));
441        assert_eq!(iter.next(), Some(&get_sockaddr("[::1]:7778")));
442        assert_eq!(iter.next(), Some(&get_sockaddr("[::1]:7778")));
443        assert_eq!(iter.next(), Some(&get_sockaddr("[::1]:7777")));
444        assert_eq!(iter.next(), Some(&get_sockaddr("[::1]:7779")));
445    }
446}