mongodb 1.2.1

The official MongoDB driver for Rust
Documentation
use std::{collections::HashMap, time::Duration};

use serde::Deserialize;
use tokio::sync::RwLockReadGuard;

use crate::{
    bson::{doc, oid::ObjectId},
    client::Client,
    error::ErrorKind,
    is_master::{IsMasterCommandResponse, IsMasterReply},
    options::{ClientOptions, ReadPreference, SelectionCriteria, StreamAddress},
    sdam::description::{
        server::{ServerDescription, ServerType},
        topology::{TopologyDescription, TopologyType},
    },
    test::{run_spec_test, TestClient, CLIENT_OPTIONS, LOCK},
};

#[derive(Debug, Deserialize)]
pub struct TestFile {
    description: String,
    uri: String,
    phases: Vec<Phase>,
}

#[derive(Debug, Deserialize)]
pub struct Phase {
    responses: Vec<Response>,
    outcome: Outcome,
}

#[derive(Debug, Deserialize)]
pub struct Response(String, IsMasterCommandResponse);

#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Outcome {
    topology_type: TopologyType,
    set_name: Option<String>,
    servers: HashMap<String, Server>,
    logical_session_timeout_minutes: Option<i32>,
    compatible: Option<bool>,
}

#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Server {
    #[serde(rename = "type")]
    server_type: String,
    set_name: Option<String>,
    set_version: Option<i32>,
    election_id: Option<ObjectId>,
    logical_session_timeout_minutes: Option<i32>,
    min_wire_version: Option<i32>,
    max_wire_version: Option<i32>,
}

fn server_type_from_str(s: &str) -> Option<ServerType> {
    let t = match s {
        "Standalone" => ServerType::Standalone,
        "Mongos" => ServerType::Mongos,
        "RSPrimary" => ServerType::RSPrimary,
        "RSSecondary" => ServerType::RSSecondary,
        "RSArbiter" => ServerType::RSArbiter,
        "RSOther" => ServerType::RSOther,
        "RSGhost" => ServerType::RSGhost,
        "Unknown" | "PossiblePrimary" => ServerType::Unknown,
        _ => return None,
    };

    Some(t)
}

async fn run_test(test_file: TestFile) {
    let options = ClientOptions::parse_uri(&test_file.uri, None)
        .await
        .expect(&test_file.description);

    let test_description = &test_file.description;
    let mut topology_description = TopologyDescription::new(options).expect(test_description);

    for (i, phase) in test_file.phases.into_iter().enumerate() {
        for Response(address, command_response) in phase.responses {
            let is_master_reply = if command_response == Default::default() {
                Err(ErrorKind::OperationError {
                    message: "dummy error".to_string(),
                }
                .into())
            } else {
                Ok(IsMasterReply {
                    command_response,
                    round_trip_time: Some(Duration::from_millis(1234)), // Doesn't matter for tests.
                    cluster_time: None,
                })
            };

            let address = StreamAddress::parse(&address).unwrap_or_else(|_| {
                panic!(
                    "{}: couldn't parse address \"{:?}\"",
                    test_description.as_str(),
                    address
                )
            });
            topology_description
                .update(ServerDescription::new(
                    address.clone(),
                    Some(is_master_reply),
                ))
                .expect(&test_file.description);
        }

        assert_eq!(
            topology_description.topology_type, phase.outcome.topology_type,
            "{}: {}",
            &test_file.description, i,
        );

        assert_eq!(
            topology_description.set_name, phase.outcome.set_name,
            "{}: {}",
            &test_file.description, i,
        );

        let expected_timeout = phase
            .outcome
            .logical_session_timeout_minutes
            .map(|mins| Duration::from_secs((mins as u64) * 60));
        assert_eq!(
            topology_description
                .session_support_status
                .logical_session_timeout(),
            expected_timeout,
            "{}: {}",
            &test_file.description,
            i
        );

        if let Some(compatible) = phase.outcome.compatible {
            assert_eq!(
                topology_description.compatibility_error.is_none(),
                compatible,
                "{}: {}",
                &test_file.description,
                i,
            );
        }

        assert_eq!(
            topology_description.servers.len(),
            phase.outcome.servers.len(),
            "{}: {}",
            &test_file.description,
            i
        );

        for (address, server) in phase.outcome.servers {
            let address = StreamAddress::parse(&address).unwrap_or_else(|_| {
                panic!(
                    "{}: couldn't parse address \"{:?}\"",
                    test_description, address
                )
            });
            let actual_server = &topology_description
                .servers
                .get(&address)
                .unwrap_or_else(|| panic!("{} (phase {})", test_description, i));

            let server_type = server_type_from_str(&server.server_type)
                .unwrap_or_else(|| panic!("{} (phase {})", test_description, i));

            assert_eq!(
                actual_server.server_type, server_type,
                "{} (phase {})",
                &test_file.description, i
            );

            assert_eq!(
                actual_server.set_name().unwrap_or(None),
                server.set_name,
                "{} (phase {})",
                &test_file.description,
                i
            );

            assert_eq!(
                actual_server.set_version().unwrap_or(None),
                server.set_version,
                "{} (phase {})",
                &test_file.description,
                i
            );

            assert_eq!(
                actual_server.election_id().unwrap_or(None),
                server.election_id,
                "{} (phase {})",
                &test_file.description,
                i
            );

            if let Some(logical_session_timeout_minutes) = server.logical_session_timeout_minutes {
                assert_eq!(
                    actual_server.logical_session_timeout().unwrap(),
                    Some(Duration::from_secs(
                        logical_session_timeout_minutes as u64 * 60
                    )),
                    "{} (phase {})",
                    &test_file.description,
                    i
                );
            }

            if let Some(min_wire_version) = server.min_wire_version {
                assert_eq!(
                    actual_server.min_wire_version().unwrap(),
                    Some(min_wire_version),
                    "{} (phase {})",
                    &test_file.description,
                    i
                );
            }

            if let Some(max_wire_version) = server.max_wire_version {
                assert_eq!(
                    actual_server.max_wire_version().unwrap(),
                    Some(max_wire_version),
                    "{} (phase {})",
                    &test_file.description,
                    i
                );
            }
        }
    }
}

#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn single() {
    run_spec_test(&["server-discovery-and-monitoring", "single"], run_test).await;
}

#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn rs() {
    run_spec_test(&["server-discovery-and-monitoring", "rs"], run_test).await;
}

#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn sharded() {
    run_spec_test(&["server-discovery-and-monitoring", "sharded"], run_test).await;
}

#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
#[function_name::named]
async fn direct_connection() {
    let _guard: RwLockReadGuard<_> = LOCK.run_concurrently().await;

    let test_client = TestClient::new().await;
    if !test_client.is_replica_set() {
        println!("Skipping due to non-replica set topology");
        return;
    }

    let criteria = SelectionCriteria::ReadPreference(ReadPreference::Secondary {
        options: Default::default(),
    });
    let secondary_address = test_client
        .test_select_server(Some(&criteria))
        .await
        .expect("failed to select secondary");

    let mut secondary_options = CLIENT_OPTIONS.clone();
    secondary_options.hosts = vec![secondary_address];

    let mut direct_false_options = secondary_options.clone();
    direct_false_options.direct_connection = Some(false);
    let direct_false_client =
        Client::with_options(direct_false_options).expect("client construction should succeed");
    direct_false_client
        .database(function_name!())
        .collection(function_name!())
        .insert_one(doc! {}, None)
        .await
        .expect("write should succeed with directConnection=false on secondary");

    let mut direct_true_options = secondary_options.clone();
    direct_true_options.direct_connection = Some(true);
    let direct_true_client =
        Client::with_options(direct_true_options).expect("client construction should succeed");
    let error = direct_true_client
        .database(function_name!())
        .collection(function_name!())
        .insert_one(doc! {}, None)
        .await
        .expect_err("write should fail with directConnection=true on secondary");
    assert!(error.is_not_master());

    let client =
        Client::with_options(secondary_options).expect("client construction should succeed");
    client
        .database(function_name!())
        .collection(function_name!())
        .insert_one(doc! {}, None)
        .await
        .expect("write should succeed with directConnection unspecified");
}