use std::collections::BTreeMap;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::Notify;
use tokio::sync::RwLock;
use crate::common::Error;
use crate::pd::Cluster;
use crate::pd::RetryClient;
use crate::pd::RetryClientTrait;
use crate::proto::metapb::Store;
use crate::proto::metapb::{self};
use crate::region::RegionId;
use crate::region::RegionVerId;
use crate::region::RegionWithLeader;
use crate::region::StoreId;
use crate::Key;
use crate::Result;
const MAX_RETRY_WAITING_CONCURRENT_REQUEST: usize = 4;
struct RegionCacheMap {
ver_id_to_region: HashMap<RegionVerId, RegionWithLeader>,
key_to_ver_id: BTreeMap<Key, RegionVerId>,
id_to_ver_id: HashMap<RegionId, RegionVerId>,
on_my_way_id: HashMap<RegionId, Arc<Notify>>,
}
impl RegionCacheMap {
fn new() -> RegionCacheMap {
RegionCacheMap {
ver_id_to_region: HashMap::new(),
key_to_ver_id: BTreeMap::new(),
id_to_ver_id: HashMap::new(),
on_my_way_id: HashMap::new(),
}
}
}
pub struct RegionCache<Client = RetryClient<Cluster>> {
region_cache: RwLock<RegionCacheMap>,
store_cache: RwLock<HashMap<StoreId, Store>>,
inner_client: Arc<Client>,
}
impl<Client> RegionCache<Client> {
pub fn new(inner_client: Arc<Client>) -> RegionCache<Client> {
RegionCache {
region_cache: RwLock::new(RegionCacheMap::new()),
store_cache: RwLock::new(HashMap::new()),
inner_client,
}
}
}
impl<C: RetryClientTrait> RegionCache<C> {
pub async fn get_region_by_key(&self, key: &Key) -> Result<RegionWithLeader> {
let region_cache_guard = self.region_cache.read().await;
let res = {
region_cache_guard
.key_to_ver_id
.range(..=key)
.next_back()
.map(|(x, y)| (x.clone(), y.clone()))
};
if let Some((_, candidate_region_ver_id)) = res {
let region = region_cache_guard
.ver_id_to_region
.get(&candidate_region_ver_id)
.unwrap();
if region.contains(key) {
return Ok(region.clone());
}
}
drop(region_cache_guard);
self.read_through_region_by_key(key.clone()).await
}
pub async fn get_region_by_id(&self, id: RegionId) -> Result<RegionWithLeader> {
for _ in 0..=MAX_RETRY_WAITING_CONCURRENT_REQUEST {
let region_cache_guard = self.region_cache.read().await;
let ver_id = region_cache_guard.id_to_ver_id.get(&id);
if let Some(ver_id) = ver_id {
let region = region_cache_guard.ver_id_to_region.get(ver_id).unwrap();
return Ok(region.clone());
}
let notify = region_cache_guard.on_my_way_id.get(&id).cloned();
let notified = notify.as_ref().map(|notify| notify.notified());
drop(region_cache_guard);
if let Some(n) = notified {
n.await;
continue;
} else {
return self.read_through_region_by_id(id).await;
}
}
Err(Error::StringError(format!(
"Concurrent PD requests failed for {MAX_RETRY_WAITING_CONCURRENT_REQUEST} times"
)))
}
pub async fn get_store_by_id(&self, id: StoreId) -> Result<Store> {
let store = self.store_cache.read().await.get(&id).cloned();
match store {
Some(store) => Ok(store),
None => self.read_through_store_by_id(id).await,
}
}
pub async fn read_through_region_by_key(&self, key: Key) -> Result<RegionWithLeader> {
let region = self.inner_client.clone().get_region(key.into()).await?;
self.add_region(region.clone()).await;
Ok(region)
}
async fn read_through_region_by_id(&self, id: RegionId) -> Result<RegionWithLeader> {
let notify = Arc::new(Notify::new());
{
let mut region_cache_guard = self.region_cache.write().await;
region_cache_guard.on_my_way_id.insert(id, notify.clone());
}
let region = self.inner_client.clone().get_region_by_id(id).await?;
self.add_region(region.clone()).await;
{
let mut region_cache_guard = self.region_cache.write().await;
notify.notify_waiters();
region_cache_guard.on_my_way_id.remove(&id);
}
Ok(region)
}
async fn read_through_store_by_id(&self, id: StoreId) -> Result<Store> {
let store = self.inner_client.clone().get_store(id).await?;
self.store_cache.write().await.insert(id, store.clone());
Ok(store)
}
pub async fn add_region(&self, region: RegionWithLeader) {
let mut cache = self.region_cache.write().await;
let end_key = region.end_key();
let mut to_be_removed: HashSet<RegionVerId> = HashSet::new();
if let Some(ver_id) = cache.id_to_ver_id.get(®ion.id()) {
if ver_id != ®ion.ver_id() {
to_be_removed.insert(ver_id.clone());
}
}
let mut search_range = {
if end_key.is_empty() {
cache.key_to_ver_id.range(..)
} else {
cache.key_to_ver_id.range(..end_key)
}
};
while let Some((_, ver_id_in_cache)) = search_range.next_back() {
let region_in_cache = cache.ver_id_to_region.get(ver_id_in_cache).unwrap();
if region_in_cache.region.end_key > region.region.start_key {
to_be_removed.insert(ver_id_in_cache.clone());
} else {
break;
}
}
for ver_id in to_be_removed {
let region_to_remove = cache.ver_id_to_region.remove(&ver_id).unwrap();
cache.key_to_ver_id.remove(®ion_to_remove.start_key());
cache.id_to_ver_id.remove(®ion_to_remove.id());
}
cache
.key_to_ver_id
.insert(region.start_key(), region.ver_id());
cache.id_to_ver_id.insert(region.id(), region.ver_id());
cache.ver_id_to_region.insert(region.ver_id(), region);
}
pub async fn update_leader(
&self,
ver_id: crate::region::RegionVerId,
leader: metapb::Peer,
) -> Result<()> {
let mut cache = self.region_cache.write().await;
let region_entry = cache.ver_id_to_region.get_mut(&ver_id);
if let Some(region) = region_entry {
region.leader = Some(leader);
}
Ok(())
}
pub async fn invalidate_region_cache(&self, ver_id: crate::region::RegionVerId) {
let mut cache = self.region_cache.write().await;
let region_entry = cache.ver_id_to_region.get(&ver_id);
if let Some(region) = region_entry {
let id = region.id();
let start_key = region.start_key();
cache.ver_id_to_region.remove(&ver_id);
cache.id_to_ver_id.remove(&id);
cache.key_to_ver_id.remove(&start_key);
}
}
pub async fn invalidate_store_cache(&self, store_id: StoreId) {
let mut cache = self.store_cache.write().await;
cache.remove(&store_id);
}
pub async fn read_through_all_stores(&self) -> Result<Vec<Store>> {
let stores = self
.inner_client
.clone()
.get_all_stores()
.await?
.into_iter()
.filter(is_valid_tikv_store)
.collect::<Vec<_>>();
for store in &stores {
self.store_cache
.write()
.await
.insert(store.id, store.clone());
}
Ok(stores)
}
}
const ENGINE_LABEL_KEY: &str = "engine";
const ENGINE_LABEL_TIFLASH: &str = "tiflash";
const ENGINE_LABEL_TIFLASH_COMPUTE: &str = "tiflash_compute";
fn is_valid_tikv_store(store: &metapb::Store) -> bool {
if metapb::StoreState::try_from(store.state).unwrap() == metapb::StoreState::Tombstone {
return false;
}
let is_tiflash = store.labels.iter().any(|label| {
label.key == ENGINE_LABEL_KEY
&& (label.value == ENGINE_LABEL_TIFLASH || label.value == ENGINE_LABEL_TIFLASH_COMPUTE)
});
!is_tiflash
}
#[cfg(test)]
mod test {
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering::SeqCst;
use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::Mutex;
use super::RegionCache;
use crate::common::Error;
use crate::pd::RetryClientTrait;
use crate::proto::keyspacepb;
use crate::proto::metapb::RegionEpoch;
use crate::proto::metapb::{self};
use crate::region::RegionId;
use crate::region::RegionWithLeader;
use crate::region_cache::is_valid_tikv_store;
use crate::Key;
use crate::Result;
#[derive(Default)]
struct MockRetryClient {
pub regions: Mutex<HashMap<RegionId, RegionWithLeader>>,
pub get_region_count: AtomicU64,
}
#[async_trait]
impl RetryClientTrait for MockRetryClient {
async fn get_region(
self: Arc<Self>,
key: Vec<u8>,
) -> Result<crate::region::RegionWithLeader> {
self.get_region_count.fetch_add(1, SeqCst);
self.regions
.lock()
.await
.iter()
.filter(|(_, r)| r.contains(&key.clone().into()))
.map(|(_, r)| r.clone())
.next()
.ok_or_else(|| Error::StringError("MockRetryClient: region not found".to_owned()))
}
async fn get_region_by_id(
self: Arc<Self>,
region_id: crate::region::RegionId,
) -> Result<crate::region::RegionWithLeader> {
self.get_region_count.fetch_add(1, SeqCst);
self.regions
.lock()
.await
.iter()
.filter(|(id, _)| id == &®ion_id)
.map(|(_, r)| r.clone())
.next()
.ok_or_else(|| Error::StringError("MockRetryClient: region not found".to_owned()))
}
async fn get_store(
self: Arc<Self>,
_id: crate::region::StoreId,
) -> Result<crate::proto::metapb::Store> {
todo!()
}
async fn get_all_stores(self: Arc<Self>) -> Result<Vec<crate::proto::metapb::Store>> {
todo!()
}
async fn get_timestamp(self: Arc<Self>) -> Result<crate::proto::pdpb::Timestamp> {
todo!()
}
async fn update_safepoint(self: Arc<Self>, _safepoint: u64) -> Result<bool> {
todo!()
}
async fn load_keyspace(&self, _keyspace: &str) -> Result<keyspacepb::KeyspaceMeta> {
unimplemented!()
}
}
#[tokio::test]
async fn cache_is_used() -> Result<()> {
let retry_client = Arc::new(MockRetryClient::default());
let cache = RegionCache::new(retry_client.clone());
retry_client.regions.lock().await.insert(
1,
RegionWithLeader {
region: metapb::Region {
id: 1,
start_key: vec![],
end_key: vec![100],
region_epoch: Some(RegionEpoch {
conf_ver: 0,
version: 0,
}),
..Default::default()
},
leader: Some(metapb::Peer {
store_id: 1,
..Default::default()
}),
},
);
retry_client.regions.lock().await.insert(
2,
RegionWithLeader {
region: metapb::Region {
id: 2,
start_key: vec![101],
end_key: vec![],
region_epoch: Some(RegionEpoch {
conf_ver: 0,
version: 0,
}),
..Default::default()
},
leader: Some(metapb::Peer {
store_id: 2,
..Default::default()
}),
},
);
assert_eq!(retry_client.get_region_count.load(SeqCst), 0);
assert_eq!(cache.get_region_by_id(1).await?.end_key(), vec![100].into());
assert_eq!(retry_client.get_region_count.load(SeqCst), 1);
assert_eq!(cache.get_region_by_id(1).await?.end_key(), vec![100].into());
assert_eq!(retry_client.get_region_count.load(SeqCst), 1);
cache
.invalidate_region_cache(cache.get_region_by_id(1).await?.ver_id())
.await;
assert_eq!(cache.get_region_by_id(1).await?.end_key(), vec![100].into());
assert_eq!(retry_client.get_region_count.load(SeqCst), 2);
cache
.update_leader(
cache.get_region_by_id(2).await?.ver_id(),
metapb::Peer {
store_id: 102,
..Default::default()
},
)
.await?;
assert_eq!(
cache.get_region_by_id(2).await?.leader.unwrap().store_id,
102
);
Ok(())
}
#[tokio::test]
async fn test_add_disjoint_regions() {
let retry_client = Arc::new(MockRetryClient::default());
let cache = RegionCache::new(retry_client.clone());
let region1 = region(1, vec![], vec![10]);
let region2 = region(2, vec![10], vec![20]);
let region3 = region(3, vec![30], vec![]);
cache.add_region(region1.clone()).await;
cache.add_region(region2.clone()).await;
cache.add_region(region3.clone()).await;
let mut expected_cache = BTreeMap::new();
expected_cache.insert(vec![].into(), region1);
expected_cache.insert(vec![10].into(), region2);
expected_cache.insert(vec![30].into(), region3);
assert(&cache, &expected_cache).await
}
#[tokio::test]
async fn test_add_intersecting_regions() {
let retry_client = Arc::new(MockRetryClient::default());
let cache = RegionCache::new(retry_client.clone());
cache.add_region(region(1, vec![], vec![10])).await;
cache.add_region(region(2, vec![10], vec![20])).await;
cache.add_region(region(3, vec![30], vec![40])).await;
cache.add_region(region(4, vec![50], vec![60])).await;
cache.add_region(region(5, vec![20], vec![35])).await;
let mut expected_cache: BTreeMap<Key, _> = BTreeMap::new();
expected_cache.insert(vec![].into(), region(1, vec![], vec![10]));
expected_cache.insert(vec![10].into(), region(2, vec![10], vec![20]));
expected_cache.insert(vec![20].into(), region(5, vec![20], vec![35]));
expected_cache.insert(vec![50].into(), region(4, vec![50], vec![60]));
assert(&cache, &expected_cache).await;
cache.add_region(region(6, vec![15], vec![25])).await;
let mut expected_cache = BTreeMap::new();
expected_cache.insert(vec![].into(), region(1, vec![], vec![10]));
expected_cache.insert(vec![15].into(), region(6, vec![15], vec![25]));
expected_cache.insert(vec![50].into(), region(4, vec![50], vec![60]));
assert(&cache, &expected_cache).await;
cache.add_region(region(7, vec![20], vec![])).await;
let mut expected_cache = BTreeMap::new();
expected_cache.insert(vec![].into(), region(1, vec![], vec![10]));
expected_cache.insert(vec![20].into(), region(7, vec![20], vec![]));
assert(&cache, &expected_cache).await;
cache.add_region(region(8, vec![], vec![15])).await;
let mut expected_cache = BTreeMap::new();
expected_cache.insert(vec![].into(), region(8, vec![], vec![15]));
expected_cache.insert(vec![20].into(), region(7, vec![20], vec![]));
assert(&cache, &expected_cache).await;
}
#[tokio::test]
async fn test_get_region_by_key() -> Result<()> {
let retry_client = Arc::new(MockRetryClient::default());
let cache = RegionCache::new(retry_client.clone());
let region1 = region(1, vec![], vec![10]);
let region2 = region(2, vec![10], vec![20]);
let region3 = region(3, vec![30], vec![40]);
let region4 = region(4, vec![50], vec![]);
cache.add_region(region1.clone()).await;
cache.add_region(region2.clone()).await;
cache.add_region(region3.clone()).await;
cache.add_region(region4.clone()).await;
assert_eq!(
cache.get_region_by_key(&vec![].into()).await?,
region1.clone()
);
assert_eq!(
cache.get_region_by_key(&vec![5].into()).await?,
region1.clone()
);
assert_eq!(
cache.get_region_by_key(&vec![10].into()).await?,
region2.clone()
);
assert!(cache.get_region_by_key(&vec![20].into()).await.is_err());
assert!(cache.get_region_by_key(&vec![25].into()).await.is_err());
assert_eq!(cache.get_region_by_key(&vec![60].into()).await?, region4);
Ok(())
}
async fn assert(
cache: &RegionCache<MockRetryClient>,
expected_cache: &BTreeMap<Key, RegionWithLeader>,
) {
let guard = cache.region_cache.read().await;
let mut actual_keys = guard.ver_id_to_region.values().collect::<Vec<_>>();
let mut expected_keys = expected_cache.values().collect::<Vec<_>>();
actual_keys.sort_by_cached_key(|r| r.id());
expected_keys.sort_by_cached_key(|r| r.id());
assert_eq!(actual_keys, expected_keys);
assert_eq!(
guard.key_to_ver_id.keys().collect::<HashSet<_>>(),
expected_cache.keys().collect::<HashSet<_>>()
)
}
fn region(id: RegionId, start_key: Vec<u8>, end_key: Vec<u8>) -> RegionWithLeader {
let mut region = RegionWithLeader::default();
region.region.id = id;
region.region.start_key = start_key;
region.region.end_key = end_key;
region.region.region_epoch = Some(RegionEpoch {
conf_ver: 0,
version: 0,
});
region
}
#[test]
fn test_is_valid_tikv_store() {
let mut store = metapb::Store::default();
assert!(is_valid_tikv_store(&store));
store.state = metapb::StoreState::Tombstone.into();
assert!(!is_valid_tikv_store(&store));
store.state = metapb::StoreState::Up.into();
assert!(is_valid_tikv_store(&store));
store.labels.push(metapb::StoreLabel {
key: "some_key".to_owned(),
value: "some_value".to_owned(),
});
assert!(is_valid_tikv_store(&store));
store.labels.push(metapb::StoreLabel {
key: "engine".to_owned(),
value: "tiflash".to_owned(),
});
assert!(!is_valid_tikv_store(&store));
store.labels[1].value = "tiflash_compute".to_string();
assert!(!is_valid_tikv_store(&store));
store.labels[1].value = "other".to_string();
assert!(is_valid_tikv_store(&store));
}
}