use std::sync::Arc;
use crate::errors::RpcError;
use crate::wire::Metadata;
#[derive(Clone, Debug, Default)]
pub struct CallStatistics {
pub input_batches: u64,
pub output_batches: u64,
pub input_rows: u64,
pub output_rows: u64,
pub input_bytes: u64,
pub output_bytes: u64,
}
#[derive(Clone, Debug)]
pub struct DispatchInfo {
pub method: String,
pub method_type: &'static str,
pub server_id: String,
pub protocol: String,
pub protocol_hash: String,
pub protocol_version: String,
pub request_id: String,
pub transport_metadata: Arc<Metadata>,
pub principal: String,
pub auth_domain: String,
pub authenticated: bool,
pub remote_addr: String,
pub http_status: u16,
pub request_data: Vec<u8>,
pub stream_id: String,
pub cancelled: bool,
pub claims: std::collections::BTreeMap<String, String>,
}
impl DispatchInfo {
pub fn from_request(
server: &crate::server::RpcServer,
req: &crate::server::Request,
method_type: &'static str,
auth: &crate::auth::AuthContext,
) -> Self {
Self {
method: req.method.clone(),
method_type,
server_id: server.server_id.clone(),
protocol: server.protocol_name().to_string(),
protocol_hash: server.protocol_hash().to_string(),
protocol_version: server.protocol_version().to_string(),
request_id: req.request_id.clone(),
transport_metadata: Arc::new(req.metadata.clone()),
principal: auth.principal.clone(),
auth_domain: auth.domain.clone(),
authenticated: auth.authenticated,
remote_addr: String::new(),
http_status: 0,
request_data: Vec::new(),
stream_id: String::new(),
cancelled: false,
claims: auth.claims.clone(),
}
}
}
pub type HookToken = u64;
pub trait DispatchHook: Send + Sync {
fn on_dispatch_start(&self, info: &DispatchInfo) -> HookToken;
fn on_dispatch_end(
&self,
token: HookToken,
info: &DispatchInfo,
error: Option<&RpcError>,
stats: &CallStatistics,
);
}
pub type SharedHook = Arc<dyn DispatchHook>;
pub struct ChainHook {
inner: Vec<SharedHook>,
}
impl ChainHook {
pub fn new(hooks: Vec<SharedHook>) -> Self {
Self { inner: hooks }
}
}
impl DispatchHook for ChainHook {
fn on_dispatch_start(&self, info: &DispatchInfo) -> HookToken {
for h in &self.inner {
let _ = h.on_dispatch_start(info);
}
0
}
fn on_dispatch_end(
&self,
token: HookToken,
info: &DispatchInfo,
error: Option<&RpcError>,
stats: &CallStatistics,
) {
for h in &self.inner {
h.on_dispatch_end(token, info, error, stats);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU64, Ordering};
struct CountingHook {
starts: AtomicU64,
ends: AtomicU64,
}
impl DispatchHook for CountingHook {
fn on_dispatch_start(&self, _info: &DispatchInfo) -> HookToken {
self.starts.fetch_add(1, Ordering::Relaxed) + 1
}
fn on_dispatch_end(
&self,
_token: HookToken,
_info: &DispatchInfo,
_error: Option<&RpcError>,
_stats: &CallStatistics,
) {
self.ends.fetch_add(1, Ordering::Relaxed);
}
}
#[test]
fn chain_hook_fans_out() {
let a = Arc::new(CountingHook {
starts: AtomicU64::new(0),
ends: AtomicU64::new(0),
});
let b = Arc::new(CountingHook {
starts: AtomicU64::new(0),
ends: AtomicU64::new(0),
});
let chain = ChainHook::new(vec![a.clone(), b.clone()]);
let info = DispatchInfo {
method: "echo".into(),
method_type: "unary",
server_id: "test".into(),
protocol: String::new(),
protocol_hash: String::new(),
protocol_version: String::new(),
request_id: String::new(),
transport_metadata: Arc::new(Default::default()),
principal: String::new(),
auth_domain: String::new(),
authenticated: false,
remote_addr: String::new(),
http_status: 0,
request_data: Vec::new(),
stream_id: String::new(),
cancelled: false,
claims: std::collections::BTreeMap::new(),
};
let token = chain.on_dispatch_start(&info);
chain.on_dispatch_end(token, &info, None, &CallStatistics::default());
assert_eq!(a.starts.load(Ordering::Relaxed), 1);
assert_eq!(b.starts.load(Ordering::Relaxed), 1);
assert_eq!(a.ends.load(Ordering::Relaxed), 1);
assert_eq!(b.ends.load(Ordering::Relaxed), 1);
}
}