use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize};
use std::time::{Duration, Instant};
use tracing::{debug, error, info, info_span, instrument, Instrument, Span};
use crate::effects::{ChoreoHandler, ChoreoResult, LabelId, RoleId};
use crate::identifiers::RoleName;
pub mod fields {
pub const PROTOCOL: &str = "protocol";
pub const ROLE: &str = "role";
pub const ROLE_INDEX: &str = "role_index";
pub const PHASE: &str = "phase";
pub const MESSAGE_TYPE: &str = "message_type";
pub const MESSAGE_SIZE: &str = "message_size";
pub const TARGET_ROLE: &str = "target_role";
pub const SOURCE_ROLE: &str = "source_role";
pub const CHOICE_LABEL: &str = "choice_label";
pub const DURATION_MS: &str = "duration_ms";
pub const ERROR: &str = "error";
}
pub mod events {
pub const SEND: &str = "protocol.send";
pub const RECV: &str = "protocol.recv";
pub const CHOOSE: &str = "protocol.choose";
pub const OFFER: &str = "protocol.offer";
pub const PHASE_START: &str = "protocol.phase.start";
pub const PHASE_END: &str = "protocol.phase.end";
pub const ERROR: &str = "protocol.error";
}
pub fn protocol_span(protocol: &str, role: &RoleName, role_index: Option<u32>) -> Span {
match role_index {
Some(idx) => info_span!(
"protocol.execute",
protocol = protocol,
role = role.as_str(),
role_index = idx
),
None => info_span!(
"protocol.execute",
protocol = protocol,
role = role.as_str()
),
}
}
pub fn phase_span(protocol: &str, role: &RoleName, phase: &str) -> Span {
info_span!(
"protocol.phase",
protocol = protocol,
role = role.as_str(),
phase = phase
)
}
pub fn trace_send(target_role: &str, message_type: &str, message_size: usize) {
info!(
target: events::SEND,
target_role = target_role,
message_type = message_type,
message_size = message_size,
"sending message"
);
}
pub fn trace_recv(source_role: &str, message_type: &str, message_size: usize) {
info!(
target: events::RECV,
source_role = source_role,
message_type = message_type,
message_size = message_size,
"received message"
);
}
pub fn trace_choose(target_role: &str, label: &str) {
info!(
target: events::CHOOSE,
target_role = target_role,
choice_label = label,
"made choice"
);
}
pub fn trace_offer(source_role: &str, label: &str) {
info!(
target: events::OFFER,
source_role = source_role,
choice_label = label,
"received choice"
);
}
pub fn trace_phase_start(phase: &str) {
debug!(target: events::PHASE_START, phase = phase, "phase started");
}
pub fn trace_phase_end(phase: &str, duration_ms: u64) {
debug!(
target: events::PHASE_END,
phase = phase,
duration_ms = duration_ms,
"phase completed"
);
}
pub fn trace_error(error_message: &str) {
error!(
target: events::ERROR,
error = error_message,
"protocol error"
);
}
fn format_role<R: RoleId>(role: R) -> String {
match role.role_index() {
Some(index) => format!("{}[{}]", role.role_name(), index),
None => role.role_name().to_string(),
}
}
pub struct TracingHandler<H> {
inner: H,
protocol: &'static str,
role: RoleName,
role_index: Option<u32>,
span: Span,
}
impl<H> TracingHandler<H> {
pub fn new(inner: H, protocol: &'static str, role: H::Role) -> Self
where
H: ChoreoHandler,
{
let role_name = role.role_name();
let role_index = role.role_index();
let span = protocol_span(protocol, &role_name, role_index);
Self {
inner,
protocol,
role: role_name,
role_index,
span,
}
}
pub fn indexed(inner: H, protocol: &'static str, role: RoleName, index: u32) -> Self {
let span = protocol_span(protocol, &role, Some(index));
Self {
inner,
protocol,
role,
role_index: Some(index),
span,
}
}
pub fn protocol(&self) -> &'static str {
self.protocol
}
pub fn role(&self) -> &RoleName {
&self.role
}
pub fn role_index(&self) -> Option<u32> {
self.role_index
}
pub fn inner(&self) -> &H {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut H {
&mut self.inner
}
pub fn into_inner(self) -> H {
self.inner
}
}
#[async_trait]
impl<H: ChoreoHandler> ChoreoHandler for TracingHandler<H> {
type Role = H::Role;
type Endpoint = H::Endpoint;
#[instrument(
skip(self, ep, msg),
fields(
protocol = self.protocol,
role = self.role.as_str(),
target_role = ?to,
message_type = std::any::type_name::<M>()
)
)]
async fn send<M: Serialize + Send + Sync>(
&mut self,
ep: &mut Self::Endpoint,
to: Self::Role,
msg: &M,
) -> ChoreoResult<()> {
trace_send(&format_role(to), std::any::type_name::<M>(), 0);
self.inner
.send(ep, to, msg)
.instrument(self.span.clone())
.await
}
#[instrument(
skip(self, ep),
fields(
protocol = self.protocol,
role = self.role.as_str(),
source_role = ?from,
message_type = std::any::type_name::<M>()
)
)]
async fn recv<M: DeserializeOwned + Send>(
&mut self,
ep: &mut Self::Endpoint,
from: Self::Role,
) -> ChoreoResult<M> {
let result = self
.inner
.recv::<M>(ep, from)
.instrument(self.span.clone())
.await;
if result.is_ok() {
trace_recv(&format_role(from), std::any::type_name::<M>(), 0);
}
result
}
#[instrument(
skip(self, ep),
fields(
protocol = self.protocol,
role = self.role.as_str(),
target_role = ?to,
choice_label = label.as_str()
)
)]
async fn choose(
&mut self,
ep: &mut Self::Endpoint,
to: Self::Role,
label: <Self::Role as RoleId>::Label,
) -> ChoreoResult<()> {
trace_choose(&format_role(to), label.as_str());
self.inner
.choose(ep, to, label)
.instrument(self.span.clone())
.await
}
#[instrument(
skip(self, ep),
fields(
protocol = self.protocol,
role = self.role.as_str(),
source_role = ?from
)
)]
async fn offer(
&mut self,
ep: &mut Self::Endpoint,
from: Self::Role,
) -> ChoreoResult<<Self::Role as RoleId>::Label> {
let result = self
.inner
.offer(ep, from)
.instrument(self.span.clone())
.await;
if let Ok(ref label) = result {
trace_offer(&format_role(from), label.as_str());
}
result
}
async fn with_timeout<F, T>(
&mut self,
ep: &mut Self::Endpoint,
at: Self::Role,
dur: Duration,
body: F,
) -> ChoreoResult<T>
where
F: std::future::Future<Output = ChoreoResult<T>> + Send,
{
self.inner
.with_timeout(ep, at, dur, body)
.instrument(self.span.clone())
.await
}
}
pub struct PhaseGuard {
phase: &'static str,
start: Instant,
span: Span,
}
impl PhaseGuard {
pub fn new(protocol: &'static str, role: &RoleName, phase: &'static str) -> Self {
let span = phase_span(protocol, role, phase);
{
let _enter = span.enter();
trace_phase_start(phase);
}
Self {
phase,
start: Instant::now(),
span,
}
}
pub fn span(&self) -> &Span {
&self.span
}
}
impl Drop for PhaseGuard {
fn drop(&mut self) {
let _enter = self.span.enter();
let duration_ms = u64::try_from(self.start.elapsed().as_millis()).unwrap_or(u64::MAX);
trace_phase_end(self.phase, duration_ms);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_protocol_span() {
let span = protocol_span("TestProtocol", &RoleName::from_static("Client"), None);
assert!(span.is_disabled() || !span.is_disabled());
}
#[test]
fn test_protocol_span_indexed() {
let span = protocol_span("TestProtocol", &RoleName::from_static("Worker"), Some(3));
assert!(span.is_disabled() || !span.is_disabled());
}
#[test]
fn test_phase_span() {
let span = phase_span(
"TestProtocol",
&RoleName::from_static("Client"),
"handshake",
);
assert!(span.is_disabled() || !span.is_disabled());
}
}