1use std::{
2 collections::HashMap,
3 io,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use async_std::{future::poll_fn, prelude::*};
9use async_tls::TlsConnector;
10use cassandra_proto::{
11 error,
12 frame::{parser_async::convert_frame_into_result, Frame, IntoBytes, Opcode},
13 query::{Query, QueryBatch, QueryParams},
14};
15use futures::stream::Stream;
16
17use crate::{
18 async_trait::async_trait,
19 authenticators::Authenticator,
20 compressor::Compression,
21 frame_channel::FrameChannel,
22 query::{BatchExecutor, ExecExecutor, PrepareExecutor, PreparedQuery, QueryExecutor},
23 transport::CDRSTransport,
24 utils::prepare_flags,
25 TransportTcp, TransportTls,
26};
27
28type StreamId = u16;
29
30pub struct Session<T> {
32 channel: FrameChannel<T>,
33 responses: HashMap<StreamId, Frame>,
34 authenticator: Authenticator,
35}
36
37macro_rules! receive_frame {
38 ($this: expr, $stream_id: expr) => {
39 poll_fn(|cx: &mut Context| {
40 if let Some(response) = $this.responses.remove(&$stream_id) {
41 return Poll::Ready(convert_frame_into_result(response));
42 }
43
44 match Pin::new(&mut $this.channel).poll_next(cx) {
45 Poll::Ready(Some(frame)) => {
46 if frame.stream == $stream_id {
47 return Poll::Ready(convert_frame_into_result(frame));
48 } else {
49 $this.responses.insert(frame.stream, frame);
50 return Poll::Pending;
51 }
52 }
53 Poll::Ready(None) => Poll::Ready(Err("stream was terminated".into())),
54 Poll::Pending => Poll::Pending,
55 }
56 })
57 };
58}
59
60impl Session<TransportTcp> {
61 pub async fn connect<Addr: ToString>(
62 addr: Addr,
63 compressor: Compression,
64 authenticator: Authenticator,
65 ) -> error::Result<Self> {
66 let transport = TransportTcp::new(&addr.to_string()).await?;
67 let channel = FrameChannel::new(transport, compressor);
68 let responses = HashMap::new();
69
70 let mut session = Session {
71 channel,
72 responses,
73 authenticator,
74 };
75
76 session.startup().await?;
77
78 Ok(session)
79 }
80}
81
82impl Session<TransportTls> {
83 pub async fn connect_tls<Addr: ToString>(
84 (addr, connector): (Addr, TlsConnector),
85 compressor: Compression,
86 authenticator: Authenticator,
87 ) -> error::Result<Self> {
88 let transport = TransportTls::new(&addr.to_string(), connector).await?;
89 let channel = FrameChannel::new(transport, compressor);
90 let responses = HashMap::new();
91
92 let mut session = Session {
93 channel,
94 responses,
95 authenticator,
96 };
97
98 session.startup().await?;
99
100 Ok(session)
101 }
102}
103
104impl<T: CDRSTransport> Session<T> {
105 async fn startup(&mut self) -> error::Result<()> {
106 let ref mut compression = Compression::None;
107 let startup_frame = Frame::new_req_startup(compression.as_str());
108 let stream = startup_frame.stream;
109
110 self.channel.write(&startup_frame.into_cbytes()).await?;
111 let start_response = receive_frame!(self, stream).await?;
112
113 if start_response.opcode == Opcode::Ready {
114 return Ok(());
115 }
116
117 if start_response.opcode == Opcode::Authenticate {
118 let body = start_response.get_body()?;
119 let authenticator = body.get_authenticator().expect(
120 "Cassandra Server did communicate that it neededs
121 authentication but the auth schema was missing in the body response",
122 );
123
124 let auth_check = self
132 .authenticator
133 .get_cassandra_name()
134 .ok_or(error::Error::General(
135 "No authenticator was provided".to_string(),
136 ))
137 .map(|auth| {
138 if authenticator != auth {
139 let io_err = io::Error::new(
140 io::ErrorKind::NotFound,
141 format!(
142 "Unsupported type of authenticator. {:?} got,
143 but {} is supported.",
144 authenticator, auth
145 ),
146 );
147 return Err(error::Error::Io(io_err));
148 }
149 Ok(())
150 });
151
152 if let Err(err) = auth_check {
153 return Err(err);
154 }
155
156 let auth_token_bytes =
157 self
158 .authenticator
159 .get_auth_token()
160 .into_plain()
161 .ok_or(error::Error::from(
162 "Authentication error: cannot get auth token",
163 ))?;
164 let auth_response = Frame::new_req_auth_response(auth_token_bytes);
165 let response_stream = auth_response.stream;
166
167 self.channel.write(&auth_response.into_cbytes()).await?;
168 receive_frame!(self, response_stream).await?;
169
170 return Ok(());
171 }
172
173 unreachable!();
174 }
175}
176
177#[async_trait]
178impl<T: CDRSTransport> QueryExecutor for Session<T> {
179 async fn query_with_params_tw<Q: ToString + Send>(
180 mut self: Pin<&mut Self>,
181 query: Q,
182 query_params: QueryParams,
183 with_tracing: bool,
184 with_warnings: bool,
185 ) -> error::Result<Frame> {
186 let query = Query {
187 query: query.to_string(),
188 params: query_params,
189 };
190
191 let flags = prepare_flags(with_tracing, with_warnings);
192 let query_frame = Frame::new_query(query, flags);
193 let stream = query_frame.stream;
194
195 self.channel.write(&query_frame.into_cbytes()).await?;
197 receive_frame!(self, stream).await
198 }
199}
200
201#[async_trait]
202impl<T: CDRSTransport> PrepareExecutor for Session<T> {
203 async fn prepare_tw<Q: ToString + Send>(
204 mut self: Pin<&mut Self>,
205 query: Q,
206 with_tracing: bool,
207 with_warnings: bool,
208 ) -> error::Result<PreparedQuery> {
209 let flags = prepare_flags(with_tracing, with_warnings);
210
211 let query_frame = Frame::new_req_prepare(query.to_string(), flags);
212 let stream = query_frame.stream;
213
214 self.channel.write(&query_frame.into_cbytes()).await?;
215
216 let prepared_id = receive_frame!(self, stream)
217 .await?
218 .get_body()?
219 .into_prepared()
220 .ok_or(error::Error::from(
221 "Cannot get prepared query ID from a response",
222 ))?
223 .id;
224
225 Ok(prepared_id)
226 }
227}
228
229#[async_trait]
230impl<T: CDRSTransport> ExecExecutor for Session<T> {
231 async fn exec_with_params_tw(
232 mut self: Pin<&mut Self>,
233 prepared: &PreparedQuery,
234 query_parameters: QueryParams,
235 with_tracing: bool,
236 with_warnings: bool,
237 ) -> error::Result<Frame> {
238 let flags = prepare_flags(with_tracing, with_warnings);
239 let executor_frame = Frame::new_req_execute(prepared, query_parameters, flags);
240 let stream = executor_frame.stream;
241
242 self.channel.write(&executor_frame.into_cbytes()).await?;
243 receive_frame!(self, stream).await
244 }
245}
246
247#[async_trait]
248impl<T: CDRSTransport> BatchExecutor for Session<T> {
249 async fn batch_with_params_tw(
250 mut self: Pin<&mut Self>,
251 batch: QueryBatch,
252 with_tracing: bool,
253 with_warnings: bool,
254 ) -> error::Result<Frame> {
255 let flags = prepare_flags(with_tracing, with_warnings);
256 let batch_frame = Frame::new_req_batch(batch, flags);
257 let stream = batch_frame.stream;
258
259 self.channel.write(&batch_frame.into_cbytes()).await?;
260 receive_frame!(self, stream).await
261 }
262}