use super::{
as_handle::AsHandle,
buffer::{clamp_small_int, mut_buf_ptr},
};
use odbc_sys::{SQLGetDiagRecW, SqlReturn, SQLSTATE_SIZEW};
use std::{convert::TryInto, fmt};
use widestring::{U16CStr, U16Str};
pub type State = [u16; SQLSTATE_SIZEW + 1];
#[derive(Debug, Clone, Copy)]
pub struct DiagResult {
pub state: State,
pub native_error: i32,
}
pub fn diagnostics(
handle: &dyn AsHandle,
rec_number: i16,
message_text: &mut Vec<u16>,
) -> Option<DiagResult> {
assert!(rec_number > 0);
let cap = message_text.capacity();
message_text.resize(cap, 0);
let mut text_length = 0;
let mut state = [0; SQLSTATE_SIZEW + 1];
let mut native_error = 0;
let ret = unsafe {
SQLGetDiagRecW(
handle.handle_type(),
handle.as_handle(),
rec_number,
state.as_mut_ptr(),
&mut native_error,
mut_buf_ptr(message_text),
clamp_small_int(message_text.len()),
&mut text_length,
)
};
let result = DiagResult {
state,
native_error,
};
let mut text_length: usize = text_length.try_into().unwrap();
match ret {
SqlReturn::SUCCESS | SqlReturn::SUCCESS_WITH_INFO => {
if text_length > message_text.len() {
message_text.resize(text_length + 1, 0);
diagnostics(handle, rec_number, message_text)
} else {
while text_length > 0 && message_text[text_length - 1] == 0 {
text_length -= 1;
}
message_text.resize(text_length, 0);
Some(result)
}
}
SqlReturn::NO_DATA => None,
SqlReturn::ERROR => panic!("rec_number argument of diagnostics must be > 0."),
unexpected => panic!("SQLGetDiagRec returned: {:?}", unexpected),
}
}
#[derive(Default)]
pub struct Record {
pub state: State,
pub native_error: i32,
pub message: Vec<u16>,
}
impl Record {
pub fn fill_from(&mut self, handle: &dyn AsHandle, record_number: i16) -> bool {
match diagnostics(handle, record_number, &mut self.message) {
Some(result) => {
self.state = result.state;
self.native_error = result.native_error;
true
}
None => false,
}
}
}
impl fmt::Display for Record {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let state = U16CStr::from_slice_with_nul(&self.state);
let message = U16Str::from_slice(&self.message);
write!(
f,
"State: {}, Native error: {}, Message: {}",
state
.map(U16CStr::to_string_lossy)
.unwrap_or_else(|e| format!("Error decoding state: {}", e)),
self.native_error,
message.to_string_lossy(),
)
}
}
impl fmt::Debug for Record {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(self, f)
}
}
#[cfg(test)]
mod test {
use super::Record;
#[test]
fn formatting() {
let message: Vec<_> = "[Microsoft][ODBC Driver Manager] Function sequence error"
.encode_utf16()
.collect();
let mut rec = Record::default();
for (index, letter) in "HY010".encode_utf16().enumerate() {
rec.state[index] = letter;
}
rec.message = message;
assert_eq!(
format!("{}", rec),
"State: HY010, Native error: 0, Message: [Microsoft][ODBC Driver Manager] \
Function sequence error"
);
}
}