use std::sync::Arc;
use mssql_codec::connection::CancelHandle as CodecCancelHandle;
#[cfg(feature = "tls")]
use mssql_tls::TlsStream;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use crate::error::{Error, Result};
#[cfg(feature = "tls")]
type TlsCancelHandle = CodecCancelHandle<TlsStream<TcpStream>>;
#[cfg(feature = "tls")]
type TlsPreloginCancelHandle =
CodecCancelHandle<TlsStream<mssql_tls::TlsPreloginWrapper<TcpStream>>>;
type PlainCancelHandle = CodecCancelHandle<TcpStream>;
#[derive(Clone)]
pub struct CancelHandle {
inner: Arc<Mutex<CancelHandleInner>>,
}
enum CancelHandleInner {
#[cfg(feature = "tls")]
Tls(TlsCancelHandle),
#[cfg(feature = "tls")]
TlsPrelogin(TlsPreloginCancelHandle),
Plain(PlainCancelHandle),
}
impl CancelHandle {
#[cfg(feature = "tls")]
pub(crate) fn from_tls(handle: TlsCancelHandle) -> Self {
Self {
inner: Arc::new(Mutex::new(CancelHandleInner::Tls(handle))),
}
}
#[cfg(feature = "tls")]
pub(crate) fn from_tls_prelogin(handle: TlsPreloginCancelHandle) -> Self {
Self {
inner: Arc::new(Mutex::new(CancelHandleInner::TlsPrelogin(handle))),
}
}
pub(crate) fn from_plain(handle: PlainCancelHandle) -> Self {
Self {
inner: Arc::new(Mutex::new(CancelHandleInner::Plain(handle))),
}
}
pub async fn cancel(&self) -> Result<()> {
let inner = self.inner.lock().await;
match &*inner {
#[cfg(feature = "tls")]
CancelHandleInner::Tls(h) => h.cancel().await.map_err(|e| Error::Cancel(e.to_string())),
#[cfg(feature = "tls")]
CancelHandleInner::TlsPrelogin(h) => {
h.cancel().await.map_err(|e| Error::Cancel(e.to_string()))
}
CancelHandleInner::Plain(h) => {
h.cancel().await.map_err(|e| Error::Cancel(e.to_string()))
}
}
}
pub async fn wait_cancelled(&self) {
let inner = self.inner.lock().await;
match &*inner {
#[cfg(feature = "tls")]
CancelHandleInner::Tls(h) => h.wait_cancelled().await,
#[cfg(feature = "tls")]
CancelHandleInner::TlsPrelogin(h) => h.wait_cancelled().await,
CancelHandleInner::Plain(h) => h.wait_cancelled().await,
}
}
#[must_use]
pub fn is_cancelling(&self) -> bool {
self.inner
.try_lock()
.map(|inner| match &*inner {
#[cfg(feature = "tls")]
CancelHandleInner::Tls(h) => h.is_cancelling(),
#[cfg(feature = "tls")]
CancelHandleInner::TlsPrelogin(h) => h.is_cancelling(),
CancelHandleInner::Plain(h) => h.is_cancelling(),
})
.unwrap_or(true)
}
}
impl std::fmt::Debug for CancelHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CancelHandle")
.field("is_cancelling", &self.is_cancelling())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cancel_handle_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<CancelHandle>();
}
#[test]
fn test_cancel_handle_is_clone() {
fn assert_clone<T: Clone>() {}
assert_clone::<CancelHandle>();
}
}