use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::time::Duration;
use url::Url;
use connection::Connection;
use error::MemcacheError;
use protocol::{Protocol, ProtocolTrait};
use stream::Stream;
use value::{FromMemcacheValueExt, ToMemcacheValue};
pub type Stats = HashMap<String, String>;
pub trait Connectable {
fn get_urls(self) -> Vec<String>;
}
impl Connectable for String {
fn get_urls(self) -> Vec<String> {
return vec![self];
}
}
impl Connectable for Vec<String> {
fn get_urls(self) -> Vec<String> {
return self;
}
}
impl Connectable for &str {
fn get_urls(self) -> Vec<String> {
return vec![self.to_string()];
}
}
impl Connectable for Vec<&str> {
fn get_urls(self) -> Vec<String> {
let mut urls = vec![];
for url in self {
urls.push(url.to_string());
}
return urls;
}
}
pub struct Client {
connections: Vec<Connection>,
pub hash_function: fn(&str) -> u64,
}
fn default_hash_function(key: &str) -> u64 {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
return hasher.finish();
}
impl Client {
#[deprecated(since = "0.10.0", note = "please use `connect` instead")]
pub fn new<C: Connectable>(target: C) -> Result<Self, MemcacheError> {
return Self::connect(target);
}
pub fn connect<C: Connectable>(target: C) -> Result<Self, MemcacheError> {
let urls = target.get_urls();
let mut connections = vec![];
for url in urls {
let parsed = match Url::parse(url.as_str()) {
Ok(v) => v,
Err(_) => return Err(MemcacheError::ClientError("Invalid memcache URL".into())),
};
let mut connection = Connection::connect(&parsed)?;
if parsed.has_authority() && parsed.username() != "" && parsed.password().is_some() {
let username = parsed.username();
let password = parsed.password().unwrap();
connection.protocol.auth(username, password)?;
}
connections.push(connection);
}
return Ok(Client {
connections: connections,
hash_function: default_hash_function,
});
}
fn get_connection(&mut self, key: &str) -> &mut Connection {
let connections_count = self.connections.len();
return &mut self.connections[(self.hash_function)(key) as usize % connections_count];
}
pub fn set_read_timeout(&mut self, timeout: Option<Duration>) -> Result<(), MemcacheError> {
for conn in self.connections.iter_mut() {
match conn.protocol {
Protocol::Ascii(ref mut protocol) => protocol.reader.get_mut().set_read_timeout(timeout)?,
Protocol::Binary(ref mut protocol) => protocol.stream.set_read_timeout(timeout)?,
}
}
Ok(())
}
pub fn set_write_timeout(&mut self, timeout: Option<Duration>) -> Result<(), MemcacheError> {
for conn in self.connections.iter_mut() {
match conn.protocol {
Protocol::Ascii(ref mut protocol) => protocol.reader.get_mut().set_read_timeout(timeout)?,
Protocol::Binary(ref mut protocol) => protocol.stream.set_write_timeout(timeout)?,
}
}
Ok(())
}
pub fn version(&mut self) -> Result<Vec<(String, String)>, MemcacheError> {
let mut result: Vec<(String, String)> = vec![];
for connection in &mut self.connections {
result.push(("".into(), connection.protocol.version()?));
}
return Ok(result);
}
pub fn flush(&mut self) -> Result<(), MemcacheError> {
for connection in &mut self.connections {
connection.protocol.flush()?;
}
return Ok(());
}
pub fn flush_with_delay(&mut self, delay: u32) -> Result<(), MemcacheError> {
for connection in &mut self.connections {
connection.protocol.flush_with_delay(delay)?;
}
return Ok(());
}
pub fn get<V: FromMemcacheValueExt>(&mut self, key: &str) -> Result<Option<V>, MemcacheError> {
return self.get_connection(key).protocol.get(key);
}
pub fn gets<V: FromMemcacheValueExt>(&mut self, keys: &[&str]) -> Result<HashMap<String, V>, MemcacheError> {
let mut con_keys: HashMap<usize, Vec<&str>> = HashMap::new();
let mut result: HashMap<String, V> = HashMap::new();
let connections_count = self.connections.len();
for key in keys {
let connection_index = (self.hash_function)(key) as usize % connections_count;
let array = con_keys.entry(connection_index).or_insert_with(Vec::new);
array.push(key);
}
for (&connection_index, keys) in con_keys.iter() {
let connection = &mut self.connections[connection_index];
result.extend(connection.protocol.gets(keys)?);
}
return Ok(result);
}
pub fn set<V: ToMemcacheValue<Stream>>(
&mut self,
key: &str,
value: V,
expiration: u32,
) -> Result<(), MemcacheError> {
return self.get_connection(key).protocol.set(key, value, expiration);
}
pub fn cas<V: ToMemcacheValue<Stream>>(
&mut self,
key: &str,
value: V,
expiration: u32,
cas_id: u64,
) -> Result<bool, MemcacheError> {
self.get_connection(key).protocol.cas(key, value, expiration, cas_id)
}
pub fn add<V: ToMemcacheValue<Stream>>(
&mut self,
key: &str,
value: V,
expiration: u32,
) -> Result<(), MemcacheError> {
return self.get_connection(key).protocol.add(key, value, expiration);
}
pub fn replace<V: ToMemcacheValue<Stream>>(
&mut self,
key: &str,
value: V,
expiration: u32,
) -> Result<(), MemcacheError> {
return self.get_connection(key).protocol.replace(key, value, expiration);
}
pub fn append<V: ToMemcacheValue<Stream>>(&mut self, key: &str, value: V) -> Result<(), MemcacheError> {
return self.get_connection(key).protocol.append(key, value);
}
pub fn prepend<V: ToMemcacheValue<Stream>>(&mut self, key: &str, value: V) -> Result<(), MemcacheError> {
return self.get_connection(key).protocol.prepend(key, value);
}
pub fn delete(&mut self, key: &str) -> Result<bool, MemcacheError> {
return self.get_connection(key).protocol.delete(key);
}
pub fn increment(&mut self, key: &str, amount: u64) -> Result<u64, MemcacheError> {
return self.get_connection(key).protocol.increment(key, amount);
}
pub fn decrement(&mut self, key: &str, amount: u64) -> Result<u64, MemcacheError> {
return self.get_connection(key).protocol.decrement(key, amount);
}
pub fn touch(&mut self, key: &str, expiration: u32) -> Result<bool, MemcacheError> {
return self.get_connection(key).protocol.touch(key, expiration);
}
pub fn stats(&mut self) -> Result<Vec<(String, Stats)>, MemcacheError> {
let mut result: Vec<(String, HashMap<String, String>)> = vec![];
for connection in &mut self.connections {
let stats_info = connection.protocol.stats()?;
let url = connection.url.clone();
result.push((url, stats_info));
}
return Ok(result);
}
}
#[cfg(test)]
mod tests {
#[cfg(unix)]
#[test]
fn unix() {
let mut client = super::Client::connect("memcache:///tmp/memcached.sock").unwrap();
assert!(client.version().unwrap()[0].1 != "");
}
#[cfg(feature = "tls")]
#[test]
fn ssl_noverify() {
let mut client = super::Client::connect("memcache+tls://localhost:12350?verify_mode=none").unwrap();
assert!(client.version().unwrap()[0].1 != "");
}
#[cfg(feature = "tls")]
#[test]
fn ssl_verify() {
let mut client =
super::Client::connect("memcache+tls://localhost:12350?ca_path=tests/assets/RUST_MEMCACHE_TEST_CERT.crt")
.unwrap();
assert!(client.version().unwrap()[0].1 != "");
}
#[cfg(feature = "tls")]
#[test]
fn ssl_client_certs() {
let mut client = super::Client::connect("memcache+tls://localhost:12351?key_path=tests/assets/client.key&cert_path=tests/assets/client.crt&ca_path=tests/assets/RUST_MEMCACHE_TEST_CERT.crt").unwrap();
assert!(client.version().unwrap()[0].1 != "");
}
#[test]
fn delete() {
let mut client = super::Client::connect("memcache://localhost:12345").unwrap();
client.set("an_exists_key", "value", 0).unwrap();
assert_eq!(client.delete("an_exists_key").unwrap(), true);
assert_eq!(client.delete("a_not_exists_key").unwrap(), false);
}
#[test]
fn increment() {
let mut client = super::Client::connect("memcache://localhost:12345").unwrap();
client.delete("counter").unwrap();
client.set("counter", 321, 0).unwrap();
assert_eq!(client.increment("counter", 123).unwrap(), 444);
}
}