cdrs_async/
session.rs

1use std::{
2  collections::HashMap,
3  io,
4  pin::Pin,
5  task::{Context, Poll},
6};
7
8use async_std::{future::poll_fn, prelude::*};
9use async_tls::TlsConnector;
10use cassandra_proto::{
11  error,
12  frame::{parser_async::convert_frame_into_result, Frame, IntoBytes, Opcode},
13  query::{Query, QueryBatch, QueryParams},
14};
15use futures::stream::Stream;
16
17use crate::{
18  async_trait::async_trait,
19  authenticators::Authenticator,
20  compressor::Compression,
21  frame_channel::FrameChannel,
22  query::{BatchExecutor, ExecExecutor, PrepareExecutor, PreparedQuery, QueryExecutor},
23  transport::CDRSTransport,
24  utils::prepare_flags,
25  TransportTcp, TransportTls,
26};
27
28type StreamId = u16;
29
30/// Session structure which allows clients making requests to a server.
31pub struct Session<T> {
32  channel: FrameChannel<T>,
33  responses: HashMap<StreamId, Frame>,
34  authenticator: Authenticator,
35}
36
37macro_rules! receive_frame {
38  ($this: expr, $stream_id: expr) => {
39    poll_fn(|cx: &mut Context| {
40      if let Some(response) = $this.responses.remove(&$stream_id) {
41        return Poll::Ready(convert_frame_into_result(response));
42      }
43
44      match Pin::new(&mut $this.channel).poll_next(cx) {
45        Poll::Ready(Some(frame)) => {
46          if frame.stream == $stream_id {
47            return Poll::Ready(convert_frame_into_result(frame));
48          } else {
49            $this.responses.insert(frame.stream, frame);
50            return Poll::Pending;
51          }
52        }
53        Poll::Ready(None) => Poll::Ready(Err("stream was terminated".into())),
54        Poll::Pending => Poll::Pending,
55      }
56    })
57  };
58}
59
60impl Session<TransportTcp> {
61  pub async fn connect<Addr: ToString>(
62    addr: Addr,
63    compressor: Compression,
64    authenticator: Authenticator,
65  ) -> error::Result<Self> {
66    let transport = TransportTcp::new(&addr.to_string()).await?;
67    let channel = FrameChannel::new(transport, compressor);
68    let responses = HashMap::new();
69
70    let mut session = Session {
71      channel,
72      responses,
73      authenticator,
74    };
75
76    session.startup().await?;
77
78    Ok(session)
79  }
80}
81
82impl Session<TransportTls> {
83  pub async fn connect_tls<Addr: ToString>(
84    (addr, connector): (Addr, TlsConnector),
85    compressor: Compression,
86    authenticator: Authenticator,
87  ) -> error::Result<Self> {
88    let transport = TransportTls::new(&addr.to_string(), connector).await?;
89    let channel = FrameChannel::new(transport, compressor);
90    let responses = HashMap::new();
91
92    let mut session = Session {
93      channel,
94      responses,
95      authenticator,
96    };
97
98    session.startup().await?;
99
100    Ok(session)
101  }
102}
103
104impl<T: CDRSTransport> Session<T> {
105  async fn startup(&mut self) -> error::Result<()> {
106    let ref mut compression = Compression::None;
107    let startup_frame = Frame::new_req_startup(compression.as_str());
108    let stream = startup_frame.stream;
109
110    self.channel.write(&startup_frame.into_cbytes()).await?;
111    let start_response = receive_frame!(self, stream).await?;
112
113    if start_response.opcode == Opcode::Ready {
114      return Ok(());
115    }
116
117    if start_response.opcode == Opcode::Authenticate {
118      let body = start_response.get_body()?;
119      let authenticator = body.get_authenticator().expect(
120        "Cassandra Server did communicate that it neededs
121                authentication but the auth schema was missing in the body response",
122      );
123
124      // This creates a new scope; avoiding a clone
125      // and we check whether
126      // 1. any authenticators has been passed in by client and if not send error back
127      // 2. authenticator is provided by the client and `auth_scheme` presented by
128      //      the server and client are same if not send error back
129      // 3. if it falls through it means the preliminary conditions are true
130
131      let auth_check = self
132        .authenticator
133        .get_cassandra_name()
134        .ok_or(error::Error::General(
135          "No authenticator was provided".to_string(),
136        ))
137        .map(|auth| {
138          if authenticator != auth {
139            let io_err = io::Error::new(
140              io::ErrorKind::NotFound,
141              format!(
142                "Unsupported type of authenticator. {:?} got,
143                             but {} is supported.",
144                authenticator, auth
145              ),
146            );
147            return Err(error::Error::Io(io_err));
148          }
149          Ok(())
150        });
151
152      if let Err(err) = auth_check {
153        return Err(err);
154      }
155
156      let auth_token_bytes =
157        self
158          .authenticator
159          .get_auth_token()
160          .into_plain()
161          .ok_or(error::Error::from(
162            "Authentication error: cannot get auth token",
163          ))?;
164      let auth_response = Frame::new_req_auth_response(auth_token_bytes);
165      let response_stream = auth_response.stream;
166
167      self.channel.write(&auth_response.into_cbytes()).await?;
168      receive_frame!(self, response_stream).await?;
169
170      return Ok(());
171    }
172
173    unreachable!();
174  }
175}
176
177#[async_trait]
178impl<T: CDRSTransport> QueryExecutor for Session<T> {
179  async fn query_with_params_tw<Q: ToString + Send>(
180    mut self: Pin<&mut Self>,
181    query: Q,
182    query_params: QueryParams,
183    with_tracing: bool,
184    with_warnings: bool,
185  ) -> error::Result<Frame> {
186    let query = Query {
187      query: query.to_string(),
188      params: query_params,
189    };
190
191    let flags = prepare_flags(with_tracing, with_warnings);
192    let query_frame = Frame::new_query(query, flags);
193    let stream = query_frame.stream;
194
195    // send frame
196    self.channel.write(&query_frame.into_cbytes()).await?;
197    receive_frame!(self, stream).await
198  }
199}
200
201#[async_trait]
202impl<T: CDRSTransport> PrepareExecutor for Session<T> {
203  async fn prepare_tw<Q: ToString + Send>(
204    mut self: Pin<&mut Self>,
205    query: Q,
206    with_tracing: bool,
207    with_warnings: bool,
208  ) -> error::Result<PreparedQuery> {
209    let flags = prepare_flags(with_tracing, with_warnings);
210
211    let query_frame = Frame::new_req_prepare(query.to_string(), flags);
212    let stream = query_frame.stream;
213
214    self.channel.write(&query_frame.into_cbytes()).await?;
215
216    let prepared_id = receive_frame!(self, stream)
217      .await?
218      .get_body()?
219      .into_prepared()
220      .ok_or(error::Error::from(
221        "Cannot get prepared query ID from a response",
222      ))?
223      .id;
224
225    Ok(prepared_id)
226  }
227}
228
229#[async_trait]
230impl<T: CDRSTransport> ExecExecutor for Session<T> {
231  async fn exec_with_params_tw(
232    mut self: Pin<&mut Self>,
233    prepared: &PreparedQuery,
234    query_parameters: QueryParams,
235    with_tracing: bool,
236    with_warnings: bool,
237  ) -> error::Result<Frame> {
238    let flags = prepare_flags(with_tracing, with_warnings);
239    let executor_frame = Frame::new_req_execute(prepared, query_parameters, flags);
240    let stream = executor_frame.stream;
241
242    self.channel.write(&executor_frame.into_cbytes()).await?;
243    receive_frame!(self, stream).await
244  }
245}
246
247#[async_trait]
248impl<T: CDRSTransport> BatchExecutor for Session<T> {
249  async fn batch_with_params_tw(
250    mut self: Pin<&mut Self>,
251    batch: QueryBatch,
252    with_tracing: bool,
253    with_warnings: bool,
254  ) -> error::Result<Frame> {
255    let flags = prepare_flags(with_tracing, with_warnings);
256    let batch_frame = Frame::new_req_batch(batch, flags);
257    let stream = batch_frame.stream;
258
259    self.channel.write(&batch_frame.into_cbytes()).await?;
260    receive_frame!(self, stream).await
261  }
262}