pub(crate) mod handshake;
use std::time::{Duration, Instant};
#[cfg(test)]
use crate::options::ClientOptions;
use crate::{
client::{
auth::Credential,
options::{ServerAddress, TlsOptions},
},
error::{Error as MongoError, ErrorKind, Result},
hello::HelloReply,
options::Socks5Proxy,
runtime,
runtime::{stream::DEFAULT_CONNECT_TIMEOUT, AsyncStream, TlsConfig},
sdam::{topology::TopologySpec, HandshakePhase},
};
use super::{
conn::{
pooled::PooledConnection,
ConnectionGeneration,
LoadBalancedGeneration,
PendingConnection,
},
Connection,
PoolGeneration,
};
use handshake::{Handshaker, HandshakerOptions};
#[derive(Clone)]
pub(crate) struct ConnectionEstablisher {
handshaker: Handshaker,
tls_config: Option<TlsConfig>,
connect_timeout: Duration,
proxy: Option<Socks5Proxy>,
#[cfg(test)]
test_patch_reply: Option<fn(&mut Result<HelloReply>)>,
}
pub(crate) struct EstablisherOptions {
handshake_options: HandshakerOptions,
tls_options: Option<TlsOptions>,
connect_timeout: Option<Duration>,
#[allow(unused)]
proxy: Option<Socks5Proxy>,
#[cfg(test)]
pub(crate) test_patch_reply: Option<fn(&mut Result<HelloReply>)>,
}
impl From<&TopologySpec> for EstablisherOptions {
fn from(spec: &TopologySpec) -> Self {
Self {
handshake_options: HandshakerOptions::from(spec),
tls_options: spec.options.tls_options(),
connect_timeout: spec.options.connect_timeout,
#[cfg(test)]
test_patch_reply: None,
#[cfg(feature = "socks5-proxy")]
proxy: spec.options.socks5_proxy.clone(),
#[cfg(not(feature = "socks5-proxy"))]
proxy: None,
}
}
}
#[cfg(test)]
impl From<&ClientOptions> for EstablisherOptions {
fn from(options: &ClientOptions) -> Self {
Self::from(&TopologySpec::try_from(options.clone()).unwrap())
}
}
impl ConnectionEstablisher {
pub(crate) fn new(options: EstablisherOptions) -> Result<Self> {
let handshaker = Handshaker::new(options.handshake_options)?;
let tls_config = if let Some(tls_options) = options.tls_options {
Some(TlsConfig::new(tls_options)?)
} else {
None
};
let connect_timeout = match options.connect_timeout {
Some(d) if d.is_zero() => Duration::MAX,
Some(d) => d,
None => DEFAULT_CONNECT_TIMEOUT,
};
Ok(Self {
handshaker,
tls_config,
connect_timeout,
#[cfg(test)]
test_patch_reply: options.test_patch_reply,
#[cfg(feature = "socks5-proxy")]
proxy: options.proxy,
#[cfg(not(feature = "socks5-proxy"))]
proxy: None,
})
}
async fn make_stream(&self, address: ServerAddress) -> Result<AsyncStream> {
runtime::timeout(
self.connect_timeout,
AsyncStream::connect(address, self.tls_config.as_ref(), self.proxy.as_ref()),
)
.await?
}
pub(crate) async fn establish_connection(
&self,
mut pending_connection: PendingConnection,
credential: Option<&Credential>,
) -> std::result::Result<PooledConnection, EstablishError> {
let pool_gen = pending_connection.generation.clone();
let address = pending_connection.address.clone();
let cancellation_receiver = pending_connection.cancellation_receiver.take();
let stream = self
.make_stream(address)
.await
.map_err(|e| EstablishError::pre_hello(e, pool_gen.clone()))?;
let mut connection = PooledConnection::new(pending_connection, stream);
#[allow(unused_mut)]
let mut handshake_result = self
.handshaker
.handshake(&mut connection, credential, cancellation_receiver)
.await;
#[cfg(test)]
if let Some(patch) = self.test_patch_reply {
patch(&mut handshake_result);
}
match (&pool_gen, connection.service_id()) {
(PoolGeneration::Normal(_), _) => {}
(PoolGeneration::LoadBalanced(gen_map), Some(service_id)) => {
connection.generation = LoadBalancedGeneration {
generation: *gen_map.get(&service_id).unwrap_or(&0),
service_id,
}
.into();
}
(PoolGeneration::LoadBalanced(_), None) => {
if handshake_result.is_ok() {
return Err(EstablishError::post_hello(
ErrorKind::IncompatibleServer {
message: "Driver attempted to initialize in load balancing mode, but \
the server does not support this mode."
.to_string(),
}
.into(),
connection.generation,
));
}
}
}
handshake_result.map_err(|e| {
if connection.stream_description().is_err() {
EstablishError::pre_hello(e, pool_gen)
} else {
EstablishError::post_hello(e, connection.generation)
}
})?;
Ok(connection)
}
pub(crate) async fn establish_monitoring_connection(
&self,
address: ServerAddress,
id: u32,
) -> Result<(Connection, HelloReply)> {
let stream = self.make_stream(address.clone()).await?;
let mut connection = Connection::new(address, stream, id, Instant::now());
let hello_reply = self
.handshaker
.handshake(&mut connection, None, None)
.await?;
Ok((connection, hello_reply))
}
}
#[derive(Debug, Clone)]
pub(crate) struct EstablishError {
pub(crate) cause: MongoError,
pub(crate) handshake_phase: HandshakePhase,
}
impl EstablishError {
fn pre_hello(cause: MongoError, generation: PoolGeneration) -> Self {
Self {
cause,
handshake_phase: HandshakePhase::PreHello { generation },
}
}
fn post_hello(cause: MongoError, generation: ConnectionGeneration) -> Self {
Self {
cause,
handshake_phase: HandshakePhase::PostHello { generation },
}
}
}