kade_proto/clients/
client.rs

1use crate::cmd::{Dump, Get, Load, Ping, Publish, Set, Subscribe, Unsubscribe, Version};
2use crate::pkg::{Connection, Frame};
3
4use async_stream::try_stream;
5use bytes::Bytes;
6use std::io::{Error, ErrorKind};
7use std::{path::Path, time::Duration};
8use tokio::net::{TcpStream, ToSocketAddrs};
9use tokio_stream::Stream;
10use tracing::{debug, instrument};
11
12pub struct Client {
13    connection: Connection,
14}
15
16pub struct Subscriber {
17    client: Client,
18    subscribed_channels: Vec<String>,
19}
20
21#[derive(Debug, Clone)]
22pub struct Message {
23    pub channel: String,
24    pub content: Bytes,
25}
26
27impl Client {
28    pub async fn connect<T: ToSocketAddrs>(addr: T) -> crate::Result<Client> {
29        let socket = TcpStream::connect(addr).await?;
30        let connection = Connection::new(socket);
31
32        Ok(Client { connection })
33    }
34
35    pub async fn version(&mut self) -> crate::Result<Bytes> {
36        let frame = Version::new().into_frame();
37        self.connection.write_frame(&frame).await?;
38
39        match self.read_response().await? {
40            Frame::Simple(value) => Ok(value.into()),
41            Frame::Bulk(value) => Ok(value),
42            frame => Err(frame.to_error()),
43        }
44    }
45
46    #[instrument(skip(self))]
47    pub async fn ping(&mut self, msg: Option<Bytes>) -> crate::Result<Bytes> {
48        let frame = Ping::new(msg).into_frame();
49        debug!(request = ?frame);
50        self.connection.write_frame(&frame).await?;
51
52        match self.read_response().await? {
53            Frame::Simple(value) => Ok(value.into()),
54            Frame::Bulk(value) => Ok(value),
55            frame => Err(frame.to_error()),
56        }
57    }
58
59    #[instrument(skip(self))]
60    pub async fn dump(&mut self, path: &Path) -> crate::Result<()> {
61        let frame = Dump::new(path.to_path_buf()).into_frame();
62        debug!(request = ?frame);
63
64        self.connection.write_frame(&frame).await?;
65
66        match self.read_response().await? {
67            Frame::Simple(response) if response == "OK" => Ok(()),
68            frame => Err(frame.to_error()),
69        }
70    }
71
72    #[instrument(skip(self))]
73    pub async fn load(&mut self, path: &Path) -> crate::Result<()> {
74        let frame = Load::new(path.to_path_buf()).into_frame();
75        debug!(request = ?frame);
76
77        self.connection.write_frame(&frame).await?;
78
79        match self.read_response().await? {
80            Frame::Simple(response) if response == "OK" => Ok(()),
81            frame => Err(frame.to_error()),
82        }
83    }
84
85    #[instrument(skip(self))]
86    pub async fn get(&mut self, key: &str) -> crate::Result<Option<Bytes>> {
87        let frame = Get::new(key).into_frame();
88        debug!(request = ?frame);
89
90        self.connection.write_frame(&frame).await?;
91
92        match self.read_response().await? {
93            Frame::Simple(value) => Ok(Some(value.into())),
94            Frame::Bulk(value) => Ok(Some(value)),
95            Frame::Null => Ok(None),
96            frame => Err(frame.to_error()),
97        }
98    }
99
100    #[instrument(skip(self))]
101    pub async fn set(&mut self, key: &str, value: Bytes) -> crate::Result<()> {
102        (); // instrument
103        self.set_cmd(Set::new(key, value, None)).await
104    }
105
106    #[instrument(skip(self))]
107    pub async fn set_expires(&mut self, key: &str, value: Bytes, expiration: Duration) -> crate::Result<()> {
108        (); // instrument
109        self.set_cmd(Set::new(key, value, Some(expiration))).await
110    }
111
112    async fn set_cmd(&mut self, cmd: Set) -> crate::Result<()> {
113        let frame = cmd.into_frame();
114        debug!(request = ?frame);
115
116        self.connection.write_frame(&frame).await?;
117
118        match self.read_response().await? {
119            Frame::Simple(response) if response == "OK" => Ok(()),
120            frame => Err(frame.to_error()),
121        }
122    }
123
124    #[instrument(skip(self))]
125    pub async fn publish(&mut self, channel: &str, message: Bytes) -> crate::Result<u64> {
126        let frame = Publish::new(channel, message).into_frame();
127        debug!(request = ?frame);
128
129        self.connection.write_frame(&frame).await?;
130
131        match self.read_response().await? {
132            Frame::Integer(response) => Ok(response),
133            frame => Err(frame.to_error()),
134        }
135    }
136
137    #[instrument(skip(self))]
138    pub async fn subscribe(mut self, channels: Vec<String>) -> crate::Result<Subscriber> {
139        self.subscribe_cmd(&channels).await?;
140
141        Ok(Subscriber {
142            client: self,
143            subscribed_channels: channels,
144        })
145    }
146
147    async fn subscribe_cmd(&mut self, channels: &[String]) -> crate::Result<()> {
148        let frame = Subscribe::new(channels.to_vec()).into_frame();
149        debug!(request = ?frame);
150
151        self.connection.write_frame(&frame).await?;
152
153        for channel in channels {
154            let response = self.read_response().await?;
155
156            match response {
157                Frame::Array(ref frame) => match frame.as_slice() {
158                    [subscribe, schannel, ..] if *subscribe == "subscribe" && *schannel == channel => {}
159                    _ => return Err(response.to_error()),
160                },
161                frame => return Err(frame.to_error()),
162            };
163        }
164
165        Ok(())
166    }
167
168    async fn read_response(&mut self) -> crate::Result<Frame> {
169        let response = self.connection.read_frame().await?;
170        debug!(?response);
171
172        match response {
173            Some(Frame::Error(msg)) => Err(msg.into()),
174            Some(frame) => Ok(frame),
175            None => Err(Error::new(ErrorKind::ConnectionReset, "connection reset by server").into()),
176        }
177    }
178}
179
180impl Subscriber {
181    pub fn get_subscribed(&self) -> &[String] { &self.subscribed_channels }
182
183    pub async fn next_message(&mut self) -> crate::Result<Option<Message>> {
184        match self.client.connection.read_frame().await? {
185            Some(mframe) => {
186                debug!(?mframe);
187
188                match mframe {
189                    Frame::Array(ref frame) => match frame.as_slice() {
190                        [message, channel, content] if *message == "message" => Ok(Some(Message {
191                            channel: channel.to_string(),
192                            content: Bytes::from(content.to_string()),
193                        })),
194                        _ => Err(mframe.to_error()),
195                    },
196                    frame => Err(frame.to_error()),
197                }
198            }
199            None => Ok(None),
200        }
201    }
202
203    pub fn into_stream(mut self) -> impl Stream<Item = crate::Result<Message>> {
204        try_stream! {
205            while let Some(message) = self.next_message().await? {
206                yield message;
207            }
208        }
209    }
210
211    #[instrument(skip(self))]
212    pub async fn subscribe(&mut self, channels: &[String]) -> crate::Result<()> {
213        self.client.subscribe_cmd(channels).await?;
214        self.subscribed_channels.extend(channels.iter().map(Clone::clone));
215
216        Ok(())
217    }
218
219    #[instrument(skip(self))]
220    pub async fn unsubscribe(&mut self, channels: &[String]) -> crate::Result<()> {
221        let frame = Unsubscribe::new(channels).into_frame();
222        debug!(request = ?frame);
223
224        self.client.connection.write_frame(&frame).await?;
225
226        let num = if channels.is_empty() { self.subscribed_channels.len() } else { channels.len() };
227
228        for _ in 0..num {
229            let response = self.client.read_response().await?;
230
231            match response {
232                Frame::Array(ref frame) => match frame.as_slice() {
233                    [unsubscribe, channel, ..] if *unsubscribe == "unsubscribe" => {
234                        let len = self.subscribed_channels.len();
235
236                        if len == 0 {
237                            return Err(response.to_error());
238                        }
239
240                        self.subscribed_channels.retain(|c| *channel != &c[..]);
241
242                        if self.subscribed_channels.len() != len - 1 {
243                            return Err(response.to_error());
244                        }
245                    }
246                    _ => return Err(response.to_error()),
247                },
248                frame => return Err(frame.to_error()),
249            };
250        }
251
252        Ok(())
253    }
254}