use r2d2::{Builder, ManageConnection};
use std::cell::RefCell;
use std::io;
use std::io::Write;
use std::net::ToSocketAddrs;
use crate::authenticators::Authenticator;
use crate::cluster::ConnectionPool;
use crate::cluster::NodeTcpConfig;
use crate::compression::Compression;
use crate::error;
use crate::frame::parser::parse_frame;
use crate::frame::{Frame, IntoBytes, Opcode};
use crate::transport::{CDRSTransport, TransportTcp};
pub type TcpConnectionPool<A> = ConnectionPool<TcpConnectionsManager<A>>;
pub fn new_tcp_pool<'a, A: Authenticator + Send + Sync + 'static>(
node_config: NodeTcpConfig<'a, A>,
) -> error::Result<TcpConnectionPool<A>> {
let manager =
TcpConnectionsManager::new(node_config.addr.to_string(), node_config.authenticator);
let pool = Builder::new()
.max_size(node_config.max_size)
.min_idle(node_config.min_idle)
.max_lifetime(node_config.max_lifetime)
.idle_timeout(node_config.idle_timeout)
.connection_timeout(node_config.connection_timeout)
.build(manager)
.map_err(|err| error::Error::from(err.to_string()))?;
let addr = node_config
.addr
.to_socket_addrs()?
.next()
.ok_or_else(|| error::Error::from("Cannot parse address"))?;
Ok(TcpConnectionPool::new(pool, addr))
}
pub struct TcpConnectionsManager<A> {
addr: String,
auth: A,
}
impl<A> TcpConnectionsManager<A> {
pub fn new<S: ToString>(addr: S, auth: A) -> Self {
TcpConnectionsManager {
addr: addr.to_string(),
auth,
}
}
}
impl<A: Authenticator + 'static + Send + Sync> ManageConnection for TcpConnectionsManager<A> {
type Connection = RefCell<TransportTcp>;
type Error = error::Error;
fn connect(&self) -> Result<Self::Connection, Self::Error> {
let transport = RefCell::new(TransportTcp::new(&self.addr)?);
startup(&transport, &self.auth)?;
Ok(transport)
}
fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
let options_frame = Frame::new_req_options().into_cbytes();
conn.borrow_mut().write(options_frame.as_slice())?;
parse_frame(conn, &Compression::None {}).map(|_| ())
}
fn has_broken(&self, conn: &mut Self::Connection) -> bool {
!conn.borrow().is_alive()
}
}
pub fn startup<'b, T: CDRSTransport + 'static, A: Authenticator + 'static + Sized>(
transport: &RefCell<T>,
session_authenticator: &'b A,
) -> error::Result<()> {
let ref mut compression = Compression::None;
let startup_frame = Frame::new_req_startup(compression.as_str()).into_cbytes();
transport.borrow_mut().write(startup_frame.as_slice())?;
let start_response = parse_frame(transport, compression)?;
if start_response.opcode == Opcode::Ready {
return Ok(());
}
if start_response.opcode == Opcode::Authenticate {
let body = start_response.get_body()?;
let authenticator = body.get_authenticator().expect(
"Cassandra Server did communicate that it neededs
authentication but the auth schema was missing in the body response",
);
let auth_check = session_authenticator
.get_cassandra_name()
.ok_or(error::Error::General(
"No authenticator was provided".to_string(),
))
.map(|auth| {
if authenticator != auth {
let io_err = io::Error::new(
io::ErrorKind::NotFound,
format!(
"Unsupported type of authenticator. {:?} got,
but {} is supported.",
authenticator, auth
),
);
return Err(error::Error::Io(io_err));
}
Ok(())
});
if let Err(err) = auth_check {
return Err(err);
}
let auth_token_bytes = session_authenticator.get_auth_token();
transport.borrow_mut().write(
Frame::new_req_auth_response(auth_token_bytes)
.into_cbytes()
.as_slice(),
)?;
parse_frame(transport, compression)?;
return Ok(());
}
unreachable!();
}