kade_proto/clients/
client.rs1use crate::cmd::{Dump, Get, Load, Ping, Publish, Set, Subscribe, Unsubscribe, Version};
2use crate::pkg::{Connection, Frame};
3
4use async_stream::try_stream;
5use bytes::Bytes;
6use std::io::{Error, ErrorKind};
7use std::{path::Path, time::Duration};
8use tokio::net::{TcpStream, ToSocketAddrs};
9use tokio_stream::Stream;
10use tracing::{debug, instrument};
11
12pub struct Client {
13 connection: Connection,
14}
15
16pub struct Subscriber {
17 client: Client,
18 subscribed_channels: Vec<String>,
19}
20
21#[derive(Debug, Clone)]
22pub struct Message {
23 pub channel: String,
24 pub content: Bytes,
25}
26
27impl Client {
28 pub async fn connect<T: ToSocketAddrs>(addr: T) -> crate::Result<Client> {
29 let socket = TcpStream::connect(addr).await?;
30 let connection = Connection::new(socket);
31
32 Ok(Client { connection })
33 }
34
35 pub async fn version(&mut self) -> crate::Result<Bytes> {
36 let frame = Version::new().into_frame();
37 self.connection.write_frame(&frame).await?;
38
39 match self.read_response().await? {
40 Frame::Simple(value) => Ok(value.into()),
41 Frame::Bulk(value) => Ok(value),
42 frame => Err(frame.to_error()),
43 }
44 }
45
46 #[instrument(skip(self))]
47 pub async fn ping(&mut self, msg: Option<Bytes>) -> crate::Result<Bytes> {
48 let frame = Ping::new(msg).into_frame();
49 debug!(request = ?frame);
50 self.connection.write_frame(&frame).await?;
51
52 match self.read_response().await? {
53 Frame::Simple(value) => Ok(value.into()),
54 Frame::Bulk(value) => Ok(value),
55 frame => Err(frame.to_error()),
56 }
57 }
58
59 #[instrument(skip(self))]
60 pub async fn dump(&mut self, path: &Path) -> crate::Result<()> {
61 let frame = Dump::new(path.to_path_buf()).into_frame();
62 debug!(request = ?frame);
63
64 self.connection.write_frame(&frame).await?;
65
66 match self.read_response().await? {
67 Frame::Simple(response) if response == "OK" => Ok(()),
68 frame => Err(frame.to_error()),
69 }
70 }
71
72 #[instrument(skip(self))]
73 pub async fn load(&mut self, path: &Path) -> crate::Result<()> {
74 let frame = Load::new(path.to_path_buf()).into_frame();
75 debug!(request = ?frame);
76
77 self.connection.write_frame(&frame).await?;
78
79 match self.read_response().await? {
80 Frame::Simple(response) if response == "OK" => Ok(()),
81 frame => Err(frame.to_error()),
82 }
83 }
84
85 #[instrument(skip(self))]
86 pub async fn get(&mut self, key: &str) -> crate::Result<Option<Bytes>> {
87 let frame = Get::new(key).into_frame();
88 debug!(request = ?frame);
89
90 self.connection.write_frame(&frame).await?;
91
92 match self.read_response().await? {
93 Frame::Simple(value) => Ok(Some(value.into())),
94 Frame::Bulk(value) => Ok(Some(value)),
95 Frame::Null => Ok(None),
96 frame => Err(frame.to_error()),
97 }
98 }
99
100 #[instrument(skip(self))]
101 pub async fn set(&mut self, key: &str, value: Bytes) -> crate::Result<()> {
102 (); self.set_cmd(Set::new(key, value, None)).await
104 }
105
106 #[instrument(skip(self))]
107 pub async fn set_expires(&mut self, key: &str, value: Bytes, expiration: Duration) -> crate::Result<()> {
108 (); self.set_cmd(Set::new(key, value, Some(expiration))).await
110 }
111
112 async fn set_cmd(&mut self, cmd: Set) -> crate::Result<()> {
113 let frame = cmd.into_frame();
114 debug!(request = ?frame);
115
116 self.connection.write_frame(&frame).await?;
117
118 match self.read_response().await? {
119 Frame::Simple(response) if response == "OK" => Ok(()),
120 frame => Err(frame.to_error()),
121 }
122 }
123
124 #[instrument(skip(self))]
125 pub async fn publish(&mut self, channel: &str, message: Bytes) -> crate::Result<u64> {
126 let frame = Publish::new(channel, message).into_frame();
127 debug!(request = ?frame);
128
129 self.connection.write_frame(&frame).await?;
130
131 match self.read_response().await? {
132 Frame::Integer(response) => Ok(response),
133 frame => Err(frame.to_error()),
134 }
135 }
136
137 #[instrument(skip(self))]
138 pub async fn subscribe(mut self, channels: Vec<String>) -> crate::Result<Subscriber> {
139 self.subscribe_cmd(&channels).await?;
140
141 Ok(Subscriber {
142 client: self,
143 subscribed_channels: channels,
144 })
145 }
146
147 async fn subscribe_cmd(&mut self, channels: &[String]) -> crate::Result<()> {
148 let frame = Subscribe::new(channels.to_vec()).into_frame();
149 debug!(request = ?frame);
150
151 self.connection.write_frame(&frame).await?;
152
153 for channel in channels {
154 let response = self.read_response().await?;
155
156 match response {
157 Frame::Array(ref frame) => match frame.as_slice() {
158 [subscribe, schannel, ..] if *subscribe == "subscribe" && *schannel == channel => {}
159 _ => return Err(response.to_error()),
160 },
161 frame => return Err(frame.to_error()),
162 };
163 }
164
165 Ok(())
166 }
167
168 async fn read_response(&mut self) -> crate::Result<Frame> {
169 let response = self.connection.read_frame().await?;
170 debug!(?response);
171
172 match response {
173 Some(Frame::Error(msg)) => Err(msg.into()),
174 Some(frame) => Ok(frame),
175 None => Err(Error::new(ErrorKind::ConnectionReset, "connection reset by server").into()),
176 }
177 }
178}
179
180impl Subscriber {
181 pub fn get_subscribed(&self) -> &[String] { &self.subscribed_channels }
182
183 pub async fn next_message(&mut self) -> crate::Result<Option<Message>> {
184 match self.client.connection.read_frame().await? {
185 Some(mframe) => {
186 debug!(?mframe);
187
188 match mframe {
189 Frame::Array(ref frame) => match frame.as_slice() {
190 [message, channel, content] if *message == "message" => Ok(Some(Message {
191 channel: channel.to_string(),
192 content: Bytes::from(content.to_string()),
193 })),
194 _ => Err(mframe.to_error()),
195 },
196 frame => Err(frame.to_error()),
197 }
198 }
199 None => Ok(None),
200 }
201 }
202
203 pub fn into_stream(mut self) -> impl Stream<Item = crate::Result<Message>> {
204 try_stream! {
205 while let Some(message) = self.next_message().await? {
206 yield message;
207 }
208 }
209 }
210
211 #[instrument(skip(self))]
212 pub async fn subscribe(&mut self, channels: &[String]) -> crate::Result<()> {
213 self.client.subscribe_cmd(channels).await?;
214 self.subscribed_channels.extend(channels.iter().map(Clone::clone));
215
216 Ok(())
217 }
218
219 #[instrument(skip(self))]
220 pub async fn unsubscribe(&mut self, channels: &[String]) -> crate::Result<()> {
221 let frame = Unsubscribe::new(channels).into_frame();
222 debug!(request = ?frame);
223
224 self.client.connection.write_frame(&frame).await?;
225
226 let num = if channels.is_empty() { self.subscribed_channels.len() } else { channels.len() };
227
228 for _ in 0..num {
229 let response = self.client.read_response().await?;
230
231 match response {
232 Frame::Array(ref frame) => match frame.as_slice() {
233 [unsubscribe, channel, ..] if *unsubscribe == "unsubscribe" => {
234 let len = self.subscribed_channels.len();
235
236 if len == 0 {
237 return Err(response.to_error());
238 }
239
240 self.subscribed_channels.retain(|c| *channel != &c[..]);
241
242 if self.subscribed_channels.len() != len - 1 {
243 return Err(response.to_error());
244 }
245 }
246 _ => return Err(response.to_error()),
247 },
248 frame => return Err(frame.to_error()),
249 };
250 }
251
252 Ok(())
253 }
254}