1use std::sync::Arc;
9
10use crate::errors::RpcError;
11use crate::wire::Metadata;
12
13#[derive(Clone, Debug, Default)]
19pub struct CallStatistics {
20 pub input_batches: u64,
21 pub output_batches: u64,
22 pub input_rows: u64,
23 pub output_rows: u64,
24 pub input_bytes: u64,
25 pub output_bytes: u64,
26}
27
28#[derive(Clone, Debug)]
30pub struct DispatchInfo {
31 pub method: String,
32 pub method_type: &'static str,
33 pub server_id: String,
34 pub protocol: String,
36 pub protocol_hash: String,
38 pub protocol_version: String,
40 pub request_id: String,
41 pub transport_metadata: Arc<Metadata>,
43 pub principal: String,
45 pub auth_domain: String,
47 pub authenticated: bool,
49 pub remote_addr: String,
51 pub http_status: u16,
53 pub request_data: Vec<u8>,
55 pub stream_id: String,
57 pub cancelled: bool,
59 pub claims: std::collections::BTreeMap<String, String>,
65}
66
67impl DispatchInfo {
68 pub fn from_request(
71 server: &crate::server::RpcServer,
72 req: &crate::server::Request,
73 method_type: &'static str,
74 auth: &crate::auth::AuthContext,
75 ) -> Self {
76 Self {
77 method: req.method.clone(),
78 method_type,
79 server_id: server.server_id.clone(),
80 protocol: server.protocol_name().to_string(),
81 protocol_hash: server.protocol_hash().to_string(),
82 protocol_version: server.protocol_version().to_string(),
83 request_id: req.request_id.clone(),
84 transport_metadata: Arc::new(req.metadata.clone()),
85 principal: auth.principal.clone(),
86 auth_domain: auth.domain.clone(),
87 authenticated: auth.authenticated,
88 remote_addr: String::new(),
89 http_status: 0,
90 request_data: Vec::new(),
91 stream_id: String::new(),
92 cancelled: false,
93 claims: auth.claims.clone(),
94 }
95 }
96}
97
98pub type HookToken = u64;
100
101pub trait DispatchHook: Send + Sync {
103 fn on_dispatch_start(&self, info: &DispatchInfo) -> HookToken;
106
107 fn on_dispatch_end(
110 &self,
111 token: HookToken,
112 info: &DispatchInfo,
113 error: Option<&RpcError>,
114 stats: &CallStatistics,
115 );
116}
117
118pub type SharedHook = Arc<dyn DispatchHook>;
120
121pub struct ChainHook {
123 inner: Vec<SharedHook>,
124}
125
126impl ChainHook {
127 pub fn new(hooks: Vec<SharedHook>) -> Self {
128 Self { inner: hooks }
129 }
130}
131
132impl DispatchHook for ChainHook {
133 fn on_dispatch_start(&self, info: &DispatchInfo) -> HookToken {
134 for h in &self.inner {
138 let _ = h.on_dispatch_start(info);
139 }
140 0
141 }
142
143 fn on_dispatch_end(
144 &self,
145 token: HookToken,
146 info: &DispatchInfo,
147 error: Option<&RpcError>,
148 stats: &CallStatistics,
149 ) {
150 for h in &self.inner {
151 h.on_dispatch_end(token, info, error, stats);
152 }
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use std::sync::atomic::{AtomicU64, Ordering};
160
161 struct CountingHook {
162 starts: AtomicU64,
163 ends: AtomicU64,
164 }
165
166 impl DispatchHook for CountingHook {
167 fn on_dispatch_start(&self, _info: &DispatchInfo) -> HookToken {
168 self.starts.fetch_add(1, Ordering::Relaxed) + 1
169 }
170 fn on_dispatch_end(
171 &self,
172 _token: HookToken,
173 _info: &DispatchInfo,
174 _error: Option<&RpcError>,
175 _stats: &CallStatistics,
176 ) {
177 self.ends.fetch_add(1, Ordering::Relaxed);
178 }
179 }
180
181 #[test]
182 fn chain_hook_fans_out() {
183 let a = Arc::new(CountingHook {
184 starts: AtomicU64::new(0),
185 ends: AtomicU64::new(0),
186 });
187 let b = Arc::new(CountingHook {
188 starts: AtomicU64::new(0),
189 ends: AtomicU64::new(0),
190 });
191 let chain = ChainHook::new(vec![a.clone(), b.clone()]);
192 let info = DispatchInfo {
193 method: "echo".into(),
194 method_type: "unary",
195 server_id: "test".into(),
196 protocol: String::new(),
197 protocol_hash: String::new(),
198 protocol_version: String::new(),
199 request_id: String::new(),
200 transport_metadata: Arc::new(Default::default()),
201 principal: String::new(),
202 auth_domain: String::new(),
203 authenticated: false,
204 remote_addr: String::new(),
205 http_status: 0,
206 request_data: Vec::new(),
207 stream_id: String::new(),
208 cancelled: false,
209 claims: std::collections::BTreeMap::new(),
210 };
211 let token = chain.on_dispatch_start(&info);
212 chain.on_dispatch_end(token, &info, None, &CallStatistics::default());
213 assert_eq!(a.starts.load(Ordering::Relaxed), 1);
214 assert_eq!(b.starts.load(Ordering::Relaxed), 1);
215 assert_eq!(a.ends.load(Ordering::Relaxed), 1);
216 assert_eq!(b.ends.load(Ordering::Relaxed), 1);
217 }
218}