use crate::row::sealed::{AsName, Sealed};
use crate::simple_query::SimpleColumn;
use crate::statement::Column;
use crate::types::{FromSql, Type, WrongType};
use crate::{Error, Statement};
use fallible_iterator::FallibleIterator;
use postgres_protocol::message::backend::DataRowBody;
use postgres_types::Field;
use std::fmt;
use std::ops::Range;
use std::str;
use std::sync::Arc;
mod sealed {
pub trait Sealed {}
pub trait AsName {
fn as_name(&self) -> &str;
}
}
impl AsName for Column {
fn as_name(&self) -> &str {
self.name()
}
}
impl AsName for String {
fn as_name(&self) -> &str {
self
}
}
pub trait RowIndex: Sealed {
#[doc(hidden)]
fn __idx<T>(&self, columns: &[T]) -> Option<usize>
where
T: AsName;
}
impl Sealed for usize {}
impl RowIndex for usize {
#[inline]
fn __idx<T>(&self, columns: &[T]) -> Option<usize>
where
T: AsName,
{
if *self >= columns.len() {
None
} else {
Some(*self)
}
}
}
impl Sealed for str {}
impl RowIndex for str {
#[inline]
fn __idx<T>(&self, columns: &[T]) -> Option<usize>
where
T: AsName,
{
if let Some(idx) = columns.iter().position(|d| d.as_name() == self) {
return Some(idx);
};
columns
.iter()
.position(|d| d.as_name().eq_ignore_ascii_case(self))
}
}
impl<'a, T> Sealed for &'a T where T: ?Sized + Sealed {}
impl<'a, T> RowIndex for &'a T
where
T: ?Sized + RowIndex,
{
#[inline]
fn __idx<U>(&self, columns: &[U]) -> Option<usize>
where
U: AsName,
{
T::__idx(*self, columns)
}
}
pub struct Row {
statement: Statement,
body: DataRowBody,
ranges: Vec<Option<Range<usize>>>,
}
impl fmt::Debug for Row {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Row")
.field("columns", &self.columns())
.finish()
}
}
pub struct CompositeArray<'a> {
fields: &'a [Field],
array: postgres_protocol::types::ArrayValues<'a>,
}
pub struct CompositeRow<'a> {
fields: &'a [Field],
data: &'a [u8],
ranges: &'a [(u32, u32)],
}
#[derive(Debug)]
enum CompositeError {
OidEOF,
SizeEOF,
ValueLengthOutOfBounds,
}
fn munch_composite(input: &[u8]) -> Result<(Option<(i32, &[u8])>, &[u8]), CompositeError> {
let (oid_bytes, rest) = input.split_first_chunk().ok_or(CompositeError::OidEOF)?;
let oid = i32::from_be_bytes(*oid_bytes);
let (size_bytes, rest) = rest.split_first_chunk().ok_or(CompositeError::SizeEOF)?;
let size = i32::from_be_bytes(*size_bytes);
if size < 0 {
return Ok((None, rest));
}
if size as usize > input.len() {
return Err(CompositeError::ValueLengthOutOfBounds);
}
let (data, rest) = rest.split_at(size as usize);
Ok((Some((oid, data)), rest))
}
impl<'b> CompositeRow<'b> {
fn col_buffer(&self, idx: usize) -> Option<&[u8]> {
let (a, b) = self.ranges[idx];
if b == 0 {
return None;
}
Some(&self.data[a as usize..b as usize])
}
pub fn get<'a, T: FromSql<'a>>(&'a self, idx: usize) -> Result<T, Error> {
let Some(column) = self.fields.get(idx) else {
return Err(Error::column_index(idx));
};
let ty = column.type_();
if !T::accepts(ty) {
return Err(Error::from_sql(
Box::new(WrongType::new::<T>(ty.clone())),
idx,
));
}
Ok(FromSql::from_sql_nullable(ty, self.col_buffer(idx)).unwrap())
}
}
impl<'a> CompositeArray<'a> {
pub fn fields(&self) -> &[Field] {
&self.fields
}
pub fn len(&self) -> usize {
self.array.size_hint().0
}
pub fn next<'b>(
&'b mut self,
buffer: &'b mut Vec<(u32, u32)>,
) -> Result<Option<CompositeRow<'b>>, Error> {
let data = match self.array.next() {
Ok(Some(None)) => {
return Err(Error::custom(format!("Unexpected NULL composite").into()))
}
Ok(None) => return Ok(None),
Ok(Some(Some(value))) => value,
Err(err) => return Err(Error::from_sql(err, 0)),
};
buffer.clear();
if data.len() < 4 {
return Err(Error::custom("Missing composite length header".into()));
}
let data: &[u8] = &data[4..];
let mut tdata = data;
while !tdata.is_empty() {
let entry = match munch_composite(&tdata) {
Ok((entry, rest)) => {
tdata = rest;
entry
}
Err(err) => {
return Err(Error::custom(
format!("Invalid composite encoding: {:?}", err).into(),
));
}
};
if let Some((_oid, bytes)) = entry {
let start = unsafe { bytes.as_ptr().offset_from(data.as_ptr()) } as u32;
buffer.push((start, start + bytes.len() as u32))
} else {
buffer.push((0, 0))
}
}
Ok(Some(CompositeRow {
fields: &self.fields,
data,
ranges: buffer,
}))
}
}
pub trait Record {
fn get<'a, T: FromSql<'a>>(&'a self, idx: usize) -> Result<T, Error>;
}
impl Record for Row {
fn get<'a, T: FromSql<'a>>(&'a self, idx: usize) -> Result<T, Error> {
self.get(idx)
}
}
impl<'a> Record for CompositeRow<'a> {
fn get<'b, T: FromSql<'b>>(&'b self, idx: usize) -> Result<T, Error> {
self.get(idx)
}
}
impl Row {
pub(crate) fn new(statement: Statement, body: DataRowBody) -> Result<Row, Error> {
let ranges = body.ranges().collect().map_err(Error::parse)?;
Ok(Row {
statement,
body,
ranges,
})
}
pub fn get_composite_array<'a>(&'a self, idx: usize) -> Result<CompositeArray<'a>, Error> {
let Some(column) = self.columns().get(idx) else {
return Err(Error::column_index(idx));
};
let ty = column.type_();
if let postgres_types::Kind::Array(arr) = column.type_().kind() {
if let postgres_types::Kind::Composite(fields) = &arr.kind() {
let Some(evts) = self.col_buffer(idx) else {
return Err(Error::row_count());
};
match postgres_protocol::types::array_from_sql(evts) {
Ok(array) => {
return Ok(CompositeArray {
array: array.values(),
fields,
})
}
Err(err) => return Err(Error::from_sql(err, idx)),
}
}
}
return Err(Error::from_sql(
Box::new(WrongType::new::<()>(ty.clone())),
idx,
));
}
pub fn columns(&self) -> &[Column] {
self.statement.columns()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn len(&self) -> usize {
self.columns().len()
}
pub fn get_unwrap<'a, I, T>(&'a self, idx: I) -> T
where
I: RowIndex + fmt::Display,
T: FromSql<'a>,
{
match self.get_inner(&idx) {
Ok(ok) => ok,
Err(err) => panic!("error retrieving column {}: {}", idx, err),
}
}
pub fn get<'a, T: FromSql<'a>>(&'a self, idx: usize) -> Result<T, Error> {
let Some(column) = self.columns().get(idx) else {
return Err(Error::column_index(idx));
};
let ty = column.type_();
if !T::accepts(ty) {
return Err(Error::from_sql(
Box::new(WrongType::new::<T>(ty.clone())),
idx,
));
}
FromSql::from_sql_nullable(ty, self.col_buffer(idx)).map_err(|e| Error::from_sql(e, idx))
}
pub fn get_by<'a, I, T>(&'a self, idx: I) -> Result<T, Error>
where
I: RowIndex + fmt::Display,
T: FromSql<'a>,
{
self.get_inner(&idx)
}
fn get_inner<'a, I, T>(&'a self, idx: &I) -> Result<T, Error>
where
I: RowIndex + fmt::Display,
T: FromSql<'a>,
{
let idx = match idx.__idx(self.columns()) {
Some(idx) => idx,
None => return Err(Error::column(idx.to_string())),
};
let ty = self.columns()[idx].type_();
if !T::accepts(ty) {
return Err(Error::from_sql(
Box::new(WrongType::new::<T>(ty.clone())),
idx,
));
}
FromSql::from_sql_nullable(ty, self.col_buffer(idx)).map_err(|e| Error::from_sql(e, idx))
}
#[doc(hidden)]
pub fn col_buffer(&self, idx: usize) -> Option<&[u8]> {
let range = self.ranges[idx].to_owned()?;
Some(&self.body.buffer()[range])
}
}
impl AsName for SimpleColumn {
fn as_name(&self) -> &str {
self.name()
}
}
#[derive(Debug)]
pub struct SimpleQueryRow {
columns: Arc<[SimpleColumn]>,
body: DataRowBody,
ranges: Vec<Option<Range<usize>>>,
}
impl SimpleQueryRow {
#[allow(clippy::new_ret_no_self)]
pub(crate) fn new(
columns: Arc<[SimpleColumn]>,
body: DataRowBody,
) -> Result<SimpleQueryRow, Error> {
let ranges = body.ranges().collect().map_err(Error::parse)?;
Ok(SimpleQueryRow {
columns,
body,
ranges,
})
}
pub fn columns(&self) -> &[SimpleColumn] {
&self.columns
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn len(&self) -> usize {
self.columns.len()
}
pub fn get<I>(&self, idx: I) -> Option<&str>
where
I: RowIndex + fmt::Display,
{
match self.get_inner(&idx) {
Ok(ok) => ok,
Err(err) => panic!("error retrieving column {}: {}", idx, err),
}
}
pub fn try_get<I>(&self, idx: I) -> Result<Option<&str>, Error>
where
I: RowIndex + fmt::Display,
{
self.get_inner(&idx)
}
fn get_inner<I>(&self, idx: &I) -> Result<Option<&str>, Error>
where
I: RowIndex + fmt::Display,
{
let idx = match idx.__idx(&self.columns) {
Some(idx) => idx,
None => return Err(Error::column(idx.to_string())),
};
let buf = self.ranges[idx].clone().map(|r| &self.body.buffer()[r]);
FromSql::from_sql_nullable(&Type::TEXT, buf).map_err(|e| Error::from_sql(e, idx))
}
}