1use std::marker::PhantomData;
4use std::sync::Arc;
5
6use async_recursion::async_recursion;
7use async_trait::async_trait;
8use futures::future::try_join_all;
9use futures::prelude::*;
10use log::debug;
11use log::info;
12use tokio::sync::Semaphore;
13use tokio::time::sleep;
14
15use crate::backoff::Backoff;
16use crate::pd::PdClient;
17use crate::proto::errorpb;
18use crate::proto::errorpb::EpochNotMatch;
19use crate::proto::kvrpcpb;
20use crate::proto::pdpb::Timestamp;
21use crate::region::StoreId;
22use crate::region::{RegionVerId, RegionWithLeader};
23use crate::request::shard::HasNextBatch;
24use crate::request::NextBatch;
25use crate::request::Shardable;
26use crate::request::{KvRequest, StoreRequest};
27use crate::stats::tikv_stats;
28use crate::store::HasRegionError;
29use crate::store::HasRegionErrors;
30use crate::store::KvClient;
31use crate::store::RegionStore;
32use crate::store::{HasKeyErrors, Store};
33use crate::transaction::resolve_locks;
34use crate::transaction::HasLocks;
35use crate::transaction::ResolveLocksContext;
36use crate::transaction::ResolveLocksOptions;
37use crate::util::iter::FlatMapOkIterExt;
38use crate::Error;
39use crate::Result;
40
41use super::keyspace::Keyspace;
42
43#[async_trait]
46pub trait Plan: Sized + Clone + Sync + Send + 'static {
47 type Result: Send;
49
50 async fn execute(&self) -> Result<Self::Result>;
52}
53
54#[derive(Clone)]
56pub struct Dispatch<Req: KvRequest> {
57 pub request: Req,
58 pub kv_client: Option<Arc<dyn KvClient + Send + Sync>>,
59}
60
61#[async_trait]
62impl<Req: KvRequest> Plan for Dispatch<Req> {
63 type Result = Req::Response;
64
65 async fn execute(&self) -> Result<Self::Result> {
66 let stats = tikv_stats(self.request.label());
67 let result = self
68 .kv_client
69 .as_ref()
70 .expect("Unreachable: kv_client has not been initialised in Dispatch")
71 .dispatch(&self.request)
72 .await;
73 let result = stats.done(result);
74 result.map(|r| {
75 *r.downcast()
76 .expect("Downcast failed: request and response type mismatch")
77 })
78 }
79}
80
81impl<Req: KvRequest + StoreRequest> StoreRequest for Dispatch<Req> {
82 fn apply_store(&mut self, store: &Store) {
83 self.kv_client = Some(store.client.clone());
84 self.request.apply_store(store);
85 }
86}
87
88const MULTI_REGION_CONCURRENCY: usize = 16;
89const MULTI_STORES_CONCURRENCY: usize = 16;
90
91pub(crate) fn is_grpc_error(e: &Error) -> bool {
92 matches!(e, Error::GrpcAPI(_) | Error::Grpc(_))
93}
94
95pub struct RetryableMultiRegion<P: Plan, PdC: PdClient> {
96 pub(super) inner: P,
97 pub pd_client: Arc<PdC>,
98 pub backoff: Backoff,
99
100 pub preserve_region_results: bool,
104}
105
106impl<P: Plan + Shardable, PdC: PdClient> RetryableMultiRegion<P, PdC>
107where
108 P::Result: HasKeyErrors + HasRegionError,
109{
110 #[async_recursion]
112 async fn single_plan_handler(
113 pd_client: Arc<PdC>,
114 current_plan: P,
115 backoff: Backoff,
116 permits: Arc<Semaphore>,
117 preserve_region_results: bool,
118 ) -> Result<<Self as Plan>::Result> {
119 let shards = current_plan.shards(&pd_client).collect::<Vec<_>>().await;
120 debug!("single_plan_handler, shards: {}", shards.len());
121 let mut handles = Vec::with_capacity(shards.len());
122 for shard in shards {
123 let (shard, region) = shard?;
124 let clone = current_plan.clone_then_apply_shard(shard);
125 let handle = tokio::spawn(Self::single_shard_handler(
126 pd_client.clone(),
127 clone,
128 region,
129 backoff.clone(),
130 permits.clone(),
131 preserve_region_results,
132 ));
133 handles.push(handle);
134 }
135
136 let results = try_join_all(handles).await?;
137 if preserve_region_results {
138 Ok(results
139 .into_iter()
140 .flat_map_ok(|x| x)
141 .map(|x| match x {
142 Ok(r) => r,
143 Err(e) => Err(e),
144 })
145 .collect())
146 } else {
147 Ok(results
148 .into_iter()
149 .collect::<Result<Vec<_>>>()?
150 .into_iter()
151 .flatten()
152 .collect())
153 }
154 }
155
156 #[async_recursion]
157 async fn single_shard_handler(
158 pd_client: Arc<PdC>,
159 mut plan: P,
160 region: RegionWithLeader,
161 mut backoff: Backoff,
162 permits: Arc<Semaphore>,
163 preserve_region_results: bool,
164 ) -> Result<<Self as Plan>::Result> {
165 debug!("single_shard_handler");
166 let region_store = match pd_client
167 .clone()
168 .map_region_to_store(region)
169 .await
170 .and_then(|region_store| {
171 plan.apply_store(®ion_store)?;
172 Ok(region_store)
173 }) {
174 Ok(region_store) => region_store,
175 Err(Error::LeaderNotFound { region }) => {
176 debug!(
177 "single_shard_handler::sharding: leader not found: {:?}",
178 region
179 );
180 return Self::handle_other_error(
181 pd_client,
182 plan,
183 region.clone(),
184 None,
185 backoff,
186 permits,
187 preserve_region_results,
188 Error::LeaderNotFound { region },
189 )
190 .await;
191 }
192 Err(err) => {
193 debug!("single_shard_handler::sharding, error: {:?}", err);
194 return Err(err);
195 }
196 };
197
198 let permit = permits.acquire().await.unwrap();
200 let res = plan.execute().await;
201 drop(permit);
202
203 let mut resp = match res {
204 Ok(resp) => resp,
205 Err(e) if is_grpc_error(&e) => {
206 debug!("single_shard_handler:execute: grpc error: {:?}", e);
207 return Self::handle_other_error(
208 pd_client,
209 plan,
210 region_store.region_with_leader.ver_id(),
211 region_store.region_with_leader.get_store_id().ok(),
212 backoff,
213 permits,
214 preserve_region_results,
215 e,
216 )
217 .await;
218 }
219 Err(e) => {
220 debug!("single_shard_handler:execute: error: {:?}", e);
221 return Err(e);
222 }
223 };
224
225 if let Some(e) = resp.key_errors() {
226 debug!("single_shard_handler:execute: key errors: {:?}", e);
227 Ok(vec![Err(Error::MultipleKeyErrors(e))])
228 } else if let Some(e) = resp.region_error() {
229 debug!("single_shard_handler:execute: region error: {:?}", e);
230 match backoff.next_delay_duration() {
231 Some(duration) => {
232 let region_error_resolved =
233 handle_region_error(pd_client.clone(), e, region_store).await?;
234 if !region_error_resolved {
236 sleep(duration).await;
237 }
238 Self::single_plan_handler(
239 pd_client,
240 plan,
241 backoff,
242 permits,
243 preserve_region_results,
244 )
245 .await
246 }
247 None => Err(Error::RegionError(Box::new(e))),
248 }
249 } else {
250 Ok(vec![Ok(resp)])
251 }
252 }
253
254 #[allow(clippy::too_many_arguments)]
255 async fn handle_other_error(
256 pd_client: Arc<PdC>,
257 plan: P,
258 region: RegionVerId,
259 store: Option<StoreId>,
260 mut backoff: Backoff,
261 permits: Arc<Semaphore>,
262 preserve_region_results: bool,
263 e: Error,
264 ) -> Result<<Self as Plan>::Result> {
265 debug!("handle_other_error: {:?}", e);
266 pd_client.invalidate_region_cache(region).await;
267 if is_grpc_error(&e) {
268 if let Some(store_id) = store {
269 pd_client.invalidate_store_cache(store_id).await;
270 }
271 }
272 match backoff.next_delay_duration() {
273 Some(duration) => {
274 sleep(duration).await;
275 Self::single_plan_handler(
276 pd_client,
277 plan,
278 backoff,
279 permits,
280 preserve_region_results,
281 )
282 .await
283 }
284 None => Err(e),
285 }
286 }
287}
288
289pub(crate) async fn handle_region_error<PdC: PdClient>(
294 pd_client: Arc<PdC>,
295 e: errorpb::Error,
296 region_store: RegionStore,
297) -> Result<bool> {
298 debug!("handle_region_error: {:?}", e);
299 let ver_id = region_store.region_with_leader.ver_id();
300 let store_id = region_store.region_with_leader.get_store_id();
301 if let Some(not_leader) = e.not_leader {
302 if let Some(leader) = not_leader.leader {
303 match pd_client
304 .update_leader(region_store.region_with_leader.ver_id(), leader)
305 .await
306 {
307 Ok(_) => Ok(true),
308 Err(e) => {
309 pd_client.invalidate_region_cache(ver_id).await;
310 Err(e)
311 }
312 }
313 } else {
314 pd_client.invalidate_region_cache(ver_id).await;
319 Ok(false)
320 }
321 } else if e.store_not_match.is_some() {
322 pd_client.invalidate_region_cache(ver_id).await;
323 if let Ok(store_id) = store_id {
324 pd_client.invalidate_store_cache(store_id).await;
325 }
326 Ok(false)
327 } else if e.epoch_not_match.is_some() {
328 on_region_epoch_not_match(pd_client.clone(), region_store, e.epoch_not_match.unwrap()).await
329 } else if e.stale_command.is_some() || e.region_not_found.is_some() {
330 pd_client.invalidate_region_cache(ver_id).await;
331 Ok(false)
332 } else if e.server_is_busy.is_some()
333 || e.raft_entry_too_large.is_some()
334 || e.max_timestamp_not_synced.is_some()
335 {
336 Err(Error::RegionError(Box::new(e)))
337 } else {
338 pd_client.invalidate_region_cache(ver_id).await;
341 if let Ok(store_id) = store_id {
342 pd_client.invalidate_store_cache(store_id).await;
343 }
344 Ok(false)
345 }
346}
347
348pub(crate) async fn on_region_epoch_not_match<PdC: PdClient>(
353 pd_client: Arc<PdC>,
354 region_store: RegionStore,
355 error: EpochNotMatch,
356) -> Result<bool> {
357 let ver_id = region_store.region_with_leader.ver_id();
358 if error.current_regions.is_empty() {
359 pd_client.invalidate_region_cache(ver_id).await;
360 return Ok(true);
361 }
362
363 for r in error.current_regions {
364 if r.id == region_store.region_with_leader.id() {
365 let region_epoch = r.region_epoch.unwrap();
366 let returned_conf_ver = region_epoch.conf_ver;
367 let returned_version = region_epoch.version;
368 let current_region_epoch = region_store
369 .region_with_leader
370 .region
371 .region_epoch
372 .clone()
373 .unwrap();
374 let current_conf_ver = current_region_epoch.conf_ver;
375 let current_version = current_region_epoch.version;
376
377 if returned_conf_ver < current_conf_ver || returned_version < current_version {
379 return Ok(false);
380 }
381 }
382 }
383 pd_client.invalidate_region_cache(ver_id).await;
385 Ok(false)
386}
387
388impl<P: Plan, PdC: PdClient> Clone for RetryableMultiRegion<P, PdC> {
389 fn clone(&self) -> Self {
390 RetryableMultiRegion {
391 inner: self.inner.clone(),
392 pd_client: self.pd_client.clone(),
393 backoff: self.backoff.clone(),
394 preserve_region_results: self.preserve_region_results,
395 }
396 }
397}
398
399#[async_trait]
400impl<P: Plan + Shardable, PdC: PdClient> Plan for RetryableMultiRegion<P, PdC>
401where
402 P::Result: HasKeyErrors + HasRegionError,
403{
404 type Result = Vec<Result<P::Result>>;
405
406 async fn execute(&self) -> Result<Self::Result> {
407 let concurrency_permits = Arc::new(Semaphore::new(MULTI_REGION_CONCURRENCY));
411 Self::single_plan_handler(
412 self.pd_client.clone(),
413 self.inner.clone(),
414 self.backoff.clone(),
415 concurrency_permits.clone(),
416 self.preserve_region_results,
417 )
418 .await
419 }
420}
421
422pub struct RetryableAllStores<P: Plan, PdC: PdClient> {
423 pub(super) inner: P,
424 pub pd_client: Arc<PdC>,
425 pub backoff: Backoff,
426}
427
428impl<P: Plan, PdC: PdClient> Clone for RetryableAllStores<P, PdC> {
429 fn clone(&self) -> Self {
430 RetryableAllStores {
431 inner: self.inner.clone(),
432 pd_client: self.pd_client.clone(),
433 backoff: self.backoff.clone(),
434 }
435 }
436}
437
438#[async_trait]
443impl<P: Plan + StoreRequest, PdC: PdClient> Plan for RetryableAllStores<P, PdC>
444where
445 P::Result: HasKeyErrors + HasRegionError,
446{
447 type Result = Vec<Result<P::Result>>;
448
449 async fn execute(&self) -> Result<Self::Result> {
450 let concurrency_permits = Arc::new(Semaphore::new(MULTI_STORES_CONCURRENCY));
451 let stores = self.pd_client.clone().all_stores().await?;
452 let mut handles = Vec::with_capacity(stores.len());
453 for store in stores {
454 let mut clone = self.inner.clone();
455 clone.apply_store(&store);
456 let handle = tokio::spawn(Self::single_store_handler(
457 clone,
458 self.backoff.clone(),
459 concurrency_permits.clone(),
460 ));
461 handles.push(handle);
462 }
463 let results = try_join_all(handles).await?;
464 Ok(results.into_iter().collect::<Vec<_>>())
465 }
466}
467
468impl<P: Plan, PdC: PdClient> RetryableAllStores<P, PdC>
469where
470 P::Result: HasKeyErrors + HasRegionError,
471{
472 async fn single_store_handler(
473 plan: P,
474 mut backoff: Backoff,
475 permits: Arc<Semaphore>,
476 ) -> Result<P::Result> {
477 loop {
478 let permit = permits.acquire().await.unwrap();
479 let res = plan.execute().await;
480 drop(permit);
481
482 match res {
483 Ok(mut resp) => {
484 if let Some(e) = resp.key_errors() {
485 return Err(Error::MultipleKeyErrors(e));
486 } else if let Some(e) = resp.region_error() {
487 return Err(Error::RegionError(Box::new(e)));
489 } else {
490 return Ok(resp);
491 }
492 }
493 Err(e) if is_grpc_error(&e) => match backoff.next_delay_duration() {
494 Some(duration) => {
495 sleep(duration).await;
496 continue;
497 }
498 None => return Err(e),
499 },
500 Err(e) => return Err(e),
501 }
502 }
503 }
504}
505
506pub trait Merge<In>: Sized + Clone + Send + Sync + 'static {
508 type Out: Send;
509
510 fn merge(&self, input: Vec<Result<In>>) -> Result<Self::Out>;
511}
512
513#[derive(Clone)]
514pub struct MergeResponse<P: Plan, In, M: Merge<In>> {
515 pub inner: P,
516 pub merge: M,
517 pub phantom: PhantomData<In>,
518}
519
520#[async_trait]
521impl<In: Clone + Send + Sync + 'static, P: Plan<Result = Vec<Result<In>>>, M: Merge<In>> Plan
522 for MergeResponse<P, In, M>
523{
524 type Result = M::Out;
525
526 async fn execute(&self) -> Result<Self::Result> {
527 self.merge.merge(self.inner.execute().await?)
528 }
529}
530
531#[derive(Clone, Copy)]
533pub struct Collect;
534
535#[derive(Clone, Copy)]
538pub struct CollectSingle;
539
540#[doc(hidden)]
541#[macro_export]
542macro_rules! collect_single {
543 ($type_: ty) => {
544 impl Merge<$type_> for CollectSingle {
545 type Out = $type_;
546
547 fn merge(&self, mut input: Vec<Result<$type_>>) -> Result<Self::Out> {
548 assert!(input.len() == 1);
549 input.pop().unwrap()
550 }
551 }
552 };
553}
554
555#[derive(Clone, Debug)]
559pub struct CollectWithShard;
560
561#[derive(Clone, Copy)]
564pub struct CollectError;
565
566impl<T: Send> Merge<T> for CollectError {
567 type Out = Vec<T>;
568
569 fn merge(&self, input: Vec<Result<T>>) -> Result<Self::Out> {
570 input.into_iter().collect()
571 }
572}
573
574pub trait Process<In>: Sized + Clone + Send + Sync + 'static {
576 type Out: Send;
577
578 fn process(&self, input: Result<In>) -> Result<Self::Out>;
579}
580
581#[derive(Clone)]
582pub struct ProcessResponse<P: Plan, Pr: Process<P::Result>> {
583 pub inner: P,
584 pub processor: Pr,
585}
586
587#[async_trait]
588impl<P: Plan, Pr: Process<P::Result>> Plan for ProcessResponse<P, Pr> {
589 type Result = Pr::Out;
590
591 async fn execute(&self) -> Result<Self::Result> {
592 self.processor.process(self.inner.execute().await)
593 }
594}
595
596#[derive(Clone, Copy, Debug)]
597pub struct DefaultProcessor;
598
599pub struct ResolveLock<P: Plan, PdC: PdClient> {
600 pub inner: P,
601 pub timestamp: Timestamp,
602 pub pd_client: Arc<PdC>,
603 pub backoff: Backoff,
604 pub keyspace: Keyspace,
605}
606
607impl<P: Plan, PdC: PdClient> Clone for ResolveLock<P, PdC> {
608 fn clone(&self) -> Self {
609 ResolveLock {
610 inner: self.inner.clone(),
611 timestamp: self.timestamp.clone(),
612 pd_client: self.pd_client.clone(),
613 backoff: self.backoff.clone(),
614 keyspace: self.keyspace,
615 }
616 }
617}
618
619#[async_trait]
620impl<P: Plan, PdC: PdClient> Plan for ResolveLock<P, PdC>
621where
622 P::Result: HasLocks,
623{
624 type Result = P::Result;
625
626 async fn execute(&self) -> Result<Self::Result> {
627 let mut result = self.inner.execute().await?;
628 let mut clone = self.clone();
629 loop {
630 let locks = result.take_locks();
631 if locks.is_empty() {
632 return Ok(result);
633 }
634
635 if self.backoff.is_none() {
636 return Err(Error::ResolveLockError(locks));
637 }
638
639 let pd_client = self.pd_client.clone();
640 let live_locks = resolve_locks(
641 locks,
642 self.timestamp.clone(),
643 pd_client.clone(),
644 self.keyspace,
645 )
646 .await?;
647 if live_locks.is_empty() {
648 result = self.inner.execute().await?;
649 } else {
650 match clone.backoff.next_delay_duration() {
651 None => return Err(Error::ResolveLockError(live_locks)),
652 Some(delay_duration) => {
653 sleep(delay_duration).await;
654 result = clone.inner.execute().await?;
655 }
656 }
657 }
658 }
659 }
660}
661
662#[derive(Debug, Default)]
663pub struct CleanupLocksResult {
664 pub region_error: Option<errorpb::Error>,
665 pub key_error: Option<Vec<Error>>,
666 pub resolved_locks: usize,
667}
668
669impl Clone for CleanupLocksResult {
670 fn clone(&self) -> Self {
671 Self {
672 resolved_locks: self.resolved_locks,
673 ..Default::default() }
675 }
676}
677
678impl HasRegionError for CleanupLocksResult {
679 fn region_error(&mut self) -> Option<errorpb::Error> {
680 self.region_error.take()
681 }
682}
683
684impl HasKeyErrors for CleanupLocksResult {
685 fn key_errors(&mut self) -> Option<Vec<Error>> {
686 self.key_error.take()
687 }
688}
689
690impl Merge<CleanupLocksResult> for Collect {
691 type Out = CleanupLocksResult;
692
693 fn merge(&self, input: Vec<Result<CleanupLocksResult>>) -> Result<Self::Out> {
694 input
695 .into_iter()
696 .try_fold(CleanupLocksResult::default(), |acc, x| {
697 Ok(CleanupLocksResult {
698 resolved_locks: acc.resolved_locks + x?.resolved_locks,
699 ..Default::default()
700 })
701 })
702 }
703}
704
705pub struct CleanupLocks<P: Plan, PdC: PdClient> {
706 pub inner: P,
707 pub ctx: ResolveLocksContext,
708 pub options: ResolveLocksOptions,
709 pub store: Option<RegionStore>,
710 pub pd_client: Arc<PdC>,
711 pub keyspace: Keyspace,
712 pub backoff: Backoff,
713}
714
715impl<P: Plan, PdC: PdClient> Clone for CleanupLocks<P, PdC> {
716 fn clone(&self) -> Self {
717 CleanupLocks {
718 inner: self.inner.clone(),
719 ctx: self.ctx.clone(),
720 options: self.options,
721 store: None,
722 pd_client: self.pd_client.clone(),
723 keyspace: self.keyspace,
724 backoff: self.backoff.clone(),
725 }
726 }
727}
728
729#[async_trait]
730impl<P: Plan + Shardable + NextBatch, PdC: PdClient> Plan for CleanupLocks<P, PdC>
731where
732 P::Result: HasLocks + HasNextBatch + HasKeyErrors + HasRegionError,
733{
734 type Result = CleanupLocksResult;
735
736 async fn execute(&self) -> Result<Self::Result> {
737 let mut result = CleanupLocksResult::default();
738 let mut inner = self.inner.clone();
739 let mut lock_resolver = crate::transaction::LockResolver::new(self.ctx.clone());
740 let region = &self.store.as_ref().unwrap().region_with_leader;
741 let mut has_more_batch = true;
742
743 while has_more_batch {
744 let mut scan_lock_resp = inner.execute().await?;
745
746 if let Some(e) = scan_lock_resp.key_errors() {
748 info!("CleanupLocks::execute, inner key errors:{:?}", e);
749 result.key_error = Some(e);
750 return Ok(result);
751 } else if let Some(e) = scan_lock_resp.region_error() {
752 info!("CleanupLocks::execute, inner region error:{}", e.message);
753 result.region_error = Some(e);
754 return Ok(result);
755 }
756
757 match scan_lock_resp.has_next_batch() {
759 Some(range) if region.contains(range.0.as_ref()) => {
760 debug!("CleanupLocks::execute, next range:{:?}", range);
761 inner.next_batch(range);
762 }
763 _ => has_more_batch = false,
764 }
765
766 let mut locks = scan_lock_resp.take_locks();
767 if locks.is_empty() {
768 break;
769 }
770 if locks.len() < self.options.batch_size as usize {
771 has_more_batch = false;
772 }
773
774 if self.options.async_commit_only {
775 locks = locks
776 .into_iter()
777 .filter(|l| l.use_async_commit)
778 .collect::<Vec<_>>();
779 }
780 debug!("CleanupLocks::execute, meet locks:{}", locks.len());
781
782 let lock_size = locks.len();
783 match lock_resolver
784 .cleanup_locks(
785 self.store.clone().unwrap(),
786 locks,
787 self.pd_client.clone(),
788 self.keyspace,
789 )
790 .await
791 {
792 Ok(()) => {
793 result.resolved_locks += lock_size;
794 }
795 Err(Error::ExtractedErrors(mut errors)) => {
796 if let Error::RegionError(e) = errors.pop().unwrap() {
798 result.region_error = Some(*e);
799 } else {
800 result.key_error = Some(errors);
801 }
802 return Ok(result);
803 }
804 Err(e) => {
805 return Err(e);
806 }
807 }
808
809 }
814
815 Ok(result)
816 }
817}
818
819pub struct ExtractError<P: Plan> {
828 pub inner: P,
829}
830
831impl<P: Plan> Clone for ExtractError<P> {
832 fn clone(&self) -> Self {
833 ExtractError {
834 inner: self.inner.clone(),
835 }
836 }
837}
838
839#[async_trait]
840impl<P: Plan> Plan for ExtractError<P>
841where
842 P::Result: HasKeyErrors + HasRegionErrors,
843{
844 type Result = P::Result;
845
846 async fn execute(&self) -> Result<Self::Result> {
847 let mut result = self.inner.execute().await?;
848 if let Some(errors) = result.key_errors() {
849 Err(Error::ExtractedErrors(errors))
850 } else if let Some(errors) = result.region_errors() {
851 Err(Error::ExtractedErrors(
852 errors
853 .into_iter()
854 .map(|e| Error::RegionError(Box::new(e)))
855 .collect(),
856 ))
857 } else {
858 Ok(result)
859 }
860 }
861}
862
863pub struct PreserveShard<P: Plan + Shardable> {
869 pub inner: P,
870 pub shard: Option<P::Shard>,
871}
872
873impl<P: Plan + Shardable> Clone for PreserveShard<P> {
874 fn clone(&self) -> Self {
875 PreserveShard {
876 inner: self.inner.clone(),
877 shard: None,
878 }
879 }
880}
881
882#[async_trait]
883impl<P> Plan for PreserveShard<P>
884where
885 P: Plan + Shardable,
886{
887 type Result = ResponseWithShard<P::Result, P::Shard>;
888
889 async fn execute(&self) -> Result<Self::Result> {
890 let res = self.inner.execute().await?;
891 let shard = self
892 .shard
893 .as_ref()
894 .expect("Unreachable: Shardable::apply_shard() is not called before executing PreserveShard")
895 .clone();
896 Ok(ResponseWithShard(res, shard))
897 }
898}
899
900#[derive(Debug, Clone)]
902pub struct ResponseWithShard<Resp, Shard>(pub Resp, pub Shard);
903
904impl<Resp: HasKeyErrors, Shard> HasKeyErrors for ResponseWithShard<Resp, Shard> {
905 fn key_errors(&mut self) -> Option<Vec<Error>> {
906 self.0.key_errors()
907 }
908}
909
910impl<Resp: HasLocks, Shard> HasLocks for ResponseWithShard<Resp, Shard> {
911 fn take_locks(&mut self) -> Vec<kvrpcpb::LockInfo> {
912 self.0.take_locks()
913 }
914}
915
916impl<Resp: HasRegionError, Shard> HasRegionError for ResponseWithShard<Resp, Shard> {
917 fn region_error(&mut self) -> Option<errorpb::Error> {
918 self.0.region_error()
919 }
920}
921
922#[cfg(test)]
923mod test {
924 use futures::stream::BoxStream;
925 use futures::stream::{self};
926
927 use super::*;
928 use crate::mock::MockPdClient;
929 use crate::proto::kvrpcpb::BatchGetResponse;
930
931 #[derive(Clone)]
932 struct ErrPlan;
933
934 #[async_trait]
935 impl Plan for ErrPlan {
936 type Result = BatchGetResponse;
937
938 async fn execute(&self) -> Result<Self::Result> {
939 Err(Error::Unimplemented)
940 }
941 }
942
943 impl Shardable for ErrPlan {
944 type Shard = ();
945
946 fn shards(
947 &self,
948 _: &Arc<impl crate::pd::PdClient>,
949 ) -> BoxStream<'static, crate::Result<(Self::Shard, RegionWithLeader)>> {
950 Box::pin(stream::iter(1..=3).map(|_| Err(Error::Unimplemented))).boxed()
951 }
952
953 fn apply_shard(&mut self, _: Self::Shard) {}
954
955 fn apply_store(&mut self, _: &crate::store::RegionStore) -> Result<()> {
956 Ok(())
957 }
958 }
959
960 #[tokio::test]
961 async fn test_err() {
962 let plan = RetryableMultiRegion {
963 inner: ResolveLock {
964 inner: ErrPlan,
965 timestamp: Timestamp::default(),
966 backoff: Backoff::no_backoff(),
967 pd_client: Arc::new(MockPdClient::default()),
968 keyspace: Keyspace::Disable,
969 },
970 pd_client: Arc::new(MockPdClient::default()),
971 backoff: Backoff::no_backoff(),
972 preserve_region_results: false,
973 };
974 assert!(plan.execute().await.is_err())
975 }
976}