use std::cell::RefCell;
use std::future::Future;
use std::sync::Arc;
use crate::connection::client::DatabaseClient;
use crate::error::{Result, SurqlError};
tokio::task_local! {
static CURRENT_CLIENT: RefCell<Option<Arc<DatabaseClient>>>;
}
pub fn get_db() -> Result<Arc<DatabaseClient>> {
CURRENT_CLIENT
.try_with(|slot| slot.borrow().clone())
.map_err(|_| no_scope_error())?
.ok_or_else(|| SurqlError::Context {
reason: "no active database connection; use connection_scope() or set_db() first"
.into(),
})
}
pub fn set_db(client: Arc<DatabaseClient>) -> Result<()> {
CURRENT_CLIENT
.try_with(|slot| {
*slot.borrow_mut() = Some(client);
})
.map_err(|_| no_scope_error())
}
pub fn clear_db() -> Result<()> {
CURRENT_CLIENT
.try_with(|slot| {
*slot.borrow_mut() = None;
})
.map_err(|_| no_scope_error())
}
pub fn has_db() -> bool {
CURRENT_CLIENT
.try_with(|slot| slot.borrow().is_some())
.unwrap_or(false)
}
pub async fn connection_scope<F, T>(client: Arc<DatabaseClient>, fut: F) -> T
where
F: Future<Output = T>,
{
CURRENT_CLIENT.scope(RefCell::new(Some(client)), fut).await
}
pub async fn connection_override<F, T>(client: Arc<DatabaseClient>, fut: F) -> T
where
F: Future<Output = T>,
{
connection_scope(client, fut).await
}
fn no_scope_error() -> SurqlError {
SurqlError::Context {
reason: "no active connection_scope on this task".into(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::connection::ConnectionConfig;
fn make_client() -> Arc<DatabaseClient> {
Arc::new(DatabaseClient::new(ConnectionConfig::default()).expect("default config is valid"))
}
#[tokio::test]
async fn get_db_outside_scope_errors() {
let err = get_db().unwrap_err();
assert!(matches!(err, SurqlError::Context { .. }));
assert!(!has_db());
}
#[tokio::test]
async fn set_db_outside_scope_errors() {
let client = make_client();
let err = set_db(client).unwrap_err();
assert!(matches!(err, SurqlError::Context { .. }));
}
#[tokio::test]
async fn clear_db_outside_scope_errors() {
let err = clear_db().unwrap_err();
assert!(matches!(err, SurqlError::Context { .. }));
}
#[tokio::test]
async fn scope_sets_and_restores() {
assert!(!has_db());
let client = make_client();
connection_scope(client.clone(), async {
assert!(has_db());
let got = get_db().expect("client in scope");
assert!(Arc::ptr_eq(&got, &client));
})
.await;
assert!(!has_db(), "scope must release the binding");
}
#[tokio::test]
async fn override_swaps_inside_outer_scope() {
let outer = make_client();
let inner = make_client();
connection_scope(outer.clone(), async {
let got = get_db().unwrap();
assert!(Arc::ptr_eq(&got, &outer));
connection_override(inner.clone(), async {
let got = get_db().unwrap();
assert!(Arc::ptr_eq(&got, &inner));
})
.await;
let got = get_db().unwrap();
assert!(Arc::ptr_eq(&got, &outer));
})
.await;
}
#[tokio::test]
async fn set_and_clear_inside_scope() {
let first = make_client();
let second = make_client();
connection_scope(first.clone(), async {
set_db(second.clone()).expect("set in scope");
let got = get_db().unwrap();
assert!(Arc::ptr_eq(&got, &second));
clear_db().expect("clear in scope");
assert!(!has_db());
assert!(matches!(get_db().unwrap_err(), SurqlError::Context { .. }));
})
.await;
}
#[tokio::test]
async fn scopes_are_isolated_across_tasks() {
let a = make_client();
let b = make_client();
let task_a = {
let a = a.clone();
tokio::spawn(async move {
connection_scope(a.clone(), async {
tokio::task::yield_now().await;
let got = get_db().unwrap();
assert!(Arc::ptr_eq(&got, &a));
})
.await;
})
};
let task_b = {
let b = b.clone();
tokio::spawn(async move {
connection_scope(b.clone(), async {
tokio::task::yield_now().await;
let got = get_db().unwrap();
assert!(Arc::ptr_eq(&got, &b));
})
.await;
})
};
task_a.await.unwrap();
task_b.await.unwrap();
assert!(!has_db());
}
}