kade_proto/cmd/
subscribe.rs

1use crate::cmd::{Parse, ParseError, Unknown};
2use crate::prelude::*;
3
4use bytes::Bytes;
5use std::pin::Pin;
6use tokio::select;
7use tokio::sync::broadcast;
8use tokio_stream::{Stream, StreamExt, StreamMap};
9
10#[derive(Debug)]
11pub struct Subscribe {
12    channels: Vec<String>,
13}
14
15#[derive(Clone, Debug)]
16pub struct Unsubscribe {
17    channels: Vec<String>,
18}
19
20type Messages = Pin<Box<dyn Stream<Item = Bytes> + Send>>;
21
22impl Subscribe {
23    pub fn new(channels: Vec<String>) -> Subscribe { Subscribe { channels } }
24
25    pub fn parse_frames(parse: &mut Parse) -> crate::Result<Subscribe> {
26        use ParseError::EndOfStream;
27        let mut channels = vec![parse.next_string()?];
28
29        loop {
30            match parse.next_string() {
31                Ok(s) => channels.push(s),
32                Err(EndOfStream) => break,
33                Err(err) => return Err(err.into()),
34            }
35        }
36
37        Ok(Subscribe { channels })
38    }
39
40    pub async fn apply(mut self, db: &Db, dst: &mut Connection, shutdown: &mut Shutdown) -> crate::Result<()> {
41        let mut subscriptions = StreamMap::new();
42
43        loop {
44            for channel_name in self.channels.drain(..) {
45                subscribe_to_channel(channel_name, &mut subscriptions, db, dst).await?;
46            }
47
48            select! {
49                Some((channel_name, msg)) = subscriptions.next() => {
50                    dst.write_frame(&make_message_frame(channel_name, msg)).await?;
51                }
52                res = dst.read_frame() => {
53                    let frame = match res? {
54                        Some(frame) => frame,
55                        None => return Ok(())
56                    };
57
58                    handle_command(
59                        frame,
60                        &mut self.channels,
61                        &mut subscriptions,
62                        dst,
63                    ).await?;
64                }
65                _ = shutdown.recv() => {
66                    return Ok(());
67                }
68            };
69        }
70    }
71
72    pub fn into_frame(self) -> Frame {
73        let mut frame = Frame::array();
74        frame.push_bulk(Bytes::from("subscribe".as_bytes()));
75        for channel in self.channels {
76            frame.push_bulk(Bytes::from(channel.into_bytes()));
77        }
78        frame
79    }
80}
81
82async fn subscribe_to_channel(channel_name: String, subscriptions: &mut StreamMap<String, Messages>, db: &Db, dst: &mut Connection) -> crate::Result<()> {
83    let mut rx = db.subscribe(channel_name.clone());
84
85    let rx = Box::pin(async_stream::stream! {
86        loop {
87            match rx.recv().await {
88                Ok(msg) => yield msg,
89                Err(broadcast::error::RecvError::Lagged(_)) => {}
90                Err(_) => break,
91            }
92        }
93    });
94
95    subscriptions.insert(channel_name.clone(), rx);
96
97    let response = make_subscribe_frame(channel_name, subscriptions.len());
98    dst.write_frame(&response).await?;
99
100    Ok(())
101}
102
103async fn handle_command(frame: Frame, subscribe_to: &mut Vec<String>, subscriptions: &mut StreamMap<String, Messages>, dst: &mut Connection) -> crate::Result<()> {
104    match Command::from_frame(frame)? {
105        Command::Subscribe(subscribe) => {
106            subscribe_to.extend(subscribe.channels.into_iter());
107        }
108        Command::Unsubscribe(mut unsubscribe) => {
109            if unsubscribe.channels.is_empty() {
110                unsubscribe.channels = subscriptions.keys().map(|channel_name| channel_name.to_string()).collect();
111            }
112
113            for channel_name in unsubscribe.channels {
114                subscriptions.remove(&channel_name);
115
116                let response = make_unsubscribe_frame(channel_name, subscriptions.len());
117                dst.write_frame(&response).await?;
118            }
119        }
120        command => {
121            let cmd = Unknown::new(command.get_name());
122            cmd.apply(dst).await?;
123        }
124    }
125    Ok(())
126}
127
128fn make_subscribe_frame(channel_name: String, num_subs: usize) -> Frame {
129    let mut response = Frame::array();
130    response.push_bulk(Bytes::from_static(b"subscribe"));
131    response.push_bulk(Bytes::from(channel_name));
132    response.push_int(num_subs as u64);
133    response
134}
135
136fn make_unsubscribe_frame(channel_name: String, num_subs: usize) -> Frame {
137    let mut response = Frame::array();
138    response.push_bulk(Bytes::from_static(b"unsubscribe"));
139    response.push_bulk(Bytes::from(channel_name));
140    response.push_int(num_subs as u64);
141    response
142}
143
144fn make_message_frame(channel_name: String, msg: Bytes) -> Frame {
145    let mut response = Frame::array();
146    response.push_bulk(Bytes::from_static(b"message"));
147    response.push_bulk(Bytes::from(channel_name));
148    response.push_bulk(msg);
149    response
150}
151
152impl Unsubscribe {
153    pub fn new(channels: &[String]) -> Unsubscribe { Unsubscribe { channels: channels.to_vec() } }
154
155    pub fn parse_frames(parse: &mut Parse) -> Result<Unsubscribe, ParseError> {
156        use ParseError::EndOfStream;
157        let mut channels = vec![];
158
159        loop {
160            match parse.next_string() {
161                Ok(s) => channels.push(s),
162                Err(EndOfStream) => break,
163                Err(err) => return Err(err),
164            }
165        }
166
167        Ok(Unsubscribe { channels })
168    }
169
170    pub fn into_frame(self) -> Frame {
171        let mut frame = Frame::array();
172        frame.push_bulk(Bytes::from("unsubscribe".as_bytes()));
173
174        for channel in self.channels {
175            frame.push_bulk(Bytes::from(channel.into_bytes()));
176        }
177
178        frame
179    }
180}