1use std::time::Duration;
19
20#[cfg(feature = "watch")]
21use std::sync::Arc;
22
23use bytes::Bytes;
24use d_engine_core::MaybeCloneOneshot;
25use d_engine_core::RaftEvent;
26use d_engine_core::RaftOneshot;
27use d_engine_core::client::{ClientApi, ClientApiError, ClientApiResult};
28use d_engine_proto::client::ClientReadRequest;
29use d_engine_proto::client::ClientWriteRequest;
30use d_engine_proto::client::ReadConsistencyPolicy;
31use d_engine_proto::client::WriteCommand;
32use d_engine_proto::common::LeaderHint;
33use d_engine_proto::error::ErrorCode;
34use tokio::sync::mpsc;
35
36#[cfg(feature = "watch")]
37use d_engine_core::watch::WatchRegistry;
38
39fn channel_closed_error() -> ClientApiError {
44 ClientApiError::Network {
45 code: ErrorCode::ConnectionTimeout,
46 message: "Channel closed, node may be shutting down".to_string(),
47 retry_after_ms: None,
48 leader_hint: None,
49 }
50}
51
52fn timeout_error(duration: Duration) -> ClientApiError {
53 ClientApiError::Network {
54 code: ErrorCode::ConnectionTimeout,
55 message: format!("Operation timed out after {duration:?}"),
56 retry_after_ms: Some(1000),
57 leader_hint: None,
58 }
59}
60
61fn not_leader_error(
62 leader_id: Option<String>,
63 leader_address: Option<String>,
64) -> ClientApiError {
65 let message = match (&leader_address, &leader_id) {
66 (Some(addr), _) => format!("Not leader, try leader at: {addr}"),
67 (None, Some(id)) => format!("Not leader, leader_id: {id}"),
68 (None, None) => "Not leader".to_string(),
69 };
70
71 let leader_hint = match (&leader_id, &leader_address) {
72 (Some(id_str), Some(addr)) => id_str.parse::<u32>().ok().map(|id| LeaderHint {
73 leader_id: id,
74 address: addr.clone(),
75 }),
76 _ => None,
77 };
78
79 ClientApiError::Network {
80 code: ErrorCode::NotLeader,
81 message,
82 retry_after_ms: Some(100),
83 leader_hint,
84 }
85}
86
87fn server_error(msg: String) -> ClientApiError {
88 ClientApiError::Business {
89 code: ErrorCode::Uncategorized,
90 message: msg,
91 required_action: None,
92 }
93}
94
95#[derive(Clone)]
99pub struct EmbeddedClient {
100 event_tx: mpsc::Sender<RaftEvent>,
101 cmd_tx: mpsc::UnboundedSender<d_engine_core::ClientCmd>,
102 client_id: u32,
103 timeout: Duration,
104 #[cfg(feature = "watch")]
105 watch_registry: Option<Arc<WatchRegistry>>,
106}
107
108impl EmbeddedClient {
109 pub(crate) fn new_internal(
111 event_tx: mpsc::Sender<RaftEvent>,
112 cmd_tx: mpsc::UnboundedSender<d_engine_core::ClientCmd>,
113 client_id: u32,
114 timeout: Duration,
115 ) -> Self {
116 Self {
117 event_tx,
118 cmd_tx,
119 client_id,
120 timeout,
121 #[cfg(feature = "watch")]
122 watch_registry: None,
123 }
124 }
125
126 #[cfg(feature = "watch")]
128 pub(crate) fn with_watch_registry(
129 mut self,
130 registry: Arc<WatchRegistry>,
131 ) -> Self {
132 self.watch_registry = Some(registry);
133 self
134 }
135
136 fn map_error_response(
138 error_code: i32,
139 metadata: Option<d_engine_proto::error::ErrorMetadata>,
140 ) -> ClientApiError {
141 match ErrorCode::try_from(error_code) {
142 Ok(ErrorCode::NotLeader) => {
143 let (leader_id, leader_address) = if let Some(meta) = metadata {
144 (meta.leader_id, meta.leader_address)
145 } else {
146 (None, None)
147 };
148 not_leader_error(leader_id, leader_address)
149 }
150 _ => server_error(format!("Error code: {error_code}")),
151 }
152 }
153
154 pub async fn put(
156 &self,
157 key: impl AsRef<[u8]>,
158 value: impl AsRef<[u8]>,
159 ) -> ClientApiResult<()> {
160 let command = WriteCommand::insert(
161 Bytes::copy_from_slice(key.as_ref()),
162 Bytes::copy_from_slice(value.as_ref()),
163 );
164
165 let request = ClientWriteRequest {
166 client_id: self.client_id,
167 command: Some(command),
168 };
169
170 let (resp_tx, resp_rx) = MaybeCloneOneshot::new();
171
172 self.cmd_tx
173 .send(d_engine_core::ClientCmd::Propose(request, resp_tx))
174 .map_err(|_| channel_closed_error())?;
175
176 let result = tokio::time::timeout(self.timeout, resp_rx)
177 .await
178 .map_err(|_| timeout_error(self.timeout))?
179 .map_err(|_| channel_closed_error())?;
180
181 let response =
182 result.map_err(|status| server_error(format!("RPC error: {}", status.message())))?;
183
184 if response.error != ErrorCode::Success as i32 {
185 return Err(Self::map_error_response(response.error, response.metadata));
186 }
187
188 Ok(())
189 }
190
191 pub async fn get_linearizable(
209 &self,
210 key: impl AsRef<[u8]>,
211 ) -> ClientApiResult<Option<Bytes>> {
212 self.get_with_consistency(key, ReadConsistencyPolicy::LinearizableRead).await
213 }
214
215 pub async fn get_eventual(
235 &self,
236 key: impl AsRef<[u8]>,
237 ) -> ClientApiResult<Option<Bytes>> {
238 self.get_with_consistency(key, ReadConsistencyPolicy::EventualConsistency).await
239 }
240
241 pub async fn get_with_consistency(
260 &self,
261 key: impl AsRef<[u8]>,
262 consistency: ReadConsistencyPolicy,
263 ) -> ClientApiResult<Option<Bytes>> {
264 let request = ClientReadRequest {
265 client_id: self.client_id,
266 keys: vec![Bytes::copy_from_slice(key.as_ref())],
267 consistency_policy: Some(consistency as i32),
268 };
269
270 let (resp_tx, resp_rx) = MaybeCloneOneshot::new();
271
272 self.cmd_tx
273 .send(d_engine_core::ClientCmd::Read(request, resp_tx))
274 .map_err(|_| channel_closed_error())?;
275
276 let result = tokio::time::timeout(self.timeout, resp_rx)
277 .await
278 .map_err(|_| timeout_error(self.timeout))?
279 .map_err(|_| channel_closed_error())?;
280
281 let response =
282 result.map_err(|status| server_error(format!("RPC error: {}", status.message())))?;
283
284 if response.error != ErrorCode::Success as i32 {
285 return Err(Self::map_error_response(response.error, response.metadata));
286 }
287
288 match response.success_result {
289 Some(d_engine_proto::client::client_response::SuccessResult::ReadData(
290 read_results,
291 )) => {
292 Ok(read_results.results.first().map(|r| r.value.clone()))
295 }
296 _ => Ok(None),
297 }
298 }
299
300 pub async fn get_multi_linearizable(
310 &self,
311 keys: &[Bytes],
312 ) -> ClientApiResult<Vec<Option<Bytes>>> {
313 self.get_multi_with_consistency(keys, ReadConsistencyPolicy::LinearizableRead)
314 .await
315 }
316
317 pub async fn get_multi_eventual(
327 &self,
328 keys: &[Bytes],
329 ) -> ClientApiResult<Vec<Option<Bytes>>> {
330 self.get_multi_with_consistency(keys, ReadConsistencyPolicy::EventualConsistency)
331 .await
332 }
333
334 pub async fn get_multi_with_consistency(
336 &self,
337 keys: &[Bytes],
338 consistency: ReadConsistencyPolicy,
339 ) -> ClientApiResult<Vec<Option<Bytes>>> {
340 let request = ClientReadRequest {
341 client_id: self.client_id,
342 keys: keys.to_vec(),
343 consistency_policy: Some(consistency as i32),
344 };
345
346 let (resp_tx, resp_rx) = MaybeCloneOneshot::new();
347
348 self.cmd_tx
349 .send(d_engine_core::ClientCmd::Read(request, resp_tx))
350 .map_err(|_| channel_closed_error())?;
351
352 let result = tokio::time::timeout(self.timeout, resp_rx)
353 .await
354 .map_err(|_| timeout_error(self.timeout))?
355 .map_err(|_| channel_closed_error())?;
356
357 let response =
358 result.map_err(|status| server_error(format!("RPC error: {}", status.message())))?;
359
360 if response.error != ErrorCode::Success as i32 {
361 return Err(Self::map_error_response(response.error, response.metadata));
362 }
363
364 match response.success_result {
365 Some(d_engine_proto::client::client_response::SuccessResult::ReadData(
366 read_results,
367 )) => {
368 let results_by_key: std::collections::HashMap<_, _> =
372 read_results.results.into_iter().map(|r| (r.key, r.value)).collect();
373
374 Ok(keys.iter().map(|k| results_by_key.get(k).cloned()).collect())
375 }
376 _ => Ok(vec![None; keys.len()]),
377 }
378 }
379
380 pub async fn delete(
382 &self,
383 key: impl AsRef<[u8]>,
384 ) -> ClientApiResult<()> {
385 let command = WriteCommand::delete(Bytes::copy_from_slice(key.as_ref()));
386
387 let request = ClientWriteRequest {
388 client_id: self.client_id,
389 command: Some(command),
390 };
391
392 let (resp_tx, resp_rx) = MaybeCloneOneshot::new();
393
394 self.cmd_tx
395 .send(d_engine_core::ClientCmd::Propose(request, resp_tx))
396 .map_err(|_| channel_closed_error())?;
397
398 let result = tokio::time::timeout(self.timeout, resp_rx)
399 .await
400 .map_err(|_| timeout_error(self.timeout))?
401 .map_err(|_| channel_closed_error())?;
402
403 let response =
404 result.map_err(|status| server_error(format!("RPC error: {}", status.message())))?;
405
406 if response.error != ErrorCode::Success as i32 {
407 return Err(Self::map_error_response(response.error, response.metadata));
408 }
409
410 Ok(())
411 }
412
413 pub fn client_id(&self) -> u32 {
415 self.client_id
416 }
417
418 pub fn timeout(&self) -> Duration {
420 self.timeout
421 }
422
423 pub fn node_id(&self) -> u32 {
425 self.client_id
426 }
427
428 #[cfg(feature = "watch")]
455 pub fn watch(
456 &self,
457 key: impl AsRef<[u8]>,
458 ) -> ClientApiResult<d_engine_core::watch::WatcherHandle> {
459 let registry = self.watch_registry.as_ref().ok_or_else(|| ClientApiError::Business {
460 code: ErrorCode::Uncategorized,
461 message: "Watch feature disabled (WatchRegistry not initialized)".to_string(),
462 required_action: None,
463 })?;
464
465 let key_bytes = Bytes::copy_from_slice(key.as_ref());
466 Ok(registry.register(key_bytes))
467 }
468
469 async fn get_cluster_membership(
471 &self
472 ) -> ClientApiResult<d_engine_proto::server::cluster::ClusterMembership> {
473 let request = d_engine_proto::server::cluster::MetadataRequest {};
474
475 let (resp_tx, resp_rx) = MaybeCloneOneshot::new();
476
477 self.event_tx
478 .send(RaftEvent::ClusterConf(request, resp_tx))
479 .await
480 .map_err(|_| channel_closed_error())?;
481
482 let result = tokio::time::timeout(self.timeout, resp_rx)
483 .await
484 .map_err(|_| timeout_error(self.timeout))?
485 .map_err(|_| channel_closed_error())?;
486
487 result.map_err(|status| server_error(format!("ClusterConf error: {}", status.message())))
488 }
489}
490
491impl std::fmt::Debug for EmbeddedClient {
492 fn fmt(
493 &self,
494 f: &mut std::fmt::Formatter<'_>,
495 ) -> std::fmt::Result {
496 f.debug_struct("EmbeddedClient")
497 .field("client_id", &self.client_id)
498 .field("timeout", &self.timeout)
499 .finish()
500 }
501}
502
503#[async_trait::async_trait]
505impl ClientApi for EmbeddedClient {
506 async fn put(
507 &self,
508 key: impl AsRef<[u8]> + Send,
509 value: impl AsRef<[u8]> + Send,
510 ) -> ClientApiResult<()> {
511 self.put(key, value).await
512 }
513
514 async fn put_with_ttl(
515 &self,
516 key: impl AsRef<[u8]> + Send,
517 value: impl AsRef<[u8]> + Send,
518 ttl_secs: u64,
519 ) -> ClientApiResult<()> {
520 let command = WriteCommand::insert_with_ttl(
521 Bytes::copy_from_slice(key.as_ref()),
522 Bytes::copy_from_slice(value.as_ref()),
523 ttl_secs,
524 );
525
526 let request = ClientWriteRequest {
527 client_id: self.client_id,
528 command: Some(command),
529 };
530
531 let (resp_tx, resp_rx) = MaybeCloneOneshot::new();
532
533 self.cmd_tx
534 .send(d_engine_core::ClientCmd::Propose(request, resp_tx))
535 .map_err(|_| channel_closed_error())?;
536
537 let result = tokio::time::timeout(self.timeout, resp_rx)
538 .await
539 .map_err(|_| timeout_error(self.timeout))?
540 .map_err(|_| channel_closed_error())?;
541
542 let response =
543 result.map_err(|status| server_error(format!("RPC error: {}", status.message())))?;
544
545 if response.error != ErrorCode::Success as i32 {
546 return Err(Self::map_error_response(response.error, response.metadata));
547 }
548
549 Ok(())
550 }
551
552 async fn get(
553 &self,
554 key: impl AsRef<[u8]> + Send,
555 ) -> ClientApiResult<Option<Bytes>> {
556 self.get_linearizable(key).await
557 }
558
559 async fn get_multi(
560 &self,
561 keys: &[Bytes],
562 ) -> ClientApiResult<Vec<Option<Bytes>>> {
563 self.get_multi_linearizable(keys).await
564 }
565
566 async fn delete(
567 &self,
568 key: impl AsRef<[u8]> + Send,
569 ) -> ClientApiResult<()> {
570 self.delete(key).await
571 }
572
573 async fn compare_and_swap(
574 &self,
575 key: impl AsRef<[u8]> + Send,
576 expected_value: Option<impl AsRef<[u8]> + Send>,
577 new_value: impl AsRef<[u8]> + Send,
578 ) -> ClientApiResult<bool> {
579 let command = WriteCommand::compare_and_swap(
580 Bytes::copy_from_slice(key.as_ref()),
581 expected_value.map(|v| Bytes::copy_from_slice(v.as_ref())),
582 Bytes::copy_from_slice(new_value.as_ref()),
583 );
584
585 let request = ClientWriteRequest {
586 client_id: self.client_id,
587 command: Some(command),
588 };
589
590 let (resp_tx, resp_rx) = MaybeCloneOneshot::new();
591
592 self.cmd_tx
593 .send(d_engine_core::ClientCmd::Propose(request, resp_tx))
594 .map_err(|_| channel_closed_error())?;
595
596 let result = tokio::time::timeout(self.timeout, resp_rx)
597 .await
598 .map_err(|_| timeout_error(self.timeout))?
599 .map_err(|_| channel_closed_error())?;
600
601 let response =
602 result.map_err(|status| server_error(format!("RPC error: {}", status.message())))?;
603
604 if response.error != ErrorCode::Success as i32 {
605 return Err(Self::map_error_response(response.error, response.metadata));
606 }
607
608 match response.success_result {
609 Some(d_engine_proto::client::client_response::SuccessResult::WriteResult(result)) => {
610 Ok(result.succeeded)
611 }
612 _ => Err(server_error("Invalid CAS response".to_string())),
613 }
614 }
615
616 async fn list_members(
617 &self
618 ) -> ClientApiResult<Vec<d_engine_proto::server::cluster::NodeMeta>> {
619 let cluster_membership = self.get_cluster_membership().await?;
620 Ok(cluster_membership.nodes)
621 }
622
623 async fn get_leader_id(&self) -> ClientApiResult<Option<u32>> {
624 let cluster_membership = self.get_cluster_membership().await?;
625 Ok(cluster_membership.current_leader_id)
626 }
627
628 async fn get_multi_with_policy(
629 &self,
630 keys: &[Bytes],
631 consistency_policy: Option<ReadConsistencyPolicy>,
632 ) -> ClientApiResult<Vec<Option<Bytes>>> {
633 self.get_multi_with_consistency(
634 keys,
635 consistency_policy.unwrap_or(ReadConsistencyPolicy::LinearizableRead),
636 )
637 .await
638 }
639
640 async fn get_linearizable(
641 &self,
642 key: impl AsRef<[u8]> + Send,
643 ) -> ClientApiResult<Option<Bytes>> {
644 self.get_linearizable(key).await
645 }
646
647 async fn get_lease(
648 &self,
649 key: impl AsRef<[u8]> + Send,
650 ) -> ClientApiResult<Option<Bytes>> {
651 self.get_with_consistency(key, ReadConsistencyPolicy::LeaseRead).await
652 }
653
654 async fn get_eventual(
655 &self,
656 key: impl AsRef<[u8]> + Send,
657 ) -> ClientApiResult<Option<Bytes>> {
658 self.get_eventual(key).await
659 }
660}