use arc_swap::ArcSwapOption;
use cassandra_protocol::compression::Compression;
use cassandra_protocol::consistency::Consistency;
use cassandra_protocol::error;
use cassandra_protocol::events::ServerEvent;
use cassandra_protocol::frame::message_error::{ErrorType, UnpreparedError};
use cassandra_protocol::frame::message_query::BodyReqQuery;
use cassandra_protocol::frame::message_response::ResponseBody;
use cassandra_protocol::frame::message_result::{BodyResResultPrepared, TableSpec};
use cassandra_protocol::frame::{Envelope, Flags, Serialize, Version};
use cassandra_protocol::query::{PreparedQuery, QueryBatch, QueryValues};
use cassandra_protocol::types::value::Value;
use cassandra_protocol::types::{CBytesShort, CIntShort, SHORT_LEN};
use derivative::Derivative;
use futures::stream::FuturesUnordered;
use futures::{FutureExt, StreamExt};
use itertools::Itertools;
use std::io::{Cursor, Write};
use std::marker::PhantomData;
use std::net::SocketAddr;
use std::sync::{Arc, LazyLock, Mutex};
use thiserror::Error;
use tokio::sync::broadcast::{channel, Receiver, Sender};
use tokio::sync::watch;
use tokio::task::JoinHandle;
use tokio::time::sleep;
use tokio::{pin, select};
use tracing::*;
use crate::cluster::connection_manager::ConnectionManager;
use crate::cluster::connection_pool::{ConnectionPoolConfig, ConnectionPoolFactory};
use crate::cluster::control_connection::ControlConnection;
#[cfg(feature = "rust-tls")]
use crate::cluster::rustls_connection_manager::RustlsConnectionManager;
use crate::cluster::send_envelope::send_envelope;
use crate::cluster::tcp_connection_manager::TcpConnectionManager;
use crate::cluster::topology::{Node, NodeDistance, NodeState};
use crate::cluster::Murmur3Token;
#[cfg(feature = "rust-tls")]
use crate::cluster::NodeRustlsConfig;
use crate::cluster::{ClusterMetadata, ClusterMetadataManager, SessionContext};
use crate::cluster::{GenericClusterConfig, KeyspaceHolder};
use crate::cluster::{NodeTcpConfig, SessionPager};
use crate::frame_encoding::{FrameEncodingFactory, ProtocolFrameEncodingFactory};
use crate::future::BoxFuture;
use crate::load_balancing::node_distance_evaluator::AllLocalNodeDistanceEvaluator;
use crate::load_balancing::node_distance_evaluator::NodeDistanceEvaluator;
use crate::load_balancing::{
InitializingWrapperLoadBalancingStrategy, LoadBalancingStrategy, QueryPlan, Request,
};
use crate::retry::{
DefaultRetryPolicy, ExponentialReconnectionPolicy, ReconnectionPolicy, RetryPolicy,
};
use crate::speculative_execution::{Context, SpeculativeExecutionPolicy};
use crate::statement::{StatementParams, StatementParamsBuilder};
#[cfg(feature = "rust-tls")]
use crate::transport::TransportRustls;
use crate::transport::{CdrsTransport, TransportTcp};
pub const DEFAULT_TRANSPORT_BUFFER_SIZE: usize = 1024;
const DEFAULT_EVENT_CHANNEL_CAPACITY: usize = 128;
const MAX_REPREPARE_ATTEMPTS: usize = 5;
static DEFAULT_STATEMENT_PARAMETERS: LazyLock<StatementParams> = LazyLock::new(Default::default);
#[inline]
fn convert_to_prepared(body: ResponseBody) -> error::Result<BodyResResultPrepared> {
body.into_prepared()
.ok_or_else(|| "Cannot convert envelope into prepare response!".into())
}
#[inline]
fn prepare_flags(with_tracing: bool, with_warnings: bool, beta_protocol: bool) -> Flags {
let mut flags = Flags::empty();
if with_tracing {
flags.insert(Flags::TRACING);
}
if with_warnings {
flags.insert(Flags::WARNING);
}
if beta_protocol {
flags.insert(Flags::BETA);
}
flags
}
fn create_keyspace_holder() -> (Arc<KeyspaceHolder>, watch::Receiver<Option<String>>) {
let (keyspace_sender, keyspace_receiver) = watch::channel(None);
(
Arc::new(KeyspaceHolder::new(keyspace_sender)),
keyspace_receiver,
)
}
fn verify_compression_configuration(
version: Version,
compression: Compression,
) -> Result<(), SessionBuildError> {
if version < Version::V5 || compression != Compression::Snappy {
Ok(())
} else {
Err(SessionBuildError::CompressionTypeNotSupported)
}
}
fn serialize_routing_value(cursor: &mut Cursor<&mut Vec<u8>>, value: &Vec<u8>, version: Version) {
let temp_size: CIntShort = 0;
temp_size.serialize(cursor, version);
let before_value_pos = cursor.position();
value.serialize(cursor, version);
let after_value_pos = cursor.position();
cursor.set_position(before_value_pos - SHORT_LEN as u64);
let value_size: CIntShort = (after_value_pos - before_value_pos) as CIntShort;
value_size.serialize(cursor, version);
cursor.set_position(after_value_pos);
let _ = cursor.write(&[0]);
}
fn serialize_routing_key_with_indexes(
values: &[Value],
pk_indexes: &[i16],
version: Version,
) -> Option<Vec<u8>> {
match pk_indexes.len() {
0 => None,
1 => values
.get(pk_indexes[0] as usize)
.and_then(|value| match value {
Value::Some(value) => Some(value.serialize_to_vec(version)),
_ => None,
}),
_ => {
let mut buf = vec![];
if pk_indexes
.iter()
.map(|index| values.get(*index as usize))
.fold_options(Cursor::new(&mut buf), |mut cursor, value| {
if let Value::Some(value) = value {
serialize_routing_value(&mut cursor, value, version)
}
cursor
})
.is_some()
{
Some(buf)
} else {
None
}
}
}
}
fn serialize_routing_key(values: &[Value], version: Version) -> Vec<u8> {
match values.len() {
0 => vec![],
1 => match &values[0] {
Value::Some(value) => value.serialize_to_vec(version),
_ => vec![],
},
_ => {
let mut buf = vec![];
let mut cursor = Cursor::new(&mut buf);
for value in values {
if let Value::Some(value) = value {
serialize_routing_value(&mut cursor, value, version);
}
}
buf
}
}
}
#[derive(Derivative)]
#[derivative(Debug)]
pub struct Session<
T: CdrsTransport + 'static,
CM: ConnectionManager<T> + 'static,
LB: LoadBalancingStrategy<T, CM> + Send + Sync,
> {
#[derivative(Debug = "ignore")]
load_balancing: Arc<InitializingWrapperLoadBalancingStrategy<T, CM, LB>>,
keyspace_holder: Arc<KeyspaceHolder>,
#[derivative(Debug = "ignore")]
retry_policy: Box<dyn RetryPolicy + Send + Sync>,
#[derivative(Debug = "ignore")]
speculative_execution_policy: Option<Box<dyn SpeculativeExecutionPolicy + Send + Sync>>,
control_connection_handle: JoinHandle<()>,
event_sender: Sender<ServerEvent>,
#[derivative(Debug = "ignore")]
cluster_metadata_manager: Arc<ClusterMetadataManager<T, CM>>,
#[derivative(Debug = "ignore")]
_transport: PhantomData<T>,
#[derivative(Debug = "ignore")]
_connection_manager: PhantomData<CM>,
version: Version,
}
impl<
T: CdrsTransport + 'static,
CM: ConnectionManager<T>,
LB: LoadBalancingStrategy<T, CM> + Send + Sync,
> Drop for Session<T, CM, LB>
{
fn drop(&mut self) {
self.control_connection_handle.abort();
}
}
impl<
T: CdrsTransport + 'static,
CM: ConnectionManager<T> + Send + Sync + 'static,
LB: LoadBalancingStrategy<T, CM> + Send + Sync + 'static,
> Session<T, CM, LB>
{
pub fn paged(&self, page_size: i32) -> SessionPager<'_, T, CM, LB> {
SessionPager::new(self, page_size)
}
pub fn exec_with_params<'a, 'b: 'a>(
&'a self,
prepared: &'b PreparedQuery,
parameters: &'b StatementParams,
) -> BoxFuture<'a, error::Result<Envelope>> {
async move {
let consistency = parameters.query_params.consistency;
let flags = prepare_flags(
parameters.tracing,
parameters.warnings,
parameters.beta_protocol,
);
let result_metadata_id = prepared
.result_metadata_id
.load()
.as_ref()
.map(|metadata| (**metadata).clone());
let envelope = Envelope::new_req_execute(
&prepared.id,
result_metadata_id.as_ref(),
¶meters.query_params,
flags,
self.version,
);
let keyspace = prepared
.keyspace
.as_deref()
.or(parameters.keyspace.as_deref());
let routing_key =
parameters
.query_params
.values
.as_ref()
.and_then(|values| match values {
QueryValues::SimpleValues(values) => serialize_routing_key_with_indexes(
values,
&prepared.pk_indexes,
self.version,
)
.or_else(|| {
parameters
.routing_key
.as_ref()
.map(|values| serialize_routing_key(values, self.version))
}),
QueryValues::NamedValues(_) => None,
});
let mut attempts_remaining = MAX_REPREPARE_ATTEMPTS;
let result = loop {
let result = self
.send_envelope(
&envelope,
parameters.is_idempotent,
keyspace,
parameters.token,
routing_key.as_deref(),
Some(consistency),
parameters.speculative_execution_policy.as_ref(),
parameters.retry_policy.as_ref(),
)
.await;
if let Err(error::Error::Server { body: error, addr }) = &result {
if let ErrorType::Unprepared(_) = error.ty {
if attempts_remaining > 0 {
attempts_remaining -= 1;
if self
.reprepare(
&prepared.id,
prepared.query.clone(),
keyspace.map(|keyspace| keyspace.to_string()),
parameters,
*addr,
)
.await
.is_ok()
{
continue;
}
}
}
}
break result;
};
let response = result
.as_ref()
.map_err(|error| error.clone())
.and_then(|result| result.response_body());
let new_metadata_id = response.as_ref().map(|result| {
result
.as_rows_metadata()
.and_then(|metadata| metadata.new_metadata_id.as_ref())
});
if let Ok(Some(new_metadata_id)) = new_metadata_id {
prepared
.result_metadata_id
.swap(Some(Arc::new(new_metadata_id.clone())));
}
result
}
.boxed()
}
pub async fn exec_with_values<V: Into<QueryValues>>(
&self,
prepared: &PreparedQuery,
values: V,
) -> error::Result<Envelope> {
self.exec_with_params(
prepared,
&StatementParamsBuilder::new()
.with_values(values.into())
.build(),
)
.await
}
#[inline]
pub async fn exec(&self, prepared: &PreparedQuery) -> error::Result<Envelope> {
self.exec_with_params(prepared, &DEFAULT_STATEMENT_PARAMETERS)
.await
}
pub async fn prepare_raw_tw<Q: ToString>(
&self,
query: Q,
keyspace: Option<String>,
with_tracing: bool,
with_warnings: bool,
beta_protocol: bool,
) -> error::Result<BodyResResultPrepared> {
self.prepare_raw_tw_with_query_plan(
query,
keyspace,
with_tracing,
with_warnings,
beta_protocol,
None,
)
.await
}
pub async fn prepare_raw_tw_with_query_plan<Q: ToString>(
&self,
query: Q,
keyspace: Option<String>,
with_tracing: bool,
with_warnings: bool,
beta_protocol: bool,
query_plan: Option<QueryPlan<T, CM>>,
) -> error::Result<BodyResResultPrepared> {
let flags = prepare_flags(with_tracing, with_warnings, beta_protocol);
let envelope = Envelope::new_req_prepare(query.to_string(), keyspace, flags, self.version);
let response = match query_plan {
None => {
self.send_envelope(&envelope, true, None, None, None, None, None, None)
.await
}
Some(query_plan) => send_envelope(
query_plan.nodes.into_iter(),
&envelope,
true,
self.retry_policy.as_ref().new_session(),
)
.await
.unwrap_or_else(|| Err("No response for prepare!".into())),
};
response
.and_then(|response| response.response_body())
.and_then(convert_to_prepared)
}
#[inline]
pub async fn prepare_raw<Q: ToString>(&self, query: Q) -> error::Result<BodyResResultPrepared> {
self.prepare_raw_tw(query, None, false, false, false).await
}
pub async fn prepare_tw<Q: ToString>(
&self,
query: Q,
keyspace: Option<String>,
with_tracing: bool,
with_warnings: bool,
beta_protocol: bool,
) -> error::Result<PreparedQuery> {
let s = query.to_string();
self.prepare_raw_tw(query, keyspace, with_tracing, with_warnings, beta_protocol)
.await
.map(|result| PreparedQuery {
id: result.id,
query: s,
keyspace: result
.metadata
.global_table_spec
.map(|TableSpec { ks_name, .. }| ks_name),
pk_indexes: result.metadata.pk_indexes,
result_metadata_id: ArcSwapOption::new(result.result_metadata_id.map(Arc::new)),
})
}
#[inline]
pub async fn prepare<Q: ToString>(&self, query: Q) -> error::Result<PreparedQuery> {
self.prepare_tw(query, None, false, false, false).await
}
#[inline]
pub async fn batch(&self, batch: QueryBatch) -> error::Result<Envelope> {
self.batch_with_params(batch, &DEFAULT_STATEMENT_PARAMETERS)
.await
}
pub fn batch_with_params<'a, 'b: 'a>(
&'a self,
batch: QueryBatch,
parameters: &'b StatementParams,
) -> BoxFuture<'a, error::Result<Envelope>> {
async move {
let flags = prepare_flags(
parameters.tracing,
parameters.warnings,
parameters.beta_protocol,
);
let consistency = batch.request.consistency;
let envelope = Envelope::new_req_batch(batch.request.clone(), flags, self.version);
let mut attempts_remaining = MAX_REPREPARE_ATTEMPTS;
loop {
let result = self
.send_envelope(
&envelope,
parameters.is_idempotent,
parameters.keyspace.as_deref(),
None,
None,
Some(consistency),
parameters.speculative_execution_policy.as_ref(),
parameters.retry_policy.as_ref(),
)
.await;
if let Err(error::Error::Server { body: error, addr }) = &result {
if let ErrorType::Unprepared(UnpreparedError { id }) = &error.ty {
if attempts_remaining == 0 {
return result;
}
let query = match batch.prepared_queries.get(id) {
None => {
warn!(
?id,
"Cannot find prepared query for unprepared statement in a batch!"
);
return result;
}
Some(query) => query,
};
attempts_remaining -= 1;
let prepare_result = self
.reprepare(
id,
query.query.clone(),
query.keyspace.clone(),
parameters,
*addr,
)
.await;
if prepare_result.is_ok() {
continue;
}
}
}
return result;
}
}
.boxed()
}
async fn reprepare(
&self,
id: &CBytesShort,
query: String,
keyspace: Option<String>,
parameters: &StatementParams,
node_broadcast_rpc_address: SocketAddr,
) -> error::Result<()> {
debug!("Re-preparing statement.");
let flags = prepare_flags(
parameters.tracing,
parameters.warnings,
parameters.beta_protocol,
);
let node = self
.cluster_metadata_manager
.find_node_by_rpc_address(node_broadcast_rpc_address)
.ok_or_else(|| {
error::Error::from(format!(
"Cannot find node {node_broadcast_rpc_address} for statement re-preparation!"
))
})?;
let prepare_envelope = Envelope::new_req_prepare(query, keyspace, flags, self.version);
let retry_policy = self.effective_retry_policy(parameters.retry_policy.as_ref());
let prepare_result = send_envelope(
[node].iter().cloned(),
&prepare_envelope,
true,
retry_policy.new_session(),
)
.await
.unwrap_or_else(|| Err("No response for re-prepare statement!".into()))
.and_then(|response| response.response_body())
.and_then(convert_to_prepared)?;
if id != &prepare_result.id {
return Err("Re-preparing an unprepared statement resulted in a different id - probably schema changed on the server.".into());
}
Ok(())
}
#[inline]
pub async fn query<Q: ToString>(&self, query: Q) -> error::Result<Envelope> {
self.query_with_params(query, DEFAULT_STATEMENT_PARAMETERS.clone())
.await
}
#[inline]
pub async fn query_with_values<Q: ToString, V: Into<QueryValues>>(
&self,
query: Q,
values: V,
) -> error::Result<Envelope> {
self.query_with_params(
query,
StatementParamsBuilder::new()
.with_values(values.into())
.build(),
)
.await
}
pub async fn query_with_params<Q: ToString>(
&self,
query: Q,
parameters: StatementParams,
) -> error::Result<Envelope> {
let is_idempotent = parameters.is_idempotent;
let consistency = parameters.query_params.consistency;
let keyspace = parameters.keyspace;
let token = parameters.token;
let routing_key = parameters
.routing_key
.as_ref()
.map(|values| serialize_routing_key(values, self.version));
let query = BodyReqQuery {
query: query.to_string(),
query_params: parameters.query_params,
};
let flags = prepare_flags(
parameters.tracing,
parameters.warnings,
parameters.beta_protocol,
);
let envelope = Envelope::new_query(query, flags, self.version);
self.send_envelope(
&envelope,
is_idempotent,
keyspace.as_deref(),
token,
routing_key.as_deref(),
Some(consistency),
parameters.speculative_execution_policy.as_ref(),
parameters.retry_policy.as_ref(),
)
.await
}
#[inline]
pub fn current_keyspace(&self) -> Option<Arc<String>> {
self.keyspace_holder.current_keyspace()
}
#[inline]
pub fn cluster_metadata(&self) -> Arc<ClusterMetadata<T, CM>> {
self.cluster_metadata_manager.metadata()
}
#[inline]
pub fn query_plan(&self, request: Option<Request>) -> QueryPlan<T, CM> {
self.load_balancing
.query_plan(request, self.cluster_metadata().as_ref())
}
#[inline]
pub fn create_event_receiver(&self) -> Receiver<ServerEvent> {
self.event_sender.subscribe()
}
#[inline]
pub fn retry_policy(&self) -> &dyn RetryPolicy {
self.retry_policy.as_ref()
}
#[allow(clippy::too_many_arguments)]
async fn send_envelope(
&self,
envelope: &Envelope,
is_idempotent: bool,
keyspace: Option<&str>,
token: Option<Murmur3Token>,
routing_key: Option<&[u8]>,
consistency: Option<Consistency>,
speculative_execution_policy: Option<&Arc<dyn SpeculativeExecutionPolicy + Send + Sync>>,
retry_policy: Option<&Arc<dyn RetryPolicy + Send + Sync>>,
) -> error::Result<Envelope> {
let current_keyspace = self.current_keyspace();
let request = Request::new(
keyspace.or_else(|| current_keyspace.as_ref().map(|keyspace| &***keyspace)),
token,
routing_key,
consistency,
);
let query_plan = self.query_plan(Some(request));
struct SharedQueryPlan<
T: CdrsTransport + 'static,
CM: ConnectionManager<T> + 'static,
I: Iterator<Item = Arc<Node<T, CM>>>,
> {
current_node: Mutex<I>,
}
impl<
T: CdrsTransport + 'static,
CM: ConnectionManager<T> + 'static,
I: Iterator<Item = Arc<Node<T, CM>>>,
> SharedQueryPlan<T, CM, I>
{
fn new(current_node: I) -> Self {
SharedQueryPlan {
current_node: Mutex::new(current_node),
}
}
}
impl<
T: CdrsTransport + 'static,
CM: ConnectionManager<T> + 'static,
I: Iterator<Item = Arc<Node<T, CM>>>,
> Iterator for &SharedQueryPlan<T, CM, I>
{
type Item = Arc<Node<T, CM>>;
fn next(&mut self) -> Option<Self::Item> {
self.current_node.lock().unwrap().next()
}
}
let speculative_execution_policy = speculative_execution_policy
.map(|speculative_execution_policy| speculative_execution_policy.as_ref())
.or(self.speculative_execution_policy.as_deref());
let retry_policy = self.effective_retry_policy(retry_policy);
match speculative_execution_policy {
Some(speculative_execution_policy) if is_idempotent => {
let shared_query_plan = SharedQueryPlan::new(query_plan.nodes.into_iter());
let mut context = Context::new(1);
let mut async_tasks = FuturesUnordered::new();
async_tasks.push(send_envelope(
&shared_query_plan,
envelope,
is_idempotent,
retry_policy.new_session(),
));
let sleep_fut = sleep(
speculative_execution_policy
.execution_interval(&context)
.unwrap_or_default(),
)
.fuse();
pin!(sleep_fut);
let mut last_error = None;
loop {
select! {
_ = &mut sleep_fut => {
if let Some(interval) =
speculative_execution_policy.execution_interval(&context)
{
context.running_executions += 1;
async_tasks.push(send_envelope(
&shared_query_plan,
envelope,
is_idempotent,
retry_policy.new_session(),
));
sleep_fut.set(sleep(interval).fuse());
}
}
result = async_tasks.select_next_some() => {
match result {
Some(result) => {
match result {
Err(error::Error::Io(_)) | Err(error::Error::Timeout(_)) => {
last_error = Some(result);
},
_ => return result,
}
}
None => {
if async_tasks.is_empty() {
return last_error.unwrap_or_else(|| Err("No nodes available in query plan!".into()));
}
}
}
}
}
}
}
_ => send_envelope(
query_plan.nodes.into_iter(),
envelope,
is_idempotent,
retry_policy.new_session(),
)
.await
.unwrap_or_else(|| Err("No nodes available in query plan!".into())),
}
}
#[inline]
fn effective_retry_policy<'a, 'b: 'a>(
&'a self,
retry_policy: Option<&'b Arc<dyn RetryPolicy + Send + Sync>>,
) -> &'a (dyn RetryPolicy + Send + Sync) {
retry_policy
.map(|retry_policy| retry_policy.as_ref())
.unwrap_or_else(|| self.retry_policy.as_ref())
}
#[allow(clippy::too_many_arguments)]
async fn new(
load_balancing: LB,
keyspace_holder: Arc<KeyspaceHolder>,
keyspace_receiver: watch::Receiver<Option<String>>,
retry_policy: Box<dyn RetryPolicy + Send + Sync>,
reconnection_policy: Arc<dyn ReconnectionPolicy + Send + Sync>,
node_distance_evaluator: Box<dyn NodeDistanceEvaluator + Send + Sync>,
speculative_execution_policy: Option<Box<dyn SpeculativeExecutionPolicy + Send + Sync>>,
contact_points: Vec<SocketAddr>,
connection_manager: CM,
event_channel_capacity: usize,
version: Version,
connection_pool_config: ConnectionPoolConfig,
beta_protocol: bool,
) -> Result<Self, SessionBuildError> {
let connection_pool_factory = Arc::new(ConnectionPoolFactory::new(
connection_pool_config,
version,
connection_manager,
keyspace_receiver,
reconnection_policy.clone(),
));
let contact_points = contact_points
.into_iter()
.map(|contact_point| {
Arc::new(Node::new_with_state(
connection_pool_factory.clone(),
contact_point,
None,
None,
Some(NodeDistance::Local),
NodeState::Up,
Default::default(),
"".into(),
"".into(),
))
})
.collect_vec();
let load_balancing = Arc::new(InitializingWrapperLoadBalancingStrategy::new(
load_balancing,
contact_points.clone(),
));
let (event_sender, event_receiver) = channel(event_channel_capacity);
let session_context = Arc::new(SessionContext::default());
let cluster_metadata_manager = Arc::new(ClusterMetadataManager::new(
contact_points.clone(),
connection_pool_factory,
session_context.clone(),
node_distance_evaluator,
version,
beta_protocol,
));
cluster_metadata_manager.listen_to_events(event_receiver);
let control_connection = ControlConnection::new(
load_balancing.clone(),
contact_points,
reconnection_policy.clone(),
cluster_metadata_manager.clone(),
event_sender.clone(),
session_context,
version,
);
let (init_complete_sender, init_complete_receiver) = tokio::sync::oneshot::channel();
let control_connection_handle =
AbortOnDropHandle::new(tokio::spawn(control_connection.run(init_complete_sender)));
if init_complete_receiver.await.is_err() {
return Err(SessionBuildError::SessionInitFailed);
}
Ok(Session {
load_balancing,
keyspace_holder,
retry_policy,
speculative_execution_policy,
control_connection_handle: control_connection_handle.into_inner(),
event_sender,
cluster_metadata_manager,
_transport: Default::default(),
_connection_manager: Default::default(),
version,
})
}
}
#[repr(transparent)]
pub struct RetryPolicyWrapper(pub Box<dyn RetryPolicy + Send + Sync>);
#[repr(transparent)]
pub struct ReconnectionPolicyWrapper(pub Arc<dyn ReconnectionPolicy + Send + Sync>);
#[repr(transparent)]
pub struct NodeDistanceEvaluatorWrapper(pub Box<dyn NodeDistanceEvaluator + Send + Sync>);
#[repr(transparent)]
pub struct SpeculativeExecutionPolicyWrapper(pub Box<dyn SpeculativeExecutionPolicy + Send + Sync>);
pub async fn connect_generic<T, C, A, CM, LB>(
config: &C,
initial_nodes: A,
load_balancing: LB,
retry_policy: RetryPolicyWrapper,
reconnection_policy: ReconnectionPolicyWrapper,
node_distance_evaluator: NodeDistanceEvaluatorWrapper,
speculative_execution_policy: Option<SpeculativeExecutionPolicyWrapper>,
) -> error::Result<Session<T, CM, LB>>
where
A: IntoIterator<Item = SocketAddr>,
T: CdrsTransport + 'static,
CM: ConnectionManager<T> + Send + Sync + 'static,
C: GenericClusterConfig<T, CM>,
LB: LoadBalancingStrategy<T, CM> + Sized + Send + Sync + 'static,
{
let (keyspace_holder, keyspace_receiver) = create_keyspace_holder();
let connection_manager = config.create_manager(keyspace_holder.clone()).await?;
Session::new(
load_balancing,
keyspace_holder,
keyspace_receiver,
retry_policy.0,
reconnection_policy.0,
node_distance_evaluator.0,
speculative_execution_policy.map(|policy| policy.0),
initial_nodes.into_iter().collect(),
connection_manager,
config.event_channel_capacity(),
config.version(),
config.connection_pool_config(),
config.beta_protocol(),
)
.await
.map_err(|e| error::Error::General(e.to_string()))
}
struct SessionConfig<
T: CdrsTransport,
CM: ConnectionManager<T>,
LB: LoadBalancingStrategy<T, CM> + Send + Sync,
> {
compression: Compression,
transport_buffer_size: usize,
tcp_nodelay: bool,
load_balancing: LB,
retry_policy: Box<dyn RetryPolicy + Send + Sync>,
reconnection_policy: Arc<dyn ReconnectionPolicy + Send + Sync>,
node_distance_evaluator: Box<dyn NodeDistanceEvaluator + Send + Sync>,
speculative_execution_policy: Option<Box<dyn SpeculativeExecutionPolicy + Send + Sync>>,
event_channel_capacity: usize,
connection_pool_config: ConnectionPoolConfig,
keyspace: Option<String>,
_connection_manager: PhantomData<CM>,
_transport: PhantomData<T>,
}
impl<
T: CdrsTransport + 'static,
CM: ConnectionManager<T> + 'static,
LB: LoadBalancingStrategy<T, CM> + Send + Sync + 'static,
> SessionConfig<T, CM, LB>
{
fn new(load_balancing: LB) -> Self {
SessionConfig {
compression: Compression::None,
transport_buffer_size: DEFAULT_TRANSPORT_BUFFER_SIZE,
tcp_nodelay: true,
load_balancing,
retry_policy: Box::<DefaultRetryPolicy>::default(),
reconnection_policy: Arc::new(ExponentialReconnectionPolicy::default()),
node_distance_evaluator: Box::<AllLocalNodeDistanceEvaluator>::default(),
speculative_execution_policy: None,
event_channel_capacity: DEFAULT_EVENT_CHANNEL_CAPACITY,
connection_pool_config: Default::default(),
keyspace: None,
_connection_manager: Default::default(),
_transport: Default::default(),
}
}
async fn into_session(
self,
keyspace_holder: Arc<KeyspaceHolder>,
keyspace_receiver: watch::Receiver<Option<String>>,
contact_points: Vec<SocketAddr>,
connection_manager: CM,
version: Version,
beta_protocol: bool,
) -> Result<Session<T, CM, LB>, SessionBuildError> {
if let Some(keyspace) = self.keyspace {
keyspace_holder.update_current_keyspace_without_notification(keyspace);
}
Session::new(
self.load_balancing,
keyspace_holder,
keyspace_receiver,
self.retry_policy,
self.reconnection_policy,
self.node_distance_evaluator,
self.speculative_execution_policy,
contact_points,
connection_manager,
self.event_channel_capacity,
version,
self.connection_pool_config,
beta_protocol,
)
.await
}
}
#[derive(Error, Debug, Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
pub enum SessionBuildError {
#[error("Given compression type is not supported for selected protocol!")]
CompressionTypeNotSupported,
#[error("Session control connection died before completing initialization")]
SessionInitFailed,
}
pub trait SessionBuilder<
T: CdrsTransport + 'static,
CM: ConnectionManager<T>,
LB: LoadBalancingStrategy<T, CM> + Send + Sync + 'static,
>
{
#[must_use]
fn with_compression(self, compression: Compression) -> Self;
#[must_use]
fn with_retry_policy(self, retry_policy: Box<dyn RetryPolicy + Send + Sync>) -> Self;
#[must_use]
fn with_reconnection_policy(
self,
reconnection_policy: Arc<dyn ReconnectionPolicy + Send + Sync>,
) -> Self;
#[must_use]
fn with_frame_encoder_factory(
self,
frame_encoder_factory: Box<dyn FrameEncodingFactory + Send + Sync>,
) -> Self;
#[must_use]
fn with_node_distance_evaluator(
self,
node_distance_evaluator: Box<dyn NodeDistanceEvaluator + Send + Sync>,
) -> Self;
#[must_use]
fn with_speculative_execution_policy(
self,
speculative_execution_policy: Box<dyn SpeculativeExecutionPolicy + Send + Sync>,
) -> Self;
#[must_use]
fn with_transport_buffer_size(self, transport_buffer_size: usize) -> Self;
#[must_use]
fn with_tcp_nodelay(self, tcp_nodelay: bool) -> Self;
#[must_use]
fn with_event_channel_capacity(self, event_channel_capacity: usize) -> Self;
#[must_use]
fn with_connection_pool_config(self, connection_pool_config: ConnectionPoolConfig) -> Self;
#[must_use]
fn with_keyspace(self, keyspace: String) -> Self;
#[must_use]
fn with_beta_protocol(self, beta_protocol: bool) -> Self;
fn build(self) -> BoxFuture<'static, Result<Session<T, CM, LB>, SessionBuildError>>;
}
pub struct TcpSessionBuilder<
LB: LoadBalancingStrategy<TransportTcp, TcpConnectionManager> + Send + Sync,
> {
config: SessionConfig<TransportTcp, TcpConnectionManager, LB>,
node_config: NodeTcpConfig,
frame_encoder_factory: Box<dyn FrameEncodingFactory + Send + Sync>,
}
impl<LB: LoadBalancingStrategy<TransportTcp, TcpConnectionManager> + Send + Sync + 'static>
TcpSessionBuilder<LB>
{
pub fn new(load_balancing: LB, node_config: NodeTcpConfig) -> Self {
TcpSessionBuilder {
config: SessionConfig::new(load_balancing),
node_config,
frame_encoder_factory: Box::<ProtocolFrameEncodingFactory>::default(),
}
}
}
impl<LB: LoadBalancingStrategy<TransportTcp, TcpConnectionManager> + Send + Sync + 'static>
SessionBuilder<TransportTcp, TcpConnectionManager, LB> for TcpSessionBuilder<LB>
{
fn with_compression(mut self, compression: Compression) -> Self {
self.config.compression = compression;
self
}
fn with_retry_policy(mut self, retry_policy: Box<dyn RetryPolicy + Send + Sync>) -> Self {
self.config.retry_policy = retry_policy;
self
}
fn with_reconnection_policy(
mut self,
reconnection_policy: Arc<dyn ReconnectionPolicy + Send + Sync>,
) -> Self {
self.config.reconnection_policy = reconnection_policy;
self
}
fn with_frame_encoder_factory(
mut self,
frame_encoder_factory: Box<dyn FrameEncodingFactory + Send + Sync>,
) -> Self {
self.frame_encoder_factory = frame_encoder_factory;
self
}
fn with_node_distance_evaluator(
mut self,
node_distance_evaluator: Box<dyn NodeDistanceEvaluator + Send + Sync>,
) -> Self {
self.config.node_distance_evaluator = node_distance_evaluator;
self
}
fn with_speculative_execution_policy(
mut self,
speculative_execution_policy: Box<dyn SpeculativeExecutionPolicy + Send + Sync>,
) -> Self {
self.config.speculative_execution_policy = Some(speculative_execution_policy);
self
}
fn with_transport_buffer_size(mut self, transport_buffer_size: usize) -> Self {
self.config.transport_buffer_size = transport_buffer_size;
self
}
fn with_tcp_nodelay(mut self, tcp_nodelay: bool) -> Self {
self.config.tcp_nodelay = tcp_nodelay;
self
}
fn with_event_channel_capacity(mut self, event_channel_capacity: usize) -> Self {
self.config.event_channel_capacity = event_channel_capacity;
self
}
fn with_connection_pool_config(mut self, connection_pool_config: ConnectionPoolConfig) -> Self {
self.config.connection_pool_config = connection_pool_config;
self
}
fn with_keyspace(mut self, keyspace: String) -> Self {
self.config.keyspace = Some(keyspace);
self
}
fn with_beta_protocol(mut self, beta_protocol: bool) -> Self {
self.node_config.beta_protocol = beta_protocol;
self
}
fn build(
self,
) -> BoxFuture<
'static,
Result<Session<TransportTcp, TcpConnectionManager, LB>, SessionBuildError>,
> {
async move {
match verify_compression_configuration(
self.node_config.version,
self.config.compression,
) {
Ok(()) => {
let (keyspace_holder, keyspace_receiver) = create_keyspace_holder();
let connection_manager = TcpConnectionManager::new(
self.node_config.authenticator_provider,
keyspace_holder.clone(),
self.frame_encoder_factory,
self.config.compression,
self.config.transport_buffer_size,
self.config.tcp_nodelay,
self.node_config.version,
#[cfg(feature = "http-proxy")]
self.node_config.http_proxy,
);
self.config
.into_session(
keyspace_holder,
keyspace_receiver,
self.node_config.contact_points,
connection_manager,
self.node_config.version,
self.node_config.beta_protocol,
)
.await
}
Err(err) => Err(err),
}
}
.boxed()
}
}
#[cfg(feature = "rust-tls")]
pub struct RustlsSessionBuilder<
LB: LoadBalancingStrategy<TransportRustls, RustlsConnectionManager> + Send + Sync + 'static,
> {
config: SessionConfig<TransportRustls, RustlsConnectionManager, LB>,
node_config: NodeRustlsConfig,
frame_encoder_factory: Box<dyn FrameEncodingFactory + Send + Sync>,
}
#[cfg(feature = "rust-tls")]
impl<LB: LoadBalancingStrategy<TransportRustls, RustlsConnectionManager> + Send + Sync>
RustlsSessionBuilder<LB>
{
pub fn new(load_balancing: LB, node_config: NodeRustlsConfig) -> Self {
RustlsSessionBuilder {
config: SessionConfig::new(load_balancing),
node_config,
frame_encoder_factory: Box::<ProtocolFrameEncodingFactory>::default(),
}
}
}
#[cfg(feature = "rust-tls")]
impl<
LB: LoadBalancingStrategy<TransportRustls, RustlsConnectionManager> + Send + Sync + 'static,
> SessionBuilder<TransportRustls, RustlsConnectionManager, LB> for RustlsSessionBuilder<LB>
{
fn with_compression(mut self, compression: Compression) -> Self {
self.config.compression = compression;
self
}
fn with_retry_policy(mut self, retry_policy: Box<dyn RetryPolicy + Send + Sync>) -> Self {
self.config.retry_policy = retry_policy;
self
}
fn with_reconnection_policy(
mut self,
reconnection_policy: Arc<dyn ReconnectionPolicy + Send + Sync>,
) -> Self {
self.config.reconnection_policy = reconnection_policy;
self
}
fn with_frame_encoder_factory(
mut self,
frame_encoder_factory: Box<dyn FrameEncodingFactory + Send + Sync>,
) -> Self {
self.frame_encoder_factory = frame_encoder_factory;
self
}
fn with_node_distance_evaluator(
mut self,
node_distance_evaluator: Box<dyn NodeDistanceEvaluator + Send + Sync>,
) -> Self {
self.config.node_distance_evaluator = node_distance_evaluator;
self
}
fn with_speculative_execution_policy(
mut self,
speculative_execution_policy: Box<dyn SpeculativeExecutionPolicy + Send + Sync>,
) -> Self {
self.config.speculative_execution_policy = Some(speculative_execution_policy);
self
}
fn with_transport_buffer_size(mut self, transport_buffer_size: usize) -> Self {
self.config.transport_buffer_size = transport_buffer_size;
self
}
fn with_tcp_nodelay(mut self, tcp_nodelay: bool) -> Self {
self.config.tcp_nodelay = tcp_nodelay;
self
}
fn with_event_channel_capacity(mut self, event_channel_capacity: usize) -> Self {
self.config.event_channel_capacity = event_channel_capacity;
self
}
fn with_connection_pool_config(mut self, connection_pool_config: ConnectionPoolConfig) -> Self {
self.config.connection_pool_config = connection_pool_config;
self
}
fn with_keyspace(mut self, keyspace: String) -> Self {
self.config.keyspace = Some(keyspace);
self
}
fn with_beta_protocol(mut self, beta_protocol: bool) -> Self {
self.node_config.beta_protocol = beta_protocol;
self
}
fn build(
self,
) -> BoxFuture<
'static,
Result<Session<TransportRustls, RustlsConnectionManager, LB>, SessionBuildError>,
> {
async move {
match verify_compression_configuration(
self.node_config.version,
self.config.compression,
) {
Ok(()) => {
let (keyspace_holder, keyspace_receiver) = create_keyspace_holder();
let connection_manager = RustlsConnectionManager::new(
self.node_config.dns_name,
self.node_config.authenticator_provider,
self.node_config.config,
keyspace_holder.clone(),
self.frame_encoder_factory,
self.config.compression,
self.config.transport_buffer_size,
self.config.tcp_nodelay,
self.node_config.version,
#[cfg(feature = "http-proxy")]
self.node_config.http_proxy,
);
self.config
.into_session(
keyspace_holder,
keyspace_receiver,
self.node_config.contact_points,
connection_manager,
self.node_config.version,
self.node_config.beta_protocol,
)
.await
}
Err(err) => Err(err),
}
}
.boxed()
}
}
struct AbortOnDropHandle(Option<JoinHandle<()>>);
impl AbortOnDropHandle {
fn new(handle: JoinHandle<()>) -> Self {
Self(Some(handle))
}
fn into_inner(mut self) -> JoinHandle<()> {
self.0
.take()
.expect("AbortOnDropHandle inner cannot be None")
}
}
impl Drop for AbortOnDropHandle {
fn drop(&mut self) {
if let Some(handle) = self.0.take() {
handle.abort();
}
}
}
#[cfg(test)]
mod tests {
use crate::cluster::session::{prepare_flags, AbortOnDropHandle};
use cassandra_protocol::frame::Flags;
use tokio::task::JoinHandle;
#[test]
fn prepare_flags_test() {
assert!(prepare_flags(true, false, false).contains(Flags::TRACING));
assert!(prepare_flags(false, true, false).contains(Flags::WARNING));
assert!(prepare_flags(false, false, true).contains(Flags::BETA));
let all = prepare_flags(true, true, true);
assert!(all.contains(Flags::TRACING));
assert!(all.contains(Flags::WARNING));
assert!(all.contains(Flags::BETA));
}
#[tokio::test]
async fn abort_on_drop_handle_aborts_when_dropped() {
let task: JoinHandle<()> = tokio::spawn(async {
std::future::pending::<()>().await;
});
let abort_observer = task.abort_handle();
{
let _guard = AbortOnDropHandle::new(task);
}
tokio::task::yield_now().await;
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
assert!(
abort_observer.is_finished(),
"task must be aborted after AbortOnDropHandle drops"
);
}
#[tokio::test]
async fn abort_on_drop_handle_does_not_abort_when_released() {
let task: JoinHandle<()> = tokio::spawn(async {
std::future::pending::<()>().await;
});
let abort_observer = task.abort_handle();
let guard = AbortOnDropHandle::new(task);
let released = guard.into_inner();
tokio::task::yield_now().await;
assert!(
!abort_observer.is_finished(),
"task must keep running after into_inner"
);
released.abort();
}
}