1use crate::core::{ErrorKind, Key, Memcached, Meta, MtopError, SlabItems, Slabs, Stats, Value};
2use crate::discovery::{Server, ServerID};
3use crate::net::{self, TlsConfig};
4use crate::pool::{ClientFactory, ClientPool, ClientPoolConfig, PooledClient};
5use std::collections::HashMap;
6use std::future::Future;
7use std::hash::DefaultHasher;
8use std::hash::{Hash, Hasher};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::runtime::Handle;
12use tokio::task::JoinHandle;
13use tokio_rustls::rustls::ClientConfig;
14use tokio_rustls::rustls::pki_types::ServerName;
15use tracing::instrument::WithSubscriber;
16
17#[derive(Debug)]
20pub struct TcpClientFactory {
21 client_config: Option<Arc<ClientConfig>>,
22 server_name: Option<ServerName<'static>>,
23}
24
25impl TcpClientFactory {
26 pub async fn new(tls: TlsConfig) -> Result<Self, MtopError> {
27 let server_name = if tls.enabled { tls.server_name.clone() } else { None };
28
29 let client_config = if tls.enabled {
30 Some(Arc::new(net::tls_client_config(tls).await?))
31 } else {
32 None
33 };
34
35 Ok(Self {
36 client_config,
37 server_name,
38 })
39 }
40
41 fn get_server_name(&self, server: &Server) -> ServerName<'static> {
42 self.server_name.clone().unwrap_or_else(|| server.server_name().clone())
43 }
44
45 async fn connect(id: &ServerID) -> Result<Memcached, MtopError> {
46 let (read, write) = match id {
47 ServerID::Socket(sock) => net::tcp_connect(sock).await?,
48 ServerID::Name(name) => net::tcp_connect(name).await?,
49 };
50
51 Ok(Memcached::new(read, write))
52 }
53
54 async fn connect_tls(
55 id: &ServerID,
56 server_name: ServerName<'static>,
57 cfg: Arc<ClientConfig>,
58 ) -> Result<Memcached, MtopError> {
59 let (read, write) = match id {
60 ServerID::Socket(sock) => net::tcp_tls_connect(sock, server_name, cfg.clone()).await?,
61 ServerID::Name(name) => net::tcp_tls_connect(name, server_name, cfg.clone()).await?,
62 };
63
64 Ok(Memcached::new(read, write))
65 }
66}
67
68impl ClientFactory<Server, Memcached> for TcpClientFactory {
69 async fn make(&self, server: &Server) -> Result<Memcached, MtopError> {
70 if let Some(cfg) = &self.client_config {
71 Self::connect_tls(server.id(), self.get_server_name(server), cfg.clone()).await
72 } else {
73 Self::connect(server.id()).await
74 }
75 }
76}
77
78pub trait Selector {
80 fn servers(&self) -> impl Future<Output = Vec<Server>> + Send + Sync;
82
83 fn server(&self, key: &Key) -> impl Future<Output = Result<Server, MtopError>> + Send + Sync;
85}
86
87#[derive(Debug)]
92pub struct RendezvousSelector {
93 servers: Vec<Server>,
94}
95
96impl RendezvousSelector {
97 pub fn new(servers: Vec<Server>) -> Self {
99 Self { servers }
100 }
101
102 fn score(server: &Server, key: &Key) -> u64 {
103 let mut hasher = DefaultHasher::new();
104
105 match server.id() {
107 ServerID::Name(name) => name.hash(&mut hasher),
108 ServerID::Socket(addr) => addr.hash(&mut hasher),
109 }
110
111 hasher.write(key.as_ref().as_bytes());
112 hasher.finish()
113 }
114}
115
116impl Selector for RendezvousSelector {
117 async fn servers(&self) -> Vec<Server> {
118 self.servers.clone()
119 }
120
121 async fn server(&self, key: &Key) -> Result<Server, MtopError> {
122 if self.servers.is_empty() {
123 Err(MtopError::runtime("no servers available"))
124 } else if self.servers.len() == 1 {
125 Ok(self.servers.first().cloned().unwrap())
126 } else {
127 let mut max = u64::MIN;
128 let mut choice = None;
129
130 for server in self.servers.iter() {
131 let score = Self::score(server, key);
132 if score >= max {
133 choice = Some(server);
134 max = score;
135 }
136 }
137
138 Ok(choice.cloned().unwrap())
139 }
140 }
141}
142
143#[derive(Debug, Default)]
145pub struct ServersResponse<T> {
146 pub values: HashMap<ServerID, T>,
147 pub errors: HashMap<ServerID, MtopError>,
148}
149
150#[derive(Debug, Default)]
152pub struct ValuesResponse {
153 pub values: HashMap<String, Value>,
154 pub errors: HashMap<ServerID, MtopError>,
155}
156
157macro_rules! spawn_for_host {
159 ($self:ident, $method:ident, $server:expr $(, $args:expr)* $(,)?) => {{
160 let pool = $self.pool.clone();
161 $self.handle.spawn(async move {
162 let mut conn = pool.get($server).await?;
163 match conn.$method($($args,)*).await {
164 Ok(v) => {
165 pool.put(conn).await;
166 Ok(v)
167 }
168 Err(e) => {
169 if e.kind() == ErrorKind::Protocol {
173 pool.put(conn).await;
174 }
175 Err(e)
176 }
177 }
178 }
179 .with_current_subscriber())
181 }};
182}
183
184macro_rules! operation_for_key {
186 ($self:ident, $method:ident, $key:expr $(, $args:expr)* $(,)?) => {{
187 let key = Key::one($key)?;
188 let server = $self.selector.server(&key).await?;
189 let mut conn = $self.pool.get(&server).await?;
190
191 match conn.$method(&key, $($args,)*).await {
192 Ok(v) => {
193 $self.pool.put(conn).await;
194 Ok(v)
195 }
196 Err(e) => {
197 if e.kind() == ErrorKind::Protocol {
201 $self.pool.put(conn).await;
202 }
203 Err(e)
204 }
205 }
206 }};
207}
208
209macro_rules! operation_for_all {
211 ($self:ident, $method:ident) => {{
212 let servers = $self.selector.servers().await;
213 let tasks = servers
214 .into_iter()
215 .map(|server| (server.id().clone(), spawn_for_host!($self, $method, &server)))
216 .collect::<Vec<_>>();
217
218 Ok(collect_results(tasks).await)
219 }};
220}
221
222async fn collect_results<T>(tasks: Vec<(ServerID, JoinHandle<Result<T, MtopError>>)>) -> ServersResponse<T> {
224 let mut values = HashMap::with_capacity(tasks.len());
225 let mut errors = HashMap::new();
226
227 for (id, task) in tasks {
228 match task.await {
229 Ok(Ok(result)) => {
230 values.insert(id, result);
231 }
232 Ok(Err(e)) => {
233 errors.insert(id, e);
234 }
235 Err(e) => {
236 errors.insert(id, MtopError::runtime_cause("fetching cluster values", e));
237 }
238 };
239 }
240
241 ServersResponse { values, errors }
242}
243
244#[derive(Debug, Clone)]
246pub struct MemcachedClientConfig {
247 pub pool_max_idle: u64,
248}
249
250impl Default for MemcachedClientConfig {
251 fn default() -> Self {
252 Self { pool_max_idle: 4 }
253 }
254}
255
256#[derive(Debug)]
259pub struct MemcachedClient<S = RendezvousSelector, F = TcpClientFactory>
260where
261 S: Selector + Send + Sync + 'static,
262 F: ClientFactory<Server, Memcached> + Send + Sync + 'static,
263{
264 handle: Handle,
265 selector: S,
266 pool: Arc<ClientPool<Server, Memcached, F>>,
267}
268
269impl<S, F> MemcachedClient<S, F>
270where
271 S: Selector + Send + Sync + 'static,
272 F: ClientFactory<Server, Memcached> + Send + Sync + 'static,
273{
274 pub fn new(config: MemcachedClientConfig, handle: Handle, selector: S, factory: F) -> Self {
281 let pool_config = ClientPoolConfig {
282 name: "memcached-tcp".to_owned(),
283 max_idle: config.pool_max_idle,
284 };
285
286 Self {
287 handle,
288 selector,
289 pool: Arc::new(ClientPool::new(pool_config, factory)),
290 }
291 }
292
293 pub async fn raw_open(&self, server: &Server) -> Result<PooledClient<Server, Memcached>, MtopError> {
296 self.pool.get(server).await
297 }
298
299 pub async fn raw_close(&self, connection: PooledClient<Server, Memcached>) {
303 self.pool.put(connection).await
304 }
305
306 pub async fn stats(&self) -> Result<ServersResponse<Stats>, MtopError> {
312 operation_for_all!(self, stats)
313 }
314
315 pub async fn slabs(&self) -> Result<ServersResponse<Slabs>, MtopError> {
323 operation_for_all!(self, slabs)
324 }
325
326 pub async fn items(&self) -> Result<ServersResponse<SlabItems>, MtopError> {
335 operation_for_all!(self, items)
336 }
337
338 pub async fn metas(&self) -> Result<ServersResponse<Vec<Meta>>, MtopError> {
346 operation_for_all!(self, metas)
347 }
348
349 pub async fn ping(&self) -> Result<ServersResponse<()>, MtopError> {
355 operation_for_all!(self, ping)
356 }
357
358 pub async fn flush_all(&self, wait: Option<Duration>) -> Result<ServersResponse<()>, MtopError> {
366 let servers = self.selector.servers().await;
370
371 let tasks = servers
372 .into_iter()
373 .enumerate()
374 .map(|(i, server)| {
375 let delay = wait.map(|d| d * i as u32);
376 (server.id().clone(), spawn_for_host!(self, flush_all, &server, delay))
377 })
378 .collect::<Vec<_>>();
379
380 Ok(collect_results(tasks).await)
381 }
382
383 pub async fn get<I, K>(&self, keys: I) -> Result<ValuesResponse, MtopError>
391 where
392 I: IntoIterator<Item = K>,
393 K: Into<String>,
394 {
395 let keys = Key::many(keys)?;
396 if keys.is_empty() {
397 return Ok(ValuesResponse::default());
398 }
399
400 let num_keys = keys.len();
401 let mut by_server: HashMap<Server, Vec<Key>> = HashMap::new();
402 for key in keys {
403 let server = self.selector.server(&key).await?;
404 let entry = by_server.entry(server).or_default();
405 entry.push(key);
406 }
407
408 let tasks = by_server
409 .into_iter()
410 .map(|(server, keys)| (server.id().clone(), spawn_for_host!(self, get, &server, &keys)))
411 .collect::<Vec<_>>();
412
413 let mut values = HashMap::with_capacity(num_keys);
414 let mut errors = HashMap::new();
415
416 for (id, task) in tasks {
417 match task.await {
418 Ok(Ok(results)) => {
419 values.extend(results);
420 }
421 Ok(Err(e)) => {
422 errors.insert(id, e);
423 }
424 Err(e) => {
425 errors.insert(id, MtopError::runtime_cause("fetching keys", e));
426 }
427 };
428 }
429
430 Ok(ValuesResponse { values, errors })
431 }
432
433 pub async fn incr<K>(&self, key: K, delta: u64) -> Result<u64, MtopError>
440 where
441 K: Into<String>,
442 {
443 operation_for_key!(self, incr, key, delta)
444 }
445
446 pub async fn decr<K>(&self, key: K, delta: u64) -> Result<u64, MtopError>
454 where
455 K: Into<String>,
456 {
457 operation_for_key!(self, decr, key, delta)
458 }
459
460 pub async fn set<K, V>(&self, key: K, flags: u64, ttl: u32, data: V) -> Result<(), MtopError>
466 where
467 K: Into<String>,
468 V: AsRef<[u8]>,
469 {
470 operation_for_key!(self, set, key, flags, ttl, data)
471 }
472
473 pub async fn add<K, V>(&self, key: K, flags: u64, ttl: u32, data: V) -> Result<(), MtopError>
479 where
480 K: Into<String>,
481 V: AsRef<[u8]>,
482 {
483 operation_for_key!(self, add, key, flags, ttl, data)
484 }
485
486 pub async fn replace<K, V>(&self, key: K, flags: u64, ttl: u32, data: V) -> Result<(), MtopError>
492 where
493 K: Into<String>,
494 V: AsRef<[u8]>,
495 {
496 operation_for_key!(self, replace, key, flags, ttl, data)
497 }
498
499 pub async fn touch<K>(&self, key: K, ttl: u32) -> Result<(), MtopError>
505 where
506 K: Into<String>,
507 {
508 operation_for_key!(self, touch, key, ttl)
509 }
510
511 pub async fn delete<K>(&self, key: K) -> Result<(), MtopError>
517 where
518 K: Into<String>,
519 {
520 operation_for_key!(self, delete, key)
521 }
522}
523
524#[cfg(test)]
525mod test {
526 use super::{MemcachedClient, MemcachedClientConfig, Selector};
527 use crate::core::{ErrorKind, Key, Memcached, MtopError, Value};
528 use crate::discovery::{Server, ServerID};
529 use crate::pool::ClientFactory;
530 use rustls_pki_types::ServerName;
531 use std::collections::HashMap;
532 use std::io::Cursor;
533 use std::time::Duration;
534 use tokio::runtime::Handle;
535
536 #[derive(Debug, Default)]
537 struct MockSelector {
538 mapping: HashMap<Key, Server>,
539 }
540
541 impl Selector for MockSelector {
542 async fn servers(&self) -> Vec<Server> {
543 self.mapping.values().cloned().collect()
544 }
545
546 async fn server(&self, key: &Key) -> Result<Server, MtopError> {
547 self.mapping
548 .get(key)
549 .cloned()
550 .ok_or_else(|| MtopError::runtime("no servers available"))
551 }
552 }
553
554 #[derive(Debug, Default)]
555 struct MockClientFactory {
556 contents: HashMap<Server, Vec<u8>>,
557 }
558
559 impl ClientFactory<Server, Memcached> for MockClientFactory {
560 async fn make(&self, key: &Server) -> Result<Memcached, MtopError> {
561 let bytes = self
562 .contents
563 .get(key)
564 .cloned()
565 .ok_or_else(|| MtopError::runtime(format!("no server for {:?}", key)))?;
566 let reads = Cursor::new(bytes);
567 Ok(Memcached::new(reads, Vec::new()))
568 }
569 }
570
571 macro_rules! new_client {
572 () => {{
573 let cfg = MemcachedClientConfig { pool_max_idle: 1 };
574 let handle = Handle::current();
575 let selector = MockSelector::default();
576 let factory = MockClientFactory::default();
577 MemcachedClient::new(cfg, handle, selector, factory)
578 }};
579
580 ($($host_and_port:expr => $key:expr => $contents:expr$(,)?)*) => {{
581 let mut mapping = HashMap::new();
582 let mut contents = HashMap::new();
583
584 $(
585 let server = {
586 let (host, port_str) = $host_and_port.split_once(':').unwrap();
587 let port: u16 = port_str.parse().unwrap();
588 let id = ServerID::from((host, port));
589 let name = ServerName::try_from(host).unwrap();
590
591 Server::new(id, name)
592 };
593 mapping.insert(Key::one($key).unwrap(), server.clone());
594 contents.insert(server, $contents.to_vec());
595 )*
596
597 let cfg = MemcachedClientConfig { pool_max_idle: 1 };
598 let handle = Handle::current();
599 let selector = MockSelector { mapping };
600 let factory = MockClientFactory { contents };
601 MemcachedClient::new(cfg, handle, selector, factory)
602 }};
603 }
604
605 #[tokio::test]
619 async fn test_memcached_client_ping_no_servers() {
620 let client = new_client!();
621 let response = client.ping().await.unwrap();
622
623 assert!(response.values.is_empty());
624 assert!(response.errors.is_empty());
625 }
626
627 #[tokio::test]
628 async fn test_memcached_client_ping_no_errors() {
629 let client = new_client!(
630 "cache01.example.com:11211" => "unused1" => "VERSION 1.6.22\r\n".as_bytes(),
631 "cache02.example.com:11211" => "unused2" => "VERSION 1.6.22\r\n".as_bytes(),
632 );
633 let response = client.ping().await.unwrap();
634
635 assert!(response.values.contains_key(&ServerID::from(("cache01.example.com", 11211))));
636 assert!(response.values.contains_key(&ServerID::from(("cache02.example.com", 11211))));
637 assert!(response.errors.is_empty());
638 }
639
640 #[tokio::test]
641 async fn test_memcached_client_ping_some_errors() {
642 let client = new_client!(
643 "cache01.example.com:11211" => "unused1" => "VERSION 1.6.22\r\n".as_bytes(),
644 "cache02.example.com:11211" => "unused2" => "ERROR Too many open connections\r\n".as_bytes(),
645 );
646 let response = client.ping().await.unwrap();
647
648 assert!(response.values.contains_key(&ServerID::from(("cache01.example.com", 11211))));
649 assert!(response.errors.contains_key(&ServerID::from(("cache02.example.com", 11211))));
650 }
651
652 #[tokio::test]
657 async fn test_memcached_client_set_no_servers() {
658 let client = new_client!();
659 let res = client.set("key1", 1, 60, "foo".as_bytes()).await;
660 let err = res.unwrap_err();
661
662 assert_eq!(ErrorKind::Runtime, err.kind());
663 }
664
665 #[tokio::test]
666 async fn test_memcached_client_set_success() {
667 let client = new_client!(
668 "cache01.example.com:11211" => "key1" => "STORED\r\n".as_bytes(),
669 );
670 client.set("key1", 1, 60, "foo".as_bytes()).await.unwrap();
671 }
672
673 #[tokio::test]
678 async fn test_memcached_client_get_invalid_keys() {
679 let client = new_client!();
680 let res = client.get(vec!["invalid key"]).await;
681 let err = res.unwrap_err();
682
683 assert_eq!(ErrorKind::Runtime, err.kind());
684 }
685
686 #[tokio::test]
687 async fn test_memcached_client_get_no_keys() {
688 let client = new_client!();
689 let keys: Vec<String> = Vec::new();
690 let response = client.get(keys).await.unwrap();
691
692 assert!(response.values.is_empty());
693 assert!(response.errors.is_empty());
694 }
695
696 #[tokio::test]
697 async fn test_memcached_client_get_no_servers() {
698 let client = new_client!();
699 let res = client.get(vec!["key1", "key2"]).await;
700 let err = res.unwrap_err();
701
702 assert_eq!(ErrorKind::Runtime, err.kind());
703 }
704
705 #[tokio::test]
706 async fn test_memcached_client_get_no_errors() {
707 let client = new_client!(
708 "cache01.example.com:11211" => "key1" => "VALUE key1 1 6 123\r\nfoobar\r\nEND\r\n".as_bytes(),
709 "cache02.example.com:11211" => "key2" => "VALUE key2 2 7 456\r\nbazbing\r\nEND\r\n".as_bytes(),
710 );
711 let response = client.get(vec!["key1", "key2"]).await.unwrap();
712
713 assert_eq!(
714 response.values.get("key1"),
715 Some(&Value {
716 key: "key1".to_owned(),
717 cas: 123,
718 flags: 1,
719 data: "foobar".as_bytes().to_owned(),
720 })
721 );
722 assert_eq!(
723 response.values.get("key2"),
724 Some(&Value {
725 key: "key2".to_owned(),
726 cas: 456,
727 flags: 2,
728 data: "bazbing".as_bytes().to_owned(),
729 })
730 );
731 }
732
733 #[tokio::test]
734 async fn test_memcached_client_get_some_errors() {
735 let client = new_client!(
736 "cache01.example.com:11211" => "key1" => "VALUE key1 1 6 123\r\nfoobar\r\nEND\r\n".as_bytes(),
737 "cache02.example.com:11211" => "key2" => "ERROR Too many open connections\r\n".as_bytes(),
738 );
739 let res = client.get(vec!["key1", "key2"]).await;
740 let values = res.unwrap();
741
742 assert_eq!(
743 values.values.get("key1"),
744 Some(&Value {
745 key: "key1".to_owned(),
746 cas: 123,
747 flags: 1,
748 data: "foobar".as_bytes().to_owned(),
749 })
750 );
751 assert_eq!(values.values.get("key2"), None);
752
753 let id = ServerID::from(("cache02.example.com", 11211));
754 assert_eq!(values.errors.get(&id).map(|e| e.kind()), Some(ErrorKind::Protocol))
755 }
756
757 #[tokio::test]
762 async fn test_memcached_client_flush_all_no_wait_success() {
763 let client = new_client!(
764 "cache01.example.com:11211" => "unused1" => "OK\r\n".as_bytes(),
765 "cache02.example.com:11211" => "unused2" => "OK\r\n".as_bytes(),
766 );
767
768 let res = client.flush_all(None).await;
769 let response = res.unwrap();
770
771 assert!(response.values.contains_key(&ServerID::from(("cache01.example.com", 11211))));
772 assert!(response.values.contains_key(&ServerID::from(("cache02.example.com", 11211))));
773 }
774
775 #[tokio::test]
776 async fn test_memcached_client_flush_all_no_wait_some_errors() {
777 let client = new_client!(
778 "cache01.example.com:11211" => "unused1" => "OK\r\n".as_bytes(),
779 "cache02.example.com:11211" => "unused2" => "ERROR\r\n".as_bytes(),
780 );
781
782 let res = client.flush_all(None).await;
783 let response = res.unwrap();
784
785 assert!(response.values.contains_key(&ServerID::from(("cache01.example.com", 11211))));
786 assert!(response.errors.contains_key(&ServerID::from(("cache02.example.com", 11211))));
787 }
788
789 #[tokio::test]
790 async fn test_memcached_client_flush_all_wait_success() {
791 let client = new_client!(
792 "cache01.example.com:11211" => "unused1" => "OK\r\n".as_bytes(),
793 "cache02.example.com:11211" => "unused2" => "OK\r\n".as_bytes(),
794 );
795
796 let res = client.flush_all(Some(Duration::from_secs(5))).await;
797 let response = res.unwrap();
798
799 assert!(response.values.contains_key(&ServerID::from(("cache01.example.com", 11211))));
800 assert!(response.values.contains_key(&ServerID::from(("cache02.example.com", 11211))));
801 }
802
803 #[tokio::test]
804 async fn test_memcached_client_flush_all_wait_some_errors() {
805 let client = new_client!(
806 "cache01.example.com:11211" => "unused1" => "OK\r\n".as_bytes(),
807 "cache02.example.com:11211" => "unused2" => "ERROR\r\n".as_bytes(),
808 );
809
810 let res = client.flush_all(Some(Duration::from_secs(5))).await;
811 let response = res.unwrap();
812
813 assert!(response.values.contains_key(&ServerID::from(("cache01.example.com", 11211))));
814 assert!(response.errors.contains_key(&ServerID::from(("cache02.example.com", 11211))));
815 }
816}