pub(in crate::biome) mod models;
mod operations;
pub(in crate::biome) mod schema;
use std::sync::{Arc, RwLock};
use diesel::r2d2::{ConnectionManager, Pool};
use crate::store::pool::ConnectionPool;
use super::{
InsertableOAuthUserSession, OAuthUser, OAuthUserIter, OAuthUserSession, OAuthUserSessionStore,
OAuthUserSessionStoreError,
};
use operations::{
add_session::OAuthUserSessionStoreAddSession as _,
get_session::OAuthUserSessionStoreGetSession as _, get_user::OAuthUserSessionStoreGetUser as _,
list_users::OAuthUserSessionStoreListUsers as _,
remove_session::OAuthUserSessionStoreRemoveSession as _,
update_session::OAuthUserSessionStoreUpdateSession as _, OAuthUserSessionStoreOperations,
};
pub struct DieselOAuthUserSessionStore<C: diesel::Connection + 'static> {
connection_pool: ConnectionPool<C>,
}
impl<C: diesel::Connection + 'static> DieselOAuthUserSessionStore<C> {
pub fn new(connection_pool: Pool<ConnectionManager<C>>) -> Self {
Self {
connection_pool: connection_pool.into(),
}
}
pub fn new_with_write_exclusivity(
connection_pool: Arc<RwLock<Pool<ConnectionManager<C>>>>,
) -> Self {
Self {
connection_pool: connection_pool.into(),
}
}
}
#[cfg(feature = "sqlite")]
impl OAuthUserSessionStore for DieselOAuthUserSessionStore<diesel::sqlite::SqliteConnection> {
fn add_session(
&self,
session: InsertableOAuthUserSession,
) -> Result<(), OAuthUserSessionStoreError> {
self.connection_pool.execute_write(|connection| {
OAuthUserSessionStoreOperations::new(connection).add_session(session)
})
}
fn update_session(
&self,
session: InsertableOAuthUserSession,
) -> Result<(), OAuthUserSessionStoreError> {
self.connection_pool.execute_write(|connection| {
OAuthUserSessionStoreOperations::new(connection).update_session(session)
})
}
fn remove_session(
&self,
splinter_access_token: &str,
) -> Result<(), OAuthUserSessionStoreError> {
self.connection_pool.execute_write(|connection| {
OAuthUserSessionStoreOperations::new(connection).remove_session(splinter_access_token)
})
}
fn get_session(
&self,
splinter_access_token: &str,
) -> Result<Option<OAuthUserSession>, OAuthUserSessionStoreError> {
self.connection_pool.execute_read(|connection| {
OAuthUserSessionStoreOperations::new(connection).get_session(splinter_access_token)
})
}
fn get_user(&self, subject: &str) -> Result<Option<OAuthUser>, OAuthUserSessionStoreError> {
self.connection_pool.execute_read(|connection| {
OAuthUserSessionStoreOperations::new(connection).get_user(subject)
})
}
fn list_users(&self) -> Result<OAuthUserIter, OAuthUserSessionStoreError> {
self.connection_pool.execute_read(|connection| {
OAuthUserSessionStoreOperations::new(connection).list_users()
})
}
fn clone_box(&self) -> Box<dyn OAuthUserSessionStore> {
Box::new(Self {
connection_pool: self.connection_pool.clone(),
})
}
}
#[cfg(feature = "postgres")]
impl OAuthUserSessionStore for DieselOAuthUserSessionStore<diesel::pg::PgConnection> {
fn add_session(
&self,
session: InsertableOAuthUserSession,
) -> Result<(), OAuthUserSessionStoreError> {
self.connection_pool.execute_write(|connection| {
OAuthUserSessionStoreOperations::new(connection).add_session(session)
})
}
fn update_session(
&self,
session: InsertableOAuthUserSession,
) -> Result<(), OAuthUserSessionStoreError> {
self.connection_pool.execute_write(|connection| {
OAuthUserSessionStoreOperations::new(connection).update_session(session)
})
}
fn remove_session(
&self,
splinter_access_token: &str,
) -> Result<(), OAuthUserSessionStoreError> {
self.connection_pool.execute_write(|connection| {
OAuthUserSessionStoreOperations::new(connection).remove_session(splinter_access_token)
})
}
fn get_session(
&self,
splinter_access_token: &str,
) -> Result<Option<OAuthUserSession>, OAuthUserSessionStoreError> {
self.connection_pool.execute_read(|connection| {
OAuthUserSessionStoreOperations::new(connection).get_session(splinter_access_token)
})
}
fn get_user(&self, subject: &str) -> Result<Option<OAuthUser>, OAuthUserSessionStoreError> {
self.connection_pool.execute_read(|connection| {
OAuthUserSessionStoreOperations::new(connection).get_user(subject)
})
}
fn list_users(&self) -> Result<OAuthUserIter, OAuthUserSessionStoreError> {
self.connection_pool.execute_read(|connection| {
OAuthUserSessionStoreOperations::new(connection).list_users()
})
}
fn clone_box(&self) -> Box<dyn OAuthUserSessionStore> {
Box::new(Self {
connection_pool: self.connection_pool.clone(),
})
}
}
#[cfg(all(test, feature = "sqlite"))]
pub mod tests {
use super::*;
use crate::biome::oauth::store::InsertableOAuthUserSessionBuilder;
use crate::migrations::run_sqlite_migrations;
use diesel::{
r2d2::{ConnectionManager, Pool},
sqlite::SqliteConnection,
};
#[test]
fn sqlite_add_and_get_session() {
let pool = create_connection_pool_and_migrate();
let oauth_user_session_store = DieselOAuthUserSessionStore::new(pool);
let splinter_access_token = "splinter_access_token";
let subject = "subject";
let oauth_access_token = "oauth_access_token";
let session = InsertableOAuthUserSessionBuilder::new()
.with_splinter_access_token(splinter_access_token.into())
.with_subject(subject.into())
.with_oauth_access_token(oauth_access_token.into())
.build()
.expect("Unable to build session");
oauth_user_session_store
.add_session(session)
.expect("Unable to add session");
let session = oauth_user_session_store
.get_session(splinter_access_token)
.expect("Unable to get session")
.expect("Session not found");
assert_eq!(session.splinter_access_token(), splinter_access_token);
assert_eq!(session.user().subject(), subject);
assert_eq!(session.oauth_access_token(), oauth_access_token);
assert_eq!(session.oauth_refresh_token(), None);
assert!(oauth_user_session_store
.get_session("NonExistentToken")
.expect("Unable to query non-existent token")
.is_none());
let non_unique_session = InsertableOAuthUserSessionBuilder::new()
.with_splinter_access_token("splinter_access_token".into())
.with_subject("different_subject".into())
.with_oauth_access_token("different_oauth_access_token".into())
.build()
.expect("Unable to build non-unique session");
assert!(matches!(
oauth_user_session_store.add_session(non_unique_session),
Err(OAuthUserSessionStoreError::ConstraintViolation(_)),
));
}
#[test]
fn sqlite_update_session() {
let pool = create_connection_pool_and_migrate();
let oauth_user_session_store = DieselOAuthUserSessionStore::new(pool);
let splinter_access_token = "splinter_access_token";
let subject = "subject";
let oauth_access_token = "oauth_access_token";
let session = InsertableOAuthUserSessionBuilder::new()
.with_splinter_access_token(splinter_access_token.into())
.with_subject(subject.into())
.with_oauth_access_token(oauth_access_token.into())
.build()
.expect("Unable to build session");
oauth_user_session_store
.add_session(session)
.expect("Unable to add session");
let session = oauth_user_session_store
.get_session(splinter_access_token)
.expect("Unable to get session")
.expect("Session not found");
let originally_authenticated = session.last_authenticated();
std::thread::sleep(std::time::Duration::from_secs(5));
let updated_oauth_access_token = "updated_oauth_access_token";
let updated_oauth_refresh_token = "updated_oauth_refresh_token";
let updated_session = session
.into_update_builder()
.with_oauth_access_token(updated_oauth_access_token.into())
.with_oauth_refresh_token(Some(updated_oauth_refresh_token.into()))
.build();
oauth_user_session_store
.update_session(updated_session)
.expect("Unable to update session");
let updated_session = oauth_user_session_store
.get_session(splinter_access_token)
.expect("Unable to get updated session")
.expect("Updated session not found");
assert_eq!(
updated_session.oauth_access_token(),
updated_oauth_access_token
);
assert_eq!(
updated_session.oauth_refresh_token(),
Some(updated_oauth_refresh_token)
);
assert!(updated_session.last_authenticated() > originally_authenticated);
let non_existent_session = InsertableOAuthUserSessionBuilder::new()
.with_splinter_access_token("NonExistentToken".into())
.with_subject(subject.into())
.with_oauth_access_token(oauth_access_token.into())
.build()
.expect("Unable to build non-existent session");
assert!(matches!(
oauth_user_session_store.update_session(non_existent_session),
Err(OAuthUserSessionStoreError::InvalidState(_)),
));
let update_subject = InsertableOAuthUserSessionBuilder::new()
.with_splinter_access_token(splinter_access_token.into())
.with_subject("updated_subject".into())
.with_oauth_access_token(oauth_access_token.into())
.build()
.expect("Unable to build session for updating subject");
assert!(matches!(
oauth_user_session_store.update_session(update_subject),
Err(OAuthUserSessionStoreError::InvalidArgument(_)),
));
}
#[test]
fn sqlite_remove_session() {
let pool = create_connection_pool_and_migrate();
let oauth_user_session_store = DieselOAuthUserSessionStore::new(pool);
let splinter_access_token = "splinter_access_token";
let session = InsertableOAuthUserSessionBuilder::new()
.with_splinter_access_token(splinter_access_token.into())
.with_subject("subject".into())
.with_oauth_access_token("oauth_access_token".into())
.build()
.expect("Unable to build session");
oauth_user_session_store
.add_session(session)
.expect("Unable to add session");
oauth_user_session_store
.remove_session(splinter_access_token)
.expect("Unable to remove session");
assert!(oauth_user_session_store
.get_session(splinter_access_token)
.expect("Unable to attempt to get session")
.is_none());
assert!(matches!(
oauth_user_session_store.remove_session("NonExistentToken"),
Err(OAuthUserSessionStoreError::InvalidState(_)),
));
}
#[test]
fn sqlite_get_user() {
let pool = create_connection_pool_and_migrate();
let oauth_user_session_store = DieselOAuthUserSessionStore::new(pool);
let splinter_access_token1 = "splinter_access_token1";
let subject = "subject";
let session1 = InsertableOAuthUserSessionBuilder::new()
.with_splinter_access_token(splinter_access_token1.into())
.with_subject(subject.into())
.with_oauth_access_token("oauth_access_token1".into())
.build()
.expect("Unable to build session1");
oauth_user_session_store
.add_session(session1)
.expect("Unable to add session1");
let user = oauth_user_session_store
.get_user(subject)
.expect("Unable to get user")
.expect("User not found");
assert_eq!(user.subject(), subject);
let splinter_access_token2 = "splinter_access_token2";
let session2 = InsertableOAuthUserSessionBuilder::new()
.with_splinter_access_token(splinter_access_token2.into())
.with_subject(subject.into())
.with_oauth_access_token("oauth_access_token2".into())
.build()
.expect("Unable to build session2");
oauth_user_session_store
.add_session(session2)
.expect("Unable to add session2");
let same_user = oauth_user_session_store
.get_user(subject)
.expect("Unable to get user")
.expect("User not found");
assert_eq!(user.subject(), same_user.subject());
assert_eq!(user.user_id(), same_user.user_id());
oauth_user_session_store
.remove_session(splinter_access_token1)
.expect("Unable to remove session1");
oauth_user_session_store
.remove_session(splinter_access_token2)
.expect("Unable to remove session2");
let still_the_same_user = oauth_user_session_store
.get_user(subject)
.expect("Unable to get user")
.expect("User not found");
assert_eq!(user.subject(), still_the_same_user.subject());
assert_eq!(user.user_id(), still_the_same_user.user_id());
}
#[test]
fn sqlite_multiple_sessions() {
let pool = create_connection_pool_and_migrate();
let oauth_user_session_store = DieselOAuthUserSessionStore::new(pool);
let splinter_access_token1 = "splinter_access_token1";
let subject = "subject";
let oauth_access_token1 = "oauth_access_token1";
let session1 = InsertableOAuthUserSessionBuilder::new()
.with_splinter_access_token(splinter_access_token1.into())
.with_subject(subject.into())
.with_oauth_access_token(oauth_access_token1.into())
.build()
.expect("Unable to build session1");
oauth_user_session_store
.add_session(session1)
.expect("Unable to add session1");
let splinter_access_token2 = "splinter_access_token2";
let oauth_access_token2 = "oauth_access_token2";
let session2 = InsertableOAuthUserSessionBuilder::new()
.with_splinter_access_token(splinter_access_token2.into())
.with_subject(subject.into())
.with_oauth_access_token(oauth_access_token2.into())
.build()
.expect("Unable to build session2");
oauth_user_session_store
.add_session(session2)
.expect("Unable to add session2");
let stored_session1 = oauth_user_session_store
.get_session(splinter_access_token1)
.expect("Unable to get session1")
.expect("Session1 not found");
assert_eq!(
stored_session1.splinter_access_token(),
splinter_access_token1
);
assert_eq!(stored_session1.user().subject(), subject);
assert_eq!(stored_session1.oauth_access_token(), oauth_access_token1);
let stored_session2 = oauth_user_session_store
.get_session(splinter_access_token2)
.expect("Unable to get session2")
.expect("Session2 not found");
assert_eq!(
stored_session2.splinter_access_token(),
splinter_access_token2
);
assert_eq!(stored_session2.user().subject(), subject);
assert_eq!(stored_session2.oauth_access_token(), oauth_access_token2);
}
#[test]
fn sqlite_list_users() {
let pool = create_connection_pool_and_migrate();
let oauth_user_session_store = DieselOAuthUserSessionStore::new(pool);
let splinter_access_token1 = "splinter_access_token1";
let subject = "subject";
let session1 = InsertableOAuthUserSessionBuilder::new()
.with_splinter_access_token(splinter_access_token1.into())
.with_subject(subject.into())
.with_oauth_access_token("oauth_access_token1".into())
.build()
.expect("Unable to build session1");
oauth_user_session_store
.add_session(session1)
.expect("Unable to add session1");
let users = oauth_user_session_store
.list_users()
.expect("Unable to list users")
.collect::<Vec<OAuthUser>>();
assert_eq!(users.len(), 1);
let user = users.get(0).expect("Unable to get user");
assert_eq!(user.subject(), subject);
let splinter_access_token2 = "splinter_access_token2";
let session2 = InsertableOAuthUserSessionBuilder::new()
.with_splinter_access_token(splinter_access_token2.into())
.with_subject(subject.into())
.with_oauth_access_token("oauth_access_token2".into())
.build()
.expect("Unable to build session2");
oauth_user_session_store
.add_session(session2)
.expect("Unable to add session2");
let users = oauth_user_session_store
.list_users()
.expect("Unable to list users")
.collect::<Vec<OAuthUser>>();
assert_eq!(users.len(), 1);
let first_user = users.get(0).expect("Unable to get user");
assert_eq!(first_user.subject(), subject);
let splinter_access_token3 = "splinter_access_token3";
let second_subject = "second_subject";
let session3 = InsertableOAuthUserSessionBuilder::new()
.with_splinter_access_token(splinter_access_token3.into())
.with_subject(second_subject.into())
.with_oauth_access_token("oauth_access_token3".into())
.build()
.expect("Unable to build session3");
oauth_user_session_store
.add_session(session3)
.expect("Unable to add session3");
let users = oauth_user_session_store
.list_users()
.expect("Unable to list users")
.collect::<Vec<OAuthUser>>();
assert_eq!(users.len(), 2);
let user = users.get(0).expect("Unable to get user");
assert_eq!(user.subject(), subject);
let second_user = users.get(1).expect("Unable to get user");
assert_eq!(second_user.subject(), second_subject);
oauth_user_session_store
.remove_session(splinter_access_token1)
.expect("Unable to remove session1");
oauth_user_session_store
.remove_session(splinter_access_token2)
.expect("Unable to remove session2");
oauth_user_session_store
.remove_session(splinter_access_token3)
.expect("Unable to remove session3");
let users = oauth_user_session_store
.list_users()
.expect("Unable to list users")
.collect::<Vec<OAuthUser>>();
assert_eq!(users.len(), 2);
}
fn create_connection_pool_and_migrate() -> Pool<ConnectionManager<SqliteConnection>> {
let connection_manager = ConnectionManager::<SqliteConnection>::new(":memory:");
let pool = Pool::builder()
.max_size(1)
.build(connection_manager)
.expect("Failed to build connection pool");
run_sqlite_migrations(&*pool.get().expect("Failed to get connection for migrations"))
.expect("Failed to run migrations");
pool
}
}