Skip to main content

d_engine_client/
grpc_client.rs

1use std::sync::Arc;
2
3use arc_swap::ArcSwap;
4use bytes::Bytes;
5use d_engine_proto::client::ClientReadRequest;
6use d_engine_proto::client::ClientResult;
7use d_engine_proto::client::ClientWriteRequest;
8use d_engine_proto::client::ReadConsistencyPolicy;
9use d_engine_proto::client::WatchRequest;
10use d_engine_proto::client::WatchResponse;
11use d_engine_proto::client::WriteCommand;
12use d_engine_proto::client::raft_client_service_client::RaftClientServiceClient;
13use d_engine_proto::error::ErrorCode;
14use rand::Rng;
15use rand::SeedableRng;
16use rand::rngs::StdRng;
17use tonic::codec::CompressionEncoding;
18use tonic::transport::Channel;
19use tracing::debug;
20use tracing::error;
21use tracing::warn;
22
23use super::ClientInner;
24use crate::ClientApiError;
25use crate::ClientResponseExt;
26use crate::scoped_timer::ScopedTimer;
27use d_engine_core::client::{ClientApi, ClientApiResult};
28
29/// gRPC-based key-value store client
30///
31/// Implements remote CRUD operations via gRPC protocol.
32/// All write operations use strong consistency.
33#[derive(Clone)]
34pub struct GrpcClient {
35    pub(super) client_inner: Arc<ArcSwap<ClientInner>>,
36}
37
38impl GrpcClient {
39    pub(crate) fn new(client_inner: Arc<ArcSwap<ClientInner>>) -> Self {
40        Self { client_inner }
41    }
42
43    // Convenience methods for explicit consistency levels
44    pub async fn get_linearizable(
45        &self,
46        key: impl AsRef<[u8]>,
47    ) -> std::result::Result<Option<ClientResult>, ClientApiError> {
48        self.get_with_policy(key, Some(ReadConsistencyPolicy::LinearizableRead)).await
49    }
50
51    pub async fn get_lease(
52        &self,
53        key: impl AsRef<[u8]>,
54    ) -> std::result::Result<Option<ClientResult>, ClientApiError> {
55        self.get_with_policy(key, Some(ReadConsistencyPolicy::LeaseRead)).await
56    }
57
58    pub async fn get_eventual(
59        &self,
60        key: impl AsRef<[u8]>,
61    ) -> std::result::Result<Option<ClientResult>, ClientApiError> {
62        self.get_with_policy(key, Some(ReadConsistencyPolicy::EventualConsistency))
63            .await
64    }
65
66    /// Retrieves a single key's value with explicit consistency policy
67    ///
68    /// Allows client to override server's default consistency policy for this specific request.
69    /// If server's allow_client_override is false, the override will be ignored.
70    ///
71    /// # Parameters
72    /// * `key` - The key to retrieve, accepts any type implementing `AsRef<[u8]>`
73    /// * `policy` - Explicit consistency policy for this request
74    pub async fn get_with_policy(
75        &self,
76        key: impl AsRef<[u8]>,
77        consistency_policy: Option<ReadConsistencyPolicy>,
78    ) -> std::result::Result<Option<ClientResult>, ClientApiError> {
79        // Delegate to multi-get implementation
80        let mut results =
81            self.get_multi_with_policy(std::iter::once(key), consistency_policy).await?;
82
83        // Extract single result (safe due to single-key input)
84        Ok(results.pop().unwrap_or(None))
85    }
86
87    /// Fetches multiple keys with explicit consistency policy override
88    ///
89    /// Allows client to override server's default consistency policy for this batch request.
90    /// If server's allow_client_override is false, the override will be ignored.
91    pub async fn get_multi_with_policy(
92        &self,
93        keys: impl IntoIterator<Item = impl AsRef<[u8]>>,
94        consistency_policy: Option<ReadConsistencyPolicy>,
95    ) -> std::result::Result<Vec<Option<ClientResult>>, ClientApiError> {
96        let _timer = ScopedTimer::new("client::get_multi");
97
98        let client_inner = self.client_inner.load();
99        // Convert keys to commands
100        let keys: Vec<Bytes> =
101            keys.into_iter().map(|k| Bytes::copy_from_slice(k.as_ref())).collect();
102
103        // Validate at least one key
104        if keys.is_empty() {
105            warn!("Attempted multi-get with empty key collection");
106            return Err(ErrorCode::InvalidRequest.into());
107        }
108
109        // Build request — keep a reference for result alignment after move
110        let keys_for_alignment = keys.clone();
111        let request = ClientReadRequest {
112            client_id: client_inner.client_id,
113            keys,
114            consistency_policy: consistency_policy.map(|p| p as i32),
115        };
116
117        // Select client based on policy (if specified)
118        // None means "use server default" — server default may be Linearizable,
119        // so we must send to leader to avoid rejection from followers.
120        let mut client = match consistency_policy {
121            Some(ReadConsistencyPolicy::LinearizableRead)
122            | Some(ReadConsistencyPolicy::LeaseRead)
123            | None => {
124                debug!("Using leader client for explicit consistency policy");
125                self.make_leader_client().await?
126            }
127            Some(ReadConsistencyPolicy::EventualConsistency) => {
128                debug!("Using load-balanced client for cluster default policy");
129                self.make_client().await?
130            }
131        };
132
133        // Execute request
134        match client.handle_client_read(request).await {
135            Ok(response) => {
136                debug!("Read response: {:?}", response);
137                // Server returns only results for existing keys (sparse).
138                // Reconstruct aligned vector matching input key order,
139                // filling None for keys not present in the response.
140                // Mirrors embedded_client::get_multi_with_consistency behavior.
141                let sparse = response.into_inner().into_read_results()?;
142                let results_by_key: std::collections::HashMap<bytes::Bytes, _> =
143                    sparse.into_iter().filter_map(|opt| opt.map(|r| (r.key.clone(), r))).collect();
144                Ok(keys_for_alignment.iter().map(|k| results_by_key.get(k).cloned()).collect())
145            }
146            Err(status) => {
147                error!("Read request failed: {:?}", status);
148                Err(status.into())
149            }
150        }
151    }
152
153    async fn make_leader_client(
154        &self
155    ) -> std::result::Result<RaftClientServiceClient<Channel>, ClientApiError> {
156        let client_inner = self.client_inner.load();
157
158        let channel = client_inner.pool.get_leader();
159        let mut client = RaftClientServiceClient::new(channel);
160        if client_inner.pool.config.enable_compression {
161            client = client
162                .send_compressed(CompressionEncoding::Gzip)
163                .accept_compressed(CompressionEncoding::Gzip);
164        }
165
166        Ok(client)
167    }
168
169    pub(super) async fn make_client(
170        &self
171    ) -> std::result::Result<RaftClientServiceClient<Channel>, ClientApiError> {
172        let client_inner = self.client_inner.load();
173
174        // Balance from read clients
175        let mut rng = StdRng::from_entropy();
176        let channels = client_inner.pool.get_all_channels();
177        let i = rng.gen_range(0..channels.len());
178
179        let mut client = RaftClientServiceClient::new(channels[i].clone());
180
181        if client_inner.pool.config.enable_compression {
182            client = client
183                .send_compressed(CompressionEncoding::Gzip)
184                .accept_compressed(CompressionEncoding::Gzip);
185        }
186
187        Ok(client)
188    }
189
190    /// Watch for changes to a specific key
191    ///
192    /// Returns a stream of watch events when the key's value changes.
193    /// The stream will continue until explicitly closed or a connection error occurs.
194    ///
195    /// # Arguments
196    ///
197    /// * `key` - The key to watch
198    ///
199    /// # Returns
200    ///
201    /// A streaming response that yields `WatchResponse` events
202    ///
203    /// # Errors
204    ///
205    /// Returns error if unable to establish watch connection
206    pub async fn watch(
207        &self,
208        key: impl AsRef<[u8]>,
209    ) -> ClientApiResult<tonic::Streaming<WatchResponse>> {
210        let client_inner = self.client_inner.load();
211
212        let request = WatchRequest {
213            client_id: client_inner.client_id,
214            key: Bytes::copy_from_slice(key.as_ref()),
215        };
216
217        // Watch can connect to any node (leader or follower)
218        let mut client = self.make_client().await?;
219
220        match client.watch(request).await {
221            Ok(response) => {
222                debug!("Watch stream established");
223                Ok(response.into_inner())
224            }
225            Err(status) => {
226                error!("Watch request failed: {:?}", status);
227                Err(status.into())
228            }
229        }
230    }
231}
232
233// ==================== Core ClientApi Trait Implementation ====================
234
235// Implement ClientApi trait for GrpcClient
236#[async_trait::async_trait]
237impl ClientApi for GrpcClient {
238    async fn put(
239        &self,
240        key: impl AsRef<[u8]> + Send,
241        value: impl AsRef<[u8]> + Send,
242    ) -> ClientApiResult<()> {
243        // Performance tracking for put operation
244        let _timer = ScopedTimer::new("client::put");
245
246        let client_inner = self.client_inner.load();
247
248        // Build write request with insert command
249        let command = WriteCommand::insert(
250            Bytes::copy_from_slice(key.as_ref()),
251            Bytes::copy_from_slice(value.as_ref()),
252        );
253
254        let request = ClientWriteRequest {
255            client_id: client_inner.client_id,
256            command: Some(command),
257        };
258
259        // Send write request to leader node (strong consistency required)
260        let mut client = self.make_leader_client().await?;
261        match client.handle_client_write(request).await {
262            Ok(response) => {
263                debug!("[:GrpcClient:write] response: {:?}", response);
264                let client_response = response.get_ref();
265                client_response.validate_error()
266            }
267            Err(status) => {
268                error!("[:GrpcClient:write] status: {:?}", status);
269                Err(Into::<ClientApiError>::into(ClientApiError::from(status)))
270            }
271        }
272    }
273
274    async fn put_with_ttl(
275        &self,
276        key: impl AsRef<[u8]> + Send,
277        value: impl AsRef<[u8]> + Send,
278        ttl_secs: u64,
279    ) -> ClientApiResult<()> {
280        // Performance tracking for put_with_ttl operation
281        let _timer = ScopedTimer::new("client::put_with_ttl");
282
283        let client_inner = self.client_inner.load();
284
285        // Build write request with TTL-enabled insert command
286        let command = WriteCommand::insert_with_ttl(
287            Bytes::copy_from_slice(key.as_ref()),
288            Bytes::copy_from_slice(value.as_ref()),
289            ttl_secs,
290        );
291
292        let request = ClientWriteRequest {
293            client_id: client_inner.client_id,
294            command: Some(command),
295        };
296
297        // Send write request to leader node (strong consistency required)
298        let mut client = self.make_leader_client().await?;
299        match client.handle_client_write(request).await {
300            Ok(response) => {
301                debug!("[:GrpcClient:put_with_ttl] response: {:?}", response);
302                let client_response = response.get_ref();
303                client_response.validate_error()
304            }
305            Err(status) => {
306                error!("[:GrpcClient:put_with_ttl] status: {:?}", status);
307                Err(Into::<ClientApiError>::into(ClientApiError::from(status)))
308            }
309        }
310    }
311
312    async fn get(
313        &self,
314        key: impl AsRef<[u8]> + Send,
315    ) -> ClientApiResult<Option<Bytes>> {
316        // Delegate to get_with_policy with server's default consistency policy
317        let result = self.get_with_policy(key, None).await;
318
319        match result {
320            Ok(Some(client_result)) => Ok(Some(client_result.value)),
321            Ok(None) => Ok(None),
322            Err(e) => Err(Into::<ClientApiError>::into(e)),
323        }
324    }
325
326    async fn get_multi(
327        &self,
328        keys: &[Bytes],
329    ) -> ClientApiResult<Vec<Option<Bytes>>> {
330        // Delegate to get_multi_with_policy with server's default consistency policy
331        let result = self.get_multi_with_policy(keys.iter().cloned(), None).await;
332
333        match result {
334            Ok(results) => {
335                // Extract values from ClientResult, preserving None for missing keys
336                Ok(results.into_iter().map(|opt| opt.map(|r| r.value)).collect())
337            }
338            Err(e) => Err(Into::<ClientApiError>::into(e)),
339        }
340    }
341
342    async fn delete(
343        &self,
344        key: impl AsRef<[u8]> + Send,
345    ) -> ClientApiResult<()> {
346        let client_inner = self.client_inner.load();
347
348        // Build delete request
349        let command = WriteCommand::delete(Bytes::copy_from_slice(key.as_ref()));
350
351        let request = ClientWriteRequest {
352            client_id: client_inner.client_id,
353            command: Some(command),
354        };
355
356        // Send delete request to leader node (strong consistency required)
357        let mut client = self.make_leader_client().await?;
358        match client.handle_client_write(request).await {
359            Ok(response) => {
360                debug!("[:GrpcClient:delete] response: {:?}", response);
361                let client_response = response.get_ref();
362                client_response.validate_error()
363            }
364            Err(status) => {
365                error!("[:GrpcClient:delete] status: {:?}", status);
366                Err(Into::<ClientApiError>::into(ClientApiError::from(status)))
367            }
368        }
369    }
370
371    async fn compare_and_swap(
372        &self,
373        key: impl AsRef<[u8]> + Send,
374        expected_value: Option<impl AsRef<[u8]> + Send>,
375        new_value: impl AsRef<[u8]> + Send,
376    ) -> ClientApiResult<bool> {
377        let client_inner = self.client_inner.load();
378
379        // Build CAS request
380        let expected = expected_value.map(|v| Bytes::copy_from_slice(v.as_ref()));
381        let command = WriteCommand::compare_and_swap(
382            Bytes::copy_from_slice(key.as_ref()),
383            expected,
384            Bytes::copy_from_slice(new_value.as_ref()),
385        );
386
387        let request = ClientWriteRequest {
388            client_id: client_inner.client_id,
389            command: Some(command),
390        };
391
392        // Send CAS request to leader node
393        let mut client = self.make_leader_client().await?;
394        match client.handle_client_write(request).await {
395            Ok(response) => {
396                debug!("[:GrpcClient:compare_and_swap] response: {:?}", response);
397                let client_response = response.get_ref();
398
399                // Validate no error occurred
400                client_response.validate_error()?;
401
402                // Extract CAS result (true = succeeded, false = failed comparison)
403                Ok(client_response.is_write_success())
404            }
405            Err(status) => {
406                error!("[:GrpcClient:compare_and_swap] status: {:?}", status);
407                Err(Into::<ClientApiError>::into(ClientApiError::from(status)))
408            }
409        }
410    }
411
412    async fn list_members(
413        &self
414    ) -> ClientApiResult<Vec<d_engine_proto::server::cluster::NodeMeta>> {
415        let client_inner = self.client_inner.load();
416        Ok(client_inner.pool.get_all_members())
417    }
418
419    async fn get_leader_id(&self) -> ClientApiResult<Option<u32>> {
420        let client_inner = self.client_inner.load();
421        Ok(client_inner.pool.get_leader_id())
422    }
423
424    async fn get_multi_with_policy(
425        &self,
426        keys: &[Bytes],
427        consistency_policy: Option<ReadConsistencyPolicy>,
428    ) -> ClientApiResult<Vec<Option<Bytes>>> {
429        // Explicitly call the convenience method on impl block, not trait method
430        let result =
431            <Self>::get_multi_with_policy(self, keys.iter().cloned(), consistency_policy).await;
432
433        match result {
434            Ok(results) => Ok(results.into_iter().map(|opt| opt.map(|r| r.value)).collect()),
435            Err(e) => Err(e),
436        }
437    }
438
439    async fn get_linearizable(
440        &self,
441        key: impl AsRef<[u8]> + Send,
442    ) -> ClientApiResult<Option<Bytes>> {
443        let result = self.get_with_policy(key, Some(ReadConsistencyPolicy::LinearizableRead)).await;
444
445        match result {
446            Ok(Some(client_result)) => Ok(Some(client_result.value)),
447            Ok(None) => Ok(None),
448            Err(e) => Err(e),
449        }
450    }
451
452    async fn get_lease(
453        &self,
454        key: impl AsRef<[u8]> + Send,
455    ) -> ClientApiResult<Option<Bytes>> {
456        let result = self.get_with_policy(key, Some(ReadConsistencyPolicy::LeaseRead)).await;
457
458        match result {
459            Ok(Some(client_result)) => Ok(Some(client_result.value)),
460            Ok(None) => Ok(None),
461            Err(e) => Err(e),
462        }
463    }
464
465    async fn get_eventual(
466        &self,
467        key: impl AsRef<[u8]> + Send,
468    ) -> ClientApiResult<Option<Bytes>> {
469        let result = self
470            .get_with_policy(key, Some(ReadConsistencyPolicy::EventualConsistency))
471            .await;
472
473        match result {
474            Ok(Some(client_result)) => Ok(Some(client_result.value)),
475            Ok(None) => Ok(None),
476            Err(e) => Err(e),
477        }
478    }
479}