cdrs_tokio/cluster/
connection_manager.rs1use std::io;
2use std::net::SocketAddr;
3use tokio::sync::mpsc::Sender;
4
5#[cfg(test)]
6use mockall::*;
7
8use crate::cluster::KeyspaceHolder;
9use crate::future::BoxFuture;
10use crate::transport::CdrsTransport;
11use cassandra_protocol::authenticators::SaslAuthenticatorProvider;
12use cassandra_protocol::compression::Compression;
13use cassandra_protocol::error::{Error, Result};
14use cassandra_protocol::frame::message_response::ResponseBody;
15use cassandra_protocol::frame::{Envelope, Opcode, Version};
16use cassandra_protocol::query::utils::quote;
17
18pub trait ConnectionManager<T: CdrsTransport>: Send + Sync {
20 fn connection(
23 &self,
24 event_handler: Option<Sender<Envelope>>,
25 error_handler: Option<Sender<Error>>,
26 addr: SocketAddr,
27 ) -> BoxFuture<'_, Result<T>>;
28}
29
30#[cfg(test)]
31mock! {
32 pub ConnectionManager<T: CdrsTransport> {
33 }
34
35 #[allow(dead_code)]
36 impl<T: CdrsTransport> ConnectionManager<T> for ConnectionManager<T> {
37 fn connection<'a>(
38 &'a self,
39 event_handler: Option<Sender<Envelope>>,
40 error_handler: Option<Sender<Error>>,
41 addr: SocketAddr,
42 ) -> BoxFuture<'a, Result<T>>;
43 }
44}
45
46pub async fn startup<
48 T: CdrsTransport + 'static,
49 A: SaslAuthenticatorProvider + Send + Sync + ?Sized + 'static,
50>(
51 transport: &T,
52 authenticator_provider: &A,
53 keyspace_holder: &KeyspaceHolder,
54 compression: Compression,
55 version: Version,
56) -> Result<()> {
57 let startup_envelope =
58 Envelope::new_req_startup(compression.as_str().map(String::from), version);
59
60 let start_response = match transport.write_envelope(&startup_envelope, true).await {
61 Ok(response) => Ok(response),
62 Err(Error::Server { body, .. }) if body.is_bad_protocol() => {
63 Err(Error::InvalidProtocol(transport.address()))
64 }
65 Err(error) => Err(error),
66 }?;
67
68 if start_response.opcode == Opcode::Ready {
69 return set_keyspace(transport, keyspace_holder, version).await;
70 }
71
72 if start_response.opcode == Opcode::Authenticate {
73 let body = start_response.response_body()?;
74 let authenticator = body.authenticator()
75 .ok_or_else(|| Error::General("Cassandra server did communicate that it needed authentication but the auth schema was missing in the body response".into()))?;
76
77 authenticator_provider
85 .name()
86 .ok_or_else(|| Error::General("No authenticator was provided".to_string()))
87 .and_then(|auth| {
88 if authenticator != auth {
89 let io_err = io::Error::new(
90 io::ErrorKind::NotFound,
91 format!(
92 "Unsupported type of authenticator. {authenticator:?} got,
93 but {auth} is supported."
94 ),
95 );
96 return Err(Error::Io(io_err));
97 }
98 Ok(())
99 })?;
100
101 let authenticator = authenticator_provider.create_authenticator();
102 let response = authenticator.initial_response();
103 let mut envelope = transport
104 .write_envelope(&Envelope::new_req_auth_response(response, version), false)
105 .await?;
106
107 loop {
108 match envelope.response_body()? {
109 ResponseBody::AuthChallenge(challenge) => {
110 let response = authenticator.evaluate_challenge(challenge.data)?;
111
112 envelope = transport
113 .write_envelope(&Envelope::new_req_auth_response(response, version), false)
114 .await?;
115 }
116 ResponseBody::AuthSuccess(success) => {
117 authenticator.handle_success(success.data)?;
118 break;
119 }
120 _ => return Err(Error::UnexpectedAuthResponse(envelope.opcode)),
121 }
122 }
123
124 return set_keyspace(transport, keyspace_holder, version).await;
125 }
126
127 Err(Error::UnexpectedStartupResponse(start_response.opcode))
128}
129
130async fn set_keyspace<T: CdrsTransport>(
131 transport: &T,
132 keyspace_holder: &KeyspaceHolder,
133 version: Version,
134) -> Result<()> {
135 if let Some(current_keyspace) = keyspace_holder.current_keyspace() {
136 let use_envelope = Envelope::new_req_query(
137 format!("USE {}", quote(current_keyspace.as_ref())),
138 Default::default(),
139 None,
140 false,
141 None,
142 None,
143 None,
144 None,
145 None,
146 None,
147 Default::default(),
148 version,
149 );
150
151 transport
152 .write_envelope(&use_envelope, false)
153 .await
154 .map(|_| ())
155 } else {
156 Ok(())
157 }
158}