use crate::routing::{Shard, ShardCount, Sharder, Token};
use crate::transport::errors::QueryError;
use crate::transport::{
connection,
connection::{Connection, ConnectionConfig, ErrorReceiver, VerifiedKeyspaceName},
};
use arc_swap::ArcSwap;
use futures::{future::RemoteHandle, stream::FuturesUnordered, Future, FutureExt, StreamExt};
use rand::Rng;
use std::convert::TryInto;
use std::io::ErrorKind;
use std::net::{IpAddr, SocketAddr};
use std::num::NonZeroUsize;
use std::pin::Pin;
use std::sync::{Arc, Weak};
use std::time::Duration;
use tokio::sync::{mpsc, Notify};
use tracing::{debug, trace, warn};
#[derive(Debug, Clone)]
pub enum PoolSize {
PerHost(NonZeroUsize),
PerShard(NonZeroUsize),
}
impl Default for PoolSize {
fn default() -> Self {
PoolSize::PerShard(NonZeroUsize::new(1).unwrap())
}
}
#[derive(Clone)]
pub struct PoolConfig {
pub connection_config: ConnectionConfig,
pub pool_size: PoolSize,
pub can_use_shard_aware_port: bool,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
connection_config: Default::default(),
pool_size: Default::default(),
can_use_shard_aware_port: true,
}
}
}
enum MaybePoolConnections {
Initializing,
Broken,
Ready(PoolConnections),
}
#[derive(Clone)]
enum PoolConnections {
NotSharded(Vec<Arc<Connection>>),
Sharded {
sharder: Sharder,
connections: Vec<Vec<Arc<Connection>>>,
},
}
pub struct NodeConnectionPool {
conns: Arc<ArcSwap<MaybePoolConnections>>,
use_keyspace_request_sender: mpsc::Sender<UseKeyspaceRequest>,
_refiller_handle: RemoteHandle<()>,
pool_updated_notify: Arc<Notify>,
}
impl NodeConnectionPool {
pub fn new(
address: IpAddr,
port: u16,
pool_config: PoolConfig,
current_keyspace: Option<VerifiedKeyspaceName>,
) -> Self {
let (use_keyspace_request_sender, use_keyspace_request_receiver) = mpsc::channel(1);
let pool_updated_notify = Arc::new(Notify::new());
let refiller = PoolRefiller::new(
address,
port,
pool_config,
current_keyspace,
pool_updated_notify.clone(),
);
let conns = refiller.get_shared_connections();
let (fut, handle) = refiller.run(use_keyspace_request_receiver).remote_handle();
tokio::spawn(fut);
Self {
conns,
use_keyspace_request_sender,
_refiller_handle: handle,
pool_updated_notify,
}
}
pub fn connection_for_token(&self, token: Token) -> Result<Arc<Connection>, QueryError> {
self.with_connections(|pool_conns| match pool_conns {
PoolConnections::NotSharded(conns) => {
Self::choose_random_connection_from_slice(conns).unwrap()
}
PoolConnections::Sharded {
sharder,
connections,
} => {
let shard: u16 = sharder
.shard_of(token)
.try_into()
.expect("Shard number doesn't fit in u16");
Self::connection_for_shard(shard, sharder.nr_shards, connections.as_slice())
}
})
}
pub fn random_connection(&self) -> Result<Arc<Connection>, QueryError> {
self.with_connections(|pool_conns| match pool_conns {
PoolConnections::NotSharded(conns) => {
Self::choose_random_connection_from_slice(conns).unwrap()
}
PoolConnections::Sharded {
sharder,
connections,
} => {
let shard: u16 = rand::thread_rng().gen_range(0..sharder.nr_shards.get());
Self::connection_for_shard(shard, sharder.nr_shards, connections.as_slice())
}
})
}
fn connection_for_shard(
shard: u16,
nr_shards: ShardCount,
shard_conns: &[Vec<Arc<Connection>>],
) -> Arc<Connection> {
if let Some(conn) = Self::choose_random_connection_from_slice(&shard_conns[shard as usize])
{
return conn;
}
let mut shards_to_try: Vec<u16> = (0..shard).chain(shard + 1..nr_shards.get()).collect();
while !shards_to_try.is_empty() {
let idx = rand::thread_rng().gen_range(0..shards_to_try.len());
let shard = shards_to_try.swap_remove(idx);
if let Some(conn) =
Self::choose_random_connection_from_slice(&shard_conns[shard as usize])
{
return conn;
}
}
unreachable!("could not find any connection in supposedly non-empty pool")
}
pub async fn use_keyspace(
&self,
keyspace_name: VerifiedKeyspaceName,
) -> Result<(), QueryError> {
let (response_sender, response_receiver) = tokio::sync::oneshot::channel();
self.use_keyspace_request_sender
.send(UseKeyspaceRequest {
keyspace_name,
response_sender,
})
.await
.expect("Bug in ConnectionKeeper::use_keyspace sending");
response_receiver.await.unwrap() }
pub async fn wait_until_initialized(&self) {
let notified = self.pool_updated_notify.notified();
if let MaybePoolConnections::Initializing = **self.conns.load() {
notified.await;
}
}
pub fn get_working_connections(&self) -> Result<Vec<Arc<Connection>>, QueryError> {
self.with_connections(|pool_conns| match pool_conns {
PoolConnections::NotSharded(conns) => conns.clone(),
PoolConnections::Sharded { connections, .. } => {
connections.iter().flatten().cloned().collect()
}
})
}
fn choose_random_connection_from_slice(v: &[Arc<Connection>]) -> Option<Arc<Connection>> {
if v.is_empty() {
None
} else if v.len() == 1 {
Some(v[0].clone())
} else {
let idx = rand::thread_rng().gen_range(0..v.len());
Some(v[idx].clone())
}
}
fn with_connections<T>(&self, f: impl FnOnce(&PoolConnections) -> T) -> Result<T, QueryError> {
let conns = self.conns.load_full();
match &*conns {
MaybePoolConnections::Ready(pool_connections) => Ok(f(pool_connections)),
_ => Err(QueryError::IoError(Arc::new(std::io::Error::new(
ErrorKind::Other,
"No connections in the pool",
)))),
}
}
}
const EXCESS_CONNECTION_BOUND_PER_SHARD_MULTIPLIER: usize = 10;
const MIN_FILL_BACKOFF: Duration = Duration::from_millis(50);
const MAX_FILL_BACKOFF: Duration = Duration::from_secs(10);
const FILL_BACKOFF_MULTIPLIER: u32 = 2;
struct RefillDelayStrategy {
current_delay: Duration,
}
impl RefillDelayStrategy {
fn new() -> Self {
Self {
current_delay: MIN_FILL_BACKOFF,
}
}
fn get_delay(&self) -> Duration {
self.current_delay
}
fn on_successful_fill(&mut self) {
self.current_delay = MIN_FILL_BACKOFF;
}
fn on_fill_error(&mut self) {
self.current_delay = std::cmp::min(
MAX_FILL_BACKOFF,
self.current_delay * FILL_BACKOFF_MULTIPLIER,
);
}
}
struct PoolRefiller {
address: IpAddr,
regular_port: u16,
pool_config: PoolConfig,
shard_aware_port: Option<u16>,
sharder: Option<Sharder>,
shared_conns: Arc<ArcSwap<MaybePoolConnections>>,
conns: Vec<Vec<Arc<Connection>>>,
had_error_since_last_refill: bool,
refill_delay_strategy: RefillDelayStrategy,
ready_connections:
FuturesUnordered<Pin<Box<dyn Future<Output = OpenedConnectionEvent> + Send + 'static>>>,
connection_errors:
FuturesUnordered<Pin<Box<dyn Future<Output = BrokenConnectionEvent> + Send + 'static>>>,
excess_connections: Vec<Arc<Connection>>,
current_keyspace: Option<VerifiedKeyspaceName>,
pool_updated_notify: Arc<Notify>,
}
#[derive(Debug)]
struct UseKeyspaceRequest {
keyspace_name: VerifiedKeyspaceName,
response_sender: tokio::sync::oneshot::Sender<Result<(), QueryError>>,
}
impl PoolRefiller {
pub fn new(
address: IpAddr,
port: u16,
pool_config: PoolConfig,
current_keyspace: Option<VerifiedKeyspaceName>,
pool_updated_notify: Arc<Notify>,
) -> Self {
let conns = vec![Vec::new()];
let shared_conns = Arc::new(ArcSwap::new(Arc::new(MaybePoolConnections::Initializing)));
Self {
address,
regular_port: port,
pool_config,
shard_aware_port: None,
sharder: None,
shared_conns,
conns,
had_error_since_last_refill: false,
refill_delay_strategy: RefillDelayStrategy::new(),
ready_connections: FuturesUnordered::new(),
connection_errors: FuturesUnordered::new(),
excess_connections: Vec::new(),
current_keyspace,
pool_updated_notify,
}
}
pub fn get_shared_connections(&self) -> Arc<ArcSwap<MaybePoolConnections>> {
self.shared_conns.clone()
}
pub async fn run(
mut self,
mut use_keyspace_request_receiver: mpsc::Receiver<UseKeyspaceRequest>,
) {
debug!("[{}] Started asynchronous pool worker", self.address);
let mut next_refill_time = tokio::time::Instant::now();
let mut refill_scheduled = true;
loop {
tokio::select! {
_ = tokio::time::sleep_until(next_refill_time), if refill_scheduled => {
self.had_error_since_last_refill = false;
self.start_filling();
refill_scheduled = false;
}
evt = self.ready_connections.select_next_some(), if !self.ready_connections.is_empty() => {
self.handle_ready_connection(evt);
if self.is_full() {
debug!(
"[{}] Pool is full, clearing {} excess connections",
self.address,
self.excess_connections.len()
);
self.excess_connections.clear();
}
}
evt = self.connection_errors.select_next_some(), if !self.connection_errors.is_empty() => {
if let Some(conn) = evt.connection.upgrade() {
debug!("[{}] Got error for connection {:p}: {:?}", self.address, Arc::as_ptr(&conn), evt.error);
self.remove_connection(conn);
}
}
req = use_keyspace_request_receiver.recv() => {
if let Some(req) = req {
debug!("[{}] Requested keyspace change: {}", self.address, req.keyspace_name.as_str());
self.use_keyspace(&req.keyspace_name, req.response_sender);
} else {
trace!("[{}] Keyspace request channel dropped, stopping asynchronous pool worker", self.address);
return;
}
}
}
if !refill_scheduled && self.need_filling() {
self.update_shared_conns();
if self.had_error_since_last_refill {
self.refill_delay_strategy.on_fill_error();
} else {
self.refill_delay_strategy.on_successful_fill();
}
let delay = self.refill_delay_strategy.get_delay();
debug!(
"[{}] Scheduling next refill in {} ms",
self.address,
delay.as_millis(),
);
next_refill_time = tokio::time::Instant::now() + delay;
refill_scheduled = true;
}
}
}
fn is_filling(&self) -> bool {
!self.ready_connections.is_empty()
}
fn is_full(&self) -> bool {
match self.pool_config.pool_size {
PoolSize::PerHost(target) => self.active_connection_count() >= target.get(),
PoolSize::PerShard(target) => {
self.conns.iter().all(|conns| conns.len() >= target.get())
}
}
}
fn is_empty(&self) -> bool {
self.conns.iter().all(|conns| conns.is_empty())
}
fn need_filling(&self) -> bool {
!self.is_filling() && !self.is_full()
}
fn can_use_shard_aware_port(&self) -> bool {
self.sharder.is_some()
&& self.shard_aware_port.is_some()
&& self.pool_config.can_use_shard_aware_port
}
fn start_filling(&mut self) {
if self.is_empty() {
trace!(
"[{}] Will open the first connection to the node",
self.address
);
self.start_opening_connection(None);
return;
}
if self.can_use_shard_aware_port() {
if let PoolSize::PerShard(target) = self.pool_config.pool_size {
for (shard_id, shard_conns) in self.conns.iter().enumerate() {
let to_open_count = target.get().saturating_sub(shard_conns.len());
if to_open_count == 0 {
continue;
}
trace!(
"[{}] Will open {} connections to shard {}",
self.address,
to_open_count,
shard_id,
);
for _ in 0..to_open_count {
self.start_opening_connection(Some(shard_id as Shard));
}
}
return;
}
}
let to_open_count = match self.pool_config.pool_size {
PoolSize::PerHost(target) => {
target.get().saturating_sub(self.active_connection_count())
}
PoolSize::PerShard(target) => self
.conns
.iter()
.map(|conns| target.get().saturating_sub(conns.len()))
.sum::<usize>(),
};
trace!(
"[{}] Will open {} non-shard-aware connections",
self.address,
to_open_count,
);
for _ in 0..to_open_count {
self.start_opening_connection(None);
}
}
fn handle_ready_connection(&mut self, evt: OpenedConnectionEvent) {
match evt.result {
Err(err) => {
if evt.requested_shard.is_some() {
debug!(
"[{}] Failed to open connection to the shard-aware port: {:?}, will retry with regular port",
self.address,
err,
);
self.start_opening_connection(None);
} else {
self.had_error_since_last_refill = true;
debug!(
"[{}] Failed to open connection to the non-shard-aware port: {:?}",
self.address, err,
);
}
}
Ok((connection, error_receiver)) => {
let shard_info = connection.get_shard_info().as_ref();
let sharder = shard_info.map(|s| s.get_sharder());
let shard_id = shard_info.map_or(0, |s| s.shard as usize);
self.maybe_reshard(sharder);
if self.shard_aware_port != connection.get_shard_aware_port() {
debug!(
"[{}] Updating shard aware port: {:?}",
self.address,
connection.get_shard_aware_port(),
);
self.shard_aware_port = connection.get_shard_aware_port();
}
if let Some(keyspace) = &self.current_keyspace {
if evt.keyspace_name.as_ref() != Some(keyspace) {
self.start_setting_keyspace_for_connection(
connection,
error_receiver,
evt.requested_shard,
);
return;
}
}
let can_be_accepted = match self.pool_config.pool_size {
PoolSize::PerHost(target) => self.active_connection_count() < target.get(),
PoolSize::PerShard(target) => self.conns[shard_id].len() < target.get(),
};
if can_be_accepted {
let conn = Arc::new(connection);
trace!(
"[{}] Adding connection {:p} to shard {} pool, now there are {} for the shard, total {}",
self.address,
Arc::as_ptr(&conn),
shard_id,
self.conns[shard_id].len() + 1,
self.active_connection_count() + 1,
);
self.connection_errors
.push(wait_for_error(Arc::downgrade(&conn), error_receiver).boxed());
self.conns[shard_id].push(conn);
self.update_shared_conns();
} else if evt.requested_shard.is_some() {
debug!(
"[{}] Excess shard-aware port connection for shard {}; will retry with non-shard-aware port",
self.address,
shard_id,
);
self.start_opening_connection(None);
} else {
let conn = Arc::new(connection);
trace!(
"[{}] Storing excess connection {:p} for shard {}",
self.address,
Arc::as_ptr(&conn),
shard_id,
);
self.connection_errors
.push(wait_for_error(Arc::downgrade(&conn), error_receiver).boxed());
self.excess_connections.push(conn);
let excess_connection_limit = self.excess_connection_limit();
if self.excess_connections.len() > excess_connection_limit {
debug!(
"[{}] Excess connection pool exceeded limit of {} connections - clearing",
self.address,
excess_connection_limit,
);
self.excess_connections.clear();
}
}
}
}
}
fn start_opening_connection(&self, shard: Option<Shard>) {
let cfg = self.pool_config.connection_config.clone();
let fut = match (self.sharder.clone(), self.shard_aware_port, shard) {
(Some(sharder), Some(port), Some(shard)) => {
let shard_aware_address = (self.address, port).into();
async move {
let result = open_connection_to_shard_aware_port(
shard_aware_address,
shard,
sharder.clone(),
&cfg,
)
.await;
OpenedConnectionEvent {
result,
requested_shard: Some(shard),
keyspace_name: None,
}
}
.boxed()
}
_ => {
let non_shard_aware_address = (self.address, self.regular_port).into();
async move {
let result =
connection::open_connection(non_shard_aware_address, None, cfg).await;
OpenedConnectionEvent {
result,
requested_shard: None,
keyspace_name: None,
}
}
.boxed()
}
};
self.ready_connections.push(fut);
}
fn maybe_reshard(&mut self, new_sharder: Option<Sharder>) {
if self.sharder == new_sharder {
return;
}
debug!(
"[{}] New sharder: {:?}, clearing all connections",
self.address, new_sharder,
);
self.sharder = new_sharder.clone();
self.conns.clear();
let shard_count = new_sharder.map_or(1, |s| s.nr_shards.get() as usize);
self.conns.resize_with(shard_count, Vec::new);
self.excess_connections.clear();
}
fn update_shared_conns(&mut self) {
let new_conns = if !self.has_connections() {
Arc::new(MaybePoolConnections::Broken)
} else {
let new_conns = if let Some(sharder) = self.sharder.as_ref() {
debug_assert_eq!(self.conns.len(), sharder.nr_shards.get() as usize);
PoolConnections::Sharded {
sharder: sharder.clone(),
connections: self.conns.clone(),
}
} else {
debug_assert_eq!(self.conns.len(), 1);
PoolConnections::NotSharded(self.conns[0].clone())
};
Arc::new(MaybePoolConnections::Ready(new_conns))
};
self.shared_conns.store(new_conns);
self.pool_updated_notify.notify_waiters();
}
fn remove_connection(&mut self, connection: Arc<Connection>) {
let ptr = Arc::as_ptr(&connection);
let maybe_remove_in_vec = |v: &mut Vec<Arc<Connection>>| -> bool {
let maybe_idx = v
.iter()
.enumerate()
.find(|(_, other_conn)| Arc::ptr_eq(&connection, other_conn))
.map(|(idx, _)| idx);
match maybe_idx {
Some(idx) => {
v.swap_remove(idx);
true
}
None => false,
}
};
let shard_id = connection
.get_shard_info()
.as_ref()
.map_or(0, |s| s.shard as usize);
if shard_id < self.conns.len() && maybe_remove_in_vec(&mut self.conns[shard_id]) {
trace!(
"[{}] Connection {:p} removed from shard {} pool, now there is {} for the shard, total {}",
self.address,
ptr,
shard_id,
self.conns[shard_id].len(),
self.active_connection_count(),
);
self.update_shared_conns();
return;
}
if maybe_remove_in_vec(&mut self.excess_connections) {
trace!(
"[{}] Connection {:p} removed from excess connection pool",
self.address,
ptr,
);
return;
}
trace!(
"[{}] Connection {:p} was already removed",
self.address,
ptr,
);
}
fn use_keyspace(
&mut self,
keyspace_name: &VerifiedKeyspaceName,
response_sender: tokio::sync::oneshot::Sender<Result<(), QueryError>>,
) {
self.current_keyspace = Some(keyspace_name.clone());
let mut conns = self.conns.clone();
let keyspace_name = keyspace_name.clone();
let address = self.address;
let fut = async move {
let mut use_keyspace_futures = Vec::new();
for shard_conns in conns.iter_mut() {
for conn in shard_conns.iter_mut() {
let fut = conn.use_keyspace(&keyspace_name);
use_keyspace_futures.push(fut);
}
}
if use_keyspace_futures.is_empty() {
return Ok(());
}
let use_keyspace_results: Vec<Result<(), QueryError>> =
futures::future::join_all(use_keyspace_futures).await;
let mut was_ok: bool = false;
let mut io_error: Option<Arc<std::io::Error>> = None;
for result in use_keyspace_results {
match result {
Ok(()) => was_ok = true,
Err(err) => match err {
QueryError::IoError(io_err) => io_error = Some(io_err),
_ => return Err(err),
},
}
}
if was_ok {
return Ok(());
}
Err(QueryError::IoError(io_error.unwrap()))
};
tokio::task::spawn(async move {
let res = fut.await;
match &res {
Ok(()) => debug!("[{}] Successfully changed current keyspace", address),
Err(err) => warn!("[{}] Failed to change keyspace: {:?}", address, err),
}
let _ = response_sender.send(res);
});
}
fn start_setting_keyspace_for_connection(
&mut self,
connection: Connection,
error_receiver: ErrorReceiver,
requested_shard: Option<Shard>,
) {
let keyspace_name = self.current_keyspace.as_ref().cloned().unwrap();
self.ready_connections.push(
async move {
let result = connection.use_keyspace(&keyspace_name).await;
if let Err(err) = result {
warn!(
"[{}] Failed to set keyspace for new connection: {}",
connection.get_connect_address().ip(),
err,
);
}
OpenedConnectionEvent {
result: Ok((connection, error_receiver)),
requested_shard,
keyspace_name: Some(keyspace_name),
}
}
.boxed(),
);
}
fn has_connections(&self) -> bool {
self.conns.iter().any(|v| !v.is_empty())
}
fn active_connection_count(&self) -> usize {
self.conns.iter().map(Vec::len).sum::<usize>()
}
fn excess_connection_limit(&self) -> usize {
match self.pool_config.pool_size {
PoolSize::PerShard(_) => {
EXCESS_CONNECTION_BOUND_PER_SHARD_MULTIPLIER
* self
.sharder
.as_ref()
.map_or(1, |s| s.nr_shards.get() as usize)
}
PoolSize::PerHost(_) => 0,
}
}
}
struct BrokenConnectionEvent {
connection: Weak<Connection>,
error: QueryError,
}
async fn wait_for_error(
connection: Weak<Connection>,
error_receiver: ErrorReceiver,
) -> BrokenConnectionEvent {
BrokenConnectionEvent {
connection,
error: error_receiver.await.unwrap_or_else(|_| {
QueryError::IoError(Arc::new(std::io::Error::new(
ErrorKind::Other,
"Connection broken",
)))
}),
}
}
struct OpenedConnectionEvent {
result: Result<(Connection, ErrorReceiver), QueryError>,
requested_shard: Option<Shard>,
keyspace_name: Option<VerifiedKeyspaceName>,
}
async fn open_connection_to_shard_aware_port(
address: SocketAddr,
shard: Shard,
sharder: Sharder,
connection_config: &ConnectionConfig,
) -> Result<(Connection, ErrorReceiver), QueryError> {
let source_port_iter = sharder.iter_source_ports_for_shard(shard);
for port in source_port_iter {
let connect_result =
connection::open_connection(address, Some(port), connection_config.clone()).await;
match connect_result {
Err(err) if err.is_address_unavailable_for_use() => continue, result => return result,
}
}
Err(QueryError::IoError(Arc::new(std::io::Error::new(
std::io::ErrorKind::AddrInUse,
"Could not find free source port for shard",
))))
}
#[cfg(test)]
mod tests {
use super::open_connection_to_shard_aware_port;
use crate::routing::{ShardCount, Sharder};
use crate::transport::connection::ConnectionConfig;
use std::net::{SocketAddr, ToSocketAddrs};
#[tokio::test]
async fn many_connections() {
let connections_number = 512;
let connect_address: SocketAddr = std::env::var("SCYLLA_URI")
.unwrap_or_else(|_| "127.0.0.1:9042".to_string())
.to_socket_addrs()
.unwrap()
.next()
.unwrap();
let connection_config = ConnectionConfig {
compression: None,
tcp_nodelay: true,
#[cfg(feature = "ssl")]
ssl_context: None,
..Default::default()
};
let sharder = Sharder::new(ShardCount::new(3).unwrap(), 12);
let mut conns = Vec::new();
for _ in 0..connections_number {
conns.push(open_connection_to_shard_aware_port(
connect_address,
0,
sharder.clone(),
&connection_config,
));
}
let joined = futures::future::join_all(conns).await;
for res in joined {
res.unwrap();
}
}
}