use std::io::{self, Read, Write};
use std::net::TcpStream;
use std::time::Duration;
use kevy_embedded::Subscription;
use kevy_resp::{Reply, encode_command};
#[cfg(test)]
use crate::subscribe_io::classify;
use crate::subscribe_io::{frame_to_event, invalid, recv_remote, send_to, shape};
use crate::{Target, parse_url, resolve_store};
#[derive(Debug)]
pub struct Subscriber {
inner: Inner,
}
#[derive(Debug)]
enum Inner {
Remote {
stream: TcpStream,
buf: Vec<u8>,
},
Embedded {
subscription: Subscription,
timeout: Option<Duration>,
},
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PubsubEvent {
Subscribe {
channel: Vec<u8>,
count: i64,
},
Psubscribe {
pattern: Vec<u8>,
count: i64,
},
Unsubscribe {
channel: Option<Vec<u8>>,
count: i64,
},
Punsubscribe {
pattern: Option<Vec<u8>>,
count: i64,
},
Message {
channel: Vec<u8>,
payload: Vec<u8>,
},
Pmessage {
pattern: Vec<u8>,
channel: Vec<u8>,
payload: Vec<u8>,
},
}
impl Subscriber {
pub fn connect(url: &str) -> io::Result<Self> {
let target = parse_url(url)?;
let inner = match target {
Target::EmbedMemoryAnonymous => {
return Err(io::Error::new(
io::ErrorKind::Unsupported,
"anonymous mem:// has no other producer; use mem://<name> for a shared bus",
));
}
Target::EmbedMemoryNamed(_) | Target::EmbedPersist(_) => Inner::Embedded {
subscription: resolve_store(&target)?.subscribe(&[]),
timeout: None,
},
Target::Remote(remote_url) => {
let (host, port) = remote_host_port(&remote_url)?;
let stream = TcpStream::connect((host.as_str(), port))?;
stream.set_nodelay(true).ok();
Inner::Remote {
stream,
buf: Vec::with_capacity(8192),
}
}
};
Ok(Self { inner })
}
pub fn open(url: &str, channels: &[&[u8]]) -> io::Result<Self> {
if channels.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Subscriber::open needs ≥ 1 channel — use Subscriber::connect() for empty start",
));
}
let mut s = Self::connect(url)?;
s.subscribe(channels)?;
Ok(s)
}
pub fn subscribe(&mut self, channels: &[&[u8]]) -> io::Result<()> {
if channels.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"SUBSCRIBE needs ≥ 1 channel",
));
}
match &mut self.inner {
Inner::Remote { stream, .. } => send_to(stream, b"SUBSCRIBE", channels),
Inner::Embedded { subscription, .. } => {
subscription.subscribe(channels);
Ok(())
}
}
}
pub fn psubscribe(&mut self, patterns: &[&[u8]]) -> io::Result<()> {
if patterns.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"PSUBSCRIBE needs ≥ 1 pattern",
));
}
match &mut self.inner {
Inner::Remote { stream, .. } => send_to(stream, b"PSUBSCRIBE", patterns),
Inner::Embedded { subscription, .. } => {
subscription.psubscribe(patterns);
Ok(())
}
}
}
pub fn unsubscribe(&mut self, channels: &[&[u8]]) -> io::Result<()> {
match &mut self.inner {
Inner::Remote { stream, .. } => send_to(stream, b"UNSUBSCRIBE", channels),
Inner::Embedded { subscription, .. } => {
subscription.unsubscribe(channels);
Ok(())
}
}
}
pub fn punsubscribe(&mut self, patterns: &[&[u8]]) -> io::Result<()> {
match &mut self.inner {
Inner::Remote { stream, .. } => send_to(stream, b"PUNSUBSCRIBE", patterns),
Inner::Embedded { subscription, .. } => {
subscription.punsubscribe(patterns);
Ok(())
}
}
}
pub fn recv(&mut self) -> io::Result<PubsubEvent> {
match &mut self.inner {
Inner::Remote { stream, buf } => recv_remote(stream, buf),
Inner::Embedded {
subscription,
timeout,
} => {
let frame = match *timeout {
Some(d) => subscription.recv_timeout(d)?,
None => subscription.recv()?,
};
Ok(frame_to_event(frame))
}
}
}
pub fn recv_message(&mut self) -> io::Result<(Vec<u8>, Vec<u8>)> {
loop {
match self.recv()? {
PubsubEvent::Message { channel, payload } => return Ok((channel, payload)),
PubsubEvent::Pmessage { channel, payload, .. } => {
return Ok((channel, payload));
}
PubsubEvent::Subscribe { .. }
| PubsubEvent::Psubscribe { .. }
| PubsubEvent::Unsubscribe { .. }
| PubsubEvent::Punsubscribe { .. } => continue,
}
}
}
pub fn hello3(&mut self) -> io::Result<PubsubEvent> {
match &mut self.inner {
Inner::Embedded { .. } => Err(io::Error::new(
io::ErrorKind::Unsupported,
"HELLO 3 is a remote/TCP-only operation; embedded backend has no proto switch",
)),
Inner::Remote { stream, buf } => {
let mut frame = Vec::new();
encode_command(&mut frame, &[b"HELLO".to_vec(), b"3".to_vec()]);
stream.write_all(&frame)?;
let mut chunk = [0u8; 4096];
loop {
match kevy_resp::parse_reply(buf) {
Ok(Some((reply, used))) => {
buf.drain(..used);
return match reply {
Reply::Map(_) | Reply::Array(_) => {
Ok(PubsubEvent::Subscribe {
channel: b"HELLO".to_vec(),
count: 3,
})
}
Reply::Error(e) => Err(io::Error::other(
String::from_utf8_lossy(&e).into_owned(),
)),
other => Err(invalid(format!(
"unexpected HELLO 3 reply shape: {}",
shape(&other)
))),
};
}
Ok(None) => {}
Err(_) => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"malformed HELLO 3 reply",
));
}
}
let n = stream.read(&mut chunk)?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"server closed connection during HELLO 3",
));
}
buf.extend_from_slice(&chunk[..n]);
}
}
}
}
pub fn set_read_timeout(&mut self, dur: Option<Duration>) -> io::Result<()> {
match &mut self.inner {
Inner::Remote { stream, .. } => stream.set_read_timeout(dur),
Inner::Embedded { timeout, .. } => {
*timeout = dur;
Ok(())
}
}
}
pub fn events(&mut self) -> SubscriberEvents<'_> {
SubscriberEvents { sub: self }
}
pub fn messages(&mut self) -> SubscriberMessages<'_> {
SubscriberMessages { sub: self }
}
}
#[derive(Debug)]
pub struct SubscriberEvents<'a> {
sub: &'a mut Subscriber,
}
impl Iterator for SubscriberEvents<'_> {
type Item = io::Result<PubsubEvent>;
fn next(&mut self) -> Option<Self::Item> {
match self.sub.recv() {
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => None,
other => Some(other),
}
}
}
#[derive(Debug)]
pub struct SubscriberMessages<'a> {
sub: &'a mut Subscriber,
}
impl Iterator for SubscriberMessages<'_> {
type Item = io::Result<(Vec<u8>, Vec<u8>)>;
fn next(&mut self) -> Option<Self::Item> {
match self.sub.recv_message() {
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => None,
other => Some(other),
}
}
}
fn remote_host_port(url: &str) -> io::Result<(String, u16)> {
let (_scheme, rest) = url.split_once("://").ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "URL missing '://'")
})?;
if rest.contains('@') {
return Err(io::Error::new(
io::ErrorKind::Unsupported,
"userinfo (user:pass@host) is unsupported — kevy has no AUTH",
));
}
let authority = rest.split('/').next().unwrap_or("");
let (host, port) = match authority.rsplit_once(':') {
Some((h, p)) => {
let port: u16 = p.parse().map_err(|_| {
io::Error::new(io::ErrorKind::InvalidInput, format!("bad port: {p}"))
})?;
(h.to_string(), port)
}
None => (authority.to_string(), 6379),
};
if host.is_empty() {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "empty host"));
}
Ok((host, port))
}
#[cfg(test)]
#[path = "subscribe_tests.rs"]
mod tests;