use crate::{scripts, BaseDistributedObject, RObject, RObjectBase, RedissonResult, SyncRedisConnectionManager};
use parking_lot::Mutex;
use serde::{de::DeserializeOwned, Serialize};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread;
pub trait MessageListener<V>: Send + Sync + 'static
where
V: DeserializeOwned + Send + Sync + 'static,
{
fn on_message(&self, channel: &str, message: V);
}
#[derive(Clone)]
pub struct RTopic<V> {
base: BaseDistributedObject,
_marker: std::marker::PhantomData<V>,
listeners: Arc<Mutex<Vec<Arc<dyn MessageListener<V>>>>>,
is_subscribed: Arc<AtomicBool>,
subscription_thread: Arc<Mutex<Option<thread::JoinHandle<()>>>>,
}
impl<V> RTopic<V>
where
V: Serialize + DeserializeOwned + Send + Sync + 'static + Clone,
{
pub fn new(connection_manager: Arc<SyncRedisConnectionManager>, name: String) -> Self {
Self {
base: BaseDistributedObject::new(connection_manager, name),
_marker: std::marker::PhantomData,
listeners: Arc::new(Mutex::new(Vec::new())),
is_subscribed: Arc::new(AtomicBool::new(false)),
subscription_thread: Arc::new(Mutex::new(None)),
}
}
pub fn publish(&self, message: &V) -> RedissonResult<u64> {
let mut conn = self.base.get_connection()?;
let message_json = BaseDistributedObject::serialize(message)?;
let receivers: u64 = scripts::PUBLISH_SCRIPT
.key(self.base.get_full_key())
.arg(message_json)
.invoke(&mut conn)?;
Ok(receivers)
}
pub fn add_listener<L>(&self, listener: L) -> RedissonResult<()>
where
L: MessageListener<V>,
{
let arc_listener = Arc::new(listener);
{
let mut listeners = self.listeners.lock();
listeners.push(arc_listener);
}
if !self.is_subscribed.load(Ordering::Acquire) {
self.start_subscription()?;
}
Ok(())
}
pub fn remove_listener<L>(&self, listener: &L) -> RedissonResult<bool>
where
L: MessageListener<V> + ?Sized,
{
let ptr = listener as *const L as *const ();
let mut listeners = self.listeners.lock();
let original_len = listeners.len();
listeners.retain(|l| {
let listener_ptr = l.as_ref() as *const dyn MessageListener<V> as *const ();
listener_ptr != ptr
});
let removed = original_len != listeners.len();
if listeners.is_empty() {
self.stop_subscription()?;
}
Ok(removed)
}
pub fn listener_count(&self) -> usize {
self.listeners.lock().len()
}
pub fn clear_listeners(&self) -> RedissonResult<()> {
{
let mut listeners = self.listeners.lock();
listeners.clear();
}
self.stop_subscription()?;
Ok(())
}
pub fn is_subscribed(&self) -> bool {
self.is_subscribed.load(Ordering::Acquire)
}
pub fn get_topic_name(&self) -> &str {
&self.base.get_name()
}
pub fn get_channel_name(&self) -> String {
self.base.get_full_key()
}
fn start_subscription(&self) -> RedissonResult<()> {
if self.is_subscribed.swap(true, Ordering::SeqCst) {
return Ok(());
}
let channel = self.base.get_full_key();
let connection_manager = self.base.connection_manager().clone();
let listeners = self.listeners.clone();
let is_subscribed = self.is_subscribed.clone();
let handle = thread::spawn(move || {
match connection_manager.get_connection() {
Ok(mut conn) => {
let mut pubsub_conn = conn.as_pubsub().unwrap();
if let Err(e) = pubsub_conn.subscribe(&channel) {
eprintln!("Failed to subscribe to channel {}: {}", channel, e);
is_subscribed.store(false, Ordering::SeqCst);
return;
}
while is_subscribed.load(Ordering::Acquire) {
match pubsub_conn.get_message() {
Ok(msg) => {
if let Ok(payload) = msg.get_payload::<String>() {
let message_result: Result<V, _> = BaseDistributedObject::deserialize(&payload);
match message_result {
Ok(message) => {
let listeners_guard = listeners.lock();
for listener in listeners_guard.iter() {
listener.on_message(&channel, message.clone());
}
}
Err(e) => {
eprintln!("Failed to deserialize message: {}", e);
}
}
}
}
Err(e) => {
eprintln!("Failed to get message: {}", e);
break;
}
}
}
let _ = pubsub_conn.unsubscribe(&channel);
is_subscribed.store(false, Ordering::SeqCst);
}
Err(e) => {
eprintln!("Failed to get connection for subscription: {}", e);
is_subscribed.store(false, Ordering::SeqCst);
}
}
});
{
let mut thread_guard = self.subscription_thread.lock();
*thread_guard = Some(handle);
}
Ok(())
}
fn stop_subscription(&self) -> RedissonResult<()> {
if !self.is_subscribed.swap(false, Ordering::SeqCst) {
return Ok(());
}
let handle = {
let mut thread_guard = self.subscription_thread.lock();
thread_guard.take()
};
if let Some(handle) = handle {
let _ = handle.join();
}
Ok(())
}
}
pub struct FunctionListener<F, V>
where
F: Fn(String, V) + Send + Sync + 'static,
V: DeserializeOwned + Send + Sync + 'static,
{
callback: F,
_marker: std::marker::PhantomData<V>,
}
impl<F, V> FunctionListener<F, V>
where
F: Fn(String, V) + Send + Sync + 'static,
V: DeserializeOwned + Send + Sync + 'static,
{
pub fn new(callback: F) -> Self {
Self {
callback,
_marker: std::marker::PhantomData,
}
}
}
impl<F, V> MessageListener<V> for FunctionListener<F, V>
where
F: Fn(String, V) + Send + Sync + 'static,
V: DeserializeOwned + Send + Sync + 'static,
{
fn on_message(&self, channel: &str, message: V) {
(self.callback)(channel.to_string(), message);
}
}
impl<V> RTopic<V>
where
V: Serialize + DeserializeOwned + Send + Sync + 'static + Clone,
{
pub fn add_listener_fn<F>(&self, callback: F) -> RedissonResult<()>
where
F: Fn(String, V) + Send + Sync + 'static,
{
let listener = FunctionListener::new(callback);
self.add_listener(listener)
}
}
impl<V> crate::RObjectBase<String> for RTopic<V>
where
V: Serialize + DeserializeOwned + Send + Sync + 'static + Clone,
{
fn get_connection(&self) -> RedissonResult<crate::RedisConnection> {
self.base.get_connection()
}
fn serialize_value(&self, value: &String) -> RedissonResult<String> {
Ok(value.clone())
}
fn deserialize_value(&self, data: &str) -> RedissonResult<String> {
Ok(data.to_string())
}
fn get_full_key(&self) -> String {
self.base.get_full_key()
}
}
impl<V> crate::RObject for RTopic<V>
where
V: Serialize + DeserializeOwned + Send + Sync + 'static + Clone,
{
fn get_name(&self) -> &str {
self.get_topic_name()
}
fn delete(&self) -> RedissonResult<bool> {
self.base.delete()
}
fn rename(&self, new_name: &str) -> RedissonResult<()> {
self.base.rename(new_name)
}
fn is_exists(&self) -> RedissonResult<bool> {
self.base.is_exists()
}
fn move_to_db(&self, db_index: i32) -> RedissonResult<bool> {
self.base.move_to_db(db_index)
}
fn get_expire_time(&self) -> RedissonResult<Option<std::time::Duration>> {
self.base.get_expire_time()
}
fn expire(&self, duration: std::time::Duration) -> RedissonResult<bool> {
self.base.expire(duration)
}
fn expire_at(&self, timestamp: i64) -> RedissonResult<bool> {
self.base.expire_at(timestamp)
}
fn clear_expire(&self) -> RedissonResult<bool> {
self.base.clear_expire()
}
}