Skip to main content

kevy_client/
subscribe.rs

1//! Pub/sub consumer side — a connection dedicated to receiving messages.
2//!
3//! `SUBSCRIBE` / `PSUBSCRIBE` morph a connection into a one-way event
4//! stream: the client no longer sends ordinary commands and instead reads
5//! an unbounded sequence of `subscribe`, `message`, `pmessage`,
6//! `unsubscribe`, … frames until the connection is closed. That semantic
7//! doesn't fit the one-shot `Connection::request` shape, so subscribed
8//! traffic gets its own type, [`Subscriber`].
9//!
10//! Two backends, switched on the URL:
11//! - `kevy://` / `redis://` / `tcp://` — dedicated TCP socket
12//! - `mem://<name>` / `file:///path` — in-process bus, via the URL
13//!   registry in [`crate::resolve_store`]. Anonymous `mem://` (no name)
14//!   has no bus and is rejected; use a named bus to actually receive
15//!   messages from a [`crate::Connection::publish`] on the same URL.
16//!
17//! ```no_run
18//! use kevy_client::{Subscriber, PubsubEvent};
19//!
20//! let mut sub = Subscriber::open("kevy://localhost:6379", &[b"news"])?;
21//! loop {
22//!     if let PubsubEvent::Message { channel, payload } = sub.recv()? {
23//!         println!("{}: {}", String::from_utf8_lossy(&channel),
24//!                            String::from_utf8_lossy(&payload));
25//!     }
26//! }
27//! # Ok::<(), std::io::Error>(())
28//! ```
29
30use std::io::{self, Read, Write};
31use std::net::TcpStream;
32use std::time::Duration;
33
34use kevy_embedded::{PubsubFrame, Subscription};
35use kevy_resp::{Reply, encode_command, parse_reply};
36
37use crate::{Target, parse_url, resolve_store};
38
39/// One subscribed connection. Owns either a TCP socket or an in-process
40/// [`Subscription`]; the variant is chosen by the URL scheme in
41/// [`Subscriber::open`] / [`Subscriber::connect`].
42#[derive(Debug)]
43pub struct Subscriber {
44    inner: Inner,
45}
46
47#[derive(Debug)]
48enum Inner {
49    /// TCP RESP2 connection, drained one reply at a time.
50    Remote {
51        stream: TcpStream,
52        buf: Vec<u8>,
53    },
54    /// In-process bus subscription. `timeout` mirrors the TCP
55    /// `SO_RCVTIMEO` behaviour for [`Subscriber::recv`] / [`Subscriber::set_read_timeout`].
56    Embedded {
57        subscription: Subscription,
58        timeout: Option<Duration>,
59    },
60}
61
62/// One pubsub frame received from the bus or the wire.
63///
64/// `Unsubscribe` / `Punsubscribe`'s `channel` / `pattern` is `None` when
65/// the server is acknowledging "unsubscribed from everything" with a nil
66/// bulk — matching the Redis wire shape.
67#[non_exhaustive]
68#[derive(Debug, Clone, PartialEq, Eq)]
69pub enum PubsubEvent {
70    /// `SUBSCRIBE` ack — one per channel the client subscribed to.
71    Subscribe {
72        /// Channel that was just subscribed.
73        channel: Vec<u8>,
74        /// Total number of channels + patterns the connection is now subscribed to.
75        count: i64,
76    },
77    /// `PSUBSCRIBE` ack — one per pattern.
78    Psubscribe {
79        /// Pattern that was just subscribed.
80        pattern: Vec<u8>,
81        /// Total number of channels + patterns the connection is now subscribed to.
82        count: i64,
83    },
84    /// `UNSUBSCRIBE` ack — `channel: None` when the server is reporting
85    /// "no channels were subscribed" (the spec's nil bulk).
86    Unsubscribe {
87        /// Channel that was just unsubscribed (`None` for "all" / "none").
88        channel: Option<Vec<u8>>,
89        /// Total number of channels + patterns still subscribed.
90        count: i64,
91    },
92    /// `PUNSUBSCRIBE` ack — pattern `None` when the server is reporting
93    /// "no patterns were subscribed".
94    Punsubscribe {
95        /// Pattern that was just unsubscribed (`None` for "all" / "none").
96        pattern: Option<Vec<u8>>,
97        /// Total number of channels + patterns still subscribed.
98        count: i64,
99    },
100    /// Plain `PUBLISH` delivery on a subscribed channel.
101    Message {
102        /// Channel the publish was made to.
103        channel: Vec<u8>,
104        /// Raw payload bytes (no encoding assumed).
105        payload: Vec<u8>,
106    },
107    /// Pattern-match delivery: a `PUBLISH` to a channel that matched one
108    /// of this connection's patterns.
109    Pmessage {
110        /// Pattern the channel matched.
111        pattern: Vec<u8>,
112        /// Channel the publish was made to.
113        channel: Vec<u8>,
114        /// Raw payload bytes.
115        payload: Vec<u8>,
116    },
117}
118
119impl Subscriber {
120    /// Open a fresh connection without subscribing to anything yet. Call
121    /// [`Self::subscribe`] / [`Self::psubscribe`] next.
122    ///
123    /// Accepted URLs:
124    /// - `kevy://`, `redis://`, `tcp://` — TCP RESP server
125    /// - `mem://<name>`, `file:///path` — in-process shared bus
126    /// - `mem://` (anonymous), `rediss://`, `kevys://`, `redis://user:pass@…`
127    ///   are rejected with [`io::ErrorKind::Unsupported`]
128    pub fn connect(url: &str) -> io::Result<Self> {
129        let target = parse_url(url)?;
130        let inner = match target {
131            Target::EmbedMemoryAnonymous => {
132                return Err(io::Error::new(
133                    io::ErrorKind::Unsupported,
134                    "anonymous mem:// has no other producer; use mem://<name> for a shared bus",
135                ));
136            }
137            Target::EmbedMemoryNamed(_) | Target::EmbedPersist(_) => Inner::Embedded {
138                subscription: resolve_store(&target)?.subscribe(&[]),
139                timeout: None,
140            },
141            Target::Remote(remote_url) => {
142                let (host, port) = remote_host_port(&remote_url)?;
143                let stream = TcpStream::connect((host.as_str(), port))?;
144                stream.set_nodelay(true).ok();
145                Inner::Remote {
146                    stream,
147                    buf: Vec::with_capacity(8192),
148                }
149            }
150        };
151        Ok(Self { inner })
152    }
153
154    /// Open and subscribe to one or more channels in one step. Returns
155    /// `ErrorKind::InvalidInput` if `channels` is empty (use
156    /// [`Self::connect`] for an empty start).
157    pub fn open(url: &str, channels: &[&[u8]]) -> io::Result<Self> {
158        if channels.is_empty() {
159            return Err(io::Error::new(
160                io::ErrorKind::InvalidInput,
161                "Subscriber::open needs ≥ 1 channel — use Subscriber::connect() for empty start",
162            ));
163        }
164        let mut s = Self::connect(url)?;
165        s.subscribe(channels)?;
166        Ok(s)
167    }
168
169    /// `SUBSCRIBE channel [channel ...]`. Per-channel `Subscribe` acks
170    /// are delivered via [`Self::recv`].
171    pub fn subscribe(&mut self, channels: &[&[u8]]) -> io::Result<()> {
172        if channels.is_empty() {
173            return Err(io::Error::new(
174                io::ErrorKind::InvalidInput,
175                "SUBSCRIBE needs ≥ 1 channel",
176            ));
177        }
178        match &mut self.inner {
179            Inner::Remote { stream, .. } => send_to(stream, b"SUBSCRIBE", channels),
180            Inner::Embedded { subscription, .. } => {
181                subscription.subscribe(channels);
182                Ok(())
183            }
184        }
185    }
186
187    /// `PSUBSCRIBE pattern [pattern ...]`. Patterns use Redis glob syntax
188    /// (`*`, `?`, `[…]`).
189    pub fn psubscribe(&mut self, patterns: &[&[u8]]) -> io::Result<()> {
190        if patterns.is_empty() {
191            return Err(io::Error::new(
192                io::ErrorKind::InvalidInput,
193                "PSUBSCRIBE needs ≥ 1 pattern",
194            ));
195        }
196        match &mut self.inner {
197            Inner::Remote { stream, .. } => send_to(stream, b"PSUBSCRIBE", patterns),
198            Inner::Embedded { subscription, .. } => {
199                subscription.psubscribe(patterns);
200                Ok(())
201            }
202        }
203    }
204
205    /// `UNSUBSCRIBE [channel ...]`. Empty `channels` unsubscribes from
206    /// every channel (Redis wire semantics).
207    pub fn unsubscribe(&mut self, channels: &[&[u8]]) -> io::Result<()> {
208        match &mut self.inner {
209            Inner::Remote { stream, .. } => send_to(stream, b"UNSUBSCRIBE", channels),
210            Inner::Embedded { subscription, .. } => {
211                subscription.unsubscribe(channels);
212                Ok(())
213            }
214        }
215    }
216
217    /// `PUNSUBSCRIBE [pattern ...]`. Empty `patterns` unsubscribes from
218    /// every pattern.
219    pub fn punsubscribe(&mut self, patterns: &[&[u8]]) -> io::Result<()> {
220        match &mut self.inner {
221            Inner::Remote { stream, .. } => send_to(stream, b"PUNSUBSCRIBE", patterns),
222            Inner::Embedded { subscription, .. } => {
223                subscription.punsubscribe(patterns);
224                Ok(())
225            }
226        }
227    }
228
229    /// Block until the next pubsub frame arrives. Apply
230    /// [`Self::set_read_timeout`] for bounded blocking.
231    /// Connection close / bus tear-down yields `ErrorKind::UnexpectedEof`.
232    pub fn recv(&mut self) -> io::Result<PubsubEvent> {
233        match &mut self.inner {
234            Inner::Remote { stream, buf } => recv_remote(stream, buf),
235            Inner::Embedded {
236                subscription,
237                timeout,
238            } => {
239                let frame = match *timeout {
240                    Some(d) => subscription.recv_timeout(d)?,
241                    None => subscription.recv()?,
242                };
243                Ok(frame_to_event(frame))
244            }
245        }
246    }
247
248    /// Apply (or clear) a read timeout. After setting `Some(dur)`,
249    /// [`Self::recv`] returns an `io::Error` of kind `WouldBlock` /
250    /// `TimedOut` when no frame arrives within `dur`.
251    pub fn set_read_timeout(&mut self, dur: Option<Duration>) -> io::Result<()> {
252        match &mut self.inner {
253            Inner::Remote { stream, .. } => stream.set_read_timeout(dur),
254            Inner::Embedded { timeout, .. } => {
255                *timeout = dur;
256                Ok(())
257            }
258        }
259    }
260}
261
262fn send_to(stream: &mut TcpStream, verb: &[u8], args: &[&[u8]]) -> io::Result<()> {
263    let mut argv = Vec::with_capacity(args.len() + 1);
264    argv.push(verb.to_vec());
265    argv.extend(args.iter().map(|a| a.to_vec()));
266    let mut frame = Vec::new();
267    encode_command(&mut frame, &argv);
268    stream.write_all(&frame)
269}
270
271fn recv_remote(stream: &mut TcpStream, buf: &mut Vec<u8>) -> io::Result<PubsubEvent> {
272    let mut chunk = [0u8; 8192];
273    loop {
274        match parse_reply(buf) {
275            Ok(Some((reply, used))) => {
276                buf.drain(..used);
277                return classify(reply);
278            }
279            Ok(None) => {}
280            Err(_) => {
281                return Err(io::Error::new(
282                    io::ErrorKind::InvalidData,
283                    "malformed reply",
284                ));
285            }
286        }
287        let n = stream.read(&mut chunk)?;
288        if n == 0 {
289            return Err(io::Error::new(
290                io::ErrorKind::UnexpectedEof,
291                "server closed connection",
292            ));
293        }
294        buf.extend_from_slice(&chunk[..n]);
295    }
296}
297
298fn frame_to_event(frame: PubsubFrame) -> PubsubEvent {
299    match frame {
300        PubsubFrame::Subscribe { channel, count } => PubsubEvent::Subscribe {
301            channel,
302            count: count as i64,
303        },
304        PubsubFrame::Psubscribe { pattern, count } => PubsubEvent::Psubscribe {
305            pattern,
306            count: count as i64,
307        },
308        PubsubFrame::Unsubscribe { channel, count } => PubsubEvent::Unsubscribe {
309            channel,
310            count: count as i64,
311        },
312        PubsubFrame::Punsubscribe { pattern, count } => PubsubEvent::Punsubscribe {
313            pattern,
314            count: count as i64,
315        },
316        PubsubFrame::Message { channel, payload } => PubsubEvent::Message { channel, payload },
317        PubsubFrame::Pmessage {
318            pattern,
319            channel,
320            payload,
321        } => PubsubEvent::Pmessage {
322            pattern,
323            channel,
324            payload,
325        },
326    }
327}
328
329fn classify(reply: Reply) -> io::Result<PubsubEvent> {
330    let items = match reply {
331        Reply::Array(v) => v,
332        other => return Err(invalid(format!("expected array frame, got {}", shape(&other)))),
333    };
334    let kind = match items.first() {
335        Some(Reply::Bulk(b)) => b.clone(),
336        _ => return Err(invalid("pubsub frame missing kind field")),
337    };
338    match kind.as_slice() {
339        b"subscribe" => {
340            let [_, ch, n] = into_array3(items)?;
341            Ok(PubsubEvent::Subscribe {
342                channel: take_bulk(ch, "channel")?,
343                count: take_int(n, "count")?,
344            })
345        }
346        b"psubscribe" => {
347            let [_, p, n] = into_array3(items)?;
348            Ok(PubsubEvent::Psubscribe {
349                pattern: take_bulk(p, "pattern")?,
350                count: take_int(n, "count")?,
351            })
352        }
353        b"unsubscribe" => {
354            let [_, ch, n] = into_array3(items)?;
355            Ok(PubsubEvent::Unsubscribe {
356                channel: take_bulk_or_nil(ch, "channel")?,
357                count: take_int(n, "count")?,
358            })
359        }
360        b"punsubscribe" => {
361            let [_, p, n] = into_array3(items)?;
362            Ok(PubsubEvent::Punsubscribe {
363                pattern: take_bulk_or_nil(p, "pattern")?,
364                count: take_int(n, "count")?,
365            })
366        }
367        b"message" => {
368            let [_, ch, payload] = into_array3(items)?;
369            Ok(PubsubEvent::Message {
370                channel: take_bulk(ch, "channel")?,
371                payload: take_bulk(payload, "payload")?,
372            })
373        }
374        b"pmessage" => {
375            let [_, pat, ch, payload] = into_array4(items)?;
376            Ok(PubsubEvent::Pmessage {
377                pattern: take_bulk(pat, "pattern")?,
378                channel: take_bulk(ch, "channel")?,
379                payload: take_bulk(payload, "payload")?,
380            })
381        }
382        other => Err(invalid(format!(
383            "unknown pubsub kind '{}'",
384            String::from_utf8_lossy(other)
385        ))),
386    }
387}
388
389fn into_array3(items: Vec<Reply>) -> io::Result<[Reply; 3]> {
390    items.try_into().map_err(|v: Vec<Reply>| {
391        invalid(format!("expected 3-element pubsub frame, got {}", v.len()))
392    })
393}
394
395fn into_array4(items: Vec<Reply>) -> io::Result<[Reply; 4]> {
396    items.try_into().map_err(|v: Vec<Reply>| {
397        invalid(format!("expected 4-element pubsub frame, got {}", v.len()))
398    })
399}
400
401fn take_bulk(r: Reply, field: &str) -> io::Result<Vec<u8>> {
402    match r {
403        Reply::Bulk(b) => Ok(b),
404        other => Err(invalid(format!(
405            "expected bulk for {field}, got {}",
406            shape(&other)
407        ))),
408    }
409}
410
411fn take_bulk_or_nil(r: Reply, field: &str) -> io::Result<Option<Vec<u8>>> {
412    match r {
413        Reply::Bulk(b) => Ok(Some(b)),
414        Reply::Nil => Ok(None),
415        other => Err(invalid(format!(
416            "expected bulk/nil for {field}, got {}",
417            shape(&other)
418        ))),
419    }
420}
421
422fn take_int(r: Reply, field: &str) -> io::Result<i64> {
423    match r {
424        Reply::Int(n) => Ok(n),
425        other => Err(invalid(format!(
426            "expected integer for {field}, got {}",
427            shape(&other)
428        ))),
429    }
430}
431
432fn shape(r: &Reply) -> &'static str {
433    match r {
434        Reply::Simple(_) => "simple-string",
435        Reply::Error(_) => "error",
436        Reply::Int(_) => "integer",
437        Reply::Bulk(_) => "bulk-string",
438        Reply::Nil => "nil",
439        Reply::Array(_) => "array",
440    }
441}
442
443fn invalid(msg: impl Into<String>) -> io::Error {
444    io::Error::new(io::ErrorKind::InvalidData, msg.into())
445}
446
447// ─────────────────────────────────────────────────────────────────────────
448// Remote host:port extraction. Reuses the same authority parsing logic
449// kevy-resp-client::from_url applies, but only needs host+port (pub/sub
450// is global, not db-scoped — any /N path segment is ignored).
451// ─────────────────────────────────────────────────────────────────────────
452
453fn remote_host_port(url: &str) -> io::Result<(String, u16)> {
454    let (_scheme, rest) = url.split_once("://").ok_or_else(|| {
455        io::Error::new(io::ErrorKind::InvalidInput, "URL missing '://'")
456    })?;
457    if rest.contains('@') {
458        return Err(io::Error::new(
459            io::ErrorKind::Unsupported,
460            "userinfo (user:pass@host) is unsupported — kevy has no AUTH",
461        ));
462    }
463    let authority = rest.split('/').next().unwrap_or("");
464    let (host, port) = match authority.rsplit_once(':') {
465        Some((h, p)) => {
466            let port: u16 = p.parse().map_err(|_| {
467                io::Error::new(io::ErrorKind::InvalidInput, format!("bad port: {p}"))
468            })?;
469            (h.to_string(), port)
470        }
471        None => (authority.to_string(), 6379),
472    };
473    if host.is_empty() {
474        return Err(io::Error::new(io::ErrorKind::InvalidInput, "empty host"));
475    }
476    Ok((host, port))
477}
478
479#[cfg(test)]
480#[path = "subscribe_tests.rs"]
481mod tests;