use std::borrow::Cow;
use std::sync::Arc;
use bytes::Bytes;
use mssql_types::decode::{TypeInfo, decode_value};
use mssql_types::{FromSql, SqlValue, TypeError};
use crate::blob::BlobReader;
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub struct ColumnSlice {
pub offset: u32,
pub length: u32,
pub is_null: bool,
}
impl ColumnSlice {
pub fn new(offset: u32, length: u32, is_null: bool) -> Self {
Self {
offset,
length,
is_null,
}
}
pub fn null() -> Self {
Self {
offset: 0,
length: 0,
is_null: true,
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct Column {
pub name: String,
pub index: usize,
pub type_name: String,
pub nullable: bool,
pub max_length: Option<u32>,
pub precision: Option<u8>,
pub scale: Option<u8>,
pub collation: Option<tds_protocol::Collation>,
}
impl Column {
pub fn new(name: impl Into<String>, index: usize, type_name: impl Into<String>) -> Self {
Self {
name: name.into(),
index,
type_name: type_name.into(),
nullable: true,
max_length: None,
precision: None,
scale: None,
collation: None,
}
}
#[must_use]
pub fn with_nullable(mut self, nullable: bool) -> Self {
self.nullable = nullable;
self
}
#[must_use]
pub fn with_max_length(mut self, max_length: u32) -> Self {
self.max_length = Some(max_length);
self
}
#[must_use]
pub fn with_precision_scale(mut self, precision: u8, scale: u8) -> Self {
self.precision = Some(precision);
self.scale = Some(scale);
self
}
#[must_use]
pub fn with_collation(mut self, collation: tds_protocol::Collation) -> Self {
self.collation = Some(collation);
self
}
#[must_use]
pub fn encoding_name(&self) -> &'static str {
#[cfg(feature = "encoding")]
if let Some(ref collation) = self.collation {
return collation.encoding_name();
}
"unknown"
}
#[must_use]
pub fn is_utf8_collation(&self) -> bool {
#[cfg(feature = "encoding")]
if let Some(ref collation) = self.collation {
return collation.is_utf8();
}
false
}
pub fn to_type_info(&self) -> TypeInfo {
let type_id = type_name_to_id(&self.type_name);
TypeInfo {
type_id,
length: self.max_length,
scale: self.scale,
precision: self.precision,
collation: self.collation.map(|c| mssql_types::decode::Collation {
lcid: c.lcid,
flags: c.sort_id,
}),
}
}
}
fn type_name_to_id(name: &str) -> u8 {
match name.to_uppercase().as_str() {
"INT" | "INTEGER" => 0x38,
"BIGINT" => 0x7F,
"SMALLINT" => 0x34,
"TINYINT" => 0x30,
"BIT" => 0x32,
"FLOAT" => 0x3E,
"REAL" => 0x3B,
"DECIMAL" | "NUMERIC" => 0x6C,
"MONEY" | "SMALLMONEY" => 0x6E,
"NVARCHAR" | "NCHAR" | "NTEXT" => 0xE7,
"VARCHAR" | "CHAR" | "TEXT" => 0xA7,
"VARBINARY" | "BINARY" | "IMAGE" => 0xA5,
"DATE" => 0x28,
"TIME" => 0x29,
"DATETIME2" => 0x2A,
"DATETIMEOFFSET" => 0x2B,
"DATETIME" => 0x3D,
"SMALLDATETIME" => 0x3F,
"UNIQUEIDENTIFIER" => 0x24,
"XML" => 0xF1,
_ if name.ends_with("N") => 0x26,
_ => 0xA5,
}
}
#[derive(Debug, Clone)]
pub struct ColMetaData {
pub columns: Arc<[Column]>,
}
impl ColMetaData {
pub fn new(columns: Vec<Column>) -> Self {
Self {
columns: columns.into(),
}
}
#[must_use]
pub fn len(&self) -> usize {
self.columns.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.columns.is_empty()
}
#[must_use]
pub fn get(&self, index: usize) -> Option<&Column> {
self.columns.get(index)
}
#[must_use]
pub fn find_by_name(&self, name: &str) -> Option<usize> {
self.columns
.iter()
.position(|c| c.name.eq_ignore_ascii_case(name))
}
}
#[derive(Clone)]
pub struct Row {
buffer: Arc<Bytes>,
slices: Arc<[ColumnSlice]>,
metadata: Arc<ColMetaData>,
values: Option<Arc<[SqlValue]>>,
}
impl Row {
pub fn new(buffer: Arc<Bytes>, slices: Arc<[ColumnSlice]>, metadata: Arc<ColMetaData>) -> Self {
Self {
buffer,
slices,
metadata,
values: None,
}
}
#[allow(dead_code)]
pub(crate) fn from_values(columns: Vec<Column>, values: Vec<SqlValue>) -> Self {
let metadata = Arc::new(ColMetaData::new(columns));
let slices: Arc<[ColumnSlice]> = values
.iter()
.enumerate()
.map(|(i, v)| ColumnSlice::new(i as u32, 0, v.is_null()))
.collect::<Vec<_>>()
.into();
Self {
buffer: Arc::new(Bytes::new()),
slices,
metadata,
values: Some(values.into()),
}
}
#[must_use]
pub fn get_bytes(&self, index: usize) -> Option<&[u8]> {
let slice = self.slices.get(index)?;
if slice.is_null {
return None;
}
let start = slice.offset as usize;
let end = start + slice.length as usize;
if end <= self.buffer.len() {
Some(&self.buffer[start..end])
} else {
None
}
}
#[must_use]
pub fn get_str(&self, index: usize) -> Option<Cow<'_, str>> {
let bytes = self.get_bytes(index)?;
match std::str::from_utf8(bytes) {
Ok(s) => Some(Cow::Borrowed(s)),
Err(_) => {
#[cfg(feature = "encoding")]
if let Some(column) = self.metadata.get(index) {
if let Some(ref collation) = column.collation {
if let Some(encoding) = collation.encoding() {
let (decoded, _, had_errors) = encoding.decode(bytes);
if had_errors {
tracing::warn!(
column_name = %column.name,
column_index = index,
encoding = %encoding.name(),
lcid = collation.lcid,
byte_len = bytes.len(),
"collation-aware decoding had errors, falling back to UTF-16LE"
);
} else {
return Some(Cow::Owned(decoded.into_owned()));
}
} else {
tracing::debug!(
column_name = %column.name,
column_index = index,
lcid = collation.lcid,
"no encoding found for LCID, falling back to UTF-16LE"
);
}
}
}
let utf16: Vec<u16> = bytes
.chunks_exact(2)
.map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
.collect();
String::from_utf16(&utf16).ok().map(Cow::Owned)
}
}
}
#[must_use]
pub fn get_string(&self, index: usize) -> Option<String> {
self.get_str(index).map(|cow| cow.into_owned())
}
#[must_use]
pub fn get_stream(&self, index: usize) -> Option<BlobReader> {
let slice = self.slices.get(index)?;
if slice.is_null {
return None;
}
let start = slice.offset as usize;
let end = start + slice.length as usize;
if end <= self.buffer.len() {
let data = self.buffer.slice(start..end);
Some(BlobReader::from_bytes(data))
} else {
None
}
}
#[must_use]
pub fn get_stream_by_name(&self, name: &str) -> Option<BlobReader> {
let index = self.metadata.find_by_name(name)?;
self.get_stream(index)
}
pub fn get<T: FromSql>(&self, index: usize) -> Result<T, TypeError> {
if let Some(ref values) = self.values {
return values
.get(index)
.ok_or_else(|| TypeError::TypeMismatch {
expected: "valid column index",
actual: format!("index {index} out of bounds"),
})
.and_then(T::from_sql);
}
let slice = self
.slices
.get(index)
.ok_or_else(|| TypeError::TypeMismatch {
expected: "valid column index",
actual: format!("index {index} out of bounds"),
})?;
if slice.is_null {
return Err(TypeError::UnexpectedNull);
}
let value = self.parse_value(index, slice)?;
T::from_sql(&value)
}
pub fn get_by_name<T: FromSql>(&self, name: &str) -> Result<T, TypeError> {
let index = self
.metadata
.find_by_name(name)
.ok_or_else(|| TypeError::TypeMismatch {
expected: "valid column name",
actual: format!("column '{name}' not found"),
})?;
self.get(index)
}
pub fn try_get<T: FromSql>(&self, index: usize) -> Option<T> {
if let Some(ref values) = self.values {
return values
.get(index)
.and_then(|v| T::from_sql_nullable(v).ok().flatten());
}
let slice = self.slices.get(index)?;
if slice.is_null {
return None;
}
self.get(index).ok()
}
pub fn try_get_by_name<T: FromSql>(&self, name: &str) -> Option<T> {
let index = self.metadata.find_by_name(name)?;
self.try_get(index)
}
#[must_use]
pub fn get_raw(&self, index: usize) -> Option<SqlValue> {
if let Some(ref values) = self.values {
return values.get(index).cloned();
}
let slice = self.slices.get(index)?;
self.parse_value(index, slice).ok()
}
#[must_use]
pub fn get_raw_by_name(&self, name: &str) -> Option<SqlValue> {
let index = self.metadata.find_by_name(name)?;
self.get_raw(index)
}
#[must_use]
pub fn len(&self) -> usize {
self.slices.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.slices.is_empty()
}
#[must_use]
pub fn columns(&self) -> &[Column] {
&self.metadata.columns
}
#[must_use]
pub fn metadata(&self) -> &Arc<ColMetaData> {
&self.metadata
}
#[must_use]
pub fn is_null(&self, index: usize) -> bool {
self.slices.get(index).map(|s| s.is_null).unwrap_or(true)
}
#[must_use]
pub fn is_null_by_name(&self, name: &str) -> bool {
self.metadata
.find_by_name(name)
.map(|i| self.is_null(i))
.unwrap_or(true)
}
fn parse_value(&self, index: usize, slice: &ColumnSlice) -> Result<SqlValue, TypeError> {
if slice.is_null {
return Ok(SqlValue::Null);
}
let column = self
.metadata
.get(index)
.ok_or_else(|| TypeError::TypeMismatch {
expected: "valid column metadata",
actual: format!("no metadata for column {index}"),
})?;
let start = slice.offset as usize;
let end = start + slice.length as usize;
if end > self.buffer.len() {
return Err(TypeError::TypeMismatch {
expected: "valid byte range",
actual: format!(
"range {}..{} exceeds buffer length {}",
start,
end,
self.buffer.len()
),
});
}
let type_info = column.to_type_info();
let mut buf = self.buffer.slice(start..end);
decode_value(&mut buf, &type_info)
}
}
impl std::fmt::Debug for Row {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Row")
.field("columns", &self.metadata.columns.len())
.field("buffer_size", &self.buffer.len())
.field("has_cached_values", &self.values.is_some())
.finish()
}
}
pub struct RowIter<'a> {
row: &'a Row,
index: usize,
}
impl Iterator for RowIter<'_> {
type Item = SqlValue;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.row.len() {
return None;
}
let value = self.row.get_raw(self.index);
self.index += 1;
value
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.row.len() - self.index;
(remaining, Some(remaining))
}
}
impl<'a> IntoIterator for &'a Row {
type Item = SqlValue;
type IntoIter = RowIter<'a>;
fn into_iter(self) -> Self::IntoIter {
RowIter {
row: self,
index: 0,
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_column_slice_null() {
let slice = ColumnSlice::null();
assert!(slice.is_null);
assert_eq!(slice.offset, 0);
assert_eq!(slice.length, 0);
}
#[test]
fn test_column_metadata() {
let col = Column::new("id", 0, "INT")
.with_nullable(false)
.with_precision_scale(10, 0);
assert_eq!(col.name, "id");
assert_eq!(col.index, 0);
assert!(!col.nullable);
assert_eq!(col.precision, Some(10));
}
#[test]
fn test_col_metadata_find_by_name() {
let meta = ColMetaData::new(vec![
Column::new("id", 0, "INT"),
Column::new("Name", 1, "NVARCHAR"),
]);
assert_eq!(meta.find_by_name("id"), Some(0));
assert_eq!(meta.find_by_name("ID"), Some(0)); assert_eq!(meta.find_by_name("name"), Some(1));
assert_eq!(meta.find_by_name("unknown"), None);
}
#[test]
fn test_row_from_values_backward_compat() {
let columns = vec![
Column::new("id", 0, "INT"),
Column::new("name", 1, "NVARCHAR"),
];
let values = vec![SqlValue::Int(42), SqlValue::String("Alice".to_string())];
let row = Row::from_values(columns, values);
assert_eq!(row.len(), 2);
assert_eq!(row.get::<i32>(0).unwrap(), 42);
assert_eq!(row.get_by_name::<String>("name").unwrap(), "Alice");
}
#[test]
fn test_row_is_null() {
let columns = vec![
Column::new("id", 0, "INT"),
Column::new("nullable_col", 1, "NVARCHAR"),
];
let values = vec![SqlValue::Int(1), SqlValue::Null];
let row = Row::from_values(columns, values);
assert!(!row.is_null(0));
assert!(row.is_null(1));
assert!(row.is_null(99)); }
#[test]
fn test_row_get_bytes_with_buffer() {
let buffer = Arc::new(Bytes::from_static(b"Hello World"));
let slices: Arc<[ColumnSlice]> = vec![
ColumnSlice::new(0, 5, false), ColumnSlice::new(6, 5, false), ]
.into();
let meta = Arc::new(ColMetaData::new(vec![
Column::new("greeting", 0, "VARCHAR"),
Column::new("subject", 1, "VARCHAR"),
]));
let row = Row::new(buffer, slices, meta);
assert_eq!(row.get_bytes(0), Some(b"Hello".as_slice()));
assert_eq!(row.get_bytes(1), Some(b"World".as_slice()));
}
#[test]
fn test_row_get_str() {
let buffer = Arc::new(Bytes::from_static(b"Test"));
let slices: Arc<[ColumnSlice]> = vec![ColumnSlice::new(0, 4, false)].into();
let meta = Arc::new(ColMetaData::new(vec![Column::new("val", 0, "VARCHAR")]));
let row = Row::new(buffer, slices, meta);
let s = row.get_str(0).unwrap();
assert_eq!(s, "Test");
assert!(matches!(s, Cow::Borrowed(_)));
}
#[test]
fn test_row_metadata_access() {
let columns = vec![Column::new("col1", 0, "INT")];
let row = Row::from_values(columns, vec![SqlValue::Int(1)]);
assert_eq!(row.columns().len(), 1);
assert_eq!(row.columns()[0].name, "col1");
assert_eq!(row.metadata().len(), 1);
}
#[test]
fn test_row_get_stream() {
let buffer = Arc::new(Bytes::from_static(b"Hello, World!"));
let slices: Arc<[ColumnSlice]> = vec![
ColumnSlice::new(0, 5, false), ColumnSlice::new(7, 5, false), ColumnSlice::null(), ]
.into();
let meta = Arc::new(ColMetaData::new(vec![
Column::new("greeting", 0, "VARBINARY"),
Column::new("subject", 1, "VARBINARY"),
Column::new("nullable", 2, "VARBINARY"),
]));
let row = Row::new(buffer, slices, meta);
let reader = row.get_stream(0).unwrap();
assert_eq!(reader.len(), Some(5));
assert_eq!(reader.as_bytes().as_ref(), b"Hello");
let reader = row.get_stream(1).unwrap();
assert_eq!(reader.len(), Some(5));
assert_eq!(reader.as_bytes().as_ref(), b"World");
assert!(row.get_stream(2).is_none());
assert!(row.get_stream(99).is_none());
}
#[test]
fn test_row_get_stream_by_name() {
let buffer = Arc::new(Bytes::from_static(b"Binary data here"));
let slices: Arc<[ColumnSlice]> = vec![ColumnSlice::new(0, 11, false)].into();
let meta = Arc::new(ColMetaData::new(vec![Column::new(
"document",
0,
"VARBINARY",
)]));
let row = Row::new(buffer, slices, meta);
let reader = row.get_stream_by_name("document").unwrap();
assert_eq!(reader.len(), Some(11));
let reader = row.get_stream_by_name("DOCUMENT").unwrap();
assert_eq!(reader.len(), Some(11));
assert!(row.get_stream_by_name("unknown").is_none());
}
}