use std::{collections::HashMap, sync::Arc, time::Duration};
use approx::abs_diff_eq;
use bson::{doc, Document};
use semver::VersionReq;
use serde::Deserialize;
use tokio::sync::RwLockWriteGuard;
use crate::{
options::ServerAddress,
runtime::AsyncJoinHandle,
sdam::{description::topology::server_selection, Server},
selection_criteria::ReadPreference,
test::{
run_spec_test,
EventClient,
FailCommandOptions,
FailPoint,
FailPointMode,
TestClient,
CLIENT_OPTIONS,
LOCK,
},
RUNTIME,
};
use super::TestTopologyDescription;
#[derive(Debug, Deserialize)]
struct TestFile {
description: String,
topology_description: TestTopologyDescription,
mocked_topology_state: Vec<TestServer>,
iterations: u32,
outcome: TestOutcome,
}
#[derive(Debug, Deserialize)]
struct TestOutcome {
tolerance: f64,
expected_frequencies: HashMap<ServerAddress, f64>,
}
#[derive(Debug, Deserialize)]
struct TestServer {
address: ServerAddress,
operation_count: u32,
}
async fn run_test(test_file: TestFile) {
println!("Running {}", test_file.description);
let mut tallies: HashMap<ServerAddress, u32> = HashMap::new();
let servers: HashMap<ServerAddress, Arc<Server>> = test_file
.mocked_topology_state
.into_iter()
.map(|desc| {
(
desc.address.clone(),
Arc::new(Server::new_mocked(desc.address, desc.operation_count)),
)
})
.collect();
let topology_description = test_file
.topology_description
.into_topology_description(None)
.unwrap();
let read_pref = ReadPreference::Nearest {
options: Default::default(),
}
.into();
for _ in 0..test_file.iterations {
let selection =
server_selection::attempt_to_select_server(&read_pref, &topology_description, &servers)
.expect("selection should not fail")
.expect("a server should have been selected");
*tallies.entry(selection.address.clone()).or_insert(0) += 1;
}
for (address, expected_frequency) in test_file.outcome.expected_frequencies {
let actual_frequency =
tallies.get(&address).cloned().unwrap_or(0) as f64 / (test_file.iterations as f64);
let epsilon = if expected_frequency != 1.0 && expected_frequency != 0.0 {
test_file.outcome.tolerance
} else {
f64::EPSILON
};
assert!(
abs_diff_eq!(actual_frequency, expected_frequency, epsilon = epsilon),
"{}: for server {} expected frequency = {}, actual = {}",
test_file.description,
address,
expected_frequency,
actual_frequency
);
}
}
#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn select_in_window() {
run_spec_test(&["server-selection", "in_window"], run_test).await;
}
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn load_balancing_test() {
let _guard: RwLockWriteGuard<_> = LOCK.run_exclusively().await;
let mut setup_client_options = CLIENT_OPTIONS.clone();
if setup_client_options.credential.is_some() {
println!("skipping load_balancing_test test due to auth being enabled");
return;
}
setup_client_options.hosts.drain(1..);
setup_client_options.direct_connection = Some(true);
let setup_client = TestClient::with_options(Some(setup_client_options)).await;
let version = VersionReq::parse(">= 4.2.9").unwrap();
if !version.matches(&setup_client.server_version) {
println!(
"skipping load_balancing_test test due to server not supporting blockConnection option"
);
return;
}
if !setup_client.is_sharded() {
println!("skipping load_balancing_test test due to topology not being sharded");
return;
}
if CLIENT_OPTIONS.hosts.len() != 2 {
println!("skipping load_balancing_test test due to topology not having 2 mongoses");
return;
}
setup_client
.database("load_balancing_test")
.collection("load_balancing_test")
.insert_one(doc! {}, None)
.await
.unwrap();
async fn do_test(client: &mut EventClient, min_share: f64, max_share: f64, iterations: usize) {
client.clear_cached_events();
let mut handles: Vec<AsyncJoinHandle<()>> = Vec::new();
for _ in 0..10 {
let collection = client
.database("load_balancing_test")
.collection::<Document>("load_balancing_test");
handles.push(
RUNTIME
.spawn(async move {
for _ in 0..iterations {
let _ = collection.find_one(None, None).await;
}
})
.unwrap(),
)
}
futures::future::join_all(handles).await;
let mut tallies: HashMap<ServerAddress, u32> = HashMap::new();
for event in client.get_command_started_events(&["find"]) {
*tallies.entry(event.connection.address.clone()).or_insert(0) += 1;
}
assert_eq!(tallies.len(), 2);
let mut counts: Vec<_> = tallies.values().collect();
counts.sort();
let share_of_selections = (*counts[0] as f64) / ((*counts[0] + *counts[1]) as f64);
assert!(
share_of_selections <= max_share,
"expected no more than {}% of selections, instead got {}%",
(max_share * 100.0) as u32,
(share_of_selections * 100.0) as u32
);
assert!(
share_of_selections >= min_share,
"expected at least {}% of selections, instead got {}%",
(min_share * 100.0) as u32,
(share_of_selections * 100.0) as u32
);
}
let mut client = EventClient::new().await;
do_test(&mut client, 0.0, 0.50, 100).await;
let options = FailCommandOptions::builder()
.block_connection(Duration::from_millis(500))
.build();
let failpoint = FailPoint::fail_command(&["find"], FailPointMode::AlwaysOn, options);
let fp_guard = setup_client
.enable_failpoint(failpoint, None)
.await
.expect("enabling failpoint should succeed");
do_test(&mut client, 0.05, 0.25, 10).await;
drop(fp_guard);
do_test(&mut client, 0.40, 0.50, 100).await;
}