use super::execution_profile::ExecutionProfileHandle;
use super::session::{Session, SessionConfig};
use super::{Compression, PoolSize, SelfIdentity, WriteCoalescingDelay};
use crate::authentication::{AuthenticatorProvider, PlainTextAuthenticator};
use crate::client::session::TlsContext;
use crate::errors::NewSessionError;
use crate::policies::address_translator::AddressTranslator;
use crate::policies::host_filter::HostFilter;
use crate::policies::timestamp_generator::TimestampGenerator;
use crate::routing::ShardAwarePortRange;
use crate::statement::Consistency;
use std::borrow::Borrow;
use std::marker::PhantomData;
use std::net::{IpAddr, SocketAddr};
use std::num::NonZeroU32;
use std::sync::Arc;
use std::time::Duration;
use tracing::warn;
mod sealed {
#[expect(unnameable_types)]
pub trait Sealed {}
}
pub trait SessionBuilderKind: sealed::Sealed + Clone {}
#[derive(Clone)]
pub enum DefaultMode {}
impl sealed::Sealed for DefaultMode {}
impl SessionBuilderKind for DefaultMode {}
pub type SessionBuilder = GenericSessionBuilder<DefaultMode>;
#[cfg(feature = "unstable-client-routes")]
#[derive(Clone)]
pub enum ClientRoutesMode {}
#[cfg(feature = "unstable-client-routes")]
impl sealed::Sealed for ClientRoutesMode {}
#[cfg(feature = "unstable-client-routes")]
impl SessionBuilderKind for ClientRoutesMode {}
#[cfg(feature = "unstable-client-routes")]
pub type ClientRoutesSessionBuilder = GenericSessionBuilder<ClientRoutesMode>;
#[derive(Clone)]
pub struct GenericSessionBuilder<Kind: SessionBuilderKind> {
pub config: SessionConfig,
kind: PhantomData<Kind>,
}
impl GenericSessionBuilder<DefaultMode> {
pub fn new() -> Self {
SessionBuilder {
config: SessionConfig::new(),
kind: PhantomData,
}
}
}
#[cfg(feature = "unstable-client-routes")]
impl GenericSessionBuilder<ClientRoutesMode> {
pub fn new(config: super::client_routes::ClientRoutesConfig) -> Self {
ClientRoutesSessionBuilder {
config: SessionConfig {
client_routes_config: Some(config),
disallow_shard_aware_port: true,
..SessionConfig::new()
},
kind: PhantomData,
}
}
}
pub trait SessionBuilderKindSupportsKnownNodes: SessionBuilderKind {}
impl SessionBuilderKindSupportsKnownNodes for DefaultMode {}
#[cfg(feature = "unstable-client-routes")]
impl SessionBuilderKindSupportsKnownNodes for ClientRoutesMode {}
impl<K: SessionBuilderKindSupportsKnownNodes> GenericSessionBuilder<K> {
pub fn known_node(mut self, hostname: impl AsRef<str>) -> Self {
self.config.add_known_node(hostname);
self
}
pub fn known_node_addr(mut self, node_addr: SocketAddr) -> Self {
self.config.add_known_node_addr(node_addr);
self
}
pub fn known_nodes(mut self, hostnames: impl IntoIterator<Item = impl AsRef<str>>) -> Self {
self.config.add_known_nodes(hostnames);
self
}
pub fn known_nodes_addr(
mut self,
node_addrs: impl IntoIterator<Item = impl Borrow<SocketAddr>>,
) -> Self {
self.config.add_known_nodes_addr(node_addrs);
self
}
}
pub trait SessionBuilderKindSupportsAddressTranslation: SessionBuilderKind {}
impl SessionBuilderKindSupportsAddressTranslation for DefaultMode {}
impl<K: SessionBuilderKindSupportsAddressTranslation> GenericSessionBuilder<K> {
pub fn address_translator(mut self, translator: Arc<dyn AddressTranslator>) -> Self {
self.config.address_translator = Some(translator);
self
}
}
pub trait SessionBuilderKindSupportsTls: SessionBuilderKind {}
impl SessionBuilderKindSupportsTls for DefaultMode {}
impl<K: SessionBuilderKindSupportsTls> GenericSessionBuilder<K> {
#[cfg_attr(
feature = "openssl-010",
doc = r#"
# Example
```
# async fn example() -> Result<(), Box<dyn std::error::Error>> {
use std::fs;
use std::path::PathBuf;
use scylla::client::session::Session;
use scylla::client::session_builder::SessionBuilder;
use openssl::ssl::{SslContextBuilder, SslVerifyMode, SslMethod, SslFiletype};
let certdir = fs::canonicalize(PathBuf::from("./examples/certs/scylla.crt"))?;
let mut context_builder = SslContextBuilder::new(SslMethod::tls())?;
context_builder.set_certificate_file(certdir.as_path(), SslFiletype::PEM)?;
context_builder.set_verify(SslVerifyMode::NONE);
let session: Session = SessionBuilder::new()
.known_node("127.0.0.1:9042")
.tls_context(Some(context_builder.build()))
.build()
.await?;
# Ok(())
# }
```
"#
)]
pub fn tls_context(mut self, tls_context: Option<impl Into<TlsContext>>) -> Self {
#[cfg_attr(
not(any(feature = "openssl-010", feature = "rustls-023")),
// TODO: make this expect() once MSRV is 1.92+.
allow(unreachable_code)
)]
{
self.config.tls_context = tls_context.map(|t| t.into());
}
self
}
}
impl<K: SessionBuilderKind> GenericSessionBuilder<K> {
pub fn user(mut self, username: impl Into<String>, passwd: impl Into<String>) -> Self {
self.config.authenticator = Some(Arc::new(PlainTextAuthenticator::new(
username.into(),
passwd.into(),
)));
self
}
pub fn authenticator_provider(
mut self,
authenticator_provider: Arc<dyn AuthenticatorProvider>,
) -> Self {
self.config.authenticator = Some(authenticator_provider);
self
}
pub fn local_ip_address(mut self, local_ip_address: Option<impl Into<IpAddr>>) -> Self {
self.config.local_ip_address = local_ip_address.map(Into::into);
self
}
pub fn shard_aware_local_port_range(mut self, port_range: ShardAwarePortRange) -> Self {
self.config.shard_aware_local_port_range = port_range;
self
}
pub fn compression(mut self, compression: Option<Compression>) -> Self {
self.config.compression = compression;
self
}
pub fn schema_agreement_interval(mut self, timeout: Duration) -> Self {
self.config.schema_agreement_interval = timeout;
self
}
pub fn default_execution_profile_handle(
mut self,
profile_handle: ExecutionProfileHandle,
) -> Self {
self.config.default_execution_profile_handle = profile_handle;
self
}
pub fn tcp_nodelay(mut self, nodelay: bool) -> Self {
self.config.tcp_nodelay = nodelay;
self
}
pub fn tcp_keepalive_interval(mut self, interval: Duration) -> Self {
if interval <= Duration::from_secs(1) {
warn!(
"Setting the TCP keepalive interval to low values ({:?}) is not recommended as it can have a negative impact on performance. Consider setting it above 1 second.",
interval
);
}
self.config.tcp_keepalive_interval = Some(interval);
self
}
pub fn tcp_recv_buffer_size(mut self, size: usize) -> Self {
self.config.tcp_recv_buffer_size = Some(size);
self
}
pub fn tcp_send_buffer_size(mut self, size: usize) -> Self {
self.config.tcp_send_buffer_size = Some(size);
self
}
pub fn tcp_reuse_address(mut self, reuse: bool) -> Self {
self.config.tcp_reuse_address = Some(reuse);
self
}
pub fn tcp_linger(mut self, duration: Duration) -> Self {
self.config.tcp_linger = Some(duration);
self
}
pub fn use_keyspace(mut self, keyspace_name: impl Into<String>, case_sensitive: bool) -> Self {
self.config.used_keyspace = Some(keyspace_name.into());
self.config.keyspace_case_sensitive = case_sensitive;
self
}
pub async fn build(&self) -> Result<Session, NewSessionError> {
Session::connect(self.config.clone()).await
}
pub fn connection_timeout(mut self, duration: Duration) -> Self {
self.config.connect_timeout = duration;
self
}
pub fn pool_size(mut self, size: PoolSize) -> Self {
self.config.connection_pool_size = size;
self
}
pub fn disallow_shard_aware_port(mut self, disallow: bool) -> Self {
self.config.disallow_shard_aware_port = disallow;
self
}
pub fn timestamp_generator(mut self, timestamp_generator: Arc<dyn TimestampGenerator>) -> Self {
self.config.timestamp_generator = Some(timestamp_generator);
self
}
pub fn keyspaces_to_fetch(
mut self,
keyspaces: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
self.config.keyspaces_to_fetch = keyspaces.into_iter().map(Into::into).collect();
self
}
pub fn fetch_schema_metadata(mut self, fetch: bool) -> Self {
self.config.fetch_schema_metadata = fetch;
self
}
pub fn metadata_request_serverside_timeout(mut self, timeout: Duration) -> Self {
self.config.metadata_request_serverside_timeout = Some(timeout);
self
}
pub fn keepalive_interval(mut self, interval: Duration) -> Self {
if interval <= Duration::from_secs(1) {
warn!(
"Setting the keepalive interval to low values ({:?}) is not recommended as it can have a negative impact on performance. Consider setting it above 1 second.",
interval
);
}
self.config.keepalive_interval = Some(interval);
self
}
pub fn keepalive_timeout(mut self, timeout: Duration) -> Self {
if timeout <= Duration::from_secs(1) {
warn!(
"Setting the keepalive timeout to low values ({:?}) is not recommended as it may aggressively close connections. Consider setting it above 5 seconds.",
timeout
);
}
self.config.keepalive_timeout = Some(timeout);
self
}
pub fn schema_agreement_timeout(mut self, timeout: Duration) -> Self {
self.config.schema_agreement_timeout = timeout;
self
}
pub fn auto_await_schema_agreement(mut self, enabled: bool) -> Self {
self.config.schema_agreement_automatic_waiting = enabled;
self
}
pub fn hostname_resolution_timeout(mut self, duration: Option<Duration>) -> Self {
self.config.hostname_resolution_timeout = duration;
self
}
pub fn host_filter(mut self, filter: Arc<dyn HostFilter>) -> Self {
self.config.host_filter = Some(filter);
self
}
pub fn refresh_metadata_on_auto_schema_agreement(mut self, refresh_metadata: bool) -> Self {
self.config.refresh_metadata_on_auto_schema_agreement = refresh_metadata;
self
}
pub fn tracing_info_fetch_attempts(mut self, attempts: NonZeroU32) -> Self {
self.config.tracing_info_fetch_attempts = attempts;
self
}
pub fn tracing_info_fetch_interval(mut self, interval: Duration) -> Self {
self.config.tracing_info_fetch_interval = interval;
self
}
pub fn tracing_info_fetch_consistency(mut self, consistency: Consistency) -> Self {
self.config.tracing_info_fetch_consistency = consistency;
self
}
pub fn write_coalescing(mut self, enable: bool) -> Self {
self.config.enable_write_coalescing = enable;
self
}
pub fn write_coalescing_delay(mut self, delay: WriteCoalescingDelay) -> Self {
self.config.write_coalescing_delay = delay;
self
}
pub fn cluster_metadata_refresh_interval(mut self, interval: Duration) -> Self {
self.config.cluster_metadata_refresh_interval = interval;
self
}
pub fn custom_identity(mut self, identity: SelfIdentity<'static>) -> Self {
self.config.identity = identity;
self
}
}
impl Default for SessionBuilder {
fn default() -> Self {
SessionBuilder::new()
}
}
#[cfg(test)]
mod tests {
use scylla_cql::Consistency;
use scylla_cql::frame::types::SerialConsistency;
use super::super::Compression;
use super::SessionBuilder;
use crate::client::execution_profile::{ExecutionProfile, defaults};
use crate::cluster::node::KnownNode;
use crate::test_utils::setup_tracing;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::time::Duration;
#[test]
fn default_session_builder() {
setup_tracing();
let builder = SessionBuilder::new();
assert!(builder.config.known_nodes.is_empty());
assert_eq!(builder.config.compression, None);
}
#[test]
fn add_known_node() {
setup_tracing();
let mut builder = SessionBuilder::new();
builder = builder.known_node("test_hostname");
assert_eq!(
builder.config.known_nodes,
vec![KnownNode::Hostname("test_hostname".into())]
);
assert_eq!(builder.config.compression, None);
}
#[test]
fn add_known_node_addr() {
setup_tracing();
let mut builder = SessionBuilder::new();
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(172, 17, 0, 3)), 1357);
builder = builder.known_node_addr(addr);
assert_eq!(builder.config.known_nodes, vec![KnownNode::Address(addr)]);
assert_eq!(builder.config.compression, None);
}
#[test]
fn add_known_nodes() {
setup_tracing();
let mut builder = SessionBuilder::new();
builder = builder.known_nodes(["test_hostname1", "test_hostname2"]);
assert_eq!(
builder.config.known_nodes,
vec![
KnownNode::Hostname("test_hostname1".into()),
KnownNode::Hostname("test_hostname2".into())
]
);
assert_eq!(builder.config.compression, None);
}
#[test]
fn add_known_nodes_addr() {
setup_tracing();
let mut builder = SessionBuilder::new();
let addr1 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(172, 17, 0, 3)), 1357);
let addr2 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(172, 17, 0, 4)), 9090);
builder = builder.known_nodes_addr([addr1, addr2]);
assert_eq!(
builder.config.known_nodes,
vec![KnownNode::Address(addr1), KnownNode::Address(addr2)]
);
assert_eq!(builder.config.compression, None);
}
#[test]
fn compression() {
setup_tracing();
let mut builder = SessionBuilder::new();
assert_eq!(builder.config.compression, None);
builder = builder.compression(Some(Compression::Lz4));
assert_eq!(builder.config.compression, Some(Compression::Lz4));
builder = builder.compression(Some(Compression::Snappy));
assert_eq!(builder.config.compression, Some(Compression::Snappy));
builder = builder.compression(None);
assert_eq!(builder.config.compression, None);
}
#[test]
fn tcp_nodelay() {
setup_tracing();
let mut builder = SessionBuilder::new();
assert!(builder.config.tcp_nodelay);
builder = builder.tcp_nodelay(false);
assert!(!builder.config.tcp_nodelay);
builder = builder.tcp_nodelay(true);
assert!(builder.config.tcp_nodelay);
}
#[test]
fn use_keyspace() {
setup_tracing();
let mut builder = SessionBuilder::new();
assert_eq!(builder.config.used_keyspace, None);
assert!(!builder.config.keyspace_case_sensitive);
builder = builder.use_keyspace("ks_name_1", true);
assert_eq!(builder.config.used_keyspace, Some("ks_name_1".to_string()));
assert!(builder.config.keyspace_case_sensitive);
builder = builder.use_keyspace("ks_name_2", false);
assert_eq!(builder.config.used_keyspace, Some("ks_name_2".to_string()));
assert!(!builder.config.keyspace_case_sensitive);
}
#[test]
fn connection_timeout() {
setup_tracing();
let mut builder = SessionBuilder::new();
assert_eq!(
builder.config.connect_timeout,
std::time::Duration::from_secs(5)
);
builder = builder.connection_timeout(std::time::Duration::from_secs(10));
assert_eq!(
builder.config.connect_timeout,
std::time::Duration::from_secs(10)
);
}
#[test]
fn fetch_schema_metadata() {
setup_tracing();
let mut builder = SessionBuilder::new();
assert!(builder.config.fetch_schema_metadata);
builder = builder.fetch_schema_metadata(false);
assert!(!builder.config.fetch_schema_metadata);
builder = builder.fetch_schema_metadata(true);
assert!(builder.config.fetch_schema_metadata);
}
#[tokio::test]
async fn execution_profile() {
setup_tracing();
let default_builder = SessionBuilder::new();
let default_execution_profile = default_builder
.config
.default_execution_profile_handle
.access();
assert_eq!(
default_execution_profile.consistency,
defaults::consistency()
);
assert_eq!(
default_execution_profile.serial_consistency,
defaults::serial_consistency()
);
assert_eq!(
default_execution_profile.request_timeout,
defaults::request_timeout()
);
assert_eq!(
default_execution_profile.load_balancing_policy.name(),
defaults::load_balancing_policy().name()
);
let custom_consistency = Consistency::Any;
let custom_serial_consistency = Some(SerialConsistency::Serial);
let custom_timeout = Some(Duration::from_secs(1));
let execution_profile_handle = ExecutionProfile::builder()
.consistency(custom_consistency)
.serial_consistency(custom_serial_consistency)
.request_timeout(custom_timeout)
.build()
.into_handle();
let builder_with_profile =
default_builder.default_execution_profile_handle(execution_profile_handle.clone());
let execution_profile = execution_profile_handle.access();
let profile_in_builder = builder_with_profile
.config
.default_execution_profile_handle
.access();
assert_eq!(
profile_in_builder.consistency,
execution_profile.consistency
);
assert_eq!(
profile_in_builder.serial_consistency,
execution_profile.serial_consistency
);
assert_eq!(
profile_in_builder.request_timeout,
execution_profile.request_timeout
);
assert_eq!(
profile_in_builder.load_balancing_policy.name(),
execution_profile.load_balancing_policy.name()
);
}
#[test]
fn cluster_metadata_refresh_interval() {
setup_tracing();
let builder = SessionBuilder::new();
assert_eq!(
builder.config.cluster_metadata_refresh_interval,
std::time::Duration::from_secs(60)
);
}
#[test]
fn all_features() {
setup_tracing();
let mut builder = SessionBuilder::new();
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(172, 17, 0, 3)), 8465);
let addr1 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(172, 17, 0, 3)), 1357);
let addr2 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(172, 17, 0, 4)), 9090);
builder = builder.known_node("hostname_test");
builder = builder.known_node_addr(addr);
builder = builder.known_nodes(["hostname_test1", "hostname_test2"]);
builder = builder.known_nodes_addr([addr1, addr2]);
builder = builder.compression(Some(Compression::Snappy));
builder = builder.tcp_nodelay(true);
builder = builder.use_keyspace("ks_name", true);
builder = builder.fetch_schema_metadata(false);
builder = builder.cluster_metadata_refresh_interval(Duration::from_secs(1));
assert_eq!(
builder.config.known_nodes,
vec![
KnownNode::Hostname("hostname_test".into()),
KnownNode::Address(addr),
KnownNode::Hostname("hostname_test1".into()),
KnownNode::Hostname("hostname_test2".into()),
KnownNode::Address(addr1),
KnownNode::Address(addr2),
]
);
assert_eq!(builder.config.compression, Some(Compression::Snappy));
assert!(builder.config.tcp_nodelay);
assert_eq!(
builder.config.cluster_metadata_refresh_interval,
Duration::from_secs(1)
);
assert_eq!(builder.config.used_keyspace, Some("ks_name".to_string()));
assert!(builder.config.keyspace_case_sensitive);
assert!(!builder.config.fetch_schema_metadata);
}
fn _check_known_nodes_compatibility(
hostnames: &[impl AsRef<str>],
host_addresses: &[SocketAddr],
) {
let mut sb: SessionBuilder = SessionBuilder::new();
sb = sb.known_nodes(hostnames);
sb = sb.known_nodes_addr(host_addresses);
let mut config = sb.config;
config.add_known_nodes(hostnames);
config.add_known_nodes_addr(host_addresses);
}
}