rsmc_tokio/
lib.rs

1use async_trait::async_trait;
2use rsmc_core::client::Connection;
3use std::{ops::DerefMut, sync::Arc};
4use tokio::{
5    io::{AsyncReadExt, AsyncWriteExt},
6    net::TcpStream,
7    sync::Mutex,
8};
9
10pub use rsmc_core::client::{ClientConfig, Compressor, Error, NoCompressor};
11#[cfg(feature = "zlib")]
12pub use rsmc_core::zlib::ZlibCompressor;
13
14/// A pool of connections to memcached using tokio for async I/O and
15/// the desired compression scheme. Use this to create a connection pool.
16/// For example:
17///
18/// ```ignore
19/// use rsmc_tokio::{Pool, ClientConfig};
20///
21/// let cfg = ClientConfig::new_uncompressed(vec!["localhost:11211".into()]);
22/// let pool = Pool::builder(cfg).max_size(16).build().unwrap();
23/// ```
24pub type Pool<P> = rsmc_core::client::Pool<TokioConnection, P>;
25
26/// A TokioConnection uses the tokio runtime to form TCP connections to
27/// memcached.
28#[derive(Debug, Clone)]
29pub struct TokioConnection {
30    stream: Arc<Mutex<TcpStream>>,
31}
32
33#[async_trait]
34impl Connection for TokioConnection {
35    async fn connect(url: String) -> Result<Self, Error> {
36        let stream = TcpStream::connect(url).await?;
37        let stream = Arc::new(Mutex::new(stream));
38        Ok(TokioConnection { stream })
39    }
40
41    async fn read(&mut self, buf: &mut Vec<u8>) -> Result<usize, Error> {
42        let mut lock = self.stream.lock().await;
43        let stream = lock.deref_mut();
44        Ok(stream.read(buf).await?)
45    }
46
47    async fn write(&mut self, data: &[u8]) -> Result<(), Error> {
48        let mut lock = self.stream.lock().await;
49        let stream = lock.deref_mut();
50        Ok(stream.write_all(data).await?)
51    }
52}
53
54#[cfg(test)]
55mod test {
56    use flate2::Compression;
57    use futures::Future;
58    use rand::prelude::*;
59    use std::{
60        collections::HashMap,
61        io::{BufRead, BufReader},
62        process::{Child, Command, Stdio},
63    };
64
65    use super::*;
66
67    #[derive(Debug)]
68    struct MemcachedTester {
69        names: Vec<String>,
70        procs: Vec<Child>,
71    }
72
73    impl MemcachedTester {
74        fn new(port: usize) -> Self {
75            let name = format!("test_memcached_{}", port);
76            let proc = MemcachedTester::new_proc(&name, port);
77
78            Self {
79                procs: vec![proc],
80                names: vec![name],
81            }
82        }
83
84        fn new_cluster(ports: Vec<usize>) -> Self {
85            let (names, procs) = ports
86                .into_iter()
87                .enumerate()
88                .map(|(i, port)| {
89                    let name = format!("test_memcached_{}", i);
90                    let proc = MemcachedTester::new_proc(&name, port);
91                    (name, proc)
92                })
93                .unzip();
94
95            Self { procs, names }
96        }
97
98        fn new_proc(name: &str, port: usize) -> Child {
99            let mut proc = Command::new("docker")
100                .args(&[
101                    "run",
102                    "--rm",
103                    "-t",
104                    "--name",
105                    name,
106                    "-p",
107                    &format!("{}:11211", port),
108                    "memcached",
109                    "memcached",
110                    "-vv",
111                ])
112                .stdout(Stdio::piped())
113                .spawn()
114                .unwrap();
115
116            let stdout = proc.stdout.as_mut().unwrap();
117            let mut reader = BufReader::new(stdout);
118            let mut buf = String::new();
119            reader.read_line(&mut buf).unwrap();
120
121            proc
122        }
123
124        fn run<F: Future>(self, call: F) {
125            tokio_test::block_on(call);
126        }
127    }
128
129    impl Drop for MemcachedTester {
130        fn drop(&mut self) {
131            for name in self.names.iter() {
132                Command::new("docker")
133                    .args(&["stop", &name])
134                    .output()
135                    .unwrap();
136            }
137
138            for proc in self.procs.iter_mut() {
139                proc.wait().unwrap();
140            }
141        }
142    }
143
144    #[test]
145    fn test_connect() {
146        let mut rng = rand::thread_rng();
147        let random_port = rng.gen_range(10000..20000);
148        MemcachedTester::new(random_port).run(async {
149            tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
150            let host = format!("127.0.0.1:{}", random_port);
151            TokioConnection::connect(host).await.unwrap();
152        })
153    }
154
155    async fn test_run<P: Compressor>(pool: Pool<P>) {
156        let mut client = pool.get().await.unwrap();
157
158        for (k, v) in &[
159            ("key", "value"),
160            ("hello", "world"),
161            ("abc", "123"),
162            ("dead", "beef"),
163        ] {
164            assert_eq!(None, client.get::<_, String>(k).await.unwrap());
165            assert_eq!(None, client.get::<_, String>(k).await.unwrap());
166            assert_eq!((), client.set(k, v, 1).await.unwrap());
167            let expect = Some(v.to_string());
168            let actual = client.get::<_, String>(k).await.unwrap();
169            assert_eq!(expect, actual);
170
171            assert_eq!((), client.delete(k).await.unwrap());
172            assert_eq!(None, client.get::<_, String>(k).await.unwrap());
173        }
174
175        for map in &[
176            &[("key", "value"), ("hello", "world")],
177            &[("abc", "123"), ("dead", "beef")],
178        ] {
179            let keys = map.iter().map(|(k, _)| k.as_bytes()).collect::<Vec<_>>();
180            let hash_map = map.iter().fold(HashMap::new(), |mut acc, (k, v)| {
181                acc.insert(k.as_bytes(), *v);
182                acc
183            });
184
185            let (result, _) = client.get_multi::<_, String>(&keys).await.unwrap();
186            assert_eq!(0, result.len());
187
188            client.set_multi(hash_map.clone(), 1).await.unwrap();
189
190            let get_keys = [keys.clone(), vec![b"not found"]].concat();
191            let (result, _) = client.get_multi::<_, String>(&get_keys).await.unwrap();
192            assert_eq!(keys.len(), result.len());
193            result.into_iter().for_each(|(k, v)| {
194                let expect = hash_map.get(&k[..]).unwrap();
195                assert_eq!(*expect, v);
196            });
197
198            client.delete_multi(&keys).await.unwrap();
199            let (result, _) = client.get_multi::<_, String>(&keys).await.unwrap();
200            assert_eq!(0, result.len());
201        }
202    }
203
204    #[test]
205    fn test_single_connection() {
206        let mut rng = rand::thread_rng();
207        let random_port = rng.gen_range(20000..30000);
208        MemcachedTester::new(random_port).run(async {
209            tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
210            let host = format!("127.0.0.1:{}", random_port);
211            let cfg = ClientConfig::new_uncompressed(vec![host]);
212            let pool = Pool::builder(cfg).max_size(16).build().unwrap();
213            test_run(pool).await;
214        });
215    }
216
217    #[test]
218    fn test_cluster() {
219        let rng = &mut rand::thread_rng();
220        let mut random_ports = (30001..40000).collect::<Vec<_>>();
221        random_ports.shuffle(rng);
222        let random_ports: Vec<_> = random_ports[0..3].into();
223        MemcachedTester::new_cluster(random_ports.clone()).run(async {
224            tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
225            let cfg = ClientConfig::new(
226                random_ports
227                    .into_iter()
228                    .map(|port| format!("127.0.0.1:{}", port))
229                    .collect(),
230                ZlibCompressor::new(Compression::default(), 1),
231            );
232            let pool = Pool::builder(cfg).max_size(16).build().unwrap();
233            test_run(pool).await;
234        });
235    }
236}