1use std::io::{self, Read, Write};
31use std::net::TcpStream;
32use std::time::Duration;
33
34use kevy_embedded::{PubsubFrame, Subscription};
35use kevy_resp::{Reply, encode_command, parse_reply};
36
37use crate::{Target, parse_url, resolve_store};
38
39#[derive(Debug)]
43pub struct Subscriber {
44 inner: Inner,
45}
46
47#[derive(Debug)]
48enum Inner {
49 Remote {
51 stream: TcpStream,
52 buf: Vec<u8>,
53 },
54 Embedded {
57 subscription: Subscription,
58 timeout: Option<Duration>,
59 },
60}
61
62#[non_exhaustive]
68#[derive(Debug, Clone, PartialEq, Eq)]
69pub enum PubsubEvent {
70 Subscribe {
72 channel: Vec<u8>,
74 count: i64,
76 },
77 Psubscribe {
79 pattern: Vec<u8>,
81 count: i64,
83 },
84 Unsubscribe {
87 channel: Option<Vec<u8>>,
89 count: i64,
91 },
92 Punsubscribe {
95 pattern: Option<Vec<u8>>,
97 count: i64,
99 },
100 Message {
102 channel: Vec<u8>,
104 payload: Vec<u8>,
106 },
107 Pmessage {
110 pattern: Vec<u8>,
112 channel: Vec<u8>,
114 payload: Vec<u8>,
116 },
117}
118
119impl Subscriber {
120 pub fn connect(url: &str) -> io::Result<Self> {
129 let target = parse_url(url)?;
130 let inner = match target {
131 Target::EmbedMemoryAnonymous => {
132 return Err(io::Error::new(
133 io::ErrorKind::Unsupported,
134 "anonymous mem:// has no other producer; use mem://<name> for a shared bus",
135 ));
136 }
137 Target::EmbedMemoryNamed(_) | Target::EmbedPersist(_) => Inner::Embedded {
138 subscription: resolve_store(&target)?.subscribe(&[]),
139 timeout: None,
140 },
141 Target::Remote(remote_url) => {
142 let (host, port) = remote_host_port(&remote_url)?;
143 let stream = TcpStream::connect((host.as_str(), port))?;
144 stream.set_nodelay(true).ok();
145 Inner::Remote {
146 stream,
147 buf: Vec::with_capacity(8192),
148 }
149 }
150 };
151 Ok(Self { inner })
152 }
153
154 pub fn open(url: &str, channels: &[&[u8]]) -> io::Result<Self> {
158 if channels.is_empty() {
159 return Err(io::Error::new(
160 io::ErrorKind::InvalidInput,
161 "Subscriber::open needs ≥ 1 channel — use Subscriber::connect() for empty start",
162 ));
163 }
164 let mut s = Self::connect(url)?;
165 s.subscribe(channels)?;
166 Ok(s)
167 }
168
169 pub fn subscribe(&mut self, channels: &[&[u8]]) -> io::Result<()> {
172 if channels.is_empty() {
173 return Err(io::Error::new(
174 io::ErrorKind::InvalidInput,
175 "SUBSCRIBE needs ≥ 1 channel",
176 ));
177 }
178 match &mut self.inner {
179 Inner::Remote { stream, .. } => send_to(stream, b"SUBSCRIBE", channels),
180 Inner::Embedded { subscription, .. } => {
181 subscription.subscribe(channels);
182 Ok(())
183 }
184 }
185 }
186
187 pub fn psubscribe(&mut self, patterns: &[&[u8]]) -> io::Result<()> {
190 if patterns.is_empty() {
191 return Err(io::Error::new(
192 io::ErrorKind::InvalidInput,
193 "PSUBSCRIBE needs ≥ 1 pattern",
194 ));
195 }
196 match &mut self.inner {
197 Inner::Remote { stream, .. } => send_to(stream, b"PSUBSCRIBE", patterns),
198 Inner::Embedded { subscription, .. } => {
199 subscription.psubscribe(patterns);
200 Ok(())
201 }
202 }
203 }
204
205 pub fn unsubscribe(&mut self, channels: &[&[u8]]) -> io::Result<()> {
208 match &mut self.inner {
209 Inner::Remote { stream, .. } => send_to(stream, b"UNSUBSCRIBE", channels),
210 Inner::Embedded { subscription, .. } => {
211 subscription.unsubscribe(channels);
212 Ok(())
213 }
214 }
215 }
216
217 pub fn punsubscribe(&mut self, patterns: &[&[u8]]) -> io::Result<()> {
220 match &mut self.inner {
221 Inner::Remote { stream, .. } => send_to(stream, b"PUNSUBSCRIBE", patterns),
222 Inner::Embedded { subscription, .. } => {
223 subscription.punsubscribe(patterns);
224 Ok(())
225 }
226 }
227 }
228
229 pub fn recv(&mut self) -> io::Result<PubsubEvent> {
233 match &mut self.inner {
234 Inner::Remote { stream, buf } => recv_remote(stream, buf),
235 Inner::Embedded {
236 subscription,
237 timeout,
238 } => {
239 let frame = match *timeout {
240 Some(d) => subscription.recv_timeout(d)?,
241 None => subscription.recv()?,
242 };
243 Ok(frame_to_event(frame))
244 }
245 }
246 }
247
248 pub fn set_read_timeout(&mut self, dur: Option<Duration>) -> io::Result<()> {
252 match &mut self.inner {
253 Inner::Remote { stream, .. } => stream.set_read_timeout(dur),
254 Inner::Embedded { timeout, .. } => {
255 *timeout = dur;
256 Ok(())
257 }
258 }
259 }
260}
261
262fn send_to(stream: &mut TcpStream, verb: &[u8], args: &[&[u8]]) -> io::Result<()> {
263 let mut argv = Vec::with_capacity(args.len() + 1);
264 argv.push(verb.to_vec());
265 argv.extend(args.iter().map(|a| a.to_vec()));
266 let mut frame = Vec::new();
267 encode_command(&mut frame, &argv);
268 stream.write_all(&frame)
269}
270
271fn recv_remote(stream: &mut TcpStream, buf: &mut Vec<u8>) -> io::Result<PubsubEvent> {
272 let mut chunk = [0u8; 8192];
273 loop {
274 match parse_reply(buf) {
275 Ok(Some((reply, used))) => {
276 buf.drain(..used);
277 return classify(reply);
278 }
279 Ok(None) => {}
280 Err(_) => {
281 return Err(io::Error::new(
282 io::ErrorKind::InvalidData,
283 "malformed reply",
284 ));
285 }
286 }
287 let n = stream.read(&mut chunk)?;
288 if n == 0 {
289 return Err(io::Error::new(
290 io::ErrorKind::UnexpectedEof,
291 "server closed connection",
292 ));
293 }
294 buf.extend_from_slice(&chunk[..n]);
295 }
296}
297
298fn frame_to_event(frame: PubsubFrame) -> PubsubEvent {
299 match frame {
300 PubsubFrame::Subscribe { channel, count } => PubsubEvent::Subscribe {
301 channel,
302 count: count as i64,
303 },
304 PubsubFrame::Psubscribe { pattern, count } => PubsubEvent::Psubscribe {
305 pattern,
306 count: count as i64,
307 },
308 PubsubFrame::Unsubscribe { channel, count } => PubsubEvent::Unsubscribe {
309 channel,
310 count: count as i64,
311 },
312 PubsubFrame::Punsubscribe { pattern, count } => PubsubEvent::Punsubscribe {
313 pattern,
314 count: count as i64,
315 },
316 PubsubFrame::Message { channel, payload } => PubsubEvent::Message { channel, payload },
317 PubsubFrame::Pmessage {
318 pattern,
319 channel,
320 payload,
321 } => PubsubEvent::Pmessage {
322 pattern,
323 channel,
324 payload,
325 },
326 }
327}
328
329fn classify(reply: Reply) -> io::Result<PubsubEvent> {
330 let items = match reply {
331 Reply::Array(v) => v,
332 other => return Err(invalid(format!("expected array frame, got {}", shape(&other)))),
333 };
334 let kind = match items.first() {
335 Some(Reply::Bulk(b)) => b.clone(),
336 _ => return Err(invalid("pubsub frame missing kind field")),
337 };
338 match kind.as_slice() {
339 b"subscribe" => {
340 let [_, ch, n] = into_array3(items)?;
341 Ok(PubsubEvent::Subscribe {
342 channel: take_bulk(ch, "channel")?,
343 count: take_int(n, "count")?,
344 })
345 }
346 b"psubscribe" => {
347 let [_, p, n] = into_array3(items)?;
348 Ok(PubsubEvent::Psubscribe {
349 pattern: take_bulk(p, "pattern")?,
350 count: take_int(n, "count")?,
351 })
352 }
353 b"unsubscribe" => {
354 let [_, ch, n] = into_array3(items)?;
355 Ok(PubsubEvent::Unsubscribe {
356 channel: take_bulk_or_nil(ch, "channel")?,
357 count: take_int(n, "count")?,
358 })
359 }
360 b"punsubscribe" => {
361 let [_, p, n] = into_array3(items)?;
362 Ok(PubsubEvent::Punsubscribe {
363 pattern: take_bulk_or_nil(p, "pattern")?,
364 count: take_int(n, "count")?,
365 })
366 }
367 b"message" => {
368 let [_, ch, payload] = into_array3(items)?;
369 Ok(PubsubEvent::Message {
370 channel: take_bulk(ch, "channel")?,
371 payload: take_bulk(payload, "payload")?,
372 })
373 }
374 b"pmessage" => {
375 let [_, pat, ch, payload] = into_array4(items)?;
376 Ok(PubsubEvent::Pmessage {
377 pattern: take_bulk(pat, "pattern")?,
378 channel: take_bulk(ch, "channel")?,
379 payload: take_bulk(payload, "payload")?,
380 })
381 }
382 other => Err(invalid(format!(
383 "unknown pubsub kind '{}'",
384 String::from_utf8_lossy(other)
385 ))),
386 }
387}
388
389fn into_array3(items: Vec<Reply>) -> io::Result<[Reply; 3]> {
390 items.try_into().map_err(|v: Vec<Reply>| {
391 invalid(format!("expected 3-element pubsub frame, got {}", v.len()))
392 })
393}
394
395fn into_array4(items: Vec<Reply>) -> io::Result<[Reply; 4]> {
396 items.try_into().map_err(|v: Vec<Reply>| {
397 invalid(format!("expected 4-element pubsub frame, got {}", v.len()))
398 })
399}
400
401fn take_bulk(r: Reply, field: &str) -> io::Result<Vec<u8>> {
402 match r {
403 Reply::Bulk(b) => Ok(b),
404 other => Err(invalid(format!(
405 "expected bulk for {field}, got {}",
406 shape(&other)
407 ))),
408 }
409}
410
411fn take_bulk_or_nil(r: Reply, field: &str) -> io::Result<Option<Vec<u8>>> {
412 match r {
413 Reply::Bulk(b) => Ok(Some(b)),
414 Reply::Nil => Ok(None),
415 other => Err(invalid(format!(
416 "expected bulk/nil for {field}, got {}",
417 shape(&other)
418 ))),
419 }
420}
421
422fn take_int(r: Reply, field: &str) -> io::Result<i64> {
423 match r {
424 Reply::Int(n) => Ok(n),
425 other => Err(invalid(format!(
426 "expected integer for {field}, got {}",
427 shape(&other)
428 ))),
429 }
430}
431
432fn shape(r: &Reply) -> &'static str {
433 match r {
434 Reply::Simple(_) => "simple-string",
435 Reply::Error(_) => "error",
436 Reply::Int(_) => "integer",
437 Reply::Bulk(_) => "bulk-string",
438 Reply::Nil => "nil",
439 Reply::Array(_) => "array",
440 }
441}
442
443fn invalid(msg: impl Into<String>) -> io::Error {
444 io::Error::new(io::ErrorKind::InvalidData, msg.into())
445}
446
447fn remote_host_port(url: &str) -> io::Result<(String, u16)> {
454 let (_scheme, rest) = url.split_once("://").ok_or_else(|| {
455 io::Error::new(io::ErrorKind::InvalidInput, "URL missing '://'")
456 })?;
457 if rest.contains('@') {
458 return Err(io::Error::new(
459 io::ErrorKind::Unsupported,
460 "userinfo (user:pass@host) is unsupported — kevy has no AUTH",
461 ));
462 }
463 let authority = rest.split('/').next().unwrap_or("");
464 let (host, port) = match authority.rsplit_once(':') {
465 Some((h, p)) => {
466 let port: u16 = p.parse().map_err(|_| {
467 io::Error::new(io::ErrorKind::InvalidInput, format!("bad port: {p}"))
468 })?;
469 (h.to_string(), port)
470 }
471 None => (authority.to_string(), 6379),
472 };
473 if host.is_empty() {
474 return Err(io::Error::new(io::ErrorKind::InvalidInput, "empty host"));
475 }
476 Ok((host, port))
477}
478
479#[cfg(test)]
480#[path = "subscribe_tests.rs"]
481mod tests;