use std::{
collections::HashMap,
sync::{Arc, Mutex},
time::Duration,
};
use engula_api::{server::v1::*, v1::*};
use tokio_stream::StreamExt;
use tonic::Streaming;
use tracing::{info, warn};
use crate::RootClient;
#[derive(Debug, Clone)]
pub struct Router {
state: Arc<Mutex<State>>,
}
#[derive(Debug, Clone, Default)]
pub struct State {
node_id_lookup: HashMap<u64, String >,
db_id_lookup: HashMap<u64, DatabaseDesc>,
db_name_lookup: HashMap<String, u64>,
co_id_lookup: HashMap<u64, CollectionDesc>,
co_name_lookup: HashMap<(u64 /* db */, String), u64>,
co_shards_lookup: HashMap<u64 , Vec<ShardDesc>>,
shard_group_lookup: HashMap<u64 , u64 >,
group_id_lookup: HashMap<u64 , RouterGroupState>,
}
#[derive(Debug, Clone, Default)]
pub struct RouterGroupState {
pub id: u64,
pub epoch: u64,
pub leader_state: Option<(/* id */ u64, /* term */ u64)>,
pub replicas: HashMap<u64, ReplicaDesc>,
}
impl Router {
pub async fn new(root_client: RootClient) -> Self {
let state = Arc::new(Mutex::new(State::default()));
let state_clone = state.clone();
tokio::spawn(async move {
state_main(state_clone, root_client).await;
});
Self { state }
}
pub fn find_shard(
&self,
desc: CollectionDesc,
key: &[u8],
) -> Result<(u64, ShardDesc), crate::Error> {
if let Some(collection_desc::Partition::Hash(collection_desc::HashPartition { slots })) =
desc.partition
{
let crc = crc32fast::hash(key);
let slot = crc % (slots as u32);
let state = self.state.lock().unwrap();
let shards = state
.co_shards_lookup
.get(&desc.id)
.ok_or_else(|| crate::Error::NotFound(format!("shard (key={:?})", key)))?;
if slots != shards.len() as u32 {
return Err(crate::Error::NotFound("expired shard info".into()));
}
let shard = shards
.iter()
.find(|s| {
if let shard_desc::Partition::Hash(p) = s.partition.as_ref().unwrap() {
if p.slot_id == slot {
return true;
}
}
false
})
.unwrap();
let group_id = state
.shard_group_lookup
.get(&shard.id)
.cloned()
.ok_or_else(|| crate::Error::NotFound(format!("shard (key={key:?}) group")))?;
return Ok((group_id, shard.clone()));
}
let state = self.state.lock().unwrap();
let shards = state
.co_shards_lookup
.get(&desc.id)
.ok_or_else(|| crate::Error::NotFound(format!("shard (key={:?})", key)))?;
for shard in shards {
if let Some(shard_desc::Partition::Range(shard_desc::RangePartition { start, end })) =
shard.partition.clone()
{
if start.as_slice() > key {
continue;
}
if (end.as_slice() < key) || (end.is_empty())
{
let group_id = state
.shard_group_lookup
.get(&shard.id)
.cloned()
.ok_or_else(|| {
crate::Error::NotFound(format!("shard (key={key:?}) group"))
})?;
return Ok((group_id, shard.clone()));
}
}
}
Err(crate::Error::NotFound(format!("shard (key={:?})", key)))
}
pub fn find_group_by_shard(&self, shard: u64) -> Result<RouterGroupState, crate::Error> {
let state = self.state.lock().unwrap();
let group = state
.shard_group_lookup
.get(&shard)
.and_then(|id| state.group_id_lookup.get(id))
.cloned();
group.ok_or_else(|| crate::Error::NotFound(format!("group (shard={:?})", shard)))
}
pub fn find_group(&self, id: u64) -> Result<RouterGroupState, crate::Error> {
let state = self.state.lock().unwrap();
let group = state.group_id_lookup.get(&id).cloned();
group.ok_or_else(|| crate::Error::NotFound(format!("group (id={:?})", id)))
}
pub fn find_node_addr(&self, id: u64) -> Result<String, crate::Error> {
let state = self.state.lock().unwrap();
let addr = state.node_id_lookup.get(&id).cloned();
addr.ok_or_else(|| crate::Error::NotFound(format!("node_addr (node_id={:?})", id)))
}
pub fn total_nodes(&self) -> usize {
self.state.lock().unwrap().node_id_lookup.len()
}
}
async fn state_main(state: Arc<Mutex<State>>, root_client: RootClient) {
info!("start watching events...");
let mut interval = 1;
loop {
let cur_group_epochs = {
let state = state.lock().unwrap();
state
.group_id_lookup
.iter()
.map(|(id, s)| (*id, s.epoch))
.collect()
};
let events = match root_client.watch(cur_group_epochs).await {
Ok(events) => events,
Err(e) => {
warn!(err = ?e, "watch events");
tokio::time::sleep(Duration::from_millis(interval)).await;
interval = std::cmp::min(interval * 2, 1000);
continue;
}
};
interval = 1;
watch_events(state.as_ref(), events).await;
}
}
async fn watch_events(state: &Mutex<State>, mut events: Streaming<WatchResponse>) {
use watch_response::{delete_event::Event as DeleteEvent, update_event::Event as UpdateEvent};
let mut cached_group_states: HashMap<u64, GroupState> = HashMap::default();
while let Some(event) = events.next().await {
let (updates, deletes) = match event {
Ok(resp) => (resp.updates, resp.deletes),
Err(status) => {
warn!("WatchEvent error: {}", status);
continue;
}
};
for update in updates {
let event = match update.event {
None => continue,
Some(event) => event,
};
let mut state = state.lock().unwrap();
match event {
UpdateEvent::Node(node_desc) => {
state.node_id_lookup.insert(node_desc.id, node_desc.addr);
}
UpdateEvent::Group(group_desc) => {
let (id, epoch) = (group_desc.id, group_desc.epoch);
let (shards, replicas) = (group_desc.shards, group_desc.replicas);
let replicas = replicas
.into_iter()
.map(|d| (d.id, d))
.collect::<HashMap<u64, ReplicaDesc>>();
let mut group_state = RouterGroupState {
id,
epoch,
leader_state: None,
replicas,
};
if let Some(old_state) = state.group_id_lookup.get(&id) {
group_state.leader_state = old_state.leader_state;
} else if let Some(cached_state) = cached_group_states.remove(&id) {
group_state.leader_state = leader_state(&cached_state);
}
state.group_id_lookup.insert(id, group_state);
for shard in shards {
state.shard_group_lookup.insert(shard.id, id);
let co_shards_lookup = &mut state.co_shards_lookup;
match co_shards_lookup.get_mut(&shard.collection_id) {
None => {
co_shards_lookup.insert(shard.collection_id, vec![shard]);
}
Some(shards) => {
shards.retain(|s| s.id != shard.id);
shards.push(shard);
}
}
}
}
UpdateEvent::GroupState(group_state) => {
let id = group_state.group_id;
if let Some(group) = state.group_id_lookup.get_mut(&id) {
group.leader_state = leader_state(&group_state);
} else {
cached_group_states.insert(id, group_state);
}
}
UpdateEvent::Database(db_desc) => {
let desc = db_desc.clone();
let (id, name) = (db_desc.id, db_desc.name);
if let Some(old_desc) = state.db_id_lookup.insert(id, desc) {
if old_desc.name != name {
state.db_name_lookup.remove(&name);
}
}
state.db_name_lookup.insert(name, id);
}
UpdateEvent::Collection(co_desc) => {
let desc = co_desc.clone();
let (id, name, db) = (co_desc.id, co_desc.name, co_desc.db);
if let Some(old_desc) = state.co_id_lookup.insert(id, desc) {
if old_desc.name != name {
state.co_name_lookup.remove(&(db, old_desc.name));
}
}
state.co_name_lookup.insert((db, name), id);
}
}
}
for delete in deletes {
let event = match delete.event {
None => continue,
Some(event) => event,
};
let mut state = state.lock().unwrap();
match event {
DeleteEvent::Node(node) => {
state.node_id_lookup.remove(&node);
}
DeleteEvent::Group(_) => todo!(),
DeleteEvent::GroupState(_) => todo!(),
DeleteEvent::Database(db) => {
if let Some(desc) = state.db_id_lookup.remove(&db) {
state.db_name_lookup.remove(desc.name.as_str());
}
}
DeleteEvent::Collection(co) => {
if let Some(desc) = state.co_id_lookup.remove(&co) {
state.co_name_lookup.remove(&(desc.db, desc.name));
}
}
}
}
}
}
#[inline]
fn leader_state(group_state: &GroupState) -> Option<(u64, u64)> {
if let Some(leader_id) = group_state.leader_id {
group_state
.replicas
.iter()
.find(|r| r.replica_id == leader_id)
.map(|r| (leader_id, r.term))
} else {
None
}
}