use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::time::Duration;
use url::Url;
use crate::connection::ConnectionManager;
use crate::error::{ClientError, MemcacheError};
use crate::protocol::{Protocol, ProtocolTrait};
use crate::stream::Stream;
use crate::value::{FromMemcacheValueExt, ToMemcacheValue};
use r2d2::Pool;
pub type Stats = HashMap<String, String>;
pub trait Connectable {
fn get_urls(self) -> Vec<String>;
}
impl Connectable for (&str, u16) {
fn get_urls(self) -> Vec<String> {
return vec![format!("{}:{}", self.0, self.1)];
}
}
impl Connectable for &[(&str, u16)] {
fn get_urls(self) -> Vec<String> {
self.iter().map(|(host, port)| format!("{}:{}", host, port)).collect()
}
}
impl Connectable for Url {
fn get_urls(self) -> Vec<String> {
return vec![self.to_string()];
}
}
impl Connectable for String {
fn get_urls(self) -> Vec<String> {
return vec![self];
}
}
impl Connectable for Vec<String> {
fn get_urls(self) -> Vec<String> {
return self;
}
}
impl Connectable for &str {
fn get_urls(self) -> Vec<String> {
return vec![self.to_string()];
}
}
impl Connectable for Vec<&str> {
fn get_urls(self) -> Vec<String> {
let mut urls = vec![];
for url in self {
urls.push(url.to_string());
}
return urls;
}
}
#[derive(Clone)]
pub struct Client {
connections: Vec<Pool<ConnectionManager>>,
pub hash_function: fn(&str) -> u64,
}
unsafe impl Send for Client {}
fn default_hash_function(key: &str) -> u64 {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
return hasher.finish();
}
pub(crate) fn check_key_len(key: &str) -> Result<(), MemcacheError> {
if key.len() > 250 {
Err(ClientError::KeyTooLong)?
}
Ok(())
}
impl Client {
#[deprecated(since = "0.10.0", note = "please use `connect` instead")]
pub fn new<C: Connectable>(target: C) -> Result<Self, MemcacheError> {
return Self::connect(target);
}
pub fn builder() -> ClientBuilder {
ClientBuilder::new()
}
pub fn with_pool_size<C: Connectable>(target: C, size: u32) -> Result<Self, MemcacheError> {
let urls = target.get_urls();
let mut connections = vec![];
for url in urls {
let parsed = Url::parse(url.as_str())?;
let timeout = parsed
.query_pairs()
.find(|&(ref k, ref _v)| k == "connect_timeout")
.and_then(|(ref _k, ref v)| v.parse::<f64>().ok())
.map(Duration::from_secs_f64);
let builder = r2d2::Pool::builder().max_size(size);
let builder = if let Some(timeout) = timeout {
builder.connection_timeout(timeout)
} else {
builder
};
let pool = builder.build(ConnectionManager::new(parsed))?;
connections.push(pool);
}
Ok(Client {
connections,
hash_function: default_hash_function,
})
}
pub fn with_pool(pool: Pool<ConnectionManager>) -> Result<Self, MemcacheError> {
Ok(Client {
connections: vec![pool],
hash_function: default_hash_function,
})
}
pub fn with_pools(pools: Vec<Pool<ConnectionManager>>) -> Result<Self, MemcacheError> {
Ok(Client {
connections: pools,
hash_function: default_hash_function,
})
}
pub fn connect<C: Connectable>(target: C) -> Result<Self, MemcacheError> {
Self::builder().add_server(target)?.build()
}
fn get_connection(&self, key: &str) -> Pool<ConnectionManager> {
let connections_count = self.connections.len();
return self.connections[(self.hash_function)(key) as usize % connections_count].clone();
}
pub fn set_read_timeout(&self, timeout: Option<Duration>) -> Result<(), MemcacheError> {
for conn in self.connections.iter() {
let mut conn = conn.get()?;
match **conn {
Protocol::Ascii(ref mut protocol) => protocol.stream().set_read_timeout(timeout)?,
Protocol::Binary(ref mut protocol) => protocol.stream.set_read_timeout(timeout)?,
}
}
Ok(())
}
pub fn set_write_timeout(&self, timeout: Option<Duration>) -> Result<(), MemcacheError> {
for conn in self.connections.iter() {
let mut conn = conn.get()?;
match **conn {
Protocol::Ascii(ref mut protocol) => protocol.stream().set_write_timeout(timeout)?,
Protocol::Binary(ref mut protocol) => protocol.stream.set_write_timeout(timeout)?,
}
}
Ok(())
}
pub fn version(&self) -> Result<Vec<(String, String)>, MemcacheError> {
let mut result = Vec::with_capacity(self.connections.len());
for connection in self.connections.iter() {
let mut connection = connection.get()?;
let url = connection.get_url();
result.push((url, connection.version()?));
}
Ok(result)
}
pub fn flush(&self) -> Result<(), MemcacheError> {
for connection in self.connections.iter() {
connection.get()?.flush()?;
}
return Ok(());
}
pub fn flush_with_delay(&self, delay: u32) -> Result<(), MemcacheError> {
for connection in self.connections.iter() {
connection.get()?.flush_with_delay(delay)?;
}
return Ok(());
}
pub fn get<V: FromMemcacheValueExt>(&self, key: &str) -> Result<Option<V>, MemcacheError> {
check_key_len(key)?;
return self.get_connection(key).get()?.get(key);
}
pub fn gets<V: FromMemcacheValueExt>(&self, keys: &[&str]) -> Result<HashMap<String, V>, MemcacheError> {
for key in keys {
check_key_len(key)?;
}
let mut con_keys: HashMap<usize, Vec<&str>> = HashMap::new();
let mut result: HashMap<String, V> = HashMap::new();
let connections_count = self.connections.len();
for key in keys {
let connection_index = (self.hash_function)(key) as usize % connections_count;
let array = con_keys.entry(connection_index).or_insert_with(Vec::new);
array.push(key);
}
for (&connection_index, keys) in con_keys.iter() {
let connection = self.connections[connection_index].clone();
result.extend(connection.get()?.gets(keys)?);
}
return Ok(result);
}
pub fn set<V: ToMemcacheValue<Stream>>(&self, key: &str, value: V, expiration: u32) -> Result<(), MemcacheError> {
check_key_len(key)?;
return self.get_connection(key).get()?.set(key, value, expiration);
}
pub fn cas<V: ToMemcacheValue<Stream>>(
&self,
key: &str,
value: V,
expiration: u32,
cas_id: u64,
) -> Result<bool, MemcacheError> {
check_key_len(key)?;
self.get_connection(key).get()?.cas(key, value, expiration, cas_id)
}
pub fn add<V: ToMemcacheValue<Stream>>(&self, key: &str, value: V, expiration: u32) -> Result<(), MemcacheError> {
check_key_len(key)?;
return self.get_connection(key).get()?.add(key, value, expiration);
}
pub fn replace<V: ToMemcacheValue<Stream>>(
&self,
key: &str,
value: V,
expiration: u32,
) -> Result<(), MemcacheError> {
check_key_len(key)?;
return self.get_connection(key).get()?.replace(key, value, expiration);
}
pub fn append<V: ToMemcacheValue<Stream>>(&self, key: &str, value: V) -> Result<(), MemcacheError> {
check_key_len(key)?;
return self.get_connection(key).get()?.append(key, value);
}
pub fn prepend<V: ToMemcacheValue<Stream>>(&self, key: &str, value: V) -> Result<(), MemcacheError> {
check_key_len(key)?;
return self.get_connection(key).get()?.prepend(key, value);
}
pub fn delete(&self, key: &str) -> Result<bool, MemcacheError> {
check_key_len(key)?;
return self.get_connection(key).get()?.delete(key);
}
pub fn increment(&self, key: &str, amount: u64) -> Result<u64, MemcacheError> {
check_key_len(key)?;
return self.get_connection(key).get()?.increment(key, amount);
}
pub fn decrement(&self, key: &str, amount: u64) -> Result<u64, MemcacheError> {
check_key_len(key)?;
return self.get_connection(key).get()?.decrement(key, amount);
}
pub fn touch(&self, key: &str, expiration: u32) -> Result<bool, MemcacheError> {
check_key_len(key)?;
return self.get_connection(key).get()?.touch(key, expiration);
}
pub fn stats(&self) -> Result<Vec<(String, Stats)>, MemcacheError> {
let mut result: Vec<(String, HashMap<String, String>)> = vec![];
for connection in self.connections.iter() {
let mut connection = connection.get()?;
let stats_info = connection.stats()?;
let url = connection.get_url();
result.push((url, stats_info));
}
return Ok(result);
}
}
pub struct ClientBuilder {
targets: Vec<String>,
max_size: u32,
min_idle: Option<u32>,
max_lifetime: Option<Duration>,
read_timeout: Option<Duration>,
write_timeout: Option<Duration>,
connection_timeout: Option<Duration>,
hash_function: fn(&str) -> u64,
}
impl ClientBuilder {
pub fn new() -> Self {
ClientBuilder {
targets: vec![],
max_size: 1,
min_idle: None,
max_lifetime: None,
read_timeout: None,
write_timeout: None,
connection_timeout: None,
hash_function: default_hash_function,
}
}
pub fn add_server<C: Connectable>(mut self, target: C) -> Result<Self, MemcacheError> {
let targets = target.get_urls();
if targets.len() == 0 {
return Err(MemcacheError::BadURL("No servers specified".to_string()));
}
self.targets.extend(targets);
Ok(self)
}
pub fn with_max_pool_size(mut self, max_size: u32) -> Self {
self.max_size = max_size;
self
}
pub fn with_min_idle_conns(mut self, min_idle: u32) -> Self {
self.min_idle = Some(min_idle);
self
}
pub fn with_max_conn_lifetime(mut self, max_lifetime: Duration) -> Self {
self.max_lifetime = Some(max_lifetime);
self
}
pub fn with_read_timeout(mut self, read_timeout: Duration) -> Self {
self.read_timeout = Some(read_timeout);
self
}
pub fn with_write_timeout(mut self, write_timeout: Duration) -> Self {
self.write_timeout = Some(write_timeout);
self
}
pub fn with_connection_timeout(mut self, connection_timeout: Duration) -> Self {
self.connection_timeout = Some(connection_timeout);
self
}
pub fn with_hash_function(mut self, hash_function: fn(&str) -> u64) -> Self {
self.hash_function = hash_function;
self
}
pub fn build(self) -> Result<Client, MemcacheError> {
let urls = self.targets;
if urls.len() == 0 {
return Err(MemcacheError::BadURL("No servers specified".to_string()));
}
let max_size = self.max_size;
let min_idle = self.min_idle;
let max_lifetime = self.max_lifetime;
let timeout = self.connection_timeout;
let mut connections = vec![];
for url in urls.iter() {
let url = Url::parse(url.as_str()).map_err(|e| MemcacheError::BadURL(e.to_string()))?;
match url.scheme() {
"memcache" | "memcache+tls" | "memcache+udp" => {}
_ => {
return Err(MemcacheError::BadURL(format!("Unsupported protocol: {}", url.scheme())));
}
}
let mut builder = r2d2::Pool::builder()
.max_size(max_size)
.min_idle(min_idle)
.max_lifetime(max_lifetime);
if let Some(timeout) = timeout {
builder = builder.connection_timeout(timeout);
}
let connection = builder
.build(ConnectionManager::new(url))
.map_err(|e| MemcacheError::PoolError(e))?;
connections.push(connection);
}
let client = Client {
connections,
hash_function: self.hash_function,
};
client.set_read_timeout(self.read_timeout)?;
client.set_write_timeout(self.write_timeout)?;
Ok(client)
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
#[test]
fn build_client_happy_path() {
let client = super::Client::builder()
.add_server("memcache://localhost:12345")
.unwrap()
.build()
.unwrap();
assert!(client.version().unwrap()[0].1 != "");
}
#[test]
fn build_client_bad_url() {
let client = super::Client::builder()
.add_server("memcache://localhost:12345:")
.unwrap()
.build();
assert!(client.is_err());
}
#[test]
fn build_client_no_url() {
let client = super::Client::builder().build();
assert!(client.is_err());
let client = super::Client::builder().add_server(Vec::<String>::new());
assert!(client.is_err());
}
#[test]
fn build_client_with_large_pool_size() {
let client = super::Client::builder()
.add_server("memcache://localhost:12345")
.unwrap()
.with_max_pool_size(100)
.build();
assert!(
client.is_ok(),
"Expected successful client creation with large pool size"
);
}
#[test]
fn build_client_with_custom_hash_function() {
fn custom_hash_function(_key: &str) -> u64 {
42 }
let client = super::Client::builder()
.add_server("memcache://localhost:12345")
.unwrap()
.with_hash_function(custom_hash_function)
.build()
.unwrap();
assert_eq!(
(client.hash_function)("any_key"),
42,
"Expected custom hash function to be used"
);
}
#[test]
fn build_client_zero_min_idle_conns() {
let client = super::Client::builder()
.add_server("memcache://localhost:12345")
.unwrap()
.with_min_idle_conns(0)
.build();
assert!(client.is_ok(), "Should handle zero min idle conns");
}
#[test]
fn build_client_invalid_hash_function() {
let invalid_hash_function = |_: &str| -> u64 {
panic!("This should not be called");
};
let client = super::Client::builder()
.add_server("memcache://localhost:12345")
.unwrap()
.with_hash_function(invalid_hash_function)
.build();
assert!(client.is_ok(), "Should handle custom hash function gracefully");
}
#[test]
fn build_client_with_unsupported_protocol() {
let client = super::Client::builder()
.add_server("unsupported://localhost:12345")
.unwrap()
.build();
assert!(client.is_err(), "Expected error when using an unsupported protocol");
}
#[test]
fn build_client_with_all_optional_parameters() {
let client = super::Client::builder()
.add_server("memcache://localhost:12345")
.unwrap()
.with_max_pool_size(10)
.with_min_idle_conns(2)
.with_max_conn_lifetime(Duration::from_secs(30))
.with_read_timeout(Duration::from_secs(5))
.with_write_timeout(Duration::from_secs(5))
.with_connection_timeout(Duration::from_secs(2))
.build();
assert!(client.is_ok(), "Should successfully build with all optional parameters");
}
#[cfg(unix)]
#[test]
fn unix() {
let client = super::Client::connect("memcache:///tmp/memcached.sock").unwrap();
assert!(client.version().unwrap()[0].1 != "");
}
#[cfg(feature = "tls")]
#[test]
fn ssl_noverify() {
let client = super::Client::connect("memcache+tls://localhost:12350?verify_mode=none").unwrap();
assert!(client.version().unwrap()[0].1 != "");
}
#[cfg(feature = "tls")]
#[test]
fn ssl_verify() {
let client =
super::Client::connect("memcache+tls://localhost:12350?ca_path=tests/assets/RUST_MEMCACHE_TEST_CERT.crt")
.unwrap();
assert!(client.version().unwrap()[0].1 != "");
}
#[cfg(feature = "tls")]
#[test]
fn ssl_client_certs() {
let client = super::Client::connect("memcache+tls://localhost:12351?key_path=tests/assets/client.key&cert_path=tests/assets/client.crt&ca_path=tests/assets/RUST_MEMCACHE_TEST_CERT.crt").unwrap();
assert!(client.version().unwrap()[0].1 != "");
}
#[test]
fn delete() {
let client = super::Client::connect("memcache://localhost:12345").unwrap();
client.set("an_exists_key", "value", 0).unwrap();
assert_eq!(client.delete("an_exists_key").unwrap(), true);
assert_eq!(client.delete("a_not_exists_key").unwrap(), false);
}
#[test]
fn increment() {
let client = super::Client::connect("memcache://localhost:12345").unwrap();
client.delete("counter").unwrap();
client.set("counter", 321, 0).unwrap();
assert_eq!(client.increment("counter", 123).unwrap(), 444);
}
}