mongodb 3.6.0

The official MongoDB driver for Rust
Documentation
use std::collections::HashMap;

use opentelemetry::SpanId;
use opentelemetry_sdk::trace::{
    BatchSpanProcessor,
    InMemorySpanExporter,
    InMemorySpanExporterBuilder,
    SdkTracerProvider,
    SpanData,
};
use serde::Deserialize;

use crate::{
    bson::{doc, Bson, Document},
    test::spec::unified_runner::{results_match, EntityMap, TestRunner},
};

#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase", deny_unknown_fields)]
pub(crate) struct ObserveTracingMessages {
    enable_command_payload: Option<bool>,
}

#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase", deny_unknown_fields)]
pub(crate) struct ExpectedTracingMessages {
    client: String,
    ignore_extra_spans: Option<bool>,
    spans: Vec<ExpectedSpan>,
}

#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase", deny_unknown_fields)]
struct ExpectedSpan {
    name: String,
    attributes: Document,
    #[serde(default)]
    nested: Vec<ExpectedSpan>,
}

#[derive(Clone, Debug)]
pub(crate) struct ClientTracing {
    exporter: InMemorySpanExporter,
    provider: SdkTracerProvider,
}

impl ClientTracing {
    pub(crate) fn new(observe: &ObserveTracingMessages) -> (Self, super::OpentelemetryOptions) {
        let exporter = InMemorySpanExporterBuilder::new().build();
        let provider = SdkTracerProvider::builder()
            .with_span_processor(BatchSpanProcessor::builder(exporter.clone()).build())
            .build();
        let mut options = super::OpentelemetryOptions::builder()
            .enabled(true)
            .tracer_provider(provider.clone())
            .build();
        if observe.enable_command_payload.unwrap_or(false) {
            options.query_text_max_length = Some(1000);
        }
        (Self { exporter, provider }, options)
    }
}

impl TestRunner {
    pub(crate) async fn match_spans(
        &self,
        expected: &ExpectedTracingMessages,
    ) -> Result<(), String> {
        let client_tracing = self.get_client(&expected.client).await.tracing.unwrap();
        client_tracing.provider.force_flush().unwrap();
        let mut root_spans = vec![];
        let mut nested_spans = HashMap::<SpanId, Vec<SpanData>>::new();
        for span in client_tracing.exporter.get_finished_spans().unwrap() {
            if span.parent_span_id == SpanId::INVALID {
                root_spans.push(span);
            } else {
                nested_spans
                    .entry(span.parent_span_id)
                    .or_default()
                    .push(span);
            }
        }
        let (root_spans, nested_spans) = (root_spans, nested_spans);

        let entities = self.entities.read().await;
        Matcher {
            nested: &nested_spans,
            entities: &entities,
            ignore_extra: expected.ignore_extra_spans.unwrap_or(false),
        }
        .match_span_slice(&root_spans, &expected.spans)?;

        Ok(())
    }
}

struct Matcher<'a> {
    nested: &'a HashMap<SpanId, Vec<SpanData>>,
    entities: &'a EntityMap,
    ignore_extra: bool,
}

impl<'a> Matcher<'a> {
    fn match_span_slice(
        &self,
        actual: &[SpanData],
        expected: &[ExpectedSpan],
    ) -> Result<(), String> {
        let err_suffix = || format!("actual:\n{:#?}\n\nexpected:\n{:#?}", actual, expected);
        if self.ignore_extra {
            if actual.len() < expected.len() {
                return Err(format!(
                    "expected at least {} spans, got {}\n{}",
                    expected.len(),
                    actual.len(),
                    err_suffix(),
                ));
            }
            let mut actual = actual;
            let mut expected = expected;
            while let Some((exp_span, rest)) = expected.split_first() {
                expected = rest;
                let act_span = loop {
                    let Some((span, rest)) = actual.split_first() else {
                        return Err(format!(
                            "no span found with name {:?}\n{}",
                            exp_span.name,
                            err_suffix(),
                        ));
                    };
                    actual = rest;
                    if span.name == exp_span.name {
                        break span;
                    }
                };
                self.match_span(act_span, exp_span)?;
            }
        } else {
            if actual.len() != expected.len() {
                return Err(format!(
                    "expected exactly {} spans, got {}\n{}",
                    expected.len(),
                    actual.len(),
                    err_suffix(),
                ));
            }

            for (act_span, exp_span) in actual.iter().zip(expected) {
                self.match_span(act_span, exp_span)?;
            }
        }

        Ok(())
    }

    fn match_span(&self, actual: &SpanData, expected: &ExpectedSpan) -> Result<(), String> {
        let err_suffix = || format!("actual:\n{:#?}\nexpected:\n{:#?}", actual, expected);
        if expected.name != actual.name {
            return Err(format!(
                "expected name {:?}, got {:?}\n{}",
                expected.name,
                actual.name,
                err_suffix(),
            ));
        }
        let mut actual_attrs = doc! {};
        for kv in &actual.attributes {
            let key = kv.key.as_str();
            let value = match key {
                "db.mongodb.lsid" => match &kv.value {
                    opentelemetry::Value::String(s) => {
                        let doc: Bson = serde_json::from_str::<serde_json::Value>(s.as_str())
                            .map_err(|e| format!("serde_json error: {}", e))?
                            .try_into()
                            .map_err(|e| format!("json value error: {}", e))?;
                        doc
                    }
                    _ => return Err(format!("unexpected type for {:?}: {:?}", key, kv.value)),
                },
                _ => value_to_bson(&kv.value)?,
            };
            actual_attrs.insert(key, value);
        }
        for (k, expected_v) in &expected.attributes {
            if let Err(e) =
                results_match(actual_attrs.get(k), expected_v, false, Some(self.entities))
            {
                return Err(format!("span attribute {}: {}\n{}", k, e, err_suffix()));
            }
        }

        let actual_nested = self
            .nested
            .get(&actual.span_context.span_id())
            .map(|v| v.as_slice())
            .unwrap_or(&[]);
        self.match_span_slice(actual_nested, &expected.nested)?;

        Ok(())
    }
}

fn value_to_bson(val: &opentelemetry::Value) -> Result<Bson, String> {
    use opentelemetry::{Array, Value};
    Ok(match val {
        Value::Bool(b) => Bson::Boolean(*b),
        Value::I64(i) => Bson::Int64(*i),
        Value::F64(f) => Bson::Double(*f),
        Value::String(sv) => Bson::String(sv.as_str().to_owned()),
        Value::Array(array) => match array {
            Array::Bool(items) => items.into(),
            Array::I64(items) => items.into(),
            Array::F64(items) => items.into(),
            Array::String(items) => items.iter().map(|i| i.as_str()).collect::<Vec<_>>().into(),
            _ => return Err(format!("unhandled opentelemetry array {:?}", array)),
        },
        _ => return Err(format!("unhandled opentelemetry value {:?}", val)),
    })
}