use crate::{Poller, PollingResult, Result, sealed};
use google_cloud_gax::polling_state::PollingState;
use tracing::{Instrument, Span, info_span};
tokio::task_local! {
static LRO_RECORDER: LroRecorder;
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct LroRecorder {
span: Span,
attempt_count: Option<u32>,
destination_id: std::sync::Arc<std::sync::OnceLock<String>>,
}
impl LroRecorder {
pub fn new(span: Span) -> Self {
Self {
span,
attempt_count: None,
destination_id: std::sync::Arc::new(std::sync::OnceLock::new()),
}
}
pub fn current() -> Option<Self> {
LRO_RECORDER.try_get().ok()
}
pub async fn scope<F, T>(&self, future: F) -> T
where
F: std::future::Future<Output = T>,
{
LRO_RECORDER.scope(self.clone(), future).await
}
pub fn span(&self) -> &Span {
&self.span
}
pub fn attempt_count(&self) -> Option<u32> {
self.attempt_count
}
}
#[macro_export]
#[doc(hidden)]
macro_rules! record_discovery_polling_result {
($span:expr, $op:expr) => {
let span = &$span;
let op = &$op;
let done = $crate::internal::DiscoveryOperation::done(op);
span.record("gcp.longrunning.done", done);
if done {
let error = $crate::internal::DiscoveryOperation::error(op);
let code = error.as_ref().map(|e| e.code as i32).unwrap_or(0);
span.record("gcp.longrunning.status_code", code);
if let Some(status) = error {
span.record("otel.status_code", "ERROR");
span.record("otel.status_description", &status.message);
span.record("error.type", status.code.to_string());
}
}
};
}
impl LroRecorder {
pub fn with_attempt_count(&self, count: u32) -> Self {
Self {
span: self.span.clone(),
attempt_count: Some(count),
destination_id: self.destination_id.clone(),
}
}
pub fn record_destination_id(&self, name: &str) {
self.span.record("gcp.resource.destination.id", name);
let _ = self.destination_id.set(name.to_string());
}
pub fn destination_id(&self) -> Option<String> {
self.destination_id.get().cloned()
}
pub fn record_error(&self, err: &crate::Error) {
self.span.record("otel.status_code", "ERROR");
self.span.record("otel.status_description", err.to_string());
}
pub async fn record_action<F, Fut, T>(&self, f: F) -> T
where
F: FnOnce(Span) -> Fut,
Fut: std::future::Future<Output = T>,
{
let span = self.span.clone();
self.scope(async move { f(span).await }).await
}
}
#[macro_export]
#[doc(hidden)]
macro_rules! record_polling_attributes {
($span:expr) => {
if let Some(recorder) = $crate::LroRecorder::current() {
if let Some(attempt) = recorder.attempt_count() {
let span = &$span;
span.record("gcp.longrunning.poll_attempt_count", attempt);
span.record("gcp.longrunning.done", false);
}
if let Some(dest_id) = recorder.destination_id() {
let span = &$span;
span.record("gcp.resource.destination.id", dest_id);
}
}
};
}
#[derive(Clone, Debug)]
pub struct Tracing<P> {
inner: P,
recorder: LroRecorder,
poll_attempt_count: u32,
started: bool,
}
impl<P> Tracing<P> {
pub(crate) fn new(inner: P, span: Span) -> Self {
Self {
inner,
recorder: LroRecorder::new(span),
poll_attempt_count: 0,
started: false,
}
}
}
impl<P> sealed::Poller for Tracing<P>
where
P: sealed::Poller + Send,
{
async fn backoff(&mut self, state: &PollingState) {
let span = info_span!("LRO Sleep");
let inner = &mut self.inner;
self.recorder
.record_action(|_| async move { inner.backoff(state).instrument(span).await })
.await
}
}
impl<P, ResponseType, MetadataType> Poller<ResponseType, MetadataType> for Tracing<P>
where
P: Poller<ResponseType, MetadataType>,
ResponseType: Send,
MetadataType: Send,
{
async fn poll(&mut self) -> Option<PollingResult<ResponseType, MetadataType>> {
let attempt = if self.started {
self.poll_attempt_count += 1;
self.poll_attempt_count
} else {
self.started = true;
0 };
let inner = &mut self.inner;
let span = self.recorder.span().clone();
let recorder = self.recorder.with_attempt_count(attempt);
recorder
.scope(async move { inner.poll().instrument(span).await })
.await
}
async fn until_done(self) -> Result<ResponseType> {
let this = self;
let recorder = this.recorder.clone();
let result = recorder
.record_action(|wait_span| async move {
crate::until_done(this).instrument(wait_span).await
})
.await;
if let Err(ref e) = result {
recorder.record_error(e);
}
result
}
#[cfg(feature = "unstable-stream")]
fn into_stream(
self,
) -> impl futures::Stream<Item = PollingResult<ResponseType, MetadataType>> + Unpin {
crate::into_stream(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Error;
use gaxi::client_request_signals;
use gaxi::options::InstrumentationClientInfo;
use google_cloud_test_utils::test_layer::TestLayer;
use google_cloud_wkt::{Duration, Timestamp};
struct FailingPoller;
impl sealed::Poller for FailingPoller {
async fn backoff(&mut self, _state: &PollingState) {}
}
impl Poller<Duration, Timestamp> for FailingPoller {
async fn poll(&mut self) -> Option<PollingResult<Duration, Timestamp>> {
Some(PollingResult::Completed(Err(Error::io(
"logical-test-failure",
))))
}
async fn until_done(self) -> Result<Duration> {
Err(Error::io("logical-test-failure"))
}
#[cfg(feature = "unstable-stream")]
fn into_stream(
self,
) -> impl futures::Stream<Item = PollingResult<Duration, Timestamp>> + Unpin {
crate::into_stream(self)
}
}
#[tokio::test]
async fn test_tracing_decorator_error_reporting() {
let guard = TestLayer::initialize();
let span = tracing::info_span!(
"test_span",
"otel.status_code" = tracing::field::Empty,
"otel.status_description" = tracing::field::Empty,
);
let poller = Tracing::new(FailingPoller, span);
let got = poller.until_done().await;
assert!(got.is_err());
{
let captured = TestLayer::capture(&guard);
let got = captured
.iter()
.find(|s| s.name == "test_span")
.unwrap_or_else(|| panic!("missing `test_span` in captured spans: {captured:?}"));
assert_eq!(
got.attributes
.get("otel.status_code")
.and_then(|v| v.as_string()),
Some("ERROR".to_string())
);
assert!(
got.attributes
.get("otel.status_description")
.and_then(|v| v.as_string())
.unwrap()
.contains("logical-test-failure")
);
}
}
struct CountingPoller {
attempts: Vec<u32>,
}
impl sealed::Poller for CountingPoller {
async fn backoff(&mut self, _state: &PollingState) {}
}
impl Poller<Duration, Timestamp> for CountingPoller {
async fn poll(&mut self) -> Option<PollingResult<Duration, Timestamp>> {
let attempt = LroRecorder::current()
.and_then(|r| r.attempt_count())
.unwrap();
self.attempts.push(attempt);
Some(PollingResult::InProgress(None))
}
async fn until_done(self) -> Result<Duration> {
Ok(Duration::clamp(0, 0))
}
#[cfg(feature = "unstable-stream")]
fn into_stream(
self,
) -> impl futures::Stream<Item = PollingResult<Duration, Timestamp>> + Unpin {
crate::into_stream(self)
}
}
#[tokio::test]
async fn test_tracing_decorator_attempt_counting() {
let span = tracing::info_span!("test_lro_span");
let poller = CountingPoller { attempts: vec![] };
let mut traced = Tracing::new(poller, span);
let _ = traced.poll().await;
let _ = traced.poll().await;
let _ = traced.poll().await;
assert_eq!(traced.inner.attempts, vec![0, 1, 2]);
}
#[tokio::test]
async fn test_lro_recorder_span_nesting() {
let _guard = TestLayer::initialize();
let span = tracing::info_span!("test_lro_span");
let recorder = LroRecorder::new(span.clone());
let span_clone = span.clone();
recorder
.record_action(|_| async move {
let active_recorder = LroRecorder::current().unwrap();
assert_eq!(
active_recorder.span.metadata().unwrap().name(),
"test_lro_span"
);
assert_eq!(active_recorder.span, span_clone);
})
.await;
}
#[cfg(google_cloud_unstable_tracing)]
#[tokio::test]
async fn record_polling_attributes_macro() {
let guard = TestLayer::initialize();
let span =
client_request_signals!(info: &InstrumentationClientInfo::default(), method: "test");
let recorder = LroRecorder::new(span.clone()).with_attempt_count(42);
recorder.record_destination_id("my-test-lro-id");
recorder
.scope(async move {
crate::record_polling_attributes!(&span);
})
.await;
drop(recorder);
let captured = TestLayer::capture(&guard);
let got = captured
.iter()
.find(|s| s.name == "client_request")
.unwrap();
assert_eq!(
got.attributes.get("gcp.longrunning.poll_attempt_count"),
Some(&google_cloud_test_utils::test_layer::AttributeValue::UInt64(42))
);
assert_eq!(
got.attributes.get("gcp.longrunning.done"),
Some(&google_cloud_test_utils::test_layer::AttributeValue::Boolean(false))
);
assert_eq!(
got.attributes.get("gcp.resource.destination.id"),
Some(
&google_cloud_test_utils::test_layer::AttributeValue::String(
std::borrow::Cow::Borrowed("my-test-lro-id")
)
)
);
}
#[cfg(google_cloud_unstable_tracing)]
#[tokio::test]
async fn record_polling_attributes_macro_no_recorder() {
let guard = TestLayer::initialize();
let span =
client_request_signals!(info: &InstrumentationClientInfo::default(), method: "test");
crate::record_polling_attributes!(&span);
drop(span);
let captured = TestLayer::capture(&guard);
let got = captured
.iter()
.find(|s| s.name == "client_request")
.unwrap();
assert!(
got.attributes
.get("gcp.longrunning.poll_attempt_count")
.is_none()
);
assert!(got.attributes.get("gcp.longrunning.done").is_none());
}
#[cfg(google_cloud_unstable_tracing)]
#[tokio::test]
async fn record_polling_attributes_macro_no_attempt_count() {
let guard = TestLayer::initialize();
let span =
client_request_signals!(info: &InstrumentationClientInfo::default(), method: "test");
let recorder = LroRecorder::new(span.clone());
recorder
.scope(async move {
crate::record_polling_attributes!(&span);
})
.await;
drop(recorder);
let captured = TestLayer::capture(&guard);
let got = captured
.iter()
.find(|s| s.name == "client_request")
.unwrap();
assert!(
got.attributes
.get("gcp.longrunning.poll_attempt_count")
.is_none()
);
assert!(got.attributes.get("gcp.longrunning.done").is_none());
}
#[derive(Default)]
struct MockDiscoveryOperation {
done: bool,
error: Option<google_cloud_gax::error::rpc::Status>,
}
impl crate::internal::DiscoveryOperation for MockDiscoveryOperation {
fn done(&self) -> bool {
self.done
}
fn name(&self) -> Option<&String> {
None
}
fn error(&self) -> Option<google_cloud_gax::error::rpc::Status> {
self.error.clone()
}
}
#[tokio::test]
async fn record_discovery_polling_result_success() {
let guard = TestLayer::initialize();
let span =
client_request_signals!(info: &InstrumentationClientInfo::default(), method: "test");
let op = MockDiscoveryOperation {
done: true,
error: None,
};
record_discovery_polling_result!(span, op);
{
let captured = TestLayer::capture(&guard);
let got = captured
.iter()
.find(|s| s.name == "client_request")
.unwrap();
assert_eq!(
got.attributes
.get("gcp.longrunning.done")
.and_then(|v| v.as_bool()),
Some(true)
);
assert_eq!(
got.attributes
.get("gcp.longrunning.status_code")
.and_then(|v| v.as_i64()),
Some(0)
);
assert_eq!(
got.attributes
.get("otel.status_code")
.and_then(|v| v.as_string()),
Some("UNSET".to_string())
);
}
}
#[tokio::test]
async fn record_discovery_polling_result_error() {
let guard = TestLayer::initialize();
let span =
client_request_signals!(info: &InstrumentationClientInfo::default(), method: "test");
let status = google_cloud_gax::error::rpc::Status::default()
.set_code(google_cloud_gax::error::rpc::Code::NotFound)
.set_message("not found");
let op = MockDiscoveryOperation {
done: true,
error: Some(status),
};
record_discovery_polling_result!(span, op);
{
let captured = TestLayer::capture(&guard);
let got = captured
.iter()
.find(|s| s.name == "client_request")
.unwrap();
assert_eq!(
got.attributes
.get("gcp.longrunning.done")
.and_then(|v| v.as_bool()),
Some(true)
);
assert_eq!(
got.attributes
.get("gcp.longrunning.status_code")
.and_then(|v| v.as_i64()),
Some(google_cloud_gax::error::rpc::Code::NotFound as i64)
);
assert_eq!(
got.attributes
.get("otel.status_code")
.and_then(|v| v.as_string()),
Some("ERROR".to_string())
);
assert_eq!(
got.attributes
.get("otel.status_description")
.and_then(|v| v.as_string()),
Some("not found".to_string())
);
assert_eq!(
got.attributes.get("error.type").and_then(|v| v.as_string()),
Some("NOT_FOUND".to_string())
);
}
}
#[tokio::test]
async fn record_discovery_polling_result_in_progress() {
let guard = TestLayer::initialize();
let span =
client_request_signals!(info: &InstrumentationClientInfo::default(), method: "test");
let op = MockDiscoveryOperation {
done: false,
error: None,
};
record_discovery_polling_result!(span, op);
{
let captured = TestLayer::capture(&guard);
let got = captured
.iter()
.find(|s| s.name == "client_request")
.unwrap();
assert_eq!(
got.attributes
.get("gcp.longrunning.done")
.and_then(|v| v.as_bool()),
Some(false)
);
assert!(got.attributes.get("gcp.longrunning.status_code").is_none());
}
}
}