1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
// Copyright (c) 2016 The Rouille developers
// Licensed under the Apache License, Version 2.0
// <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT
// license <LICENSE-MIT or http://opensource.org/licenses/MIT>,
// at your option. All files in the project carrying such
// notice may not be copied, modified, or distributed except
// according to those terms.

//! Sessions handling.
//!
//! The main feature of this module is the `session` function which handles a session. This
//! function guarantees that a single unique identifier is assigned to each client. This identifier
//! is accessible through the parameter passed to the inner closure.
//!
//! # Basic example
//!
//! Here is a basic example showing how to get a session ID.
//!
//! ```
//! use rouille::Request;
//! use rouille::Response;
//! use rouille::session;
//!
//! fn handle_request(request: &Request) -> Response {
//!     session::session(request, "SID", 3600, |session| {
//!         let id: &str = session.id();
//!
//!         // This id is unique to each client.
//!
//!         Response::text(format!("Session ID: {}", id))
//!     })
//! }
//! ```

use std::borrow::Cow;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use rand;
use rand::Rng;
use rand::distributions::Alphanumeric;

use Request;
use Response;
use input;

pub fn session<'r, F>(request: &'r Request, cookie_name: &str, timeout_s: u64, inner: F) -> Response
    where F: FnOnce(&Session<'r>) -> Response
{
    let mut cookie = input::cookies(request).into_iter();
    let cookie = cookie.find(|&(ref k, _)| k == &cookie_name);
    let cookie = cookie.map(|(_, v)| v);

    let session = if let Some(cookie) = cookie {
        Session {
            key_was_retreived: AtomicBool::new(false),
            key_was_given: true,
            key: cookie.into(),
        }
    } else {
        Session {
            key_was_retreived: AtomicBool::new(false),
            key_was_given: false,
            key: generate_session_id().into(),
        }
    };

    let mut response = inner(&session);

    if session.key_was_retreived.load(Ordering::Relaxed) {       // TODO: use `get_mut()`
        // FIXME: correct interactions with existing headers
        // TODO: allow setting domain
        let header_value = format!("{}={}; Max-Age={}; Path=/; HttpOnly",
                                    cookie_name, session.key, timeout_s);
        response.headers.push(("Set-Cookie".into(), header_value.into()));
    }

    response
}

/// Contains the ID of the session.
pub struct Session<'r> {
    key_was_retreived: AtomicBool,
    key_was_given: bool,
    key: Cow<'r, str>,
}

impl<'r> Session<'r> {
    /// Returns true if the client gave us a session ID.
    ///
    /// If this returns false, then we are sure that no data is available.
    #[inline]
    pub fn client_has_sid(&self) -> bool {
        self.key_was_given
    }

    /// Returns the id of the session.
    #[inline]
    pub fn id(&self) -> &str {
        self.key_was_retreived.store(true, Ordering::Relaxed);
        &self.key
    }

    /*/// Generates a new id. This modifies the value returned by `id()`.
    // TODO: implement
    #[inline]
    pub fn regenerate_id(&self) {
        unimplemented!()
    }*/
}

/// Generates a string suitable for a session ID.
///
/// The output string doesn't contain any punctuation or character such as quotes or brackets
/// that could need to be escaped.
pub fn generate_session_id() -> String {
    // 5e+114 possibilities is reasonable.
    rand::OsRng::new().expect("Failed to initialize OsRng")     // TODO: <- handle that?
                      .sample_iter(&Alphanumeric)
                      .filter(|&c| (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
                                   (c >= '0' && c <= '9'))
                      .take(64).collect::<String>()
}