use bytes::Bytes;
use futures::{future::RemoteHandle, FutureExt};
use tokio::io::{split, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::{TcpSocket, TcpStream};
use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant;
use tracing::{debug, error, trace, warn};
use uuid::Uuid;
#[cfg(feature = "ssl")]
use openssl::ssl::{Ssl, SslContext};
#[cfg(feature = "ssl")]
use std::pin::Pin;
use std::sync::atomic::AtomicU64;
#[cfg(feature = "ssl")]
use tokio_openssl::SslStream;
use crate::authentication::AuthenticatorProvider;
use scylla_cql::frame::response::authenticate::Authenticate;
use std::collections::{BTreeSet, HashMap};
use std::convert::TryFrom;
use std::io::ErrorKind;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::{
cmp::Ordering,
net::{Ipv4Addr, Ipv6Addr},
};
use super::errors::{BadKeyspaceName, BadQuery, DbError, QueryError};
use crate::batch::{Batch, BatchStatement};
use crate::frame::protocol_features::ProtocolFeatures;
use crate::frame::{
self,
request::{self, batch, execute, query, register, Request},
response::{event::Event, result, NonErrorResponse, Response, ResponseOpcode},
server_event_type::EventType,
value::{BatchValues, ValueList},
FrameParams, SerializedRequest,
};
use crate::query::Query;
use crate::routing::ShardInfo;
use crate::statement::prepared_statement::PreparedStatement;
use crate::statement::Consistency;
use crate::transport::session::IntoTypedRows;
use crate::transport::Compression;
pub use crate::QueryResult;
const LOCAL_VERSION: &str = "SELECT schema_version FROM system.local WHERE key='local'";
const OLD_ORPHAN_COUNT_THRESHOLD: usize = 1024;
const OLD_AGE_ORPHAN_THRESHOLD: std::time::Duration = std::time::Duration::from_secs(1);
pub struct Connection {
submit_channel: mpsc::Sender<Task>,
_worker_handle: RemoteHandle<()>,
connect_address: SocketAddr,
config: ConnectionConfig,
features: ConnectionFeatures,
request_id_generator: AtomicU64,
orphan_notification_sender: mpsc::UnboundedSender<RequestId>,
}
#[derive(Default)]
pub(crate) struct ConnectionFeatures {
shard_info: Option<ShardInfo>,
shard_aware_port: Option<u16>,
protocol_features: ProtocolFeatures,
}
type RequestId = u64;
struct ResponseHandler {
response_sender: oneshot::Sender<Result<TaskResponse, QueryError>>,
request_id: RequestId,
}
struct OrphanhoodNotifier<'a> {
enabled: bool,
request_id: RequestId,
notification_sender: &'a mpsc::UnboundedSender<RequestId>,
}
impl<'a> OrphanhoodNotifier<'a> {
fn new(
request_id: RequestId,
notification_sender: &'a mpsc::UnboundedSender<RequestId>,
) -> Self {
Self {
enabled: true,
request_id,
notification_sender,
}
}
fn disable(mut self) {
self.enabled = false;
}
}
impl<'a> Drop for OrphanhoodNotifier<'a> {
fn drop(&mut self) {
if self.enabled {
let _ = self.notification_sender.send(self.request_id);
}
}
}
struct Task {
serialized_request: SerializedRequest,
response_handler: ResponseHandler,
}
struct TaskResponse {
params: FrameParams,
opcode: ResponseOpcode,
body: Bytes,
}
pub struct QueryResponse {
pub response: Response,
pub tracing_id: Option<Uuid>,
pub warnings: Vec<String>,
}
pub struct NonErrorQueryResponse {
pub response: NonErrorResponse,
pub tracing_id: Option<Uuid>,
pub warnings: Vec<String>,
}
impl QueryResponse {
pub fn into_non_error_query_response(self) -> Result<NonErrorQueryResponse, QueryError> {
Ok(NonErrorQueryResponse {
response: self.response.into_non_error_response()?,
tracing_id: self.tracing_id,
warnings: self.warnings,
})
}
pub fn into_query_result(self) -> Result<QueryResult, QueryError> {
self.into_non_error_query_response()?.into_query_result()
}
}
impl NonErrorQueryResponse {
pub fn as_set_keyspace(&self) -> Option<&result::SetKeyspace> {
match &self.response {
NonErrorResponse::Result(result::Result::SetKeyspace(sk)) => Some(sk),
_ => None,
}
}
pub fn as_schema_change(&self) -> Option<&result::SchemaChange> {
match &self.response {
NonErrorResponse::Result(result::Result::SchemaChange(sc)) => Some(sc),
_ => None,
}
}
pub fn into_query_result(self) -> Result<QueryResult, QueryError> {
let (rows, paging_state, col_specs) = match self.response {
NonErrorResponse::Result(result::Result::Rows(rs)) => (
Some(rs.rows),
rs.metadata.paging_state,
rs.metadata.col_specs,
),
NonErrorResponse::Result(_) => (None, None, vec![]),
_ => {
return Err(QueryError::ProtocolError(
"Unexpected server response, expected Result or Error",
))
}
};
Ok(QueryResult {
rows,
warnings: self.warnings,
tracing_id: self.tracing_id,
paging_state,
col_specs,
})
}
}
#[derive(Clone)]
pub struct ConnectionConfig {
pub compression: Option<Compression>,
pub tcp_nodelay: bool,
#[cfg(feature = "ssl")]
pub ssl_context: Option<SslContext>,
pub connect_timeout: std::time::Duration,
pub event_sender: Option<mpsc::Sender<Event>>,
pub default_consistency: Consistency,
pub authenticator: Option<Arc<dyn AuthenticatorProvider>>,
}
impl Default for ConnectionConfig {
fn default() -> Self {
Self {
compression: None,
tcp_nodelay: true,
event_sender: None,
#[cfg(feature = "ssl")]
ssl_context: None,
connect_timeout: std::time::Duration::from_secs(5),
default_consistency: Default::default(),
authenticator: None,
}
}
}
impl ConnectionConfig {
#[cfg(feature = "ssl")]
pub fn is_ssl(&self) -> bool {
self.ssl_context.is_some()
}
#[cfg(not(feature = "ssl"))]
pub fn is_ssl(&self) -> bool {
false
}
}
pub type ErrorReceiver = tokio::sync::oneshot::Receiver<QueryError>;
impl Connection {
pub async fn new(
addr: SocketAddr,
source_port: Option<u16>,
config: ConnectionConfig,
) -> Result<(Self, ErrorReceiver), QueryError> {
let stream_connector = match source_port {
Some(p) => {
tokio::time::timeout(config.connect_timeout, connect_with_source_port(addr, p))
.await
}
None => tokio::time::timeout(config.connect_timeout, TcpStream::connect(addr)).await,
};
let stream = match stream_connector {
Ok(stream) => stream?,
Err(_) => {
return Err(QueryError::TimeoutError);
}
};
stream.set_nodelay(config.tcp_nodelay)?;
let (sender, receiver) = mpsc::channel(1024);
let (error_sender, error_receiver) = tokio::sync::oneshot::channel();
let (orphan_notification_sender, orphan_notification_receiver) = mpsc::unbounded_channel();
let _worker_handle = Self::run_router(
config.clone(),
stream,
receiver,
error_sender,
orphan_notification_receiver,
)
.await?;
let connection = Connection {
submit_channel: sender,
_worker_handle,
config,
features: Default::default(),
connect_address: addr,
request_id_generator: AtomicU64::new(0),
orphan_notification_sender,
};
Ok((connection, error_receiver))
}
pub async fn startup(&self, options: HashMap<String, String>) -> Result<Response, QueryError> {
Ok(self
.send_request(&request::Startup { options }, false, false)
.await?
.response)
}
pub async fn get_options(&self) -> Result<Response, QueryError> {
Ok(self
.send_request(&request::Options {}, false, false)
.await?
.response)
}
pub async fn prepare(&self, query: &Query) -> Result<PreparedStatement, QueryError> {
let query_response = self
.send_request(
&request::Prepare {
query: &query.contents,
},
true,
query.config.tracing,
)
.await?;
let mut prepared_statement = match query_response.response {
Response::Error(err) => return Err(err.into()),
Response::Result(result::Result::Prepared(p)) => PreparedStatement::new(
p.id,
self.features
.protocol_features
.prepared_flags_contain_lwt_mark(p.prepared_metadata.flags as u32),
p.prepared_metadata,
query.contents.clone(),
query.get_page_size(),
query.config.clone(),
),
_ => {
return Err(QueryError::ProtocolError(
"PREPARE: Unexpected server response",
))
}
};
if let Some(tracing_id) = query_response.tracing_id {
prepared_statement.prepare_tracing_ids.push(tracing_id);
}
Ok(prepared_statement)
}
async fn reprepare(
&self,
query: impl Into<Query>,
previous_prepared: &PreparedStatement,
) -> Result<(), QueryError> {
let reprepare_query: Query = query.into();
let reprepared = self.prepare(&reprepare_query).await?;
if reprepared.get_id() != previous_prepared.get_id() {
Err(QueryError::ProtocolError(
"Prepared statement Id changed, md5 sum should stay the same",
))
} else {
Ok(())
}
}
pub async fn authenticate_response(
&self,
response: Option<Vec<u8>>,
) -> Result<QueryResponse, QueryError> {
self.send_request(&request::AuthResponse { response }, false, false)
.await
}
pub async fn query_single_page(
&self,
query: impl Into<Query>,
values: impl ValueList,
) -> Result<QueryResult, QueryError> {
let query: Query = query.into();
let consistency = query
.config
.determine_consistency(self.config.default_consistency);
self.query_single_page_with_consistency(query, &values, consistency)
.await
}
pub async fn query_single_page_with_consistency(
&self,
query: impl Into<Query>,
values: impl ValueList,
consistency: Consistency,
) -> Result<QueryResult, QueryError> {
let query: Query = query.into();
self.query_with_consistency(&query, &values, consistency, None)
.await?
.into_query_result()
}
pub async fn query(
&self,
query: &Query,
values: impl ValueList,
paging_state: Option<Bytes>,
) -> Result<QueryResponse, QueryError> {
self.query_with_consistency(
query,
values,
query
.config
.determine_consistency(self.config.default_consistency),
paging_state,
)
.await
}
pub async fn query_with_consistency(
&self,
query: &Query,
values: impl ValueList,
consistency: Consistency,
paging_state: Option<Bytes>,
) -> Result<QueryResponse, QueryError> {
let serialized_values = values.serialized()?;
let query_frame = query::Query {
contents: &query.contents,
parameters: query::QueryParameters {
consistency,
serial_consistency: query.get_serial_consistency(),
values: &serialized_values,
page_size: query.get_page_size(),
paging_state,
timestamp: query.get_timestamp(),
},
};
self.send_request(&query_frame, true, query.config.tracing)
.await
}
pub async fn query_all(
&self,
query: &Query,
values: impl ValueList,
) -> Result<QueryResult, QueryError> {
self.query_all_with_consistency(
query,
values,
query
.config
.determine_consistency(self.config.default_consistency),
)
.await
}
pub async fn query_all_with_consistency(
&self,
query: &Query,
values: impl ValueList,
consistency: Consistency,
) -> Result<QueryResult, QueryError> {
if query.get_page_size().is_none() {
return Err(QueryError::BadQuery(BadQuery::Other(
"Called Connection::query_all without page size set!".to_string(),
)));
}
let mut final_result = QueryResult::default();
let serialized_values = values.serialized()?;
let mut paging_state: Option<Bytes> = None;
query
.config
.determine_consistency(self.config.default_consistency);
loop {
let mut cur_result: QueryResult = self
.query_with_consistency(query, &serialized_values, consistency, paging_state)
.await?
.into_query_result()?;
paging_state = cur_result.paging_state.take();
final_result.merge_with_next_page_res(cur_result);
if paging_state.is_none() {
return Ok(final_result);
}
}
}
pub async fn execute_single_page(
&self,
prepared_statement: &PreparedStatement,
values: impl ValueList,
paging_state: Option<Bytes>,
) -> Result<QueryResult, QueryError> {
self.execute(prepared_statement, values, paging_state)
.await?
.into_query_result()
}
pub async fn execute(
&self,
prepared_statement: &PreparedStatement,
values: impl ValueList,
paging_state: Option<Bytes>,
) -> Result<QueryResponse, QueryError> {
self.execute_with_consistency(
prepared_statement,
values,
prepared_statement
.config
.determine_consistency(self.config.default_consistency),
paging_state,
)
.await
}
pub async fn execute_with_consistency(
&self,
prepared_statement: &PreparedStatement,
values: impl ValueList,
consistency: Consistency,
paging_state: Option<Bytes>,
) -> Result<QueryResponse, QueryError> {
let serialized_values = values.serialized()?;
let execute_frame = execute::Execute {
id: prepared_statement.get_id().to_owned(),
parameters: query::QueryParameters {
consistency,
serial_consistency: prepared_statement.get_serial_consistency(),
values: &serialized_values,
page_size: prepared_statement.get_page_size(),
timestamp: prepared_statement.get_timestamp(),
paging_state,
},
};
let query_response = self
.send_request(&execute_frame, true, prepared_statement.config.tracing)
.await?;
match &query_response.response {
Response::Error(frame::response::Error {
error: DbError::Unprepared { statement_id },
..
}) => {
debug!("Connection::execute: Got DbError::Unprepared - repreparing statement with id {:?}", statement_id);
self.reprepare(prepared_statement.get_statement(), prepared_statement)
.await?;
self.send_request(&execute_frame, true, prepared_statement.config.tracing)
.await
}
_ => Ok(query_response),
}
}
#[allow(dead_code)]
pub async fn execute_all(
&self,
prepared_statement: &PreparedStatement,
values: impl ValueList,
) -> Result<QueryResult, QueryError> {
if prepared_statement.get_page_size().is_none() {
return Err(QueryError::BadQuery(BadQuery::Other(
"Called Connection::execute_all without page size set!".to_string(),
)));
}
let mut final_result = QueryResult::default();
let serialized_values = values.serialized()?;
let mut paging_state: Option<Bytes> = None;
loop {
let mut cur_result: QueryResult = self
.execute_single_page(prepared_statement, &serialized_values, paging_state)
.await?;
paging_state = cur_result.paging_state.take();
final_result.merge_with_next_page_res(cur_result);
if paging_state.is_none() {
return Ok(final_result);
}
}
}
#[allow(dead_code)]
pub async fn batch(
&self,
batch: &Batch,
values: impl BatchValues,
) -> Result<QueryResult, QueryError> {
self.batch_with_consistency(
batch,
values,
batch
.config
.determine_consistency(self.config.default_consistency),
)
.await
}
pub async fn batch_with_consistency(
&self,
batch: &Batch,
values: impl BatchValues,
consistency: Consistency,
) -> Result<QueryResult, QueryError> {
let statements_iter = batch.statements.iter().map(|s| match s {
BatchStatement::Query(q) => batch::BatchStatement::Query { text: &q.contents },
BatchStatement::PreparedStatement(s) => {
batch::BatchStatement::Prepared { id: s.get_id() }
}
});
let batch_frame = batch::Batch {
statements_count: statements_iter.len(),
statements: statements_iter,
values,
batch_type: batch.get_type(),
consistency,
serial_consistency: batch.get_serial_consistency(),
timestamp: batch.get_timestamp(),
};
loop {
let query_response = self
.send_request(&batch_frame, true, batch.config.tracing)
.await?;
return match query_response.response {
Response::Error(err) => match err.error {
DbError::Unprepared { statement_id } => {
debug!("Connection::batch: got DbError::Unprepared - repreparing statement with id {:?}", statement_id);
let prepared_statement = batch.statements.iter().find_map(|s| match s {
BatchStatement::PreparedStatement(s) if *s.get_id() == statement_id => {
Some(s)
}
_ => None,
});
if let Some(p) = prepared_statement {
self.reprepare(p.get_statement(), p).await?;
continue;
} else {
return Err(QueryError::ProtocolError(
"The server returned a prepared statement Id that did not exist in the batch",
));
}
}
_ => Err(err.into()),
},
Response::Result(_) => Ok(query_response.into_query_result()?),
_ => Err(QueryError::ProtocolError(
"BATCH: Unexpected server response",
)),
};
}
}
pub async fn use_keyspace(
&self,
keyspace_name: &VerifiedKeyspaceName,
) -> Result<(), QueryError> {
let query: Query = match keyspace_name.is_case_sensitive {
true => format!("USE \"{}\"", keyspace_name.as_str()).into(),
false => format!("USE {}", keyspace_name.as_str()).into(),
};
let query_response = self.query(&query, (), None).await?;
match query_response.response {
Response::Result(result::Result::SetKeyspace(set_keyspace)) => {
if set_keyspace.keyspace_name.to_lowercase()
!= keyspace_name.as_str().to_lowercase()
{
return Err(QueryError::ProtocolError(
"USE <keyspace_name> returned response with different keyspace name",
));
}
Ok(())
}
Response::Error(err) => Err(err.into()),
_ => Err(QueryError::ProtocolError(
"USE <keyspace_name> returned unexpected response",
)),
}
}
async fn register(
&self,
event_types_to_register_for: Vec<EventType>,
) -> Result<(), QueryError> {
let register_frame = register::Register {
event_types_to_register_for,
};
match self
.send_request(®ister_frame, true, false)
.await?
.response
{
Response::Ready => Ok(()),
Response::Error(err) => Err(err.into()),
_ => Err(QueryError::ProtocolError(
"Unexpected response to REGISTER message",
)),
}
}
pub async fn fetch_schema_version(&self) -> Result<Uuid, QueryError> {
let (version_id,): (Uuid,) = self
.query_single_page(LOCAL_VERSION, &[])
.await?
.rows
.ok_or(QueryError::ProtocolError("Version query returned not rows"))?
.into_typed::<(Uuid,)>()
.next()
.ok_or(QueryError::ProtocolError("Admin table returned empty rows"))?
.map_err(|_| QueryError::ProtocolError("Row is not uuid type as it should be"))?;
Ok(version_id)
}
fn allocate_request_id(&self) -> RequestId {
self.request_id_generator
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
}
async fn send_request<R: Request>(
&self,
request: &R,
compress: bool,
tracing: bool,
) -> Result<QueryResponse, QueryError> {
let compression = if compress {
self.config.compression
} else {
None
};
let serialized_request = SerializedRequest::make(request, compression, tracing)?;
let request_id = self.allocate_request_id();
let (response_sender, receiver) = oneshot::channel();
let response_handler = ResponseHandler {
response_sender,
request_id,
};
let notifier = OrphanhoodNotifier::new(request_id, &self.orphan_notification_sender);
self.submit_channel
.send(Task {
serialized_request,
response_handler,
})
.await
.map_err(|_| {
QueryError::IoError(Arc::new(std::io::Error::new(
ErrorKind::Other,
"Connection broken",
)))
})?;
let task_response = receiver.await.map_err(|_| {
QueryError::IoError(Arc::new(std::io::Error::new(
ErrorKind::Other,
"Connection broken",
)))
})?;
notifier.disable();
Self::parse_response(
task_response?,
self.config.compression,
&self.features.protocol_features,
)
}
fn parse_response(
task_response: TaskResponse,
compression: Option<Compression>,
features: &ProtocolFeatures,
) -> Result<QueryResponse, QueryError> {
let body_with_ext = frame::parse_response_body_extensions(
task_response.params.flags,
compression,
task_response.body,
)?;
for warn_description in &body_with_ext.warnings {
warn!(
warning = warn_description.as_str(),
"Response from the database contains a warning",
);
}
let response =
Response::deserialize(features, task_response.opcode, &mut &*body_with_ext.body)?;
Ok(QueryResponse {
response,
warnings: body_with_ext.warnings,
tracing_id: body_with_ext.trace_id,
})
}
async fn run_router(
config: ConnectionConfig,
stream: TcpStream,
receiver: mpsc::Receiver<Task>,
error_sender: tokio::sync::oneshot::Sender<QueryError>,
orphan_notification_receiver: mpsc::UnboundedReceiver<RequestId>,
) -> Result<RemoteHandle<()>, std::io::Error> {
#[cfg(feature = "ssl")]
if let Some(context) = &config.ssl_context {
let ssl = Ssl::new(context)?;
let mut stream = SslStream::new(ssl, stream)?;
let _pin = Pin::new(&mut stream).connect().await;
let (task, handle) = Self::router(
config,
stream,
receiver,
error_sender,
orphan_notification_receiver,
)
.remote_handle();
tokio::task::spawn(task);
return Ok(handle);
}
let (task, handle) = Self::router(
config,
stream,
receiver,
error_sender,
orphan_notification_receiver,
)
.remote_handle();
tokio::task::spawn(task);
Ok(handle)
}
async fn router(
config: ConnectionConfig,
stream: (impl AsyncRead + AsyncWrite),
receiver: mpsc::Receiver<Task>,
error_sender: tokio::sync::oneshot::Sender<QueryError>,
orphan_notification_receiver: mpsc::UnboundedReceiver<RequestId>,
) {
let (read_half, write_half) = split(stream);
let handler_map = StdMutex::new(ResponseHandlerMap::new());
let r = Self::reader(
BufReader::with_capacity(8192, read_half),
&handler_map,
config,
);
let w = Self::writer(
BufWriter::with_capacity(8192, write_half),
&handler_map,
receiver,
);
let o = Self::orphaner(&handler_map, orphan_notification_receiver);
let result = futures::try_join!(r, w, o);
let error: QueryError = match result {
Ok(_) => return, Err(err) => err,
};
let response_handlers: HashMap<i16, ResponseHandler> =
handler_map.into_inner().unwrap().into_handlers();
for (_, handler) in response_handlers {
let _ = handler.response_sender.send(Err(error.clone()));
}
let _ = error_sender.send(error);
}
async fn reader(
mut read_half: (impl AsyncRead + Unpin),
handler_map: &StdMutex<ResponseHandlerMap>,
config: ConnectionConfig,
) -> Result<(), QueryError> {
loop {
let (params, opcode, body) = frame::read_response_frame(&mut read_half).await?;
let response = TaskResponse {
params,
opcode,
body,
};
match params.stream.cmp(&-1) {
Ordering::Less => {
continue;
}
Ordering::Equal => {
if let Some(event_sender) = config.event_sender.as_ref() {
Self::handle_event(response, config.compression, event_sender).await?;
}
continue;
}
_ => {}
}
let handler_lookup_res = {
let mut handler_map_guard = handler_map.try_lock().unwrap();
handler_map_guard.lookup(params.stream)
};
use HandlerLookupResult::*;
match handler_lookup_res {
Handler(handler) => {
let _ = handler.response_sender.send(Ok(response));
}
Missing => {
debug!(
"Received response with unexpected StreamId {}",
params.stream
);
return Err(QueryError::ProtocolError(
"Received response with unexpected StreamId",
));
}
Orphaned => {
}
}
}
}
fn alloc_stream_id(
handler_map: &StdMutex<ResponseHandlerMap>,
response_handler: ResponseHandler,
) -> Option<i16> {
let mut handler_map_guard = handler_map.try_lock().unwrap();
match handler_map_guard.allocate(response_handler) {
Ok(stream_id) => Some(stream_id),
Err(response_handler) => {
error!("Could not allocate stream id");
let _ = response_handler
.response_sender
.send(Err(QueryError::UnableToAllocStreamId));
None
}
}
}
async fn writer(
mut write_half: (impl AsyncWrite + Unpin),
handler_map: &StdMutex<ResponseHandlerMap>,
mut task_receiver: mpsc::Receiver<Task>,
) -> Result<(), QueryError> {
while let Some(mut task) = task_receiver.recv().await {
let mut num_requests = 0;
let mut total_sent = 0;
while let Some(stream_id) = Self::alloc_stream_id(handler_map, task.response_handler) {
let mut req = task.serialized_request;
req.set_stream(stream_id);
let req_data: &[u8] = req.get_data();
total_sent += req_data.len();
num_requests += 1;
write_half.write_all(req_data).await?;
task = match task_receiver.try_recv() {
Ok(t) => t,
Err(_) => {
tokio::task::yield_now().await;
match task_receiver.try_recv() {
Ok(t) => t,
Err(_) => break,
}
}
}
}
trace!("Sending {} requests; {} bytes", num_requests, total_sent);
write_half.flush().await?;
}
Ok(())
}
async fn orphaner(
handler_map: &StdMutex<ResponseHandlerMap>,
mut orphan_receiver: mpsc::UnboundedReceiver<RequestId>,
) -> Result<(), QueryError> {
let mut interval = tokio::time::interval(OLD_AGE_ORPHAN_THRESHOLD);
loop {
tokio::select! {
_ = interval.tick() => {
let handler_map_guard = handler_map.try_lock().unwrap();
let old_orphan_count = handler_map_guard.old_orphans_count();
if old_orphan_count > OLD_ORPHAN_COUNT_THRESHOLD {
warn!(
"Too many old orphaned stream ids: {}",
old_orphan_count,
);
return Err(QueryError::TooManyOrphanedStreamIds(old_orphan_count as u16))
}
}
Some(request_id) = orphan_receiver.recv() => {
trace!(
"Trying to orphan stream id associated with request_id = {}",
request_id,
);
let mut handler_map_guard = handler_map.try_lock().unwrap(); handler_map_guard.orphan(request_id);
}
else => { break }
}
}
Ok(())
}
async fn handle_event(
task_response: TaskResponse,
compression: Option<Compression>,
event_sender: &mpsc::Sender<Event>,
) -> Result<(), QueryError> {
let features = ProtocolFeatures::default();
let response = Self::parse_response(task_response, compression, &features)?.response;
let event = match response {
Response::Event(e) => e,
_ => {
warn!("Expected to receive Event response, got {:?}", response);
return Ok(());
}
};
event_sender.send(event).await.map_err(|_| {
QueryError::IoError(Arc::new(std::io::Error::new(
ErrorKind::Other,
"Connection broken",
)))
})
}
pub fn get_shard_info(&self) -> &Option<ShardInfo> {
&self.features.shard_info
}
pub fn get_shard_aware_port(&self) -> Option<u16> {
self.features.shard_aware_port
}
fn set_features(&mut self, features: ConnectionFeatures) {
self.features = features;
}
pub fn get_connect_address(&self) -> SocketAddr {
self.connect_address
}
}
pub async fn open_connection(
addr: SocketAddr,
source_port: Option<u16>,
config: ConnectionConfig,
) -> Result<(Connection, ErrorReceiver), QueryError> {
open_named_connection(
addr,
source_port,
config,
Some("scylla-rust-driver".to_string()),
option_env!("CARGO_PKG_VERSION").map(|v| v.to_string()),
)
.await
}
pub async fn open_named_connection(
addr: SocketAddr,
source_port: Option<u16>,
config: ConnectionConfig,
driver_name: Option<String>,
driver_version: Option<String>,
) -> Result<(Connection, ErrorReceiver), QueryError> {
let (mut connection, error_receiver) =
Connection::new(addr, source_port, config.clone()).await?;
let options_result = connection.get_options().await?;
let shard_aware_port_key = match config.is_ssl() {
true => "SCYLLA_SHARD_AWARE_PORT_SSL",
false => "SCYLLA_SHARD_AWARE_PORT",
};
let mut supported = match options_result {
Response::Supported(supported) => supported,
_ => {
return Err(QueryError::ProtocolError(
"Wrong response to OPTIONS message was received",
));
}
};
let shard_info = ShardInfo::try_from(&supported.options).ok();
let supported_compression = supported.options.remove("COMPRESSION").unwrap_or_default();
let shard_aware_port = supported
.options
.remove(shard_aware_port_key)
.unwrap_or_default()
.into_iter()
.next()
.and_then(|p| p.parse::<u16>().ok());
let protocol_features = ProtocolFeatures::parse_from_supported(&supported.options);
let mut options = HashMap::new();
protocol_features.add_startup_options(&mut options);
let features = ConnectionFeatures {
shard_info,
shard_aware_port,
protocol_features,
};
connection.set_features(features);
options.insert("CQL_VERSION".to_string(), "4.0.0".to_string()); if let Some(name) = driver_name {
options.insert("DRIVER_NAME".to_string(), name);
}
if let Some(version) = driver_version {
options.insert("DRIVER_VERSION".to_string(), version);
}
if let Some(compression) = &config.compression {
let compression_str = compression.to_string();
if supported_compression.iter().any(|c| c == &compression_str) {
options.insert("COMPRESSION".to_string(), compression.to_string());
} else {
connection.config.compression = None;
}
}
let result = connection.startup(options).await?;
match result {
Response::Ready => {}
Response::Authenticate(authenticate) => {
perform_authenticate(&mut connection, &authenticate).await?;
}
_ => {
return Err(QueryError::ProtocolError(
"Unexpected response to STARTUP message",
))
}
}
if connection.config.event_sender.is_some() {
let all_event_types = vec![
EventType::TopologyChange,
EventType::StatusChange,
EventType::SchemaChange,
];
connection.register(all_event_types).await?;
}
Ok((connection, error_receiver))
}
async fn perform_authenticate(
connection: &mut Connection,
authenticate: &Authenticate,
) -> Result<(), QueryError> {
let authenticator = &authenticate.authenticator_name as &str;
match connection.config.authenticator {
Some(ref authenticator_provider) => {
let (mut response, mut auth_session) = authenticator_provider
.start_authentication_session(authenticator)
.await
.map_err(QueryError::InvalidMessage)?;
loop {
match connection
.authenticate_response(response)
.await?.response
{
Response::AuthChallenge(challenge) => {
response = auth_session
.evaluate_challenge(
challenge.authenticate_message.as_deref(),
)
.await
.map_err(QueryError::InvalidMessage)?;
}
Response::AuthSuccess(success) => {
auth_session
.success(success.success_message.as_deref())
.await
.map_err(QueryError::InvalidMessage)?;
break;
}
Response::Error(err) => {
return Err(err.into());
}
_ => {
return Err(QueryError::ProtocolError(
"Unexpected response to Authenticate Response message",
))
}
}
}
},
None => return Err(QueryError::InvalidMessage(
"Authentication is required. You can use SessionBuilder::user(\"user\", \"pass\") to provide credentials \
or SessionBuilder::authenticator_provider to provide custom authenticator".to_string(),
)),
}
Ok(())
}
async fn connect_with_source_port(
addr: SocketAddr,
source_port: u16,
) -> Result<TcpStream, std::io::Error> {
match addr {
SocketAddr::V4(_) => {
let socket = TcpSocket::new_v4()?;
socket.bind(SocketAddr::new(
Ipv4Addr::new(0, 0, 0, 0).into(),
source_port,
))?;
Ok(socket.connect(addr).await?)
}
SocketAddr::V6(_) => {
let socket = TcpSocket::new_v6()?;
socket.bind(SocketAddr::new(
Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into(),
source_port,
))?;
Ok(socket.connect(addr).await?)
}
}
}
struct OrphanageTracker {
orphans: HashMap<i16, Instant>,
by_orphaning_times: BTreeSet<(Instant, i16)>,
}
impl OrphanageTracker {
pub fn new() -> Self {
Self {
orphans: HashMap::new(),
by_orphaning_times: BTreeSet::new(),
}
}
pub fn insert(&mut self, stream_id: i16) {
let now = Instant::now();
self.orphans.insert(stream_id, now);
self.by_orphaning_times.insert((now, stream_id));
}
pub fn remove(&mut self, stream_id: i16) {
if let Some(time) = self.orphans.remove(&stream_id) {
self.by_orphaning_times.remove(&(time, stream_id));
}
}
pub fn contains(&self, stream_id: i16) -> bool {
self.orphans.contains_key(&stream_id)
}
pub fn orphans_older_than(&self, age: std::time::Duration) -> usize {
let minimal_age = Instant::now() - age;
self.by_orphaning_times
.range(..(minimal_age, i16::MAX))
.count() }
}
struct ResponseHandlerMap {
stream_set: StreamIdSet,
handlers: HashMap<i16, ResponseHandler>,
request_to_stream: HashMap<RequestId, i16>,
orphanage_tracker: OrphanageTracker,
}
enum HandlerLookupResult {
Orphaned,
Handler(ResponseHandler),
Missing,
}
impl ResponseHandlerMap {
pub fn new() -> Self {
Self {
stream_set: StreamIdSet::new(),
handlers: HashMap::new(),
request_to_stream: HashMap::new(),
orphanage_tracker: OrphanageTracker::new(),
}
}
pub fn allocate(&mut self, response_handler: ResponseHandler) -> Result<i16, ResponseHandler> {
if let Some(stream_id) = self.stream_set.allocate() {
self.request_to_stream
.insert(response_handler.request_id, stream_id);
let prev_handler = self.handlers.insert(stream_id, response_handler);
assert!(prev_handler.is_none());
Ok(stream_id)
} else {
Err(response_handler)
}
}
pub fn orphan(&mut self, request_id: RequestId) {
if let Some(stream_id) = self.request_to_stream.get(&request_id) {
debug!(
"Orphaning stream_id = {} associated with request_id = {}",
stream_id, request_id
);
self.orphanage_tracker.insert(*stream_id);
self.handlers.remove(stream_id);
self.request_to_stream.remove(&request_id);
}
}
pub fn old_orphans_count(&self) -> usize {
self.orphanage_tracker
.orphans_older_than(OLD_AGE_ORPHAN_THRESHOLD)
}
pub fn lookup(&mut self, stream_id: i16) -> HandlerLookupResult {
self.stream_set.free(stream_id);
if self.orphanage_tracker.contains(stream_id) {
self.orphanage_tracker.remove(stream_id);
return HandlerLookupResult::Orphaned;
}
if let Some(handler) = self.handlers.remove(&stream_id) {
self.request_to_stream.remove(&handler.request_id);
HandlerLookupResult::Handler(handler)
} else {
HandlerLookupResult::Missing
}
}
pub fn into_handlers(self) -> HashMap<i16, ResponseHandler> {
self.handlers
}
}
struct StreamIdSet {
used_bitmap: Box<[u64]>,
}
impl StreamIdSet {
pub fn new() -> Self {
const BITMAP_SIZE: usize = (std::i16::MAX as usize + 1) / 64;
Self {
used_bitmap: vec![0; BITMAP_SIZE].into_boxed_slice(),
}
}
pub fn allocate(&mut self) -> Option<i16> {
for (block_id, block) in self.used_bitmap.iter_mut().enumerate() {
if *block != !0 {
let off = block.trailing_ones();
*block |= 1u64 << off;
let stream_id = off as i16 + block_id as i16 * 64;
return Some(stream_id);
}
}
None
}
pub fn free(&mut self, stream_id: i16) {
let block_id = stream_id as usize / 64;
let off = stream_id as usize % 64;
self.used_bitmap[block_id] &= !(1 << off);
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct VerifiedKeyspaceName {
name: Arc<String>,
pub is_case_sensitive: bool,
}
impl VerifiedKeyspaceName {
pub fn new(keyspace_name: String, case_sensitive: bool) -> Result<Self, BadKeyspaceName> {
Self::verify_keyspace_name_is_valid(&keyspace_name)?;
Ok(VerifiedKeyspaceName {
name: Arc::new(keyspace_name),
is_case_sensitive: case_sensitive,
})
}
pub fn as_str(&self) -> &str {
self.name.as_str()
}
fn verify_keyspace_name_is_valid(keyspace_name: &str) -> Result<(), BadKeyspaceName> {
if keyspace_name.is_empty() {
return Err(BadKeyspaceName::Empty);
}
let keyspace_name_len: usize = keyspace_name.chars().count(); if keyspace_name_len > 48 {
return Err(BadKeyspaceName::TooLong(
keyspace_name.to_string(),
keyspace_name_len,
));
}
for character in keyspace_name.chars() {
match character {
'a'..='z' | 'A'..='Z' | '0'..='9' | '_' => {}
_ => {
return Err(BadKeyspaceName::IllegalCharacter(
keyspace_name.to_string(),
character,
))
}
};
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use scylla_cql::errors::BadQuery;
use scylla_cql::frame::protocol_features::{
LWT_OPTIMIZATION_META_BIT_MASK_KEY, SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION,
};
use scylla_cql::frame::types;
use scylla_proxy::{
Condition, Node, Proxy, Reaction, RequestFrame, RequestOpcode, RequestReaction,
RequestRule, ResponseFrame,
};
use tokio::select;
use tokio::sync::mpsc;
use super::super::errors::QueryError;
use super::ConnectionConfig;
use crate::query::Query;
use crate::transport::connection::open_connection;
use crate::utils::test_utils::unique_keyspace_name;
use crate::{IntoTypedRows, SessionBuilder};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::str::FromStr;
use std::sync::Arc;
async fn resolve_hostname(hostname: &str) -> SocketAddr {
match tokio::net::lookup_host(hostname).await {
Ok(mut addrs) => addrs.next().unwrap(),
Err(_) => {
tokio::net::lookup_host((hostname, 9042)) .await
.unwrap()
.next()
.unwrap()
}
}
}
#[tokio::test]
async fn connection_query_all_execute_all_test() {
let uri = std::env::var("SCYLLA_URI").unwrap_or_else(|_| "127.0.0.1:9042".to_string());
let addr: SocketAddr = resolve_hostname(&uri).await;
let (connection, _) = super::open_connection(addr, None, ConnectionConfig::default())
.await
.unwrap();
let ks = unique_keyspace_name();
{
let session = SessionBuilder::new()
.known_node_addr(addr)
.build()
.await
.unwrap();
session.query(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'SimpleStrategy', 'replication_factor' : 1}}", ks.clone()), &[]).await.unwrap();
session.use_keyspace(ks.clone(), false).await.unwrap();
session
.query("DROP TABLE IF EXISTS connection_query_all_tab", &[])
.await
.unwrap();
session
.query(
"CREATE TABLE IF NOT EXISTS connection_query_all_tab (p int primary key)",
&[],
)
.await
.unwrap();
}
connection
.use_keyspace(&super::VerifiedKeyspaceName::new(ks, false).unwrap())
.await
.unwrap();
let select_query = Query::new("SELECT p FROM connection_query_all_tab").with_page_size(7);
let empty_res = connection.query_all(&select_query, &[]).await.unwrap();
assert!(empty_res.rows.unwrap().is_empty());
let mut prepared_select = connection.prepare(&select_query).await.unwrap();
prepared_select.set_page_size(7);
let empty_res_prepared = connection.execute_all(&prepared_select, &[]).await.unwrap();
assert!(empty_res_prepared.rows.unwrap().is_empty());
let values: Vec<i32> = (0..100).collect();
let mut insert_futures = Vec::new();
let insert_query =
Query::new("INSERT INTO connection_query_all_tab (p) VALUES (?)").with_page_size(7);
for v in &values {
insert_futures.push(connection.query_single_page(insert_query.clone(), (v,)));
}
futures::future::try_join_all(insert_futures).await.unwrap();
let mut results: Vec<i32> = connection
.query_all(&select_query, &[])
.await
.unwrap()
.rows
.unwrap()
.into_typed::<(i32,)>()
.map(|r| r.unwrap().0)
.collect();
results.sort_unstable(); assert_eq!(results, values);
let mut results2: Vec<i32> = connection
.execute_all(&prepared_select, &[])
.await
.unwrap()
.rows
.unwrap()
.into_typed::<(i32,)>()
.map(|r| r.unwrap().0)
.collect();
results2.sort_unstable();
assert_eq!(results2, values);
let insert_res1 = connection.query_all(&insert_query, (0,)).await.unwrap();
assert!(insert_res1.rows.is_none());
let prepared_insert = connection.prepare(&insert_query).await.unwrap();
let insert_res2 = connection
.execute_all(&prepared_insert, (0,))
.await
.unwrap();
assert!(insert_res2.rows.is_none(),);
let no_page_size_query = Query::new("SELECT p FROM connection_query_all_tab");
let no_page_res = connection.query_all(&no_page_size_query, &[]).await;
assert!(matches!(
no_page_res,
Err(QueryError::BadQuery(BadQuery::Other(_)))
));
let prepared_no_page_size_query = connection.prepare(&no_page_size_query).await.unwrap();
let prepared_no_page_res = connection
.execute_all(&prepared_no_page_size_query, &[])
.await;
assert!(matches!(
prepared_no_page_res,
Err(QueryError::BadQuery(BadQuery::Other(_)))
));
}
#[tokio::test]
async fn test_lwt_optimisation_mark_negotiation() {
const MASK: &str = "2137";
let lwt_optimisation_entry = format!("{}={}", LWT_OPTIMIZATION_META_BIT_MASK_KEY, MASK);
let proxy_addr = SocketAddr::from_str("127.0.0.54:9042").unwrap();
let config = ConnectionConfig::default();
let (startup_tx, mut startup_rx) = mpsc::unbounded_channel();
let options_without_lwt_optimisation_support = HashMap::<String, Vec<String>>::new();
let options_with_lwt_optimisation_support = [(
SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION.into(),
vec![lwt_optimisation_entry.clone()],
)]
.into_iter()
.collect::<HashMap<String, Vec<String>>>();
let make_rules = |options| {
vec![
RequestRule(
Condition::RequestOpcode(RequestOpcode::Options),
RequestReaction::forge_response(Arc::new(move |frame: RequestFrame| {
ResponseFrame::forged_supported(frame.params, &options).unwrap()
})),
),
RequestRule(
Condition::RequestOpcode(RequestOpcode::Startup),
RequestReaction::drop_frame().with_feedback_when_performed(startup_tx.clone()),
),
]
};
let mut proxy = Proxy::builder()
.with_node(
Node::builder()
.proxy_address(proxy_addr)
.request_rules(make_rules(options_without_lwt_optimisation_support))
.build_dry_mode(),
)
.build()
.run()
.await
.unwrap();
let startup_without_lwt_optimisation = select! {
_ = open_connection(proxy_addr, None, config.clone()) => unreachable!(),
startup = startup_rx.recv() => startup.unwrap(),
};
proxy.running_nodes[0]
.change_request_rules(Some(make_rules(options_with_lwt_optimisation_support)));
let startup_with_lwt_optimisation = select! {
_ = open_connection(proxy_addr, None, config.clone()) => unreachable!(),
startup = startup_rx.recv() => startup.unwrap(),
};
let _ = proxy.finish().await;
let chosen_options =
types::read_string_map(&mut &*startup_without_lwt_optimisation.body).unwrap();
assert!(!chosen_options.contains_key(SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION));
let chosen_options =
types::read_string_map(&mut &startup_with_lwt_optimisation.body[..]).unwrap();
assert!(chosen_options.contains_key(SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION));
assert_eq!(
chosen_options
.get(SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION)
.unwrap(),
&lwt_optimisation_entry
)
}
}