use crate::error::Result;
use crate::server::progress::ProgressReporter;
use crate::types::{CancelledNotification, Notification};
use std::collections::HashMap;
use std::sync::Arc;
#[cfg(not(target_arch = "wasm32"))]
use tokio::sync::RwLock;
#[cfg(not(target_arch = "wasm32"))]
use tokio_util::sync::CancellationToken;
pub struct CancellationManager {
tokens: Arc<RwLock<HashMap<String, CancellationToken>>>,
notification_sender: Option<Arc<dyn Fn(Notification) + Send + Sync>>,
}
impl CancellationManager {
pub fn new() -> Self {
Self {
tokens: Arc::new(RwLock::new(HashMap::new())),
notification_sender: None,
}
}
pub fn set_notification_sender(&mut self, sender: Arc<dyn Fn(Notification) + Send + Sync>) {
self.notification_sender = Some(sender);
}
pub async fn create_token(&self, request_id: String) -> CancellationToken {
let token = CancellationToken::new();
let mut tokens = self.tokens.write().await;
tokens.insert(request_id, token.clone());
token
}
pub async fn cancel_request(&self, request_id: String, reason: Option<String>) -> Result<()> {
let token = {
let mut tokens = self.tokens.write().await;
tokens.remove(&request_id)
};
if let Some(token) = token {
token.cancel();
if let Some(sender) = &self.notification_sender {
let notification =
Notification::Client(crate::types::ClientNotification::Cancelled(
CancelledNotification::new(crate::types::RequestId::String(
request_id.clone(),
))
.with_reason(reason.unwrap_or_else(|| "Cancelled by server".to_string())),
));
sender(notification);
}
}
Ok(())
}
pub async fn remove_token(&self, request_id: &str) {
let mut tokens = self.tokens.write().await;
tokens.remove(request_id);
}
pub async fn is_cancelled(&self, request_id: &str) -> bool {
let tokens = self.tokens.read().await;
tokens
.get(request_id)
.is_some_and(tokio_util::sync::CancellationToken::is_cancelled)
}
pub async fn get_token(&self, request_id: &str) -> Option<CancellationToken> {
let tokens = self.tokens.read().await;
tokens.get(request_id).cloned()
}
pub async fn clear(&self) {
let mut tokens = self.tokens.write().await;
for token in tokens.values() {
token.cancel();
}
tokens.clear();
}
}
impl Default for CancellationManager {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for CancellationManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CancellationManager")
.field(
"active_tokens",
&self.tokens.try_read().map_or(0, |t| t.len()),
)
.finish()
}
}
#[derive(Clone)]
#[non_exhaustive]
pub struct RequestHandlerExtra {
pub cancellation_token: CancellationToken,
pub request_id: String,
pub session_id: Option<String>,
pub auth_info: Option<crate::types::auth::AuthInfo>,
pub auth_context: Option<crate::server::auth::AuthContext>,
pub metadata: HashMap<String, String>,
#[allow(dead_code)]
pub progress_reporter: Option<Arc<dyn ProgressReporter>>,
pub task_request: Option<serde_json::Value>,
pub extensions: http::Extensions,
#[cfg(not(target_arch = "wasm32"))]
pub peer: Option<Arc<dyn crate::shared::peer::PeerHandle>>,
}
impl RequestHandlerExtra {
pub fn new(request_id: String, cancellation_token: CancellationToken) -> Self {
Self {
cancellation_token,
request_id,
session_id: None,
auth_info: None,
auth_context: None,
metadata: HashMap::new(),
progress_reporter: None,
task_request: None,
extensions: http::Extensions::new(),
#[cfg(not(target_arch = "wasm32"))]
peer: None,
}
}
pub fn with_session_id(mut self, session_id: Option<String>) -> Self {
self.session_id = session_id;
self
}
pub fn with_auth_info(mut self, auth_info: Option<crate::types::auth::AuthInfo>) -> Self {
self.auth_info = auth_info;
self
}
pub fn with_auth_context(
mut self,
auth_context: Option<crate::server::auth::AuthContext>,
) -> Self {
self.auth_context = auth_context;
self
}
pub fn with_progress_reporter(
mut self,
progress_reporter: Option<Arc<dyn ProgressReporter>>,
) -> Self {
self.progress_reporter = progress_reporter;
self
}
pub fn with_task_request(mut self, task_request: Option<serde_json::Value>) -> Self {
self.task_request = task_request;
self
}
pub fn extensions(&self) -> &http::Extensions {
&self.extensions
}
pub fn extensions_mut(&mut self) -> &mut http::Extensions {
&mut self.extensions
}
#[cfg(not(target_arch = "wasm32"))]
pub fn with_peer(mut self, peer: Arc<dyn crate::shared::peer::PeerHandle>) -> Self {
self.peer = Some(peer);
self
}
#[cfg(not(target_arch = "wasm32"))]
pub fn peer(&self) -> Option<&Arc<dyn crate::shared::peer::PeerHandle>> {
self.peer.as_ref()
}
pub fn is_task_request(&self) -> bool {
self.task_request.is_some()
}
pub fn auth_context(&self) -> Option<&crate::server::auth::AuthContext> {
self.auth_context.as_ref()
}
pub fn get_metadata(&self, key: &str) -> Option<&String> {
self.metadata.get(key)
}
pub fn set_metadata(&mut self, key: String, value: String) {
self.metadata.insert(key, value);
}
pub fn is_cancelled(&self) -> bool {
self.cancellation_token.is_cancelled()
}
pub async fn cancelled(&self) {
self.cancellation_token.cancelled().await;
}
pub async fn report_progress(
&self,
progress: f64,
total: Option<f64>,
message: Option<String>,
) -> crate::Result<()> {
if let Some(rep) = &self.progress_reporter {
rep.report_progress(progress, total, message).await
} else {
Ok(())
}
}
pub async fn report_percent(&self, percent: f64, message: Option<String>) -> crate::Result<()> {
if let Some(rep) = &self.progress_reporter {
rep.report_percent(percent, message).await
} else {
Ok(())
}
}
pub async fn report_count(
&self,
current: usize,
total: usize,
message: Option<String>,
) -> crate::Result<()> {
if let Some(rep) = &self.progress_reporter {
rep.report_count(current, total, message).await
} else {
Ok(())
}
}
}
impl Default for RequestHandlerExtra {
fn default() -> Self {
Self {
cancellation_token: CancellationToken::new(),
request_id: uuid::Uuid::new_v4().to_string(),
session_id: None,
auth_info: None,
auth_context: None,
metadata: HashMap::new(),
progress_reporter: None,
task_request: None,
extensions: http::Extensions::new(),
#[cfg(not(target_arch = "wasm32"))]
peer: None,
}
}
}
impl std::fmt::Debug for RequestHandlerExtra {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
const SENSITIVE_KEYS: &[&str] = &[
"oauth_token",
"access_token",
"refresh_token",
"api_key",
"secret",
"password",
"bearer_token",
"auth_token",
];
let redacted_metadata: HashMap<String, String> = self
.metadata
.iter()
.map(|(k, v)| {
let is_sensitive = SENSITIVE_KEYS
.iter()
.any(|sensitive| k.to_lowercase().contains(sensitive));
if is_sensitive {
(k.clone(), "[REDACTED]".to_string())
} else {
(k.clone(), v.clone())
}
})
.collect();
let mut debug = f.debug_struct("RequestHandlerExtra");
debug
.field("cancellation_token", &self.cancellation_token)
.field("request_id", &self.request_id)
.field("session_id", &self.session_id)
.field("auth_info", &self.auth_info)
.field("auth_context", &self.auth_context)
.field("metadata", &redacted_metadata)
.field("task_request", &self.task_request.is_some())
.field("extensions", &self.extensions);
#[cfg(not(target_arch = "wasm32"))]
debug.field("peer", &self.peer.as_ref().map(|_| "Arc<dyn PeerHandle>"));
debug.finish()
}
}
impl CancellationManager {
pub async fn cancel_request_silent(&self, request_id: String) -> Result<()> {
let token = {
let mut tokens = self.tokens.write().await;
tokens.remove(&request_id)
};
if let Some(token) = token {
token.cancel();
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_create_and_cancel_token() {
let manager = CancellationManager::new();
let token = manager.create_token("test-request".to_string()).await;
assert!(!token.is_cancelled());
manager
.cancel_request("test-request".to_string(), None)
.await
.unwrap();
assert!(token.is_cancelled());
assert!(manager.get_token("test-request").await.is_none());
}
#[tokio::test]
async fn test_cancel_with_reason() {
let manager = CancellationManager::new();
let notifications = Arc::new(RwLock::new(Vec::new()));
let notifications_clone = notifications.clone();
let mut manager = manager;
manager.set_notification_sender(Arc::new(move |notif| {
let notifications = notifications_clone.clone();
tokio::spawn(async move {
notifications.write().await.push(notif);
});
}));
let _token = manager.create_token("test-request".to_string()).await;
manager
.cancel_request("test-request".to_string(), Some("Test reason".to_string()))
.await
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let notifs = notifications.read().await;
assert_eq!(notifs.len(), 1);
if let Notification::Client(crate::types::ClientNotification::Cancelled(cancelled)) =
¬ifs[0]
{
assert_eq!(
cancelled.request_id,
crate::types::RequestId::String("test-request".to_string())
);
assert_eq!(cancelled.reason, Some("Test reason".to_string()));
} else {
panic!("Expected Cancelled notification");
}
}
#[tokio::test]
async fn test_remove_token() {
let manager = CancellationManager::new();
let token = manager.create_token("test-request".to_string()).await;
assert!(manager.get_token("test-request").await.is_some());
manager.remove_token("test-request").await;
assert!(manager.get_token("test-request").await.is_none());
assert!(!token.is_cancelled());
}
#[tokio::test]
async fn test_clear_all_tokens() {
let manager = CancellationManager::new();
let token1 = manager.create_token("request1".to_string()).await;
let token2 = manager.create_token("request2".to_string()).await;
let token3 = manager.create_token("request3".to_string()).await;
manager.clear().await;
assert!(token1.is_cancelled());
assert!(token2.is_cancelled());
assert!(token3.is_cancelled());
assert!(manager.get_token("request1").await.is_none());
assert!(manager.get_token("request2").await.is_none());
assert!(manager.get_token("request3").await.is_none());
}
#[tokio::test]
async fn test_request_handler_extra() {
let token = CancellationToken::new();
let extra = RequestHandlerExtra::new("test-req".to_string(), token.clone())
.with_session_id(Some("session-123".to_string()));
assert_eq!(extra.request_id, "test-req");
assert_eq!(extra.session_id, Some("session-123".to_string()));
assert!(!extra.is_cancelled());
token.cancel();
assert!(extra.is_cancelled());
}
#[tokio::test]
async fn test_metadata_redaction_in_debug() {
let token = CancellationToken::new();
let mut extra = RequestHandlerExtra::new("test-req".to_string(), token);
extra.set_metadata("oauth_token".to_string(), "secret-token-123".to_string());
extra.set_metadata("access_token".to_string(), "bearer-xyz".to_string());
extra.set_metadata("user_id".to_string(), "user-456".to_string());
extra.set_metadata("request_count".to_string(), "42".to_string());
let debug_output = format!("{:?}", extra);
assert!(
debug_output.contains("[REDACTED]"),
"Expected redacted values in: {}",
debug_output
);
assert!(
!debug_output.contains("secret-token-123"),
"OAuth token should be redacted: {}",
debug_output
);
assert!(
!debug_output.contains("bearer-xyz"),
"Access token should be redacted: {}",
debug_output
);
assert!(
debug_output.contains("user-456"),
"Non-sensitive metadata should not be redacted: {}",
debug_output
);
assert!(
debug_output.contains("42"),
"Non-sensitive metadata should not be redacted: {}",
debug_output
);
}
#[tokio::test]
async fn test_extensions_default_empty() {
let extra = RequestHandlerExtra::default();
assert!(extra.extensions().get::<String>().is_none());
}
#[tokio::test]
async fn test_extensions_insert_overwrite_returns_old() {
let mut extra = RequestHandlerExtra::default();
assert_eq!(extra.extensions_mut().insert(42u64), None);
assert_eq!(extra.extensions_mut().insert(99u64), Some(42u64));
assert_eq!(extra.extensions().get::<u64>(), Some(&99u64));
}
#[tokio::test]
async fn test_debug_extensions_prints_type_names_only() {
let mut extra = RequestHandlerExtra::default();
extra
.extensions_mut()
.insert("SECRET_VALUE_DO_NOT_LEAK".to_string());
let debug_out = format!("{:?}", extra);
assert!(!debug_out.contains("SECRET_VALUE_DO_NOT_LEAK"));
}
}