use bytes::Bytes;
use futures::{future::RemoteHandle, FutureExt};
use tokio::io::{split, AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpSocket, TcpStream};
use tokio::sync::{mpsc, oneshot};
use tracing::{error, warn};
use uuid::Uuid;
#[cfg(feature = "ssl")]
use openssl::ssl::{Ssl, SslContext};
#[cfg(feature = "ssl")]
use std::pin::Pin;
#[cfg(feature = "ssl")]
use tokio_openssl::SslStream;
use std::collections::HashMap;
use std::convert::TryFrom;
use std::io::ErrorKind;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::{
cmp::Ordering,
net::{Ipv4Addr, Ipv6Addr},
};
use super::errors::{BadKeyspaceName, BadQuery, DbError, QueryError};
use crate::batch::{Batch, BatchStatement};
use crate::frame::{
self,
request::{self, batch, execute, query, register, Request},
response::{event::Event, result, result::ColumnSpec, Response, ResponseOpcode},
server_event_type::EventType,
value::{BatchValues, ValueList},
FrameParams, SerializedRequest,
};
use crate::query::Query;
use crate::routing::ShardInfo;
use crate::statement::prepared_statement::PreparedStatement;
use crate::statement::Consistency;
use crate::transport::session::IntoTypedRows;
use crate::transport::Authenticator;
use crate::transport::Authenticator::{
AllowAllAuthenticator, CassandraAllowAllAuthenticator, CassandraPasswordAuthenticator,
PasswordAuthenticator, ScyllaTransitionalAuthenticator,
};
use crate::transport::Compression;
const LOCAL_VERSION: &str = "SELECT schema_version FROM system.local WHERE key='local'";
pub struct Connection {
submit_channel: mpsc::Sender<Task>,
_worker_handle: RemoteHandle<()>,
connect_address: SocketAddr,
source_port: u16,
shard_info: Option<ShardInfo>,
config: ConnectionConfig,
shard_aware_port: Option<u16>,
}
type ResponseHandler = oneshot::Sender<Result<TaskResponse, QueryError>>;
struct Task {
serialized_request: SerializedRequest,
response_handler: ResponseHandler,
}
struct TaskResponse {
params: FrameParams,
opcode: ResponseOpcode,
body: Bytes,
}
pub struct QueryResponse {
pub response: Response,
pub tracing_id: Option<Uuid>,
pub warnings: Vec<String>,
}
#[derive(Default, Debug)]
pub struct QueryResult {
pub rows: Option<Vec<result::Row>>,
pub warnings: Vec<String>,
pub tracing_id: Option<Uuid>,
pub paging_state: Option<Bytes>,
pub col_specs: Vec<ColumnSpec>,
}
impl QueryResult {
pub fn get_column_spec<'a>(&'a self, name: &str) -> Option<(usize, &'a ColumnSpec)> {
self.col_specs
.iter()
.enumerate()
.find(|(_id, spec)| spec.name == name)
}
pub(crate) fn merge_with_next_page_res(&mut self, other: QueryResult) {
if let Some(other_rows) = other.rows {
match &mut self.rows {
Some(self_rows) => self_rows.extend(other_rows),
None => self.rows = Some(other_rows),
}
};
self.warnings.extend(other.warnings);
self.tracing_id = other.tracing_id;
self.paging_state = other.paging_state;
self.col_specs = other.col_specs;
}
}
pub struct BatchResult {
pub warnings: Vec<String>,
pub tracing_id: Option<Uuid>,
}
impl QueryResponse {
pub fn as_set_keyspace(&self) -> Option<&result::SetKeyspace> {
match &self.response {
Response::Result(result::Result::SetKeyspace(sk)) => Some(sk),
_ => None,
}
}
pub fn into_query_result(self) -> Result<QueryResult, QueryError> {
let (rows, paging_state, col_specs) = match self.response {
Response::Error(err) => return Err(err.into()),
Response::Result(result::Result::Rows(rs)) => (
Some(rs.rows),
rs.metadata.paging_state,
rs.metadata.col_specs,
),
Response::Result(_) => (None, None, vec![]),
_ => {
return Err(QueryError::ProtocolError(
"Unexpected server response, expected Result or Error",
))
}
};
Ok(QueryResult {
rows,
warnings: self.warnings,
tracing_id: self.tracing_id,
paging_state,
col_specs,
})
}
}
#[derive(Clone)]
pub struct ConnectionConfig {
pub compression: Option<Compression>,
pub tcp_nodelay: bool,
#[cfg(feature = "ssl")]
pub ssl_context: Option<SslContext>,
pub auth_username: Option<String>,
pub auth_password: Option<String>,
pub connect_timeout: std::time::Duration,
pub event_sender: Option<mpsc::Sender<Event>>,
pub default_consistency: Consistency,
}
impl Default for ConnectionConfig {
fn default() -> Self {
Self {
compression: None,
tcp_nodelay: true,
event_sender: None,
#[cfg(feature = "ssl")]
ssl_context: None,
auth_username: None,
auth_password: None,
connect_timeout: std::time::Duration::from_secs(5),
default_consistency: Default::default(),
}
}
}
impl ConnectionConfig {
#[cfg(feature = "ssl")]
pub fn is_ssl(&self) -> bool {
self.ssl_context.is_some()
}
#[cfg(not(feature = "ssl"))]
pub fn is_ssl(&self) -> bool {
false
}
}
pub type ErrorReceiver = tokio::sync::oneshot::Receiver<QueryError>;
impl Connection {
pub async fn new(
addr: SocketAddr,
source_port: Option<u16>,
config: ConnectionConfig,
) -> Result<(Self, ErrorReceiver), QueryError> {
let stream_connector = match source_port {
Some(p) => {
tokio::time::timeout(config.connect_timeout, connect_with_source_port(addr, p))
.await
}
None => tokio::time::timeout(config.connect_timeout, TcpStream::connect(addr)).await,
};
let stream = match stream_connector {
Ok(stream) => stream?,
Err(_) => {
return Err(QueryError::TimeoutError);
}
};
let source_port = stream.local_addr()?.port();
stream.set_nodelay(config.tcp_nodelay)?;
let (sender, receiver) = mpsc::channel(128);
let (error_sender, error_receiver) = tokio::sync::oneshot::channel();
let _worker_handle =
Self::run_router(config.clone(), stream, receiver, error_sender).await?;
let connection = Connection {
submit_channel: sender,
_worker_handle,
source_port,
connect_address: addr,
shard_info: None,
config,
shard_aware_port: None,
};
Ok((connection, error_receiver))
}
pub async fn startup(&self, options: HashMap<String, String>) -> Result<Response, QueryError> {
Ok(self
.send_request(&request::Startup { options }, false, false)
.await?
.response)
}
pub async fn get_options(&self) -> Result<Response, QueryError> {
Ok(self
.send_request(&request::Options {}, false, false)
.await?
.response)
}
pub async fn prepare(&self, query: &Query) -> Result<PreparedStatement, QueryError> {
let query_response = self
.send_request(
&request::Prepare {
query: &query.contents,
},
true,
query.config.tracing,
)
.await?;
let mut prepared_statement = match query_response.response {
Response::Error(err) => return Err(err.into()),
Response::Result(result::Result::Prepared(p)) => PreparedStatement::new(
p.id,
p.prepared_metadata,
query.contents.clone(),
query.get_page_size(),
),
_ => {
return Err(QueryError::ProtocolError(
"PREPARE: Unexpected server response",
))
}
};
if let Some(tracing_id) = query_response.tracing_id {
prepared_statement.prepare_tracing_ids.push(tracing_id);
}
Ok(prepared_statement)
}
pub async fn authenticate_response(
&self,
username: Option<String>,
password: Option<String>,
authenticator: Authenticator,
) -> Result<QueryResponse, QueryError> {
self.send_request(
&request::AuthResponse {
username,
password,
authenticator,
},
false,
false,
)
.await
}
pub async fn query_single_page(
&self,
query: impl Into<Query>,
values: impl ValueList,
) -> Result<QueryResult, QueryError> {
let query: Query = query.into();
self.query(&query, &values, None).await?.into_query_result()
}
pub async fn query_single_page_by_ref(
&self,
query: &Query,
values: &impl ValueList,
paging_state: Option<Bytes>,
) -> Result<QueryResult, QueryError> {
self.query(query, values, paging_state)
.await?
.into_query_result()
}
pub async fn query(
&self,
query: &Query,
values: impl ValueList,
paging_state: Option<Bytes>,
) -> Result<QueryResponse, QueryError> {
let serialized_values = values.serialized()?;
let query_frame = query::Query {
contents: &query.contents,
parameters: query::QueryParameters {
consistency: query
.config
.determine_consistency(self.config.default_consistency),
serial_consistency: query.get_serial_consistency(),
values: &serialized_values,
page_size: query.get_page_size(),
paging_state,
timestamp: query.get_timestamp(),
},
};
self.send_request(&query_frame, true, query.config.tracing)
.await
}
pub async fn query_all(
&self,
query: &Query,
values: impl ValueList,
) -> Result<QueryResult, QueryError> {
if query.get_page_size().is_none() {
return Err(QueryError::ProtocolError(
"Called Connection::query_all without page size set!",
));
}
let mut final_result = QueryResult::default();
let serialized_values = values.serialized()?;
let mut paging_state: Option<Bytes> = None;
loop {
let mut cur_result: QueryResult = self
.query(query, &serialized_values, paging_state)
.await?
.into_query_result()?;
paging_state = cur_result.paging_state.take();
final_result.merge_with_next_page_res(cur_result);
if paging_state.is_none() {
return Ok(final_result);
}
}
}
pub async fn execute_single_page(
&self,
prepared_statement: &PreparedStatement,
values: impl ValueList,
paging_state: Option<Bytes>,
) -> Result<QueryResult, QueryError> {
self.execute(prepared_statement, values, paging_state)
.await?
.into_query_result()
}
pub async fn execute(
&self,
prepared_statement: &PreparedStatement,
values: impl ValueList,
paging_state: Option<Bytes>,
) -> Result<QueryResponse, QueryError> {
let serialized_values = values.serialized()?;
let execute_frame = execute::Execute {
id: prepared_statement.get_id().to_owned(),
parameters: query::QueryParameters {
consistency: prepared_statement
.config
.determine_consistency(self.config.default_consistency),
serial_consistency: prepared_statement.get_serial_consistency(),
values: &serialized_values,
page_size: prepared_statement.get_page_size(),
timestamp: prepared_statement.get_timestamp(),
paging_state,
},
};
let query_response = self
.send_request(&execute_frame, true, prepared_statement.config.tracing)
.await?;
if let Response::Error(err) = &query_response.response {
if err.error == DbError::Unprepared {
let reprepare_query: Query = prepared_statement.get_statement().into();
let reprepared = self.prepare(&reprepare_query).await?;
if reprepared.get_id() != prepared_statement.get_id() {
return Err(QueryError::ProtocolError(
"Prepared statement Id changed, md5 sum should stay the same",
));
}
return self
.send_request(&execute_frame, true, prepared_statement.config.tracing)
.await;
}
}
Ok(query_response)
}
pub async fn execute_all(
&self,
prepared_statement: &PreparedStatement,
values: impl ValueList,
) -> Result<QueryResult, QueryError> {
if prepared_statement.get_page_size().is_none() {
return Err(QueryError::ProtocolError(
"Called Connection::execute_all without page size set!",
));
}
let mut final_result = QueryResult::default();
let serialized_values = values.serialized()?;
let mut paging_state: Option<Bytes> = None;
loop {
let mut cur_result: QueryResult = self
.execute_single_page(prepared_statement, &serialized_values, paging_state)
.await?;
paging_state = cur_result.paging_state.take();
final_result.merge_with_next_page_res(cur_result);
if paging_state.is_none() {
return Ok(final_result);
}
}
}
pub async fn batch(
&self,
batch: &Batch,
values: impl BatchValues,
) -> Result<BatchResult, QueryError> {
let statements_count = batch.statements.len();
if statements_count != values.len() {
return Err(QueryError::BadQuery(BadQuery::ValueLenMismatch(
values.len(),
statements_count,
)));
}
let statements_iter = batch.statements.iter().map(|s| match s {
BatchStatement::Query(q) => batch::BatchStatement::Query { text: &q.contents },
BatchStatement::PreparedStatement(s) => {
batch::BatchStatement::Prepared { id: s.get_id() }
}
});
let batch_frame = batch::Batch {
statements: statements_iter,
statements_count,
values,
batch_type: batch.get_type(),
consistency: batch
.config
.determine_consistency(self.config.default_consistency),
serial_consistency: batch.get_serial_consistency(),
timestamp: batch.get_timestamp(),
};
let query_response = self
.send_request(&batch_frame, true, batch.config.tracing)
.await?;
match query_response.response {
Response::Error(err) => Err(err.into()),
Response::Result(_) => Ok(BatchResult {
warnings: query_response.warnings,
tracing_id: query_response.tracing_id,
}),
_ => Err(QueryError::ProtocolError(
"BATCH: Unexpected server response",
)),
}
}
pub async fn use_keyspace(
&self,
keyspace_name: &VerifiedKeyspaceName,
) -> Result<(), QueryError> {
let query: Query = match keyspace_name.is_case_sensitive {
true => format!("USE \"{}\"", keyspace_name.as_str()).into(),
false => format!("USE {}", keyspace_name.as_str()).into(),
};
let query_response = self.query(&query, (), None).await?;
match query_response.response {
Response::Result(result::Result::SetKeyspace(set_keyspace)) => {
if set_keyspace.keyspace_name.to_lowercase()
!= keyspace_name.as_str().to_lowercase()
{
return Err(QueryError::ProtocolError(
"USE <keyspace_name> returned response with different keyspace name",
));
}
Ok(())
}
Response::Error(err) => Err(err.into()),
_ => Err(QueryError::ProtocolError(
"USE <keyspace_name> returned unexpected response",
)),
}
}
async fn register(
&self,
event_types_to_register_for: Vec<EventType>,
) -> Result<(), QueryError> {
let register_frame = register::Register {
event_types_to_register_for,
};
match self
.send_request(®ister_frame, true, false)
.await?
.response
{
Response::Ready => Ok(()),
Response::Error(err) => Err(err.into()),
_ => Err(QueryError::ProtocolError(
"Unexpected response to REGISTER message",
)),
}
}
pub async fn fetch_schema_version(&self) -> Result<Uuid, QueryError> {
let (version_id,): (Uuid,) = self
.query_single_page(LOCAL_VERSION, &[])
.await?
.rows
.ok_or(QueryError::ProtocolError("Version query returned not rows"))?
.into_typed::<(Uuid,)>()
.next()
.ok_or(QueryError::ProtocolError("Admin table returned empty rows"))?
.map_err(|_| QueryError::ProtocolError("Row is not uuid type as it should be"))?;
Ok(version_id)
}
async fn send_request<R: Request>(
&self,
request: &R,
compress: bool,
tracing: bool,
) -> Result<QueryResponse, QueryError> {
let compression = if compress {
self.config.compression
} else {
None
};
let serialized_request = SerializedRequest::make(request, compression, tracing)?;
let (sender, receiver) = oneshot::channel();
self.submit_channel
.send(Task {
serialized_request,
response_handler: sender,
})
.await
.map_err(|_| {
QueryError::IoError(Arc::new(std::io::Error::new(
ErrorKind::Other,
"Connection broken",
)))
})?;
let task_response = receiver.await.map_err(|_| {
QueryError::IoError(Arc::new(std::io::Error::new(
ErrorKind::Other,
"Connection broken",
)))
})??;
Self::parse_response(task_response, self.config.compression)
}
fn parse_response(
task_response: TaskResponse,
compression: Option<Compression>,
) -> Result<QueryResponse, QueryError> {
let body_with_ext = frame::parse_response_body_extensions(
task_response.params.flags,
compression,
task_response.body,
)?;
for warn_description in &body_with_ext.warnings {
warn!(warning = warn_description.as_str());
}
let response = Response::deserialize(task_response.opcode, &mut &*body_with_ext.body)?;
Ok(QueryResponse {
response,
warnings: body_with_ext.warnings,
tracing_id: body_with_ext.trace_id,
})
}
#[cfg(feature = "ssl")]
async fn run_router(
config: ConnectionConfig,
stream: TcpStream,
receiver: mpsc::Receiver<Task>,
error_sender: tokio::sync::oneshot::Sender<QueryError>,
) -> Result<RemoteHandle<()>, std::io::Error> {
let res = match config.ssl_context {
Some(ref context) => {
let ssl = Ssl::new(context)?;
let mut stream = SslStream::new(ssl, stream)?;
let _pin = Pin::new(&mut stream).connect().await;
Self::run_router_spawner(stream, receiver, error_sender, config)
}
None => Self::run_router_spawner(stream, receiver, error_sender, config),
};
Ok(res)
}
#[cfg(not(feature = "ssl"))]
async fn run_router(
config: ConnectionConfig,
stream: TcpStream,
receiver: mpsc::Receiver<Task>,
error_sender: tokio::sync::oneshot::Sender<QueryError>,
) -> Result<RemoteHandle<()>, std::io::Error> {
Ok(Self::run_router_spawner(
stream,
receiver,
error_sender,
config,
))
}
fn run_router_spawner(
stream: (impl AsyncRead + AsyncWrite + Send + 'static),
receiver: mpsc::Receiver<Task>,
error_sender: tokio::sync::oneshot::Sender<QueryError>,
config: ConnectionConfig,
) -> RemoteHandle<()> {
let (task, handle) = Self::router(stream, receiver, error_sender, config).remote_handle();
tokio::task::spawn(task);
handle
}
async fn router(
stream: (impl AsyncRead + AsyncWrite),
receiver: mpsc::Receiver<Task>,
error_sender: tokio::sync::oneshot::Sender<QueryError>,
config: ConnectionConfig,
) {
let (read_half, write_half) = split(stream);
let handler_map = StdMutex::new(ResponseHandlerMap::new());
let r = Self::reader(read_half, &handler_map, config);
let w = Self::writer(write_half, &handler_map, receiver);
let result = futures::try_join!(r, w);
let error: QueryError = match result {
Ok(_) => return, Err(err) => err,
};
let response_handlers: HashMap<i16, ResponseHandler> =
handler_map.into_inner().unwrap().into_handlers();
for (_, handler) in response_handlers {
let _ = handler.send(Err(error.clone()));
}
let _ = error_sender.send(error);
}
async fn reader(
mut read_half: (impl AsyncRead + Unpin),
handler_map: &StdMutex<ResponseHandlerMap>,
config: ConnectionConfig,
) -> Result<(), QueryError> {
loop {
let (params, opcode, body) = frame::read_response_frame(&mut read_half).await?;
let response = TaskResponse {
params,
opcode,
body,
};
match params.stream.cmp(&-1) {
Ordering::Less => {
continue;
}
Ordering::Equal => {
if let Some(event_sender) = config.event_sender.as_ref() {
Self::handle_event(response, config.compression, event_sender).await?;
}
continue;
}
_ => {}
}
let handler = {
let mut lock = handler_map.try_lock().unwrap();
lock.take(params.stream)
};
if let Some(handler) = handler {
let _ = handler.send(Ok(response));
} else {
return Err(QueryError::ProtocolError(
"Received reponse with unexpected StreamId",
));
}
}
}
async fn writer(
mut write_half: (impl AsyncWrite + Unpin),
handler_map: &StdMutex<ResponseHandlerMap>,
mut task_receiver: mpsc::Receiver<Task>,
) -> Result<(), QueryError> {
while let Some(task) = task_receiver.recv().await {
let stream_id = {
let mut lock = handler_map.try_lock().unwrap();
if let Some(stream_id) = lock.allocate(task.response_handler) {
stream_id
} else {
error!("Could not allocate stream id");
continue;
}
};
let mut req = task.serialized_request;
req.set_stream(stream_id);
write_half.write_all(req.get_data()).await?;
}
Ok(())
}
async fn handle_event(
task_response: TaskResponse,
compression: Option<Compression>,
event_sender: &mpsc::Sender<Event>,
) -> Result<(), QueryError> {
let response = Self::parse_response(task_response, compression)?.response;
let event = match response {
Response::Event(e) => e,
_ => {
warn!("Expected to receive Event response, got {:?}", response);
return Ok(());
}
};
event_sender.send(event).await.map_err(|_| {
QueryError::IoError(Arc::new(std::io::Error::new(
ErrorKind::Other,
"Connection broken",
)))
})
}
pub fn get_shard_info(&self) -> &Option<ShardInfo> {
&self.shard_info
}
pub fn get_is_shard_aware(&self) -> bool {
Some(self.connect_address.port()) == self.shard_aware_port
}
pub fn get_shard_aware_port(&self) -> Option<u16> {
self.shard_aware_port
}
pub fn get_source_port(&self) -> u16 {
self.source_port
}
fn set_shard_info(&mut self, shard_info: Option<ShardInfo>) {
self.shard_info = shard_info
}
fn set_shard_aware_port(&mut self, shard_aware_port: Option<u16>) {
self.shard_aware_port = shard_aware_port;
}
pub fn get_connect_address(&self) -> SocketAddr {
self.connect_address
}
}
pub async fn open_connection(
addr: SocketAddr,
source_port: Option<u16>,
config: ConnectionConfig,
) -> Result<(Connection, ErrorReceiver), QueryError> {
open_named_connection(
addr,
source_port,
config,
Some("scylla-rust-driver".to_string()),
)
.await
}
pub async fn open_named_connection(
addr: SocketAddr,
source_port: Option<u16>,
config: ConnectionConfig,
driver_name: Option<String>,
) -> Result<(Connection, ErrorReceiver), QueryError> {
let (mut connection, error_receiver) =
Connection::new(addr, source_port, config.clone()).await?;
let options_result = connection.get_options().await?;
let shard_aware_port_key = match config.is_ssl() {
true => "SCYLLA_SHARD_AWARE_PORT_SSL",
false => "SCYLLA_SHARD_AWARE_PORT",
};
let (shard_info, supported_compression, shard_aware_port) = match options_result {
Response::Supported(mut supported) => {
let shard_info = ShardInfo::try_from(&supported.options).ok();
let supported_compression = supported
.options
.remove("COMPRESSION")
.unwrap_or_else(Vec::new);
let shard_aware_port = supported
.options
.remove(shard_aware_port_key)
.unwrap_or_else(Vec::new)
.into_iter()
.next()
.and_then(|p| p.parse::<u16>().ok());
(shard_info, supported_compression, shard_aware_port)
}
_ => (None, Vec::new(), None),
};
connection.set_shard_info(shard_info);
connection.set_shard_aware_port(shard_aware_port);
let mut options = HashMap::new();
options.insert("CQL_VERSION".to_string(), "4.0.0".to_string()); if let Some(name) = driver_name {
options.insert("DRIVER_NAME".to_string(), name);
}
if let Some(compression) = &config.compression {
let compression_str = compression.to_string();
if supported_compression.iter().any(|c| c == &compression_str) {
options.insert("COMPRESSION".to_string(), compression.to_string());
} else {
connection.config.compression = None;
}
}
let result = connection.startup(options).await?;
match result {
Response::Ready => {}
Response::Authenticate(authenticate) => {
let authenticator: Authenticator = match &authenticate.authenticator_name as &str {
"AllowAllAuthenticator" => AllowAllAuthenticator,
"PasswordAuthenticator" => PasswordAuthenticator,
"org.apache.cassandra.auth.PasswordAuthenticator" => CassandraPasswordAuthenticator,
"org.apache.cassandra.auth.AllowAllAuthenticator" => CassandraAllowAllAuthenticator,
"com.scylladb.auth.TransitionalAuthenticator" => ScyllaTransitionalAuthenticator,
_ => unimplemented!(
"Authenticator not supported, {}",
authenticate.authenticator_name
),
};
let username = connection.config.auth_username.to_owned();
let password = connection.config.auth_password.to_owned();
let auth_result = connection
.authenticate_response(username, password, authenticator)
.await?;
match auth_result.response {
Response::AuthChallenge(authenticate_challenge) => {
let challenge_message = authenticate_challenge.authenticate_message;
unimplemented!(
"Auth Challenge not implemented yet, {:?}",
challenge_message
)
}
Response::AuthSuccess(_authenticate_success) => {
}
Response::Error(err) => {
return Err(err.into());
}
_ => {
return Err(QueryError::ProtocolError(
"Unexpected response to Authenticate Response message",
))
}
}
}
_ => {
return Err(QueryError::ProtocolError(
"Unexpected response to STARTUP message",
))
}
}
if connection.config.event_sender.is_some() {
let all_event_types = vec![
EventType::TopologyChange,
EventType::StatusChange,
EventType::SchemaChange,
];
connection.register(all_event_types).await?;
}
Ok((connection, error_receiver))
}
async fn connect_with_source_port(
addr: SocketAddr,
source_port: u16,
) -> Result<TcpStream, std::io::Error> {
match addr {
SocketAddr::V4(_) => {
let socket = TcpSocket::new_v4()?;
socket.bind(SocketAddr::new(
Ipv4Addr::new(0, 0, 0, 0).into(),
source_port,
))?;
Ok(socket.connect(addr).await?)
}
SocketAddr::V6(_) => {
let socket = TcpSocket::new_v6()?;
socket.bind(SocketAddr::new(
Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into(),
source_port,
))?;
Ok(socket.connect(addr).await?)
}
}
}
struct ResponseHandlerMap {
stream_set: StreamIdSet,
handlers: HashMap<i16, ResponseHandler>,
}
impl ResponseHandlerMap {
pub fn new() -> Self {
Self {
stream_set: StreamIdSet::new(),
handlers: HashMap::new(),
}
}
pub fn allocate(&mut self, response_handler: ResponseHandler) -> Option<i16> {
let stream_id = self.stream_set.allocate()?;
let prev_handler = self.handlers.insert(stream_id, response_handler);
assert!(prev_handler.is_none());
Some(stream_id)
}
pub fn take(&mut self, stream_id: i16) -> Option<ResponseHandler> {
self.stream_set.free(stream_id);
self.handlers.remove(&stream_id)
}
pub fn into_handlers(self) -> HashMap<i16, ResponseHandler> {
self.handlers
}
}
struct StreamIdSet {
used_bitmap: Box<[u64]>,
}
impl StreamIdSet {
pub fn new() -> Self {
const BITMAP_SIZE: usize = (std::i16::MAX as usize + 1) / 64;
Self {
used_bitmap: vec![0; BITMAP_SIZE].into_boxed_slice(),
}
}
pub fn allocate(&mut self) -> Option<i16> {
for (block_id, block) in self.used_bitmap.iter_mut().enumerate() {
if *block != !0 {
let off = block.trailing_ones();
*block |= 1u64 << off;
let stream_id = off as i16 + block_id as i16 * 64;
return Some(stream_id);
}
}
None
}
pub fn free(&mut self, stream_id: i16) {
let block_id = stream_id as usize / 64;
let off = stream_id as usize % 64;
self.used_bitmap[block_id] &= !(1 << off);
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct VerifiedKeyspaceName {
name: Arc<String>,
pub is_case_sensitive: bool,
}
impl VerifiedKeyspaceName {
pub fn new(keyspace_name: String, case_sensitive: bool) -> Result<Self, BadKeyspaceName> {
Self::verify_keyspace_name_is_valid(&keyspace_name)?;
Ok(VerifiedKeyspaceName {
name: Arc::new(keyspace_name),
is_case_sensitive: case_sensitive,
})
}
pub fn as_str(&self) -> &str {
self.name.as_str()
}
fn verify_keyspace_name_is_valid(keyspace_name: &str) -> Result<(), BadKeyspaceName> {
if keyspace_name.is_empty() {
return Err(BadKeyspaceName::Empty);
}
let keyspace_name_len: usize = keyspace_name.chars().count(); if keyspace_name_len > 48 {
return Err(BadKeyspaceName::TooLong(
keyspace_name.to_string(),
keyspace_name_len,
));
}
for character in keyspace_name.chars() {
match character {
'a'..='z' | 'A'..='Z' | '0'..='9' | '_' => {}
_ => {
return Err(BadKeyspaceName::IllegalCharacter(
keyspace_name.to_string(),
character,
))
}
};
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::super::errors::QueryError;
use super::ConnectionConfig;
use crate::query::Query;
use crate::IntoTypedRows;
use std::net::SocketAddr;
async fn resolve_hostname(hostname: &str) -> SocketAddr {
match tokio::net::lookup_host(hostname).await {
Ok(mut addrs) => addrs.next().unwrap(),
Err(_) => {
tokio::net::lookup_host((hostname, 9042)) .await
.unwrap()
.next()
.unwrap()
}
}
}
#[tokio::test]
async fn connection_query_all_execute_all_test() {
let uri = std::env::var("SCYLLA_URI").unwrap_or_else(|_| "127.0.0.1:9042".to_string());
let addr: SocketAddr = resolve_hostname(&uri).await;
let (connection, _) = super::open_connection(addr, None, ConnectionConfig::default())
.await
.unwrap();
connection.query_single_page("CREATE KEYSPACE IF NOT EXISTS ks WITH REPLICATION = {'class' : 'SimpleStrategy', 'replication_factor' : 1}", &[]).await.unwrap();
connection
.query_single_page("DROP TABLE IF EXISTS ks.connection_query_all_tab", &[])
.await
.unwrap();
connection
.query_single_page(
"CREATE TABLE IF NOT EXISTS ks.connection_query_all_tab (p int primary key)",
&[],
)
.await
.unwrap();
let select_query =
Query::new("SELECT p FROM ks.connection_query_all_tab").with_page_size(7);
let empty_res = connection.query_all(&select_query, &[]).await.unwrap();
assert!(empty_res.rows.unwrap().is_empty());
let mut prepared_select = connection.prepare(&select_query).await.unwrap();
prepared_select.set_page_size(7);
let empty_res_prepared = connection.execute_all(&prepared_select, &[]).await.unwrap();
assert!(empty_res_prepared.rows.unwrap().is_empty());
let values: Vec<i32> = (0..100).collect();
let mut insert_futures = Vec::new();
let insert_query =
Query::new("INSERT INTO ks.connection_query_all_tab (p) VALUES (?)").with_page_size(7);
for v in &values {
insert_futures.push(connection.query_single_page(insert_query.clone(), (v,)));
}
futures::future::try_join_all(insert_futures).await.unwrap();
let mut results: Vec<i32> = connection
.query_all(&select_query, &[])
.await
.unwrap()
.rows
.unwrap()
.into_typed::<(i32,)>()
.map(|r| r.unwrap().0)
.collect();
results.sort_unstable(); assert_eq!(results, values);
let mut results2: Vec<i32> = connection
.execute_all(&prepared_select, &[])
.await
.unwrap()
.rows
.unwrap()
.into_typed::<(i32,)>()
.map(|r| r.unwrap().0)
.collect();
results2.sort_unstable();
assert_eq!(results2, values);
let insert_res1 = connection.query_all(&insert_query, (0,)).await.unwrap();
assert!(insert_res1.rows.is_none());
let prepared_insert = connection.prepare(&insert_query).await.unwrap();
let insert_res2 = connection
.execute_all(&prepared_insert, (0,))
.await
.unwrap();
assert!(insert_res2.rows.is_none(),);
let no_page_size_query = Query::new("SELECT p FROM ks.connection_query_all_tab");
let no_page_res = connection.query_all(&no_page_size_query, &[]).await;
assert!(matches!(no_page_res, Err(QueryError::ProtocolError(_))));
let prepared_no_page_size_query = connection.prepare(&no_page_size_query).await.unwrap();
let prepared_no_page_res = connection
.execute_all(&prepared_no_page_size_query, &[])
.await;
assert!(matches!(
prepared_no_page_res,
Err(QueryError::ProtocolError(_))
));
}
}