1use std::collections::BTreeMap;
2use std::fmt::{Debug, Formatter};
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6use crate::cluster::topology::{Node, NodeMap};
7use crate::cluster::ConnectionManager;
8use crate::cluster::Murmur3Token;
9use crate::transport::CdrsTransport;
10
11pub struct TokenMap<T: CdrsTransport + 'static, CM: ConnectionManager<T> + 'static> {
13 token_ring: BTreeMap<Murmur3Token, Arc<Node<T, CM>>>,
14}
15
16impl<T: CdrsTransport, CM: ConnectionManager<T>> Clone for TokenMap<T, CM> {
17 fn clone(&self) -> Self {
18 TokenMap {
19 token_ring: self.token_ring.clone(),
20 }
21 }
22}
23
24impl<T: CdrsTransport, CM: ConnectionManager<T>> Debug for TokenMap<T, CM> {
25 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("TokenMap")
27 .field("token_ring", &self.token_ring)
28 .finish()
29 }
30}
31
32impl<T: CdrsTransport, CM: ConnectionManager<T>> Default for TokenMap<T, CM> {
33 fn default() -> Self {
34 TokenMap {
35 token_ring: Default::default(),
36 }
37 }
38}
39
40impl<T: CdrsTransport, CM: ConnectionManager<T>> TokenMap<T, CM> {
41 pub fn new(nodes: &NodeMap<T, CM>) -> Self {
42 TokenMap {
43 token_ring: nodes
44 .iter()
45 .flat_map(|(_, node)| {
46 node.tokens()
47 .iter()
48 .map(move |token| (*token, node.clone()))
49 })
50 .collect(),
51 }
52 }
53
54 pub fn nodes_for_token_capped(
56 &self,
57 token: Murmur3Token,
58 replica_count: usize,
59 ) -> impl Iterator<Item = Arc<Node<T, CM>>> + '_ {
60 self.token_ring
61 .range(token..)
62 .chain(self.token_ring.iter())
63 .take(replica_count)
64 .map(|(_, node)| node.clone())
65 }
66
67 pub fn nodes_for_token(
69 &self,
70 token: Murmur3Token,
71 ) -> impl Iterator<Item = Arc<Node<T, CM>>> + '_ {
72 self.token_ring
73 .range(token..)
74 .chain(self.token_ring.iter())
75 .take(self.token_ring.len())
76 .map(|(_, node)| node.clone())
77 }
78
79 #[must_use]
81 pub fn clone_with_node(&self, node: Arc<Node<T, CM>>) -> Self {
82 let mut map = self.clone();
83 for token in node.tokens() {
84 map.token_ring.insert(*token, node.clone());
85 }
86
87 map
88 }
89
90 #[must_use]
92 pub fn clone_without_node(&self, broadcast_rpc_address: SocketAddr) -> Self {
93 let token_ring = self
94 .token_ring
95 .iter()
96 .filter_map(|(token, node)| {
97 if node.broadcast_rpc_address() == broadcast_rpc_address {
98 None
99 } else {
100 Some((*token, node.clone()))
101 }
102 })
103 .collect();
104
105 TokenMap { token_ring }
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use cassandra_protocol::frame::Version;
112 use itertools::Itertools;
113 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
114 use std::sync::{Arc, LazyLock};
115 use tokio::sync::watch;
116 use uuid::Uuid;
117
118 use crate::cluster::connection_manager::MockConnectionManager;
119 use crate::cluster::connection_pool::ConnectionPoolFactory;
120 use crate::cluster::topology::{Node, NodeMap};
121 use crate::cluster::Murmur3Token;
122 use crate::cluster::TokenMap;
123 use crate::retry::MockReconnectionPolicy;
124 use crate::transport::MockCdrsTransport;
125
126 static HOST_ID_1: LazyLock<Uuid> = LazyLock::new(Uuid::new_v4);
127 static HOST_ID_2: LazyLock<Uuid> = LazyLock::new(Uuid::new_v4);
128 static HOST_ID_3: LazyLock<Uuid> = LazyLock::new(Uuid::new_v4);
129
130 fn prepare_nodes() -> NodeMap<MockCdrsTransport, MockConnectionManager<MockCdrsTransport>> {
131 let (_, keyspace_receiver) = watch::channel(None);
132 let connection_manager = MockConnectionManager::<MockCdrsTransport>::new();
133 let reconnection_policy = MockReconnectionPolicy::new();
134 let connection_pool_factory = Arc::new(ConnectionPoolFactory::new(
135 Default::default(),
136 Version::V4,
137 connection_manager,
138 keyspace_receiver,
139 Arc::new(reconnection_policy),
140 ));
141
142 let mut nodes = NodeMap::default();
143 nodes.insert(
144 *HOST_ID_1,
145 Arc::new(Node::new(
146 connection_pool_factory.clone(),
147 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)), 8080),
148 None,
149 Some(*HOST_ID_1),
150 None,
151 vec![
152 Murmur3Token::new(-2),
153 Murmur3Token::new(-1),
154 Murmur3Token::new(0),
155 ],
156 "".into(),
157 "".into(),
158 )),
159 );
160 nodes.insert(
161 *HOST_ID_2,
162 Arc::new(Node::new(
163 connection_pool_factory.clone(),
164 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)), 8080),
165 None,
166 Some(*HOST_ID_2),
167 None,
168 vec![Murmur3Token::new(20)],
169 "".into(),
170 "".into(),
171 )),
172 );
173 nodes.insert(
174 *HOST_ID_3,
175 Arc::new(Node::new(
176 connection_pool_factory,
177 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)), 8080),
178 None,
179 Some(*HOST_ID_3),
180 None,
181 vec![
182 Murmur3Token::new(2),
183 Murmur3Token::new(1),
184 Murmur3Token::new(10),
185 ],
186 "".into(),
187 "".into(),
188 )),
189 );
190
191 nodes
192 }
193
194 fn verify_tokens(host_ids: &[Uuid], token: Murmur3Token) {
195 let token_map = TokenMap::new(&prepare_nodes());
196 let nodes = token_map
197 .nodes_for_token_capped(token, host_ids.len())
198 .collect_vec();
199
200 assert_eq!(nodes.len(), host_ids.len());
201 for (index, node) in nodes.iter().enumerate() {
202 assert_eq!(node.host_id().unwrap(), host_ids[index]);
203 }
204 }
205
206 #[test]
207 fn should_return_replicas_in_order() {
208 verify_tokens(
209 &[*HOST_ID_1, *HOST_ID_3, *HOST_ID_3, *HOST_ID_3, *HOST_ID_2],
210 Murmur3Token::new(0),
211 );
212 }
213
214 #[test]
215 fn should_return_replicas_in_order_for_non_primary_token() {
216 verify_tokens(&[*HOST_ID_3, *HOST_ID_2], Murmur3Token::new(3));
217 }
218
219 #[test]
220 fn should_return_replicas_in_a_ring() {
221 verify_tokens(
222 &[*HOST_ID_2, *HOST_ID_1, *HOST_ID_1, *HOST_ID_1, *HOST_ID_3],
223 Murmur3Token::new(20),
224 );
225 }
226}