use crate::attr::*;
use crate::error::SqlResult;
use crate::handle::{SqlHandle, SqlRawHandle};
use crate::tran::SqlEndTransaction;
use crate::util;
use crate::util::VecCapacityExt;
use odbc_sys::*;
use std::fmt;
#[derive(Debug, Eq, PartialEq, Hash)]
pub struct SqlEnvironment {
handle: SQLHENV,
}
impl SqlAttribute for EnvironmentAttribute {
fn buffer_length(&self) -> Option<SqlAttributeStringLength> {
let value = match self {
SQL_ATTR_CONNECTION_POOLING | SQL_ATTR_CP_MATCH => SQL_IS_UINTEGER,
_ => SQL_IS_INTEGER,
};
Some(value)
}
}
unsafe impl SqlAttributes for SqlEnvironment {
type AttributeType = EnvironmentAttribute;
const GETTER_NAME: &'static str = "SQLGetEnvAttr";
const GETTER: unsafe extern "system" fn(
SQLHENV,
EnvironmentAttribute,
SQLPOINTER,
SQLINTEGER,
*mut SQLINTEGER,
) -> SQLRETURN = SQLGetEnvAttr;
const SETTER: unsafe extern "system" fn(
SQLHENV,
EnvironmentAttribute,
SQLPOINTER,
SQLINTEGER,
) -> SQLRETURN = SQLSetEnvAttr;
const SETTER_NAME: &'static str = "SQLSetEnvAttr";
}
impl SqlEndTransaction for SqlEnvironment {}
impl SqlHandle for SqlEnvironment {
const TYPE: HandleType = SQL_HANDLE_ENV;
type Type = SQLHENV;
unsafe fn typed_handle(&self) -> SQLHENV {
self.handle
}
unsafe fn handle(&self) -> SQLHANDLE {
self.handle as SQLHANDLE
}
}
unsafe impl SqlRawHandle for SqlEnvironment {}
impl Drop for SqlEnvironment {
fn drop(&mut self) {
unsafe { self.dealloc_handle().unwrap() }
}
}
unsafe impl Send for SqlEnvironment {}
unsafe impl Sync for SqlEnvironment {}
use std::collections::HashMap;
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct SqlDriver {
pub description: String,
pub attributes: HashMap<String, String>,
}
#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub struct SqlDataSource {
pub server_name: String,
pub description: String,
}
#[repr(u32)]
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub enum SqlConnectionPooling {
Off,
OnePerDriver,
OnePerEnvironment,
DriverAware,
}
#[repr(u32)]
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub enum SqlConnectionPoolingMatch {
Strict,
Relaxed,
}
#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub enum SqlDataSourceFilter {
Both,
User,
System,
}
impl Default for SqlDataSourceFilter {
fn default() -> SqlDataSourceFilter {
SqlDataSourceFilter::Both
}
}
impl fmt::Display for SqlDriver {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{} (\"{:?}\")", self.description, self.attributes)
}
}
impl fmt::Display for SqlDataSource {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{} \"{}\"", self.server_name, self.description)
}
}
impl SqlEnvironment {
pub fn new() -> SqlResult<SqlEnvironment> {
let handle = unsafe { Self::alloc_handle()? as SQLHENV };
Ok(SqlEnvironment { handle })
}
pub fn set_version(&mut self, version: OdbcVersion) -> SqlResult {
unsafe { self.set_attribute(SQL_ATTR_ODBC_VERSION, version as usize) }
}
pub fn get_version(&mut self) -> SqlResult<OdbcVersion> {
unsafe {
self.get_attribute::<SQLINTEGER>(SQL_ATTR_ODBC_VERSION)
.map(|value| std::mem::transmute(value))
}
}
pub fn set_connection_pooling(&mut self, setting: SqlConnectionPooling) -> SqlResult {
unsafe { self.set_attribute(SQL_ATTR_CONNECTION_POOLING, setting as usize) }
}
pub fn get_connection_pooling(&mut self) -> SqlResult<SqlConnectionPooling> {
unsafe {
self.get_attribute::<SQLUINTEGER>(SQL_ATTR_CONNECTION_POOLING)
.map(|value| std::mem::transmute(value))
}
}
pub fn set_connection_pooling_match(
&mut self,
setting: SqlConnectionPoolingMatch,
) -> SqlResult {
unsafe { self.set_attribute(SQL_ATTR_CP_MATCH, setting as usize) }
}
pub fn get_connection_pooling_match(&mut self) -> SqlResult<SqlConnectionPoolingMatch> {
unsafe {
self.get_attribute::<SQLUINTEGER>(SQL_ATTR_CP_MATCH)
.map(|value| std::mem::transmute(value))
}
}
pub fn drivers(&self) -> SqlResult<Vec<SqlDriver>> {
let mut drivers: Vec<SqlDriver> = vec![];
let mut description: Vec<u16> = Vec::with_capacity(SQL_MAX_MESSAGE_LENGTH as usize);
let mut description_len: SQLSMALLINT = 0;
let mut attributes: Vec<u16> = Vec::with_capacity(SQL_MAX_MESSAGE_LENGTH as usize);
let mut attributes_len: SQLSMALLINT = 0;
let mut direction = SQL_FETCH_FIRST;
loop {
let ret = unsafe {
SQLDriversW(
self.handle,
direction,
description.as_mut_ptr(),
description.capacity() as SQLSMALLINT,
&mut description_len,
attributes.as_mut_ptr(),
attributes.capacity() as SQLSMALLINT,
&mut attributes_len,
)
};
match ret {
SQL_ERROR => return Err(self.get_detailed_error(ret)),
SQL_NO_DATA => break,
SQL_SUCCESS | SQL_SUCCESS_WITH_INFO => {
assert!(description_len >= 0);
let description_len = description_len as usize;
assert!(attributes_len >= 0);
let attributes_len = attributes_len as usize;
if util::is_string_data_right_truncated(self, ret)? {
description.reserve_capacity(description_len + 1);
attributes.reserve_capacity(attributes_len + 1);
direction = SQL_FETCH_FIRST;
drivers.clear();
} else {
unsafe {
description.set_len_checked(description_len);
attributes.set_len_checked(attributes_len);
}
let description_text = util::from_utf_16_null_terminated(&description)?;
let attributes_text = util::from_utf_16_null_terminated(&attributes)?;
let attributes: HashMap<String, String> = attributes_text
.split('\0')
.filter_map(|pair: &str| {
if let Some(pos) = pair.bytes().position(|c| c == b'=') {
let (key, value) = pair.split_at(pos);
Some((key.to_owned(), value[1..].to_owned()))
} else {
None
}
})
.collect();
let driver = SqlDriver {
description: description_text,
attributes,
};
drivers.push(driver);
direction = SQL_FETCH_NEXT;
}
}
_ => panic!("Unexpected SQLDriversW return code: {:?}", ret),
}
}
Ok(drivers)
}
pub fn data_sources(&self, filter: SqlDataSourceFilter) -> SqlResult<Vec<SqlDataSource>> {
use self::SqlDataSourceFilter::*;
let mut data_sources: Vec<SqlDataSource> = vec![];
let initial_direction = match filter {
Both => SQL_FETCH_FIRST,
User => SQL_FETCH_FIRST_USER,
System => SQL_FETCH_FIRST_SYSTEM,
};
let mut server_name: Vec<u16> = Vec::with_capacity(SQL_MAX_MESSAGE_LENGTH as usize);
let mut server_name_len: SQLSMALLINT = 0;
let mut description: Vec<u16> = Vec::with_capacity(SQL_MAX_MESSAGE_LENGTH as usize);
let mut description_len: SQLSMALLINT = 0;
let mut direction = initial_direction;
loop {
let ret = unsafe {
SQLDataSourcesW(
self.handle,
direction,
server_name.as_mut_ptr(),
server_name.capacity() as SQLSMALLINT,
&mut server_name_len,
description.as_mut_ptr(),
description.capacity() as SQLSMALLINT,
&mut description_len,
)
};
match ret {
SQL_ERROR => return Err(self.get_detailed_error(ret)),
SQL_NO_DATA => break,
SQL_SUCCESS | SQL_SUCCESS_WITH_INFO => {
assert!(server_name_len >= 0);
let server_name_len = server_name_len as usize;
assert!(description_len >= 0);
let description_len = description_len as usize;
if util::is_string_data_right_truncated(self, ret)? {
server_name.reserve_capacity(server_name_len + 1);
description.reserve_capacity(description_len + 1);
data_sources.clear();
direction = initial_direction;
} else {
unsafe {
server_name.set_len_checked(server_name_len);
description.set_len_checked(description_len);
}
let server_name_text = util::from_utf_16_null_terminated(&server_name)?;
let description_text = util::from_utf_16_null_terminated(&description)?;
let data_source = SqlDataSource {
server_name: server_name_text,
description: description_text,
};
data_sources.push(data_source);
direction = SQL_FETCH_NEXT;
}
}
_ => panic!("Unexpected SQLDataSourcesW return code: {:?}", ret),
}
}
Ok(data_sources)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::util::print_diagnostics;
#[test]
fn new_environment() {
SqlEnvironment::new().unwrap();
}
#[test]
fn get_set_version() {
let mut env = SqlEnvironment::new().unwrap();
print_diagnostics(&env);
env.set_version(OdbcVersion::SQL_OV_ODBC3_80).unwrap();
print_diagnostics(&env);
let ver = env.get_version().unwrap();
print_diagnostics(&env);
assert_eq!(OdbcVersion::SQL_OV_ODBC3_80, ver);
}
#[cfg(windows)]
#[test]
fn get_set_connection_pooling() {
let mut env = SqlEnvironment::new().unwrap();
print_diagnostics(&env);
env.set_version(OdbcVersion::SQL_OV_ODBC3_80).unwrap();
print_diagnostics(&env);
env.set_connection_pooling(SqlConnectionPooling::DriverAware)
.unwrap();
let cp = env.get_connection_pooling().unwrap();
print_diagnostics(&env);
assert_eq!(SqlConnectionPooling::DriverAware, cp);
}
#[cfg(windows)]
#[test]
fn get_set_connection_pooling_match() {
let mut env = SqlEnvironment::new().unwrap();
print_diagnostics(&env);
env.set_version(OdbcVersion::SQL_OV_ODBC3_80).unwrap();
print_diagnostics(&env);
env.set_connection_pooling_match(SqlConnectionPoolingMatch::Relaxed)
.unwrap();
let m = env.get_connection_pooling_match().unwrap();
print_diagnostics(&env);
assert_eq!(SqlConnectionPoolingMatch::Relaxed, m);
}
#[test]
fn drivers() {
let mut env = SqlEnvironment::new().unwrap();
print_diagnostics(&env);
env.set_version(OdbcVersion::SQL_OV_ODBC3_80).unwrap();
print_diagnostics(&env);
let drivers = env.drivers().unwrap();
print_diagnostics(&env);
for driver in drivers.iter() {
println!("{}", driver);
}
}
#[test]
fn data_sources() {
let mut env = SqlEnvironment::new().unwrap();
print_diagnostics(&env);
env.set_version(OdbcVersion::SQL_OV_ODBC3_80).unwrap();
print_diagnostics(&env);
let data_sources = env.data_sources(SqlDataSourceFilter::Both).unwrap();
print_diagnostics(&env);
for dsn in data_sources.iter() {
println!("{}", dsn);
}
}
}