use super::cache::{SharedLruCache, new_shared_cache};
use super::topology::ClusterTopology;
use ahash::AHasher;
use spire_proto::spiredb::cluster::{
GetTableRegionsRequest, Region, RegionList, cluster_service_client::ClusterServiceClient,
};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use tonic::transport::Channel;
const DEFAULT_REGION_CACHE_CAPACITY: usize = 128;
#[derive(Debug, Clone)]
pub struct RegionInfo {
pub region_id: u64,
pub start_key: Vec<u8>,
pub end_key: Vec<u8>,
pub leader_store_id: u64,
#[allow(dead_code)]
pub peer_store_ids: Vec<u64>,
}
impl From<Region> for RegionInfo {
fn from(r: Region) -> Self {
Self {
region_id: r.id,
start_key: r.start_key,
end_key: r.end_key,
leader_store_id: r.leader_store_id,
peer_store_ids: r.peers.into_iter().map(|p| p.store_id).collect(),
}
}
}
#[derive(Clone)]
struct CachedRegions {
regions: Arc<Vec<RegionInfo>>,
}
pub struct RegionRouter {
cluster_client: ClusterServiceClient<Channel>,
topology: Arc<ClusterTopology>,
region_cache: SharedLruCache<CachedRegions>,
}
impl std::fmt::Debug for RegionRouter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RegionRouter")
.field("region_cache_len", &self.region_cache.len())
.field("topology_stores", &self.topology.store_count())
.finish()
}
}
impl RegionRouter {
pub fn new(
cluster_client: ClusterServiceClient<Channel>,
topology: Arc<ClusterTopology>,
) -> Self {
Self::with_capacity(cluster_client, topology, DEFAULT_REGION_CACHE_CAPACITY)
}
pub fn with_capacity(
cluster_client: ClusterServiceClient<Channel>,
topology: Arc<ClusterTopology>,
region_cache_capacity: usize,
) -> Self {
Self {
cluster_client,
topology,
region_cache: new_shared_cache(region_cache_capacity),
}
}
fn hash_key<T: Hash>(key: &T) -> u64 {
let mut hasher = AHasher::default();
key.hash(&mut hasher);
hasher.finish()
}
pub async fn get_table_regions(
&self,
table: &str,
) -> Result<Arc<Vec<RegionInfo>>, tonic::Status> {
let hash = Self::hash_key(&table);
if let Some(cached) = self.region_cache.get_and_touch(hash) {
return Ok(cached.regions);
}
self.refresh_table_regions(table).await
}
async fn refresh_table_regions(
&self,
table: &str,
) -> Result<Arc<Vec<RegionInfo>>, tonic::Status> {
let hash = Self::hash_key(&table);
let request = GetTableRegionsRequest {
table_name: table.to_string(),
};
let mut client = self.cluster_client.clone();
let response = client.get_table_regions(request).await?;
let region_list: RegionList = response.into_inner();
let regions: Vec<RegionInfo> = region_list
.regions
.into_iter()
.map(RegionInfo::from)
.collect();
let regions = Arc::new(regions);
let cached = CachedRegions {
regions: regions.clone(),
};
self.region_cache.insert(hash, cached);
log::debug!(
"Cached {} regions for table '{}' (hash: {})",
regions.len(),
table,
hash
);
Ok(regions)
}
pub fn get_store_address(&self, store_id: u64) -> Result<String, tonic::Status> {
self.topology.get_store_address(store_id).ok_or_else(|| {
tonic::Status::not_found(format!("Store {} not found in topology", store_id))
})
}
#[allow(dead_code)]
pub async fn get_regions_for_range(
&self,
table: &str,
start_key: &[u8],
end_key: &[u8],
) -> Result<Vec<RegionInfo>, tonic::Status> {
let all_regions = self.get_table_regions(table).await?;
let matching: Vec<RegionInfo> = all_regions
.iter()
.filter(|r| {
(r.end_key.is_empty() || r.end_key.as_slice() > start_key)
&& (end_key.is_empty() || r.start_key.as_slice() < end_key)
})
.cloned()
.collect();
Ok(matching)
}
#[allow(dead_code)]
pub fn invalidate_table(&self, table: &str) {
let hash = Self::hash_key(&table);
self.region_cache.remove(hash);
}
#[allow(dead_code)]
pub fn cache_stats(&self) -> CacheStats {
CacheStats {
region_cache_size: self.region_cache.len(),
region_cache_capacity: self.region_cache.capacity(),
topology_store_count: self.topology.store_count(),
}
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct CacheStats {
pub region_cache_size: usize,
pub region_cache_capacity: usize,
pub topology_store_count: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hash_key() {
let hash1 = RegionRouter::hash_key(&"users");
let hash2 = RegionRouter::hash_key(&"users");
let hash3 = RegionRouter::hash_key(&"orders");
assert_eq!(hash1, hash2);
assert_ne!(hash1, hash3);
}
}