1use std::io::{self, Read, Write};
30use std::net::TcpStream;
31use std::time::Duration;
32
33use kevy_resp::{Reply, encode_command, parse_reply};
34
35#[derive(Debug)]
37pub struct Subscriber {
38 stream: TcpStream,
39 buf: Vec<u8>,
40}
41
42#[non_exhaustive]
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub enum PubsubEvent {
50 Subscribe {
52 channel: Vec<u8>,
54 count: i64,
56 },
57 Psubscribe {
59 pattern: Vec<u8>,
61 count: i64,
63 },
64 Unsubscribe {
67 channel: Option<Vec<u8>>,
69 count: i64,
71 },
72 Punsubscribe {
75 pattern: Option<Vec<u8>>,
77 count: i64,
79 },
80 Message {
82 channel: Vec<u8>,
84 payload: Vec<u8>,
86 },
87 Pmessage {
90 pattern: Vec<u8>,
92 channel: Vec<u8>,
94 payload: Vec<u8>,
96 },
97}
98
99impl Subscriber {
100 pub fn connect(url: &str) -> io::Result<Self> {
107 let (host, port) = parse_pubsub_url(url)?;
108 let stream = TcpStream::connect((host.as_str(), port))?;
109 stream.set_nodelay(true).ok();
110 Ok(Self {
111 stream,
112 buf: Vec::with_capacity(8192),
113 })
114 }
115
116 pub fn open(url: &str, channels: &[&[u8]]) -> io::Result<Self> {
124 if channels.is_empty() {
125 return Err(io::Error::new(
126 io::ErrorKind::InvalidInput,
127 "Subscriber::open needs ≥ 1 channel — use Subscriber::connect() for empty start",
128 ));
129 }
130 let mut s = Self::connect(url)?;
131 s.subscribe(channels)?;
132 Ok(s)
133 }
134
135 pub fn subscribe(&mut self, channels: &[&[u8]]) -> io::Result<()> {
139 if channels.is_empty() {
140 return Err(io::Error::new(
141 io::ErrorKind::InvalidInput,
142 "SUBSCRIBE needs ≥ 1 channel",
143 ));
144 }
145 self.send(b"SUBSCRIBE", channels)
146 }
147
148 pub fn psubscribe(&mut self, patterns: &[&[u8]]) -> io::Result<()> {
151 if patterns.is_empty() {
152 return Err(io::Error::new(
153 io::ErrorKind::InvalidInput,
154 "PSUBSCRIBE needs ≥ 1 pattern",
155 ));
156 }
157 self.send(b"PSUBSCRIBE", patterns)
158 }
159
160 pub fn unsubscribe(&mut self, channels: &[&[u8]]) -> io::Result<()> {
163 self.send(b"UNSUBSCRIBE", channels)
164 }
165
166 pub fn punsubscribe(&mut self, patterns: &[&[u8]]) -> io::Result<()> {
169 self.send(b"PUNSUBSCRIBE", patterns)
170 }
171
172 pub fn recv(&mut self) -> io::Result<PubsubEvent> {
179 let mut chunk = [0u8; 8192];
180 loop {
181 match parse_reply(&self.buf) {
182 Ok(Some((reply, used))) => {
183 self.buf.drain(..used);
184 return classify(reply);
185 }
186 Ok(None) => {}
187 Err(_) => {
188 return Err(io::Error::new(
189 io::ErrorKind::InvalidData,
190 "malformed reply",
191 ));
192 }
193 }
194 let n = self.stream.read(&mut chunk)?;
195 if n == 0 {
196 return Err(io::Error::new(
197 io::ErrorKind::UnexpectedEof,
198 "server closed connection",
199 ));
200 }
201 self.buf.extend_from_slice(&chunk[..n]);
202 }
203 }
204
205 pub fn set_read_timeout(&mut self, dur: Option<Duration>) -> io::Result<()> {
209 self.stream.set_read_timeout(dur)
210 }
211
212 fn send(&mut self, verb: &[u8], args: &[&[u8]]) -> io::Result<()> {
213 let mut argv = Vec::with_capacity(args.len() + 1);
214 argv.push(verb.to_vec());
215 argv.extend(args.iter().map(|a| a.to_vec()));
216 let mut frame = Vec::new();
217 encode_command(&mut frame, &argv);
218 self.stream.write_all(&frame)
219 }
220}
221
222fn classify(reply: Reply) -> io::Result<PubsubEvent> {
223 let items = match reply {
224 Reply::Array(v) => v,
225 other => return Err(invalid(format!("expected array frame, got {}", shape(&other)))),
226 };
227 let kind = match items.first() {
228 Some(Reply::Bulk(b)) => b.clone(),
229 _ => return Err(invalid("pubsub frame missing kind field")),
230 };
231 match kind.as_slice() {
232 b"subscribe" => {
233 let [_, ch, n] = into_array3(items)?;
234 Ok(PubsubEvent::Subscribe {
235 channel: take_bulk(ch, "channel")?,
236 count: take_int(n, "count")?,
237 })
238 }
239 b"psubscribe" => {
240 let [_, p, n] = into_array3(items)?;
241 Ok(PubsubEvent::Psubscribe {
242 pattern: take_bulk(p, "pattern")?,
243 count: take_int(n, "count")?,
244 })
245 }
246 b"unsubscribe" => {
247 let [_, ch, n] = into_array3(items)?;
248 Ok(PubsubEvent::Unsubscribe {
249 channel: take_bulk_or_nil(ch, "channel")?,
250 count: take_int(n, "count")?,
251 })
252 }
253 b"punsubscribe" => {
254 let [_, p, n] = into_array3(items)?;
255 Ok(PubsubEvent::Punsubscribe {
256 pattern: take_bulk_or_nil(p, "pattern")?,
257 count: take_int(n, "count")?,
258 })
259 }
260 b"message" => {
261 let [_, ch, payload] = into_array3(items)?;
262 Ok(PubsubEvent::Message {
263 channel: take_bulk(ch, "channel")?,
264 payload: take_bulk(payload, "payload")?,
265 })
266 }
267 b"pmessage" => {
268 let [_, pat, ch, payload] = into_array4(items)?;
269 Ok(PubsubEvent::Pmessage {
270 pattern: take_bulk(pat, "pattern")?,
271 channel: take_bulk(ch, "channel")?,
272 payload: take_bulk(payload, "payload")?,
273 })
274 }
275 other => Err(invalid(format!(
276 "unknown pubsub kind '{}'",
277 String::from_utf8_lossy(other)
278 ))),
279 }
280}
281
282fn into_array3(items: Vec<Reply>) -> io::Result<[Reply; 3]> {
283 items.try_into().map_err(|v: Vec<Reply>| {
284 invalid(format!("expected 3-element pubsub frame, got {}", v.len()))
285 })
286}
287
288fn into_array4(items: Vec<Reply>) -> io::Result<[Reply; 4]> {
289 items.try_into().map_err(|v: Vec<Reply>| {
290 invalid(format!("expected 4-element pubsub frame, got {}", v.len()))
291 })
292}
293
294fn take_bulk(r: Reply, field: &str) -> io::Result<Vec<u8>> {
295 match r {
296 Reply::Bulk(b) => Ok(b),
297 other => Err(invalid(format!(
298 "expected bulk for {field}, got {}",
299 shape(&other)
300 ))),
301 }
302}
303
304fn take_bulk_or_nil(r: Reply, field: &str) -> io::Result<Option<Vec<u8>>> {
305 match r {
306 Reply::Bulk(b) => Ok(Some(b)),
307 Reply::Nil => Ok(None),
308 other => Err(invalid(format!(
309 "expected bulk/nil for {field}, got {}",
310 shape(&other)
311 ))),
312 }
313}
314
315fn take_int(r: Reply, field: &str) -> io::Result<i64> {
316 match r {
317 Reply::Int(n) => Ok(n),
318 other => Err(invalid(format!(
319 "expected integer for {field}, got {}",
320 shape(&other)
321 ))),
322 }
323}
324
325fn shape(r: &Reply) -> &'static str {
326 match r {
327 Reply::Simple(_) => "simple-string",
328 Reply::Error(_) => "error",
329 Reply::Int(_) => "integer",
330 Reply::Bulk(_) => "bulk-string",
331 Reply::Nil => "nil",
332 Reply::Array(_) => "array",
333 }
334}
335
336fn invalid(msg: impl Into<String>) -> io::Error {
337 io::Error::new(io::ErrorKind::InvalidData, msg.into())
338}
339
340fn parse_pubsub_url(url: &str) -> io::Result<(String, u16)> {
345 let (scheme, rest) = url.split_once("://").ok_or_else(|| {
346 io::Error::new(io::ErrorKind::InvalidInput, "URL missing '://'")
347 })?;
348 match scheme {
349 "kevy" | "redis" | "tcp" => {}
350 "mem" | "file" => {
351 return Err(io::Error::new(
352 io::ErrorKind::Unsupported,
353 format!(
354 "{scheme}:// is an embedded backend — pub/sub needs a TCP server. \
355 Use kevy://host:port instead."
356 ),
357 ));
358 }
359 "rediss" | "kevys" => {
360 return Err(io::Error::new(
361 io::ErrorKind::Unsupported,
362 "TLS schemes (rediss://, kevys://) are unsupported — kevy has no TLS",
363 ));
364 }
365 other => {
366 return Err(io::Error::new(
367 io::ErrorKind::InvalidInput,
368 format!("unknown URL scheme '{other}://'"),
369 ));
370 }
371 }
372 if rest.contains('@') {
373 return Err(io::Error::new(
374 io::ErrorKind::Unsupported,
375 "userinfo (user:pass@host) is unsupported — kevy has no AUTH",
376 ));
377 }
378 let authority = rest.split('/').next().unwrap_or("");
379 let (host, port) = match authority.rsplit_once(':') {
380 Some((h, p)) => {
381 let port: u16 = p.parse().map_err(|_| {
382 io::Error::new(io::ErrorKind::InvalidInput, format!("bad port: {p}"))
383 })?;
384 (h.to_string(), port)
385 }
386 None => (authority.to_string(), 6379),
387 };
388 if host.is_empty() {
389 return Err(io::Error::new(io::ErrorKind::InvalidInput, "empty host"));
390 }
391 Ok((host, port))
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397
398 #[test]
401 fn parses_kevy_redis_tcp() {
402 for url in [
403 "kevy://localhost:6379",
404 "redis://localhost:6379",
405 "tcp://localhost:6379",
406 ] {
407 let (h, p) = parse_pubsub_url(url).unwrap();
408 assert_eq!(h, "localhost");
409 assert_eq!(p, 6379);
410 }
411 }
412
413 #[test]
414 fn default_port_when_omitted() {
415 let (h, p) = parse_pubsub_url("kevy://example.com").unwrap();
416 assert_eq!(h, "example.com");
417 assert_eq!(p, 6379);
418 }
419
420 #[test]
421 fn db_path_segment_ignored() {
422 let (h, p) = parse_pubsub_url("kevy://h:1234/0").unwrap();
424 assert_eq!(h, "h");
425 assert_eq!(p, 1234);
426 let (h, p) = parse_pubsub_url("redis://h:1234/3").unwrap();
427 assert_eq!(h, "h");
428 assert_eq!(p, 1234);
429 }
430
431 #[test]
432 fn mem_file_rejected_unsupported() {
433 for url in ["mem://", "file:///tmp/data"] {
434 let err = parse_pubsub_url(url).unwrap_err();
435 assert_eq!(err.kind(), io::ErrorKind::Unsupported);
436 }
437 }
438
439 #[test]
440 fn tls_rejected_unsupported() {
441 assert_eq!(
442 parse_pubsub_url("rediss://h:6379").unwrap_err().kind(),
443 io::ErrorKind::Unsupported
444 );
445 }
446
447 #[test]
448 fn userinfo_rejected_unsupported() {
449 assert_eq!(
450 parse_pubsub_url("kevy://u:p@h:6379").unwrap_err().kind(),
451 io::ErrorKind::Unsupported
452 );
453 }
454
455 #[test]
456 fn unknown_scheme_rejected() {
457 assert_eq!(
458 parse_pubsub_url("memcached://h:11211").unwrap_err().kind(),
459 io::ErrorKind::InvalidInput
460 );
461 }
462
463 #[test]
464 fn bad_port_rejected() {
465 assert!(parse_pubsub_url("kevy://h:notaport").is_err());
466 assert!(parse_pubsub_url("kevy://h:99999").is_err());
467 }
468
469 #[test]
470 fn empty_host_rejected() {
471 assert!(parse_pubsub_url("kevy://:6379").is_err());
472 }
473
474 #[test]
477 fn classify_subscribe_ack() {
478 let r = Reply::Array(vec![
479 Reply::Bulk(b"subscribe".to_vec()),
480 Reply::Bulk(b"chan".to_vec()),
481 Reply::Int(1),
482 ]);
483 assert_eq!(
484 classify(r).unwrap(),
485 PubsubEvent::Subscribe {
486 channel: b"chan".to_vec(),
487 count: 1,
488 }
489 );
490 }
491
492 #[test]
493 fn classify_psubscribe_ack() {
494 let r = Reply::Array(vec![
495 Reply::Bulk(b"psubscribe".to_vec()),
496 Reply::Bulk(b"chan.*".to_vec()),
497 Reply::Int(2),
498 ]);
499 assert_eq!(
500 classify(r).unwrap(),
501 PubsubEvent::Psubscribe {
502 pattern: b"chan.*".to_vec(),
503 count: 2,
504 }
505 );
506 }
507
508 #[test]
509 fn classify_message_event() {
510 let r = Reply::Array(vec![
511 Reply::Bulk(b"message".to_vec()),
512 Reply::Bulk(b"news".to_vec()),
513 Reply::Bulk(b"hello".to_vec()),
514 ]);
515 assert_eq!(
516 classify(r).unwrap(),
517 PubsubEvent::Message {
518 channel: b"news".to_vec(),
519 payload: b"hello".to_vec(),
520 }
521 );
522 }
523
524 #[test]
525 fn classify_pmessage_event() {
526 let r = Reply::Array(vec![
527 Reply::Bulk(b"pmessage".to_vec()),
528 Reply::Bulk(b"news.*".to_vec()),
529 Reply::Bulk(b"news.tech".to_vec()),
530 Reply::Bulk(b"hi".to_vec()),
531 ]);
532 assert_eq!(
533 classify(r).unwrap(),
534 PubsubEvent::Pmessage {
535 pattern: b"news.*".to_vec(),
536 channel: b"news.tech".to_vec(),
537 payload: b"hi".to_vec(),
538 }
539 );
540 }
541
542 #[test]
543 fn classify_unsubscribe_with_channel() {
544 let r = Reply::Array(vec![
545 Reply::Bulk(b"unsubscribe".to_vec()),
546 Reply::Bulk(b"chan".to_vec()),
547 Reply::Int(0),
548 ]);
549 assert_eq!(
550 classify(r).unwrap(),
551 PubsubEvent::Unsubscribe {
552 channel: Some(b"chan".to_vec()),
553 count: 0,
554 }
555 );
556 }
557
558 #[test]
559 fn classify_unsubscribe_with_nil_channel() {
560 let r = Reply::Array(vec![
563 Reply::Bulk(b"unsubscribe".to_vec()),
564 Reply::Nil,
565 Reply::Int(0),
566 ]);
567 assert_eq!(
568 classify(r).unwrap(),
569 PubsubEvent::Unsubscribe {
570 channel: None,
571 count: 0,
572 }
573 );
574 }
575
576 #[test]
577 fn classify_punsubscribe_with_pattern() {
578 let r = Reply::Array(vec![
579 Reply::Bulk(b"punsubscribe".to_vec()),
580 Reply::Bulk(b"chan.*".to_vec()),
581 Reply::Int(0),
582 ]);
583 assert_eq!(
584 classify(r).unwrap(),
585 PubsubEvent::Punsubscribe {
586 pattern: Some(b"chan.*".to_vec()),
587 count: 0,
588 }
589 );
590 }
591
592 #[test]
593 fn classify_rejects_unknown_kind() {
594 let r = Reply::Array(vec![
595 Reply::Bulk(b"bogus".to_vec()),
596 Reply::Bulk(b"x".to_vec()),
597 Reply::Int(0),
598 ]);
599 assert_eq!(classify(r).unwrap_err().kind(), io::ErrorKind::InvalidData);
600 }
601
602 #[test]
603 fn classify_rejects_non_array() {
604 assert_eq!(
605 classify(Reply::Simple(b"OK".to_vec())).unwrap_err().kind(),
606 io::ErrorKind::InvalidData
607 );
608 }
609
610 #[test]
611 fn classify_rejects_wrong_arity() {
612 let r = Reply::Array(vec![
614 Reply::Bulk(b"subscribe".to_vec()),
615 Reply::Bulk(b"x".to_vec()),
616 ]);
617 assert_eq!(classify(r).unwrap_err().kind(), io::ErrorKind::InvalidData);
618 }
619
620 #[test]
623 fn open_with_empty_channels_rejected() {
624 let err = Subscriber::open("kevy://127.0.0.1:1", &[]).unwrap_err();
625 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
626 }
627}