#![cfg(native)]
use std::collections::HashMap;
use std::sync::Arc;
use http::{HeaderMap, HeaderValue, StatusCode};
use reinhardt_di::{InjectionContext, SingletonScope};
use uuid::Uuid;
use super::auth::{MockSession, TestUser};
use super::mock_request::MockHttpRequest;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum TransactionMode {
#[default]
Rollback,
Commit,
None,
}
#[allow(clippy::type_complexity)]
pub struct ServerFnTestContext {
singleton_scope: Arc<SingletonScope>,
overrides: Vec<Box<dyn FnOnce(&InjectionContext) + Send>>,
mock_request: Option<MockHttpRequest>,
mock_session: Option<MockSession>,
test_user: Option<TestUser>,
transaction_mode: TransactionMode,
request_headers: HeaderMap,
csrf_token: Option<String>,
}
impl ServerFnTestContext {
pub fn new(singleton_scope: Arc<SingletonScope>) -> Self {
Self {
singleton_scope,
overrides: Vec::new(),
mock_request: None,
mock_session: None,
test_user: None,
transaction_mode: TransactionMode::default(),
request_headers: HeaderMap::new(),
csrf_token: None,
}
}
pub fn with_database<P: Clone + Send + Sync + 'static>(mut self, pool: P) -> Self {
self.overrides.push(Box::new(move |ctx| {
ctx.set_singleton(pool);
}));
self
}
pub fn with_singleton<T: Clone + Send + Sync + 'static>(mut self, value: T) -> Self {
self.overrides.push(Box::new(move |ctx| {
ctx.set_singleton(value);
}));
self
}
pub fn with_permissions<S: Into<String>>(mut self, permissions: Vec<S>) -> Self {
if let Some(ref mut user) = self.test_user {
for perm in permissions {
user.permissions.push(perm.into());
}
if let Some(ref mut session) = self.mock_session {
session.user = Some(user.clone());
}
} else {
let mut user = TestUser::authenticated("test-user");
for perm in permissions {
user.permissions.push(perm.into());
}
self.test_user = Some(user.clone());
self.mock_session = Some(MockSession::authenticated(user));
}
self
}
pub fn with_roles<S: Into<String>>(mut self, roles: Vec<S>) -> Self {
if let Some(ref mut user) = self.test_user {
for role in roles {
user.roles.push(role.into());
}
if let Some(ref mut session) = self.mock_session {
session.user = Some(user.clone());
}
} else {
let mut user = TestUser::authenticated("test-user");
for role in roles {
user.roles.push(role.into());
}
self.test_user = Some(user.clone());
self.mock_session = Some(MockSession::authenticated(user));
}
self
}
pub fn with_request(mut self, request: MockHttpRequest) -> Self {
self.mock_request = Some(request);
self
}
pub fn with_request_headers(mut self, headers: HeaderMap) -> Self {
self.request_headers = headers;
self
}
pub fn with_header(mut self, name: &str, value: &str) -> Self {
if let Ok(header_value) = HeaderValue::from_str(value)
&& let Ok(header_name) = http::header::HeaderName::from_bytes(name.as_bytes())
{
self.request_headers.insert(header_name, header_value);
}
self
}
pub fn with_csrf_token(mut self, token: &str) -> Self {
self.csrf_token = Some(token.to_string());
if let Ok(header_value) = HeaderValue::from_str(token) {
self.request_headers
.insert("x-csrf-token", header_value.clone());
}
if let Some(ref mut session) = self.mock_session {
session.csrf_token = token.to_string();
}
self
}
pub fn with_transaction_mode(mut self, mode: TransactionMode) -> Self {
self.transaction_mode = mode;
self
}
pub fn with_transaction_rollback(self) -> Self {
self.with_transaction_mode(TransactionMode::Rollback)
}
pub fn with_session(mut self, session: MockSession) -> Self {
self.mock_session = Some(session);
self
}
#[cfg(native)]
pub fn auth(self) -> crate::auth::ServerFnAuthBuilder {
crate::auth::ServerFnAuthBuilder::new(self)
}
pub fn with_mock_session(mut self) -> Self {
if self.mock_session.is_none() {
self.mock_session = Some(MockSession::anonymous());
}
self
}
pub fn build(self) -> ServerFnTestEnv {
let ctx = InjectionContext::builder(self.singleton_scope.clone()).build();
for override_fn in self.overrides {
override_fn(&ctx);
}
if let Some(session) = self.mock_session.clone() {
ctx.set_singleton(session);
}
if let Some(user) = self.test_user.clone() {
ctx.set_singleton(user);
}
if let Some(request) = self.mock_request.clone() {
ctx.set_singleton(request);
}
ServerFnTestEnv {
injection_context: ctx,
mock_session: self.mock_session,
test_user: self.test_user,
mock_request: self.mock_request,
transaction_mode: self.transaction_mode,
request_headers: self.request_headers,
csrf_token: self.csrf_token,
}
}
pub fn build_context(self) -> InjectionContext {
self.build().injection_context
}
}
#[derive(Clone)]
pub struct ServerFnTestEnv {
pub injection_context: InjectionContext,
pub mock_session: Option<MockSession>,
pub test_user: Option<TestUser>,
pub mock_request: Option<MockHttpRequest>,
pub transaction_mode: TransactionMode,
pub request_headers: HeaderMap,
pub csrf_token: Option<String>,
}
impl ServerFnTestEnv {
pub fn context(&self) -> &InjectionContext {
&self.injection_context
}
pub fn is_authenticated(&self) -> bool {
self.test_user.is_some() && self.mock_session.as_ref().is_some_and(|s| s.user.is_some())
}
pub fn user_id(&self) -> Option<Uuid> {
self.test_user.as_ref().map(|u| u.id)
}
pub fn has_permission(&self, permission: &str) -> bool {
self.test_user
.as_ref()
.is_some_and(|u| u.has_permission(permission))
}
pub fn has_role(&self, role: &str) -> bool {
self.test_user
.as_ref()
.is_some_and(|u| u.roles.iter().any(|r| r == role))
}
pub fn get_header(&self, name: &str) -> Option<&str> {
self.request_headers.get(name).and_then(|v| v.to_str().ok())
}
}
impl std::ops::Deref for ServerFnTestEnv {
type Target = InjectionContext;
fn deref(&self) -> &Self::Target {
&self.injection_context
}
}
#[derive(Debug, Clone)]
pub struct ExpectedResult<T> {
pub value: Option<T>,
pub status: Option<StatusCode>,
pub headers: HashMap<String, String>,
}
impl<T> Default for ExpectedResult<T> {
fn default() -> Self {
Self {
value: None,
status: None,
headers: HashMap::new(),
}
}
}
impl<T> ExpectedResult<T> {
pub fn new() -> Self {
Self::default()
}
pub fn with_value(mut self, value: T) -> Self {
self.value = Some(value);
self
}
pub fn with_status(mut self, status: StatusCode) -> Self {
self.status = Some(status);
self
}
pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(name.into(), value.into());
self
}
pub fn success(self) -> Self {
self.with_status(StatusCode::OK)
}
pub fn created(self) -> Self {
self.with_status(StatusCode::CREATED)
}
pub fn bad_request(self) -> Self {
self.with_status(StatusCode::BAD_REQUEST)
}
pub fn unauthorized(self) -> Self {
self.with_status(StatusCode::UNAUTHORIZED)
}
pub fn forbidden(self) -> Self {
self.with_status(StatusCode::FORBIDDEN)
}
pub fn not_found(self) -> Self {
self.with_status(StatusCode::NOT_FOUND)
}
pub fn conflict(self) -> Self {
self.with_status(StatusCode::CONFLICT)
}
pub fn internal_error(self) -> Self {
self.with_status(StatusCode::INTERNAL_SERVER_ERROR)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_builder() {
let singleton = Arc::new(SingletonScope::new());
let ctx = ServerFnTestContext::new(singleton)
.with_mock_session()
.build();
assert!(ctx.mock_session.is_some());
}
#[test]
fn test_authenticated_user() {
let singleton = Arc::new(SingletonScope::new());
let user = TestUser::admin();
let ctx = ServerFnTestContext::new(singleton)
.with_session(MockSession::authenticated(user))
.build();
assert!(ctx.mock_session.as_ref().is_some_and(|s| s.user.is_some()));
}
#[test]
fn test_permissions() {
let singleton = Arc::new(SingletonScope::new());
let ctx = ServerFnTestContext::new(singleton)
.with_permissions(vec!["read", "write"])
.build();
assert!(ctx.has_permission("read"));
assert!(ctx.has_permission("write"));
assert!(!ctx.has_permission("admin"));
}
#[test]
fn test_csrf_token() {
let singleton = Arc::new(SingletonScope::new());
let ctx = ServerFnTestContext::new(singleton)
.with_mock_session()
.with_csrf_token("test-token")
.build();
assert_eq!(ctx.csrf_token.as_deref(), Some("test-token"));
assert_eq!(ctx.get_header("x-csrf-token"), Some("test-token"));
}
#[test]
fn test_transaction_mode() {
let singleton = Arc::new(SingletonScope::new());
let ctx = ServerFnTestContext::new(singleton)
.with_transaction_rollback()
.build();
assert_eq!(ctx.transaction_mode, TransactionMode::Rollback);
}
}