extern crate bincode;
extern crate iron;
extern crate cookie;
extern crate rand;
extern crate r2d2;
extern crate r2d2_redis;
extern crate rustc_serialize;
extern crate redis;
#[macro_use]
extern crate log;
use iron::prelude::*;
use iron::middleware::{Handler, AroundMiddleware};
use iron::typemap::Key;
use iron::headers::SetCookie;
use iron::headers::Cookie as IronCookie;
use cookie::{Cookie, CookieJar};
use rand::{thread_rng, Rng};
use rustc_serialize::{Encodable, Decodable};
use bincode::SizeLimit;
use bincode::rustc_serialize::{encode, decode};
use redis::{RedisResult, ToRedisArgs, FromRedisValue, RedisError, Value};
use std::sync::Arc;
use std::io::{Error, ErrorKind};
use std::default::Default;
use r2d2::{Pool, PooledConnection};
use r2d2_redis::RedisConnectionManager;
const COOKIE_SESSION_KEY: &'static str = "RUSESSION";
const ERROR_MESSAGE: &'static str = "No required extension fields found in request.check if \
session.Session has been linked around to Chain.If there are \
more than one middleware around to the same \
chain,session.Session should be placed as last around";
pub trait RequestSessionExt {
fn get_session<'c, 'd, T: Decodable>(&'c mut self, key: &'d str) -> RedisResult<T>;
fn remove_session<'c, 'd>(&'c mut self, key: &'d str) -> RedisResult<()>;
fn set_session<'c, 'd, T: Encodable>(&'c mut self, key: &'d str, value: &T) -> RedisResult<()>;
fn clear_session<'c, 'd>(&'c mut self) -> RedisResult<()>;
}
pub struct Session {
cache: Arc<Cache>,
}
impl Session {
pub fn new<T1, T2>(signing_key: T1, expire_seconds: usize, connect_str: T2) -> Self
where T1: AsRef<str>,
T2: AsRef<str>
{
let cache = Cache {
signing_key: Arc::new(String::from(signing_key.as_ref()).into_bytes()),
expire_seconds: expire_seconds,
pool: connect_pool(connect_str.as_ref()),
};
Session { cache: Arc::new(cache) }
}
}
impl AroundMiddleware for Session {
fn around(self, handler: Box<Handler>) -> Box<Handler> {
struct SessionHandler<H: Handler> {
cache: Arc<Cache>,
handler: H,
}
impl<H: Handler> Handler for SessionHandler<H> {
fn handle(&self, req: &mut Request) -> IronResult<Response> {
let (session_id, new_created) = req.headers
.get::<IronCookie>()
.map(|c| c.to_cookie_jar(&self.cache.signing_key))
.unwrap_or(CookieJar::new(&self.cache.signing_key))
.signed()
.iter()
.filter(|c| c.name == COOKIE_SESSION_KEY)
.map(|c| (c.value, false))
.next()
.unwrap_or((thread_rng().gen_ascii_chars().take(32).collect::<String>(), true));
req.extensions.insert::<SessionKey>(Arc::new(session_id.clone()));
req.extensions.insert::<CacheKey>(self.cache.clone());
self.handler.handle(req).map(|mut res| {
if new_created {
let c = CookieJar::new(&self.cache.signing_key);
c.signed().add(Cookie::new(COOKIE_SESSION_KEY.into(), session_id));
res.headers.set(SetCookie(c.delta()));
} else {
self.cache.refresh(&session_id).expect("refresh session error");
}
res
})
}
}
Box::new(SessionHandler {
cache: self.cache,
handler: handler,
})
}
}
struct Cache {
expire_seconds: usize,
signing_key: Arc<Vec<u8>>,
pool: RedisResult<Pool<RedisConnectionManager>>,
}
impl Cache {
fn get_conn(&self) -> RedisResult<PooledConnection<RedisConnectionManager>> {
match self.pool {
Ok(ref p) => p.get().map_err(|e| error(&format!("{}", e))),
Err(ref e) => Err(error(&format!("{}", e))),
}
}
fn refresh(&self, session_id: &str) -> RedisResult<()> {
self.get_conn()
.and_then(|c| redis::cmd("EXPIRE").arg(session_id).arg(self.expire_seconds).query(&*c))
}
fn set<T: Encodable>(&self, session_id: &str, key: &str, value: &T) -> RedisResult<()> {
self.get_conn().and_then(|c| {
redis::cmd("HSET").arg(session_id).arg(key).arg(EncodeWrapper(value)).query(&*c)
})
}
fn get<T: Decodable>(&self, session_id: &str, key: &str) -> RedisResult<T> {
self.get_conn().and_then(|c| {
redis::cmd("HGET").arg(session_id).arg(key).query::<DecodeWrapper<T>>(&*c).map(|v| v.0)
})
}
fn remove(&self, session_id: &str, key: &str) -> RedisResult<()> {
self.get_conn().and_then(|c| redis::cmd("HDEL").arg(session_id).arg(key).query(&*c))
}
fn clear(&self, session_id: &str) -> RedisResult<()> {
self.get_conn().and_then(|c| redis::cmd("DEL").arg(session_id).query(&*c))
}
}
fn connect_pool(connect_str: &str) -> RedisResult<Pool<RedisConnectionManager>> {
let cache = Default::default();
info!("Connecting to {}", connect_str);
RedisConnectionManager::new(connect_str)
.map_err(|e| error(&format!("error occurs in RedisConnectionManager::new():{}", e)))
.and_then(|m| {
Pool::new(cache, m)
.map_err(|e| error(&format!("Error connecting redis {},{}", connect_str, e)))
})
}
fn get_extension<'a>(req: &'a mut Request) -> RedisResult<(&'a Arc<Cache>, &'a Arc<String>)> {
match req.extensions.get::<CacheKey>() {
Some(cache) => {
match req.extensions.get::<SessionKey>() {
Some(session_id) => Ok((cache, session_id)),
None => Err(error(ERROR_MESSAGE)),
}
}
None => Err(error(ERROR_MESSAGE)),
}
}
impl<'a, 'b> RequestSessionExt for Request<'a, 'b> {
fn get_session<'c, 'd, T: Decodable>(&'c mut self, key: &'d str) -> RedisResult<T> {
get_extension(self).and_then(|(c, s)| c.get(s, key))
}
fn set_session<'c, 'd, T: Encodable>(&'c mut self, key: &'d str, value: &T) -> RedisResult<()> {
get_extension(self).and_then(|(c, s)| c.set(s, key, value))
}
fn remove_session<'c, 'd>(&'c mut self, key: &'d str) -> RedisResult<()> {
get_extension(self).and_then(|(c, s)| c.remove(s, key))
}
fn clear_session<'c, 'd>(&'c mut self) -> RedisResult<()> {
get_extension(self).and_then(|(c, s)| c.clear(s))
}
}
struct SessionKey;
impl Key for SessionKey {
type Value = Arc<String>;
}
struct CacheKey;
impl Key for CacheKey {
type Value = Arc<Cache>;
}
struct EncodeWrapper<'a, T: 'a>(&'a T);
impl<'a, T: Encodable + 'a> ToRedisArgs for EncodeWrapper<'a, T> {
fn to_redis_args(&self) -> Vec<Vec<u8>> {
match encode(self.0, SizeLimit::Infinite) {
Ok(data) => vec![data],
Err(e) => panic!("error occurs in to_redis_args,error is:{}", e),
}
}
}
fn error(e: &str) -> RedisError {
RedisError::from(Error::new(ErrorKind::Other, e))
}
struct DecodeWrapper<T>(T);
impl<T: Decodable> FromRedisValue for DecodeWrapper<T> {
fn from_redis_value(v: &Value) -> RedisResult<DecodeWrapper<T>> {
match *v {
Value::Data(ref items) => {
decode(&items[..])
.map(|v| DecodeWrapper(v))
.map_err(|e| error(&format!("{}", e)))
}
_ => Err(error("no session found!")),
}
}
}