#![cfg(feature = "async-tokio")]
use crate::data::error::{DarraError, Result};
use crate::utils::ffi;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex as StdMutex, OnceLock};
use tokio::sync::Mutex as TokioMutex;
pub type ProgressFn = Arc<dyn Fn(&str) + Send + Sync>;
#[derive(Clone, Debug)]
pub struct CancelToken {
flag: Arc<AtomicBool>,
}
impl CancelToken {
pub fn new() -> Self {
Self { flag: Arc::new(AtomicBool::new(false)) }
}
pub fn cancel(&self) {
self.flag.store(true, Ordering::SeqCst);
}
pub fn is_cancelled(&self) -> bool {
self.flag.load(Ordering::SeqCst)
}
pub fn check(&self, where_: &str) -> Result<()> {
if self.is_cancelled() {
Err(DarraError::Cancelled(format!("{} 被取消", where_)))
} else {
Ok(())
}
}
}
impl Default for CancelToken {
fn default() -> Self {
Self::new()
}
}
struct InNativeGuard(Arc<AtomicBool>);
impl InNativeGuard {
fn enter(flag: Arc<AtomicBool>) -> Self {
flag.store(true, Ordering::SeqCst);
Self(flag)
}
}
impl Drop for InNativeGuard {
fn drop(&mut self) {
self.0.store(false, Ordering::SeqCst);
}
}
struct MasterAsyncState {
gate: Arc<TokioMutex<()>>,
shutting_down: Arc<AtomicBool>,
}
impl MasterAsyncState {
fn new() -> Self {
Self {
gate: Arc::new(TokioMutex::new(())),
shutting_down: Arc::new(AtomicBool::new(false)),
}
}
}
fn registry() -> &'static StdMutex<HashMap<u16, MasterAsyncState>> {
static REG: OnceLock<StdMutex<HashMap<u16, MasterAsyncState>>> = OnceLock::new();
REG.get_or_init(|| StdMutex::new(HashMap::new()))
}
fn state_handles(master_index: u16) -> (Arc<TokioMutex<()>>, Arc<AtomicBool>) {
let mut reg = registry().lock().expect("async 隔离注册表 poisoned");
let st = reg.entry(master_index).or_insert_with(MasterAsyncState::new);
(st.gate.clone(), st.shutting_down.clone())
}
pub fn mark_shutdown(master_index: u16) {
let reg = registry().lock().expect("async 隔离注册表 poisoned");
if let Some(st) = reg.get(&master_index) {
st.shutting_down.store(true, Ordering::SeqCst);
}
}
pub fn release_state(master_index: u16) {
let mut reg = registry().lock().expect("async 隔离注册表 poisoned");
reg.remove(&master_index);
}
pub async fn run_exclusive_async<T, F, A>(
master_index: u16,
op_name: &str,
progress: Option<ProgressFn>,
cancel: Option<CancelToken>,
cancel_abort: A,
body: F,
) -> Result<T>
where
T: Send + 'static,
F: FnOnce(Option<ProgressFn>) -> Result<T> + Send + 'static,
A: Fn() + Send + Sync + 'static,
{
let (gate, shutting_down) = state_handles(master_index);
if let Some(ct) = &cancel {
ct.check(op_name)?;
}
let _guard = gate.lock_owned().await;
if shutting_down.load(Ordering::SeqCst) {
return Err(DarraError::Cancelled(format!(
"{}: 主站正在关闭/已释放, async 操作中止 (防 use-after-free)",
op_name
)));
}
if let Some(ct) = &cancel {
ct.check(op_name)?;
}
let in_native = Arc::new(AtomicBool::new(false));
let abort = Arc::new(cancel_abort);
let watcher = cancel.as_ref().map(|ct| {
let ct = ct.clone();
let abort = abort.clone();
let in_native_w = in_native.clone();
tokio::spawn(async move {
loop {
if ct.is_cancelled() && in_native_w.load(Ordering::SeqCst) {
(abort)();
break;
}
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
})
});
let shutting_down_bg = shutting_down.clone();
let cancel_bg = cancel.clone();
let op_name_bg = op_name.to_string();
let in_native_bg = in_native.clone();
let join = tokio::task::spawn_blocking(move || -> Result<T> {
if let Some(ct) = &cancel_bg {
ct.check(&op_name_bg)?;
}
if shutting_down_bg.load(Ordering::SeqCst) {
return Err(DarraError::Cancelled(format!(
"{}: 主站正在关闭/已释放, async 操作中止 (防 use-after-free)",
op_name_bg
)));
}
let _native_guard = InNativeGuard::enter(in_native_bg);
body(progress)
})
.await;
if let Some(w) = watcher {
w.abort();
}
match join {
Ok(r) => r,
Err(e) => Err(DarraError::Other(format!(
"{}: 后台异步任务异常 (join error): {}",
op_name, e
))),
}
}
pub async fn run_exclusive_async_unit<F, A>(
master_index: u16,
op_name: &str,
progress: Option<ProgressFn>,
cancel: Option<CancelToken>,
cancel_abort: A,
body: F,
) -> Result<()>
where
F: FnOnce(Option<ProgressFn>) -> Result<()> + Send + 'static,
A: Fn() + Send + Sync + 'static,
{
run_exclusive_async(master_index, op_name, progress, cancel, cancel_abort, body).await
}
fn static_scan_gate() -> &'static TokioMutex<()> {
static GATE: OnceLock<TokioMutex<()>> = OnceLock::new();
GATE.get_or_init(|| TokioMutex::new(()))
}
pub async fn run_static_exclusive_async<T, F>(
op_name: &str,
cancel: Option<CancelToken>,
body: F,
) -> Result<T>
where
T: Send + 'static,
F: FnOnce() -> Result<T> + Send + 'static,
{
if let Some(ct) = &cancel {
ct.check(op_name)?;
}
let _guard = static_scan_gate().lock().await;
if let Some(ct) = &cancel {
ct.check(op_name)?;
}
let in_native = Arc::new(AtomicBool::new(false));
let watcher = cancel.as_ref().map(|ct| {
let ct = ct.clone();
let in_native_w = in_native.clone();
tokio::spawn(async move {
loop {
if ct.is_cancelled() && in_native_w.load(Ordering::SeqCst) {
safe_abort_scan();
break;
}
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
})
});
let cancel_bg = cancel.clone();
let op_name_bg = op_name.to_string();
let in_native_bg = in_native.clone();
let join = tokio::task::spawn_blocking(move || -> Result<T> {
if let Some(ct) = &cancel_bg {
ct.check(&op_name_bg)?;
}
let _native_guard = InNativeGuard::enter(in_native_bg);
body()
})
.await;
if let Some(w) = watcher {
w.abort();
}
match join {
Ok(r) => r,
Err(e) => Err(DarraError::Other(format!(
"{}: 后台异步任务异常 (join error): {}",
op_name, e
))),
}
}
pub fn safe_abort() {
unsafe { ffi::AbortNetwork() };
std::thread::sleep(std::time::Duration::from_millis(50));
unsafe { ffi::ResetAbortNetwork() };
}
pub fn safe_abort_scan() {
unsafe { ffi::AbortScan() };
std::thread::sleep(std::time::Duration::from_millis(50));
unsafe { ffi::ResetScanAbort() };
}