1use crate::core::{ErrorKind, Key, Memcached, Meta, MtopError, SlabItems, Slabs, Stats, Value};
2use crate::discovery::{Server, ServerID};
3use crate::net::{self, TlsConfig};
4use crate::pool::{ClientFactory, ClientPool, ClientPoolConfig, PooledClient};
5use async_trait::async_trait;
6use std::collections::HashMap;
7use std::fmt;
8use std::hash::DefaultHasher;
9use std::hash::{Hash, Hasher};
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::runtime::Handle;
13use tokio::task::JoinHandle;
14use tokio_rustls::rustls::ClientConfig;
15use tokio_rustls::rustls::pki_types::ServerName;
16use tracing::instrument::WithSubscriber;
17
18#[derive(Debug)]
21pub struct TlsTcpClientFactory {
22 client_config: Arc<ClientConfig>,
23 server_name: Option<ServerName<'static>>,
24}
25
26impl TlsTcpClientFactory {
27 pub async fn new(tls: TlsConfig) -> Result<Self, MtopError> {
28 let server_name = tls.server_name.clone();
29 let client_config = Arc::new(net::tls_client_config(tls).await?);
30
31 Ok(Self {
32 client_config,
33 server_name,
34 })
35 }
36}
37
38#[async_trait]
39impl ClientFactory<Server, Memcached> for TlsTcpClientFactory {
40 async fn make(&self, key: &Server) -> Result<Memcached, MtopError> {
41 let server_name = self
42 .server_name
43 .clone()
44 .or_else(|| key.server_name().clone())
45 .expect("TLS server name must be set on each server when using TlsTcpClientFactory: this is a bug");
46
47 let (read, write) = match key.id() {
48 ServerID::Socket(sock) => net::tcp_tls_connect(sock, server_name, self.client_config.clone()).await?,
49 ServerID::Name(name) => net::tcp_tls_connect(name, server_name, self.client_config.clone()).await?,
50 id => panic!("unexpected {:?} passed to TlsTcpClientFactory: this is a bug", id),
51 };
52
53 Ok(Memcached::new(read, write))
54 }
55}
56
57#[derive(Debug)]
60pub struct TcpClientFactory;
61
62#[async_trait]
63impl ClientFactory<Server, Memcached> for TcpClientFactory {
64 async fn make(&self, key: &Server) -> Result<Memcached, MtopError> {
65 let (read, write) = match key.id() {
66 ServerID::Socket(sock) => net::tcp_connect(sock).await?,
67 ServerID::Name(name) => net::tcp_connect(name).await?,
68 id => panic!("unexpected {:?} passed to TcpClientFactory: this is a bug", id),
69 };
70
71 Ok(Memcached::new(read, write))
72 }
73}
74
75#[cfg(unix)]
78#[derive(Debug)]
79pub struct UnixClientFactory;
80
81#[cfg(unix)]
82#[async_trait]
83impl ClientFactory<Server, Memcached> for UnixClientFactory {
84 async fn make(&self, key: &Server) -> Result<Memcached, MtopError> {
85 let (read, write) = match key.id() {
86 ServerID::Path(path) => net::unix_connect(path).await?,
87 id => panic!("unexpected {:?} passed to UnixClientFactory: this is a bug", id),
88 };
89
90 Ok(Memcached::new(read, write))
91 }
92}
93
94pub trait Selector {
96 fn servers(&self) -> Vec<Server>;
98
99 fn server(&self, key: &Key) -> Result<Server, MtopError>;
101}
102
103#[derive(Debug)]
108pub struct RendezvousSelector {
109 servers: Vec<Server>,
110}
111
112impl RendezvousSelector {
113 pub fn new(servers: Vec<Server>) -> Self {
115 Self { servers }
116 }
117
118 fn score(server: &Server, key: &Key) -> u64 {
119 let mut hasher = DefaultHasher::new();
120
121 match server.id() {
123 ServerID::Name(name) => name.hash(&mut hasher),
124 ServerID::Socket(addr) => addr.hash(&mut hasher),
125 ServerID::Path(path) => path.hash(&mut hasher),
126 }
127
128 hasher.write(key.as_ref().as_bytes());
129 hasher.finish()
130 }
131}
132
133impl Selector for RendezvousSelector {
134 fn servers(&self) -> Vec<Server> {
135 self.servers.clone()
136 }
137
138 fn server(&self, key: &Key) -> Result<Server, MtopError> {
139 if self.servers.is_empty() {
140 Err(MtopError::runtime("no servers available"))
141 } else if self.servers.len() == 1 {
142 Ok(self.servers.first().cloned().unwrap())
143 } else {
144 let mut max = u64::MIN;
145 let mut choice = None;
146
147 for server in self.servers.iter() {
148 let score = Self::score(server, key);
149 if score >= max {
150 choice = Some(server);
151 max = score;
152 }
153 }
154
155 Ok(choice.cloned().unwrap())
156 }
157 }
158}
159
160#[derive(Debug, Default)]
162pub struct ServersResponse<T> {
163 pub values: HashMap<ServerID, T>,
164 pub errors: HashMap<ServerID, MtopError>,
165}
166
167#[derive(Debug, Default)]
169pub struct ValuesResponse {
170 pub values: HashMap<String, Value>,
171 pub errors: HashMap<ServerID, MtopError>,
172}
173
174macro_rules! run_for_host {
177 ($pool:expr, $server:expr, $method:ident $(, $args:expr)* $(,)?) => {{
178 let mut conn = $pool.get($server).await?;
179 match conn.$method($($args,)*).await {
180 Ok(v) => {
181 $pool.put(conn).await;
182 Ok(v)
183 }
184 Err(e) => {
185 if e.kind() == ErrorKind::Protocol {
189 $pool.put(conn).await;
190 }
191 Err(e)
192 }
193 }
194 }};
195}
196
197macro_rules! spawn_for_host {
199 ($self:ident, $server:expr, $method:ident $(, $args:expr)* $(,)?) => {{
200 let pool = $self.pool.clone();
201 $self.handle.spawn(async move {
202 run_for_host!(pool, $server, $method, $($args,)*)
203 }
204 .with_current_subscriber())
206 }};
207}
208
209macro_rules! operation_for_key {
211 ($self:ident, $method:ident, $key:expr $(, $args:expr)* $(,)?) => {{
212 let key = Key::one($key)?;
213 let server = $self.selector.server(&key)?;
214
215 run_for_host!($self.pool, &server, $method, &key, $($args,)*)
216 }};
217}
218
219macro_rules! operation_for_all {
221 ($self:ident, $method:ident) => {{
222 let servers = $self.selector.servers();
223 let tasks = servers
224 .into_iter()
225 .map(|server| (server.id().clone(), spawn_for_host!($self, &server, $method)))
226 .collect::<Vec<_>>();
227
228 Ok(collect_results(tasks).await)
229 }};
230}
231
232async fn collect_results<T>(tasks: Vec<(ServerID, JoinHandle<Result<T, MtopError>>)>) -> ServersResponse<T> {
234 let mut values = HashMap::with_capacity(tasks.len());
235 let mut errors = HashMap::new();
236
237 for (id, task) in tasks {
238 match task.await {
239 Ok(Ok(result)) => {
240 values.insert(id, result);
241 }
242 Ok(Err(e)) => {
243 errors.insert(id, e);
244 }
245 Err(e) => {
246 errors.insert(id, MtopError::runtime_cause("fetching cluster values", e));
247 }
248 };
249 }
250
251 ServersResponse { values, errors }
252}
253
254#[derive(Debug, Clone)]
256pub struct MemcachedClientConfig {
257 pub pool_max_idle: u64,
258 pub pool_name: String,
259}
260
261impl Default for MemcachedClientConfig {
262 fn default() -> Self {
263 Self {
264 pool_max_idle: 4,
265 pool_name: "memcached-tcp".to_owned(),
266 }
267 }
268}
269
270pub struct MemcachedClient {
273 handle: Handle,
274 selector: Box<dyn Selector + Send + Sync>,
275 pool: Arc<ClientPool<Server, Memcached>>,
276}
277
278impl MemcachedClient {
279 pub fn new<S, F>(config: MemcachedClientConfig, handle: Handle, selector: S, factory: F) -> Self
286 where
287 S: Selector + Send + Sync + 'static,
288 F: ClientFactory<Server, Memcached> + Send + Sync + 'static,
289 {
290 let pool_config = ClientPoolConfig {
291 name: config.pool_name,
292 max_idle: config.pool_max_idle,
293 };
294
295 Self {
296 handle,
297 selector: Box::new(selector),
298 pool: Arc::new(ClientPool::new(pool_config, factory)),
299 }
300 }
301
302 pub async fn raw_open(&self, server: &Server) -> Result<PooledClient<Server, Memcached>, MtopError> {
305 self.pool.get(server).await
306 }
307
308 pub async fn raw_close(&self, connection: PooledClient<Server, Memcached>) {
312 self.pool.put(connection).await
313 }
314
315 pub async fn stats(&self) -> Result<ServersResponse<Stats>, MtopError> {
321 operation_for_all!(self, stats)
322 }
323
324 pub async fn slabs(&self) -> Result<ServersResponse<Slabs>, MtopError> {
332 operation_for_all!(self, slabs)
333 }
334
335 pub async fn items(&self) -> Result<ServersResponse<SlabItems>, MtopError> {
344 operation_for_all!(self, items)
345 }
346
347 pub async fn metas(&self) -> Result<ServersResponse<Vec<Meta>>, MtopError> {
355 operation_for_all!(self, metas)
356 }
357
358 pub async fn ping(&self) -> Result<ServersResponse<()>, MtopError> {
364 operation_for_all!(self, ping)
365 }
366
367 pub async fn flush_all(&self, wait: Option<Duration>) -> Result<ServersResponse<()>, MtopError> {
375 let servers = self.selector.servers();
379
380 let tasks = servers
381 .into_iter()
382 .enumerate()
383 .map(|(i, server)| {
384 let delay = wait.map(|d| d * i as u32);
385 (server.id().clone(), spawn_for_host!(self, &server, flush_all, delay))
386 })
387 .collect::<Vec<_>>();
388
389 Ok(collect_results(tasks).await)
390 }
391
392 pub async fn get<I, K>(&self, keys: I) -> Result<ValuesResponse, MtopError>
400 where
401 I: IntoIterator<Item = K>,
402 K: Into<String>,
403 {
404 let keys = Key::many(keys)?;
405 if keys.is_empty() {
406 return Ok(ValuesResponse::default());
407 }
408
409 let num_keys = keys.len();
410 let mut by_server: HashMap<Server, Vec<Key>> = HashMap::new();
411 for key in keys {
412 let server = self.selector.server(&key)?;
413 let entry = by_server.entry(server).or_default();
414 entry.push(key);
415 }
416
417 let tasks = by_server
418 .into_iter()
419 .map(|(server, keys)| (server.id().clone(), spawn_for_host!(self, &server, get, &keys)))
420 .collect::<Vec<_>>();
421
422 let mut values = HashMap::with_capacity(num_keys);
423 let mut errors = HashMap::new();
424
425 for (id, task) in tasks {
426 match task.await {
427 Ok(Ok(results)) => {
428 values.extend(results);
429 }
430 Ok(Err(e)) => {
431 errors.insert(id, e);
432 }
433 Err(e) => {
434 errors.insert(id, MtopError::runtime_cause("fetching keys", e));
435 }
436 };
437 }
438
439 Ok(ValuesResponse { values, errors })
440 }
441
442 pub async fn incr<K>(&self, key: K, delta: u64) -> Result<u64, MtopError>
449 where
450 K: Into<String>,
451 {
452 operation_for_key!(self, incr, key, delta)
453 }
454
455 pub async fn decr<K>(&self, key: K, delta: u64) -> Result<u64, MtopError>
463 where
464 K: Into<String>,
465 {
466 operation_for_key!(self, decr, key, delta)
467 }
468
469 pub async fn set<K, V>(&self, key: K, flags: u64, ttl: u32, data: V) -> Result<(), MtopError>
475 where
476 K: Into<String>,
477 V: AsRef<[u8]>,
478 {
479 operation_for_key!(self, set, key, flags, ttl, data)
480 }
481
482 pub async fn add<K, V>(&self, key: K, flags: u64, ttl: u32, data: V) -> Result<(), MtopError>
488 where
489 K: Into<String>,
490 V: AsRef<[u8]>,
491 {
492 operation_for_key!(self, add, key, flags, ttl, data)
493 }
494
495 pub async fn replace<K, V>(&self, key: K, flags: u64, ttl: u32, data: V) -> Result<(), MtopError>
501 where
502 K: Into<String>,
503 V: AsRef<[u8]>,
504 {
505 operation_for_key!(self, replace, key, flags, ttl, data)
506 }
507
508 pub async fn touch<K>(&self, key: K, ttl: u32) -> Result<(), MtopError>
514 where
515 K: Into<String>,
516 {
517 operation_for_key!(self, touch, key, ttl)
518 }
519
520 pub async fn delete<K>(&self, key: K) -> Result<(), MtopError>
526 where
527 K: Into<String>,
528 {
529 operation_for_key!(self, delete, key)
530 }
531}
532
533impl fmt::Debug for MemcachedClient {
534 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
535 f.debug_struct("MemcachedClient")
536 .field("selector", &"...")
537 .field("pool", &self.pool)
538 .finish()
539 }
540}
541
542#[cfg(test)]
543mod test {
544 use super::{MemcachedClient, MemcachedClientConfig, Selector};
545 use crate::core::{ErrorKind, Key, Memcached, MtopError, Value};
546 use crate::discovery::{Server, ServerID};
547 use crate::pool::ClientFactory;
548 use async_trait::async_trait;
549 use rustls_pki_types::ServerName;
550 use std::collections::HashMap;
551 use std::io::Cursor;
552 use std::time::Duration;
553 use tokio::runtime::Handle;
554
555 #[derive(Debug, Default)]
556 struct MockSelector {
557 mapping: HashMap<Key, Server>,
558 }
559
560 impl Selector for MockSelector {
561 fn servers(&self) -> Vec<Server> {
562 self.mapping.values().cloned().collect()
563 }
564
565 fn server(&self, key: &Key) -> Result<Server, MtopError> {
566 self.mapping
567 .get(key)
568 .cloned()
569 .ok_or_else(|| MtopError::runtime("no servers available"))
570 }
571 }
572
573 #[derive(Debug, Default)]
574 struct MockClientFactory {
575 contents: HashMap<Server, Vec<u8>>,
576 }
577
578 #[async_trait]
579 impl ClientFactory<Server, Memcached> for MockClientFactory {
580 async fn make(&self, key: &Server) -> Result<Memcached, MtopError> {
581 let bytes = self
582 .contents
583 .get(key)
584 .cloned()
585 .ok_or_else(|| MtopError::runtime(format!("no server for {:?}", key)))?;
586 let reads = Cursor::new(bytes);
587 Ok(Memcached::new(reads, Vec::new()))
588 }
589 }
590
591 macro_rules! new_client {
592 () => {{
593 let cfg = MemcachedClientConfig::default();
594 let handle = Handle::current();
595 let selector = MockSelector::default();
596 let factory = MockClientFactory::default();
597 MemcachedClient::new(cfg, handle, selector, factory)
598 }};
599
600 ($($host_and_port:expr => $key:expr => $contents:expr$(,)?)*) => {{
601 let mut mapping = HashMap::new();
602 let mut contents = HashMap::new();
603
604 $(
605 let server = {
606 let (host, port_str) = $host_and_port.split_once(':').unwrap();
607 let port: u16 = port_str.parse().unwrap();
608 let id = ServerID::from((host, port));
609 let name = ServerName::try_from(host).unwrap();
610
611 Server::new(id, name)
612 };
613 mapping.insert(Key::one($key).unwrap(), server.clone());
614 contents.insert(server, $contents.to_vec());
615 )*
616
617 let cfg = MemcachedClientConfig::default();
618 let handle = Handle::current();
619 let selector = MockSelector { mapping };
620 let factory = MockClientFactory { contents };
621 MemcachedClient::new(cfg, handle, selector, factory)
622 }};
623 }
624
625 #[tokio::test]
639 async fn test_memcached_client_ping_no_servers() {
640 let client = new_client!();
641 let response = client.ping().await.unwrap();
642
643 assert!(response.values.is_empty());
644 assert!(response.errors.is_empty());
645 }
646
647 #[tokio::test]
648 async fn test_memcached_client_ping_no_errors() {
649 let client = new_client!(
650 "cache01.example.com:11211" => "unused1" => "VERSION 1.6.22\r\n".as_bytes(),
651 "cache02.example.com:11211" => "unused2" => "VERSION 1.6.22\r\n".as_bytes(),
652 );
653 let response = client.ping().await.unwrap();
654
655 assert!(response.values.contains_key(&ServerID::from(("cache01.example.com", 11211))));
656 assert!(response.values.contains_key(&ServerID::from(("cache02.example.com", 11211))));
657 assert!(response.errors.is_empty());
658 }
659
660 #[tokio::test]
661 async fn test_memcached_client_ping_some_errors() {
662 let client = new_client!(
663 "cache01.example.com:11211" => "unused1" => "VERSION 1.6.22\r\n".as_bytes(),
664 "cache02.example.com:11211" => "unused2" => "ERROR Too many open connections\r\n".as_bytes(),
665 );
666 let response = client.ping().await.unwrap();
667
668 assert!(response.values.contains_key(&ServerID::from(("cache01.example.com", 11211))));
669 assert!(response.errors.contains_key(&ServerID::from(("cache02.example.com", 11211))));
670 }
671
672 #[tokio::test]
677 async fn test_memcached_client_set_no_servers() {
678 let client = new_client!();
679 let res = client.set("key1", 1, 60, "foo".as_bytes()).await;
680 let err = res.unwrap_err();
681
682 assert_eq!(ErrorKind::Runtime, err.kind());
683 }
684
685 #[tokio::test]
686 async fn test_memcached_client_set_success() {
687 let client = new_client!(
688 "cache01.example.com:11211" => "key1" => "STORED\r\n".as_bytes(),
689 );
690 client.set("key1", 1, 60, "foo".as_bytes()).await.unwrap();
691 }
692
693 #[tokio::test]
698 async fn test_memcached_client_get_invalid_keys() {
699 let client = new_client!();
700 let res = client.get(vec!["invalid key"]).await;
701 let err = res.unwrap_err();
702
703 assert_eq!(ErrorKind::Runtime, err.kind());
704 }
705
706 #[tokio::test]
707 async fn test_memcached_client_get_no_keys() {
708 let client = new_client!();
709 let keys: Vec<String> = Vec::new();
710 let response = client.get(keys).await.unwrap();
711
712 assert!(response.values.is_empty());
713 assert!(response.errors.is_empty());
714 }
715
716 #[tokio::test]
717 async fn test_memcached_client_get_no_servers() {
718 let client = new_client!();
719 let res = client.get(vec!["key1", "key2"]).await;
720 let err = res.unwrap_err();
721
722 assert_eq!(ErrorKind::Runtime, err.kind());
723 }
724
725 #[tokio::test]
726 async fn test_memcached_client_get_no_errors() {
727 let client = new_client!(
728 "cache01.example.com:11211" => "key1" => "VALUE key1 1 6 123\r\nfoobar\r\nEND\r\n".as_bytes(),
729 "cache02.example.com:11211" => "key2" => "VALUE key2 2 7 456\r\nbazbing\r\nEND\r\n".as_bytes(),
730 );
731 let response = client.get(vec!["key1", "key2"]).await.unwrap();
732
733 assert_eq!(
734 response.values.get("key1"),
735 Some(&Value {
736 key: "key1".to_owned(),
737 cas: 123,
738 flags: 1,
739 data: "foobar".as_bytes().to_owned(),
740 })
741 );
742 assert_eq!(
743 response.values.get("key2"),
744 Some(&Value {
745 key: "key2".to_owned(),
746 cas: 456,
747 flags: 2,
748 data: "bazbing".as_bytes().to_owned(),
749 })
750 );
751 }
752
753 #[tokio::test]
754 async fn test_memcached_client_get_some_errors() {
755 let client = new_client!(
756 "cache01.example.com:11211" => "key1" => "VALUE key1 1 6 123\r\nfoobar\r\nEND\r\n".as_bytes(),
757 "cache02.example.com:11211" => "key2" => "ERROR Too many open connections\r\n".as_bytes(),
758 );
759 let res = client.get(vec!["key1", "key2"]).await;
760 let values = res.unwrap();
761
762 assert_eq!(
763 values.values.get("key1"),
764 Some(&Value {
765 key: "key1".to_owned(),
766 cas: 123,
767 flags: 1,
768 data: "foobar".as_bytes().to_owned(),
769 })
770 );
771 assert_eq!(values.values.get("key2"), None);
772
773 let id = ServerID::from(("cache02.example.com", 11211));
774 assert_eq!(values.errors.get(&id).map(|e| e.kind()), Some(ErrorKind::Protocol))
775 }
776
777 #[tokio::test]
782 async fn test_memcached_client_flush_all_no_wait_success() {
783 let client = new_client!(
784 "cache01.example.com:11211" => "unused1" => "OK\r\n".as_bytes(),
785 "cache02.example.com:11211" => "unused2" => "OK\r\n".as_bytes(),
786 );
787
788 let res = client.flush_all(None).await;
789 let response = res.unwrap();
790
791 assert!(response.values.contains_key(&ServerID::from(("cache01.example.com", 11211))));
792 assert!(response.values.contains_key(&ServerID::from(("cache02.example.com", 11211))));
793 }
794
795 #[tokio::test]
796 async fn test_memcached_client_flush_all_no_wait_some_errors() {
797 let client = new_client!(
798 "cache01.example.com:11211" => "unused1" => "OK\r\n".as_bytes(),
799 "cache02.example.com:11211" => "unused2" => "ERROR\r\n".as_bytes(),
800 );
801
802 let res = client.flush_all(None).await;
803 let response = res.unwrap();
804
805 assert!(response.values.contains_key(&ServerID::from(("cache01.example.com", 11211))));
806 assert!(response.errors.contains_key(&ServerID::from(("cache02.example.com", 11211))));
807 }
808
809 #[tokio::test]
810 async fn test_memcached_client_flush_all_wait_success() {
811 let client = new_client!(
812 "cache01.example.com:11211" => "unused1" => "OK\r\n".as_bytes(),
813 "cache02.example.com:11211" => "unused2" => "OK\r\n".as_bytes(),
814 );
815
816 let res = client.flush_all(Some(Duration::from_secs(5))).await;
817 let response = res.unwrap();
818
819 assert!(response.values.contains_key(&ServerID::from(("cache01.example.com", 11211))));
820 assert!(response.values.contains_key(&ServerID::from(("cache02.example.com", 11211))));
821 }
822
823 #[tokio::test]
824 async fn test_memcached_client_flush_all_wait_some_errors() {
825 let client = new_client!(
826 "cache01.example.com:11211" => "unused1" => "OK\r\n".as_bytes(),
827 "cache02.example.com:11211" => "unused2" => "ERROR\r\n".as_bytes(),
828 );
829
830 let res = client.flush_all(Some(Duration::from_secs(5))).await;
831 let response = res.unwrap();
832
833 assert!(response.values.contains_key(&ServerID::from(("cache01.example.com", 11211))));
834 assert!(response.errors.contains_key(&ServerID::from(("cache02.example.com", 11211))));
835 }
836}