1use raft_common::types::NodeId;
2use std::collections::HashMap;
3use std::time::Duration;
4use tonic::transport::Channel;
5use tracing::debug;
6
7use crate::proto::kv::kv_service_client::KvServiceClient;
8use crate::proto::kv::{
9 DeleteRequest, DeleteResponse, GetRequest, GetResponse, PutRequest, PutResponse, RangeRequest,
10 RangeResponse,
11};
12
13pub struct KvClient {
15 nodes: HashMap<NodeId, String>,
17 leader_id: Option<NodeId>,
19 connections: HashMap<NodeId, KvServiceClient<Channel>>,
21 max_retries: usize,
23 base_delay: Duration,
25}
26
27impl KvClient {
28 pub fn new(nodes: HashMap<NodeId, String>) -> Self {
29 Self {
30 nodes,
31 leader_id: None,
32 connections: HashMap::new(),
33 max_retries: 5,
34 base_delay: Duration::from_millis(100),
35 }
36 }
37
38 async fn get_connection(
39 &mut self,
40 node_id: NodeId,
41 ) -> Result<&mut KvServiceClient<Channel>, tonic::Status> {
42 if !self.connections.contains_key(&node_id) {
43 let addr = self
44 .nodes
45 .get(&node_id)
46 .ok_or_else(|| tonic::Status::not_found(format!("unknown node {}", node_id)))?
47 .clone();
48 let channel = Channel::from_shared(addr)
49 .map_err(|e| tonic::Status::internal(e.to_string()))?
50 .connect()
51 .await
52 .map_err(|e| tonic::Status::unavailable(e.to_string()))?;
53 self.connections
54 .insert(node_id, KvServiceClient::new(channel));
55 }
56 Ok(self.connections.get_mut(&node_id).unwrap())
57 }
58
59 fn pick_node(&self) -> NodeId {
61 self.leader_id
62 .unwrap_or_else(|| *self.nodes.keys().next().unwrap())
63 }
64
65 fn next_node(&self, current: NodeId) -> NodeId {
67 let ids: Vec<NodeId> = self.nodes.keys().copied().collect();
68 let pos = ids.iter().position(|&id| id == current).unwrap_or(0);
69 ids[(pos + 1) % ids.len()]
70 }
71
72 fn parse_leader_hint(status: &tonic::Status) -> Option<NodeId> {
74 let msg = status.message();
75 if let Some(start) = msg.find("Some(") {
77 let rest = &msg[start + 5..];
78 if let Some(end) = rest.find(')') {
79 return rest[..end].parse().ok();
80 }
81 }
82 None
83 }
84
85 pub async fn get(&mut self, key: &[u8]) -> Result<GetResponse, tonic::Status> {
86 let mut node_id = self.pick_node();
87
88 for attempt in 0..self.max_retries {
89 let client = self.get_connection(node_id).await?;
90 match client
91 .get(GetRequest {
92 key: key.to_vec(),
93 linearizable: false,
94 })
95 .await
96 {
97 Ok(resp) => return Ok(resp.into_inner()),
98 Err(status) => {
99 if status.code() == tonic::Code::FailedPrecondition {
100 if let Some(leader) = Self::parse_leader_hint(&status) {
101 self.leader_id = Some(leader);
102 node_id = leader;
103 continue;
104 }
105 }
106 self.connections.remove(&node_id);
107 node_id = self.next_node(node_id);
108 let delay = self.base_delay * 2u32.pow(attempt as u32);
109 debug!(attempt, delay_ms = delay.as_millis(), "Retrying");
110 tokio::time::sleep(delay).await;
111 }
112 }
113 }
114
115 Err(tonic::Status::unavailable("all retries exhausted"))
116 }
117
118 pub async fn put(&mut self, key: &[u8], value: &[u8]) -> Result<PutResponse, tonic::Status> {
119 self.put_with_options(key, value, 0, 0).await
120 }
121
122 pub async fn put_with_options(
123 &mut self,
124 key: &[u8],
125 value: &[u8],
126 lease_id: i64,
127 ttl_seconds: i64,
128 ) -> Result<PutResponse, tonic::Status> {
129 let mut node_id = self.pick_node();
130
131 for attempt in 0..self.max_retries {
132 let client = self.get_connection(node_id).await?;
133 match client
134 .put(PutRequest {
135 key: key.to_vec(),
136 value: value.to_vec(),
137 lease_id,
138 ttl_seconds,
139 })
140 .await
141 {
142 Ok(resp) => {
143 self.leader_id = Some(node_id);
144 return Ok(resp.into_inner());
145 }
146 Err(status) => {
147 if status.code() == tonic::Code::FailedPrecondition {
148 if let Some(leader) = Self::parse_leader_hint(&status) {
149 self.leader_id = Some(leader);
150 node_id = leader;
151 continue;
152 }
153 }
154 self.connections.remove(&node_id);
155 node_id = self.next_node(node_id);
156 let delay = self.base_delay * 2u32.pow(attempt as u32);
157 debug!(attempt, delay_ms = delay.as_millis(), "Retrying put");
158 tokio::time::sleep(delay).await;
159 }
160 }
161 }
162
163 Err(tonic::Status::unavailable("all retries exhausted"))
164 }
165
166 pub async fn delete(&mut self, key: &[u8]) -> Result<DeleteResponse, tonic::Status> {
167 let mut node_id = self.pick_node();
168
169 for attempt in 0..self.max_retries {
170 let client = self.get_connection(node_id).await?;
171 match client.delete(DeleteRequest { key: key.to_vec() }).await {
172 Ok(resp) => {
173 self.leader_id = Some(node_id);
174 return Ok(resp.into_inner());
175 }
176 Err(status) => {
177 if status.code() == tonic::Code::FailedPrecondition {
178 if let Some(leader) = Self::parse_leader_hint(&status) {
179 self.leader_id = Some(leader);
180 node_id = leader;
181 continue;
182 }
183 }
184 self.connections.remove(&node_id);
185 node_id = self.next_node(node_id);
186 let delay = self.base_delay * 2u32.pow(attempt as u32);
187 tokio::time::sleep(delay).await;
188 }
189 }
190 }
191
192 Err(tonic::Status::unavailable("all retries exhausted"))
193 }
194
195 pub async fn range(
196 &mut self,
197 start_key: &[u8],
198 end_key: &[u8],
199 limit: i64,
200 ) -> Result<RangeResponse, tonic::Status> {
201 let node_id = self.pick_node();
202 let client = self.get_connection(node_id).await?;
203 let resp = client
204 .range(RangeRequest {
205 start_key: start_key.to_vec(),
206 end_key: end_key.to_vec(),
207 limit,
208 })
209 .await?;
210 Ok(resp.into_inner())
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 #[test]
219 fn parse_leader_hint_from_error() {
220 let status = tonic::Status::failed_precondition("not leader, leader is Some(2)");
221 assert_eq!(KvClient::parse_leader_hint(&status), Some(2));
222 }
223
224 #[test]
225 fn parse_leader_hint_none() {
226 let status = tonic::Status::failed_precondition("not leader, leader is None");
227 assert_eq!(KvClient::parse_leader_hint(&status), None);
228 }
229}