use crate::codec::mechanism::ZmqMechanism;
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use crate::codec::ZmqGreeting;
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use crate::mechanism::SessionState;
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use crate::{ZmqError, ZmqResult};
use parking_lot::RwLock;
use std::future::Future;
use std::pin::Pin;
#[derive(Debug, Clone)]
pub struct ZapRequest {
pub domain: String,
pub address: String,
pub mechanism: ZmqMechanism,
pub username: Option<String>,
pub password: Option<String>,
pub client_pubkey: Option<Vec<u8>>,
}
#[derive(Debug, Clone)]
pub struct ZapResponse {
pub status_code: u16,
pub status_text: String,
pub user_id: String,
}
impl ZapResponse {
pub fn allow(user_id: impl Into<String>) -> Self {
Self {
status_code: 200,
status_text: "OK".into(),
user_id: user_id.into(),
}
}
pub fn deny(reason: impl Into<String>) -> Self {
Self {
status_code: 400,
status_text: reason.into(),
user_id: String::new(),
}
}
}
type BoxFuture = Pin<Box<dyn Future<Output = ZapResponse> + Send + 'static>>;
type HandlerFn = Box<dyn Fn(ZapRequest) -> BoxFuture + Send + Sync + 'static>;
static ZAP_HANDLER: RwLock<Option<HandlerFn>> = RwLock::new(None);
pub fn set_zap_handler<F, Fut>(handler: F)
where
F: Fn(ZapRequest) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ZapResponse> + Send + 'static,
{
let boxed: HandlerFn = Box::new(move |req| Box::pin(handler(req)));
*ZAP_HANDLER.write() = Some(boxed);
}
pub fn clear_zap_handler() {
*ZAP_HANDLER.write() = None;
}
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
pub(crate) async fn zap_check(
domain: &str,
peer_greeting: &ZmqGreeting,
state: &SessionState,
peer_addr: &str,
client_pubkey: Option<Vec<u8>>,
) -> ZmqResult<()> {
let req = ZapRequest {
domain: domain.to_string(),
address: peer_addr.to_string(),
mechanism: peer_greeting.mechanism,
username: state.username.clone(),
password: state.password.clone(),
client_pubkey,
};
let resp_future = {
let guard = ZAP_HANDLER.read();
match guard.as_ref() {
Some(handler) => handler(req),
None => return Ok(()),
}
};
let resp = crate::async_rt::task::timeout(std::time::Duration::from_secs(5), resp_future)
.await
.map_err(|_e| ZmqError::ZapTimeout)?;
if resp.status_code == 200 {
Ok(())
} else {
Err(ZmqError::ZapDenied {
status_code: resp.status_code,
status_text: resp.status_text,
})
}
}