use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::io::{Read, Write};
use std::net::{Ipv4Addr, Shutdown, SocketAddr, TcpStream, ToSocketAddrs};
use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
use std::sync::mpsc::{Receiver, Sender};
use std::sync::{Arc, Mutex, mpsc};
use std::thread;
use std::time::Duration;
use bytebuffer::ByteBuffer;
use byteorder::{BigEndian, ReadBytesExt};
use crate::encode::{Value, VoltError};
use crate::procedure_invocation::new_procedure_invocation;
use crate::protocol::{PING_HANDLE, build_auth_message, parse_auth_response};
use crate::response::VoltResponseInfo;
use crate::table::{VoltTable, new_volt_table};
use crate::volt_param;
#[cfg(feature = "tracing")]
macro_rules! node_error {
($($arg:tt)*) => { tracing::error!($($arg)*) };
}
#[cfg(not(feature = "tracing"))]
macro_rules! node_error {
($($arg:tt)*) => {};
}
#[derive(Clone, Eq, PartialEq, Debug)]
pub struct Opts(pub(crate) Box<InnerOpts>);
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct IpPort {
pub(crate) ip_host: String,
pub(crate) port: u16,
}
impl IpPort {
pub fn new(ip_host: String, port: u16) -> Self {
IpPort { ip_host, port }
}
}
impl Opts {
pub fn new(hosts: Vec<IpPort>) -> Opts {
Opts(Box::new(InnerOpts {
ip_ports: hosts,
user: None,
pass: None,
connect_timeout: None,
read_timeout: None,
}))
}
pub fn builder() -> OptsBuilder {
OptsBuilder::default()
}
}
#[derive(Debug, Clone, Default)]
pub struct OptsBuilder {
hosts: Vec<IpPort>,
user: Option<String>,
pass: Option<String>,
connect_timeout: Option<Duration>,
read_timeout: Option<Duration>,
}
impl OptsBuilder {
pub fn host(mut self, ip: &str, port: u16) -> Self {
self.hosts.push(IpPort::new(ip.to_string(), port));
self
}
pub fn hosts(mut self, hosts: Vec<IpPort>) -> Self {
self.hosts.extend(hosts);
self
}
pub fn user(mut self, user: &str) -> Self {
self.user = Some(user.to_string());
self
}
pub fn password(mut self, pass: &str) -> Self {
self.pass = Some(pass.to_string());
self
}
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = Some(timeout);
self
}
pub fn read_timeout(mut self, timeout: Duration) -> Self {
self.read_timeout = Some(timeout);
self
}
pub fn build(self) -> Result<Opts, VoltError> {
if self.hosts.is_empty() {
return Err(VoltError::InvalidConfig);
}
Ok(Opts(Box::new(InnerOpts {
ip_ports: self.hosts,
user: self.user,
pass: self.pass,
connect_timeout: self.connect_timeout,
read_timeout: self.read_timeout,
})))
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub(crate) struct InnerOpts {
pub(crate) ip_ports: Vec<IpPort>,
pub(crate) user: Option<String>,
pub(crate) pass: Option<String>,
pub(crate) connect_timeout: Option<Duration>,
pub(crate) read_timeout: Option<Duration>,
}
pub struct NodeOpt {
pub ip_port: IpPort,
pub user: Option<String>,
pub pass: Option<String>,
pub connect_timeout: Option<Duration>,
pub read_timeout: Option<Duration>,
}
type PendingRequests = HashMap<i64, Sender<VoltTable>>;
pub trait Connection: Sync + Send + 'static {}
#[allow(dead_code)]
pub struct Node {
write_stream: Mutex<Option<TcpStream>>,
info: ConnInfo,
requests: Arc<Mutex<PendingRequests>>,
stop: Arc<Mutex<bool>>,
counter: AtomicI64,
write_lock: AtomicBool,
listener_handle: Option<thread::JoinHandle<()>>,
}
impl Debug for Node {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "Pending request: {}", 1)
}
}
impl Drop for Node {
fn drop(&mut self) {
let res = self.shutdown();
match res {
Ok(_) => {}
Err(_e) => {
node_error!(error = ?_e, "error during node shutdown");
}
}
}
}
impl Connection for Node {}
impl Node {
pub fn new(opt: NodeOpt) -> Result<Node, VoltError> {
let ip_host = opt.ip_port;
let addr_str = format!("{}:{}", ip_host.ip_host, ip_host.port);
let auth_msg = build_auth_message(opt.user.as_deref(), opt.pass.as_deref())?;
let mut stream: TcpStream = match opt.connect_timeout {
Some(timeout) => {
let socket_addr: SocketAddr = addr_str
.to_socket_addrs()
.map_err(|_| VoltError::InvalidConfig)?
.find(|s| s.is_ipv4())
.ok_or(VoltError::InvalidConfig)?;
TcpStream::connect_timeout(&socket_addr, timeout)?
}
None => TcpStream::connect(&addr_str)?,
};
if let Some(read_timeout) = opt.read_timeout {
stream.set_read_timeout(Some(read_timeout))?;
}
stream.write_all(&auth_msg)?;
stream.flush()?;
let read = stream.read_u32::<BigEndian>()?;
let mut all = vec![0; read as usize];
stream.read_exact(&mut all)?;
let info = parse_auth_response(&all)?;
let read_stream = stream.try_clone()?;
let listener_timeout = opt.read_timeout.unwrap_or(Duration::from_secs(2));
read_stream.set_read_timeout(Some(listener_timeout))?;
let requests = Arc::new(Mutex::new(HashMap::new()));
let stop = Arc::new(Mutex::new(false));
let handle = Self::start_listener(read_stream, Arc::clone(&requests), Arc::clone(&stop));
Ok(Node {
stop,
write_stream: Mutex::new(Some(stream)),
info,
requests,
counter: AtomicI64::new(1),
write_lock: AtomicBool::new(false),
listener_handle: Some(handle),
})
}
pub fn get_sequence(&self) -> i64 {
self.counter.fetch_add(1, Ordering::Relaxed)
}
pub fn list_procedures(&self) -> Result<Receiver<VoltTable>, VoltError> {
self.call_sp("@SystemCatalog", volt_param!("PROCEDURES"))
}
pub fn call_sp(
&self,
query: &str,
param: Vec<&dyn Value>,
) -> Result<Receiver<VoltTable>, VoltError> {
let handle = self.get_sequence();
let mut proc = new_procedure_invocation(handle, false, ¶m, query);
let (tx, rx): (Sender<VoltTable>, Receiver<VoltTable>) = mpsc::channel();
self.requests.lock()?.insert(handle, tx);
let bs = proc.bytes();
let result = {
let mut stream_guard = self.write_stream.lock()?;
match stream_guard.as_mut() {
None => Err(VoltError::ConnectionNotAvailable),
Some(stream) => {
stream.write_all(&bs)?;
Ok(rx)
}
}
};
self.write_lock.store(false, Ordering::Release);
result
}
pub fn upload_jar(&self, bs: Vec<u8>) -> Result<Receiver<VoltTable>, VoltError> {
self.call_sp("@UpdateClasses", volt_param!(bs, ""))
}
pub fn query(&self, sql: &str) -> Result<Receiver<VoltTable>, VoltError> {
let zero_vec: Vec<&dyn Value> = vec![&sql];
self.call_sp("@AdHoc", zero_vec)
}
pub fn ping(&self) -> Result<(), VoltError> {
let zero_vec: Vec<&dyn Value> = Vec::new();
let mut proc = new_procedure_invocation(PING_HANDLE, false, &zero_vec, "@Ping");
let bs = proc.bytes();
while self
.write_lock
.compare_exchange_weak(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_err()
{
std::hint::spin_loop();
}
let result = {
let mut stream_guard = self.write_stream.lock()?;
match stream_guard.as_mut() {
None => Err(VoltError::ConnectionNotAvailable),
Some(stream) => {
stream.write_all(&bs)?;
Ok(())
}
}
};
self.write_lock.store(false, Ordering::Release);
result
}
fn job(
tcp: &mut impl Read,
requests: &Arc<Mutex<PendingRequests>>,
buffer: &mut Vec<u8>,
) -> Result<(), VoltError> {
let msg_len = tcp.read_u32::<BigEndian>()?;
if msg_len == 0 {
return Ok(());
}
buffer.resize(msg_len as usize, 0);
tcp.read_exact(buffer)?;
let mut res = ByteBuffer::from_bytes(buffer);
let _ = res.read_u8()?;
let handle = res.read_i64()?;
if handle == PING_HANDLE {
return Ok(()); }
if let Some(sender) = requests.lock()?.remove(&handle) {
let info = VoltResponseInfo::new(&mut res, handle)?;
let table = new_volt_table(&mut res, info)?;
let _ = sender.send(table);
}
Ok(())
}
pub fn shutdown(&mut self) -> Result<(), VoltError> {
{
let mut stop = self.stop.lock()?;
*stop = true;
}
let mut stream_guard = self.write_stream.lock()?;
if let Some(stream) = stream_guard.take() {
let _ = stream.shutdown(Shutdown::Both);
}
drop(stream_guard);
if let Some(handle) = self.listener_handle.take() {
let _ = handle.join();
}
Ok(())
}
fn start_listener(
mut tcp: TcpStream,
requests: Arc<Mutex<PendingRequests>>,
stopping: Arc<Mutex<bool>>,
) -> thread::JoinHandle<()> {
thread::spawn(move || {
let mut buffer = Vec::with_capacity(4096);
loop {
let should_stop = stopping
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
if *should_stop {
break;
}
drop(should_stop);
if let Err(_err) = Node::job(&mut tcp, &requests, &mut buffer) {
if let VoltError::Io(ref io_err) = _err {
if io_err.kind() == std::io::ErrorKind::WouldBlock
|| io_err.kind() == std::io::ErrorKind::TimedOut
{
continue;
}
}
let is_stopping = stopping
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
if !*is_stopping {
node_error!(error = %_err, "VoltDB listener error");
}
break; }
}
})
}
}
#[derive(Debug, Clone)]
pub struct ConnInfo {
pub host_id: i32,
pub connection: i64,
pub leader_addr: Ipv4Addr,
pub build: String,
}
impl Default for ConnInfo {
fn default() -> Self {
Self {
host_id: 0,
connection: 0,
leader_addr: Ipv4Addr::new(127, 0, 0, 1),
build: String::new(),
}
}
}
pub fn block_for_result(res: &Receiver<VoltTable>) -> Result<VoltTable, VoltError> {
let mut table = res.recv()?;
let err = table.has_error();
match err {
None => Ok(table),
Some(err) => Err(err),
}
}
pub fn reset() {}
pub fn get_node(addr: &str) -> Result<Node, VoltError> {
let addrs = addr
.to_socket_addrs()
.map_err(|_| VoltError::InvalidConfig)?;
let socket_addr = addrs
.into_iter()
.find(|s| s.is_ipv4())
.ok_or(VoltError::InvalidConfig)?;
let ip_port = IpPort::new(socket_addr.ip().to_string(), socket_addr.port());
let opt = NodeOpt {
ip_port,
user: None,
pass: None,
connect_timeout: None,
read_timeout: None,
};
Node::new(opt)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_opts_builder_basic() {
let opts = Opts::builder().host("localhost", 21212).build().unwrap();
assert_eq!(opts.0.ip_ports.len(), 1);
assert_eq!(opts.0.ip_ports[0].ip_host, "localhost");
assert_eq!(opts.0.ip_ports[0].port, 21212);
assert!(opts.0.user.is_none());
assert!(opts.0.pass.is_none());
}
#[test]
fn test_opts_builder_with_auth() {
let opts = Opts::builder()
.host("127.0.0.1", 21211)
.user("admin")
.password("secret")
.build()
.unwrap();
assert_eq!(opts.0.user, Some("admin".to_string()));
assert_eq!(opts.0.pass, Some("secret".to_string()));
}
#[test]
fn test_opts_builder_multiple_hosts() {
let opts = Opts::builder()
.host("host1", 21212)
.host("host2", 21212)
.host("host3", 21212)
.build()
.unwrap();
assert_eq!(opts.0.ip_ports.len(), 3);
assert_eq!(opts.0.ip_ports[0].ip_host, "host1");
assert_eq!(opts.0.ip_ports[1].ip_host, "host2");
assert_eq!(opts.0.ip_ports[2].ip_host, "host3");
}
#[test]
fn test_opts_builder_with_hosts_vec() {
let hosts = vec![
IpPort::new("node1".to_string(), 21212),
IpPort::new("node2".to_string(), 21213),
];
let opts = Opts::builder().hosts(hosts).build().unwrap();
assert_eq!(opts.0.ip_ports.len(), 2);
}
#[test]
fn test_opts_builder_with_timeouts() {
let opts = Opts::builder()
.host("localhost", 21212)
.connect_timeout(Duration::from_secs(10))
.read_timeout(Duration::from_secs(30))
.build()
.unwrap();
assert_eq!(opts.0.connect_timeout, Some(Duration::from_secs(10)));
assert_eq!(opts.0.read_timeout, Some(Duration::from_secs(30)));
}
#[test]
fn test_opts_builder_no_hosts_fails() {
let result = Opts::builder().build();
assert!(result.is_err());
match result {
Err(VoltError::InvalidConfig) => {}
_ => panic!("Expected InvalidConfig error"),
}
}
#[test]
fn test_opts_new_compatibility() {
let hosts = vec![IpPort::new("localhost".to_string(), 21212)];
let opts = Opts::new(hosts);
assert_eq!(opts.0.ip_ports.len(), 1);
assert!(opts.0.user.is_none());
assert!(opts.0.connect_timeout.is_none());
}
#[test]
fn test_ip_port_new() {
let ip_port = IpPort::new("192.168.1.1".to_string(), 8080);
assert_eq!(ip_port.ip_host, "192.168.1.1");
assert_eq!(ip_port.port, 8080);
}
}