1use async_trait::async_trait;
2use rsmc_core::client::Connection;
3use std::{ops::DerefMut, sync::Arc};
4use tokio::{
5 io::{AsyncReadExt, AsyncWriteExt},
6 net::TcpStream,
7 sync::Mutex,
8};
9
10pub use rsmc_core::client::{ClientConfig, Compressor, Error, NoCompressor};
11#[cfg(feature = "zlib")]
12pub use rsmc_core::zlib::ZlibCompressor;
13
14pub type Pool<P> = rsmc_core::client::Pool<TokioConnection, P>;
25
26#[derive(Debug, Clone)]
29pub struct TokioConnection {
30 stream: Arc<Mutex<TcpStream>>,
31}
32
33#[async_trait]
34impl Connection for TokioConnection {
35 async fn connect(url: String) -> Result<Self, Error> {
36 let stream = TcpStream::connect(url).await?;
37 let stream = Arc::new(Mutex::new(stream));
38 Ok(TokioConnection { stream })
39 }
40
41 async fn read(&mut self, buf: &mut Vec<u8>) -> Result<usize, Error> {
42 let mut lock = self.stream.lock().await;
43 let stream = lock.deref_mut();
44 Ok(stream.read(buf).await?)
45 }
46
47 async fn write(&mut self, data: &[u8]) -> Result<(), Error> {
48 let mut lock = self.stream.lock().await;
49 let stream = lock.deref_mut();
50 Ok(stream.write_all(data).await?)
51 }
52}
53
54#[cfg(test)]
55mod test {
56 use flate2::Compression;
57 use futures::Future;
58 use rand::prelude::*;
59 use std::{
60 collections::HashMap,
61 io::{BufRead, BufReader},
62 process::{Child, Command, Stdio},
63 };
64
65 use super::*;
66
67 #[derive(Debug)]
68 struct MemcachedTester {
69 names: Vec<String>,
70 procs: Vec<Child>,
71 }
72
73 impl MemcachedTester {
74 fn new(port: usize) -> Self {
75 let name = format!("test_memcached_{}", port);
76 let proc = MemcachedTester::new_proc(&name, port);
77
78 Self {
79 procs: vec![proc],
80 names: vec![name],
81 }
82 }
83
84 fn new_cluster(ports: Vec<usize>) -> Self {
85 let (names, procs) = ports
86 .into_iter()
87 .enumerate()
88 .map(|(i, port)| {
89 let name = format!("test_memcached_{}", i);
90 let proc = MemcachedTester::new_proc(&name, port);
91 (name, proc)
92 })
93 .unzip();
94
95 Self { procs, names }
96 }
97
98 fn new_proc(name: &str, port: usize) -> Child {
99 let mut proc = Command::new("docker")
100 .args(&[
101 "run",
102 "--rm",
103 "-t",
104 "--name",
105 name,
106 "-p",
107 &format!("{}:11211", port),
108 "memcached",
109 "memcached",
110 "-vv",
111 ])
112 .stdout(Stdio::piped())
113 .spawn()
114 .unwrap();
115
116 let stdout = proc.stdout.as_mut().unwrap();
117 let mut reader = BufReader::new(stdout);
118 let mut buf = String::new();
119 reader.read_line(&mut buf).unwrap();
120
121 proc
122 }
123
124 fn run<F: Future>(self, call: F) {
125 tokio_test::block_on(call);
126 }
127 }
128
129 impl Drop for MemcachedTester {
130 fn drop(&mut self) {
131 for name in self.names.iter() {
132 Command::new("docker")
133 .args(&["stop", &name])
134 .output()
135 .unwrap();
136 }
137
138 for proc in self.procs.iter_mut() {
139 proc.wait().unwrap();
140 }
141 }
142 }
143
144 #[test]
145 fn test_connect() {
146 let mut rng = rand::thread_rng();
147 let random_port = rng.gen_range(10000..20000);
148 MemcachedTester::new(random_port).run(async {
149 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
150 let host = format!("127.0.0.1:{}", random_port);
151 TokioConnection::connect(host).await.unwrap();
152 })
153 }
154
155 async fn test_run<P: Compressor>(pool: Pool<P>) {
156 let mut client = pool.get().await.unwrap();
157
158 for (k, v) in &[
159 ("key", "value"),
160 ("hello", "world"),
161 ("abc", "123"),
162 ("dead", "beef"),
163 ] {
164 assert_eq!(None, client.get::<_, String>(k).await.unwrap());
165 assert_eq!(None, client.get::<_, String>(k).await.unwrap());
166 assert_eq!((), client.set(k, v, 1).await.unwrap());
167 let expect = Some(v.to_string());
168 let actual = client.get::<_, String>(k).await.unwrap();
169 assert_eq!(expect, actual);
170
171 assert_eq!((), client.delete(k).await.unwrap());
172 assert_eq!(None, client.get::<_, String>(k).await.unwrap());
173 }
174
175 for map in &[
176 &[("key", "value"), ("hello", "world")],
177 &[("abc", "123"), ("dead", "beef")],
178 ] {
179 let keys = map.iter().map(|(k, _)| k.as_bytes()).collect::<Vec<_>>();
180 let hash_map = map.iter().fold(HashMap::new(), |mut acc, (k, v)| {
181 acc.insert(k.as_bytes(), *v);
182 acc
183 });
184
185 let (result, _) = client.get_multi::<_, String>(&keys).await.unwrap();
186 assert_eq!(0, result.len());
187
188 client.set_multi(hash_map.clone(), 1).await.unwrap();
189
190 let get_keys = [keys.clone(), vec![b"not found"]].concat();
191 let (result, _) = client.get_multi::<_, String>(&get_keys).await.unwrap();
192 assert_eq!(keys.len(), result.len());
193 result.into_iter().for_each(|(k, v)| {
194 let expect = hash_map.get(&k[..]).unwrap();
195 assert_eq!(*expect, v);
196 });
197
198 client.delete_multi(&keys).await.unwrap();
199 let (result, _) = client.get_multi::<_, String>(&keys).await.unwrap();
200 assert_eq!(0, result.len());
201 }
202 }
203
204 #[test]
205 fn test_single_connection() {
206 let mut rng = rand::thread_rng();
207 let random_port = rng.gen_range(20000..30000);
208 MemcachedTester::new(random_port).run(async {
209 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
210 let host = format!("127.0.0.1:{}", random_port);
211 let cfg = ClientConfig::new_uncompressed(vec![host]);
212 let pool = Pool::builder(cfg).max_size(16).build().unwrap();
213 test_run(pool).await;
214 });
215 }
216
217 #[test]
218 fn test_cluster() {
219 let rng = &mut rand::thread_rng();
220 let mut random_ports = (30001..40000).collect::<Vec<_>>();
221 random_ports.shuffle(rng);
222 let random_ports: Vec<_> = random_ports[0..3].into();
223 MemcachedTester::new_cluster(random_ports.clone()).run(async {
224 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
225 let cfg = ClientConfig::new(
226 random_ports
227 .into_iter()
228 .map(|port| format!("127.0.0.1:{}", port))
229 .collect(),
230 ZlibCompressor::new(Compression::default(), 1),
231 );
232 let pool = Pool::builder(cfg).max_size(16).build().unwrap();
233 test_run(pool).await;
234 });
235 }
236}