Skip to main content

redis_server_wrapper/
server.rs

1//! Type-safe wrapper for `redis-server` with builder pattern.
2
3use std::collections::HashMap;
4use std::fs;
5use std::path::PathBuf;
6use std::time::Duration;
7
8use tokio::process::Command;
9
10use crate::cli::RedisCli;
11use crate::error::{Error, Result};
12
13/// Builder and lifecycle manager for a single `redis-server` process.
14///
15/// # Example
16///
17/// ```no_run
18/// use redis_server_wrapper::RedisServer;
19///
20/// # async fn example() {
21/// let server = RedisServer::new()
22///     .port(6400)
23///     .bind("127.0.0.1")
24///     .save(false)
25///     .start()
26///     .await
27///     .unwrap();
28///
29/// assert!(server.is_alive().await);
30/// // Stopped automatically on Drop.
31/// # }
32/// ```
33#[derive(Debug, Clone)]
34pub struct RedisServerConfig {
35    // -- network --
36    pub port: u16,
37    pub bind: String,
38    pub protected_mode: bool,
39    pub tcp_backlog: Option<u32>,
40    pub unixsocket: Option<PathBuf>,
41    pub unixsocketperm: Option<u32>,
42    pub timeout: Option<u32>,
43    pub tcp_keepalive: Option<u32>,
44
45    // -- tls --
46    pub tls_port: Option<u16>,
47    pub tls_cert_file: Option<PathBuf>,
48    pub tls_key_file: Option<PathBuf>,
49    pub tls_ca_cert_file: Option<PathBuf>,
50    pub tls_auth_clients: Option<bool>,
51
52    // -- general --
53    pub daemonize: bool,
54    pub dir: PathBuf,
55    pub loglevel: LogLevel,
56    pub databases: Option<u32>,
57
58    // -- memory --
59    pub maxmemory: Option<String>,
60    pub maxmemory_policy: Option<String>,
61    pub maxclients: Option<u32>,
62
63    // -- persistence --
64    pub save: bool,
65    pub appendonly: bool,
66
67    // -- replication --
68    pub replicaof: Option<(String, u16)>,
69    pub masterauth: Option<String>,
70
71    // -- security --
72    pub password: Option<String>,
73    pub acl_file: Option<PathBuf>,
74
75    // -- cluster --
76    pub cluster_enabled: bool,
77    pub cluster_node_timeout: Option<u64>,
78
79    // -- modules --
80    pub loadmodule: Vec<PathBuf>,
81
82    // -- advanced --
83    pub hz: Option<u32>,
84    pub io_threads: Option<u32>,
85    pub io_threads_do_reads: Option<bool>,
86    pub notify_keyspace_events: Option<String>,
87
88    // -- catch-all for anything not covered above --
89    pub extra: HashMap<String, String>,
90
91    // -- binary paths --
92    pub redis_server_bin: String,
93    pub redis_cli_bin: String,
94}
95
96/// Redis log level.
97#[derive(Debug, Clone, Copy)]
98pub enum LogLevel {
99    Debug,
100    Verbose,
101    Notice,
102    Warning,
103}
104
105impl std::fmt::Display for LogLevel {
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        match self {
108            LogLevel::Debug => f.write_str("debug"),
109            LogLevel::Verbose => f.write_str("verbose"),
110            LogLevel::Notice => f.write_str("notice"),
111            LogLevel::Warning => f.write_str("warning"),
112        }
113    }
114}
115
116impl Default for RedisServerConfig {
117    fn default() -> Self {
118        Self {
119            port: 6379,
120            bind: "127.0.0.1".into(),
121            protected_mode: false,
122            tcp_backlog: None,
123            unixsocket: None,
124            unixsocketperm: None,
125            timeout: None,
126            tcp_keepalive: None,
127            tls_port: None,
128            tls_cert_file: None,
129            tls_key_file: None,
130            tls_ca_cert_file: None,
131            tls_auth_clients: None,
132            daemonize: true,
133            dir: std::env::temp_dir().join("redis-server-wrapper"),
134            loglevel: LogLevel::Notice,
135            databases: None,
136            maxmemory: None,
137            maxmemory_policy: None,
138            maxclients: None,
139            save: false,
140            appendonly: false,
141            replicaof: None,
142            masterauth: None,
143            password: None,
144            acl_file: None,
145            cluster_enabled: false,
146            cluster_node_timeout: None,
147            loadmodule: Vec::new(),
148            hz: None,
149            io_threads: None,
150            io_threads_do_reads: None,
151            notify_keyspace_events: None,
152            extra: HashMap::new(),
153            redis_server_bin: "redis-server".into(),
154            redis_cli_bin: "redis-cli".into(),
155        }
156    }
157}
158
159/// Builder for a Redis server.
160pub struct RedisServer {
161    config: RedisServerConfig,
162}
163
164impl RedisServer {
165    pub fn new() -> Self {
166        Self {
167            config: RedisServerConfig::default(),
168        }
169    }
170
171    // -- network --
172
173    /// Set the listening port (default: 6379).
174    pub fn port(mut self, port: u16) -> Self {
175        self.config.port = port;
176        self
177    }
178
179    /// Set the bind address (default: `127.0.0.1`).
180    pub fn bind(mut self, bind: impl Into<String>) -> Self {
181        self.config.bind = bind.into();
182        self
183    }
184
185    /// Enable or disable protected mode (default: off).
186    pub fn protected_mode(mut self, protected: bool) -> Self {
187        self.config.protected_mode = protected;
188        self
189    }
190
191    /// Set the TCP backlog queue length.
192    pub fn tcp_backlog(mut self, backlog: u32) -> Self {
193        self.config.tcp_backlog = Some(backlog);
194        self
195    }
196
197    /// Set a Unix socket path for connections.
198    pub fn unixsocket(mut self, path: impl Into<PathBuf>) -> Self {
199        self.config.unixsocket = Some(path.into());
200        self
201    }
202
203    /// Set Unix socket permissions (e.g. `700`).
204    pub fn unixsocketperm(mut self, perm: u32) -> Self {
205        self.config.unixsocketperm = Some(perm);
206        self
207    }
208
209    /// Close idle client connections after this many seconds (0 = disabled).
210    pub fn timeout(mut self, seconds: u32) -> Self {
211        self.config.timeout = Some(seconds);
212        self
213    }
214
215    /// Set TCP keepalive interval in seconds.
216    pub fn tcp_keepalive(mut self, seconds: u32) -> Self {
217        self.config.tcp_keepalive = Some(seconds);
218        self
219    }
220
221    // -- tls --
222
223    /// Set TLS listening port.
224    pub fn tls_port(mut self, port: u16) -> Self {
225        self.config.tls_port = Some(port);
226        self
227    }
228
229    /// Set the TLS certificate file path.
230    pub fn tls_cert_file(mut self, path: impl Into<PathBuf>) -> Self {
231        self.config.tls_cert_file = Some(path.into());
232        self
233    }
234
235    /// Set the TLS private key file path.
236    pub fn tls_key_file(mut self, path: impl Into<PathBuf>) -> Self {
237        self.config.tls_key_file = Some(path.into());
238        self
239    }
240
241    /// Set the TLS CA certificate file path.
242    pub fn tls_ca_cert_file(mut self, path: impl Into<PathBuf>) -> Self {
243        self.config.tls_ca_cert_file = Some(path.into());
244        self
245    }
246
247    /// Require TLS client authentication.
248    pub fn tls_auth_clients(mut self, require: bool) -> Self {
249        self.config.tls_auth_clients = Some(require);
250        self
251    }
252
253    // -- general --
254
255    /// Set the working directory for data files.
256    pub fn dir(mut self, dir: impl Into<PathBuf>) -> Self {
257        self.config.dir = dir.into();
258        self
259    }
260
261    /// Set the log level (default: [`LogLevel::Notice`]).
262    pub fn loglevel(mut self, level: LogLevel) -> Self {
263        self.config.loglevel = level;
264        self
265    }
266
267    /// Set the number of databases (default: 16).
268    pub fn databases(mut self, n: u32) -> Self {
269        self.config.databases = Some(n);
270        self
271    }
272
273    // -- memory --
274
275    /// Set the maximum memory limit (e.g. `"256mb"`, `"1gb"`).
276    pub fn maxmemory(mut self, limit: impl Into<String>) -> Self {
277        self.config.maxmemory = Some(limit.into());
278        self
279    }
280
281    /// Set the eviction policy when maxmemory is reached.
282    pub fn maxmemory_policy(mut self, policy: impl Into<String>) -> Self {
283        self.config.maxmemory_policy = Some(policy.into());
284        self
285    }
286
287    /// Set the maximum number of simultaneous client connections.
288    pub fn maxclients(mut self, n: u32) -> Self {
289        self.config.maxclients = Some(n);
290        self
291    }
292
293    // -- persistence --
294
295    /// Enable or disable RDB snapshots (default: off).
296    pub fn save(mut self, save: bool) -> Self {
297        self.config.save = save;
298        self
299    }
300
301    /// Enable or disable AOF persistence.
302    pub fn appendonly(mut self, appendonly: bool) -> Self {
303        self.config.appendonly = appendonly;
304        self
305    }
306
307    // -- replication --
308
309    /// Configure this server as a replica of the given master.
310    pub fn replicaof(mut self, host: impl Into<String>, port: u16) -> Self {
311        self.config.replicaof = Some((host.into(), port));
312        self
313    }
314
315    /// Set the password for authenticating with a master.
316    pub fn masterauth(mut self, password: impl Into<String>) -> Self {
317        self.config.masterauth = Some(password.into());
318        self
319    }
320
321    // -- security --
322
323    /// Set a `requirepass` password for client connections.
324    pub fn password(mut self, password: impl Into<String>) -> Self {
325        self.config.password = Some(password.into());
326        self
327    }
328
329    /// Set the path to an ACL file.
330    pub fn acl_file(mut self, path: impl Into<PathBuf>) -> Self {
331        self.config.acl_file = Some(path.into());
332        self
333    }
334
335    // -- cluster --
336
337    /// Enable Redis Cluster mode.
338    pub fn cluster_enabled(mut self, enabled: bool) -> Self {
339        self.config.cluster_enabled = enabled;
340        self
341    }
342
343    /// Set the cluster node timeout in milliseconds.
344    pub fn cluster_node_timeout(mut self, ms: u64) -> Self {
345        self.config.cluster_node_timeout = Some(ms);
346        self
347    }
348
349    // -- modules --
350
351    /// Load a Redis module at startup.
352    pub fn loadmodule(mut self, path: impl Into<PathBuf>) -> Self {
353        self.config.loadmodule.push(path.into());
354        self
355    }
356
357    // -- advanced --
358
359    /// Set the server tick frequency in Hz (default: 10).
360    pub fn hz(mut self, hz: u32) -> Self {
361        self.config.hz = Some(hz);
362        self
363    }
364
365    /// Set the number of I/O threads.
366    pub fn io_threads(mut self, n: u32) -> Self {
367        self.config.io_threads = Some(n);
368        self
369    }
370
371    /// Enable I/O threads for reads as well as writes.
372    pub fn io_threads_do_reads(mut self, enable: bool) -> Self {
373        self.config.io_threads_do_reads = Some(enable);
374        self
375    }
376
377    /// Set keyspace notification events (e.g. `"KEA"`).
378    pub fn notify_keyspace_events(mut self, events: impl Into<String>) -> Self {
379        self.config.notify_keyspace_events = Some(events.into());
380        self
381    }
382
383    // -- binary paths --
384
385    /// Set a custom `redis-server` binary path.
386    pub fn redis_server_bin(mut self, bin: impl Into<String>) -> Self {
387        self.config.redis_server_bin = bin.into();
388        self
389    }
390
391    /// Set a custom `redis-cli` binary path.
392    pub fn redis_cli_bin(mut self, bin: impl Into<String>) -> Self {
393        self.config.redis_cli_bin = bin.into();
394        self
395    }
396
397    /// Set an arbitrary config directive not covered by dedicated methods.
398    pub fn extra(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
399        self.config.extra.insert(key.into(), value.into());
400        self
401    }
402
403    /// Start the server. Returns a handle that stops the server on Drop.
404    ///
405    /// Verifies that `redis-server` and `redis-cli` binaries are available
406    /// before attempting to launch anything.
407    pub async fn start(self) -> Result<RedisServerHandle> {
408        if which::which(&self.config.redis_server_bin).is_err() {
409            return Err(Error::BinaryNotFound {
410                binary: self.config.redis_server_bin.clone(),
411            });
412        }
413        if which::which(&self.config.redis_cli_bin).is_err() {
414            return Err(Error::BinaryNotFound {
415                binary: self.config.redis_cli_bin.clone(),
416            });
417        }
418
419        let node_dir = self.config.dir.join(format!("node-{}", self.config.port));
420        fs::create_dir_all(&node_dir)?;
421
422        let conf_path = node_dir.join("redis.conf");
423        let conf_content = self.generate_config(&node_dir);
424        fs::write(&conf_path, conf_content)?;
425
426        let status = Command::new(&self.config.redis_server_bin)
427            .arg(&conf_path)
428            .stdout(std::process::Stdio::null())
429            .stderr(std::process::Stdio::null())
430            .status()
431            .await?;
432
433        if !status.success() {
434            return Err(Error::ServerStart {
435                port: self.config.port,
436            });
437        }
438
439        let mut cli = RedisCli::new()
440            .bin(&self.config.redis_cli_bin)
441            .host(&self.config.bind)
442            .port(self.config.port);
443        if let Some(ref pw) = self.config.password {
444            cli = cli.password(pw);
445        }
446
447        cli.wait_for_ready(Duration::from_secs(10)).await?;
448
449        Ok(RedisServerHandle {
450            config: self.config,
451            cli,
452        })
453    }
454
455    fn generate_config(&self, node_dir: &std::path::Path) -> String {
456        let yn = |b: bool| if b { "yes" } else { "no" };
457
458        let mut conf = format!(
459            "port {port}\n\
460             bind {bind}\n\
461             daemonize {daemonize}\n\
462             pidfile {dir}/redis.pid\n\
463             logfile {dir}/redis.log\n\
464             dir {dir}\n\
465             loglevel {level}\n\
466             protected-mode {protected}\n",
467            port = self.config.port,
468            bind = self.config.bind,
469            daemonize = yn(self.config.daemonize),
470            dir = node_dir.display(),
471            level = self.config.loglevel,
472            protected = yn(self.config.protected_mode),
473        );
474
475        // -- network --
476        if let Some(backlog) = self.config.tcp_backlog {
477            conf.push_str(&format!("tcp-backlog {backlog}\n"));
478        }
479        if let Some(ref path) = self.config.unixsocket {
480            conf.push_str(&format!("unixsocket {}\n", path.display()));
481        }
482        if let Some(perm) = self.config.unixsocketperm {
483            conf.push_str(&format!("unixsocketperm {perm}\n"));
484        }
485        if let Some(t) = self.config.timeout {
486            conf.push_str(&format!("timeout {t}\n"));
487        }
488        if let Some(ka) = self.config.tcp_keepalive {
489            conf.push_str(&format!("tcp-keepalive {ka}\n"));
490        }
491
492        // -- tls --
493        if let Some(port) = self.config.tls_port {
494            conf.push_str(&format!("tls-port {port}\n"));
495        }
496        if let Some(ref path) = self.config.tls_cert_file {
497            conf.push_str(&format!("tls-cert-file {}\n", path.display()));
498        }
499        if let Some(ref path) = self.config.tls_key_file {
500            conf.push_str(&format!("tls-key-file {}\n", path.display()));
501        }
502        if let Some(ref path) = self.config.tls_ca_cert_file {
503            conf.push_str(&format!("tls-ca-cert-file {}\n", path.display()));
504        }
505        if let Some(auth) = self.config.tls_auth_clients {
506            conf.push_str(&format!("tls-auth-clients {}\n", yn(auth)));
507        }
508
509        // -- general --
510        if let Some(n) = self.config.databases {
511            conf.push_str(&format!("databases {n}\n"));
512        }
513
514        // -- memory --
515        if let Some(ref limit) = self.config.maxmemory {
516            conf.push_str(&format!("maxmemory {limit}\n"));
517        }
518        if let Some(ref policy) = self.config.maxmemory_policy {
519            conf.push_str(&format!("maxmemory-policy {policy}\n"));
520        }
521        if let Some(n) = self.config.maxclients {
522            conf.push_str(&format!("maxclients {n}\n"));
523        }
524
525        // -- persistence --
526        if !self.config.save {
527            conf.push_str("save \"\"\n");
528        }
529        if self.config.appendonly {
530            conf.push_str("appendonly yes\n");
531        }
532
533        // -- replication --
534        if let Some((ref host, port)) = self.config.replicaof {
535            conf.push_str(&format!("replicaof {host} {port}\n"));
536        }
537        if let Some(ref pw) = self.config.masterauth {
538            conf.push_str(&format!("masterauth {pw}\n"));
539        }
540
541        // -- security --
542        if let Some(ref pw) = self.config.password {
543            conf.push_str(&format!("requirepass {pw}\n"));
544        }
545        if let Some(ref path) = self.config.acl_file {
546            conf.push_str(&format!("aclfile {}\n", path.display()));
547        }
548
549        // -- cluster --
550        if self.config.cluster_enabled {
551            conf.push_str("cluster-enabled yes\n");
552            conf.push_str(&format!(
553                "cluster-config-file {}/nodes.conf\n",
554                node_dir.display()
555            ));
556            if let Some(timeout) = self.config.cluster_node_timeout {
557                conf.push_str(&format!("cluster-node-timeout {timeout}\n"));
558            }
559        }
560
561        // -- modules --
562        for path in &self.config.loadmodule {
563            conf.push_str(&format!("loadmodule {}\n", path.display()));
564        }
565
566        // -- advanced --
567        if let Some(hz) = self.config.hz {
568            conf.push_str(&format!("hz {hz}\n"));
569        }
570        if let Some(n) = self.config.io_threads {
571            conf.push_str(&format!("io-threads {n}\n"));
572        }
573        if let Some(enable) = self.config.io_threads_do_reads {
574            conf.push_str(&format!("io-threads-do-reads {}\n", yn(enable)));
575        }
576        if let Some(ref events) = self.config.notify_keyspace_events {
577            conf.push_str(&format!("notify-keyspace-events {events}\n"));
578        }
579
580        // -- catch-all --
581        for (key, value) in &self.config.extra {
582            conf.push_str(&format!("{key} {value}\n"));
583        }
584
585        conf
586    }
587}
588
589impl Default for RedisServer {
590    fn default() -> Self {
591        Self::new()
592    }
593}
594
595/// Handle to a running Redis server. Stops the server on Drop.
596pub struct RedisServerHandle {
597    config: RedisServerConfig,
598    cli: RedisCli,
599}
600
601impl RedisServerHandle {
602    /// The server's address as "host:port".
603    pub fn addr(&self) -> String {
604        format!("{}:{}", self.config.bind, self.config.port)
605    }
606
607    /// The server's port.
608    pub fn port(&self) -> u16 {
609        self.config.port
610    }
611
612    /// The server's bind address.
613    pub fn host(&self) -> &str {
614        &self.config.bind
615    }
616
617    /// Check if the server is alive via PING.
618    pub async fn is_alive(&self) -> bool {
619        self.cli.ping().await
620    }
621
622    /// Get a `RedisCli` configured for this server.
623    pub fn cli(&self) -> &RedisCli {
624        &self.cli
625    }
626
627    /// Run a redis-cli command against this server.
628    pub async fn run(&self, args: &[&str]) -> Result<String> {
629        self.cli.run(args).await
630    }
631
632    /// Stop the server via SHUTDOWN NOSAVE.
633    pub fn stop(&self) {
634        self.cli.shutdown();
635    }
636
637    /// Wait until the server is ready (PING -> PONG).
638    pub async fn wait_for_ready(&self, timeout: Duration) -> Result<()> {
639        self.cli.wait_for_ready(timeout).await
640    }
641}
642
643impl Drop for RedisServerHandle {
644    fn drop(&mut self) {
645        self.stop();
646    }
647}
648
649#[cfg(test)]
650mod tests {
651    use super::*;
652
653    #[test]
654    fn default_config() {
655        let s = RedisServer::new();
656        assert_eq!(s.config.port, 6379);
657        assert_eq!(s.config.bind, "127.0.0.1");
658        assert!(!s.config.save);
659    }
660
661    #[test]
662    fn builder_chain() {
663        let s = RedisServer::new()
664            .port(6400)
665            .bind("0.0.0.0")
666            .save(true)
667            .appendonly(true)
668            .password("secret")
669            .loglevel(LogLevel::Warning)
670            .extra("maxmemory", "100mb");
671
672        assert_eq!(s.config.port, 6400);
673        assert_eq!(s.config.bind, "0.0.0.0");
674        assert!(s.config.save);
675        assert!(s.config.appendonly);
676        assert_eq!(s.config.password.as_deref(), Some("secret"));
677        assert_eq!(s.config.extra.get("maxmemory").unwrap(), "100mb");
678    }
679
680    #[test]
681    fn cluster_config() {
682        let s = RedisServer::new()
683            .port(7000)
684            .cluster_enabled(true)
685            .cluster_node_timeout(5000);
686
687        assert!(s.config.cluster_enabled);
688        assert_eq!(s.config.cluster_node_timeout, Some(5000));
689    }
690}