use std::{
collections::HashSet,
time::{Duration, Instant},
};
use crate::bson::doc;
use crate::{
client::options::{ClientOptions, ServerAddress},
cmap::RawCommandResponse,
error::{Error, ErrorKind},
event::{cmap::CmapEvent, sdam::SdamEvent},
hello::{LEGACY_HELLO_COMMAND_NAME, LEGACY_HELLO_COMMAND_NAME_LOWERCASE},
sdam::{ServerDescription, Topology},
test::{
fail_command_appname_initial_handshake_supported,
get_client_options,
log_uncaptured,
server_version_matches,
topology_is_load_balanced,
topology_is_replica_set,
util::{
event_buffer::EventBuffer,
fail_point::{FailPoint, FailPointMode},
},
Event,
},
Client,
};
#[tokio::test(flavor = "multi_thread")]
async fn min_heartbeat_frequency() {
if topology_is_load_balanced().await {
log_uncaptured("skipping min_heartbeat_frequency test due to load-balanced topology");
return;
}
if !fail_command_appname_initial_handshake_supported().await {
log_uncaptured(
"skipping min_heartbeat_frequency test due to server not supporting failcommand \
appname",
);
return;
}
let mut setup_client_options = get_client_options().await.clone();
setup_client_options.hosts.drain(1..);
setup_client_options.direct_connection = Some(true);
let setup_client = Client::for_test()
.options(setup_client_options.clone())
.await;
let _guard = setup_client
.enable_fail_point(
FailPoint::fail_command(
&[LEGACY_HELLO_COMMAND_NAME, "hello"],
FailPointMode::Times(5),
)
.app_name("SDAMMinHeartbeatFrequencyTest")
.error_code(1234),
)
.await
.unwrap();
let mut options = setup_client_options;
options.app_name = Some("SDAMMinHeartbeatFrequencyTest".to_string());
options.server_selection_timeout = Some(Duration::from_secs(5));
let client = Client::with_options(options).expect("client creation succeeds");
let start = Instant::now();
client
.database("admin")
.run_command(doc! { "ping": 1 })
.await
.expect("ping should eventually succeed");
let elapsed = Instant::now().duration_since(start).as_millis();
assert!(
elapsed >= 2000,
"expected to take at least 2 seconds, instead took {elapsed}ms"
);
assert!(
elapsed <= 3500,
"expected to take at most 3.5 seconds, instead took {elapsed}ms"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn sdam_pool_management() {
if !server_version_matches(">= 4.2.9").await {
log_uncaptured(
"skipping sdam_pool_management test due to server not supporting appName failCommand",
);
return;
}
let mut options = get_client_options().await.clone();
if options.load_balanced.unwrap_or(false) {
log_uncaptured("skipping sdam_pool_management test due to load-balanced topology");
return;
}
options.hosts.drain(1..);
options.direct_connection = Some(true);
options.app_name = Some("SDAMPoolManagementTest".to_string());
options.heartbeat_freq = Some(Duration::from_millis(50));
let client = Client::for_test()
.options(options)
.use_single_mongos()
.min_heartbeat_freq(Duration::from_millis(50))
.monitor_events()
.await;
let mut subscriber = client.events.stream_all();
subscriber
.next_match(Duration::from_millis(500), |event| {
matches!(event, Event::Cmap(CmapEvent::PoolReady(_)))
})
.await
.expect("should see pool ready event");
subscriber
.next_match(Duration::from_millis(500), |event| {
matches!(event, Event::Sdam(SdamEvent::ServerHeartbeatSucceeded(_)))
})
.await
.expect("should see server heartbeat succeeded event");
let _guard = client
.enable_fail_point(
FailPoint::fail_command(
&[LEGACY_HELLO_COMMAND_NAME, "hello"],
FailPointMode::Times(4),
)
.app_name("SDAMPoolManagementTest")
.error_code(1234),
)
.await
.unwrap();
let events = subscriber.collect(Duration::from_secs(1), |_| true).await;
assert!(events
.iter()
.any(|e| matches!(e, Event::Sdam(SdamEvent::ServerHeartbeatFailed(_)))));
assert!(events
.iter()
.any(|e| matches!(e, Event::Cmap(CmapEvent::PoolCleared(_)))));
assert!(events
.iter()
.any(|e| matches!(e, Event::Cmap(CmapEvent::PoolReady(_)))));
assert!(events
.iter()
.any(|e| matches!(e, Event::Sdam(SdamEvent::ServerHeartbeatSucceeded(_)))));
}
#[tokio::test(flavor = "multi_thread")]
async fn hello_ok_true() {
if !server_version_matches(">= 4.4.5").await {
log_uncaptured("skipping hello_ok_true test due to server not supporting hello");
return;
}
let mut setup_client_options = get_client_options().await.clone();
setup_client_options.hosts.drain(1..);
if setup_client_options.server_api.is_some() {
log_uncaptured("skipping hello_ok_true test due to ServerApi being configured");
return;
}
if setup_client_options.load_balanced == Some(true) {
log_uncaptured("skipping hello_ok_true test due to load balanced topology");
return;
}
let buffer = EventBuffer::new();
let mut event_stream = buffer.stream();
let mut options = setup_client_options.clone();
options.sdam_event_handler = Some(buffer.handler());
options.direct_connection = Some(true);
options.heartbeat_freq = Some(Duration::from_millis(500));
let _client = Client::with_options(options).expect("client creation should succeed");
event_stream
.next_match(Duration::from_millis(2000), |event| {
if let Event::Sdam(SdamEvent::ServerHeartbeatSucceeded(e)) = event {
assert!(e.reply.get_bool("helloOk").unwrap());
assert!(e.reply.get(LEGACY_HELLO_COMMAND_NAME_LOWERCASE).is_some());
assert!(e.reply.get("isWritablePrimary").is_none());
return true;
}
false
})
.await
.expect("first heartbeat reply should contain helloOk: true");
for _ in 0..3 {
event_stream
.next_match(Duration::from_millis(2000), |event| {
if let Event::Sdam(SdamEvent::ServerHeartbeatSucceeded(e)) = event {
assert!(e.reply.get("isWritablePrimary").is_some());
assert!(e.reply.get(LEGACY_HELLO_COMMAND_NAME_LOWERCASE).is_none());
return true;
}
false
})
.await
.expect("subsequent heartbeats should use hello");
}
}
#[tokio::test]
async fn repl_set_name_mismatch() -> crate::error::Result<()> {
if !topology_is_replica_set().await {
log_uncaptured("skipping repl_set_name_mismatch due to non-replica set topology");
return Ok(());
}
let mut options = get_client_options().await.clone();
options.hosts.drain(1..);
options.direct_connection = Some(true);
options.repl_set_name = Some("invalid".to_string());
options.server_selection_timeout = Some(Duration::from_secs(5));
let client = Client::with_options(options)?;
let result = client.list_database_names().await;
assert!(
match result {
Err(Error { ref kind, .. }) => matches!(**kind, ErrorKind::ServerSelection { .. }),
_ => false,
},
"Unexpected result {result:?}"
);
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn removed_server_monitor_stops() -> crate::error::Result<()> {
let buffer = EventBuffer::new();
let options = ClientOptions::builder()
.hosts(vec![
ServerAddress::parse("localhost:49152")?,
ServerAddress::parse("localhost:49153")?,
ServerAddress::parse("localhost:49154")?,
])
.heartbeat_freq(Duration::from_millis(50))
.sdam_event_handler(buffer.handler())
.repl_set_name("foo".to_string())
.build();
let hosts = options.hosts.clone();
let set_name = options.repl_set_name.clone().unwrap();
let mut event_stream = buffer.stream();
let topology = Topology::new(options)?;
let mut seen_monitors = HashSet::new();
event_stream
.next_match(Duration::from_millis(500), |event| {
if let Event::Sdam(SdamEvent::ServerHeartbeatStarted(e)) = event {
seen_monitors.insert(e.server_address.clone());
}
seen_monitors.len() == hosts.len()
})
.await
.expect("should see all three monitors start");
let hello = doc! {
"ok": 1,
"isWritablePrimary": true,
"hosts": [
hosts[0].clone().to_string(),
hosts[1].clone().to_string(),
],
"me": hosts[0].clone().to_string(),
"setName": set_name,
"maxBsonObjectSize": 1234,
"maxWriteBatchSize": 1234,
"maxMessageSizeBytes": 1234,
"minWireVersion": 0,
"maxWireVersion": 13,
};
let hello_reply = RawCommandResponse::with_document_and_address(hosts[0].clone(), hello)
.unwrap()
.into_hello_reply()
.unwrap();
topology
.updater()
.update(ServerDescription::new_from_hello_reply(
hosts[0].clone(),
hello_reply,
Duration::from_millis(10),
))
.await;
event_stream.next_match(Duration::from_secs(1), |event| {
matches!(event, Event::Sdam(SdamEvent::ServerClosed(e)) if e.address == hosts[2])
}).await.expect("should see server closed event");
let events = event_stream.collect(Duration::from_secs(1), |event| {
matches!(event, Event::Sdam(SdamEvent::ServerHeartbeatStarted(e)) if e.server_address == hosts[2])
}).await;
assert!(
events.len() < 3,
"expected monitor for removed server to stop performing checks, but saw {} heartbeats",
events.len()
);
Ok(())
}
#[test]
fn ipv6_invalid_me() {
let addr = ServerAddress::Tcp {
host: "::1".to_string(),
port: Some(8191),
};
let desc = ServerDescription {
address: addr.clone(),
server_type: super::ServerType::RsSecondary,
last_update_time: None,
average_round_trip_time: None,
reply: Ok(Some(crate::hello::HelloReply {
server_address: addr.clone(),
command_response: crate::hello::HelloCommandResponse {
me: Some("[::1]:8191".to_string()),
..Default::default()
},
raw_command_response: crate::bson::RawDocumentBuf::new(),
cluster_time: None,
})),
};
assert!(!desc.invalid_me().unwrap());
}
#[tokio::test(flavor = "multi_thread")]
async fn shutdown_does_not_deadlock() {
let mut options = get_client_options().await.clone();
options.heartbeat_freq = Some(Duration::from_millis(500));
options.test_options_mut().topology_worker_shutdown_delay = Some(Duration::from_millis(600));
use crate::event::{command::CommandEvent, EventHandler};
options.command_event_handler = Some(EventHandler::callback(|_event: CommandEvent| {
std::thread::yield_now();
}));
for i in 0..5 {
let client = Client::with_options(options.clone()).unwrap();
let _ = client
.database("admin")
.run_command(doc! { "ping": 1 })
.await;
tokio::time::sleep(Duration::from_millis(20)).await;
let result = tokio::time::timeout(Duration::from_secs(5), async {
client.shutdown().immediate(true).await;
})
.await;
assert!(result.is_ok(), "shutdown() deadlocked on iteration {}", i);
}
}