use crate::{pubsub::WriteItem, types::Request, Router, RpcSend, TaskSet};
use ::tracing::info_span;
use opentelemetry::trace::TraceContextExt;
use serde_json::value::RawValue;
use std::{future::Future, sync::OnceLock};
use tokio::{
sync::mpsc::{self, error::SendError},
task::JoinHandle,
};
use tokio_stream::StreamExt;
use tokio_util::sync::WaitForCancellationFutureOwned;
use tracing::{enabled, Level};
use tracing_opentelemetry::OpenTelemetrySpanExt;
#[derive(thiserror::Error, Debug)]
pub enum NotifyError {
#[error("failed to serialize notification: {0}")]
Serde(#[from] serde_json::Error),
#[error("notification channel closed")]
Send(#[from] SendError<Box<RawValue>>),
}
impl From<SendError<WriteItem>> for NotifyError {
fn from(value: SendError<WriteItem>) -> Self {
SendError(value.0.json).into()
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct TracingInfo {
pub service: &'static str,
pub context: Option<opentelemetry::context::Context>,
span: OnceLock<tracing::Span>,
}
impl TracingInfo {
#[allow(dead_code)] pub const fn new(service: &'static str) -> Self {
Self {
service,
context: None,
span: OnceLock::new(),
}
}
pub const fn new_with_context(
service: &'static str,
context: opentelemetry::context::Context,
) -> Self {
Self {
service,
context: Some(context),
span: OnceLock::new(),
}
}
fn make_span<S>(
&self,
router: &Router<S>,
with_notifications: bool,
parent: Option<&tracing::Span>,
) -> tracing::Span
where
S: Clone + Send + Sync + 'static,
{
let span = info_span!(
parent: parent.and_then(|p| p.id()),
"AjjRequest",
"otel.kind" = "server",
"rpc.system" = "jsonrpc",
"rpc.jsonrpc.version" = "2.0",
"rpc.service" = router.service_name(),
notifications_enabled = with_notifications,
"trace_id" = ::tracing::field::Empty,
"otel.name" = ::tracing::field::Empty,
"otel.status_code" = ::tracing::field::Empty,
"rpc.jsonrpc.request_id" = ::tracing::field::Empty,
"rpc.jsonrpc.error_code" = ::tracing::field::Empty,
"rpc.jsonrpc.error_message" = ::tracing::field::Empty,
"rpc.method" = ::tracing::field::Empty,
params = ::tracing::field::Empty,
);
if let Some(context) = &self.context {
let _ = span.set_parent(context.clone());
span.record(
"trace_id",
context.span().span_context().trace_id().to_string(),
);
}
span
}
fn init_request_span<S>(
&self,
router: &Router<S>,
with_notifications: bool,
parent: Option<&tracing::Span>,
) -> &tracing::Span
where
S: Clone + Send + Sync + 'static,
{
self.span
.get_or_init(|| self.make_span(router, with_notifications, parent))
}
pub fn child<S: Clone + Send + Sync + 'static>(
&self,
router: &Router<S>,
with_notifications: bool,
parent: Option<&tracing::Span>,
) -> Self {
let span = self.make_span(router, with_notifications, parent);
Self {
service: self.service,
context: self.context.clone(),
span: OnceLock::from(span),
}
}
#[track_caller]
fn request_span(&self) -> &tracing::Span {
self.span.get().expect("span not initialized")
}
#[cfg(test)]
pub fn mock() -> Self {
Self {
service: "test",
context: None,
span: OnceLock::from(info_span!("")),
}
}
}
#[derive(Debug, Clone)]
pub struct HandlerCtx {
pub(crate) notifications: Option<mpsc::Sender<WriteItem>>,
pub(crate) tasks: TaskSet,
pub(crate) tracing: TracingInfo,
}
impl HandlerCtx {
pub(crate) const fn new(
notifications: Option<mpsc::Sender<WriteItem>>,
tasks: TaskSet,
tracing: TracingInfo,
) -> Self {
Self {
notifications,
tasks,
tracing,
}
}
#[cfg(test)]
pub fn mock() -> Self {
Self {
notifications: None,
tasks: TaskSet::default(),
tracing: TracingInfo::mock(),
}
}
pub fn child_ctx<S: Clone + Send + Sync + 'static>(
&self,
router: &Router<S>,
parent: Option<&tracing::Span>,
) -> Self {
Self {
notifications: self.notifications.clone(),
tasks: self.tasks.clone(),
tracing: self
.tracing
.child(router, self.notifications_enabled(), parent),
}
}
pub const fn tracing_info(&self) -> &TracingInfo {
&self.tracing
}
pub const fn service_name(&self) -> &'static str {
self.tracing.service
}
#[track_caller]
pub fn span(&self) -> &tracing::Span {
self.tracing.request_span()
}
pub fn set_tracing_info(&mut self, tracing: TracingInfo) {
self.tracing = tracing;
}
pub fn notifications_enabled(&self) -> bool {
self.notifications
.as_ref()
.map(|tx| !tx.is_closed())
.unwrap_or_default()
}
pub fn init_request_span<S>(
&self,
router: &Router<S>,
parent: Option<&tracing::Span>,
) -> &tracing::Span
where
S: Clone + Send + Sync + 'static,
{
self.tracing_info()
.init_request_span(router, self.notifications_enabled(), parent)
}
pub async fn notify<T: RpcSend>(&self, t: T) -> Result<(), NotifyError> {
if let Some(notifications) = self.notifications.as_ref() {
let rv = t.into_raw_value()?;
notifications
.send(WriteItem {
span: self.span().clone(),
json: rv,
})
.await?;
}
Ok(())
}
pub async fn notify_stream<S, T>(&self, stream: S) -> Result<(), NotifyError>
where
S: tokio_stream::Stream<Item = T> + Send,
T: RpcSend,
{
if !self.notifications_enabled() {
return Ok(());
}
tokio::pin!(stream);
while let Some(item) = stream.next().await {
self.notify(item).await?;
}
Ok(())
}
pub fn spawn<F>(&self, f: F) -> JoinHandle<Option<F::Output>>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.tasks.spawn_cancellable(f)
}
pub fn spawn_with_ctx<F, Fut>(&self, f: F) -> JoinHandle<Option<Fut::Output>>
where
F: FnOnce(HandlerCtx) -> Fut,
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
self.tasks.spawn_cancellable(f(self.clone()))
}
pub fn spawn_notify_stream<S, T>(
&self,
stream: S,
) -> JoinHandle<Option<Result<(), NotifyError>>>
where
S: tokio_stream::Stream<Item = T> + Send + 'static,
T: RpcSend + 'static,
{
let ctx = self.clone();
self.tasks
.spawn_cancellable(async move { ctx.notify_stream(stream).await })
}
pub fn spawn_blocking<F>(&self, f: F) -> JoinHandle<Option<F::Output>>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.tasks.spawn_blocking_cancellable(f)
}
pub fn spawn_blocking_with_ctx<F, Fut>(&self, f: F) -> JoinHandle<Option<Fut::Output>>
where
F: FnOnce(HandlerCtx) -> Fut,
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
self.tasks.spawn_blocking_cancellable(f(self.clone()))
}
pub fn spawn_graceful<F, Fut>(&self, f: F) -> JoinHandle<Fut::Output>
where
F: FnOnce(WaitForCancellationFutureOwned) -> Fut + Send + 'static,
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
self.tasks.spawn_graceful(f)
}
pub fn spawn_graceful_with_ctx<F, Fut>(&self, f: F) -> JoinHandle<Fut::Output>
where
F: FnOnce(HandlerCtx, WaitForCancellationFutureOwned) -> Fut + Send + 'static,
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
let ctx = self.clone();
self.tasks.spawn_graceful(move |token| f(ctx, token))
}
pub fn spawn_blocking_graceful<F, Fut>(&self, f: F) -> JoinHandle<Fut::Output>
where
F: FnOnce(WaitForCancellationFutureOwned) -> Fut + Send + 'static,
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
self.tasks.spawn_blocking_graceful(f)
}
pub fn spawn_blocking_graceful_with_ctx<F, Fut>(&self, f: F) -> JoinHandle<Fut::Output>
where
F: FnOnce(HandlerCtx, WaitForCancellationFutureOwned) -> Fut + Send + 'static,
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
let ctx = self.clone();
self.tasks
.spawn_blocking_graceful(move |token| f(ctx, token))
}
}
#[derive(Debug, Clone)]
pub struct HandlerArgs {
ctx: HandlerCtx,
req: Request,
_seal: (),
}
impl HandlerArgs {
#[track_caller]
pub fn new(ctx: HandlerCtx, req: Request) -> Self {
let this = Self {
ctx,
req,
_seal: (),
};
let span = this.span();
span.record("otel.name", this.otel_span_name());
span.record("rpc.method", this.req.method());
span.record("rpc.jsonrpc.request_id", this.req.id());
if enabled!(Level::TRACE) {
span.record("params", this.req.params());
}
this
}
pub fn into_parts(self) -> (HandlerCtx, Request) {
(self.ctx, self.req)
}
pub const fn ctx(&self) -> &HandlerCtx {
&self.ctx
}
#[track_caller]
pub fn span(&self) -> &tracing::Span {
self.ctx.span()
}
pub const fn req(&self) -> &Request {
&self.req
}
pub fn id_owned(&self) -> Option<Box<RawValue>> {
self.req.id_owned()
}
pub fn method(&self) -> &str {
self.req.method()
}
pub fn otel_span_name(&self) -> String {
format!("{}/{}", self.ctx.service_name(), self.req.method())
}
pub const fn service_name(&self) -> &'static str {
self.ctx.service_name()
}
}