ember_client/
connection.rs1use std::io;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use std::time::Duration;
10
11use bytes::BytesMut;
12use ember_protocol::parse::parse_frame;
13use ember_protocol::types::Frame;
14use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
15use tokio::net::TcpStream;
16
17#[cfg(feature = "tls")]
18use crate::tls::TlsClientConfig;
19
20const MAX_READ_BUF: usize = 64 * 1024;
23
24const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
26
27const READ_TIMEOUT: Duration = Duration::from_secs(10);
29
30#[derive(Debug, thiserror::Error)]
32pub enum ClientError {
33 #[error("connection failed: {0}")]
34 Io(#[from] io::Error),
35
36 #[error("protocol error: {0}")]
37 Protocol(String),
38
39 #[error("server error: {0}")]
40 Server(String),
41
42 #[error("server disconnected")]
43 Disconnected,
44
45 #[error("authentication failed: {0}")]
46 AuthFailed(String),
47
48 #[error("connection timed out")]
49 Timeout,
50
51 #[error("response too large (exceeded {MAX_READ_BUF} bytes)")]
52 ResponseTooLarge,
53}
54
55pub(crate) enum Transport {
60 Tcp(TcpStream),
61 #[cfg(feature = "tls")]
62 Tls(Box<tokio_rustls::client::TlsStream<TcpStream>>),
63}
64
65impl AsyncRead for Transport {
66 fn poll_read(
67 self: Pin<&mut Self>,
68 cx: &mut Context<'_>,
69 buf: &mut ReadBuf<'_>,
70 ) -> Poll<io::Result<()>> {
71 match self.get_mut() {
72 Transport::Tcp(s) => Pin::new(s).poll_read(cx, buf),
73 #[cfg(feature = "tls")]
74 Transport::Tls(s) => Pin::new(s.as_mut()).poll_read(cx, buf),
75 }
76 }
77}
78
79impl AsyncWrite for Transport {
80 fn poll_write(
81 self: Pin<&mut Self>,
82 cx: &mut Context<'_>,
83 buf: &[u8],
84 ) -> Poll<io::Result<usize>> {
85 match self.get_mut() {
86 Transport::Tcp(s) => Pin::new(s).poll_write(cx, buf),
87 #[cfg(feature = "tls")]
88 Transport::Tls(s) => Pin::new(s.as_mut()).poll_write(cx, buf),
89 }
90 }
91
92 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
93 match self.get_mut() {
94 Transport::Tcp(s) => Pin::new(s).poll_flush(cx),
95 #[cfg(feature = "tls")]
96 Transport::Tls(s) => Pin::new(s.as_mut()).poll_flush(cx),
97 }
98 }
99
100 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
101 match self.get_mut() {
102 Transport::Tcp(s) => Pin::new(s).poll_shutdown(cx),
103 #[cfg(feature = "tls")]
104 Transport::Tls(s) => Pin::new(s.as_mut()).poll_shutdown(cx),
105 }
106 }
107}
108
109pub struct Client {
114 transport: Transport,
115 read_buf: BytesMut,
116 write_buf: BytesMut,
117}
118
119impl Client {
120 pub async fn connect(host: &str, port: u16) -> Result<Self, ClientError> {
124 let tcp = tokio::time::timeout(CONNECT_TIMEOUT, TcpStream::connect((host, port)))
125 .await
126 .map_err(|_| ClientError::Timeout)?
127 .map_err(ClientError::Io)?;
128
129 Ok(Self::from_transport(Transport::Tcp(tcp)))
130 }
131
132 #[cfg(feature = "tls")]
137 pub async fn connect_tls(
138 host: &str,
139 port: u16,
140 tls: &TlsClientConfig,
141 ) -> Result<Self, ClientError> {
142 let stream = tokio::time::timeout(CONNECT_TIMEOUT, crate::tls::connect(host, port, tls))
143 .await
144 .map_err(|_| ClientError::Timeout)?
145 .map_err(ClientError::Io)?;
146
147 Ok(Self::from_transport(stream))
148 }
149
150 fn from_transport(transport: Transport) -> Self {
151 Self {
152 transport,
153 read_buf: BytesMut::with_capacity(4096),
154 write_buf: BytesMut::with_capacity(4096),
155 }
156 }
157
158 pub async fn send(&mut self, args: &[&str]) -> Result<Frame, ClientError> {
175 let parts = args
176 .iter()
177 .map(|t| Frame::Bulk(bytes::Bytes::copy_from_slice(t.as_bytes())))
178 .collect();
179 self.send_frame(Frame::Array(parts)).await
180 }
181
182 pub(crate) async fn send_frame(&mut self, frame: Frame) -> Result<Frame, ClientError> {
187 self.write_buf.clear();
188 frame.serialize(&mut self.write_buf);
189 self.transport.write_all(&self.write_buf).await?;
190 self.transport.flush().await?;
191 self.read_response().await
192 }
193
194 pub(crate) async fn send_batch(&mut self, frames: &[Frame]) -> Result<Vec<Frame>, ClientError> {
199 self.write_buf.clear();
200 for frame in frames {
201 frame.serialize(&mut self.write_buf);
202 }
203 self.transport.write_all(&self.write_buf).await?;
204 self.transport.flush().await?;
205
206 let mut results = Vec::with_capacity(frames.len());
207 for _ in 0..frames.len() {
208 results.push(self.read_response().await?);
209 }
210 Ok(results)
211 }
212
213 pub async fn auth(&mut self, password: &str) -> Result<(), ClientError> {
218 let frame = Frame::Array(vec![
219 Frame::Bulk(bytes::Bytes::from_static(b"AUTH")),
220 Frame::Bulk(bytes::Bytes::copy_from_slice(password.as_bytes())),
221 ]);
222
223 match self.send_frame(frame).await? {
224 Frame::Simple(s) if s == "OK" => Ok(()),
225 Frame::Error(e) => Err(ClientError::AuthFailed(e)),
226 _ => Err(ClientError::AuthFailed(
227 "unexpected response to AUTH".into(),
228 )),
229 }
230 }
231
232 pub async fn disconnect(&mut self) {
237 let quit = Frame::Array(vec![Frame::Bulk(bytes::Bytes::from_static(b"QUIT"))]);
238 self.write_buf.clear();
239 quit.serialize(&mut self.write_buf);
240 let _ = self.transport.write_all(&self.write_buf).await;
241 let _ = self.transport.flush().await;
242 let _ = self.transport.shutdown().await;
243 }
244
245 pub(crate) async fn write_frame(&mut self, frame: Frame) -> Result<(), ClientError> {
250 self.write_buf.clear();
251 frame.serialize(&mut self.write_buf);
252 self.transport.write_all(&self.write_buf).await?;
253 self.transport.flush().await?;
254 Ok(())
255 }
256
257 pub(crate) async fn read_response(&mut self) -> Result<Frame, ClientError> {
259 loop {
260 if !self.read_buf.is_empty() {
261 match parse_frame(&self.read_buf) {
262 Ok(Some((frame, consumed))) => {
263 let _ = self.read_buf.split_to(consumed);
264 return Ok(frame);
265 }
266 Ok(None) => {
267 }
269 Err(e) => {
270 return Err(ClientError::Protocol(e.to_string()));
271 }
272 }
273 }
274
275 if self.read_buf.len() >= MAX_READ_BUF {
276 return Err(ClientError::ResponseTooLarge);
277 }
278
279 let read_result =
280 tokio::time::timeout(READ_TIMEOUT, self.transport.read_buf(&mut self.read_buf))
281 .await;
282
283 match read_result {
284 Ok(Ok(0)) => return Err(ClientError::Disconnected),
285 Ok(Ok(_)) => {} Ok(Err(e)) => return Err(ClientError::Io(e)),
287 Err(_) => return Err(ClientError::Timeout),
288 }
289 }
290 }
291}