mod event;
mod lock;
mod matchable;
pub use self::{
event::{CommandEvent, EventClient},
lock::TestLock,
matchable::{assert_matches, Matchable},
};
use std::{collections::HashMap, fmt::Debug, sync::Arc, time::Duration};
use crate::bson::{doc, oid::ObjectId, Bson};
use semver::Version;
use serde::Deserialize;
use self::event::EventHandler;
use super::CLIENT_OPTIONS;
use crate::{
error::{CommandError, ErrorKind, Result},
operation::RunCommand,
options::{AuthMechanism, ClientOptions, CollectionOptions, CreateCollectionOptions},
Client,
Collection,
};
pub struct TestClient {
client: Client,
pub options: ClientOptions,
pub server_info: IsMasterCommandResponse,
pub server_version: Version,
}
impl std::ops::Deref for TestClient {
type Target = Client;
fn deref(&self) -> &Self::Target {
&self.client
}
}
impl TestClient {
pub async fn new() -> Self {
Self::with_options(None).await
}
pub async fn with_options(options: Option<ClientOptions>) -> Self {
Self::with_handler(None, options).await
}
async fn with_handler(
event_handler: Option<EventHandler>,
options: impl Into<Option<ClientOptions>>,
) -> Self {
let mut options = options.into().unwrap_or_else(|| CLIENT_OPTIONS.clone());
if let Some(event_handler) = event_handler {
let handler = Arc::new(event_handler);
options.command_event_handler = Some(handler.clone());
options.cmap_event_handler = Some(handler);
}
let client = Client::with_options(options.clone()).unwrap();
let mut session = client
.start_implicit_session_with_timeout(Duration::from_secs(60 * 60))
.await;
session.mark_dirty();
let is_master = RunCommand::new("admin".into(), doc! { "isMaster": 1 }, None).unwrap();
let server_info = bson::from_bson(Bson::Document(
client
.execute_operation_with_session(is_master, &mut session)
.await
.unwrap(),
))
.unwrap();
let build_info = RunCommand::new("test".into(), doc! { "buildInfo": 1 }, None).unwrap();
let response = client
.execute_operation_with_session(build_info, &mut session)
.await
.unwrap();
let info: BuildInfo = bson::from_bson(Bson::Document(response)).unwrap();
let server_version = info.version.split('-').next().unwrap();
let server_version = Version::parse(server_version).unwrap();
Self {
client,
options,
server_info,
server_version,
}
}
pub async fn with_additional_options(
options: Option<ClientOptions>,
use_multiple_mongoses: bool,
) -> Self {
let mut options = match options {
Some(mut options) => {
options.merge(CLIENT_OPTIONS.clone());
options
}
None => CLIENT_OPTIONS.clone(),
};
if !use_multiple_mongoses && Self::new().await.is_sharded() {
options.hosts = options.hosts.iter().cloned().take(1).collect();
}
Self::with_options(Some(options)).await
}
pub async fn create_user(
&self,
user: &str,
pwd: &str,
roles: &[Bson],
mechanisms: &[AuthMechanism],
) -> Result<()> {
let mut cmd = doc! { "createUser": user, "pwd": pwd, "roles": roles };
if self.server_version_gte(4, 0) {
let ms: bson::Array = mechanisms.iter().map(|s| Bson::from(s.as_str())).collect();
cmd.insert("mechanisms", ms);
}
self.database("admin").run_command(cmd, None).await?;
Ok(())
}
pub fn get_coll(&self, db_name: &str, coll_name: &str) -> Collection {
self.database(db_name).collection(coll_name)
}
pub async fn init_db_and_coll(&self, db_name: &str, coll_name: &str) -> Collection {
let coll = self.get_coll(db_name, coll_name);
drop_collection(&coll).await;
coll
}
pub fn get_coll_with_options(
&self,
db_name: &str,
coll_name: &str,
options: CollectionOptions,
) -> Collection {
self.database(db_name)
.collection_with_options(coll_name, options)
}
pub async fn init_db_and_coll_with_options(
&self,
db_name: &str,
coll_name: &str,
options: CollectionOptions,
) -> Collection {
let coll = self.get_coll_with_options(db_name, coll_name, options);
drop_collection(&coll).await;
coll
}
pub async fn create_fresh_collection(
&self,
db_name: &str,
coll_name: &str,
options: impl Into<Option<CreateCollectionOptions>>,
) -> Collection {
self.drop_collection(db_name, coll_name).await;
self.database(db_name)
.create_collection(coll_name, options)
.await
.unwrap();
self.get_coll(db_name, coll_name)
}
pub fn auth_enabled(&self) -> bool {
self.options.credential.is_some()
}
pub fn is_standalone(&self) -> bool {
!self.is_replica_set() && !self.is_sharded()
}
pub fn is_replica_set(&self) -> bool {
self.options.repl_set_name.is_some()
}
pub fn is_sharded(&self) -> bool {
self.server_info.msg.as_deref() == Some("isdbgrid")
}
#[allow(dead_code)]
pub fn server_version_eq(&self, major: u64, minor: u64) -> bool {
self.server_version.major == major && self.server_version.minor == minor
}
#[allow(dead_code)]
pub fn server_version_gt(&self, major: u64, minor: u64) -> bool {
self.server_version.major > major
|| (self.server_version.major == major && self.server_version.minor > minor)
}
pub fn server_version_gte(&self, major: u64, minor: u64) -> bool {
self.server_version.major > major
|| (self.server_version.major == major && self.server_version.minor >= minor)
}
pub fn server_version_lt(&self, major: u64, minor: u64) -> bool {
self.server_version.major < major
|| (self.server_version.major == major && self.server_version.minor < minor)
}
#[allow(dead_code)]
pub fn server_version_lte(&self, major: u64, minor: u64) -> bool {
self.server_version.major < major
|| (self.server_version.major == major && self.server_version.minor <= minor)
}
pub async fn drop_collection(&self, db_name: &str, coll_name: &str) {
let coll = self.get_coll(db_name, coll_name);
drop_collection(&coll).await;
}
}
pub async fn drop_collection(coll: &Collection) {
match coll.drop(None).await.as_ref().map_err(|e| e.as_ref()) {
Err(ErrorKind::CommandError(CommandError { code: 26, .. })) | Ok(_) => {}
e @ Err(_) => {
e.unwrap();
}
};
}
#[derive(Debug, Deserialize)]
struct BuildInfo {
version: String,
}
#[derive(Debug, Default, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct IsMasterCommandResponse {
#[serde(rename = "ismaster")]
pub is_master: Option<bool>,
pub ok: Option<f32>,
pub hosts: Option<Vec<String>>,
pub passives: Option<Vec<String>>,
pub arbiters: Option<Vec<String>>,
pub msg: Option<String>,
pub me: Option<String>,
pub set_version: Option<i32>,
pub set_name: Option<String>,
pub hidden: Option<bool>,
pub secondary: Option<bool>,
pub arbiter_only: Option<bool>,
#[serde(rename = "isreplicaset")]
pub is_replica_set: Option<bool>,
pub logical_session_timeout_minutes: Option<i64>,
pub min_wire_version: Option<i32>,
pub max_wire_version: Option<i32>,
pub tags: Option<HashMap<String, String>>,
pub election_id: Option<ObjectId>,
pub primary: Option<String>,
}
pub fn get_db_name(description: &str) -> String {
let mut db_name = description.replace('$', "%").replace(' ', "_");
db_name.truncate(63);
db_name
}