use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
#[derive(Clone, Debug, Default)]
pub struct CancellationToken {
is_cancelled: Arc<AtomicBool>,
}
impl CancellationToken {
pub fn new() -> Self {
Self {
is_cancelled: Arc::new(AtomicBool::new(false)),
}
}
pub fn cancel(&self) {
self.is_cancelled.store(true, Ordering::Relaxed);
}
pub fn is_cancelled(&self) -> bool {
self.is_cancelled.load(Ordering::Relaxed)
}
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct RequestHandlerExtra {
pub cancellation_token: CancellationToken,
pub request_id: String,
pub session_id: Option<String>,
pub auth_info: Option<crate::types::auth::AuthInfo>,
#[cfg(not(target_arch = "wasm32"))]
pub auth_context: Option<crate::server::auth::AuthContext>,
pub extensions: http::Extensions,
}
impl RequestHandlerExtra {
pub fn new(request_id: String, cancellation_token: CancellationToken) -> Self {
Self {
cancellation_token,
request_id,
session_id: None,
auth_info: None,
#[cfg(not(target_arch = "wasm32"))]
auth_context: None,
extensions: http::Extensions::new(),
}
}
pub fn with_session_id(mut self, session_id: Option<String>) -> Self {
self.session_id = session_id;
self
}
pub fn with_auth_info(mut self, auth_info: Option<crate::types::auth::AuthInfo>) -> Self {
self.auth_info = auth_info;
self
}
#[cfg(not(target_arch = "wasm32"))]
pub fn with_auth_context(
mut self,
auth_context: Option<crate::server::auth::AuthContext>,
) -> Self {
self.auth_context = auth_context;
self
}
#[cfg(not(target_arch = "wasm32"))]
pub fn auth_context(&self) -> Option<&crate::server::auth::AuthContext> {
self.auth_context.as_ref()
}
pub fn is_cancelled(&self) -> bool {
self.cancellation_token.is_cancelled()
}
pub fn extensions(&self) -> &http::Extensions {
&self.extensions
}
pub fn extensions_mut(&mut self) -> &mut http::Extensions {
&mut self.extensions
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shared_extensions_parity() {
let mut extra = RequestHandlerExtra::new("r1".to_string(), CancellationToken::new());
extra.extensions_mut().insert(42u64);
assert_eq!(extra.extensions().get::<u64>(), Some(&42u64));
}
}