cdrs_tokio/cluster/
connection_manager.rs

1use std::io;
2use std::net::SocketAddr;
3use tokio::sync::mpsc::Sender;
4
5#[cfg(test)]
6use mockall::*;
7
8use crate::cluster::KeyspaceHolder;
9use crate::future::BoxFuture;
10use crate::transport::CdrsTransport;
11use cassandra_protocol::authenticators::SaslAuthenticatorProvider;
12use cassandra_protocol::compression::Compression;
13use cassandra_protocol::error::{Error, Result};
14use cassandra_protocol::frame::message_response::ResponseBody;
15use cassandra_protocol::frame::{Envelope, Opcode, Version};
16use cassandra_protocol::query::utils::quote;
17
18/// Manages establishing connections to nodes.
19pub trait ConnectionManager<T: CdrsTransport>: Send + Sync {
20    /// Tries to establish a new, ready to use connection with optional server event and error
21    /// handlers.
22    fn connection(
23        &self,
24        event_handler: Option<Sender<Envelope>>,
25        error_handler: Option<Sender<Error>>,
26        addr: SocketAddr,
27    ) -> BoxFuture<'_, Result<T>>;
28}
29
30#[cfg(test)]
31mock! {
32    pub ConnectionManager<T: CdrsTransport> {
33    }
34
35    #[allow(dead_code)]
36    impl<T: CdrsTransport> ConnectionManager<T> for ConnectionManager<T> {
37        fn connection<'a>(
38            &'a self,
39            event_handler: Option<Sender<Envelope>>,
40            error_handler: Option<Sender<Error>>,
41            addr: SocketAddr,
42        ) -> BoxFuture<'a, Result<T>>;
43    }
44}
45
46/// Establishes Cassandra connection with given authentication, last used keyspace and compression.
47pub async fn startup<
48    T: CdrsTransport + 'static,
49    A: SaslAuthenticatorProvider + Send + Sync + ?Sized + 'static,
50>(
51    transport: &T,
52    authenticator_provider: &A,
53    keyspace_holder: &KeyspaceHolder,
54    compression: Compression,
55    version: Version,
56) -> Result<()> {
57    let startup_envelope =
58        Envelope::new_req_startup(compression.as_str().map(String::from), version);
59
60    let start_response = match transport.write_envelope(&startup_envelope, true).await {
61        Ok(response) => Ok(response),
62        Err(Error::Server { body, .. }) if body.is_bad_protocol() => {
63            Err(Error::InvalidProtocol(transport.address()))
64        }
65        Err(error) => Err(error),
66    }?;
67
68    if start_response.opcode == Opcode::Ready {
69        return set_keyspace(transport, keyspace_holder, version).await;
70    }
71
72    if start_response.opcode == Opcode::Authenticate {
73        let body = start_response.response_body()?;
74        let authenticator = body.authenticator()
75            .ok_or_else(|| Error::General("Cassandra server did communicate that it needed authentication but the auth schema was missing in the body response".into()))?;
76
77        // This creates a new scope; avoiding a clone
78        // and we check whether
79        // 1. any authenticators has been passed in by client and if not send error back
80        // 2. authenticator is provided by the client and `auth_scheme` presented by
81        //      the server and client are same if not send error back
82        // 3. if it falls through it means the preliminary conditions are true
83
84        authenticator_provider
85            .name()
86            .ok_or_else(|| Error::General("No authenticator was provided".to_string()))
87            .and_then(|auth| {
88                if authenticator != auth {
89                    let io_err = io::Error::new(
90                        io::ErrorKind::NotFound,
91                        format!(
92                            "Unsupported type of authenticator. {authenticator:?} got,
93                             but {auth} is supported."
94                        ),
95                    );
96                    return Err(Error::Io(io_err));
97                }
98                Ok(())
99            })?;
100
101        let authenticator = authenticator_provider.create_authenticator();
102        let response = authenticator.initial_response();
103        let mut envelope = transport
104            .write_envelope(&Envelope::new_req_auth_response(response, version), false)
105            .await?;
106
107        loop {
108            match envelope.response_body()? {
109                ResponseBody::AuthChallenge(challenge) => {
110                    let response = authenticator.evaluate_challenge(challenge.data)?;
111
112                    envelope = transport
113                        .write_envelope(&Envelope::new_req_auth_response(response, version), false)
114                        .await?;
115                }
116                ResponseBody::AuthSuccess(success) => {
117                    authenticator.handle_success(success.data)?;
118                    break;
119                }
120                _ => return Err(Error::UnexpectedAuthResponse(envelope.opcode)),
121            }
122        }
123
124        return set_keyspace(transport, keyspace_holder, version).await;
125    }
126
127    Err(Error::UnexpectedStartupResponse(start_response.opcode))
128}
129
130async fn set_keyspace<T: CdrsTransport>(
131    transport: &T,
132    keyspace_holder: &KeyspaceHolder,
133    version: Version,
134) -> Result<()> {
135    if let Some(current_keyspace) = keyspace_holder.current_keyspace() {
136        let use_envelope = Envelope::new_req_query(
137            format!("USE {}", quote(current_keyspace.as_ref())),
138            Default::default(),
139            None,
140            false,
141            None,
142            None,
143            None,
144            None,
145            None,
146            None,
147            Default::default(),
148            version,
149        );
150
151        transport
152            .write_envelope(&use_envelope, false)
153            .await
154            .map(|_| ())
155    } else {
156        Ok(())
157    }
158}