1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
use std::error::Error;
use std::result::Result;
use nickel::{Request, Response, Middleware, Continue, MiddlewareResult};
use nickel::status::StatusCode;
use r2d2_sqlite::SqliteConnectionManager;
use r2d2::{Config, Pool, PooledConnection, GetTimeout};
use typemap::Key;
use plugin::Extensible;

pub struct SqliteMiddleware {
  pub pool: Pool<SqliteConnectionManager>
}

impl SqliteMiddleware {

  /// Create middleware using defaults
  ///
  /// The middleware will be setup with the r2d2 defaults.
  pub fn new(db_url: &str) -> Result<SqliteMiddleware, Box<Error>> {
      let manager = SqliteConnectionManager::new(db_url);
      let pool = try!(Pool::new(Config::default(), manager));
      Ok(SqliteMiddleware { pool: pool })
  }

  /// Create middleware using pre-built `r2d2::Pool`
  ///
  /// This allows the caller to create and configure the pool with specific settings.
  pub fn with_pool(pool: Pool<SqliteConnectionManager>) -> SqliteMiddleware {
    SqliteMiddleware { pool: pool }
  }
}

impl Key for SqliteMiddleware { type Value = Pool<SqliteConnectionManager>; }

impl<D> Middleware<D> for SqliteMiddleware {
  fn invoke<'a>(&self, req: &mut Request<D>, res: Response<'a, D>) -> MiddlewareResult<'a, D> {
    req.extensions_mut().insert::<SqliteMiddleware>(self.pool.clone());
    Ok(Continue(res))
  }
}

/// Add `db_conn()` helper method to `nickel::Request`
///
/// This trait must only be used in conjunction with `SqliteMiddleware`.
///
/// On error, the method returns a tuple per Nickel convention.
/// This allows the route to use the `try_with!` macro.
///
/// Example:
///
/// ```ignore
/// app.get("/my_counter", middleware! { |request, response|
/// 	let db = try_with!(response, request.db_conn());
/// });
/// ```
pub trait SqliteRequestExtensions {
  fn db_conn(&self) -> Result<PooledConnection<SqliteConnectionManager>, (StatusCode, GetTimeout)>;
}

impl<'a, 'b, D> SqliteRequestExtensions for Request<'a, 'b, D> {
  fn db_conn(&self) -> Result<PooledConnection<SqliteConnectionManager>, (StatusCode, GetTimeout)> {
    self.extensions()
        .get::<SqliteMiddleware>()
        .expect("SqliteMiddleware must be registered before using SqliteRequestExtensions")
        .get()
        .or_else(|err| Err((StatusCode::InternalServerError, err)))
  }
}