Skip to main content

ember_client/
connection.rs

1//! Async client connection to an ember server.
2//!
3//! Handles connecting, sending commands as RESP3 arrays, and reading back
4//! parsed frames. Works transparently over plain TCP or TLS.
5
6use std::io;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use std::time::Duration;
10
11use bytes::BytesMut;
12use ember_protocol::parse::parse_frame;
13use ember_protocol::types::Frame;
14use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
15use tokio::net::TcpStream;
16
17#[cfg(feature = "tls")]
18use crate::tls::TlsClientConfig;
19
20/// Maximum read buffer size (64 KiB). Prevents unbounded memory growth if the
21/// server sends a response that never completes.
22const MAX_READ_BUF: usize = 64 * 1024;
23
24/// Default timeout for establishing the TCP connection.
25const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
26
27/// Default timeout for reading a response from the server.
28const READ_TIMEOUT: Duration = Duration::from_secs(10);
29
30/// Errors that can occur during client operations.
31#[derive(Debug, thiserror::Error)]
32pub enum ClientError {
33    #[error("connection failed: {0}")]
34    Io(#[from] io::Error),
35
36    #[error("protocol error: {0}")]
37    Protocol(String),
38
39    #[error("server error: {0}")]
40    Server(String),
41
42    #[error("server disconnected")]
43    Disconnected,
44
45    #[error("authentication failed: {0}")]
46    AuthFailed(String),
47
48    #[error("connection timed out")]
49    Timeout,
50
51    #[error("response too large (exceeded {MAX_READ_BUF} bytes)")]
52    ResponseTooLarge,
53}
54
55/// Underlying transport — plain TCP or (optionally) TLS.
56///
57/// Centralising the dispatch here keeps the `Client` logic clean regardless
58/// of which features are compiled in.
59pub(crate) enum Transport {
60    Tcp(TcpStream),
61    #[cfg(feature = "tls")]
62    Tls(Box<tokio_rustls::client::TlsStream<TcpStream>>),
63}
64
65impl AsyncRead for Transport {
66    fn poll_read(
67        self: Pin<&mut Self>,
68        cx: &mut Context<'_>,
69        buf: &mut ReadBuf<'_>,
70    ) -> Poll<io::Result<()>> {
71        match self.get_mut() {
72            Transport::Tcp(s) => Pin::new(s).poll_read(cx, buf),
73            #[cfg(feature = "tls")]
74            Transport::Tls(s) => Pin::new(s.as_mut()).poll_read(cx, buf),
75        }
76    }
77}
78
79impl AsyncWrite for Transport {
80    fn poll_write(
81        self: Pin<&mut Self>,
82        cx: &mut Context<'_>,
83        buf: &[u8],
84    ) -> Poll<io::Result<usize>> {
85        match self.get_mut() {
86            Transport::Tcp(s) => Pin::new(s).poll_write(cx, buf),
87            #[cfg(feature = "tls")]
88            Transport::Tls(s) => Pin::new(s.as_mut()).poll_write(cx, buf),
89        }
90    }
91
92    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
93        match self.get_mut() {
94            Transport::Tcp(s) => Pin::new(s).poll_flush(cx),
95            #[cfg(feature = "tls")]
96            Transport::Tls(s) => Pin::new(s.as_mut()).poll_flush(cx),
97        }
98    }
99
100    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
101        match self.get_mut() {
102            Transport::Tcp(s) => Pin::new(s).poll_shutdown(cx),
103            #[cfg(feature = "tls")]
104            Transport::Tls(s) => Pin::new(s.as_mut()).poll_shutdown(cx),
105        }
106    }
107}
108
109/// An async client connected to a single ember server.
110///
111/// Buffers reads and writes internally. Not thread-safe — use one `Client`
112/// per task, or wrap in an `Arc<Mutex<_>>` if sharing is needed.
113pub struct Client {
114    transport: Transport,
115    read_buf: BytesMut,
116    write_buf: BytesMut,
117}
118
119impl Client {
120    /// Connects to an ember server over plain TCP.
121    ///
122    /// Times out after 5 seconds if the server is unreachable.
123    pub async fn connect(host: &str, port: u16) -> Result<Self, ClientError> {
124        let tcp = tokio::time::timeout(CONNECT_TIMEOUT, TcpStream::connect((host, port)))
125            .await
126            .map_err(|_| ClientError::Timeout)?
127            .map_err(ClientError::Io)?;
128
129        Ok(Self::from_transport(Transport::Tcp(tcp)))
130    }
131
132    /// Connects to an ember server with TLS.
133    ///
134    /// Performs the TCP connection and TLS handshake within the 5-second
135    /// connect timeout.
136    #[cfg(feature = "tls")]
137    pub async fn connect_tls(
138        host: &str,
139        port: u16,
140        tls: &TlsClientConfig,
141    ) -> Result<Self, ClientError> {
142        let stream = tokio::time::timeout(CONNECT_TIMEOUT, crate::tls::connect(host, port, tls))
143            .await
144            .map_err(|_| ClientError::Timeout)?
145            .map_err(ClientError::Io)?;
146
147        Ok(Self::from_transport(stream))
148    }
149
150    fn from_transport(transport: Transport) -> Self {
151        Self {
152            transport,
153            read_buf: BytesMut::with_capacity(4096),
154            write_buf: BytesMut::with_capacity(4096),
155        }
156    }
157
158    /// Sends a command and returns the server's RESP3 response.
159    ///
160    /// Arguments are serialized as a RESP3 array of bulk strings, which is
161    /// the standard client-to-server wire format.
162    ///
163    /// # Example
164    ///
165    /// ```no_run
166    /// # use ember_client::Client;
167    /// # async fn example() -> Result<(), ember_client::ClientError> {
168    /// let mut client = Client::connect("127.0.0.1", 6379).await?;
169    /// let pong = client.send(&["PING"]).await?;
170    /// let value = client.send(&["GET", "mykey"]).await?;
171    /// # Ok(())
172    /// # }
173    /// ```
174    pub async fn send(&mut self, args: &[&str]) -> Result<Frame, ClientError> {
175        let parts = args
176            .iter()
177            .map(|t| Frame::Bulk(bytes::Bytes::copy_from_slice(t.as_bytes())))
178            .collect();
179        self.send_frame(Frame::Array(parts)).await
180    }
181
182    /// Sends a single pre-built frame and returns the response.
183    ///
184    /// Used internally by typed command methods to avoid re-encoding
185    /// arguments through `&[&str]`.
186    pub(crate) async fn send_frame(&mut self, frame: Frame) -> Result<Frame, ClientError> {
187        self.write_buf.clear();
188        frame.serialize(&mut self.write_buf);
189        self.transport.write_all(&self.write_buf).await?;
190        self.transport.flush().await?;
191        self.read_response().await
192    }
193
194    /// Writes all frames in a single flush, then reads one response per frame.
195    ///
196    /// This is the core of pipelining: batching multiple commands into one
197    /// syscall and reading responses sequentially afterwards.
198    pub(crate) async fn send_batch(&mut self, frames: &[Frame]) -> Result<Vec<Frame>, ClientError> {
199        self.write_buf.clear();
200        for frame in frames {
201            frame.serialize(&mut self.write_buf);
202        }
203        self.transport.write_all(&self.write_buf).await?;
204        self.transport.flush().await?;
205
206        let mut results = Vec::with_capacity(frames.len());
207        for _ in 0..frames.len() {
208            results.push(self.read_response().await?);
209        }
210        Ok(results)
211    }
212
213    /// Authenticates with the server using the `AUTH` command.
214    ///
215    /// Returns `Ok(())` on success, or `ClientError::AuthFailed` if the server
216    /// rejects the password.
217    pub async fn auth(&mut self, password: &str) -> Result<(), ClientError> {
218        let frame = Frame::Array(vec![
219            Frame::Bulk(bytes::Bytes::from_static(b"AUTH")),
220            Frame::Bulk(bytes::Bytes::copy_from_slice(password.as_bytes())),
221        ]);
222
223        match self.send_frame(frame).await? {
224            Frame::Simple(s) if s == "OK" => Ok(()),
225            Frame::Error(e) => Err(ClientError::AuthFailed(e)),
226            _ => Err(ClientError::AuthFailed(
227                "unexpected response to AUTH".into(),
228            )),
229        }
230    }
231
232    /// Gracefully disconnects from the server.
233    ///
234    /// Sends `QUIT` so the server can clean up, then shuts down the transport.
235    /// Errors are ignored — this is best-effort cleanup.
236    pub async fn disconnect(&mut self) {
237        let quit = Frame::Array(vec![Frame::Bulk(bytes::Bytes::from_static(b"QUIT"))]);
238        self.write_buf.clear();
239        quit.serialize(&mut self.write_buf);
240        let _ = self.transport.write_all(&self.write_buf).await;
241        let _ = self.transport.flush().await;
242        let _ = self.transport.shutdown().await;
243    }
244
245    /// Serializes and writes a single frame without waiting for a response.
246    ///
247    /// Used by [`Subscriber`] to send SUBSCRIBE/UNSUBSCRIBE frames before
248    /// draining the confirmation responses separately.
249    pub(crate) async fn write_frame(&mut self, frame: Frame) -> Result<(), ClientError> {
250        self.write_buf.clear();
251        frame.serialize(&mut self.write_buf);
252        self.transport.write_all(&self.write_buf).await?;
253        self.transport.flush().await?;
254        Ok(())
255    }
256
257    /// Reads a complete RESP3 frame from the server.
258    pub(crate) async fn read_response(&mut self) -> Result<Frame, ClientError> {
259        loop {
260            if !self.read_buf.is_empty() {
261                match parse_frame(&self.read_buf) {
262                    Ok(Some((frame, consumed))) => {
263                        let _ = self.read_buf.split_to(consumed);
264                        return Ok(frame);
265                    }
266                    Ok(None) => {
267                        // incomplete frame — need more data
268                    }
269                    Err(e) => {
270                        return Err(ClientError::Protocol(e.to_string()));
271                    }
272                }
273            }
274
275            if self.read_buf.len() >= MAX_READ_BUF {
276                return Err(ClientError::ResponseTooLarge);
277            }
278
279            let read_result =
280                tokio::time::timeout(READ_TIMEOUT, self.transport.read_buf(&mut self.read_buf))
281                    .await;
282
283            match read_result {
284                Ok(Ok(0)) => return Err(ClientError::Disconnected),
285                Ok(Ok(_)) => {} // got data, loop back to try parsing
286                Ok(Err(e)) => return Err(ClientError::Io(e)),
287                Err(_) => return Err(ClientError::Timeout),
288            }
289        }
290    }
291}