use crate::core::{ErrorKind, Key, Memcached, Meta, MtopError, SlabItems, Slabs, Stats, Value};
use crate::discovery::{Server, ServerID};
use crate::net::{self, TlsConfig};
use crate::pool::{ClientFactory, ClientPool, ClientPoolConfig, PooledClient};
use async_trait::async_trait;
use std::collections::HashMap;
use std::fmt;
use std::hash::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::time::Duration;
use tokio::runtime::Handle;
use tokio::task::JoinHandle;
use tokio_rustls::rustls::ClientConfig;
use tokio_rustls::rustls::pki_types::ServerName;
use tracing::instrument::WithSubscriber;
#[derive(Debug)]
pub struct TlsTcpClientFactory {
client_config: Arc<ClientConfig>,
server_name: Option<ServerName<'static>>,
}
impl TlsTcpClientFactory {
pub async fn new(tls: TlsConfig) -> Result<Self, MtopError> {
let server_name = tls.server_name.clone();
let client_config = Arc::new(net::tls_client_config(tls).await?);
Ok(Self {
client_config,
server_name,
})
}
}
#[async_trait]
impl ClientFactory<Server, Memcached> for TlsTcpClientFactory {
async fn make(&self, key: &Server) -> Result<Memcached, MtopError> {
let server_name = self
.server_name
.clone()
.or_else(|| key.server_name().clone())
.expect("TLS server name must be set on each server when using TlsTcpClientFactory: this is a bug");
let (read, write) = match key.id() {
ServerID::Socket(sock) => net::tcp_tls_connect(sock, server_name, self.client_config.clone()).await?,
ServerID::Name(name) => net::tcp_tls_connect(name, server_name, self.client_config.clone()).await?,
id => panic!("unexpected {:?} passed to TlsTcpClientFactory: this is a bug", id),
};
Ok(Memcached::new(read, write))
}
}
#[derive(Debug)]
pub struct TcpClientFactory;
#[async_trait]
impl ClientFactory<Server, Memcached> for TcpClientFactory {
async fn make(&self, key: &Server) -> Result<Memcached, MtopError> {
let (read, write) = match key.id() {
ServerID::Socket(sock) => net::tcp_connect(sock).await?,
ServerID::Name(name) => net::tcp_connect(name).await?,
id => panic!("unexpected {:?} passed to TcpClientFactory: this is a bug", id),
};
Ok(Memcached::new(read, write))
}
}
#[cfg(unix)]
#[derive(Debug)]
pub struct UnixClientFactory;
#[cfg(unix)]
#[async_trait]
impl ClientFactory<Server, Memcached> for UnixClientFactory {
async fn make(&self, key: &Server) -> Result<Memcached, MtopError> {
let (read, write) = match key.id() {
ServerID::Path(path) => net::unix_connect(path).await?,
id => panic!("unexpected {:?} passed to UnixClientFactory: this is a bug", id),
};
Ok(Memcached::new(read, write))
}
}
pub trait Selector {
fn servers(&self) -> Vec<Server>;
fn server(&self, key: &Key) -> Result<Server, MtopError>;
}
#[derive(Debug)]
pub struct RendezvousSelector {
servers: Vec<Server>,
}
impl RendezvousSelector {
pub fn new(servers: Vec<Server>) -> Self {
Self { servers }
}
fn score(server: &Server, key: &Key) -> u64 {
let mut hasher = DefaultHasher::new();
match server.id() {
ServerID::Name(name) => name.hash(&mut hasher),
ServerID::Socket(addr) => addr.hash(&mut hasher),
ServerID::Path(path) => path.hash(&mut hasher),
}
hasher.write(key.as_ref().as_bytes());
hasher.finish()
}
}
impl Selector for RendezvousSelector {
fn servers(&self) -> Vec<Server> {
self.servers.clone()
}
fn server(&self, key: &Key) -> Result<Server, MtopError> {
if self.servers.is_empty() {
Err(MtopError::runtime("no servers available"))
} else if self.servers.len() == 1 {
Ok(self.servers.first().cloned().unwrap())
} else {
let mut max = u64::MIN;
let mut choice = None;
for server in self.servers.iter() {
let score = Self::score(server, key);
if score >= max {
choice = Some(server);
max = score;
}
}
Ok(choice.cloned().unwrap())
}
}
}
#[derive(Debug, Default)]
pub struct ServersResponse<T> {
pub values: HashMap<ServerID, T>,
pub errors: HashMap<ServerID, MtopError>,
}
#[derive(Debug, Default)]
pub struct ValuesResponse {
pub values: HashMap<String, Value>,
pub errors: HashMap<ServerID, MtopError>,
}
macro_rules! run_for_host {
($pool:expr, $server:expr, $method:ident $(, $args:expr)* $(,)?) => {{
let mut conn = $pool.get($server).await?;
match conn.$method($($args,)*).await {
Ok(v) => {
$pool.put(conn).await;
Ok(v)
}
Err(e) => {
if e.kind() == ErrorKind::Protocol {
$pool.put(conn).await;
}
Err(e)
}
}
}};
}
macro_rules! spawn_for_host {
($self:ident, $server:expr, $method:ident $(, $args:expr)* $(,)?) => {{
let pool = $self.pool.clone();
$self.handle.spawn(async move {
run_for_host!(pool, $server, $method, $($args,)*)
}
.with_current_subscriber())
}};
}
macro_rules! operation_for_key {
($self:ident, $method:ident, $key:expr $(, $args:expr)* $(,)?) => {{
let key = Key::one($key)?;
let server = $self.selector.server(&key)?;
run_for_host!($self.pool, &server, $method, &key, $($args,)*)
}};
}
macro_rules! operation_for_all {
($self:ident, $method:ident) => {{
let servers = $self.selector.servers();
let tasks = servers
.into_iter()
.map(|server| (server.id().clone(), spawn_for_host!($self, &server, $method)))
.collect::<Vec<_>>();
Ok(collect_results(tasks).await)
}};
}
async fn collect_results<T>(tasks: Vec<(ServerID, JoinHandle<Result<T, MtopError>>)>) -> ServersResponse<T> {
let mut values = HashMap::with_capacity(tasks.len());
let mut errors = HashMap::new();
for (id, task) in tasks {
match task.await {
Ok(Ok(result)) => {
values.insert(id, result);
}
Ok(Err(e)) => {
errors.insert(id, e);
}
Err(e) => {
errors.insert(id, MtopError::runtime_cause("fetching cluster values", e));
}
};
}
ServersResponse { values, errors }
}
#[derive(Debug, Clone)]
pub struct MemcachedClientConfig {
pub pool_max_idle: u64,
pub pool_name: String,
}
impl Default for MemcachedClientConfig {
fn default() -> Self {
Self {
pool_max_idle: 4,
pool_name: "memcached-tcp".to_owned(),
}
}
}
pub struct MemcachedClient {
handle: Handle,
selector: Box<dyn Selector + Send + Sync>,
pool: Arc<ClientPool<Server, Memcached>>,
}
impl MemcachedClient {
pub fn new<S, F>(config: MemcachedClientConfig, handle: Handle, selector: S, factory: F) -> Self
where
S: Selector + Send + Sync + 'static,
F: ClientFactory<Server, Memcached> + Send + Sync + 'static,
{
let pool_config = ClientPoolConfig {
name: config.pool_name,
max_idle: config.pool_max_idle,
};
Self {
handle,
selector: Box::new(selector),
pool: Arc::new(ClientPool::new(pool_config, factory)),
}
}
pub async fn raw_open(&self, server: &Server) -> Result<PooledClient<Server, Memcached>, MtopError> {
self.pool.get(server).await
}
pub async fn raw_close(&self, connection: PooledClient<Server, Memcached>) {
self.pool.put(connection).await
}
pub async fn stats(&self) -> Result<ServersResponse<Stats>, MtopError> {
operation_for_all!(self, stats)
}
pub async fn slabs(&self) -> Result<ServersResponse<Slabs>, MtopError> {
operation_for_all!(self, slabs)
}
pub async fn items(&self) -> Result<ServersResponse<SlabItems>, MtopError> {
operation_for_all!(self, items)
}
pub async fn metas(&self) -> Result<ServersResponse<Vec<Meta>>, MtopError> {
operation_for_all!(self, metas)
}
pub async fn ping(&self) -> Result<ServersResponse<()>, MtopError> {
operation_for_all!(self, ping)
}
pub async fn flush_all(&self, wait: Option<Duration>) -> Result<ServersResponse<()>, MtopError> {
let servers = self.selector.servers();
let tasks = servers
.into_iter()
.enumerate()
.map(|(i, server)| {
let delay = wait.map(|d| d * i as u32);
(server.id().clone(), spawn_for_host!(self, &server, flush_all, delay))
})
.collect::<Vec<_>>();
Ok(collect_results(tasks).await)
}
pub async fn get<I, K>(&self, keys: I) -> Result<ValuesResponse, MtopError>
where
I: IntoIterator<Item = K>,
K: Into<String>,
{
let keys = Key::many(keys)?;
if keys.is_empty() {
return Ok(ValuesResponse::default());
}
let num_keys = keys.len();
let mut by_server: HashMap<Server, Vec<Key>> = HashMap::new();
for key in keys {
let server = self.selector.server(&key)?;
let entry = by_server.entry(server).or_default();
entry.push(key);
}
let tasks = by_server
.into_iter()
.map(|(server, keys)| (server.id().clone(), spawn_for_host!(self, &server, get, &keys)))
.collect::<Vec<_>>();
let mut values = HashMap::with_capacity(num_keys);
let mut errors = HashMap::new();
for (id, task) in tasks {
match task.await {
Ok(Ok(results)) => {
values.extend(results);
}
Ok(Err(e)) => {
errors.insert(id, e);
}
Err(e) => {
errors.insert(id, MtopError::runtime_cause("fetching keys", e));
}
};
}
Ok(ValuesResponse { values, errors })
}
pub async fn incr<K>(&self, key: K, delta: u64) -> Result<u64, MtopError>
where
K: Into<String>,
{
operation_for_key!(self, incr, key, delta)
}
pub async fn decr<K>(&self, key: K, delta: u64) -> Result<u64, MtopError>
where
K: Into<String>,
{
operation_for_key!(self, decr, key, delta)
}
pub async fn set<K, V>(&self, key: K, flags: u64, ttl: u32, data: V) -> Result<(), MtopError>
where
K: Into<String>,
V: AsRef<[u8]>,
{
operation_for_key!(self, set, key, flags, ttl, data)
}
pub async fn add<K, V>(&self, key: K, flags: u64, ttl: u32, data: V) -> Result<(), MtopError>
where
K: Into<String>,
V: AsRef<[u8]>,
{
operation_for_key!(self, add, key, flags, ttl, data)
}
pub async fn replace<K, V>(&self, key: K, flags: u64, ttl: u32, data: V) -> Result<(), MtopError>
where
K: Into<String>,
V: AsRef<[u8]>,
{
operation_for_key!(self, replace, key, flags, ttl, data)
}
pub async fn touch<K>(&self, key: K, ttl: u32) -> Result<(), MtopError>
where
K: Into<String>,
{
operation_for_key!(self, touch, key, ttl)
}
pub async fn delete<K>(&self, key: K) -> Result<(), MtopError>
where
K: Into<String>,
{
operation_for_key!(self, delete, key)
}
}
impl fmt::Debug for MemcachedClient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MemcachedClient")
.field("selector", &"...")
.field("pool", &self.pool)
.finish()
}
}
#[cfg(test)]
mod test {
use super::{MemcachedClient, MemcachedClientConfig, Selector};
use crate::core::{ErrorKind, Key, Memcached, MtopError, Value};
use crate::discovery::{Server, ServerID};
use crate::pool::ClientFactory;
use async_trait::async_trait;
use rustls_pki_types::ServerName;
use std::collections::HashMap;
use std::io::Cursor;
use std::time::Duration;
use tokio::runtime::Handle;
#[derive(Debug, Default)]
struct MockSelector {
mapping: HashMap<Key, Server>,
}
impl Selector for MockSelector {
fn servers(&self) -> Vec<Server> {
self.mapping.values().cloned().collect()
}
fn server(&self, key: &Key) -> Result<Server, MtopError> {
self.mapping
.get(key)
.cloned()
.ok_or_else(|| MtopError::runtime("no servers available"))
}
}
#[derive(Debug, Default)]
struct MockClientFactory {
contents: HashMap<Server, Vec<u8>>,
}
#[async_trait]
impl ClientFactory<Server, Memcached> for MockClientFactory {
async fn make(&self, key: &Server) -> Result<Memcached, MtopError> {
let bytes = self
.contents
.get(key)
.cloned()
.ok_or_else(|| MtopError::runtime(format!("no server for {:?}", key)))?;
let reads = Cursor::new(bytes);
Ok(Memcached::new(reads, Vec::new()))
}
}
macro_rules! new_client {
() => {{
let cfg = MemcachedClientConfig::default();
let handle = Handle::current();
let selector = MockSelector::default();
let factory = MockClientFactory::default();
MemcachedClient::new(cfg, handle, selector, factory)
}};
($($host_and_port:expr => $key:expr => $contents:expr$(,)?)*) => {{
let mut mapping = HashMap::new();
let mut contents = HashMap::new();
$(
let server = {
let (host, port_str) = $host_and_port.split_once(':').unwrap();
let port: u16 = port_str.parse().unwrap();
let id = ServerID::from((host, port));
let name = ServerName::try_from(host).unwrap();
Server::new(id, name)
};
mapping.insert(Key::one($key).unwrap(), server.clone());
contents.insert(server, $contents.to_vec());
)*
let cfg = MemcachedClientConfig::default();
let handle = Handle::current();
let selector = MockSelector { mapping };
let factory = MockClientFactory { contents };
MemcachedClient::new(cfg, handle, selector, factory)
}};
}
#[tokio::test]
async fn test_memcached_client_ping_no_servers() {
let client = new_client!();
let response = client.ping().await.unwrap();
assert!(response.values.is_empty());
assert!(response.errors.is_empty());
}
#[tokio::test]
async fn test_memcached_client_ping_no_errors() {
let client = new_client!(
"cache01.example.com:11211" => "unused1" => "VERSION 1.6.22\r\n".as_bytes(),
"cache02.example.com:11211" => "unused2" => "VERSION 1.6.22\r\n".as_bytes(),
);
let response = client.ping().await.unwrap();
assert!(response.values.contains_key(&ServerID::from(("cache01.example.com", 11211))));
assert!(response.values.contains_key(&ServerID::from(("cache02.example.com", 11211))));
assert!(response.errors.is_empty());
}
#[tokio::test]
async fn test_memcached_client_ping_some_errors() {
let client = new_client!(
"cache01.example.com:11211" => "unused1" => "VERSION 1.6.22\r\n".as_bytes(),
"cache02.example.com:11211" => "unused2" => "ERROR Too many open connections\r\n".as_bytes(),
);
let response = client.ping().await.unwrap();
assert!(response.values.contains_key(&ServerID::from(("cache01.example.com", 11211))));
assert!(response.errors.contains_key(&ServerID::from(("cache02.example.com", 11211))));
}
#[tokio::test]
async fn test_memcached_client_set_no_servers() {
let client = new_client!();
let res = client.set("key1", 1, 60, "foo".as_bytes()).await;
let err = res.unwrap_err();
assert_eq!(ErrorKind::Runtime, err.kind());
}
#[tokio::test]
async fn test_memcached_client_set_success() {
let client = new_client!(
"cache01.example.com:11211" => "key1" => "STORED\r\n".as_bytes(),
);
client.set("key1", 1, 60, "foo".as_bytes()).await.unwrap();
}
#[tokio::test]
async fn test_memcached_client_get_invalid_keys() {
let client = new_client!();
let res = client.get(vec!["invalid key"]).await;
let err = res.unwrap_err();
assert_eq!(ErrorKind::Runtime, err.kind());
}
#[tokio::test]
async fn test_memcached_client_get_no_keys() {
let client = new_client!();
let keys: Vec<String> = Vec::new();
let response = client.get(keys).await.unwrap();
assert!(response.values.is_empty());
assert!(response.errors.is_empty());
}
#[tokio::test]
async fn test_memcached_client_get_no_servers() {
let client = new_client!();
let res = client.get(vec!["key1", "key2"]).await;
let err = res.unwrap_err();
assert_eq!(ErrorKind::Runtime, err.kind());
}
#[tokio::test]
async fn test_memcached_client_get_no_errors() {
let client = new_client!(
"cache01.example.com:11211" => "key1" => "VALUE key1 1 6 123\r\nfoobar\r\nEND\r\n".as_bytes(),
"cache02.example.com:11211" => "key2" => "VALUE key2 2 7 456\r\nbazbing\r\nEND\r\n".as_bytes(),
);
let response = client.get(vec!["key1", "key2"]).await.unwrap();
assert_eq!(
response.values.get("key1"),
Some(&Value {
key: "key1".to_owned(),
cas: 123,
flags: 1,
data: "foobar".as_bytes().to_owned(),
})
);
assert_eq!(
response.values.get("key2"),
Some(&Value {
key: "key2".to_owned(),
cas: 456,
flags: 2,
data: "bazbing".as_bytes().to_owned(),
})
);
}
#[tokio::test]
async fn test_memcached_client_get_some_errors() {
let client = new_client!(
"cache01.example.com:11211" => "key1" => "VALUE key1 1 6 123\r\nfoobar\r\nEND\r\n".as_bytes(),
"cache02.example.com:11211" => "key2" => "ERROR Too many open connections\r\n".as_bytes(),
);
let res = client.get(vec!["key1", "key2"]).await;
let values = res.unwrap();
assert_eq!(
values.values.get("key1"),
Some(&Value {
key: "key1".to_owned(),
cas: 123,
flags: 1,
data: "foobar".as_bytes().to_owned(),
})
);
assert_eq!(values.values.get("key2"), None);
let id = ServerID::from(("cache02.example.com", 11211));
assert_eq!(values.errors.get(&id).map(|e| e.kind()), Some(ErrorKind::Protocol))
}
#[tokio::test]
async fn test_memcached_client_flush_all_no_wait_success() {
let client = new_client!(
"cache01.example.com:11211" => "unused1" => "OK\r\n".as_bytes(),
"cache02.example.com:11211" => "unused2" => "OK\r\n".as_bytes(),
);
let res = client.flush_all(None).await;
let response = res.unwrap();
assert!(response.values.contains_key(&ServerID::from(("cache01.example.com", 11211))));
assert!(response.values.contains_key(&ServerID::from(("cache02.example.com", 11211))));
}
#[tokio::test]
async fn test_memcached_client_flush_all_no_wait_some_errors() {
let client = new_client!(
"cache01.example.com:11211" => "unused1" => "OK\r\n".as_bytes(),
"cache02.example.com:11211" => "unused2" => "ERROR\r\n".as_bytes(),
);
let res = client.flush_all(None).await;
let response = res.unwrap();
assert!(response.values.contains_key(&ServerID::from(("cache01.example.com", 11211))));
assert!(response.errors.contains_key(&ServerID::from(("cache02.example.com", 11211))));
}
#[tokio::test]
async fn test_memcached_client_flush_all_wait_success() {
let client = new_client!(
"cache01.example.com:11211" => "unused1" => "OK\r\n".as_bytes(),
"cache02.example.com:11211" => "unused2" => "OK\r\n".as_bytes(),
);
let res = client.flush_all(Some(Duration::from_secs(5))).await;
let response = res.unwrap();
assert!(response.values.contains_key(&ServerID::from(("cache01.example.com", 11211))));
assert!(response.values.contains_key(&ServerID::from(("cache02.example.com", 11211))));
}
#[tokio::test]
async fn test_memcached_client_flush_all_wait_some_errors() {
let client = new_client!(
"cache01.example.com:11211" => "unused1" => "OK\r\n".as_bytes(),
"cache02.example.com:11211" => "unused2" => "ERROR\r\n".as_bytes(),
);
let res = client.flush_all(Some(Duration::from_secs(5))).await;
let response = res.unwrap();
assert!(response.values.contains_key(&ServerID::from(("cache01.example.com", 11211))));
assert!(response.errors.contains_key(&ServerID::from(("cache02.example.com", 11211))));
}
}