use arc_swap::{ArcSwap, Guard};
use bytes::{Bytes, BytesMut};
use scylla_cql::frame::response::result::{
ColumnSpec, PartitionKeyIndex, ResultMetadata, TableSpec,
};
use scylla_cql::frame::types::RawValue;
use scylla_cql::serialize::SerializationError;
use scylla_cql::serialize::row::{RowSerializationContext, SerializeRow, SerializedValues};
use smallvec::{SmallVec, smallvec};
use std::convert::TryInto;
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use uuid::Uuid;
use super::{PageSize, StatementConfig};
use crate::client::execution_profile::ExecutionProfileHandle;
use crate::errors::{BadQuery, ExecutionError};
use crate::frame::response::result::{self, PreparedMetadata};
use crate::frame::types::{Consistency, SerialConsistency};
use crate::observability::history::HistoryListener;
use crate::policies::load_balancing::LoadBalancingPolicy;
use crate::policies::retry::RetryPolicy;
use crate::response::query_result::ColumnSpecs;
use crate::routing::Token;
use crate::routing::partitioner::{Partitioner, PartitionerHasher, PartitionerName};
use crate::statement::Statement;
pub(crate) struct RawPreparedStatement<'statement> {
statement: &'statement Statement,
prepared_response: result::Prepared,
is_lwt: bool,
tracing_id: Option<Uuid>,
}
impl<'statement> RawPreparedStatement<'statement> {
pub(crate) fn new(
statement: &'statement Statement,
prepared_response: result::Prepared,
is_lwt: bool,
tracing_id: Option<Uuid>,
) -> Self {
Self {
statement,
prepared_response,
is_lwt,
tracing_id,
}
}
pub(crate) fn get_id(&self) -> &Bytes {
&self.prepared_response.id
}
pub(crate) fn tracing_id(&self) -> Option<Uuid> {
self.tracing_id
}
pub(crate) fn into_response(self) -> result::Prepared {
self.prepared_response
}
}
impl RawPreparedStatement<'_> {
pub(crate) fn into_prepared_statement(self) -> PreparedStatement {
let Self {
statement,
prepared_response,
is_lwt,
tracing_id,
} = self;
let mut prepared_statement = PreparedStatement::new(
prepared_response.id,
is_lwt,
prepared_response.prepared_metadata,
Arc::new(prepared_response.result_metadata),
statement.contents.clone(),
statement.get_validated_page_size(),
statement.config.clone(),
);
if let Some(tracing_id) = tracing_id {
prepared_statement.prepare_tracing_ids.push(tracing_id);
}
prepared_statement
}
}
#[derive(Debug)]
pub struct PreparedStatement {
pub(crate) config: StatementConfig,
pub prepare_tracing_ids: Vec<Uuid>,
shared: Arc<PreparedStatementSharedData>,
page_size: PageSize,
partitioner_name: PartitionerName,
}
#[derive(Debug)]
struct PreparedStatementSharedData {
id: Bytes,
metadata: PreparedMetadata,
initial_result_metadata: Arc<ResultMetadata<'static>>,
current_result_metadata: ArcSwap<ResultMetadata<'static>>,
statement: String,
is_confirmed_lwt: bool,
}
impl Clone for PreparedStatement {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
prepare_tracing_ids: Vec::new(),
shared: self.shared.clone(),
page_size: self.page_size,
partitioner_name: self.partitioner_name.clone(),
}
}
}
#[derive(Debug)]
pub struct ColumnSpecsGuard {
result: Guard<Arc<ResultMetadata<'static>>>,
}
impl ColumnSpecsGuard {
pub fn get(&self) -> ColumnSpecs<'_, 'static> {
ColumnSpecs::new(self.result.col_specs())
}
}
impl PreparedStatement {
fn new(
id: Bytes,
is_lwt: bool,
metadata: PreparedMetadata,
result_metadata: Arc<ResultMetadata<'static>>,
statement: String,
page_size: PageSize,
config: StatementConfig,
) -> Self {
Self {
shared: Arc::new(PreparedStatementSharedData {
id,
metadata,
initial_result_metadata: Arc::clone(&result_metadata),
current_result_metadata: ArcSwap::from(result_metadata),
statement,
is_confirmed_lwt: is_lwt,
}),
prepare_tracing_ids: Vec::new(),
page_size,
partitioner_name: Default::default(),
config,
}
}
pub fn get_id(&self) -> &Bytes {
&self.shared.id
}
pub fn get_statement(&self) -> &str {
&self.shared.statement
}
pub fn set_page_size(&mut self, page_size: i32) {
self.page_size = page_size
.try_into()
.unwrap_or_else(|err| panic!("PreparedStatement::set_page_size: {err}"));
}
pub(crate) fn get_validated_page_size(&self) -> PageSize {
self.page_size
}
pub fn get_page_size(&self) -> i32 {
self.page_size.inner()
}
pub fn get_prepare_tracing_ids(&self) -> &[Uuid] {
&self.prepare_tracing_ids
}
pub fn is_token_aware(&self) -> bool {
!self.get_prepared_metadata().pk_indexes.is_empty()
}
pub fn is_confirmed_lwt(&self) -> bool {
self.shared.is_confirmed_lwt
}
pub fn compute_partition_key(
&self,
bound_values: &impl SerializeRow,
) -> Result<Bytes, PartitionKeyError> {
let serialized = self.serialize_values(bound_values)?;
let partition_key = self.extract_partition_key(&serialized)?;
let mut buf = BytesMut::new();
let mut writer = |chunk: &[u8]| buf.extend_from_slice(chunk);
partition_key.write_encoded_partition_key(&mut writer)?;
Ok(buf.freeze())
}
pub(crate) fn extract_partition_key<'ps>(
&'ps self,
bound_values: &'ps SerializedValues,
) -> Result<PartitionKey<'ps>, PartitionKeyExtractionError> {
PartitionKey::new(self.get_prepared_metadata(), bound_values)
}
pub(crate) fn extract_partition_key_and_calculate_token<'ps>(
&'ps self,
partitioner_name: &'ps PartitionerName,
serialized_values: &'ps SerializedValues,
) -> Result<Option<(PartitionKey<'ps>, Token)>, PartitionKeyError> {
if !self.is_token_aware() {
return Ok(None);
}
let partition_key = self.extract_partition_key(serialized_values)?;
let token = partition_key.calculate_token(partitioner_name)?;
Ok(Some((partition_key, token)))
}
pub fn calculate_token(
&self,
values: &impl SerializeRow,
) -> Result<Option<Token>, PartitionKeyError> {
self.calculate_token_untyped(&self.serialize_values(values)?)
}
pub(crate) fn calculate_token_untyped(
&self,
values: &SerializedValues,
) -> Result<Option<Token>, PartitionKeyError> {
self.extract_partition_key_and_calculate_token(&self.partitioner_name, values)
.map(|opt| opt.map(|(_pk, token)| token))
}
pub fn get_table_spec(&self) -> Option<&TableSpec<'_>> {
self.get_prepared_metadata()
.col_specs
.first()
.map(|spec| spec.table_spec())
}
pub fn get_keyspace_name(&self) -> Option<&str> {
self.get_prepared_metadata()
.col_specs
.first()
.map(|col_spec| col_spec.table_spec().ks_name())
}
pub fn get_table_name(&self) -> Option<&str> {
self.get_prepared_metadata()
.col_specs
.first()
.map(|col_spec| col_spec.table_spec().table_name())
}
pub fn set_consistency(&mut self, c: Consistency) {
self.config.consistency = Some(c);
}
pub fn unset_consistency(&mut self) {
self.config.consistency = None;
}
pub fn get_consistency(&self) -> Option<Consistency> {
self.config.consistency
}
pub fn set_serial_consistency(&mut self, sc: Option<SerialConsistency>) {
self.config.serial_consistency = Some(sc);
}
pub fn unset_serial_consistency(&mut self) {
self.config.serial_consistency = None;
}
pub fn get_serial_consistency(&self) -> Option<SerialConsistency> {
self.config.serial_consistency.flatten()
}
pub fn set_is_idempotent(&mut self, is_idempotent: bool) {
self.config.is_idempotent = is_idempotent;
}
pub fn get_is_idempotent(&self) -> bool {
self.config.is_idempotent
}
pub fn set_tracing(&mut self, should_trace: bool) {
self.config.tracing = should_trace;
}
pub fn get_tracing(&self) -> bool {
self.config.tracing
}
pub fn set_use_cached_result_metadata(&mut self, use_cached_metadata: bool) {
self.config.skip_result_metadata = use_cached_metadata;
}
pub fn get_use_cached_result_metadata(&self) -> bool {
self.config.skip_result_metadata
}
pub fn set_timestamp(&mut self, timestamp: Option<i64>) {
self.config.timestamp = timestamp
}
pub fn get_timestamp(&self) -> Option<i64> {
self.config.timestamp
}
pub fn set_request_timeout(&mut self, timeout: Option<Duration>) {
self.config.request_timeout = timeout
}
pub fn get_request_timeout(&self) -> Option<Duration> {
self.config.request_timeout
}
pub(crate) fn set_partitioner_name(&mut self, partitioner_name: PartitionerName) {
self.partitioner_name = partitioner_name;
}
pub(crate) fn get_prepared_metadata(&self) -> &PreparedMetadata {
&self.shared.metadata
}
pub fn get_variable_col_specs(&self) -> ColumnSpecs<'_, 'static> {
ColumnSpecs::new(&self.shared.metadata.col_specs)
}
pub fn get_variable_pk_indexes(&self) -> &[PartitionKeyIndex] {
&self.shared.metadata.pk_indexes
}
pub(crate) fn get_current_result_metadata(&self) -> Arc<ResultMetadata<'static>> {
self.shared.current_result_metadata.load_full()
}
#[allow(dead_code)]
pub(crate) fn update_current_result_metadata(
&self,
new_metadata: Arc<ResultMetadata<'static>>,
) {
self.shared.current_result_metadata.store(new_metadata);
}
#[deprecated(
since = "1.4.0",
note = "This method may return outdated metadata. Use get_current_result_set_col_specs() instead."
)]
pub fn get_result_set_col_specs(&self) -> ColumnSpecs<'_, 'static> {
ColumnSpecs::new(self.shared.initial_result_metadata.col_specs())
}
pub fn get_current_result_set_col_specs(&self) -> ColumnSpecsGuard {
ColumnSpecsGuard {
result: self.shared.current_result_metadata.load(),
}
}
pub fn get_partitioner_name(&self) -> &PartitionerName {
&self.partitioner_name
}
#[inline]
pub fn set_retry_policy(&mut self, retry_policy: Option<Arc<dyn RetryPolicy>>) {
self.config.retry_policy = retry_policy;
}
#[inline]
pub fn get_retry_policy(&self) -> Option<&Arc<dyn RetryPolicy>> {
self.config.retry_policy.as_ref()
}
#[inline]
pub fn set_load_balancing_policy(
&mut self,
load_balancing_policy: Option<Arc<dyn LoadBalancingPolicy>>,
) {
self.config.load_balancing_policy = load_balancing_policy;
}
#[inline]
pub fn get_load_balancing_policy(&self) -> Option<&Arc<dyn LoadBalancingPolicy>> {
self.config.load_balancing_policy.as_ref()
}
pub fn set_history_listener(&mut self, history_listener: Arc<dyn HistoryListener>) {
self.config.history_listener = Some(history_listener);
}
pub fn remove_history_listener(&mut self) -> Option<Arc<dyn HistoryListener>> {
self.config.history_listener.take()
}
pub fn set_execution_profile_handle(&mut self, profile_handle: Option<ExecutionProfileHandle>) {
self.config.execution_profile_handle = profile_handle;
}
pub fn get_execution_profile_handle(&self) -> Option<&ExecutionProfileHandle> {
self.config.execution_profile_handle.as_ref()
}
pub(crate) fn serialize_values(
&self,
values: &impl SerializeRow,
) -> Result<SerializedValues, SerializationError> {
let ctx = RowSerializationContext::from_prepared(self.get_prepared_metadata());
SerializedValues::from_serializable(&ctx, values)
}
pub(crate) fn make_unconfigured_handle(&self) -> UnconfiguredPreparedStatement {
UnconfiguredPreparedStatement {
shared: Arc::clone(&self.shared),
partitioner_name: self.get_partitioner_name().clone(),
}
}
}
#[derive(Debug)]
pub(crate) struct UnconfiguredPreparedStatement {
shared: Arc<PreparedStatementSharedData>,
partitioner_name: PartitionerName,
}
impl UnconfiguredPreparedStatement {
pub(crate) fn make_configured_handle(
&self,
config: StatementConfig,
page_size: PageSize,
) -> PreparedStatement {
PreparedStatement {
shared: Arc::clone(&self.shared),
prepare_tracing_ids: Vec::new(),
page_size,
partitioner_name: self.partitioner_name.clone(),
config,
}
}
}
#[derive(Clone, Debug, Error, PartialEq, Eq, PartialOrd, Ord)]
#[non_exhaustive]
pub enum PartitionKeyExtractionError {
#[error("No value with given pk_index! pk_index: {0}, values.len(): {1}")]
NoPkIndexValue(u16, u16),
}
#[derive(Clone, Debug, Error, PartialEq, Eq, PartialOrd, Ord)]
#[non_exhaustive]
pub enum TokenCalculationError {
#[error("Value bytes too long to create partition key, max 65 535 allowed! value.len(): {0}")]
ValueTooLong(usize),
}
#[derive(Clone, Debug, Error)]
#[non_exhaustive]
pub enum PartitionKeyError {
#[error(transparent)]
PartitionKeyExtraction(#[from] PartitionKeyExtractionError),
#[error(transparent)]
TokenCalculation(#[from] TokenCalculationError),
#[error(transparent)]
Serialization(#[from] SerializationError),
}
impl PartitionKeyError {
pub fn into_execution_error(self) -> ExecutionError {
match self {
PartitionKeyError::PartitionKeyExtraction(_) => {
ExecutionError::BadQuery(BadQuery::PartitionKeyExtraction)
}
PartitionKeyError::TokenCalculation(TokenCalculationError::ValueTooLong(
values_len,
)) => {
ExecutionError::BadQuery(BadQuery::ValuesTooLongForKey(values_len, u16::MAX.into()))
}
PartitionKeyError::Serialization(err) => {
ExecutionError::BadQuery(BadQuery::SerializationError(err))
}
}
}
}
pub(crate) type PartitionKeyValue<'ps> = (&'ps [u8], &'ps ColumnSpec<'ps>);
pub(crate) struct PartitionKey<'ps> {
pk_values: SmallVec<[Option<PartitionKeyValue<'ps>>; PartitionKey::SMALLVEC_ON_STACK_SIZE]>,
}
impl<'ps> PartitionKey<'ps> {
const SMALLVEC_ON_STACK_SIZE: usize = 8;
fn new(
prepared_metadata: &'ps PreparedMetadata,
bound_values: &'ps SerializedValues,
) -> Result<Self, PartitionKeyExtractionError> {
let mut pk_values: SmallVec<[_; PartitionKey::SMALLVEC_ON_STACK_SIZE]> =
smallvec![None; prepared_metadata.pk_indexes.len()];
let mut values_iter = bound_values.iter();
let mut values_iter_offset = 0;
for pk_index in prepared_metadata.pk_indexes.iter().copied() {
let next_val = values_iter
.nth((pk_index.index - values_iter_offset) as usize)
.ok_or_else(|| {
PartitionKeyExtractionError::NoPkIndexValue(
pk_index.index,
bound_values.element_count(),
)
})?;
if let RawValue::Value(v) = next_val {
let spec = &prepared_metadata.col_specs[pk_index.index as usize];
pk_values[pk_index.sequence as usize] = Some((v, spec));
}
values_iter_offset = pk_index.index + 1;
}
Ok(Self { pk_values })
}
pub(crate) fn iter(
&self,
) -> impl Iterator<Item = PartitionKeyValue<'ps>> + Clone + use<'ps, '_> {
self.pk_values.iter().flatten().copied()
}
fn write_encoded_partition_key(
&self,
writer: &mut impl FnMut(&[u8]),
) -> Result<(), TokenCalculationError> {
let mut pk_val_iter = self.iter().map(|(val, _spec)| val);
if let Some(first_value) = pk_val_iter.next() {
if let Some(second_value) = pk_val_iter.next() {
for value in std::iter::once(first_value)
.chain(std::iter::once(second_value))
.chain(pk_val_iter)
{
let v_len_u16: u16 = value
.len()
.try_into()
.map_err(|_| TokenCalculationError::ValueTooLong(value.len()))?;
writer(&v_len_u16.to_be_bytes());
writer(value);
writer(&[0u8]);
}
} else {
writer(first_value);
}
}
Ok(())
}
pub(crate) fn calculate_token(
&self,
partitioner_name: &PartitionerName,
) -> Result<Token, TokenCalculationError> {
let mut partitioner_hasher = partitioner_name.build_hasher();
let mut writer = |chunk: &[u8]| partitioner_hasher.write(chunk);
self.write_encoded_partition_key(&mut writer)?;
Ok(partitioner_hasher.finish())
}
}
#[cfg(test)]
mod tests {
use scylla_cql::frame::response::result::{
ColumnSpec, ColumnType, NativeType, PartitionKeyIndex, PreparedMetadata, TableSpec,
};
use scylla_cql::serialize::row::SerializedValues;
use crate::statement::prepared::PartitionKey;
use crate::test_utils::setup_tracing;
fn make_meta(
cols: impl IntoIterator<Item = ColumnType<'static>>,
idx: impl IntoIterator<Item = usize>,
) -> PreparedMetadata {
let table_spec = TableSpec::owned("ks".to_owned(), "t".to_owned());
let col_specs: Vec<_> = cols
.into_iter()
.enumerate()
.map(|(i, typ)| ColumnSpec::owned(format!("col_{i}"), typ, table_spec.clone()))
.collect();
let mut pk_indexes = idx
.into_iter()
.enumerate()
.map(|(sequence, index)| PartitionKeyIndex {
index: index as u16,
sequence: sequence as u16,
})
.collect::<Vec<_>>();
pk_indexes.sort_unstable_by_key(|pki| pki.index);
PreparedMetadata {
flags: 0,
col_count: col_specs.len(),
col_specs,
pk_indexes,
}
}
#[test]
fn test_partition_key_multiple_columns_shuffled() {
setup_tracing();
let meta = make_meta(
[
ColumnType::Native(NativeType::TinyInt),
ColumnType::Native(NativeType::SmallInt),
ColumnType::Native(NativeType::Int),
ColumnType::Native(NativeType::BigInt),
ColumnType::Native(NativeType::Blob),
],
[4, 0, 3],
);
let mut values = SerializedValues::new();
values
.add_value(&67i8, &ColumnType::Native(NativeType::TinyInt))
.unwrap();
values
.add_value(&42i16, &ColumnType::Native(NativeType::SmallInt))
.unwrap();
values
.add_value(&23i32, &ColumnType::Native(NativeType::Int))
.unwrap();
values
.add_value(&89i64, &ColumnType::Native(NativeType::BigInt))
.unwrap();
values
.add_value(&[1u8, 2, 3, 4, 5], &ColumnType::Native(NativeType::Blob))
.unwrap();
let pk = PartitionKey::new(&meta, &values).unwrap();
let pk_cols = Vec::from_iter(pk.iter());
assert_eq!(
pk_cols,
vec![
([1u8, 2, 3, 4, 5].as_slice(), &meta.col_specs[4]),
(67i8.to_be_bytes().as_ref(), &meta.col_specs[0]),
(89i64.to_be_bytes().as_ref(), &meta.col_specs[3]),
]
);
}
#[test]
fn test_column_specs_guard_debug() {
use crate::statement::prepared::PreparedStatement;
use bytes::Bytes;
use scylla_cql::frame::response::result::ResultMetadata;
setup_tracing();
let meta = make_meta([ColumnType::Native(NativeType::Int)], [0]);
let table_spec = TableSpec::owned("test_ks".to_owned(), "test_table".to_owned());
let col_specs = vec![ColumnSpec::owned(
"test_column_name".to_string(),
ColumnType::Native(NativeType::Text),
table_spec,
)];
let result_metadata = ResultMetadata::new_for_test(1, col_specs);
let prepared = PreparedStatement::new(
Bytes::from_static(b"test_id"),
false,
meta,
std::sync::Arc::new(result_metadata),
"SELECT * FROM test".to_string(),
crate::statement::PageSize::new(100).unwrap(),
Default::default(),
);
let guard = prepared.get_current_result_set_col_specs();
let debug_output = format!("{:?}", guard);
assert!(debug_output.contains("ColumnSpecsGuard"));
assert!(debug_output.contains("test_column_name"));
}
}