use crate::client::{ClientWrapper, PubSubCommandApplier};
use crate::cluster::routing::{Routable, SingleNodeRoutingInfo};
use crate::cluster::slotmap::SlotMap;
use crate::cmd::{self, Cmd};
use crate::connection::info::{
PubSubChannelOrPattern, PubSubSubscriptionInfo, PubSubSubscriptionKind,
};
use crate::pubsub::synchronizer_trait::PubSubSynchronizer;
use crate::value::{ErrorKind, Error, Result, Value};
use async_trait::async_trait;
use once_cell::sync::OnceCell;
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex, RwLock, Weak};
use std::time::{Duration, Instant};
use tokio::sync::{mpsc, Notify, RwLock as TokioRwLock};
const CLUSTER_KINDS: &[PubSubSubscriptionKind] = &[
PubSubSubscriptionKind::Exact,
PubSubSubscriptionKind::Pattern,
PubSubSubscriptionKind::Sharded,
];
const STANDALONE_KINDS: &[PubSubSubscriptionKind] = &[
PubSubSubscriptionKind::Exact,
PubSubSubscriptionKind::Pattern,
];
const RESUBSCRIBE_INITIAL_BACKOFF_MS: u64 = 200;
const RESUBSCRIBE_MAX_BACKOFF_MS: u64 = 2000;
const RESUBSCRIBE_MAX_ATTEMPTS: u32 = 8;
enum SyncEvent {
DesiredChanged,
TopologyChanged {
migrations: Vec<(String, PubSubSubscriptionKind, HashSet<PubSubChannelOrPattern>)>,
gone_subs: Vec<(String, PubSubSubscriptionKind, HashSet<PubSubChannelOrPattern>)>,
},
NodeDisconnected { addresses: HashSet<String> },
}
#[derive(Default)]
struct ConfirmedState {
by_address: HashMap<String, PubSubSubscriptionInfo>,
}
impl ConfirmedState {
fn aggregate(&self) -> PubSubSubscriptionInfo {
let mut result = PubSubSubscriptionInfo::new();
for subs in self.by_address.values() {
for (kind, channels) in subs {
result.entry(*kind).or_default().extend(channels.clone());
}
}
result
}
fn add(&mut self, kind: PubSubSubscriptionKind, channel: Vec<u8>, address: String) {
self.by_address
.entry(address)
.or_default()
.entry(kind)
.or_default()
.insert(channel);
}
fn remove_exact(&mut self, kind: PubSubSubscriptionKind, channel: &[u8], address: &str) {
if kind == PubSubSubscriptionKind::Sharded {
if let Some(addr_subs) = self.by_address.get_mut(address)
&& let Some(channels) = addr_subs.get_mut(&kind)
{
channels.remove(channel);
}
} else {
for addr_subs in self.by_address.values_mut() {
if let Some(channels) = addr_subs.get_mut(&kind) {
channels.remove(channel);
}
}
}
self.gc();
}
fn clear_addresses(&mut self, addresses: &HashSet<String>) {
for addr in addresses {
self.by_address.remove(addr);
}
}
fn gc(&mut self) {
self.by_address.retain(|_, subs| {
subs.retain(|_, channels| !channels.is_empty());
!subs.is_empty()
});
}
}
pub struct EventDrivenSynchronizer {
internal_client: OnceCell<Weak<TokioRwLock<ClientWrapper>>>,
is_cluster: bool,
desired: RwLock<PubSubSubscriptionInfo>,
confirmed: RwLock<ConfirmedState>,
events_tx: mpsc::UnboundedSender<SyncEvent>,
task_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
sync_notify: Notify,
reconcile_complete_notify: Notify,
request_timeout: Duration,
in_topology_change: std::sync::atomic::AtomicBool,
backoff_active: Arc<std::sync::atomic::AtomicBool>,
}
struct BackoffGuard(Arc<std::sync::atomic::AtomicBool>);
impl Drop for BackoffGuard {
fn drop(&mut self) {
self.0.store(false, std::sync::atomic::Ordering::Release);
}
}
fn command_info(cmd_str: &str) -> Option<(PubSubSubscriptionKind, bool, bool)> {
match cmd_str {
"SUBSCRIBE" => Some((PubSubSubscriptionKind::Exact, true, false)),
"UNSUBSCRIBE" => Some((PubSubSubscriptionKind::Exact, false, false)),
"PSUBSCRIBE" => Some((PubSubSubscriptionKind::Pattern, true, false)),
"PUNSUBSCRIBE" => Some((PubSubSubscriptionKind::Pattern, false, false)),
"SSUBSCRIBE" => Some((PubSubSubscriptionKind::Sharded, true, false)),
"SUNSUBSCRIBE" => Some((PubSubSubscriptionKind::Sharded, false, false)),
"SUBSCRIBE_BLOCKING" => Some((PubSubSubscriptionKind::Exact, true, true)),
"UNSUBSCRIBE_BLOCKING" => Some((PubSubSubscriptionKind::Exact, false, true)),
"PSUBSCRIBE_BLOCKING" => Some((PubSubSubscriptionKind::Pattern, true, true)),
"PUNSUBSCRIBE_BLOCKING" => Some((PubSubSubscriptionKind::Pattern, false, true)),
"SSUBSCRIBE_BLOCKING" => Some((PubSubSubscriptionKind::Sharded, true, true)),
"SUNSUBSCRIBE_BLOCKING" => Some((PubSubSubscriptionKind::Sharded, false, true)),
_ => None,
}
}
impl EventDrivenSynchronizer {
pub fn new(
initial_subscriptions: Option<PubSubSubscriptionInfo>,
is_cluster: bool,
_reconciliation_interval: Option<Duration>,
request_timeout: Duration,
) -> Arc<Self> {
let (events_tx, events_rx) = mpsc::unbounded_channel();
let sync = Arc::new(Self {
internal_client: OnceCell::new(),
is_cluster,
desired: RwLock::new(initial_subscriptions.unwrap_or_default()),
confirmed: RwLock::new(ConfirmedState::default()),
events_tx,
task_handle: Mutex::new(None),
sync_notify: Notify::new(),
reconcile_complete_notify: Notify::new(),
request_timeout,
in_topology_change: std::sync::atomic::AtomicBool::new(false),
backoff_active: Arc::new(std::sync::atomic::AtomicBool::new(false)),
});
sync.start_event_loop(events_rx);
sync
}
pub fn set_internal_client(&self, client: Weak<TokioRwLock<ClientWrapper>>) {
let _ = self.internal_client.set(client);
}
pub fn get_current_subscriptions_by_address(&self) -> HashMap<String, PubSubSubscriptionInfo> {
self.confirmed.read().unwrap().by_address.clone()
}
#[inline]
fn kinds(&self) -> &'static [PubSubSubscriptionKind] {
if self.is_cluster {
CLUSTER_KINDS
} else {
STANDALONE_KINDS
}
}
fn send_event(&self, event: SyncEvent) {
let _ = self.events_tx.send(event);
}
fn check_sync_and_notify(&self) {
let desired = self.desired.read().unwrap_or_else(|e| e.into_inner());
let confirmed = self.confirmed.read().unwrap_or_else(|e| e.into_inner());
let actual = confirmed.aggregate();
let is_synced = self.kinds().iter().all(|kind| {
let d = desired.get(kind).map(|s| s.len()).unwrap_or(0);
let a = actual.get(kind).map(|s| s.len()).unwrap_or(0);
if d != a {
return false;
}
match (desired.get(kind), actual.get(kind)) {
(Some(d_set), Some(a_set)) => d_set == a_set,
(None, None) => true,
(Some(d_set), None) => d_set.is_empty(),
(None, Some(a_set)) => a_set.is_empty(),
}
});
if is_synced {
tracing::debug!(
target: "ferriskey",
event = "pubsub_synced",
"ferriskey: pubsub subscription state synced"
);
self.sync_notify.notify_waiters();
} else {
tracing::warn!(
target: "ferriskey",
event = "pubsub_out_of_sync",
"ferriskey: pubsub subscription state drift detected"
);
}
}
fn start_event_loop(self: &Arc<Self>, mut events_rx: mpsc::UnboundedReceiver<SyncEvent>) {
let sync_weak = Arc::downgrade(self);
let handle = tokio::spawn(async move {
loop {
let event = events_rx.recv().await;
let Some(sync) = sync_weak.upgrade() else {
break; };
let Some(event) = event else {
break; };
match event {
SyncEvent::DesiredChanged => {
let mut deferred = Vec::new();
while let Ok(evt) = events_rx.try_recv() {
match evt {
SyncEvent::DesiredChanged => {} other => { deferred.push(other); break; }
}
}
for evt in deferred { let _ = sync.events_tx.send(evt); }
if let Err(e) = sync.reconcile().await {
tracing::error!("pubsub_sync - Reconcile failed: {e:?}");
}
}
SyncEvent::TopologyChanged { migrations, gone_subs } => {
let mut latest_mig = migrations;
let mut latest_gone = gone_subs;
let mut deferred = Vec::new();
while let Ok(evt) = events_rx.try_recv() {
match evt {
SyncEvent::TopologyChanged { migrations, gone_subs } => {
latest_mig.extend(migrations);
latest_gone.extend(gone_subs);
}
other => { deferred.push(other); break; }
}
}
for evt in deferred { let _ = sync.events_tx.send(evt); }
sync.on_topology_changed(latest_mig, latest_gone).await;
}
SyncEvent::NodeDisconnected { addresses } => {
sync.on_node_disconnected(&addresses).await;
}
}
sync.check_sync_and_notify();
sync.reconcile_complete_notify.notify_waiters();
}
});
*self
.task_handle
.lock()
.unwrap_or_else(|e| e.into_inner()) = Some(handle);
}
async fn reconcile(&self) -> Result<()> {
let desired = self
.desired
.read()
.unwrap_or_else(|e| e.into_inner())
.clone();
let actual = self
.confirmed
.read()
.unwrap_or_else(|e| e.into_inner())
.aggregate();
for kind in self.kinds() {
let desired_channels = desired.get(kind);
let actual_channels = actual.get(kind);
let to_sub: HashSet<_> = desired_channels
.iter()
.flat_map(|d| d.iter())
.filter(|ch| actual_channels.as_ref().is_none_or(|a| !a.contains(*ch)))
.cloned()
.collect();
if !to_sub.is_empty() {
self.send_subscription_cmd(to_sub, *kind, true, None)
.await;
}
}
let unsub_work: Vec<(String, PubSubSubscriptionKind, HashSet<PubSubChannelOrPattern>)> = {
let confirmed = self
.confirmed
.read()
.unwrap_or_else(|e| e.into_inner());
let mut work = Vec::new();
for (addr, addr_subs) in &confirmed.by_address {
for (kind, channels) in addr_subs {
let desired_for_kind = desired.get(kind);
let to_unsub: HashSet<_> = channels
.iter()
.filter(|ch| desired_for_kind.is_none_or(|d| !d.contains(*ch)))
.cloned()
.collect();
if !to_unsub.is_empty() {
work.push((addr.clone(), *kind, to_unsub));
}
}
}
work
};
for (addr, kind, to_unsub) in unsub_work {
let routing = parse_address_routing(&addr).ok();
if kind == PubSubSubscriptionKind::Sharded {
self.send_sharded_unsubscribe_by_slot(to_unsub, routing)
.await;
} else {
self.send_subscription_cmd(to_unsub, kind, false, routing)
.await;
}
}
Ok(())
}
async fn on_topology_changed(
&self,
migrations: Vec<(String, PubSubSubscriptionKind, HashSet<PubSubChannelOrPattern>)>,
gone_subs: Vec<(String, PubSubSubscriptionKind, HashSet<PubSubChannelOrPattern>)>,
) {
if migrations.is_empty() && gone_subs.is_empty() {
return;
}
self.in_topology_change.store(true, std::sync::atomic::Ordering::Release);
for (addr, kind, channels) in migrations.iter().chain(gone_subs.iter()) {
let routing = parse_address_routing(addr).ok();
if *kind == PubSubSubscriptionKind::Sharded {
self.send_sharded_unsubscribe_by_slot(channels.clone(), routing)
.await;
} else {
self.send_subscription_cmd(channels.clone(), *kind, false, routing)
.await;
}
}
self.in_topology_change.store(false, std::sync::atomic::Ordering::Release);
if let Err(e) = self.reconcile().await {
tracing::error!("pubsub_sync - Post-topology reconcile failed: {e:?}");
}
}
async fn on_node_disconnected(&self, addresses: &HashSet<String>) {
if addresses.is_empty() {
return;
}
tracing::debug!("pubsub_sync - Clearing confirmations for disconnected: {addresses:?}");
{
let mut confirmed = self.confirmed.write().unwrap_or_else(|e| e.into_inner());
confirmed.clear_addresses(addresses);
}
if let Err(e) = self.reconcile().await {
tracing::error!("pubsub_sync - Post-disconnect reconcile failed: {e:?}");
}
}
async fn send_subscription_cmd(
&self,
channels: HashSet<PubSubChannelOrPattern>,
kind: PubSubSubscriptionKind,
is_subscribe: bool,
routing: Option<SingleNodeRoutingInfo>,
) {
if channels.is_empty() {
return;
}
let cmd_name = match (kind, is_subscribe) {
(PubSubSubscriptionKind::Exact, true) => "SUBSCRIBE",
(PubSubSubscriptionKind::Exact, false) => "UNSUBSCRIBE",
(PubSubSubscriptionKind::Pattern, true) => "PSUBSCRIBE",
(PubSubSubscriptionKind::Pattern, false) => "PUNSUBSCRIBE",
(PubSubSubscriptionKind::Sharded, true) => "SSUBSCRIBE",
(PubSubSubscriptionKind::Sharded, false) => "SUNSUBSCRIBE",
};
let mut command = cmd::cmd(cmd_name);
for channel in &channels {
command.arg(channel.as_slice());
}
if kind == PubSubSubscriptionKind::Sharded && !is_subscribe {
command.set_fenced(true);
}
match self.apply_pubsub(&mut command, routing).await {
Ok(_) => {}
Err(e) => {
let action = if is_subscribe { "subscribe" } else { "unsubscribe" };
tracing::error!("pubsub_sync - Failed to {action} {kind:?}: {e:?}");
}
}
}
async fn send_sharded_unsubscribe_by_slot(
&self,
channels: HashSet<PubSubChannelOrPattern>,
routing: Option<SingleNodeRoutingInfo>,
) {
let by_slot: HashMap<u16, HashSet<_>> =
channels.into_iter().fold(HashMap::new(), |mut acc, ch| {
let slot = crate::cluster::topology::get_slot(&ch);
acc.entry(slot).or_default().insert(ch);
acc
});
for (_, slot_channels) in by_slot {
self.send_subscription_cmd(
slot_channels,
PubSubSubscriptionKind::Sharded,
false,
routing.clone(),
)
.await;
}
}
async fn apply_pubsub(
&self,
cmd: &mut Cmd,
routing: Option<SingleNodeRoutingInfo>,
) -> Result<Value> {
let client_arc = self
.internal_client
.get()
.ok_or_else(|| {
Error::from((
ErrorKind::ClientError,
"Internal client not set in synchronizer",
))
})?
.upgrade()
.ok_or_else(|| {
Error::from((ErrorKind::ClientError, "Internal client has been dropped"))
})?;
let mut client_wrapper = {
let guard = client_arc.read().await;
guard.clone()
};
client_wrapper.apply_pubsub_command(cmd, routing).await
}
fn extract_channels(cmd: &Cmd) -> Vec<PubSubChannelOrPattern> {
cmd.args_iter()
.skip(1)
.filter_map(|arg| match arg {
cmd::Arg::Simple(bytes) => Some(bytes.to_vec()),
cmd::Arg::Cursor => None,
})
.collect()
}
fn extract_channels_and_timeout(cmd: &Cmd) -> (Vec<PubSubChannelOrPattern>, u64) {
let args: Vec<_> = cmd
.args_iter()
.skip(1)
.filter_map(|arg| match arg {
cmd::Arg::Simple(bytes) => Some(bytes.to_vec()),
cmd::Arg::Cursor => None,
})
.collect();
if args.is_empty() {
return (Vec::new(), 0);
}
let last_is_timeout = args.len() > 1
&& args
.last()
.and_then(|arg| String::from_utf8_lossy(arg).parse::<u64>().ok())
.is_some();
if last_is_timeout {
let timeout_ms = String::from_utf8_lossy(args.last().unwrap())
.parse::<u64>()
.unwrap_or(0);
let channels = args[..args.len() - 1].to_vec();
(channels, timeout_ms)
} else {
(args, 0)
}
}
fn handle_lazy(
&self,
cmd: &Cmd,
kind: PubSubSubscriptionKind,
is_subscribe: bool,
) -> Result<Value> {
let channels = Self::extract_channels(cmd);
if is_subscribe && channels.is_empty() {
return Err(Error::from((
ErrorKind::ClientError,
"No channels provided for subscription",
)));
}
let channels_set = if channels.is_empty() {
None
} else {
Some(channels.into_iter().collect())
};
if is_subscribe {
self.add_desired_subscriptions(channels_set.unwrap(), kind);
} else {
self.remove_desired_subscriptions(channels_set, kind);
}
Ok(Value::Nil)
}
async fn handle_blocking(
&self,
cmd: &Cmd,
kind: PubSubSubscriptionKind,
is_subscribe: bool,
) -> Result<Value> {
let (channels, timeout_ms) = Self::extract_channels_and_timeout(cmd);
if is_subscribe && channels.is_empty() {
return Err(Error::from((
ErrorKind::ClientError,
"No channels provided for subscription",
)));
}
let channels_set: HashSet<PubSubChannelOrPattern> = channels.into_iter().collect();
if is_subscribe {
self.add_desired_subscriptions(channels_set.clone(), kind);
} else {
let to_remove = if channels_set.is_empty() {
None
} else {
Some(channels_set.clone())
};
self.remove_desired_subscriptions(to_remove, kind);
}
let (expected_channels, expected_patterns, expected_sharded) = match kind {
PubSubSubscriptionKind::Exact => (Some(channels_set), None, None),
PubSubSubscriptionKind::Pattern => (None, Some(channels_set), None),
PubSubSubscriptionKind::Sharded => (None, None, Some(channels_set)),
};
self.wait_for_sync(timeout_ms, expected_channels, expected_patterns, expected_sharded)
.await?;
Ok(Value::Nil)
}
fn get_subscriptions_value(&self) -> Value {
let (desired, actual) = self.get_subscription_state();
Value::Array(vec![
Ok(Value::BulkString(bytes::Bytes::from_static(b"desired"))),
Ok(sub_map_to_value(desired)),
Ok(Value::BulkString(bytes::Bytes::from_static(b"actual"))),
Ok(sub_map_to_value(actual)),
])
}
async fn run_with_timeout<T, F>(&self, f: F) -> Result<T>
where
F: FnOnce() -> Result<T> + Send,
T: Send,
{
match tokio::time::timeout(self.request_timeout, async move { f() }).await {
Ok(result) => result,
Err(_) => Err(std::io::Error::from(std::io::ErrorKind::TimedOut).into()),
}
}
fn schedule_resubscription_backoff(
&self,
channels: &HashSet<PubSubChannelOrPattern>,
subscription_type: PubSubSubscriptionKind,
) {
let desired = self.desired.read().unwrap_or_else(|e| e.into_inner()).clone();
let still_desired = desired.get(&subscription_type)
.is_some_and(|channels_set| channels.iter().any(|ch| channels_set.contains(ch)));
let in_change = self.in_topology_change.load(std::sync::atomic::Ordering::Acquire);
if still_desired && !in_change
&& !self.backoff_active.swap(true, std::sync::atomic::Ordering::AcqRel)
{
let tx = self.events_tx.clone();
let backoff_flag = self.backoff_active.clone();
tokio::spawn(async move {
let _guard = BackoffGuard(backoff_flag);
let mut delay_ms = RESUBSCRIBE_INITIAL_BACKOFF_MS;
for _ in 0..RESUBSCRIBE_MAX_ATTEMPTS {
let jitter_range = delay_ms / 5;
let jitter_offset = if jitter_range > 0 {
rand::random::<u64>() % (2 * jitter_range + 1)
} else { 0 };
let actual_delay = if jitter_offset >= jitter_range {
Duration::from_millis(delay_ms + (jitter_offset - jitter_range))
} else {
Duration::from_millis(delay_ms - (jitter_range - jitter_offset))
};
tokio::time::sleep(actual_delay).await;
let _ = tx.send(SyncEvent::DesiredChanged);
delay_ms = (delay_ms * 2).min(RESUBSCRIBE_MAX_BACKOFF_MS);
}
});
}
}
}
impl Drop for EventDrivenSynchronizer {
fn drop(&mut self) {
if let Some(handle) = self
.task_handle
.lock()
.unwrap_or_else(|e| e.into_inner())
.take()
{
handle.abort();
}
}
}
#[async_trait]
impl PubSubSynchronizer for EventDrivenSynchronizer {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn add_desired_subscriptions(
&self,
channels: HashSet<PubSubChannelOrPattern>,
subscription_type: PubSubSubscriptionKind,
) {
{
let mut desired = self.desired.write().unwrap_or_else(|e| e.into_inner());
desired.entry(subscription_type).or_default().extend(channels);
}
self.send_event(SyncEvent::DesiredChanged);
}
fn remove_desired_subscriptions(
&self,
channels: Option<HashSet<PubSubChannelOrPattern>>,
subscription_type: PubSubSubscriptionKind,
) {
{
let mut desired = self.desired.write().unwrap_or_else(|e| e.into_inner());
match channels {
Some(to_remove) => {
if let Some(existing) = desired.get_mut(&subscription_type) {
for ch in to_remove {
existing.remove(&ch);
}
}
}
None => {
desired.remove(&subscription_type);
}
}
}
self.send_event(SyncEvent::DesiredChanged);
}
fn add_current_subscriptions(
&self,
channels: HashSet<PubSubChannelOrPattern>,
subscription_type: PubSubSubscriptionKind,
address: String,
) {
let mut confirmed = self.confirmed.write().unwrap_or_else(|e| e.into_inner());
for channel in channels {
confirmed.add(subscription_type, channel, address.clone());
}
drop(confirmed);
self.check_sync_and_notify();
}
fn remove_current_subscriptions(
&self,
channels: HashSet<PubSubChannelOrPattern>,
subscription_type: PubSubSubscriptionKind,
address: String,
) {
let mut confirmed = self.confirmed.write().unwrap_or_else(|e| e.into_inner());
for channel in &channels {
confirmed.remove_exact(subscription_type, channel, &address);
}
drop(confirmed);
self.check_sync_and_notify();
self.schedule_resubscription_backoff(&channels, subscription_type);
}
fn get_subscription_state(
&self,
) -> (PubSubSubscriptionInfo, PubSubSubscriptionInfo) {
let desired = self
.desired
.read()
.unwrap_or_else(|e| e.into_inner())
.clone();
let actual = self
.confirmed
.read()
.unwrap_or_else(|e| e.into_inner())
.aggregate();
(desired, actual)
}
fn trigger_reconciliation(&self) {
self.send_event(SyncEvent::DesiredChanged);
}
fn remove_current_subscriptions_for_addresses(&self, addresses: &HashSet<String>) {
if !addresses.is_empty() {
self.send_event(SyncEvent::NodeDisconnected {
addresses: addresses.clone(),
});
}
}
fn handle_topology_refresh(&self, new_slot_map: &SlotMap) {
let new_addresses: HashSet<String> = new_slot_map
.all_node_addresses()
.iter()
.map(|arc| arc.to_string())
.collect();
let migrations: Vec<(String, PubSubSubscriptionKind, HashSet<PubSubChannelOrPattern>)>;
let gone_subs: Vec<(String, PubSubSubscriptionKind, HashSet<PubSubChannelOrPattern>)>;
let confirmed_keys: Vec<String>;
{
let confirmed = self.confirmed.read().unwrap_or_else(|e| e.into_inner());
confirmed_keys = confirmed.by_address.keys().cloned().collect();
let mut mig = Vec::new();
let mut gone = Vec::new();
for (addr, addr_subs) in &confirmed.by_address {
if !new_addresses.contains(addr) {
for (kind, channels) in addr_subs {
if !channels.is_empty() {
gone.push((addr.clone(), *kind, channels.clone()));
}
}
continue;
}
for (kind, channels) in addr_subs {
let mut migrated = HashSet::new();
for channel in channels {
let slot = crate::cluster::topology::get_slot(channel);
if let Some(shard_addrs) = new_slot_map.shard_addrs_for_slot(slot) {
let needs_migration = match kind {
PubSubSubscriptionKind::Sharded => {
shard_addrs.primary().as_str() != addr
}
PubSubSubscriptionKind::Exact
| PubSubSubscriptionKind::Pattern => {
!shard_addrs.is_member(addr)
}
};
if needs_migration {
migrated.insert(channel.clone());
}
} else {
migrated.insert(channel.clone());
}
}
if !migrated.is_empty() {
mig.push((addr.clone(), *kind, migrated));
}
}
}
migrations = mig;
gone_subs = gone;
}
{
let new_addrs_count = new_addresses.len();
let migrations_count = migrations.len();
let gone_count = gone_subs.len();
tracing::debug!("pubsub_sync - handle_topology_refresh: confirmed_addrs={confirmed_keys:?}, new_addrs count={new_addrs_count}, migrations={migrations_count}, gone={gone_count}");
}
if migrations.is_empty() && gone_subs.is_empty() {
let desired = self.desired.read().unwrap_or_else(|e| e.into_inner()).clone();
let actual = self.confirmed.read().unwrap_or_else(|e| e.into_inner()).aggregate();
if desired != actual {
self.send_event(SyncEvent::DesiredChanged);
}
return;
}
{
let mut confirmed = self.confirmed.write().unwrap_or_else(|e| e.into_inner());
for (addr, _, _) in &gone_subs {
confirmed.by_address.remove(addr);
}
for (addr, kind, channels) in &migrations {
if let Some(addr_subs) = confirmed.by_address.get_mut(addr)
&& let Some(existing) = addr_subs.get_mut(kind)
{
for ch in channels {
existing.remove(ch);
}
}
}
confirmed.gc();
}
self.send_event(SyncEvent::TopologyChanged {
migrations,
gone_subs,
});
}
async fn intercept_pubsub_command(&self, cmd: &Cmd) -> Option<Result<Value>> {
let command_name = cmd.command().unwrap_or_default();
let command_str = std::str::from_utf8(&command_name).unwrap_or("");
if let Some((kind, is_subscribe, is_blocking)) = command_info(command_str) {
return if is_blocking {
Some(self.handle_blocking(cmd, kind, is_subscribe).await)
} else {
let cmd = cmd.clone();
Some(
self.run_with_timeout(|| self.handle_lazy(&cmd, kind, is_subscribe))
.await,
)
};
}
if command_str == "GET_SUBSCRIPTIONS" {
return Some(
self.run_with_timeout(|| Ok(self.get_subscriptions_value())).await,
);
}
None
}
async fn wait_for_sync(
&self,
timeout_ms: u64,
expected_channels: Option<HashSet<PubSubChannelOrPattern>>,
expected_patterns: Option<HashSet<PubSubChannelOrPattern>>,
expected_sharded: Option<HashSet<PubSubChannelOrPattern>>,
) -> Result<()> {
let deadline = if timeout_ms > 0 {
Some(Instant::now() + Duration::from_millis(timeout_ms))
} else {
None
};
loop {
let notified = self.reconcile_complete_notify.notified();
let condition_met = {
if expected_channels.is_none()
&& expected_patterns.is_none()
&& expected_sharded.is_none()
{
let desired = self.desired.read().unwrap_or_else(|e| e.into_inner());
let actual = self
.confirmed
.read()
.unwrap_or_else(|e| e.into_inner())
.aggregate();
self.kinds().iter().all(|kind| {
let d = desired.get(kind);
let a = actual.get(kind);
match (d, a) {
(Some(d_set), Some(a_set)) => d_set == a_set,
(None, None) => true,
(Some(d_set), None) => d_set.is_empty(),
(None, Some(a_set)) => a_set.is_empty(),
}
})
} else {
let (desired, actual) = self.get_subscription_state();
let check = |channels: &Option<HashSet<PubSubChannelOrPattern>>,
kind: PubSubSubscriptionKind|
-> bool {
channels.as_ref().is_none_or(|chs| {
let d = desired.get(&kind);
let a = actual.get(&kind);
if chs.is_empty() {
let d_empty = d.is_none_or(|s| s.is_empty());
let a_empty = a.is_none_or(|s| s.is_empty());
d_empty && a_empty
} else {
chs.iter().all(|ch| {
let in_d = d.is_some_and(|s| s.contains(ch));
let in_a = a.is_some_and(|s| s.contains(ch));
in_d == in_a
})
}
})
};
check(&expected_channels, PubSubSubscriptionKind::Exact)
&& check(&expected_patterns, PubSubSubscriptionKind::Pattern)
&& check(&expected_sharded, PubSubSubscriptionKind::Sharded)
}
};
if condition_met {
self.check_sync_and_notify();
return Ok(());
}
if deadline.is_none() {
return Err(std::io::Error::from(std::io::ErrorKind::TimedOut).into());
}
self.trigger_reconciliation();
let deadline = deadline.unwrap();
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
return Err(std::io::Error::from(std::io::ErrorKind::TimedOut).into());
}
tokio::select! {
_ = notified => {}
_ = tokio::time::sleep(remaining) => {
return Err(std::io::Error::from(std::io::ErrorKind::TimedOut).into());
}
}
}
}
}
fn parse_address_routing(address: &str) -> Result<SingleNodeRoutingInfo> {
let (host, port_str) = address.rsplit_once(':').ok_or_else(|| {
Error::from((
ErrorKind::ClientError,
"Invalid address format",
address.to_string(),
))
})?;
let port = port_str
.parse()
.map_err(|_| Error::from((ErrorKind::ClientError, "Invalid port")))?;
Ok(SingleNodeRoutingInfo::ByAddress {
host: host.to_string(),
port,
})
}
fn sub_map_to_value(map: PubSubSubscriptionInfo) -> Value {
let entries: Vec<_> = map
.into_iter()
.map(|(kind, values)| {
let key = match kind {
PubSubSubscriptionKind::Exact => "Exact",
PubSubSubscriptionKind::Pattern => "Pattern",
PubSubSubscriptionKind::Sharded => "Sharded",
};
let values_array: Vec<Value> = values
.into_iter()
.map(|v| Value::BulkString(bytes::Bytes::from(v)))
.collect();
(
Value::BulkString(bytes::Bytes::from(key.as_bytes().to_vec())),
Value::Array(values_array.into_iter().map(Ok).collect()),
)
})
.collect();
Value::Map(entries)
}