Skip to main content

raft_client/
client.rs

1use raft_common::types::NodeId;
2use std::collections::HashMap;
3use std::time::Duration;
4use tonic::transport::Channel;
5use tracing::debug;
6
7use crate::proto::kv::kv_service_client::KvServiceClient;
8use crate::proto::kv::{
9    DeleteRequest, DeleteResponse, GetRequest, GetResponse, PutRequest, PutResponse, RangeRequest,
10    RangeResponse,
11};
12
13/// KV client with automatic leader tracking and retry with exponential backoff.
14pub struct KvClient {
15    /// All known node addresses.
16    nodes: HashMap<NodeId, String>,
17    /// Current believed leader.
18    leader_id: Option<NodeId>,
19    /// Cached gRPC connections.
20    connections: HashMap<NodeId, KvServiceClient<Channel>>,
21    /// Maximum number of retries.
22    max_retries: usize,
23    /// Base delay for exponential backoff.
24    base_delay: Duration,
25}
26
27impl KvClient {
28    pub fn new(nodes: HashMap<NodeId, String>) -> Self {
29        Self {
30            nodes,
31            leader_id: None,
32            connections: HashMap::new(),
33            max_retries: 5,
34            base_delay: Duration::from_millis(100),
35        }
36    }
37
38    async fn get_connection(
39        &mut self,
40        node_id: NodeId,
41    ) -> Result<&mut KvServiceClient<Channel>, tonic::Status> {
42        if !self.connections.contains_key(&node_id) {
43            let addr = self
44                .nodes
45                .get(&node_id)
46                .ok_or_else(|| tonic::Status::not_found(format!("unknown node {}", node_id)))?
47                .clone();
48            let channel = Channel::from_shared(addr)
49                .map_err(|e| tonic::Status::internal(e.to_string()))?
50                .connect()
51                .await
52                .map_err(|e| tonic::Status::unavailable(e.to_string()))?;
53            self.connections
54                .insert(node_id, KvServiceClient::new(channel));
55        }
56        Ok(self.connections.get_mut(&node_id).unwrap())
57    }
58
59    /// Pick a node to try: leader if known, otherwise first node.
60    fn pick_node(&self) -> NodeId {
61        self.leader_id
62            .unwrap_or_else(|| *self.nodes.keys().next().unwrap())
63    }
64
65    /// Try the next node (round-robin through known nodes).
66    fn next_node(&self, current: NodeId) -> NodeId {
67        let ids: Vec<NodeId> = self.nodes.keys().copied().collect();
68        let pos = ids.iter().position(|&id| id == current).unwrap_or(0);
69        ids[(pos + 1) % ids.len()]
70    }
71
72    /// Parse leader hint from "not leader" error messages.
73    fn parse_leader_hint(status: &tonic::Status) -> Option<NodeId> {
74        let msg = status.message();
75        // Format: "not leader, leader is Some(2)"
76        if let Some(start) = msg.find("Some(") {
77            let rest = &msg[start + 5..];
78            if let Some(end) = rest.find(')') {
79                return rest[..end].parse().ok();
80            }
81        }
82        None
83    }
84
85    pub async fn get(&mut self, key: &[u8]) -> Result<GetResponse, tonic::Status> {
86        let mut node_id = self.pick_node();
87
88        for attempt in 0..self.max_retries {
89            let client = self.get_connection(node_id).await?;
90            match client
91                .get(GetRequest {
92                    key: key.to_vec(),
93                    linearizable: false,
94                })
95                .await
96            {
97                Ok(resp) => return Ok(resp.into_inner()),
98                Err(status) => {
99                    if status.code() == tonic::Code::FailedPrecondition {
100                        if let Some(leader) = Self::parse_leader_hint(&status) {
101                            self.leader_id = Some(leader);
102                            node_id = leader;
103                            continue;
104                        }
105                    }
106                    self.connections.remove(&node_id);
107                    node_id = self.next_node(node_id);
108                    let delay = self.base_delay * 2u32.pow(attempt as u32);
109                    debug!(attempt, delay_ms = delay.as_millis(), "Retrying");
110                    tokio::time::sleep(delay).await;
111                }
112            }
113        }
114
115        Err(tonic::Status::unavailable("all retries exhausted"))
116    }
117
118    pub async fn put(&mut self, key: &[u8], value: &[u8]) -> Result<PutResponse, tonic::Status> {
119        self.put_with_options(key, value, 0, 0).await
120    }
121
122    pub async fn put_with_options(
123        &mut self,
124        key: &[u8],
125        value: &[u8],
126        lease_id: i64,
127        ttl_seconds: i64,
128    ) -> Result<PutResponse, tonic::Status> {
129        let mut node_id = self.pick_node();
130
131        for attempt in 0..self.max_retries {
132            let client = self.get_connection(node_id).await?;
133            match client
134                .put(PutRequest {
135                    key: key.to_vec(),
136                    value: value.to_vec(),
137                    lease_id,
138                    ttl_seconds,
139                })
140                .await
141            {
142                Ok(resp) => {
143                    self.leader_id = Some(node_id);
144                    return Ok(resp.into_inner());
145                }
146                Err(status) => {
147                    if status.code() == tonic::Code::FailedPrecondition {
148                        if let Some(leader) = Self::parse_leader_hint(&status) {
149                            self.leader_id = Some(leader);
150                            node_id = leader;
151                            continue;
152                        }
153                    }
154                    self.connections.remove(&node_id);
155                    node_id = self.next_node(node_id);
156                    let delay = self.base_delay * 2u32.pow(attempt as u32);
157                    debug!(attempt, delay_ms = delay.as_millis(), "Retrying put");
158                    tokio::time::sleep(delay).await;
159                }
160            }
161        }
162
163        Err(tonic::Status::unavailable("all retries exhausted"))
164    }
165
166    pub async fn delete(&mut self, key: &[u8]) -> Result<DeleteResponse, tonic::Status> {
167        let mut node_id = self.pick_node();
168
169        for attempt in 0..self.max_retries {
170            let client = self.get_connection(node_id).await?;
171            match client.delete(DeleteRequest { key: key.to_vec() }).await {
172                Ok(resp) => {
173                    self.leader_id = Some(node_id);
174                    return Ok(resp.into_inner());
175                }
176                Err(status) => {
177                    if status.code() == tonic::Code::FailedPrecondition {
178                        if let Some(leader) = Self::parse_leader_hint(&status) {
179                            self.leader_id = Some(leader);
180                            node_id = leader;
181                            continue;
182                        }
183                    }
184                    self.connections.remove(&node_id);
185                    node_id = self.next_node(node_id);
186                    let delay = self.base_delay * 2u32.pow(attempt as u32);
187                    tokio::time::sleep(delay).await;
188                }
189            }
190        }
191
192        Err(tonic::Status::unavailable("all retries exhausted"))
193    }
194
195    pub async fn range(
196        &mut self,
197        start_key: &[u8],
198        end_key: &[u8],
199        limit: i64,
200    ) -> Result<RangeResponse, tonic::Status> {
201        let node_id = self.pick_node();
202        let client = self.get_connection(node_id).await?;
203        let resp = client
204            .range(RangeRequest {
205                start_key: start_key.to_vec(),
206                end_key: end_key.to_vec(),
207                limit,
208            })
209            .await?;
210        Ok(resp.into_inner())
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    #[test]
219    fn parse_leader_hint_from_error() {
220        let status = tonic::Status::failed_precondition("not leader, leader is Some(2)");
221        assert_eq!(KvClient::parse_leader_hint(&status), Some(2));
222    }
223
224    #[test]
225    fn parse_leader_hint_none() {
226        let status = tonic::Status::failed_precondition("not leader, leader is None");
227        assert_eq!(KvClient::parse_leader_hint(&status), None);
228    }
229}