session 0.1.8

iron session implementation based on redis
extern crate bincode;
extern crate iron;
extern crate cookie;
extern crate rand;
extern crate r2d2;
extern crate r2d2_redis;
extern crate rustc_serialize;
extern crate redis;
#[macro_use]
extern crate log;

use iron::prelude::*;
use iron::middleware::{Handler, AroundMiddleware};
use iron::typemap::Key;
use iron::headers::SetCookie;
use iron::headers::Cookie as IronCookie;
use cookie::{Cookie, CookieJar};
use rand::{thread_rng, Rng};
use rustc_serialize::{Encodable, Decodable};
use bincode::SizeLimit;
use bincode::rustc_serialize::{encode, decode};
use redis::{RedisResult, ToRedisArgs, FromRedisValue, RedisError, Value};
use std::sync::Arc;
use std::io::{Error, ErrorKind};
use std::default::Default;
use r2d2::{Pool, PooledConnection};
use r2d2_redis::RedisConnectionManager;

const COOKIE_SESSION_KEY: &'static str = "RUSESSION";
const ERROR_MESSAGE: &'static str = "No required extension fields found in request.check if \
                                     session.Session has been linked around to Chain.If there are \
                                     more than one middleware around to the same \
                                     chain,session.Session should be placed as last around";

pub trait RequestSessionExt {
    fn get_session<'c, 'd, T: Decodable>(&'c mut self, key: &'d str) -> RedisResult<T>;
    fn remove_session<'c, 'd>(&'c mut self, key: &'d str) -> RedisResult<()>;
    fn set_session<'c, 'd, T: Encodable>(&'c mut self, key: &'d str, value: &T) -> RedisResult<()>;
    fn clear_session<'c, 'd>(&'c mut self) -> RedisResult<()>;
}

pub struct Session {
    cache: Arc<Cache>,
}

impl Session {
    pub fn new<T1, T2>(signing_key: T1, expire_seconds: usize, connect_str: T2) -> Self
        where T1: AsRef<str>,
              T2: AsRef<str>
    {
        let cache = Cache {
            signing_key: Arc::new(String::from(signing_key.as_ref()).into_bytes()),
            expire_seconds: expire_seconds,
            pool: connect_pool(connect_str.as_ref()),
        };
        Session { cache: Arc::new(cache) }
    }
}

impl AroundMiddleware for Session {
    fn around(self, handler: Box<Handler>) -> Box<Handler> {
        struct SessionHandler<H: Handler> {
            cache: Arc<Cache>,
            handler: H,
        }
        impl<H: Handler> Handler for SessionHandler<H> {
            fn handle(&self, req: &mut Request) -> IronResult<Response> {
                let (session_id, new_created) = req.headers
                    .get::<IronCookie>()
                    .map(|c| c.to_cookie_jar(&self.cache.signing_key))
                    .unwrap_or(CookieJar::new(&self.cache.signing_key))
                    .signed()
                    .iter()
                    .filter(|c| c.name == COOKIE_SESSION_KEY)
                    .map(|c| (c.value, false))
                    .next()
                    .unwrap_or((thread_rng().gen_ascii_chars().take(32).collect::<String>(), true));

                req.extensions.insert::<SessionKey>(Arc::new(session_id.clone()));
                req.extensions.insert::<CacheKey>(self.cache.clone());

                self.handler.handle(req).map(|mut res| {
                    if new_created {
                        let c = CookieJar::new(&self.cache.signing_key);
                        c.signed().add(Cookie::new(COOKIE_SESSION_KEY.into(), session_id));
                        res.headers.set(SetCookie(c.delta()));
                    } else {
                        self.cache.refresh(&session_id).expect("refresh session error");
                    }
                    res
                })
            }
        }
        Box::new(SessionHandler {
            cache: self.cache,
            handler: handler,
        })
    }
}

struct Cache {
    expire_seconds: usize,
    signing_key: Arc<Vec<u8>>,
    pool: RedisResult<Pool<RedisConnectionManager>>,
}

impl Cache {
    fn get_conn(&self) -> RedisResult<PooledConnection<RedisConnectionManager>> {
        match self.pool {
            Ok(ref p) => p.get().map_err(|e| error(&format!("{}", e))),
            Err(ref e) => Err(error(&format!("{}", e))),
        }
    }

    fn refresh(&self, session_id: &str) -> RedisResult<()> {
        self.get_conn()
            .and_then(|c| redis::cmd("EXPIRE").arg(session_id).arg(self.expire_seconds).query(&*c))
    }
    fn set<T: Encodable>(&self, session_id: &str, key: &str, value: &T) -> RedisResult<()> {
        self.get_conn().and_then(|c| {
            redis::cmd("HSET").arg(session_id).arg(key).arg(EncodeWrapper(value)).query(&*c)
        })
    }

    fn get<T: Decodable>(&self, session_id: &str, key: &str) -> RedisResult<T> {
        self.get_conn().and_then(|c| {
            redis::cmd("HGET").arg(session_id).arg(key).query::<DecodeWrapper<T>>(&*c).map(|v| v.0)
        })
    }

    fn remove(&self, session_id: &str, key: &str) -> RedisResult<()> {
        self.get_conn().and_then(|c| redis::cmd("HDEL").arg(session_id).arg(key).query(&*c))
    }

    fn clear(&self, session_id: &str) -> RedisResult<()> {
        self.get_conn().and_then(|c| redis::cmd("DEL").arg(session_id).query(&*c))
    }
}

fn connect_pool(connect_str: &str) -> RedisResult<Pool<RedisConnectionManager>> {
    let cache = Default::default();
    info!("Connecting to {}", connect_str);
    RedisConnectionManager::new(connect_str)
        .map_err(|e| error(&format!("error occurs in RedisConnectionManager::new():{}", e)))
        .and_then(|m| {
            Pool::new(cache, m)
                .map_err(|e| error(&format!("Error connecting redis {},{}", connect_str, e)))
        })
}

fn get_extension<'a>(req: &'a mut Request) -> RedisResult<(&'a Arc<Cache>, &'a Arc<String>)> {
    match req.extensions.get::<CacheKey>() {
        Some(cache) => {
            match req.extensions.get::<SessionKey>() {
                Some(session_id) => Ok((cache, session_id)),
                None => Err(error(ERROR_MESSAGE)),
            }
        }
        None => Err(error(ERROR_MESSAGE)),
    }
}

impl<'a, 'b> RequestSessionExt for Request<'a, 'b> {
    fn get_session<'c, 'd, T: Decodable>(&'c mut self, key: &'d str) -> RedisResult<T> {
        get_extension(self).and_then(|(c, s)| c.get(s, key))
    }
    fn set_session<'c, 'd, T: Encodable>(&'c mut self, key: &'d str, value: &T) -> RedisResult<()> {
        get_extension(self).and_then(|(c, s)| c.set(s, key, value))
    }
    fn remove_session<'c, 'd>(&'c mut self, key: &'d str) -> RedisResult<()> {
        get_extension(self).and_then(|(c, s)| c.remove(s, key))
    }
    fn clear_session<'c, 'd>(&'c mut self) -> RedisResult<()> {
        get_extension(self).and_then(|(c, s)| c.clear(s))
    }
}

struct SessionKey;
impl Key for SessionKey {
    type Value = Arc<String>;
}

struct CacheKey;
impl Key for CacheKey {
    type Value = Arc<Cache>;
}



struct EncodeWrapper<'a, T: 'a>(&'a T);
impl<'a, T: Encodable + 'a> ToRedisArgs for EncodeWrapper<'a, T> {
    fn to_redis_args(&self) -> Vec<Vec<u8>> {
        match encode(self.0, SizeLimit::Infinite) {
            Ok(data) => vec![data],
            Err(e) => panic!("error occurs in to_redis_args,error is:{}", e),
        }
    }
}

fn error(e: &str) -> RedisError {
    RedisError::from(Error::new(ErrorKind::Other, e))
}

struct DecodeWrapper<T>(T);
impl<T: Decodable> FromRedisValue for DecodeWrapper<T> {
    fn from_redis_value(v: &Value) -> RedisResult<DecodeWrapper<T>> {
        match *v {
            Value::Data(ref items) => {
                decode(&items[..])
                    .map(|v| DecodeWrapper(v))
                    .map_err(|e| error(&format!("{}", e)))
            }
            _ => Err(error("no session found!")),
        }
    }
}