use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::time::Duration;
use url::Url;
use connection::ConnectionManager;
use error::{ClientError, MemcacheError};
use protocol::{Protocol, ProtocolTrait};
use r2d2::Pool;
use stream::Stream;
use value::{FromMemcacheValueExt, ToMemcacheValue};
pub type Stats = HashMap<String, String>;
pub trait Connectable {
fn get_urls(self) -> Vec<String>;
}
impl Connectable for String {
fn get_urls(self) -> Vec<String> {
return vec![self];
}
}
impl Connectable for Vec<String> {
fn get_urls(self) -> Vec<String> {
return self;
}
}
impl Connectable for &str {
fn get_urls(self) -> Vec<String> {
return vec![self.to_string()];
}
}
impl Connectable for Vec<&str> {
fn get_urls(self) -> Vec<String> {
let mut urls = vec![];
for url in self {
urls.push(url.to_string());
}
return urls;
}
}
#[derive(Clone)]
pub struct Client {
connections: Vec<Pool<ConnectionManager>>,
pub hash_function: fn(&str) -> u64,
}
unsafe impl Send for Client {}
fn default_hash_function(key: &str) -> u64 {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
return hasher.finish();
}
pub(crate) fn check_key_len(key: &str) -> Result<(), MemcacheError> {
if key.len() > 250 {
Err(ClientError::KeyTooLong)?
}
Ok(())
}
impl Client {
#[deprecated(since = "0.10.0", note = "please use `connect` instead")]
pub fn new<C: Connectable>(target: C) -> Result<Self, MemcacheError> {
return Self::connect(target);
}
pub fn with_pool_size<C: Connectable>(target: C, size: u32) -> Result<Self, MemcacheError> {
let urls = target.get_urls();
let mut connections = vec![];
for url in urls {
let parsed = Url::parse(url.as_str())?;
let pool = r2d2::Pool::builder()
.max_size(size)
.build(ConnectionManager::new(parsed))?;
connections.push(pool);
}
Ok(Client {
connections,
hash_function: default_hash_function,
})
}
pub fn connect<C: Connectable>(target: C) -> Result<Self, MemcacheError> {
Self::with_pool_size(target, 1)
}
fn get_connection(&self, key: &str) -> Pool<ConnectionManager> {
let connections_count = self.connections.len();
return self.connections[(self.hash_function)(key) as usize % connections_count].clone();
}
pub fn set_read_timeout(&self, timeout: Option<Duration>) -> Result<(), MemcacheError> {
for conn in self.connections.iter() {
let mut conn = conn.get()?;
match **conn {
Protocol::Ascii(ref mut protocol) => protocol.stream().set_read_timeout(timeout)?,
Protocol::Binary(ref mut protocol) => protocol.stream.set_read_timeout(timeout)?,
}
}
Ok(())
}
pub fn set_write_timeout(&self, timeout: Option<Duration>) -> Result<(), MemcacheError> {
for conn in self.connections.iter() {
let mut conn = conn.get()?;
match **conn {
Protocol::Ascii(ref mut protocol) => protocol.stream().set_read_timeout(timeout)?,
Protocol::Binary(ref mut protocol) => protocol.stream.set_write_timeout(timeout)?,
}
}
Ok(())
}
pub fn version(&self) -> Result<Vec<(String, String)>, MemcacheError> {
let mut result = Vec::with_capacity(self.connections.len());
for connection in self.connections.iter() {
let mut connection = connection.get()?;
let url = connection.get_url();
result.push((url, connection.version()?));
}
Ok(result)
}
pub fn flush(&self) -> Result<(), MemcacheError> {
for connection in self.connections.iter() {
connection.get()?.flush()?;
}
return Ok(());
}
pub fn flush_with_delay(&self, delay: u32) -> Result<(), MemcacheError> {
for connection in self.connections.iter() {
connection.get()?.flush_with_delay(delay)?;
}
return Ok(());
}
pub fn get<V: FromMemcacheValueExt>(&self, key: &str) -> Result<Option<V>, MemcacheError> {
check_key_len(key)?;
return self.get_connection(key).get()?.get(key);
}
pub fn gets<V: FromMemcacheValueExt>(&self, keys: &[&str]) -> Result<HashMap<String, V>, MemcacheError> {
for key in keys {
check_key_len(key)?;
}
let mut con_keys: HashMap<usize, Vec<&str>> = HashMap::new();
let mut result: HashMap<String, V> = HashMap::new();
let connections_count = self.connections.len();
for key in keys {
let connection_index = (self.hash_function)(key) as usize % connections_count;
let array = con_keys.entry(connection_index).or_insert_with(Vec::new);
array.push(key);
}
for (&connection_index, keys) in con_keys.iter() {
let connection = self.connections[connection_index].clone();
result.extend(connection.get()?.gets(keys)?);
}
return Ok(result);
}
pub fn set<V: ToMemcacheValue<Stream>>(&self, key: &str, value: V, expiration: u32) -> Result<(), MemcacheError> {
check_key_len(key)?;
return self.get_connection(key).get()?.set(key, value, expiration);
}
pub fn cas<V: ToMemcacheValue<Stream>>(
&self,
key: &str,
value: V,
expiration: u32,
cas_id: u64,
) -> Result<bool, MemcacheError> {
check_key_len(key)?;
self.get_connection(key).get()?.cas(key, value, expiration, cas_id)
}
pub fn add<V: ToMemcacheValue<Stream>>(&self, key: &str, value: V, expiration: u32) -> Result<(), MemcacheError> {
check_key_len(key)?;
return self.get_connection(key).get()?.add(key, value, expiration);
}
pub fn replace<V: ToMemcacheValue<Stream>>(
&self,
key: &str,
value: V,
expiration: u32,
) -> Result<(), MemcacheError> {
check_key_len(key)?;
return self.get_connection(key).get()?.replace(key, value, expiration);
}
pub fn append<V: ToMemcacheValue<Stream>>(&self, key: &str, value: V) -> Result<(), MemcacheError> {
check_key_len(key)?;
return self.get_connection(key).get()?.append(key, value);
}
pub fn prepend<V: ToMemcacheValue<Stream>>(&self, key: &str, value: V) -> Result<(), MemcacheError> {
check_key_len(key)?;
return self.get_connection(key).get()?.prepend(key, value);
}
pub fn delete(&self, key: &str) -> Result<bool, MemcacheError> {
check_key_len(key)?;
return self.get_connection(key).get()?.delete(key);
}
pub fn increment(&self, key: &str, amount: u64) -> Result<u64, MemcacheError> {
check_key_len(key)?;
return self.get_connection(key).get()?.increment(key, amount);
}
pub fn decrement(&self, key: &str, amount: u64) -> Result<u64, MemcacheError> {
check_key_len(key)?;
return self.get_connection(key).get()?.decrement(key, amount);
}
pub fn touch(&self, key: &str, expiration: u32) -> Result<bool, MemcacheError> {
check_key_len(key)?;
return self.get_connection(key).get()?.touch(key, expiration);
}
pub fn stats(&self) -> Result<Vec<(String, Stats)>, MemcacheError> {
let mut result: Vec<(String, HashMap<String, String>)> = vec![];
for connection in self.connections.iter() {
let mut connection = connection.get()?;
let stats_info = connection.stats()?;
let url = connection.get_url();
result.push((url, stats_info));
}
return Ok(result);
}
}
#[cfg(test)]
mod tests {
#[cfg(unix)]
#[test]
fn unix() {
let client = super::Client::connect("memcache:///tmp/memcached.sock").unwrap();
assert!(client.version().unwrap()[0].1 != "");
}
#[cfg(feature = "tls")]
#[test]
fn ssl_noverify() {
let client = super::Client::connect("memcache+tls://localhost:12350?verify_mode=none").unwrap();
assert!(client.version().unwrap()[0].1 != "");
}
#[cfg(feature = "tls")]
#[test]
fn ssl_verify() {
let client =
super::Client::connect("memcache+tls://localhost:12350?ca_path=tests/assets/RUST_MEMCACHE_TEST_CERT.crt")
.unwrap();
assert!(client.version().unwrap()[0].1 != "");
}
#[cfg(feature = "tls")]
#[test]
fn ssl_client_certs() {
let client = super::Client::connect("memcache+tls://localhost:12351?key_path=tests/assets/client.key&cert_path=tests/assets/client.crt&ca_path=tests/assets/RUST_MEMCACHE_TEST_CERT.crt").unwrap();
assert!(client.version().unwrap()[0].1 != "");
}
#[test]
fn delete() {
let client = super::Client::connect("memcache://localhost:12345").unwrap();
client.set("an_exists_key", "value", 0).unwrap();
assert_eq!(client.delete("an_exists_key").unwrap(), true);
assert_eq!(client.delete("a_not_exists_key").unwrap(), false);
}
#[test]
fn increment() {
let client = super::Client::connect("memcache://localhost:12345").unwrap();
client.delete("counter").unwrap();
client.set("counter", 321, 0).unwrap();
assert_eq!(client.increment("counter", 123).unwrap(), 444);
}
}