use crate::frame::response::event::{Event, StatusChangeEvent};
use crate::routing::Token;
use crate::transport::connection::{Connection, VerifiedKeyspaceName};
use crate::transport::connection_pool::PoolConfig;
use crate::transport::errors::QueryError;
use crate::transport::node::Node;
use crate::transport::topology::{Keyspace, TopologyInfo, TopologyReader};
use arc_swap::ArcSwap;
use futures::future::join_all;
use futures::{future::RemoteHandle, FutureExt};
use itertools::Itertools;
use std::collections::{BTreeMap, HashMap};
use std::net::SocketAddr;
use std::sync::Arc;
use tracing::{debug, warn};
pub struct Cluster {
data: Arc<ArcSwap<ClusterData>>,
refresh_channel: tokio::sync::mpsc::Sender<RefreshRequest>,
use_keyspace_channel: tokio::sync::mpsc::Sender<UseKeyspaceRequest>,
_worker_handle: RemoteHandle<()>,
}
#[derive(Clone)]
pub struct Datacenter {
pub nodes: Vec<Arc<Node>>,
pub rack_count: usize,
}
#[derive(Clone)]
pub struct ClusterData {
pub known_peers: HashMap<SocketAddr, Arc<Node>>, pub ring: BTreeMap<Token, Arc<Node>>, pub keyspaces: HashMap<String, Keyspace>,
pub all_nodes: Vec<Arc<Node>>,
pub datacenters: HashMap<String, Datacenter>,
}
struct ClusterWorker {
cluster_data: Arc<ArcSwap<ClusterData>>,
topology_reader: TopologyReader,
pool_config: PoolConfig,
refresh_channel: tokio::sync::mpsc::Receiver<RefreshRequest>,
use_keyspace_channel: tokio::sync::mpsc::Receiver<UseKeyspaceRequest>,
server_events_channel: tokio::sync::mpsc::Receiver<Event>,
used_keyspace: Option<VerifiedKeyspaceName>,
}
#[derive(Debug)]
struct RefreshRequest {
response_chan: tokio::sync::oneshot::Sender<Result<(), QueryError>>,
}
#[derive(Debug)]
struct UseKeyspaceRequest {
keyspace_name: VerifiedKeyspaceName,
response_chan: tokio::sync::oneshot::Sender<Result<(), QueryError>>,
}
impl Cluster {
pub async fn new(
initial_peers: &[SocketAddr],
pool_config: PoolConfig,
) -> Result<Cluster, QueryError> {
let cluster_data = Arc::new(ArcSwap::from(Arc::new(ClusterData {
known_peers: HashMap::new(),
ring: BTreeMap::new(),
keyspaces: HashMap::new(),
all_nodes: Vec::new(),
datacenters: HashMap::new(),
})));
let (refresh_sender, refresh_receiver) = tokio::sync::mpsc::channel(32);
let (use_keyspace_sender, use_keyspace_receiver) = tokio::sync::mpsc::channel(32);
let (server_events_sender, server_events_receiver) = tokio::sync::mpsc::channel(32);
let worker = ClusterWorker {
cluster_data: cluster_data.clone(),
topology_reader: TopologyReader::new(
initial_peers,
pool_config.connection_config.clone(),
server_events_sender,
),
pool_config,
refresh_channel: refresh_receiver,
server_events_channel: server_events_receiver,
use_keyspace_channel: use_keyspace_receiver,
used_keyspace: None,
};
let (fut, worker_handle) = worker.work().remote_handle();
tokio::spawn(fut);
let result = Cluster {
data: cluster_data,
refresh_channel: refresh_sender,
use_keyspace_channel: use_keyspace_sender,
_worker_handle: worker_handle,
};
result.refresh_topology().await?;
Ok(result)
}
pub fn get_data(&self) -> Arc<ClusterData> {
self.data.load_full()
}
pub async fn refresh_topology(&self) -> Result<(), QueryError> {
let (response_sender, response_receiver) = tokio::sync::oneshot::channel();
self.refresh_channel
.send(RefreshRequest {
response_chan: response_sender,
})
.await
.expect("Bug in Cluster::refresh_topology sending");
response_receiver
.await
.expect("Bug in Cluster::refresh_topology receiving")
}
pub async fn use_keyspace(
&self,
keyspace_name: VerifiedKeyspaceName,
) -> Result<(), QueryError> {
let (response_sender, response_receiver) = tokio::sync::oneshot::channel();
self.use_keyspace_channel
.send(UseKeyspaceRequest {
keyspace_name,
response_chan: response_sender,
})
.await
.expect("Bug in Cluster::use_keyspace sending");
response_receiver.await.unwrap() }
pub async fn get_working_connections(&self) -> Result<Vec<Arc<Connection>>, QueryError> {
let cluster_data: Arc<ClusterData> = self.get_data();
let peers = &cluster_data.known_peers;
let mut result: Vec<Arc<Connection>> = Vec::with_capacity(peers.len());
let mut last_error: Option<QueryError> = None;
for node in peers.values() {
match node.get_working_connections() {
Ok(conns) => result.extend(conns),
Err(e) => last_error = Some(e),
}
}
if result.is_empty() {
return Err(last_error.unwrap()); }
Ok(result)
}
}
impl ClusterData {
pub fn ring_range<'a>(&'a self, t: &Token) -> impl Iterator<Item = Arc<Node>> + 'a {
let before_wrap = self.ring.range(t..).map(|(_token, node)| node.clone());
let after_wrap = self.ring.values().cloned();
before_wrap.chain(after_wrap).take(self.ring.len())
}
fn update_rack_count(datacenters: &mut HashMap<String, Datacenter>) {
for datacenter in datacenters.values_mut() {
datacenter.rack_count = datacenter
.nodes
.iter()
.filter_map(|node| node.rack.clone())
.unique()
.count();
}
}
pub async fn wait_until_all_pools_are_initialized(&self) {
for node in self.all_nodes.iter() {
node.wait_until_pool_initialized().await;
}
}
pub fn new(
info: TopologyInfo,
pool_config: &PoolConfig,
known_peers: &HashMap<SocketAddr, Arc<Node>>,
used_keyspace: &Option<VerifiedKeyspaceName>,
) -> Self {
let mut new_known_peers: HashMap<SocketAddr, Arc<Node>> =
HashMap::with_capacity(info.peers.len());
let mut ring: BTreeMap<Token, Arc<Node>> = BTreeMap::new();
let mut datacenters: HashMap<String, Datacenter> = HashMap::new();
let mut all_nodes: Vec<Arc<Node>> = Vec::with_capacity(info.peers.len());
for peer in info.peers {
let node: Arc<Node> = match known_peers.get(&peer.address) {
Some(node) if node.datacenter == peer.datacenter && node.rack == peer.rack => {
node.clone()
}
_ => Arc::new(Node::new(
peer.address,
pool_config.clone(),
peer.datacenter,
peer.rack,
used_keyspace.clone(),
)),
};
new_known_peers.insert(peer.address, node.clone());
if let Some(dc) = &node.datacenter {
match datacenters.get_mut(dc) {
Some(v) => v.nodes.push(node.clone()),
None => {
let v = Datacenter {
nodes: vec![node.clone()],
rack_count: 0,
};
datacenters.insert(dc.clone(), v);
}
}
}
for token in peer.tokens {
ring.insert(token, node.clone());
}
all_nodes.push(node);
}
Self::update_rack_count(&mut datacenters);
ClusterData {
known_peers: new_known_peers,
ring,
keyspaces: info.keyspaces,
all_nodes,
datacenters,
}
}
}
impl ClusterWorker {
pub async fn work(mut self) {
use tokio::time::{Duration, Instant};
let refresh_duration = Duration::from_secs(60); let mut last_refresh_time = Instant::now();
loop {
let mut cur_request: Option<RefreshRequest> = None;
let sleep_until: Instant = last_refresh_time
.checked_add(refresh_duration)
.unwrap_or_else(Instant::now);
let sleep_future = tokio::time::sleep_until(sleep_until);
tokio::pin!(sleep_future);
tokio::select! {
_ = sleep_future => {},
recv_res = self.refresh_channel.recv() => {
match recv_res {
Some(request) => cur_request = Some(request),
None => return, }
}
recv_res = self.server_events_channel.recv() => {
if let Some(event) = recv_res {
debug!("Received server event: {:?}", event);
match event {
Event::TopologyChange(_) => (), Event::StatusChange(status) => {
match status {
StatusChangeEvent::Down(addr) => self.change_node_down_marker(addr, true),
StatusChangeEvent::Up(addr) => self.change_node_down_marker(addr, false),
}
continue;
},
_ => continue, }
} else {
return;
}
}
recv_res = self.use_keyspace_channel.recv() => {
match recv_res {
Some(request) => {
self.used_keyspace = Some(request.keyspace_name.clone());
let cluster_data = self.cluster_data.load_full();
let use_keyspace_future = Self::handle_use_keyspace_request(cluster_data, request);
tokio::spawn(use_keyspace_future);
},
None => return, }
continue; }
}
debug!("Requesting topology refresh");
last_refresh_time = Instant::now();
let refresh_res = self.perform_refresh().await;
if let Some(request) = cur_request {
let _ = request.response_chan.send(refresh_res);
}
}
}
fn change_node_down_marker(&mut self, addr: SocketAddr, is_down: bool) {
let cluster_data = self.cluster_data.load_full();
let node = match cluster_data.known_peers.get(&addr) {
Some(node) => node,
None => {
warn!("Unknown node address {}", addr);
return;
}
};
node.change_down_marker(is_down);
}
async fn handle_use_keyspace_request(
cluster_data: Arc<ClusterData>,
request: UseKeyspaceRequest,
) {
let result = Self::send_use_keyspace(cluster_data, &request.keyspace_name).await;
let _ = request.response_chan.send(result);
}
async fn send_use_keyspace(
cluster_data: Arc<ClusterData>,
keyspace_name: &VerifiedKeyspaceName,
) -> Result<(), QueryError> {
let mut use_keyspace_futures = Vec::new();
for node in cluster_data.known_peers.values() {
let fut = node.use_keyspace(keyspace_name.clone());
use_keyspace_futures.push(fut);
}
let use_keyspace_results: Vec<Result<(), QueryError>> =
join_all(use_keyspace_futures).await;
let mut was_ok: bool = false;
let mut io_error: Option<Arc<std::io::Error>> = None;
for result in use_keyspace_results {
match result {
Ok(()) => was_ok = true,
Err(err) => match err {
QueryError::IoError(io_err) => io_error = Some(io_err),
_ => return Err(err),
},
}
}
if was_ok {
return Ok(());
}
Err(QueryError::IoError(io_error.unwrap()))
}
async fn perform_refresh(&mut self) -> Result<(), QueryError> {
let topo_info = self.topology_reader.read_topology_info().await?;
let cluster_data: Arc<ClusterData> = self.cluster_data.load_full();
let new_cluster_data = Arc::new(ClusterData::new(
topo_info,
&self.pool_config,
&cluster_data.known_peers,
&self.used_keyspace,
));
new_cluster_data
.wait_until_all_pools_are_initialized()
.await;
self.update_cluster_data(new_cluster_data);
Ok(())
}
fn update_cluster_data(&mut self, new_cluster_data: Arc<ClusterData>) {
self.cluster_data.store(new_cluster_data);
}
}