mod models;
mod operations;
mod schema;
use std::sync::{Arc, RwLock};
use diesel::{
r2d2::{ConnectionManager, Pool},
result,
};
use crate::error::{ConstraintViolationError, ConstraintViolationType, InternalError};
use crate::oauth::PendingAuthorization;
use crate::store::pool::ConnectionPool;
use super::{InflightOAuthRequestStore, InflightOAuthRequestStoreError};
use operations::insert_request::InflightOAuthRequestStoreInsertRequestOperation as _;
use operations::remove_request::InflightOAuthRequestStoreRemoveRequestOperation as _;
use operations::InflightOAuthRequestOperations;
pub struct DieselInflightOAuthRequestStore<C: diesel::Connection + 'static> {
connection_pool: ConnectionPool<C>,
}
impl<C: diesel::Connection + 'static> DieselInflightOAuthRequestStore<C> {
pub fn new(connection_pool: Pool<ConnectionManager<C>>) -> Self {
Self {
connection_pool: connection_pool.into(),
}
}
pub fn new_with_write_exclusivity(
connection_pool: Arc<RwLock<Pool<ConnectionManager<C>>>>,
) -> Self {
Self {
connection_pool: connection_pool.into(),
}
}
}
#[cfg(feature = "sqlite")]
impl InflightOAuthRequestStore
for DieselInflightOAuthRequestStore<diesel::sqlite::SqliteConnection>
{
fn insert_request(
&self,
request_id: String,
pending_authorization: PendingAuthorization,
) -> Result<(), InflightOAuthRequestStoreError> {
self.connection_pool.execute_write(|connection| {
InflightOAuthRequestOperations::new(connection).insert_request(
models::OAuthInflightRequest {
id: request_id,
pkce_verifier: pending_authorization.pkce_verifier,
client_redirect_url: pending_authorization.client_redirect_url,
},
)
})
}
fn remove_request(
&self,
request_id: &str,
) -> Result<Option<PendingAuthorization>, InflightOAuthRequestStoreError> {
self.connection_pool.execute_write(|connection| {
InflightOAuthRequestOperations::new(connection)
.remove_request(request_id)
.map(|opt_request| opt_request.map(PendingAuthorization::from))
})
}
fn clone_box(&self) -> Box<dyn InflightOAuthRequestStore> {
Box::new(Self {
connection_pool: self.connection_pool.clone(),
})
}
}
#[cfg(feature = "postgres")]
impl InflightOAuthRequestStore for DieselInflightOAuthRequestStore<diesel::pg::PgConnection> {
fn insert_request(
&self,
request_id: String,
pending_authorization: PendingAuthorization,
) -> Result<(), InflightOAuthRequestStoreError> {
self.connection_pool.execute_write(|connection| {
InflightOAuthRequestOperations::new(connection).insert_request(
models::OAuthInflightRequest {
id: request_id,
pkce_verifier: pending_authorization.pkce_verifier,
client_redirect_url: pending_authorization.client_redirect_url,
},
)
})
}
fn remove_request(
&self,
request_id: &str,
) -> Result<Option<PendingAuthorization>, InflightOAuthRequestStoreError> {
self.connection_pool.execute_write(|connection| {
InflightOAuthRequestOperations::new(connection)
.remove_request(request_id)
.map(|opt_request| opt_request.map(PendingAuthorization::from))
})
}
fn clone_box(&self) -> Box<dyn InflightOAuthRequestStore> {
Box::new(Self {
connection_pool: self.connection_pool.clone(),
})
}
}
impl From<models::OAuthInflightRequest> for PendingAuthorization {
fn from(model: models::OAuthInflightRequest) -> Self {
PendingAuthorization {
pkce_verifier: model.pkce_verifier,
client_redirect_url: model.client_redirect_url,
}
}
}
impl From<diesel::r2d2::PoolError> for InflightOAuthRequestStoreError {
fn from(err: diesel::r2d2::PoolError) -> Self {
InflightOAuthRequestStoreError::InternalError(InternalError::from_source(Box::new(err)))
}
}
impl From<result::Error> for InflightOAuthRequestStoreError {
fn from(err: result::Error) -> Self {
match err {
result::Error::DatabaseError(result::DatabaseErrorKind::UniqueViolation, _) => {
InflightOAuthRequestStoreError::ConstraintViolation(
ConstraintViolationError::from_source_with_violation_type(
ConstraintViolationType::Unique,
Box::new(err),
),
)
}
result::Error::DatabaseError(_, _) => InflightOAuthRequestStoreError::InternalError(
InternalError::from_source(Box::new(err)),
),
_ => InflightOAuthRequestStoreError::InternalError(InternalError::from_source(
Box::new(err),
)),
}
}
}
#[cfg(all(test, feature = "sqlite"))]
pub mod tests {
use super::*;
use diesel::{
r2d2::{ConnectionManager, Pool},
sqlite::SqliteConnection,
};
use crate::migrations::run_sqlite_migrations;
use crate::oauth::store::tests::{
test_duplicate_id_insert, test_request_store_insert_and_remove,
};
#[test]
fn sqlite_insert_request_and_remove() {
let pool = create_connection_pool_and_migrate();
let inflight_request_store = DieselInflightOAuthRequestStore::new(pool);
test_request_store_insert_and_remove(&inflight_request_store);
}
#[test]
fn sqlite_duplicate_id_insert() {
let pool = create_connection_pool_and_migrate();
let inflight_request_store = DieselInflightOAuthRequestStore::new(pool);
test_duplicate_id_insert(&inflight_request_store);
}
fn create_connection_pool_and_migrate() -> Pool<ConnectionManager<SqliteConnection>> {
let connection_manager = ConnectionManager::<SqliteConnection>::new(":memory:");
let pool = Pool::builder()
.max_size(1)
.build(connection_manager)
.expect("Failed to build connection pool");
run_sqlite_migrations(&*pool.get().expect("Failed to get connection for migrations"))
.expect("Failed to run migrations");
pool
}
}