use std::{collections::HashMap, sync::Arc, time::Duration};
use bson::Document;
use serde::Deserialize;
use tokio::sync::{RwLockReadGuard, RwLockWriteGuard};
use super::TestSdamEvent;
use crate::{
bson::{doc, oid::ObjectId},
client::Client,
cmap::{conn::ConnectionGeneration, PoolGeneration},
error::{BulkWriteFailure, CommandError, Error, ErrorKind},
is_master::{IsMasterCommandResponse, IsMasterReply, LastWrite},
options::{ClientOptions, ReadPreference, SelectionCriteria, ServerAddress},
sdam::{
description::{
server::{ServerDescription, ServerType},
topology::TopologyType,
},
HandshakePhase,
Topology,
TopologyDescription,
},
selection_criteria::TagSet,
test::{
run_spec_test,
Event,
EventClient,
EventHandler,
FailCommandOptions,
FailPoint,
FailPointMode,
SdamEvent,
TestClient,
CLIENT_OPTIONS,
LOCK,
},
};
#[derive(Debug, Deserialize)]
pub struct TestFile {
description: String,
uri: String,
phases: Vec<Phase>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Phase {
description: Option<String>,
#[serde(default)]
responses: Vec<Response>,
#[serde(default)]
application_errors: Vec<ApplicationError>,
outcome: Outcome,
}
#[derive(Debug, Deserialize)]
pub struct Response(String, TestIsMasterCommandResponse);
#[derive(Debug, Clone, Default, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub(crate) struct TestIsMasterCommandResponse {
pub is_writable_primary: Option<bool>,
#[serde(rename = "ismaster")]
pub is_master: Option<bool>,
pub ok: Option<f32>,
pub hosts: Option<Vec<String>>,
pub passives: Option<Vec<String>>,
pub arbiters: Option<Vec<String>>,
pub msg: Option<String>,
pub me: Option<String>,
pub set_version: Option<i32>,
pub set_name: Option<String>,
pub hidden: Option<bool>,
pub secondary: Option<bool>,
pub arbiter_only: Option<bool>,
#[serde(rename = "isreplicaset")]
pub is_replica_set: Option<bool>,
pub logical_session_timeout_minutes: Option<i64>,
pub last_write: Option<LastWrite>,
pub min_wire_version: Option<i32>,
pub max_wire_version: Option<i32>,
pub tags: Option<TagSet>,
pub election_id: Option<ObjectId>,
pub primary: Option<String>,
pub sasl_supported_mechs: Option<Vec<String>>,
pub speculative_authenticate: Option<Document>,
pub max_bson_object_size: Option<i64>,
pub max_write_batch_size: Option<i64>,
pub service_id: Option<ObjectId>,
}
impl From<TestIsMasterCommandResponse> for IsMasterCommandResponse {
fn from(test: TestIsMasterCommandResponse) -> Self {
IsMasterCommandResponse {
is_writable_primary: test.is_writable_primary,
is_master: test.is_master,
hosts: test.hosts,
passives: test.passives,
arbiters: test.arbiters,
msg: test.msg,
me: test.me,
set_version: test.set_version,
set_name: test.set_name,
hidden: test.hidden,
secondary: test.secondary,
arbiter_only: test.arbiter_only,
is_replica_set: test.is_replica_set,
logical_session_timeout_minutes: test.logical_session_timeout_minutes,
last_write: test.last_write,
min_wire_version: test.min_wire_version,
max_wire_version: test.max_wire_version,
tags: test.tags,
election_id: test.election_id,
primary: test.primary,
sasl_supported_mechs: test.sasl_supported_mechs,
speculative_authenticate: test.speculative_authenticate,
max_bson_object_size: test.max_bson_object_size.unwrap_or(1234),
max_write_batch_size: test.max_write_batch_size.unwrap_or(1234),
service_id: test.service_id,
}
}
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ApplicationError {
address: ServerAddress,
generation: Option<u32>,
max_wire_version: i32,
when: ErrorHandshakePhase,
#[serde(rename = "type")]
error_type: ErrorType,
response: Option<ServerError>,
}
impl ApplicationError {
fn to_error(&self) -> Error {
match self.error_type {
ErrorType::Command => self.response.clone().unwrap().into(),
ErrorType::Network => {
ErrorKind::Io(Arc::new(std::io::ErrorKind::UnexpectedEof.into())).into()
}
ErrorType::Timeout => {
ErrorKind::Io(Arc::new(std::io::ErrorKind::TimedOut.into())).into()
}
}
}
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum ErrorHandshakePhase {
BeforeHandshakeCompletes,
AfterHandshakeCompletes,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum ErrorType {
Command,
Network,
Timeout,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(untagged)]
pub enum ServerError {
CommandError(CommandError),
WriteError(BulkWriteFailure),
}
impl From<ServerError> for Error {
fn from(server_error: ServerError) -> Self {
match server_error {
ServerError::CommandError(command_error) => ErrorKind::Command(command_error).into(),
ServerError::WriteError(bwf) => ErrorKind::BulkWrite(bwf).into(),
}
}
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum Outcome {
Description(DescriptionOutcome),
Events(EventsOutcome),
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DescriptionOutcome {
topology_type: TopologyType,
set_name: Option<String>,
servers: HashMap<String, Server>,
logical_session_timeout_minutes: Option<i32>,
compatible: Option<bool>,
}
#[derive(Debug, Deserialize)]
pub struct EventsOutcome {
events: Vec<TestSdamEvent>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Server {
#[serde(rename = "type")]
server_type: String,
set_name: Option<String>,
set_version: Option<i32>,
election_id: Option<ObjectId>,
logical_session_timeout_minutes: Option<i32>,
min_wire_version: Option<i32>,
max_wire_version: Option<i32>,
}
fn server_type_from_str(s: &str) -> Option<ServerType> {
let t = match s {
"Standalone" => ServerType::Standalone,
"Mongos" => ServerType::Mongos,
"RSPrimary" => ServerType::RsPrimary,
"RSSecondary" => ServerType::RsSecondary,
"RSArbiter" => ServerType::RsArbiter,
"RSOther" => ServerType::RsOther,
"RSGhost" => ServerType::RsGhost,
"Unknown" | "PossiblePrimary" => ServerType::Unknown,
_ => return None,
};
Some(t)
}
async fn run_test(test_file: TestFile) {
let test_description = &test_file.description;
if test_description.contains("topologyVersion") {
println!("Skipping {} (RUST-360)", test_description);
return;
}
if test_description.contains("load balancer") {
println!("Skipping {} (RUST-653)", test_description);
return;
}
let mut options = ClientOptions::parse_uri(&test_file.uri, None)
.await
.expect(test_description);
let handler = Arc::new(EventHandler::new());
options.sdam_event_handler = Some(handler.clone());
let topology = Topology::new_mocked(options.clone());
let mut servers = topology.get_servers().await;
for (i, phase) in test_file.phases.into_iter().enumerate() {
for Response(address, command_response) in phase.responses {
let address = ServerAddress::parse(&address).unwrap_or_else(|_| {
panic!(
"{}: couldn't parse address \"{:?}\"",
test_description.as_str(),
address
)
});
let is_master_reply = if command_response.ok != Some(1.0) {
Err(Error::from(ErrorKind::Command(CommandError {
code: 1234,
code_name: "dummy error".to_string(),
message: "dummy".to_string(),
})))
} else if command_response == Default::default() {
Err(Error::from(ErrorKind::Io(Arc::new(
std::io::ErrorKind::BrokenPipe.into(),
))))
} else {
Ok(IsMasterReply {
server_address: address.clone(),
command_response: command_response.into(),
round_trip_time: Duration::from_millis(1234), cluster_time: None,
})
};
if let Some(server) = servers.get(&address).and_then(|s| s.upgrade()) {
match is_master_reply {
Ok(reply) => {
let new_sd = ServerDescription::new(address.clone(), Some(Ok(reply)));
if topology.update(&server, new_sd).await {
servers = topology.get_servers().await
}
}
Err(e) => {
topology.handle_monitor_error(e, &server).await;
}
}
}
}
for application_error in phase.application_errors {
if let Some(server) = servers
.get(&application_error.address)
.and_then(|s| s.upgrade())
{
let error = application_error.to_error();
let pool_generation = application_error
.generation
.map(PoolGeneration::Normal)
.unwrap_or_else(|| server.pool.generation());
let conn_generation = application_error
.generation
.or_else(|| server.pool.generation().as_normal())
.unwrap_or(0);
let conn_generation = ConnectionGeneration::Normal(conn_generation);
let handshake_phase = match application_error.when {
ErrorHandshakePhase::BeforeHandshakeCompletes => HandshakePhase::PreHello {
generation: pool_generation,
},
ErrorHandshakePhase::AfterHandshakeCompletes => {
HandshakePhase::AfterCompletion {
generation: conn_generation,
max_wire_version: application_error.max_wire_version,
}
}
};
topology
.handle_application_error(error, handshake_phase, &server)
.await;
}
}
let topology_description = topology.description().await;
let phase_description = phase.description.unwrap_or_else(|| format!("{}", i));
match phase.outcome {
Outcome::Description(outcome) => {
verify_description_outcome(
outcome,
topology_description,
test_description,
phase_description,
);
}
Outcome::Events(EventsOutcome { events: expected }) => {
let actual = handler.get_all_sdam_events();
assert_eq!(actual.len(), expected.len());
for (actual, expected) in actual.iter().zip(expected.iter()) {
assert_eq!(
actual, expected,
"SDAM events do not match:\n actual: {:#?}, expected: {:#?}",
actual, expected
);
}
}
}
}
}
fn verify_description_outcome(
outcome: DescriptionOutcome,
topology_description: TopologyDescription,
test_description: &str,
phase_description: String,
) {
assert_eq!(
topology_description.topology_type, outcome.topology_type,
"{}: {}",
test_description, phase_description
);
assert_eq!(
topology_description.set_name, outcome.set_name,
"{}: {}",
test_description, phase_description,
);
let expected_timeout = outcome
.logical_session_timeout_minutes
.map(|mins| Duration::from_secs((mins as u64) * 60));
assert_eq!(
topology_description
.session_support_status
.logical_session_timeout(),
expected_timeout,
"{}: {}",
test_description,
phase_description
);
if let Some(compatible) = outcome.compatible {
assert_eq!(
topology_description.compatibility_error.is_none(),
compatible,
"{}: {}",
test_description,
phase_description,
);
}
assert_eq!(
topology_description.servers.len(),
outcome.servers.len(),
"{}: {}",
test_description,
phase_description
);
for (address, server) in outcome.servers {
let address = ServerAddress::parse(&address).unwrap_or_else(|_| {
panic!(
"{}: couldn't parse address \"{:?}\"",
test_description, address
)
});
let actual_server = &topology_description
.servers
.get(&address)
.unwrap_or_else(|| panic!("{} (phase {})", test_description, phase_description));
let server_type = server_type_from_str(&server.server_type)
.unwrap_or_else(|| panic!("{} (phase {})", test_description, phase_description));
assert_eq!(
actual_server.server_type, server_type,
"{} (phase {}, address: {})",
test_description, phase_description, address,
);
assert_eq!(
actual_server.set_name().unwrap_or(None),
server.set_name,
"{} (phase {})",
test_description,
phase_description
);
assert_eq!(
actual_server.set_version().unwrap_or(None),
server.set_version,
"{} (phase {})",
test_description,
phase_description
);
assert_eq!(
actual_server.election_id().unwrap_or(None),
server.election_id,
"{} (phase {})",
test_description,
phase_description
);
if let Some(logical_session_timeout_minutes) = server.logical_session_timeout_minutes {
assert_eq!(
actual_server.logical_session_timeout().unwrap(),
Some(Duration::from_secs(
logical_session_timeout_minutes as u64 * 60
)),
"{} (phase {})",
test_description,
phase_description
);
}
if let Some(min_wire_version) = server.min_wire_version {
assert_eq!(
actual_server.min_wire_version().unwrap(),
Some(min_wire_version),
"{} (phase {})",
test_description,
phase_description
);
}
if let Some(max_wire_version) = server.max_wire_version {
assert_eq!(
actual_server.max_wire_version().unwrap(),
Some(max_wire_version),
"{} (phase {})",
test_description,
phase_description
);
}
}
}
#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn single() {
run_spec_test(&["server-discovery-and-monitoring", "single"], run_test).await;
}
#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn rs() {
run_spec_test(&["server-discovery-and-monitoring", "rs"], run_test).await;
}
#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn sharded() {
run_spec_test(&["server-discovery-and-monitoring", "sharded"], run_test).await;
}
#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn errors() {
run_spec_test(&["server-discovery-and-monitoring", "errors"], run_test).await;
}
#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn monitoring() {
run_spec_test(&["server-discovery-and-monitoring", "monitoring"], run_test).await;
}
#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
#[function_name::named]
async fn topology_closed_event_last() {
let _guard: RwLockReadGuard<_> = LOCK.run_concurrently().await;
let event_handler = EventHandler::new();
let mut subscriber = event_handler.subscribe();
let client = EventClient::with_additional_options(
None,
Some(Duration::from_millis(50)),
None,
event_handler.clone(),
)
.await;
client
.database(function_name!())
.collection(function_name!())
.insert_one(doc! { "x": 1 }, None)
.await
.unwrap();
drop(client);
subscriber
.wait_for_event(Duration::from_millis(500), |event| {
matches!(event, Event::Sdam(SdamEvent::TopologyClosed(_)))
})
.await
.expect("should see topology closed event");
assert!(subscriber
.wait_for_event(Duration::from_millis(500), |event| {
matches!(event, Event::Sdam(_))
})
.await
.is_none());
}
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn heartbeat_events() {
let _guard: RwLockWriteGuard<_> = LOCK.run_exclusively().await;
let mut options = CLIENT_OPTIONS.clone();
options.hosts.drain(1..);
options.heartbeat_freq = Some(Duration::from_millis(50));
let event_handler = EventHandler::new();
let mut subscriber = event_handler.subscribe();
let client = EventClient::with_additional_options(
Some(options),
Some(Duration::from_millis(50)),
None,
event_handler.clone(),
)
.await;
subscriber
.wait_for_event(Duration::from_millis(500), |event| {
matches!(event, Event::Sdam(SdamEvent::ServerHeartbeatStarted(_)))
})
.await
.expect("should see server heartbeat started event");
subscriber
.wait_for_event(Duration::from_millis(500), |event| {
matches!(event, Event::Sdam(SdamEvent::ServerHeartbeatSucceeded(_)))
})
.await
.expect("should see server heartbeat succeeded event");
if !client.supports_fail_command().await {
return;
}
let options = FailCommandOptions::builder().error_code(1234).build();
let failpoint =
FailPoint::fail_command(&["isMaster", "hello"], FailPointMode::Times(1), options);
let _fp_guard = client
.enable_failpoint(failpoint, None)
.await
.expect("enabling failpoint should succeed");
subscriber
.wait_for_event(Duration::from_millis(500), |event| {
matches!(event, Event::Sdam(SdamEvent::ServerHeartbeatFailed(_)))
})
.await
.expect("should see server heartbeat failed event");
}
#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
#[function_name::named]
async fn direct_connection() {
let _guard: RwLockReadGuard<_> = LOCK.run_concurrently().await;
let test_client = TestClient::new().await;
if !test_client.is_replica_set() {
println!("Skipping due to non-replica set topology");
return;
}
let criteria = SelectionCriteria::ReadPreference(ReadPreference::Secondary {
options: Default::default(),
});
let secondary_address = test_client
.test_select_server(Some(&criteria))
.await
.expect("failed to select secondary");
let mut secondary_options = CLIENT_OPTIONS.clone();
secondary_options.hosts = vec![secondary_address];
let mut direct_false_options = secondary_options.clone();
direct_false_options.direct_connection = Some(false);
let direct_false_client =
Client::with_options(direct_false_options).expect("client construction should succeed");
direct_false_client
.database(function_name!())
.collection(function_name!())
.insert_one(doc! {}, None)
.await
.expect("write should succeed with directConnection=false on secondary");
let mut direct_true_options = secondary_options.clone();
direct_true_options.direct_connection = Some(true);
let direct_true_client =
Client::with_options(direct_true_options).expect("client construction should succeed");
let error = direct_true_client
.database(function_name!())
.collection(function_name!())
.insert_one(doc! {}, None)
.await
.expect_err("write should fail with directConnection=true on secondary");
assert!(error.is_not_master());
let client =
Client::with_options(secondary_options).expect("client construction should succeed");
client
.database(function_name!())
.collection(function_name!())
.insert_one(doc! {}, None)
.await
.expect("write should succeed with directConnection unspecified");
}
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn pool_cleared_error_does_not_mark_unknown() {
let address = ServerAddress::parse("a:1234").unwrap();
let options = ClientOptions::builder()
.hosts(vec![address.clone()])
.build();
let topology = Topology::new_mocked(options);
let server = topology
.get_servers()
.await
.into_iter()
.next()
.unwrap()
.1
.upgrade()
.unwrap();
let heartbeat_response: IsMasterCommandResponse = bson::from_document(doc! {
"ok": 1,
"ismaster": true,
"minWireVersion": 0,
"maxWireVersion": 6,
"maxBsonObjectSize": 16_000,
"maxWriteBatchSize": 10_000,
})
.unwrap();
topology
.update(
&server,
ServerDescription::new(
address.clone(),
Some(Ok(IsMasterReply {
server_address: address.clone(),
command_response: heartbeat_response,
round_trip_time: Duration::from_secs(1),
cluster_time: None,
})),
),
)
.await;
assert_eq!(
topology
.get_server_description(&address)
.await
.unwrap()
.server_type,
ServerType::Standalone
);
let error: Error = ErrorKind::ConnectionPoolCleared {
message: "foo".to_string(),
}
.into();
let phase = HandshakePhase::PreHello {
generation: server.pool.generation(),
};
assert!(
!topology
.handle_application_error(error, phase, &server)
.await
);
assert_eq!(
topology
.get_server_description(&address)
.await
.unwrap()
.server_type,
ServerType::Standalone
);
}