use crate::cluster::routing::SlotAddr;
use crate::cluster::topology::SLOT_SIZE;
use crate::cluster::{ClusterConnInner, Connect, InnerCore, RefreshPolicy};
use crate::cmd::cmd;
use crate::connection::ConnectionLike;
use crate::value::{ErrorKind, Error, Result, Value, from_value};
use std::sync::Arc;
use strum_macros::{Display, EnumString};
const BITS_PER_U64: u16 = u64::BITS as u16;
const NUM_OF_SLOTS: u16 = SLOT_SIZE;
const BITS_ARRAY_SIZE: u16 = NUM_OF_SLOTS / BITS_PER_U64;
const END_OF_SCAN: u16 = NUM_OF_SLOTS;
type SlotsBitsArray = [u64; BITS_ARRAY_SIZE as usize];
#[derive(Clone, Default)]
pub struct ClusterScanArgs {
pub scan_state_cursor: ScanStateRC,
pub match_pattern: Option<Vec<u8>>,
pub count: Option<u32>,
pub object_type: Option<ObjectType>,
pub allow_non_covered_slots: bool,
}
impl ClusterScanArgs {
pub fn builder() -> ClusterScanArgsBuilder {
ClusterScanArgsBuilder::default()
}
pub(crate) fn set_scan_state_cursor(&mut self, scan_state_cursor: ScanStateRC) {
self.scan_state_cursor = scan_state_cursor;
}
}
#[derive(Default)]
pub struct ClusterScanArgsBuilder {
match_pattern: Option<Vec<u8>>,
count: Option<u32>,
object_type: Option<ObjectType>,
allow_non_covered_slots: Option<bool>,
}
impl ClusterScanArgsBuilder {
pub fn with_match_pattern<T: Into<Vec<u8>>>(mut self, pattern: T) -> Self {
self.match_pattern = Some(pattern.into());
self
}
pub fn with_count(mut self, count: u32) -> Self {
self.count = Some(count);
self
}
pub fn with_object_type(mut self, object_type: ObjectType) -> Self {
self.object_type = Some(object_type);
self
}
pub fn allow_non_covered_slots(mut self, allow: bool) -> Self {
self.allow_non_covered_slots = Some(allow);
self
}
pub fn build(self) -> ClusterScanArgs {
ClusterScanArgs {
scan_state_cursor: ScanStateRC::new(),
match_pattern: self.match_pattern,
count: self.count,
object_type: self.object_type,
allow_non_covered_slots: self.allow_non_covered_slots.unwrap_or(false),
}
}
}
#[derive(Debug, Clone, Display, PartialEq, EnumString)]
pub enum ObjectType {
String,
List,
Set,
ZSet,
Hash,
Stream,
}
impl From<String> for ObjectType {
fn from(s: String) -> Self {
match s.to_lowercase().as_str() {
"string" => ObjectType::String,
"list" => ObjectType::List,
"set" => ObjectType::Set,
"zset" => ObjectType::ZSet,
"hash" => ObjectType::Hash,
"stream" => ObjectType::Stream,
_ => ObjectType::String,
}
}
}
#[derive(PartialEq, Debug, Clone, Default)]
pub enum ScanStateStage {
#[default]
Initiating,
InProgress,
Finished,
}
#[derive(Debug, Clone, Default)]
pub struct ScanStateRC {
scan_state_rc: Arc<Option<ScanState>>,
status: ScanStateStage,
}
impl ScanStateRC {
fn from_scan_state(scan_state: ScanState) -> Self {
Self {
scan_state_rc: Arc::new(Some(scan_state)),
status: ScanStateStage::InProgress,
}
}
pub fn new() -> Self {
Self {
scan_state_rc: Arc::new(None),
status: ScanStateStage::Initiating,
}
}
fn create_finished() -> Self {
Self {
scan_state_rc: Arc::new(None),
status: ScanStateStage::Finished,
}
}
pub fn is_finished(&self) -> bool {
self.status == ScanStateStage::Finished
}
pub(crate) fn state_from_wrapper(&self) -> Option<ScanState> {
if self.status == ScanStateStage::Initiating || self.status == ScanStateStage::Finished {
None
} else {
self.scan_state_rc.as_ref().clone()
}
}
}
#[derive(PartialEq, Debug, Clone)]
pub(crate) struct ScanState {
cursor: u64,
scanned_slots_map: SlotsBitsArray,
pub(crate) address_in_scan: Arc<String>,
address_epoch: u64,
scan_status: ScanStateStage,
}
impl ScanState {
pub fn new(
cursor: u64,
scanned_slots_map: SlotsBitsArray,
address_in_scan: Arc<String>,
address_epoch: u64,
scan_status: ScanStateStage,
) -> Self {
Self {
cursor,
scanned_slots_map,
address_in_scan,
address_epoch,
scan_status,
}
}
fn create_finished_state() -> Self {
Self {
cursor: 0,
scanned_slots_map: [0; BITS_ARRAY_SIZE as usize],
address_in_scan: Default::default(),
address_epoch: 0,
scan_status: ScanStateStage::Finished,
}
}
async fn initiate_scan<C>(
core: &InnerCore<C>,
allow_non_covered_slots: bool,
) -> Result<ScanState>
where
C: ConnectionLike + Connect + Clone + Send + Sync + 'static,
{
let mut new_scanned_slots_map: SlotsBitsArray = [0; BITS_ARRAY_SIZE as usize];
let new_cursor = 0;
let address = next_address_to_scan(
core,
0,
&mut new_scanned_slots_map,
allow_non_covered_slots,
)
.await?;
match address {
NextNodeResult::AllSlotsCompleted => Ok(ScanState::create_finished_state()),
NextNodeResult::Address(address) => {
let address_epoch = core.address_epoch(&address).await.unwrap_or(0);
Ok(ScanState::new(
new_cursor,
new_scanned_slots_map,
address,
address_epoch,
ScanStateStage::InProgress,
))
}
}
}
async fn new_scan_state<C>(
&self,
core: Arc<InnerCore<C>>,
allow_non_covered_slots: bool,
new_scanned_slots_map: Option<SlotsBitsArray>,
) -> Result<ScanState>
where
C: ConnectionLike + Connect + Clone + Send + Sync + 'static,
{
let mut scanned_slots_map = new_scanned_slots_map.unwrap_or(self.scanned_slots_map);
let next_slot = next_slot(&scanned_slots_map).unwrap_or(0);
match next_address_to_scan(
&core,
next_slot,
&mut scanned_slots_map,
allow_non_covered_slots,
)
.await
{
Ok(NextNodeResult::Address(new_address)) => {
let new_epoch = core.address_epoch(&new_address).await.unwrap_or(0);
Ok(ScanState::new(
0,
scanned_slots_map,
new_address,
new_epoch,
ScanStateStage::InProgress,
))
}
Ok(NextNodeResult::AllSlotsCompleted) => Ok(ScanState::create_finished_state()),
Err(err) => Err(err),
}
}
async fn create_updated_scan_state_for_completed_address<C>(
&mut self,
core: Arc<InnerCore<C>>,
allow_non_covered_slots: bool,
) -> Result<ScanState>
where
C: ConnectionLike + Connect + Clone + Send + Sync + 'static,
{
ClusterConnInner::check_topology_and_refresh_if_diff(
core.clone(),
&RefreshPolicy::NotThrottable,
)
.await?;
let mut scanned_slots_map = self.scanned_slots_map;
let new_address_epoch = core.address_epoch(&self.address_in_scan).await.unwrap_or(0);
if new_address_epoch != self.address_epoch {
return self
.new_scan_state(core, allow_non_covered_slots, None)
.await;
}
let slots_scanned = core.slots_of_address(self.address_in_scan.clone()).await;
for slot in slots_scanned {
mark_slot_as_scanned(&mut scanned_slots_map, slot);
}
self.new_scan_state(core, allow_non_covered_slots, Some(scanned_slots_map))
.await
}
}
fn mark_slot_as_scanned(scanned_slots_map: &mut SlotsBitsArray, slot: u16) {
let slot_index = (slot as u64 / BITS_PER_U64 as u64) as usize;
let slot_bit = slot as u64 % (BITS_PER_U64 as u64);
scanned_slots_map[slot_index] |= 1 << slot_bit;
}
#[derive(PartialEq, Debug, Clone)]
enum NextNodeResult {
Address(Arc<String>),
AllSlotsCompleted,
}
async fn next_address_to_scan<C>(
core: &InnerCore<C>,
mut slot: u16,
scanned_slots_map: &mut SlotsBitsArray,
allow_non_covered_slots: bool,
) -> Result<NextNodeResult>
where
C: ConnectionLike + Connect + Clone + Send + Sync + 'static,
{
loop {
if slot == END_OF_SCAN {
return Ok(NextNodeResult::AllSlotsCompleted);
}
if let Some(addr) = core
.conn_lock
.read()
.slot_map
.node_address_for_slot(slot, SlotAddr::ReplicaRequired)
{
return Ok(NextNodeResult::Address(addr));
} else if allow_non_covered_slots {
mark_slot_as_scanned(scanned_slots_map, slot);
slot = next_slot(scanned_slots_map).unwrap();
} else {
return Err(Error::from((
ErrorKind::NotAllSlotsCovered,
"Could not find an address covering a slot, SCAN operation cannot continue \n
If you want to continue scanning even if some slots are not covered, set allow_non_covered_slots to true \n
Note that this may lead to incomplete scanning, and the SCAN operation lose its all guarantees ",
)));
}
}
}
fn next_slot(scanned_slots_map: &SlotsBitsArray) -> Option<u16> {
let all_slots_scanned = scanned_slots_map.iter().all(|&word| word == u64::MAX);
if all_slots_scanned {
return Some(END_OF_SCAN);
}
for (i, slot) in scanned_slots_map.iter().enumerate() {
let mut mask = 1;
for j in 0..BITS_PER_U64 {
if (slot & mask) == 0 {
return Some(i as u16 * BITS_PER_U64 + j);
}
mask <<= 1;
}
}
None
}
pub(crate) async fn cluster_scan<C>(
core: Arc<InnerCore<C>>,
cluster_scan_args: ClusterScanArgs,
) -> Result<(ScanStateRC, Vec<Value>)>
where
C: ConnectionLike + Connect + Clone + Send + Sync + 'static,
{
let scan_state_cursor = &cluster_scan_args.scan_state_cursor;
let allow_non_covered_slots = cluster_scan_args.allow_non_covered_slots;
let scan_state = match scan_state_cursor.state_from_wrapper() {
Some(state) => state,
None => match ScanState::initiate_scan(&core, allow_non_covered_slots).await {
Ok(state) => state,
Err(err) => {
return Err(err);
}
},
};
let ((new_cursor, new_keys), mut scan_state) =
try_scan(&scan_state, &cluster_scan_args, core.clone()).await?;
if new_cursor == 0 {
scan_state = scan_state
.create_updated_scan_state_for_completed_address(core, allow_non_covered_slots)
.await?;
}
if scan_state.scan_status == ScanStateStage::Finished {
return Ok((ScanStateRC::create_finished(), new_keys));
}
scan_state = ScanState::new(
new_cursor,
scan_state.scanned_slots_map,
scan_state.address_in_scan,
scan_state.address_epoch,
ScanStateStage::InProgress,
);
Ok((ScanStateRC::from_scan_state(scan_state), new_keys))
}
async fn send_scan<C>(
scan_state: &ScanState,
cluster_scan_args: &ClusterScanArgs,
core: Arc<InnerCore<C>>,
) -> Result<Value>
where
C: ConnectionLike + Connect + Clone + Send + Sync + 'static,
{
if let Some(conn_future) = core
.connection_for_address(&scan_state.address_in_scan)
.await
{
let mut conn = conn_future.await;
let mut scan_command = cmd("SCAN");
scan_command.arg(scan_state.cursor);
if let Some(match_pattern) = cluster_scan_args.match_pattern.as_ref() {
scan_command.arg("MATCH").arg(match_pattern);
}
if let Some(count) = cluster_scan_args.count {
scan_command.arg("COUNT").arg(count);
}
if let Some(object_type) = &cluster_scan_args.object_type {
scan_command.arg("TYPE").arg(object_type.to_string());
}
conn.req_packed_command(&scan_command).await
} else {
Err(Error::from((
ErrorKind::ConnectionNotFoundForRoute,
"Cluster scan failed. No connection available for address: ",
format!("{}", scan_state.address_in_scan),
)))
}
}
fn is_scanwise_retryable_error(err: &Error) -> bool {
matches!(
err.kind(),
ErrorKind::IoError
| ErrorKind::AllConnectionsUnavailable
| ErrorKind::ConnectionNotFoundForRoute
| ErrorKind::ClusterDown
| ErrorKind::FatalSendError
)
}
async fn next_scan_state<C>(
core: &Arc<InnerCore<C>>,
scan_state: &ScanState,
cluster_scan_args: &ClusterScanArgs,
) -> Result<Option<ScanState>>
where
C: ConnectionLike + Connect + Clone + Send + Sync + 'static,
{
let next_slot = next_slot(&scan_state.scanned_slots_map).unwrap_or(0);
let mut scanned_slots_map = scan_state.scanned_slots_map;
match next_address_to_scan(
core,
next_slot,
&mut scanned_slots_map,
cluster_scan_args.allow_non_covered_slots,
)
.await
{
Ok(NextNodeResult::Address(new_address)) => {
let new_epoch = core.address_epoch(&new_address).await.unwrap_or(0);
Ok(Some(ScanState::new(
0,
scanned_slots_map,
new_address,
new_epoch,
ScanStateStage::InProgress,
)))
}
Ok(NextNodeResult::AllSlotsCompleted) => Ok(None),
Err(err) => Err(err),
}
}
async fn try_scan<C>(
scan_state: &ScanState,
cluster_scan_args: &ClusterScanArgs,
core: Arc<InnerCore<C>>,
) -> Result<((u64, Vec<Value>), ScanState)>
where
C: ConnectionLike + Connect + Clone + Send + Sync + 'static,
{
let mut new_scan_state = scan_state.clone();
const MAX_SCAN_RETRIES: usize = 10;
let mut retries = 0;
loop {
match send_scan(&new_scan_state, cluster_scan_args, core.clone()).await {
Ok(scan_response) => {
let (new_cursor, new_keys) =
from_value::<(u64, Vec<Value>)>(&scan_response)?;
return Ok(((new_cursor, new_keys), new_scan_state));
}
Err(err) if is_scanwise_retryable_error(&err) => {
retries += 1;
if retries > MAX_SCAN_RETRIES {
return Err(Error::from((
ErrorKind::AllConnectionsUnavailable,
"Cluster scan exceeded maximum retry count",
format!(
"Failed after {} retries. Last error: {}",
MAX_SCAN_RETRIES, err
),
)));
}
ClusterConnInner::check_topology_and_refresh_if_diff(
core.clone(),
&RefreshPolicy::NotThrottable,
)
.await?;
if let Some(next_scan_state) =
next_scan_state(&core, &new_scan_state, cluster_scan_args).await?
{
new_scan_state = next_scan_state;
} else {
return Ok(((0, Vec::new()), ScanState::create_finished_state()));
}
}
Err(err) => return Err(err),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_cluster_scan_args_builder() {
let args = ClusterScanArgs::builder()
.with_match_pattern("user:*")
.with_count(100)
.with_object_type(ObjectType::Hash)
.allow_non_covered_slots(true)
.build();
assert_eq!(args.match_pattern, Some(b"user:*".to_vec()));
assert_eq!(args.count, Some(100));
assert_eq!(args.object_type, Some(ObjectType::Hash));
assert!(args.allow_non_covered_slots);
}
#[tokio::test]
async fn test_scan_state_new() {
let address = Arc::new("127.0.0.1:6379".to_string());
let scan_state = ScanState::new(
0,
[0; BITS_ARRAY_SIZE as usize],
address.clone(),
1,
ScanStateStage::InProgress,
);
assert_eq!(scan_state.cursor, 0);
assert_eq!(scan_state.scanned_slots_map, [0; BITS_ARRAY_SIZE as usize]);
assert_eq!(scan_state.address_in_scan, address);
assert_eq!(scan_state.address_epoch, 1);
assert_eq!(scan_state.scan_status, ScanStateStage::InProgress);
}
#[tokio::test]
async fn test_scan_state_create_finished() {
let scan_state = ScanState::create_finished_state();
assert_eq!(scan_state.cursor, 0);
assert_eq!(scan_state.scanned_slots_map, [0; BITS_ARRAY_SIZE as usize]);
assert_eq!(scan_state.address_in_scan, Arc::new(String::new()));
assert_eq!(scan_state.address_epoch, 0);
assert_eq!(scan_state.scan_status, ScanStateStage::Finished);
}
#[tokio::test]
async fn test_mark_slot_as_scanned() {
let mut scanned_slots_map = [0; BITS_ARRAY_SIZE as usize];
mark_slot_as_scanned(&mut scanned_slots_map, 5);
assert_eq!(scanned_slots_map[0], 1 << 5);
}
#[tokio::test]
async fn test_next_slot() {
let scan_state = ScanState::new(
0,
[0; BITS_ARRAY_SIZE as usize],
Arc::new("127.0.0.1:6379".to_string()),
1,
ScanStateStage::InProgress,
);
let next_slot = next_slot(&scan_state.scanned_slots_map);
assert_eq!(next_slot, Some(0));
}
}