1use std::fmt::{Debug, Formatter};
17use std::sync::Arc;
18
19use bytes::Bytes;
20use fred::clients::Client;
21use fred::interfaces::{ClientLike, PubsubInterface};
22use fred::types::{Message, MessageKind};
23use futures::Stream;
24use futures::stream::unfold;
25use ruststream::codec::Codec;
26use ruststream::{
27 AckError, Headers, IncomingMessage, OutgoingMessage, Partitioned, Publisher, SubscriptionSource,
28};
29use tokio::sync::OnceCell;
30use tokio::sync::broadcast::{Receiver, error::RecvError};
31
32use crate::envelope::{SharedEnvelope, frame, unframe};
33use crate::{RedisBroker, error::RedisError, message::PARTITION_KEY_HEADER};
34
35#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
37pub enum PubSubMode {
38 #[default]
40 Classic,
41 Sharded,
43}
44
45#[derive(Clone)]
58#[must_use]
59pub struct RedisPubSub {
60 channel: String,
61 mode: PubSubMode,
62 pattern: bool,
63 codec: Option<SharedEnvelope>,
64}
65
66impl Debug for RedisPubSub {
67 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
68 f.debug_struct("RedisPubSub")
69 .field("channel", &self.channel)
70 .field("mode", &self.mode)
71 .field("pattern", &self.pattern)
72 .field("codec", &self.codec.is_some())
73 .finish()
74 }
75}
76
77impl RedisPubSub {
78 pub fn new(channel: impl Into<String>) -> Self {
80 Self {
81 channel: channel.into(),
82 mode: PubSubMode::default(),
83 pattern: false,
84 codec: None,
85 }
86 }
87
88 pub const fn mode(mut self, mode: PubSubMode) -> Self {
90 self.mode = mode;
91 self
92 }
93
94 pub const fn pattern(mut self) -> Self {
97 self.pattern = true;
98 self
99 }
100
101 pub fn codec(mut self, codec: impl Codec + 'static) -> Self {
104 self.codec = Some(Arc::new(codec));
105 self
106 }
107
108 #[must_use]
110 pub fn channel(&self) -> &str {
111 &self.channel
112 }
113
114 pub(crate) const fn delivery_mode(&self) -> PubSubMode {
115 self.mode
116 }
117
118 pub(crate) const fn is_pattern(&self) -> bool {
119 self.pattern
120 }
121
122 pub(crate) fn codec_handle(&self) -> Option<SharedEnvelope> {
123 self.codec.clone()
124 }
125
126 pub(crate) fn validate(&self) -> Result<(), RedisError> {
127 if self.pattern && matches!(self.mode, PubSubMode::Sharded) {
128 return Err(RedisError::InvalidOptions(
129 "pattern subscriptions are classic-only; sharded pub/sub has no PSUBSCRIBE"
130 .to_owned(),
131 ));
132 }
133 Ok(())
134 }
135}
136
137impl SubscriptionSource<RedisBroker> for RedisPubSub {
138 type Subscriber = RedisPubSubSubscriber;
139
140 fn name(&self) -> &str {
141 self.channel()
142 }
143
144 async fn subscribe(self, broker: &RedisBroker) -> Result<Self::Subscriber, RedisError> {
145 broker.subscribe_pubsub(self).await
146 }
147}
148
149#[cfg(feature = "testing")]
150impl SubscriptionSource<crate::testing::RedisTestBroker> for RedisPubSub {
151 type Subscriber = crate::testing::RedisTestSubscriber;
152
153 fn name(&self) -> &str {
154 self.channel()
155 }
156
157 async fn subscribe(
158 self,
159 broker: &crate::testing::RedisTestBroker,
160 ) -> Result<Self::Subscriber, RedisError> {
161 broker.subscribe(self.channel()).await
162 }
163}
164
165pub struct RedisPubSubSubscriber {
168 client: Client,
169 rx: Receiver<Message>,
170 codec: Option<SharedEnvelope>,
171}
172
173impl Debug for RedisPubSubSubscriber {
174 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
175 f.debug_struct("RedisPubSubSubscriber")
176 .finish_non_exhaustive()
177 }
178}
179
180impl RedisPubSubSubscriber {
181 pub(crate) fn new(
182 client: Client,
183 rx: Receiver<Message>,
184 codec: Option<SharedEnvelope>,
185 ) -> Self {
186 Self { client, rx, codec }
187 }
188}
189
190impl Drop for RedisPubSubSubscriber {
191 fn drop(&mut self) {
192 let client = self.client.clone();
195 tokio::spawn(async move {
196 let _ = client.quit().await;
197 });
198 }
199}
200
201fn to_message(msg: &Message, codec: Option<&SharedEnvelope>) -> RedisPubSubMessage {
202 let raw = msg.value.as_bytes().unwrap_or(&[]);
203 let (payload, headers) = unframe(codec, raw);
204 RedisPubSubMessage {
205 channel: msg.channel.to_string(),
206 pattern: matches!(msg.kind, MessageKind::PMessage),
209 payload,
210 headers,
211 }
212}
213
214impl ruststream::Subscriber for RedisPubSubSubscriber {
215 type Message = RedisPubSubMessage;
216 type Error = RedisError;
217
218 fn stream(&mut self) -> impl Stream<Item = Result<Self::Message, Self::Error>> + Send + '_ {
226 let codec = self.codec.clone();
227 unfold((&mut self.rx, codec), |(rx, codec)| async move {
228 loop {
229 match rx.recv().await {
230 Ok(msg) => {
231 let message = to_message(&msg, codec.as_ref());
232 return Some((Ok(message), (rx, codec)));
233 }
234 Err(RecvError::Lagged(_)) => {}
236 Err(RecvError::Closed) => return None,
237 }
238 }
239 })
240 }
241}
242
243pub struct RedisPubSubMessage {
245 channel: String,
246 pattern: bool,
248 payload: Bytes,
249 headers: Headers,
250}
251
252impl Debug for RedisPubSubMessage {
253 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
254 f.debug_struct("RedisPubSubMessage")
255 .field("channel", &self.channel)
256 .field("pattern", &self.pattern)
257 .field("payload_len", &self.payload.len())
258 .finish_non_exhaustive()
259 }
260}
261
262impl RedisPubSubMessage {
263 #[must_use]
268 pub fn channel(&self) -> &str {
269 &self.channel
270 }
271
272 #[must_use]
275 pub fn from_pattern(&self) -> bool {
276 self.pattern
277 }
278}
279
280impl IncomingMessage for RedisPubSubMessage {
281 fn payload(&self) -> &[u8] {
282 &self.payload
283 }
284
285 fn headers(&self) -> &Headers {
286 &self.headers
287 }
288
289 async fn ack(self) -> Result<(), AckError> {
290 Err(AckError::Unsupported)
291 }
292
293 async fn nack(self, _requeue: bool) -> Result<(), AckError> {
294 Err(AckError::Unsupported)
295 }
296}
297
298impl Partitioned for RedisPubSubMessage {
299 fn partition_key(&self) -> Option<&[u8]> {
300 self.headers().get(PARTITION_KEY_HEADER)
301 }
302}
303
304#[derive(Clone)]
311pub struct RedisPubSubPublisher {
312 pool: Arc<OnceCell<fred::clients::Pool>>,
313 mode: PubSubMode,
314 codec: Option<SharedEnvelope>,
315}
316
317impl Debug for RedisPubSubPublisher {
318 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
319 f.debug_struct("RedisPubSubPublisher")
320 .field("mode", &self.mode)
321 .field("codec", &self.codec.is_some())
322 .finish_non_exhaustive()
323 }
324}
325
326impl RedisPubSubPublisher {
327 pub(crate) fn new(pool: Arc<OnceCell<fred::clients::Pool>>, mode: PubSubMode) -> Self {
328 Self {
329 pool,
330 mode,
331 codec: None,
332 }
333 }
334
335 #[must_use]
338 pub const fn mode(mut self, mode: PubSubMode) -> Self {
339 self.mode = mode;
340 self
341 }
342
343 #[must_use]
346 pub fn codec(mut self, codec: impl Codec + 'static) -> Self {
347 self.codec = Some(Arc::new(codec));
348 self
349 }
350}
351
352impl Publisher for RedisPubSubPublisher {
353 type Error = RedisError;
354
355 async fn publish(&self, msg: OutgoingMessage<'_>) -> Result<(), Self::Error> {
356 let pool = self.pool.get().cloned().ok_or(RedisError::NotConnected)?;
357 let client = pool.next();
358 let channel = msg.name().to_owned();
359 let body = frame(self.codec.as_ref(), msg.payload(), msg.headers());
360 let _: i64 = match self.mode {
361 PubSubMode::Classic => client.publish(channel, body).await,
362 PubSubMode::Sharded => client.spublish(channel, body).await,
363 }
364 .map_err(RedisError::publish)?;
365 Ok(())
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372 use crate::context::PubSubContext;
373 use ruststream::BuildContext;
374
375 #[test]
376 fn build_context_reads_channel_and_pattern_flag() {
377 let exact = RedisPubSubMessage {
378 channel: "events".to_owned(),
379 pattern: false,
380 payload: Bytes::from_static(b"{}"),
381 headers: Headers::new(),
382 };
383 let cx = PubSubContext::build(&exact);
384 assert_eq!(cx.channel(), "events");
385 assert!(!cx.from_pattern());
386
387 let matched = RedisPubSubMessage {
388 channel: "events.user".to_owned(),
389 pattern: true,
390 payload: Bytes::from_static(b"{}"),
391 headers: Headers::new(),
392 };
393 assert!(PubSubContext::build(&matched).from_pattern());
394 }
395
396 #[test]
397 fn pattern_with_sharded_is_rejected() {
398 let err = RedisPubSub::new("e.*")
399 .mode(PubSubMode::Sharded)
400 .pattern()
401 .validate()
402 .unwrap_err();
403 assert!(matches!(err, RedisError::InvalidOptions(msg) if msg.contains("classic-only")));
404 }
405
406 #[test]
407 fn classic_pattern_validates() {
408 RedisPubSub::new("e.*").pattern().validate().expect("ok");
409 }
410}