kade_proto/cmd/
subscribe.rs1use crate::cmd::{Parse, ParseError, Unknown};
2use crate::prelude::*;
3
4use bytes::Bytes;
5use std::pin::Pin;
6use tokio::select;
7use tokio::sync::broadcast;
8use tokio_stream::{Stream, StreamExt, StreamMap};
9
10#[derive(Debug)]
11pub struct Subscribe {
12 channels: Vec<String>,
13}
14
15#[derive(Clone, Debug)]
16pub struct Unsubscribe {
17 channels: Vec<String>,
18}
19
20type Messages = Pin<Box<dyn Stream<Item = Bytes> + Send>>;
21
22impl Subscribe {
23 pub fn new(channels: Vec<String>) -> Subscribe { Subscribe { channels } }
24
25 pub fn parse_frames(parse: &mut Parse) -> crate::Result<Subscribe> {
26 use ParseError::EndOfStream;
27 let mut channels = vec![parse.next_string()?];
28
29 loop {
30 match parse.next_string() {
31 Ok(s) => channels.push(s),
32 Err(EndOfStream) => break,
33 Err(err) => return Err(err.into()),
34 }
35 }
36
37 Ok(Subscribe { channels })
38 }
39
40 pub async fn apply(mut self, db: &Db, dst: &mut Connection, shutdown: &mut Shutdown) -> crate::Result<()> {
41 let mut subscriptions = StreamMap::new();
42
43 loop {
44 for channel_name in self.channels.drain(..) {
45 subscribe_to_channel(channel_name, &mut subscriptions, db, dst).await?;
46 }
47
48 select! {
49 Some((channel_name, msg)) = subscriptions.next() => {
50 dst.write_frame(&make_message_frame(channel_name, msg)).await?;
51 }
52 res = dst.read_frame() => {
53 let frame = match res? {
54 Some(frame) => frame,
55 None => return Ok(())
56 };
57
58 handle_command(
59 frame,
60 &mut self.channels,
61 &mut subscriptions,
62 dst,
63 ).await?;
64 }
65 _ = shutdown.recv() => {
66 return Ok(());
67 }
68 };
69 }
70 }
71
72 pub fn into_frame(self) -> Frame {
73 let mut frame = Frame::array();
74 frame.push_bulk(Bytes::from("subscribe".as_bytes()));
75 for channel in self.channels {
76 frame.push_bulk(Bytes::from(channel.into_bytes()));
77 }
78 frame
79 }
80}
81
82async fn subscribe_to_channel(channel_name: String, subscriptions: &mut StreamMap<String, Messages>, db: &Db, dst: &mut Connection) -> crate::Result<()> {
83 let mut rx = db.subscribe(channel_name.clone());
84
85 let rx = Box::pin(async_stream::stream! {
86 loop {
87 match rx.recv().await {
88 Ok(msg) => yield msg,
89 Err(broadcast::error::RecvError::Lagged(_)) => {}
90 Err(_) => break,
91 }
92 }
93 });
94
95 subscriptions.insert(channel_name.clone(), rx);
96
97 let response = make_subscribe_frame(channel_name, subscriptions.len());
98 dst.write_frame(&response).await?;
99
100 Ok(())
101}
102
103async fn handle_command(frame: Frame, subscribe_to: &mut Vec<String>, subscriptions: &mut StreamMap<String, Messages>, dst: &mut Connection) -> crate::Result<()> {
104 match Command::from_frame(frame)? {
105 Command::Subscribe(subscribe) => {
106 subscribe_to.extend(subscribe.channels.into_iter());
107 }
108 Command::Unsubscribe(mut unsubscribe) => {
109 if unsubscribe.channels.is_empty() {
110 unsubscribe.channels = subscriptions.keys().map(|channel_name| channel_name.to_string()).collect();
111 }
112
113 for channel_name in unsubscribe.channels {
114 subscriptions.remove(&channel_name);
115
116 let response = make_unsubscribe_frame(channel_name, subscriptions.len());
117 dst.write_frame(&response).await?;
118 }
119 }
120 command => {
121 let cmd = Unknown::new(command.get_name());
122 cmd.apply(dst).await?;
123 }
124 }
125 Ok(())
126}
127
128fn make_subscribe_frame(channel_name: String, num_subs: usize) -> Frame {
129 let mut response = Frame::array();
130 response.push_bulk(Bytes::from_static(b"subscribe"));
131 response.push_bulk(Bytes::from(channel_name));
132 response.push_int(num_subs as u64);
133 response
134}
135
136fn make_unsubscribe_frame(channel_name: String, num_subs: usize) -> Frame {
137 let mut response = Frame::array();
138 response.push_bulk(Bytes::from_static(b"unsubscribe"));
139 response.push_bulk(Bytes::from(channel_name));
140 response.push_int(num_subs as u64);
141 response
142}
143
144fn make_message_frame(channel_name: String, msg: Bytes) -> Frame {
145 let mut response = Frame::array();
146 response.push_bulk(Bytes::from_static(b"message"));
147 response.push_bulk(Bytes::from(channel_name));
148 response.push_bulk(msg);
149 response
150}
151
152impl Unsubscribe {
153 pub fn new(channels: &[String]) -> Unsubscribe { Unsubscribe { channels: channels.to_vec() } }
154
155 pub fn parse_frames(parse: &mut Parse) -> Result<Unsubscribe, ParseError> {
156 use ParseError::EndOfStream;
157 let mut channels = vec![];
158
159 loop {
160 match parse.next_string() {
161 Ok(s) => channels.push(s),
162 Err(EndOfStream) => break,
163 Err(err) => return Err(err),
164 }
165 }
166
167 Ok(Unsubscribe { channels })
168 }
169
170 pub fn into_frame(self) -> Frame {
171 let mut frame = Frame::array();
172 frame.push_bulk(Bytes::from("unsubscribe".as_bytes()));
173
174 for channel in self.channels {
175 frame.push_bulk(Bytes::from(channel.into_bytes()));
176 }
177
178 frame
179 }
180}