pub(super) mod handshake;
#[cfg(test)]
mod test;
use self::handshake::Handshaker;
use super::{
conn::{ConnectionGeneration, PendingConnection},
options::ConnectionPoolOptions,
Connection,
PoolGeneration,
};
use crate::{
client::{auth::Credential, options::ServerApi},
error::{Error as MongoError, ErrorKind},
runtime::HttpClient,
sdam::HandshakePhase,
};
#[derive(Clone, Debug)]
pub(super) struct ConnectionEstablisher {
handshaker: Handshaker,
http_client: HttpClient,
credential: Option<Credential>,
server_api: Option<ServerApi>,
}
impl ConnectionEstablisher {
pub(super) fn new(http_client: HttpClient, options: Option<&ConnectionPoolOptions>) -> Self {
let handshaker = Handshaker::new(options.cloned().map(Into::into));
Self {
handshaker,
http_client,
credential: options.and_then(|options| options.credential.clone()),
server_api: options.and_then(|options| options.server_api.clone()),
}
}
pub(super) async fn establish_connection(
&self,
pending_connection: PendingConnection,
) -> std::result::Result<Connection, EstablishError> {
let pool_gen = pending_connection.generation.clone();
let mut connection = Connection::connect(pending_connection)
.await
.map_err(|e| EstablishError::pre_hello(e, pool_gen.clone()))?;
let handshake = self
.handshaker
.handshake(&mut connection, None, &None)
.await
.map_err(|e| EstablishError::pre_hello(e, pool_gen.clone()))?;
let service_id = handshake.is_master_reply.command_response.service_id;
match (pool_gen, service_id) {
(PoolGeneration::Normal(_), _) => {}
(PoolGeneration::LoadBalanced(gen_map), Some(service_id)) => {
connection.generation = ConnectionGeneration::LoadBalanced {
generation: *gen_map.get(&service_id).unwrap_or(&0),
service_id,
};
}
_ => {
return Err(EstablishError::post_hello(
ErrorKind::Internal {
message: "load-balanced mode mismatch".to_string(),
}
.into(),
connection.generation.clone(),
));
}
}
if let Some(ref credential) = self.credential {
credential
.authenticate_stream(
&mut connection,
&self.http_client,
self.server_api.as_ref(),
handshake.first_round,
)
.await
.map_err(|e| EstablishError::post_hello(e, connection.generation.clone()))?
}
Ok(connection)
}
}
#[derive(Debug, Clone)]
pub(crate) struct EstablishError {
pub(crate) cause: MongoError,
pub(crate) handshake_phase: HandshakePhase,
}
impl EstablishError {
fn pre_hello(cause: MongoError, generation: PoolGeneration) -> Self {
Self {
cause,
handshake_phase: HandshakePhase::PreHello { generation },
}
}
fn post_hello(cause: MongoError, generation: ConnectionGeneration) -> Self {
Self {
cause,
handshake_phase: HandshakePhase::PostHello { generation },
}
}
}