use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use ceresdbproto::storage::{self, RouteRequest};
use dashmap::DashMap;
use crate::{
errors::Result,
model::route::Endpoint,
rpc_client::{RpcClient, RpcContext},
Error,
};
#[async_trait]
pub trait Router: Send + Sync {
async fn route(&self, tables: &[String], ctx: &RpcContext) -> Result<Vec<Option<Endpoint>>>;
fn evict(&self, tables: &[String]);
}
pub struct RouterImpl {
default_endpoint: Endpoint,
cache: DashMap<String, Endpoint>,
rpc_client: Arc<dyn RpcClient>,
}
impl RouterImpl {
pub fn new(default_endpoint: Endpoint, rpc_client: Arc<dyn RpcClient>) -> Self {
Self {
default_endpoint,
cache: DashMap::new(),
rpc_client,
}
}
}
#[async_trait]
impl Router for RouterImpl {
async fn route(&self, tables: &[String], ctx: &RpcContext) -> Result<Vec<Option<Endpoint>>> {
assert!(ctx.database.is_some());
let mut target_endpoints = vec![Some(self.default_endpoint.clone()); tables.len()];
let misses = {
let mut misses = HashMap::new();
for (idx, table) in tables.iter().enumerate() {
match self.cache.get(table) {
Some(pair) => {
target_endpoints[idx] = Some(pair.value().clone());
}
None => {
misses.insert(table.clone(), idx);
}
}
}
misses
};
let req_ctx = storage::RequestContext {
database: ctx.database.clone().unwrap(),
};
let miss_tables = misses.keys().cloned().collect();
let req = RouteRequest {
context: Some(req_ctx),
tables: miss_tables,
};
let resp = self.rpc_client.route(ctx, req).await?;
for route in resp.routes {
if route.endpoint.is_none() {
continue;
}
let idx = misses.get(&route.table).ok_or_else(|| {
Error::Unknown(format!("Unknown table:{} in response", route.table))
})?;
let endpoint: Endpoint = route.endpoint.unwrap().into();
self.cache.insert(route.table, endpoint.clone());
target_endpoints[*idx] = Some(endpoint);
}
Ok(target_endpoints)
}
fn evict(&self, tables: &[String]) {
tables.iter().for_each(|e| {
self.cache.remove(e.as_str());
})
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use dashmap::DashMap;
use super::{Router, RouterImpl};
use crate::{
model::route::Endpoint,
rpc_client::{MockRpcClient, RpcContext},
};
#[tokio::test]
async fn test_basic_flow() {
let table1 = "table1".to_string();
let table2 = "table2".to_string();
let table3 = "table3".to_string();
let table4 = "table4".to_string();
let endpoint1 = Endpoint::new("192.168.0.1".to_string(), 11);
let endpoint2 = Endpoint::new("192.168.0.2".to_string(), 12);
let endpoint3 = Endpoint::new("192.168.0.3".to_string(), 13);
let endpoint4 = Endpoint::new("192.168.0.4".to_string(), 14);
let default_endpoint = Endpoint::new("192.168.0.5".to_string(), 15);
let route_table = Arc::new(DashMap::default());
let mock_rpc_client = MockRpcClient {
route_table: route_table.clone(),
};
mock_rpc_client
.route_table
.insert(table1.clone(), endpoint1.clone());
mock_rpc_client
.route_table
.insert(table2.clone(), endpoint2.clone());
let ctx = RpcContext {
database: Some("db".to_string()),
timeout: None,
};
let tables = vec![table1.clone(), table2.clone()];
let route_client = RouterImpl::new(default_endpoint.clone(), Arc::new(mock_rpc_client));
let route_res1 = route_client.route(&tables, &ctx).await.unwrap();
assert_eq!(&endpoint1, route_res1.get(0).unwrap().as_ref().unwrap());
assert_eq!(&endpoint2, route_res1.get(1).unwrap().as_ref().unwrap());
route_table.insert(table1.clone(), endpoint3.clone());
route_table.insert(table2.clone(), endpoint4.clone());
let route_res2 = route_client.route(&tables, &ctx).await.unwrap();
assert_eq!(&endpoint1, route_res2.get(0).unwrap().as_ref().unwrap());
assert_eq!(&endpoint2, route_res2.get(1).unwrap().as_ref().unwrap());
route_client.evict(&[table1.clone(), table2.clone()]);
let route_res3 = route_client.route(&tables, &ctx).await.unwrap();
assert_eq!(&endpoint3, route_res3.get(0).unwrap().as_ref().unwrap());
assert_eq!(&endpoint4, route_res3.get(1).unwrap().as_ref().unwrap());
let route_res4 = route_client.route(&[table3, table4], &ctx).await.unwrap();
assert_eq!(
&default_endpoint,
route_res4.get(0).unwrap().as_ref().unwrap()
);
assert_eq!(
&default_endpoint,
route_res4.get(1).unwrap().as_ref().unwrap()
);
}
}