use super::bindings;
use super::error::MgError;
use super::value::{
QueryParam, Record, Value, c_string_to_string, hash_map_to_mg_map, mg_list_to_vec,
mg_map_to_hash_map, mg_value_string,
};
use std::collections::HashMap;
use std::ffi::CString;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::vec::IntoIter;
static CONNECTION_COUNT: AtomicUsize = AtomicUsize::new(0);
pub type TrustCallback = *const dyn Fn(&String, &String, &String, &String) -> i32;
pub struct ConnectParams {
pub port: u16,
pub host: Option<String>,
pub address: Option<String>,
pub username: Option<String>,
pub password: Option<String>,
pub client_name: String,
pub sslmode: SSLMode,
pub sslcert: Option<String>,
pub sslkey: Option<String>,
pub trust_callback: Option<TrustCallback>,
pub lazy: bool,
pub autocommit: bool,
}
impl Default for ConnectParams {
fn default() -> Self {
ConnectParams {
port: 7687,
host: None,
address: None,
username: None,
password: None,
client_name: String::from("rsmgclient/0.1"),
sslmode: SSLMode::Disable,
sslcert: None,
sslkey: None,
trust_callback: None,
lazy: true,
autocommit: false,
}
}
}
#[derive(PartialEq, Eq)]
pub enum SSLMode {
Disable,
Require,
}
pub struct Connection {
mg_session: *mut bindings::mg_session,
lazy: bool,
autocommit: bool,
status: ConnectionStatus,
results_iter: Option<IntoIter<Record>>,
arraysize: u32,
summary: Option<HashMap<String, Value>>,
#[allow(dead_code)]
trust_callback: Option<Box<TrustCallback>>,
}
#[derive(PartialEq, Eq, Debug, Clone, Copy)]
#[repr(u8)]
pub enum ConnectionStatus {
Ready,
InTransaction,
Executing,
Fetching,
Closed,
Bad,
}
fn read_error_message(mg_session: *mut bindings::mg_session) -> String {
let c_error_message = unsafe { bindings::mg_session_error(mg_session) };
unsafe { c_string_to_string(c_error_message, None) }
}
impl Drop for Connection {
fn drop(&mut self) {
unsafe { bindings::mg_session_destroy(self.mg_session) };
if CONNECTION_COUNT.fetch_sub(1, Ordering::SeqCst) == 1 {
Connection::finalize();
}
}
}
impl Connection {
pub fn init() {
unsafe {
bindings::mg_init();
}
}
pub fn finalize() {
unsafe {
bindings::mg_finalize();
}
}
pub fn lazy(&self) -> bool {
self.lazy
}
pub fn autocommit(&self) -> bool {
self.autocommit
}
pub fn arraysize(&self) -> u32 {
self.arraysize
}
pub fn status(&self) -> ConnectionStatus {
self.status
}
pub fn summary(&self) -> Option<HashMap<String, Value>> {
self.summary.as_ref().map(|x| (*x).clone())
}
pub fn set_lazy(&mut self, lazy: bool) {
match self.status {
ConnectionStatus::Ready => self.lazy = lazy,
ConnectionStatus::InTransaction => panic!("Can't set lazy while in transaction"),
ConnectionStatus::Executing => panic!("Can't set lazy while executing"),
ConnectionStatus::Fetching => panic!("Can't set lazy while fetching"),
ConnectionStatus::Bad => panic!("Can't set lazy while connection is bad"),
ConnectionStatus::Closed => panic!("Can't set lazy while connection is closed"),
}
}
pub fn set_autocommit(&mut self, autocommit: bool) {
match self.status {
ConnectionStatus::Ready => self.autocommit = autocommit,
ConnectionStatus::InTransaction => {
panic!("Can't set autocommit while in transaction")
}
ConnectionStatus::Executing => panic!("Can't set autocommit while executing"),
ConnectionStatus::Fetching => panic!("Can't set autocommit while fetching"),
ConnectionStatus::Bad => panic!("Can't set autocommit while connection is bad"),
ConnectionStatus::Closed => panic!("Can't set autocommit while connection is closed"),
}
}
pub fn set_arraysize(&mut self, arraysize: u32) {
self.arraysize = arraysize;
}
pub fn connect(param_struct: &ConnectParams) -> Result<Connection, MgError> {
let prev_count = CONNECTION_COUNT.fetch_add(1, Ordering::SeqCst);
if prev_count == 0 {
Connection::init();
}
let mg_session_params = unsafe { bindings::mg_session_params_make() };
if mg_session_params.is_null() {
if CONNECTION_COUNT.fetch_sub(1, Ordering::SeqCst) == 1 {
Connection::finalize();
}
return Err(MgError::ffi(
"Failed to allocate mg_session_params".to_string(),
));
}
let mut trust_callback_box: Option<Box<TrustCallback>> = None;
let c_host = match param_struct.host.as_ref() {
Some(s) => Some(CString::new(s.as_str()).map_err(|_| MgError::null_byte("host"))?),
None => None,
};
let c_address = match param_struct.address.as_ref() {
Some(s) => Some(CString::new(s.as_str()).map_err(|_| MgError::null_byte("address"))?),
None => None,
};
let c_username = match param_struct.username.as_ref() {
Some(s) => Some(CString::new(s.as_str()).map_err(|_| MgError::null_byte("username"))?),
None => None,
};
let c_password = match param_struct.password.as_ref() {
Some(s) => Some(CString::new(s.as_str()).map_err(|_| MgError::null_byte("password"))?),
None => None,
};
let c_client_name = CString::new(param_struct.client_name.as_str())
.map_err(|_| MgError::null_byte("client_name"))?;
let c_sslcert = match param_struct.sslcert.as_ref() {
Some(s) => Some(CString::new(s.as_str()).map_err(|_| MgError::null_byte("sslcert"))?),
None => None,
};
let c_sslkey = match param_struct.sslkey.as_ref() {
Some(s) => Some(CString::new(s.as_str()).map_err(|_| MgError::null_byte("sslkey"))?),
None => None,
};
unsafe {
if let Some(ref x) = c_host {
bindings::mg_session_params_set_host(mg_session_params, x.as_ptr())
}
bindings::mg_session_params_set_port(mg_session_params, param_struct.port);
if let Some(ref x) = c_address {
bindings::mg_session_params_set_address(mg_session_params, x.as_ptr())
}
if let Some(ref x) = c_username {
bindings::mg_session_params_set_username(mg_session_params, x.as_ptr())
}
if let Some(ref x) = c_password {
bindings::mg_session_params_set_password(mg_session_params, x.as_ptr())
}
bindings::mg_session_params_set_user_agent(mg_session_params, c_client_name.as_ptr());
bindings::mg_session_params_set_sslmode(
mg_session_params,
match param_struct.sslmode {
SSLMode::Disable => 0,
SSLMode::Require => 1,
},
);
if let Some(ref x) = c_sslcert {
bindings::mg_session_params_set_sslcert(mg_session_params, x.as_ptr())
}
if let Some(ref x) = c_sslkey {
bindings::mg_session_params_set_sslkey(mg_session_params, x.as_ptr())
}
if let Some(x) = ¶m_struct.trust_callback {
let callback_box = Box::new(*x);
let trust_callback_ptr = Box::into_raw(callback_box);
bindings::mg_session_params_set_trust_data(
mg_session_params,
trust_callback_ptr as *mut ::std::os::raw::c_void,
);
bindings::mg_session_params_set_trust_callback(
mg_session_params,
Some(trust_callback_wrapper),
);
trust_callback_box = Some(Box::from_raw(trust_callback_ptr));
}
}
let mut mg_session: *mut bindings::mg_session = std::ptr::null_mut();
let status = unsafe { bindings::mg_connect(mg_session_params, &mut mg_session) };
unsafe {
bindings::mg_session_params_destroy(mg_session_params);
};
if status != 0 {
if CONNECTION_COUNT.fetch_sub(1, Ordering::SeqCst) == 1 {
Connection::finalize();
}
return Err(MgError::connection(read_error_message(mg_session)));
}
Ok(Connection {
mg_session,
lazy: param_struct.lazy,
autocommit: param_struct.autocommit,
status: ConnectionStatus::Ready,
results_iter: None,
arraysize: 1,
summary: None,
trust_callback: trust_callback_box,
})
}
pub fn execute_without_results(&mut self, query: &str) -> Result<(), MgError> {
let c_query = CString::new(query).map_err(|_| MgError::null_byte("query"))?;
match unsafe {
bindings::mg_session_run(
self.mg_session,
c_query.as_ptr(), std::ptr::null(),
std::ptr::null_mut(),
std::ptr::null_mut(),
std::ptr::null_mut(),
)
} {
0 => {
self.status = ConnectionStatus::Executing;
}
_ => {
self.status = ConnectionStatus::Bad;
return Err(MgError::query(read_error_message(self.mg_session)));
}
}
match unsafe { bindings::mg_session_pull(self.mg_session, std::ptr::null_mut()) } {
0 => {
self.status = ConnectionStatus::Fetching;
}
_ => {
self.status = ConnectionStatus::Bad;
return Err(MgError::query(read_error_message(self.mg_session)));
}
}
loop {
let mut result = std::ptr::null_mut();
match unsafe { bindings::mg_session_fetch(self.mg_session, &mut result) } {
1 => {
continue;
}
0 => {
self.status = ConnectionStatus::Ready;
return Ok(());
}
_ => {
self.status = ConnectionStatus::Bad;
return Err(MgError::query(read_error_message(self.mg_session)));
}
};
}
}
pub fn execute(
&mut self,
query: &str,
params: Option<&HashMap<String, QueryParam>>,
) -> Result<Vec<String>, MgError> {
match self.status {
ConnectionStatus::Ready => {}
ConnectionStatus::InTransaction => {}
ConnectionStatus::Executing => {
return Err(MgError::invalid_state("execute", "already executing"));
}
ConnectionStatus::Fetching => {
return Err(MgError::invalid_state("execute", "fetching"));
}
ConnectionStatus::Closed => {
return Err(MgError::invalid_state("execute", "connection closed"));
}
ConnectionStatus::Bad => {
return Err(MgError::invalid_state("execute", "bad connection"));
}
}
if !self.autocommit && self.status == ConnectionStatus::Ready {
match self.execute_without_results("BEGIN") {
Ok(()) => self.status = ConnectionStatus::InTransaction,
Err(err) => return Err(err),
}
}
self.summary = None;
let c_query = CString::new(query).map_err(|_| MgError::null_byte("query"))?;
let mg_params = match params {
Some(x) => hash_map_to_mg_map(x),
None => std::ptr::null_mut(),
};
let mut columns = std::ptr::null();
let status = unsafe {
bindings::mg_session_run(
self.mg_session,
c_query.as_ptr(),
mg_params,
std::ptr::null_mut(),
&mut columns,
std::ptr::null_mut(),
)
};
if !mg_params.is_null() {
unsafe { bindings::mg_map_destroy(mg_params) };
}
if status != 0 {
self.status = ConnectionStatus::Bad;
return Err(MgError::query(read_error_message(self.mg_session)));
}
self.status = ConnectionStatus::Executing;
if !self.lazy {
match self.pull_and_fetch_all() {
Ok(x) => self.results_iter = Some(x.into_iter()),
Err(x) => {
self.status = ConnectionStatus::Bad;
return Err(x);
}
}
}
Ok(parse_columns(columns))
}
pub fn fetchone(&mut self) -> Result<Option<Record>, MgError> {
match self.status {
ConnectionStatus::Ready => {
return Err(MgError::invalid_state("fetchone", "ready"));
}
ConnectionStatus::InTransaction => {
return Err(MgError::invalid_state("fetchone", "in transaction"));
}
ConnectionStatus::Executing => {}
ConnectionStatus::Fetching => {}
ConnectionStatus::Closed => {
return Err(MgError::invalid_state("fetchone", "connection closed"));
}
ConnectionStatus::Bad => {
return Err(MgError::invalid_state("fetchone", "bad connection"));
}
}
match self.lazy {
true => {
if self.status == ConnectionStatus::Executing {
match self.pull(1) {
Ok(_) => {
}
Err(err) => {
self.status = ConnectionStatus::Bad;
return Err(err);
}
}
}
match self.fetch() {
Ok((Some(x), None)) => {
match self.fetch()? {
(None, Some(has_more)) => {
if has_more {
self.status = ConnectionStatus::Executing;
}
}
_ => {
}
}
Ok(Some(x))
}
Ok((None, Some(has_more))) => {
if has_more {
self.status = ConnectionStatus::Executing;
} else {
self.status = if self.autocommit {
ConnectionStatus::Ready
} else {
ConnectionStatus::InTransaction
};
}
Ok(None)
}
Ok(_) => {
self.status = if self.autocommit {
ConnectionStatus::Ready
} else {
ConnectionStatus::InTransaction
};
Ok(None)
}
Err(_) => {
self.status = if self.autocommit {
ConnectionStatus::Ready
} else {
ConnectionStatus::InTransaction
};
Ok(None)
}
}
}
false => match self.next_record() {
Some(x) => Ok(Some(x)),
None => {
self.status = if self.autocommit {
ConnectionStatus::Ready
} else {
ConnectionStatus::InTransaction
};
Ok(None)
}
},
}
}
fn next_record(&mut self) -> Option<Record> {
if let Some(iter) = self.results_iter.as_mut() {
iter.next()
} else {
None
}
}
pub fn fetchmany(&mut self, size: Option<u32>) -> Result<Vec<Record>, MgError> {
let size = match size {
Some(x) => x,
None => self.arraysize,
};
let mut vec = Vec::new();
for _i in 0..size {
match self.fetchone() {
Ok(record) => match record {
Some(x) => vec.push(x),
None => break,
},
Err(err) => return Err(err),
}
}
Ok(vec)
}
pub fn fetchall(&mut self) -> Result<Vec<Record>, MgError> {
let mut vec = Vec::new();
loop {
match self.fetchone() {
Ok(record) => match record {
Some(x) => vec.push(x),
None => break,
},
Err(err) => return Err(err),
}
}
Ok(vec)
}
fn pull(&mut self, n: i64) -> Result<(), MgError> {
match self.status {
ConnectionStatus::Ready => {
return Err(MgError::invalid_state("pull", "ready"));
}
ConnectionStatus::InTransaction => {
return Err(MgError::invalid_state("pull", "in transaction"));
}
ConnectionStatus::Executing => {}
ConnectionStatus::Fetching => {
return Err(MgError::invalid_state("pull", "fetching"));
}
ConnectionStatus::Closed => {
return Err(MgError::invalid_state("pull", "connection closed"));
}
ConnectionStatus::Bad => {
return Err(MgError::invalid_state("pull", "bad connection"));
}
}
let pull_status = match n {
0 => unsafe { bindings::mg_session_pull(self.mg_session, std::ptr::null_mut()) },
_ => {
let n_key = CString::new("n").expect("'n' is a valid C string");
unsafe {
let mg_map = bindings::mg_map_make_empty(1);
if mg_map.is_null() {
self.status = ConnectionStatus::Bad;
return Err(MgError::ffi("Failed to allocate pull map".to_string()));
}
let mg_int = bindings::mg_value_make_integer(n);
if mg_int.is_null() {
self.status = ConnectionStatus::Bad;
bindings::mg_map_destroy(mg_map);
return Err(MgError::ffi(
"Failed to allocate pull map integer value".to_string(),
));
}
if bindings::mg_map_insert(mg_map, n_key.as_ptr(), mg_int) != 0 {
self.status = ConnectionStatus::Bad;
bindings::mg_map_destroy(mg_map);
bindings::mg_value_destroy(mg_int);
return Err(MgError::ffi("Failed to insert into pull map".to_string()));
}
let status = bindings::mg_session_pull(self.mg_session, mg_map);
bindings::mg_map_destroy(mg_map);
status
}
}
};
match pull_status {
0 => {
self.status = ConnectionStatus::Fetching;
Ok(())
}
_ => {
self.status = ConnectionStatus::Bad;
Err(MgError::query(read_error_message(self.mg_session)))
}
}
}
fn fetch(&mut self) -> Result<(Option<Record>, Option<bool>), MgError> {
match self.status {
ConnectionStatus::Ready => {
return Err(MgError::invalid_state("fetch", "ready"));
}
ConnectionStatus::InTransaction => {
return Err(MgError::invalid_state("fetch", "in transaction"));
}
ConnectionStatus::Executing => {
return Err(MgError::invalid_state("fetch", "executing"));
}
ConnectionStatus::Fetching => {}
ConnectionStatus::Closed => {
return Err(MgError::invalid_state("fetch", "connection closed"));
}
ConnectionStatus::Bad => {
return Err(MgError::invalid_state("fetch", "bad connection"));
}
}
let mut mg_result: *mut bindings::mg_result = std::ptr::null_mut();
let fetch_status = unsafe { bindings::mg_session_fetch(self.mg_session, &mut mg_result) };
match fetch_status {
1 => unsafe {
let row = bindings::mg_result_row(mg_result);
Ok((
Some(Record {
values: mg_list_to_vec(row),
}),
None,
))
},
0 => unsafe {
let mg_summary = bindings::mg_result_summary(mg_result);
let c_has_more = CString::new("has_more").expect("'has_more' is a valid C string");
let mg_has_more = bindings::mg_map_at(mg_summary, c_has_more.as_ptr());
let has_more = bindings::mg_value_bool(mg_has_more) != 0;
self.summary = Some(mg_map_to_hash_map(mg_summary));
Ok((None, Some(has_more)))
},
_ => Err(MgError::query(read_error_message(self.mg_session))),
}
}
fn pull_and_fetch_all(&mut self) -> Result<Vec<Record>, MgError> {
let mut res = Vec::new();
match self.pull(0) {
Ok(_) => loop {
let x = self.fetch()?;
match x {
(Some(x), _) => res.push(x),
(None, _) => break,
}
},
Err(err) => return Err(err),
}
Ok(res)
}
pub fn commit(&mut self) -> Result<(), MgError> {
match self.status {
ConnectionStatus::Ready => {}
ConnectionStatus::InTransaction => {}
ConnectionStatus::Executing => {
return Err(MgError::invalid_state("commit", "executing"));
}
ConnectionStatus::Fetching => {
return Err(MgError::invalid_state("commit", "fetching"));
}
ConnectionStatus::Closed => {
return Err(MgError::invalid_state("commit", "connection closed"));
}
ConnectionStatus::Bad => {
return Err(MgError::invalid_state("commit", "bad connection"));
}
}
if self.autocommit || self.status != ConnectionStatus::InTransaction {
return Ok(());
}
match self.execute_without_results("COMMIT") {
Ok(()) => {
self.status = ConnectionStatus::Ready;
Ok(())
}
Err(err) => Err(err),
}
}
pub fn rollback(&mut self) -> Result<(), MgError> {
match self.status {
ConnectionStatus::Ready => {
return Err(MgError::invalid_state("rollback", "not in transaction"));
}
ConnectionStatus::InTransaction => {}
ConnectionStatus::Executing => {
return Err(MgError::invalid_state("rollback", "executing"));
}
ConnectionStatus::Fetching => {
return Err(MgError::invalid_state("rollback", "fetching"));
}
ConnectionStatus::Closed => {
return Err(MgError::invalid_state("rollback", "connection closed"));
}
ConnectionStatus::Bad => {
return Err(MgError::invalid_state("rollback", "bad connection"));
}
}
if self.autocommit {
return Ok(());
}
match self.execute_without_results("ROLLBACK") {
Ok(()) => {
self.status = ConnectionStatus::Ready;
Ok(())
}
Err(err) => Err(err),
}
}
pub fn close(&mut self) {
match self.status {
ConnectionStatus::Ready => self.status = ConnectionStatus::Closed,
ConnectionStatus::InTransaction => self.status = ConnectionStatus::Closed,
ConnectionStatus::Executing => panic!("Can't close while executing"),
ConnectionStatus::Fetching => panic!("Can't close while fetching"),
ConnectionStatus::Closed => {}
ConnectionStatus::Bad => panic!("Can't closed a bad connection"),
}
}
}
fn parse_columns(mg_list: *const bindings::mg_list) -> Vec<String> {
let size = unsafe { bindings::mg_list_size(mg_list) };
let mut columns: Vec<String> = Vec::new();
for i in 0..size {
let mg_value = unsafe { bindings::mg_list_at(mg_list, i) };
columns.push(mg_value_string(mg_value));
}
columns
}
extern "C" fn trust_callback_wrapper(
host: *const ::std::os::raw::c_char,
ip_address: *const ::std::os::raw::c_char,
key_type: *const ::std::os::raw::c_char,
fingerprint: *const ::std::os::raw::c_char,
fun_raw: *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_int {
let fun: &mut &mut dyn Fn(&String, &String, &String, &String) -> i32 = unsafe {
&mut *(fun_raw
as *mut &mut dyn for<'r, 's, 't0, 't1> std::ops::Fn(
&'r std::string::String,
&'s std::string::String,
&'t0 std::string::String,
&'t1 std::string::String,
) -> i32)
};
unsafe {
fun(
&c_string_to_string(host, None),
&c_string_to_string(ip_address, None),
&c_string_to_string(key_type, None),
&c_string_to_string(fingerprint, None),
) as std::os::raw::c_int
}
}
#[cfg(test)]
mod tests;