Skip to main content

kevy_client/
subscribe.rs

1//! Pub/sub consumer side — a TCP connection dedicated to receiving messages.
2//!
3//! `SUBSCRIBE` / `PSUBSCRIBE` morph a Redis/kevy connection into a one-way
4//! event stream: the client no longer sends ordinary commands and instead
5//! reads an unbounded sequence of `subscribe`, `message`, `pmessage`,
6//! `unsubscribe`, … frames until the connection is closed. That semantic
7//! doesn't fit the one-shot `Connection::request` shape — so subscribed
8//! traffic gets its own type, [`Subscriber`], on its own socket.
9//!
10//! ```no_run
11//! use kevy_client::Subscriber;
12//!
13//! let mut sub = Subscriber::open("kevy://localhost:6379", &[b"news"])?;
14//! loop {
15//!     match sub.recv()? {
16//!         kevy_client::PubsubEvent::Message { channel, payload } => {
17//!             println!("{}: {}", String::from_utf8_lossy(&channel),
18//!                                String::from_utf8_lossy(&payload));
19//!         }
20//!         _ => {}  // ignore subscribe-acks and other meta frames
21//!     }
22//! }
23//! # Ok::<(), std::io::Error>(())
24//! ```
25//!
26//! `mem://` / `file://` URLs are rejected with `ErrorKind::Unsupported`:
27//! single-process embed has no other producer to receive messages from.
28
29use std::io::{self, Read, Write};
30use std::net::TcpStream;
31use std::time::Duration;
32
33use kevy_resp::{Reply, encode_command, parse_reply};
34
35/// One subscribed TCP connection. Owns the socket; not `Sync`.
36#[derive(Debug)]
37pub struct Subscriber {
38    stream: TcpStream,
39    buf: Vec<u8>,
40}
41
42/// One pubsub frame received from the server.
43///
44/// `Unsubscribe` / `Punsubscribe`'s `channel` / `pattern` is `None` when the
45/// server is acknowledging "unsubscribed from everything" with a nil bulk
46/// — matching the Redis wire shape.
47#[non_exhaustive]
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub enum PubsubEvent {
50    /// `SUBSCRIBE` ack — one per channel the client subscribed to.
51    Subscribe {
52        /// Channel that was just subscribed.
53        channel: Vec<u8>,
54        /// Total number of channels + patterns the connection is now subscribed to.
55        count: i64,
56    },
57    /// `PSUBSCRIBE` ack — one per pattern.
58    Psubscribe {
59        /// Pattern that was just subscribed.
60        pattern: Vec<u8>,
61        /// Total number of channels + patterns the connection is now subscribed to.
62        count: i64,
63    },
64    /// `UNSUBSCRIBE` ack — `channel: None` when the server is reporting
65    /// "no channels were subscribed" (the spec's nil bulk).
66    Unsubscribe {
67        /// Channel that was just unsubscribed (`None` for "all" / "none").
68        channel: Option<Vec<u8>>,
69        /// Total number of channels + patterns still subscribed.
70        count: i64,
71    },
72    /// `PUNSUBSCRIBE` ack — pattern `None` when the server is reporting
73    /// "no patterns were subscribed".
74    Punsubscribe {
75        /// Pattern that was just unsubscribed (`None` for "all" / "none").
76        pattern: Option<Vec<u8>>,
77        /// Total number of channels + patterns still subscribed.
78        count: i64,
79    },
80    /// Plain `PUBLISH` delivery on a subscribed channel.
81    Message {
82        /// Channel the publish was made to.
83        channel: Vec<u8>,
84        /// Raw payload bytes (no encoding assumed).
85        payload: Vec<u8>,
86    },
87    /// Pattern-match delivery: a `PUBLISH` to a channel that matched one
88    /// of this connection's patterns.
89    Pmessage {
90        /// Pattern the channel matched.
91        pattern: Vec<u8>,
92        /// Channel the publish was made to.
93        channel: Vec<u8>,
94        /// Raw payload bytes.
95        payload: Vec<u8>,
96    },
97}
98
99impl Subscriber {
100    /// Open a fresh TCP connection without subscribing to anything.
101    /// Use [`Self::subscribe`] / [`Self::psubscribe`] next.
102    ///
103    /// Accepted URL schemes: `kevy://`, `redis://`, `tcp://` (all wire-identical).
104    /// `mem://` / `file://` return `ErrorKind::Unsupported` — there is no
105    /// other process to receive messages from inside an embedded store.
106    pub fn connect(url: &str) -> io::Result<Self> {
107        let (host, port) = parse_pubsub_url(url)?;
108        let stream = TcpStream::connect((host.as_str(), port))?;
109        stream.set_nodelay(true).ok();
110        Ok(Self {
111            stream,
112            buf: Vec::with_capacity(8192),
113        })
114    }
115
116    /// Open and subscribe to one or more channels in one step. After the
117    /// call returns, the server has the `SUBSCRIBE` command queued — drain
118    /// the per-channel ack frames with [`Self::recv`] before
119    /// you act on `Message` events.
120    ///
121    /// Returns `ErrorKind::InvalidInput` if `channels` is empty (use
122    /// [`Self::connect`] + [`Self::psubscribe`] for a pattern-only start).
123    pub fn open(url: &str, channels: &[&[u8]]) -> io::Result<Self> {
124        if channels.is_empty() {
125            return Err(io::Error::new(
126                io::ErrorKind::InvalidInput,
127                "Subscriber::open needs ≥ 1 channel — use Subscriber::connect() for empty start",
128            ));
129        }
130        let mut s = Self::connect(url)?;
131        s.subscribe(channels)?;
132        Ok(s)
133    }
134
135    /// `SUBSCRIBE channel [channel ...]`. Returns once the bytes are written;
136    /// the server sends one `Subscribe` ack per channel — drain with
137    /// [`Self::recv`].
138    pub fn subscribe(&mut self, channels: &[&[u8]]) -> io::Result<()> {
139        if channels.is_empty() {
140            return Err(io::Error::new(
141                io::ErrorKind::InvalidInput,
142                "SUBSCRIBE needs ≥ 1 channel",
143            ));
144        }
145        self.send(b"SUBSCRIBE", channels)
146    }
147
148    /// `PSUBSCRIBE pattern [pattern ...]`. Patterns use Redis glob syntax
149    /// (`*`, `?`, `[…]`). Same ack-draining note as [`Self::subscribe`].
150    pub fn psubscribe(&mut self, patterns: &[&[u8]]) -> io::Result<()> {
151        if patterns.is_empty() {
152            return Err(io::Error::new(
153                io::ErrorKind::InvalidInput,
154                "PSUBSCRIBE needs ≥ 1 pattern",
155            ));
156        }
157        self.send(b"PSUBSCRIBE", patterns)
158    }
159
160    /// `UNSUBSCRIBE [channel ...]`. Empty `channels` unsubscribes from
161    /// every channel (Redis wire semantics).
162    pub fn unsubscribe(&mut self, channels: &[&[u8]]) -> io::Result<()> {
163        self.send(b"UNSUBSCRIBE", channels)
164    }
165
166    /// `PUNSUBSCRIBE [pattern ...]`. Empty `patterns` unsubscribes from
167    /// every pattern.
168    pub fn punsubscribe(&mut self, patterns: &[&[u8]]) -> io::Result<()> {
169        self.send(b"PUNSUBSCRIBE", patterns)
170    }
171
172    /// Block until the next pubsub frame arrives, parse it, classify it.
173    ///
174    /// `recv` itself never times out — apply a read timeout via
175    /// [`Self::set_read_timeout`] if you need bounded blocking.
176    /// Server close yields `ErrorKind::UnexpectedEof`; a malformed RESP
177    /// frame yields `ErrorKind::InvalidData`.
178    pub fn recv(&mut self) -> io::Result<PubsubEvent> {
179        let mut chunk = [0u8; 8192];
180        loop {
181            match parse_reply(&self.buf) {
182                Ok(Some((reply, used))) => {
183                    self.buf.drain(..used);
184                    return classify(reply);
185                }
186                Ok(None) => {}
187                Err(_) => {
188                    return Err(io::Error::new(
189                        io::ErrorKind::InvalidData,
190                        "malformed reply",
191                    ));
192                }
193            }
194            let n = self.stream.read(&mut chunk)?;
195            if n == 0 {
196                return Err(io::Error::new(
197                    io::ErrorKind::UnexpectedEof,
198                    "server closed connection",
199                ));
200            }
201            self.buf.extend_from_slice(&chunk[..n]);
202        }
203    }
204
205    /// Apply (or clear) a read timeout on the underlying socket.
206    /// After setting `Some(dur)`, [`Self::recv`] will return an `io::Error`
207    /// of kind `WouldBlock` / `TimedOut` if no data arrives within `dur`.
208    pub fn set_read_timeout(&mut self, dur: Option<Duration>) -> io::Result<()> {
209        self.stream.set_read_timeout(dur)
210    }
211
212    fn send(&mut self, verb: &[u8], args: &[&[u8]]) -> io::Result<()> {
213        let mut argv = Vec::with_capacity(args.len() + 1);
214        argv.push(verb.to_vec());
215        argv.extend(args.iter().map(|a| a.to_vec()));
216        let mut frame = Vec::new();
217        encode_command(&mut frame, &argv);
218        self.stream.write_all(&frame)
219    }
220}
221
222fn classify(reply: Reply) -> io::Result<PubsubEvent> {
223    let items = match reply {
224        Reply::Array(v) => v,
225        other => return Err(invalid(format!("expected array frame, got {}", shape(&other)))),
226    };
227    let kind = match items.first() {
228        Some(Reply::Bulk(b)) => b.clone(),
229        _ => return Err(invalid("pubsub frame missing kind field")),
230    };
231    match kind.as_slice() {
232        b"subscribe" => {
233            let [_, ch, n] = into_array3(items)?;
234            Ok(PubsubEvent::Subscribe {
235                channel: take_bulk(ch, "channel")?,
236                count: take_int(n, "count")?,
237            })
238        }
239        b"psubscribe" => {
240            let [_, p, n] = into_array3(items)?;
241            Ok(PubsubEvent::Psubscribe {
242                pattern: take_bulk(p, "pattern")?,
243                count: take_int(n, "count")?,
244            })
245        }
246        b"unsubscribe" => {
247            let [_, ch, n] = into_array3(items)?;
248            Ok(PubsubEvent::Unsubscribe {
249                channel: take_bulk_or_nil(ch, "channel")?,
250                count: take_int(n, "count")?,
251            })
252        }
253        b"punsubscribe" => {
254            let [_, p, n] = into_array3(items)?;
255            Ok(PubsubEvent::Punsubscribe {
256                pattern: take_bulk_or_nil(p, "pattern")?,
257                count: take_int(n, "count")?,
258            })
259        }
260        b"message" => {
261            let [_, ch, payload] = into_array3(items)?;
262            Ok(PubsubEvent::Message {
263                channel: take_bulk(ch, "channel")?,
264                payload: take_bulk(payload, "payload")?,
265            })
266        }
267        b"pmessage" => {
268            let [_, pat, ch, payload] = into_array4(items)?;
269            Ok(PubsubEvent::Pmessage {
270                pattern: take_bulk(pat, "pattern")?,
271                channel: take_bulk(ch, "channel")?,
272                payload: take_bulk(payload, "payload")?,
273            })
274        }
275        other => Err(invalid(format!(
276            "unknown pubsub kind '{}'",
277            String::from_utf8_lossy(other)
278        ))),
279    }
280}
281
282fn into_array3(items: Vec<Reply>) -> io::Result<[Reply; 3]> {
283    items.try_into().map_err(|v: Vec<Reply>| {
284        invalid(format!("expected 3-element pubsub frame, got {}", v.len()))
285    })
286}
287
288fn into_array4(items: Vec<Reply>) -> io::Result<[Reply; 4]> {
289    items.try_into().map_err(|v: Vec<Reply>| {
290        invalid(format!("expected 4-element pubsub frame, got {}", v.len()))
291    })
292}
293
294fn take_bulk(r: Reply, field: &str) -> io::Result<Vec<u8>> {
295    match r {
296        Reply::Bulk(b) => Ok(b),
297        other => Err(invalid(format!(
298            "expected bulk for {field}, got {}",
299            shape(&other)
300        ))),
301    }
302}
303
304fn take_bulk_or_nil(r: Reply, field: &str) -> io::Result<Option<Vec<u8>>> {
305    match r {
306        Reply::Bulk(b) => Ok(Some(b)),
307        Reply::Nil => Ok(None),
308        other => Err(invalid(format!(
309            "expected bulk/nil for {field}, got {}",
310            shape(&other)
311        ))),
312    }
313}
314
315fn take_int(r: Reply, field: &str) -> io::Result<i64> {
316    match r {
317        Reply::Int(n) => Ok(n),
318        other => Err(invalid(format!(
319            "expected integer for {field}, got {}",
320            shape(&other)
321        ))),
322    }
323}
324
325fn shape(r: &Reply) -> &'static str {
326    match r {
327        Reply::Simple(_) => "simple-string",
328        Reply::Error(_) => "error",
329        Reply::Int(_) => "integer",
330        Reply::Bulk(_) => "bulk-string",
331        Reply::Nil => "nil",
332        Reply::Array(_) => "array",
333    }
334}
335
336fn invalid(msg: impl Into<String>) -> io::Error {
337    io::Error::new(io::ErrorKind::InvalidData, msg.into())
338}
339
340// ─────────────────────────────────────────────────────────────────────────
341// URL parsing (kevy://, redis://, tcp://; rejects mem/file/rediss/userinfo)
342// ─────────────────────────────────────────────────────────────────────────
343
344fn parse_pubsub_url(url: &str) -> io::Result<(String, u16)> {
345    let (scheme, rest) = url.split_once("://").ok_or_else(|| {
346        io::Error::new(io::ErrorKind::InvalidInput, "URL missing '://'")
347    })?;
348    match scheme {
349        "kevy" | "redis" | "tcp" => {}
350        "mem" | "file" => {
351            return Err(io::Error::new(
352                io::ErrorKind::Unsupported,
353                format!(
354                    "{scheme}:// is an embedded backend — pub/sub needs a TCP server. \
355                     Use kevy://host:port instead."
356                ),
357            ));
358        }
359        "rediss" | "kevys" => {
360            return Err(io::Error::new(
361                io::ErrorKind::Unsupported,
362                "TLS schemes (rediss://, kevys://) are unsupported — kevy has no TLS",
363            ));
364        }
365        other => {
366            return Err(io::Error::new(
367                io::ErrorKind::InvalidInput,
368                format!("unknown URL scheme '{other}://'"),
369            ));
370        }
371    }
372    if rest.contains('@') {
373        return Err(io::Error::new(
374            io::ErrorKind::Unsupported,
375            "userinfo (user:pass@host) is unsupported — kevy has no AUTH",
376        ));
377    }
378    let authority = rest.split('/').next().unwrap_or("");
379    let (host, port) = match authority.rsplit_once(':') {
380        Some((h, p)) => {
381            let port: u16 = p.parse().map_err(|_| {
382                io::Error::new(io::ErrorKind::InvalidInput, format!("bad port: {p}"))
383            })?;
384            (h.to_string(), port)
385        }
386        None => (authority.to_string(), 6379),
387    };
388    if host.is_empty() {
389        return Err(io::Error::new(io::ErrorKind::InvalidInput, "empty host"));
390    }
391    Ok((host, port))
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397
398    // ----- URL parsing -----
399
400    #[test]
401    fn parses_kevy_redis_tcp() {
402        for url in [
403            "kevy://localhost:6379",
404            "redis://localhost:6379",
405            "tcp://localhost:6379",
406        ] {
407            let (h, p) = parse_pubsub_url(url).unwrap();
408            assert_eq!(h, "localhost");
409            assert_eq!(p, 6379);
410        }
411    }
412
413    #[test]
414    fn default_port_when_omitted() {
415        let (h, p) = parse_pubsub_url("kevy://example.com").unwrap();
416        assert_eq!(h, "example.com");
417        assert_eq!(p, 6379);
418    }
419
420    #[test]
421    fn db_path_segment_ignored() {
422        // Pub/sub is global, not db-scoped — `/N` is accepted but discarded.
423        let (h, p) = parse_pubsub_url("kevy://h:1234/0").unwrap();
424        assert_eq!(h, "h");
425        assert_eq!(p, 1234);
426        let (h, p) = parse_pubsub_url("redis://h:1234/3").unwrap();
427        assert_eq!(h, "h");
428        assert_eq!(p, 1234);
429    }
430
431    #[test]
432    fn mem_file_rejected_unsupported() {
433        for url in ["mem://", "file:///tmp/data"] {
434            let err = parse_pubsub_url(url).unwrap_err();
435            assert_eq!(err.kind(), io::ErrorKind::Unsupported);
436        }
437    }
438
439    #[test]
440    fn tls_rejected_unsupported() {
441        assert_eq!(
442            parse_pubsub_url("rediss://h:6379").unwrap_err().kind(),
443            io::ErrorKind::Unsupported
444        );
445    }
446
447    #[test]
448    fn userinfo_rejected_unsupported() {
449        assert_eq!(
450            parse_pubsub_url("kevy://u:p@h:6379").unwrap_err().kind(),
451            io::ErrorKind::Unsupported
452        );
453    }
454
455    #[test]
456    fn unknown_scheme_rejected() {
457        assert_eq!(
458            parse_pubsub_url("memcached://h:11211").unwrap_err().kind(),
459            io::ErrorKind::InvalidInput
460        );
461    }
462
463    #[test]
464    fn bad_port_rejected() {
465        assert!(parse_pubsub_url("kevy://h:notaport").is_err());
466        assert!(parse_pubsub_url("kevy://h:99999").is_err());
467    }
468
469    #[test]
470    fn empty_host_rejected() {
471        assert!(parse_pubsub_url("kevy://:6379").is_err());
472    }
473
474    // ----- classify -----
475
476    #[test]
477    fn classify_subscribe_ack() {
478        let r = Reply::Array(vec![
479            Reply::Bulk(b"subscribe".to_vec()),
480            Reply::Bulk(b"chan".to_vec()),
481            Reply::Int(1),
482        ]);
483        assert_eq!(
484            classify(r).unwrap(),
485            PubsubEvent::Subscribe {
486                channel: b"chan".to_vec(),
487                count: 1,
488            }
489        );
490    }
491
492    #[test]
493    fn classify_psubscribe_ack() {
494        let r = Reply::Array(vec![
495            Reply::Bulk(b"psubscribe".to_vec()),
496            Reply::Bulk(b"chan.*".to_vec()),
497            Reply::Int(2),
498        ]);
499        assert_eq!(
500            classify(r).unwrap(),
501            PubsubEvent::Psubscribe {
502                pattern: b"chan.*".to_vec(),
503                count: 2,
504            }
505        );
506    }
507
508    #[test]
509    fn classify_message_event() {
510        let r = Reply::Array(vec![
511            Reply::Bulk(b"message".to_vec()),
512            Reply::Bulk(b"news".to_vec()),
513            Reply::Bulk(b"hello".to_vec()),
514        ]);
515        assert_eq!(
516            classify(r).unwrap(),
517            PubsubEvent::Message {
518                channel: b"news".to_vec(),
519                payload: b"hello".to_vec(),
520            }
521        );
522    }
523
524    #[test]
525    fn classify_pmessage_event() {
526        let r = Reply::Array(vec![
527            Reply::Bulk(b"pmessage".to_vec()),
528            Reply::Bulk(b"news.*".to_vec()),
529            Reply::Bulk(b"news.tech".to_vec()),
530            Reply::Bulk(b"hi".to_vec()),
531        ]);
532        assert_eq!(
533            classify(r).unwrap(),
534            PubsubEvent::Pmessage {
535                pattern: b"news.*".to_vec(),
536                channel: b"news.tech".to_vec(),
537                payload: b"hi".to_vec(),
538            }
539        );
540    }
541
542    #[test]
543    fn classify_unsubscribe_with_channel() {
544        let r = Reply::Array(vec![
545            Reply::Bulk(b"unsubscribe".to_vec()),
546            Reply::Bulk(b"chan".to_vec()),
547            Reply::Int(0),
548        ]);
549        assert_eq!(
550            classify(r).unwrap(),
551            PubsubEvent::Unsubscribe {
552                channel: Some(b"chan".to_vec()),
553                count: 0,
554            }
555        );
556    }
557
558    #[test]
559    fn classify_unsubscribe_with_nil_channel() {
560        // Spec: when there were no subscribed channels, the server replies
561        // with a nil bulk in the channel slot.
562        let r = Reply::Array(vec![
563            Reply::Bulk(b"unsubscribe".to_vec()),
564            Reply::Nil,
565            Reply::Int(0),
566        ]);
567        assert_eq!(
568            classify(r).unwrap(),
569            PubsubEvent::Unsubscribe {
570                channel: None,
571                count: 0,
572            }
573        );
574    }
575
576    #[test]
577    fn classify_punsubscribe_with_pattern() {
578        let r = Reply::Array(vec![
579            Reply::Bulk(b"punsubscribe".to_vec()),
580            Reply::Bulk(b"chan.*".to_vec()),
581            Reply::Int(0),
582        ]);
583        assert_eq!(
584            classify(r).unwrap(),
585            PubsubEvent::Punsubscribe {
586                pattern: Some(b"chan.*".to_vec()),
587                count: 0,
588            }
589        );
590    }
591
592    #[test]
593    fn classify_rejects_unknown_kind() {
594        let r = Reply::Array(vec![
595            Reply::Bulk(b"bogus".to_vec()),
596            Reply::Bulk(b"x".to_vec()),
597            Reply::Int(0),
598        ]);
599        assert_eq!(classify(r).unwrap_err().kind(), io::ErrorKind::InvalidData);
600    }
601
602    #[test]
603    fn classify_rejects_non_array() {
604        assert_eq!(
605            classify(Reply::Simple(b"OK".to_vec())).unwrap_err().kind(),
606            io::ErrorKind::InvalidData
607        );
608    }
609
610    #[test]
611    fn classify_rejects_wrong_arity() {
612        // subscribe with 2 elements (missing count).
613        let r = Reply::Array(vec![
614            Reply::Bulk(b"subscribe".to_vec()),
615            Reply::Bulk(b"x".to_vec()),
616        ]);
617        assert_eq!(classify(r).unwrap_err().kind(), io::ErrorKind::InvalidData);
618    }
619
620    // ----- subscribe arg validation -----
621
622    #[test]
623    fn open_with_empty_channels_rejected() {
624        let err = Subscriber::open("kevy://127.0.0.1:1", &[]).unwrap_err();
625        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
626    }
627}