1use std::sync::Arc;
2
3use arc_swap::ArcSwap;
4use bytes::Bytes;
5use d_engine_proto::client::ClientReadRequest;
6use d_engine_proto::client::ClientResult;
7use d_engine_proto::client::ClientWriteRequest;
8use d_engine_proto::client::ReadConsistencyPolicy;
9use d_engine_proto::client::WatchRequest;
10use d_engine_proto::client::WatchResponse;
11use d_engine_proto::client::WriteCommand;
12use d_engine_proto::client::raft_client_service_client::RaftClientServiceClient;
13use d_engine_proto::error::ErrorCode;
14use rand::Rng;
15use rand::SeedableRng;
16use rand::rngs::StdRng;
17use tonic::codec::CompressionEncoding;
18use tonic::transport::Channel;
19use tracing::debug;
20use tracing::error;
21use tracing::warn;
22
23use super::ClientInner;
24use crate::ClientApiError;
25use crate::ClientResponseExt;
26use crate::scoped_timer::ScopedTimer;
27use d_engine_core::client::{ClientApi, ClientApiResult};
28
29#[derive(Clone)]
34pub struct GrpcClient {
35 pub(super) client_inner: Arc<ArcSwap<ClientInner>>,
36}
37
38impl GrpcClient {
39 pub(crate) fn new(client_inner: Arc<ArcSwap<ClientInner>>) -> Self {
40 Self { client_inner }
41 }
42
43 pub async fn get_linearizable(
45 &self,
46 key: impl AsRef<[u8]>,
47 ) -> std::result::Result<Option<ClientResult>, ClientApiError> {
48 self.get_with_policy(key, Some(ReadConsistencyPolicy::LinearizableRead)).await
49 }
50
51 pub async fn get_lease(
52 &self,
53 key: impl AsRef<[u8]>,
54 ) -> std::result::Result<Option<ClientResult>, ClientApiError> {
55 self.get_with_policy(key, Some(ReadConsistencyPolicy::LeaseRead)).await
56 }
57
58 pub async fn get_eventual(
59 &self,
60 key: impl AsRef<[u8]>,
61 ) -> std::result::Result<Option<ClientResult>, ClientApiError> {
62 self.get_with_policy(key, Some(ReadConsistencyPolicy::EventualConsistency))
63 .await
64 }
65
66 pub async fn get_with_policy(
75 &self,
76 key: impl AsRef<[u8]>,
77 consistency_policy: Option<ReadConsistencyPolicy>,
78 ) -> std::result::Result<Option<ClientResult>, ClientApiError> {
79 let mut results =
81 self.get_multi_with_policy(std::iter::once(key), consistency_policy).await?;
82
83 Ok(results.pop().unwrap_or(None))
85 }
86
87 pub async fn get_multi_with_policy(
92 &self,
93 keys: impl IntoIterator<Item = impl AsRef<[u8]>>,
94 consistency_policy: Option<ReadConsistencyPolicy>,
95 ) -> std::result::Result<Vec<Option<ClientResult>>, ClientApiError> {
96 let _timer = ScopedTimer::new("client::get_multi");
97
98 let client_inner = self.client_inner.load();
99 let keys: Vec<Bytes> =
101 keys.into_iter().map(|k| Bytes::copy_from_slice(k.as_ref())).collect();
102
103 if keys.is_empty() {
105 warn!("Attempted multi-get with empty key collection");
106 return Err(ErrorCode::InvalidRequest.into());
107 }
108
109 let keys_for_alignment = keys.clone();
111 let request = ClientReadRequest {
112 client_id: client_inner.client_id,
113 keys,
114 consistency_policy: consistency_policy.map(|p| p as i32),
115 };
116
117 let mut client = match consistency_policy {
121 Some(ReadConsistencyPolicy::LinearizableRead)
122 | Some(ReadConsistencyPolicy::LeaseRead)
123 | None => {
124 debug!("Using leader client for explicit consistency policy");
125 self.make_leader_client().await?
126 }
127 Some(ReadConsistencyPolicy::EventualConsistency) => {
128 debug!("Using load-balanced client for cluster default policy");
129 self.make_client().await?
130 }
131 };
132
133 match client.handle_client_read(request).await {
135 Ok(response) => {
136 debug!("Read response: {:?}", response);
137 let sparse = response.into_inner().into_read_results()?;
142 let results_by_key: std::collections::HashMap<bytes::Bytes, _> =
143 sparse.into_iter().filter_map(|opt| opt.map(|r| (r.key.clone(), r))).collect();
144 Ok(keys_for_alignment.iter().map(|k| results_by_key.get(k).cloned()).collect())
145 }
146 Err(status) => {
147 error!("Read request failed: {:?}", status);
148 Err(status.into())
149 }
150 }
151 }
152
153 async fn make_leader_client(
154 &self
155 ) -> std::result::Result<RaftClientServiceClient<Channel>, ClientApiError> {
156 let client_inner = self.client_inner.load();
157
158 let channel = client_inner.pool.get_leader();
159 let mut client = RaftClientServiceClient::new(channel);
160 if client_inner.pool.config.enable_compression {
161 client = client
162 .send_compressed(CompressionEncoding::Gzip)
163 .accept_compressed(CompressionEncoding::Gzip);
164 }
165
166 Ok(client)
167 }
168
169 pub(super) async fn make_client(
170 &self
171 ) -> std::result::Result<RaftClientServiceClient<Channel>, ClientApiError> {
172 let client_inner = self.client_inner.load();
173
174 let mut rng = StdRng::from_entropy();
176 let channels = client_inner.pool.get_all_channels();
177 let i = rng.gen_range(0..channels.len());
178
179 let mut client = RaftClientServiceClient::new(channels[i].clone());
180
181 if client_inner.pool.config.enable_compression {
182 client = client
183 .send_compressed(CompressionEncoding::Gzip)
184 .accept_compressed(CompressionEncoding::Gzip);
185 }
186
187 Ok(client)
188 }
189
190 pub async fn watch(
207 &self,
208 key: impl AsRef<[u8]>,
209 ) -> ClientApiResult<tonic::Streaming<WatchResponse>> {
210 let client_inner = self.client_inner.load();
211
212 let request = WatchRequest {
213 client_id: client_inner.client_id,
214 key: Bytes::copy_from_slice(key.as_ref()),
215 };
216
217 let mut client = self.make_client().await?;
219
220 match client.watch(request).await {
221 Ok(response) => {
222 debug!("Watch stream established");
223 Ok(response.into_inner())
224 }
225 Err(status) => {
226 error!("Watch request failed: {:?}", status);
227 Err(status.into())
228 }
229 }
230 }
231}
232
233#[async_trait::async_trait]
237impl ClientApi for GrpcClient {
238 async fn put(
239 &self,
240 key: impl AsRef<[u8]> + Send,
241 value: impl AsRef<[u8]> + Send,
242 ) -> ClientApiResult<()> {
243 let _timer = ScopedTimer::new("client::put");
245
246 let client_inner = self.client_inner.load();
247
248 let command = WriteCommand::insert(
250 Bytes::copy_from_slice(key.as_ref()),
251 Bytes::copy_from_slice(value.as_ref()),
252 );
253
254 let request = ClientWriteRequest {
255 client_id: client_inner.client_id,
256 command: Some(command),
257 };
258
259 let mut client = self.make_leader_client().await?;
261 match client.handle_client_write(request).await {
262 Ok(response) => {
263 debug!("[:GrpcClient:write] response: {:?}", response);
264 let client_response = response.get_ref();
265 client_response.validate_error()
266 }
267 Err(status) => {
268 error!("[:GrpcClient:write] status: {:?}", status);
269 Err(Into::<ClientApiError>::into(ClientApiError::from(status)))
270 }
271 }
272 }
273
274 async fn put_with_ttl(
275 &self,
276 key: impl AsRef<[u8]> + Send,
277 value: impl AsRef<[u8]> + Send,
278 ttl_secs: u64,
279 ) -> ClientApiResult<()> {
280 let _timer = ScopedTimer::new("client::put_with_ttl");
282
283 let client_inner = self.client_inner.load();
284
285 let command = WriteCommand::insert_with_ttl(
287 Bytes::copy_from_slice(key.as_ref()),
288 Bytes::copy_from_slice(value.as_ref()),
289 ttl_secs,
290 );
291
292 let request = ClientWriteRequest {
293 client_id: client_inner.client_id,
294 command: Some(command),
295 };
296
297 let mut client = self.make_leader_client().await?;
299 match client.handle_client_write(request).await {
300 Ok(response) => {
301 debug!("[:GrpcClient:put_with_ttl] response: {:?}", response);
302 let client_response = response.get_ref();
303 client_response.validate_error()
304 }
305 Err(status) => {
306 error!("[:GrpcClient:put_with_ttl] status: {:?}", status);
307 Err(Into::<ClientApiError>::into(ClientApiError::from(status)))
308 }
309 }
310 }
311
312 async fn get(
313 &self,
314 key: impl AsRef<[u8]> + Send,
315 ) -> ClientApiResult<Option<Bytes>> {
316 let result = self.get_with_policy(key, None).await;
318
319 match result {
320 Ok(Some(client_result)) => Ok(Some(client_result.value)),
321 Ok(None) => Ok(None),
322 Err(e) => Err(Into::<ClientApiError>::into(e)),
323 }
324 }
325
326 async fn get_multi(
327 &self,
328 keys: &[Bytes],
329 ) -> ClientApiResult<Vec<Option<Bytes>>> {
330 let result = self.get_multi_with_policy(keys.iter().cloned(), None).await;
332
333 match result {
334 Ok(results) => {
335 Ok(results.into_iter().map(|opt| opt.map(|r| r.value)).collect())
337 }
338 Err(e) => Err(Into::<ClientApiError>::into(e)),
339 }
340 }
341
342 async fn delete(
343 &self,
344 key: impl AsRef<[u8]> + Send,
345 ) -> ClientApiResult<()> {
346 let client_inner = self.client_inner.load();
347
348 let command = WriteCommand::delete(Bytes::copy_from_slice(key.as_ref()));
350
351 let request = ClientWriteRequest {
352 client_id: client_inner.client_id,
353 command: Some(command),
354 };
355
356 let mut client = self.make_leader_client().await?;
358 match client.handle_client_write(request).await {
359 Ok(response) => {
360 debug!("[:GrpcClient:delete] response: {:?}", response);
361 let client_response = response.get_ref();
362 client_response.validate_error()
363 }
364 Err(status) => {
365 error!("[:GrpcClient:delete] status: {:?}", status);
366 Err(Into::<ClientApiError>::into(ClientApiError::from(status)))
367 }
368 }
369 }
370
371 async fn compare_and_swap(
372 &self,
373 key: impl AsRef<[u8]> + Send,
374 expected_value: Option<impl AsRef<[u8]> + Send>,
375 new_value: impl AsRef<[u8]> + Send,
376 ) -> ClientApiResult<bool> {
377 let client_inner = self.client_inner.load();
378
379 let expected = expected_value.map(|v| Bytes::copy_from_slice(v.as_ref()));
381 let command = WriteCommand::compare_and_swap(
382 Bytes::copy_from_slice(key.as_ref()),
383 expected,
384 Bytes::copy_from_slice(new_value.as_ref()),
385 );
386
387 let request = ClientWriteRequest {
388 client_id: client_inner.client_id,
389 command: Some(command),
390 };
391
392 let mut client = self.make_leader_client().await?;
394 match client.handle_client_write(request).await {
395 Ok(response) => {
396 debug!("[:GrpcClient:compare_and_swap] response: {:?}", response);
397 let client_response = response.get_ref();
398
399 client_response.validate_error()?;
401
402 Ok(client_response.is_write_success())
404 }
405 Err(status) => {
406 error!("[:GrpcClient:compare_and_swap] status: {:?}", status);
407 Err(Into::<ClientApiError>::into(ClientApiError::from(status)))
408 }
409 }
410 }
411
412 async fn list_members(
413 &self
414 ) -> ClientApiResult<Vec<d_engine_proto::server::cluster::NodeMeta>> {
415 let client_inner = self.client_inner.load();
416 Ok(client_inner.pool.get_all_members())
417 }
418
419 async fn get_leader_id(&self) -> ClientApiResult<Option<u32>> {
420 let client_inner = self.client_inner.load();
421 Ok(client_inner.pool.get_leader_id())
422 }
423
424 async fn get_multi_with_policy(
425 &self,
426 keys: &[Bytes],
427 consistency_policy: Option<ReadConsistencyPolicy>,
428 ) -> ClientApiResult<Vec<Option<Bytes>>> {
429 let result =
431 <Self>::get_multi_with_policy(self, keys.iter().cloned(), consistency_policy).await;
432
433 match result {
434 Ok(results) => Ok(results.into_iter().map(|opt| opt.map(|r| r.value)).collect()),
435 Err(e) => Err(e),
436 }
437 }
438
439 async fn get_linearizable(
440 &self,
441 key: impl AsRef<[u8]> + Send,
442 ) -> ClientApiResult<Option<Bytes>> {
443 let result = self.get_with_policy(key, Some(ReadConsistencyPolicy::LinearizableRead)).await;
444
445 match result {
446 Ok(Some(client_result)) => Ok(Some(client_result.value)),
447 Ok(None) => Ok(None),
448 Err(e) => Err(e),
449 }
450 }
451
452 async fn get_lease(
453 &self,
454 key: impl AsRef<[u8]> + Send,
455 ) -> ClientApiResult<Option<Bytes>> {
456 let result = self.get_with_policy(key, Some(ReadConsistencyPolicy::LeaseRead)).await;
457
458 match result {
459 Ok(Some(client_result)) => Ok(Some(client_result.value)),
460 Ok(None) => Ok(None),
461 Err(e) => Err(e),
462 }
463 }
464
465 async fn get_eventual(
466 &self,
467 key: impl AsRef<[u8]> + Send,
468 ) -> ClientApiResult<Option<Bytes>> {
469 let result = self
470 .get_with_policy(key, Some(ReadConsistencyPolicy::EventualConsistency))
471 .await;
472
473 match result {
474 Ok(Some(client_result)) => Ok(Some(client_result.value)),
475 Ok(None) => Ok(None),
476 Err(e) => Err(e),
477 }
478 }
479}