Skip to main content

redis_server_wrapper/
cli.rs

1//! Type-safe wrapper for the `redis-cli` command.
2
3use std::path::PathBuf;
4use std::process::{Command, Output, Stdio};
5
6use tokio::process::Command as TokioCommand;
7
8use crate::error::{Error, Result};
9
10/// RESP protocol version for client connections.
11#[derive(Debug, Clone, Copy)]
12pub enum RespProtocol {
13    /// RESP2 (default for most Redis versions).
14    Resp2,
15    /// RESP3.
16    Resp3,
17}
18
19/// Output format for `redis-cli` commands.
20#[derive(Debug, Clone, Copy)]
21pub enum OutputFormat {
22    /// Default Redis protocol output.
23    Default,
24    /// Raw output (no formatting).
25    Raw,
26    /// CSV output.
27    Csv,
28    /// JSON output.
29    Json,
30    /// Quoted JSON output.
31    QuotedJson,
32}
33
34/// Builder for executing `redis-cli` commands.
35#[derive(Debug, Clone)]
36pub struct RedisCli {
37    bin: String,
38    host: String,
39    port: u16,
40    password: Option<String>,
41    user: Option<String>,
42    db: Option<u32>,
43    unixsocket: Option<PathBuf>,
44    tls: bool,
45    sni: Option<String>,
46    cacert: Option<PathBuf>,
47    cert: Option<PathBuf>,
48    key: Option<PathBuf>,
49    resp: Option<RespProtocol>,
50    cluster_mode: bool,
51    output_format: OutputFormat,
52    no_auth_warning: bool,
53}
54
55impl RedisCli {
56    /// Create a new `redis-cli` builder with defaults (localhost:6379).
57    pub fn new() -> Self {
58        Self {
59            bin: "redis-cli".into(),
60            host: "127.0.0.1".into(),
61            port: 6379,
62            password: None,
63            user: None,
64            db: None,
65            unixsocket: None,
66            tls: false,
67            sni: None,
68            cacert: None,
69            cert: None,
70            key: None,
71            resp: None,
72            cluster_mode: false,
73            output_format: OutputFormat::Default,
74            no_auth_warning: false,
75        }
76    }
77
78    /// Set the `redis-cli` binary path.
79    pub fn bin(mut self, bin: impl Into<String>) -> Self {
80        self.bin = bin.into();
81        self
82    }
83
84    /// Set the host to connect to.
85    pub fn host(mut self, host: impl Into<String>) -> Self {
86        self.host = host.into();
87        self
88    }
89
90    /// Set the port to connect to.
91    pub fn port(mut self, port: u16) -> Self {
92        self.port = port;
93        self
94    }
95
96    /// Set the password for AUTH.
97    pub fn password(mut self, password: impl Into<String>) -> Self {
98        self.password = Some(password.into());
99        self
100    }
101
102    /// Set the ACL username for AUTH.
103    pub fn user(mut self, user: impl Into<String>) -> Self {
104        self.user = Some(user.into());
105        self
106    }
107
108    /// Select a database number.
109    pub fn db(mut self, db: u32) -> Self {
110        self.db = Some(db);
111        self
112    }
113
114    /// Connect via a Unix socket instead of TCP.
115    pub fn unixsocket(mut self, path: impl Into<PathBuf>) -> Self {
116        self.unixsocket = Some(path.into());
117        self
118    }
119
120    /// Enable TLS for the connection.
121    pub fn tls(mut self, enable: bool) -> Self {
122        self.tls = enable;
123        self
124    }
125
126    /// Set the SNI hostname for TLS.
127    pub fn sni(mut self, hostname: impl Into<String>) -> Self {
128        self.sni = Some(hostname.into());
129        self
130    }
131
132    /// Set the CA certificate file for TLS verification.
133    pub fn cacert(mut self, path: impl Into<PathBuf>) -> Self {
134        self.cacert = Some(path.into());
135        self
136    }
137
138    /// Set the client certificate file for TLS.
139    pub fn cert(mut self, path: impl Into<PathBuf>) -> Self {
140        self.cert = Some(path.into());
141        self
142    }
143
144    /// Set the client private key file for TLS.
145    pub fn key(mut self, path: impl Into<PathBuf>) -> Self {
146        self.key = Some(path.into());
147        self
148    }
149
150    /// Set the RESP protocol version.
151    pub fn resp(mut self, protocol: RespProtocol) -> Self {
152        self.resp = Some(protocol);
153        self
154    }
155
156    /// Enable cluster mode (`-c` flag) for following redirects.
157    pub fn cluster_mode(mut self, enable: bool) -> Self {
158        self.cluster_mode = enable;
159        self
160    }
161
162    /// Set the output format.
163    pub fn output_format(mut self, format: OutputFormat) -> Self {
164        self.output_format = format;
165        self
166    }
167
168    /// Suppress the AUTH password warning.
169    pub fn no_auth_warning(mut self, suppress: bool) -> Self {
170        self.no_auth_warning = suppress;
171        self
172    }
173
174    /// Run a command and return stdout on success.
175    pub async fn run(&self, args: &[&str]) -> Result<String> {
176        let output = self.raw_output(args).await?;
177        if output.status.success() {
178            Ok(String::from_utf8_lossy(&output.stdout).to_string())
179        } else {
180            let stderr = String::from_utf8_lossy(&output.stderr);
181            Err(Error::Cli {
182                host: self.host.clone(),
183                port: self.port,
184                detail: stderr.into_owned(),
185            })
186        }
187    }
188
189    /// Run a command, ignoring output. Used for fire-and-forget (SHUTDOWN).
190    pub fn fire_and_forget(&self, args: &[&str]) {
191        let _ = Command::new(&self.bin)
192            .args(self.base_args())
193            .args(args)
194            .stdout(Stdio::null())
195            .stderr(Stdio::null())
196            .status();
197    }
198
199    /// Send PING and return true if PONG is received.
200    pub async fn ping(&self) -> bool {
201        self.run(&["PING"])
202            .await
203            .map(|r| r.trim() == "PONG")
204            .unwrap_or(false)
205    }
206
207    /// Send SHUTDOWN NOSAVE. Best-effort.
208    pub fn shutdown(&self) {
209        self.fire_and_forget(&["SHUTDOWN", "NOSAVE"]);
210    }
211
212    /// Wait until the server responds to PING or timeout expires.
213    pub async fn wait_for_ready(&self, timeout: std::time::Duration) -> Result<()> {
214        let start = std::time::Instant::now();
215        loop {
216            if self.ping().await {
217                return Ok(());
218            }
219            if start.elapsed() > timeout {
220                return Err(Error::Timeout {
221                    message: format!(
222                        "{}:{} did not respond within {timeout:?}",
223                        self.host, self.port
224                    ),
225                });
226            }
227            tokio::time::sleep(std::time::Duration::from_millis(250)).await;
228        }
229    }
230
231    /// Run `redis-cli --cluster create ...` to form a cluster.
232    pub async fn cluster_create(
233        &self,
234        node_addrs: &[String],
235        replicas_per_master: u16,
236    ) -> Result<()> {
237        let mut args = self.base_args();
238        args.push("--cluster".into());
239        args.push("create".into());
240        args.extend(node_addrs.iter().cloned());
241        if replicas_per_master > 0 {
242            args.push("--cluster-replicas".into());
243            args.push(replicas_per_master.to_string());
244        }
245        args.push("--cluster-yes".into());
246
247        let str_args: Vec<&str> = args.iter().map(|s| s.as_str()).collect();
248        let output = TokioCommand::new(&self.bin)
249            .args(&str_args)
250            .output()
251            .await?;
252
253        if output.status.success() {
254            Ok(())
255        } else {
256            Err(Error::ClusterCreate {
257                stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
258                stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
259            })
260        }
261    }
262
263    fn base_args(&self) -> Vec<String> {
264        let mut args = Vec::new();
265
266        if let Some(ref path) = self.unixsocket {
267            args.push("-s".to_string());
268            args.push(path.display().to_string());
269        } else {
270            args.push("-h".to_string());
271            args.push(self.host.clone());
272            args.push("-p".to_string());
273            args.push(self.port.to_string());
274        }
275
276        if let Some(ref user) = self.user {
277            args.push("--user".to_string());
278            args.push(user.clone());
279        }
280        if let Some(ref pw) = self.password {
281            args.push("-a".to_string());
282            args.push(pw.clone());
283        }
284        if let Some(db) = self.db {
285            args.push("-n".to_string());
286            args.push(db.to_string());
287        }
288
289        // TLS
290        if self.tls {
291            args.push("--tls".to_string());
292        }
293        if let Some(ref sni) = self.sni {
294            args.push("--sni".to_string());
295            args.push(sni.clone());
296        }
297        if let Some(ref path) = self.cacert {
298            args.push("--cacert".to_string());
299            args.push(path.display().to_string());
300        }
301        if let Some(ref path) = self.cert {
302            args.push("--cert".to_string());
303            args.push(path.display().to_string());
304        }
305        if let Some(ref path) = self.key {
306            args.push("--key".to_string());
307            args.push(path.display().to_string());
308        }
309
310        // Protocol
311        if let Some(ref proto) = self.resp {
312            match proto {
313                RespProtocol::Resp2 => args.push("-2".to_string()),
314                RespProtocol::Resp3 => args.push("-3".to_string()),
315            }
316        }
317
318        // Cluster
319        if self.cluster_mode {
320            args.push("-c".to_string());
321        }
322
323        // Output format
324        match self.output_format {
325            OutputFormat::Default => {}
326            OutputFormat::Raw => args.push("--raw".to_string()),
327            OutputFormat::Csv => args.push("--csv".to_string()),
328            OutputFormat::Json => args.push("--json".to_string()),
329            OutputFormat::QuotedJson => args.push("--quoted-json".to_string()),
330        }
331
332        if self.no_auth_warning {
333            args.push("--no-auth-warning".to_string());
334        }
335
336        args
337    }
338
339    async fn raw_output(&self, args: &[&str]) -> std::io::Result<Output> {
340        TokioCommand::new(&self.bin)
341            .args(self.base_args())
342            .args(args)
343            .output()
344            .await
345    }
346}
347
348impl Default for RedisCli {
349    fn default() -> Self {
350        Self::new()
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[test]
359    fn default_config() {
360        let cli = RedisCli::new();
361        assert_eq!(cli.host, "127.0.0.1");
362        assert_eq!(cli.port, 6379);
363    }
364
365    #[test]
366    fn builder_chain() {
367        let cli = RedisCli::new()
368            .host("10.0.0.1")
369            .port(6380)
370            .password("secret")
371            .bin("/usr/local/bin/redis-cli");
372        assert_eq!(cli.host, "10.0.0.1");
373        assert_eq!(cli.port, 6380);
374        assert_eq!(cli.password.as_deref(), Some("secret"));
375        assert_eq!(cli.bin, "/usr/local/bin/redis-cli");
376    }
377}