simple_pg 0.5.8

Provides extentions and utilites for working with postgres.
Documentation
use postgres_types::Type;

#[derive(Debug)]
/// Failure to parse component field see [composite_from_sql]
pub enum CompositeError {
    OidEOF,
    SizeEOF,
    SizeTooBig,
}

impl std::fmt::Display for CompositeError {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
        match self {
            CompositeError::OidEOF => f.write_str("Oid EOf parsing composite"),
            CompositeError::SizeEOF => f.write_str("Size EOF parsing composite"),
            CompositeError::SizeTooBig => f.write_str("Size too big parsing composite"),
        }
    }
}
impl std::error::Error for CompositeError {}

/// Iterator of record field bytes `None` values represent NULL values.
pub fn record_from_sql(
    mut bytes: &[u8],
) -> impl Iterator<Item = Result<(Option<Type>, Option<&[u8]>), CompositeError>> {
    if bytes.len() > 4 {
        bytes = &bytes[4..];
    }
    std::iter::from_fn(move || {
        if bytes.is_empty() {
            return None;
        }
        let Some((oid_bytes, rest)) = bytes.split_first_chunk() else {
            return Some(Err(CompositeError::OidEOF));
        };
        let oid = Type::from_oid(u32::from_be_bytes(*oid_bytes));
        let Some((size_bytes, rest)) = rest.split_first_chunk() else {
            return Some(Err(CompositeError::SizeEOF));
        };
        let size = i32::from_be_bytes(*size_bytes);
        if size < 0 {
            bytes = rest;
            return Some(Ok((oid, None)));
        }
        if size as usize > rest.len() {
            return Some(Err(CompositeError::SizeEOF));
        }
        let (data, rest) = rest.split_at_checked(size as usize)?;
        bytes = rest;
        Some(Ok((oid, Some(data))))
    })
}

/// Iterator of composite field bytes `None` values represent NULL values.
pub fn composite_from_sql(
    bytes: &[u8],
) -> impl Iterator<Item = Result<Option<&[u8]>, CompositeError>> {
    record_from_sql(bytes).map(|value| value.map(|(_, bytes)| bytes))
}

#[derive(Copy, Clone, Debug)]
enum NumericParseError {
    UnexpectedEof,
    InvalidDigit,
}

/// Decode a Numeric into f64
pub fn numeric_f64_from_sql(bytes: &[u8]) -> Result<f64, Box<dyn std::error::Error + Send + Sync>> {
    // todo optimize
    match numeric_string_from_sql(bytes) {
        Ok(value) => Ok(value.parse()?),
        Err(NumericParseError::UnexpectedEof) => Err("Unexpected EOF while parsing Numeric".into()),
        Err(NumericParseError::InvalidDigit) => Err("Invalid Digit while parsing Numeric".into()),
    }
}

fn numeric_string_from_sql(mut bytes: &[u8]) -> Result<String, NumericParseError> {
    let mut next_i16 = move || {
        let Some((num_bytes, rest)) = bytes.split_first_chunk() else {
            return Err(NumericParseError::UnexpectedEof);
        };
        bytes = rest;
        Ok(i16::from_be_bytes(*num_bytes))
    };

    let ndigits = next_i16()?;
    let weight = next_i16()?;
    let sign = next_i16()?;
    let dscale = next_i16()?;
    let mut digits = Vec::with_capacity(ndigits as usize);
    for _ in 0..ndigits {
        digits.push(next_i16()?);
    }
    let dec_digits = 4;
    let numeric_neg = 0x4000;
    let numeric_nan = -16384;
    if sign == numeric_nan {
        return Ok(String::from("NaN"));
    }
    // --- A port of `get_str_from_var` ---
    let mut string =
        Vec::with_capacity((((weight + 1) * dec_digits) + 1 + dscale + dec_digits + 2) as usize);
    // Output a dash for negative values.
    if sign == numeric_neg {
        string.push(b'-');
    }
    // Output all digits before the decimal point.
    let mut d;
    if weight < 0 {
        d = weight + 1;
        string.push(b'0');
    } else {
        d = 0;
        while d <= weight {
            let mut dig: i16 = if d < ndigits { digits[d as usize] } else { 0 };
            let mut putit = d > 0;
            let mut d1 = dig / 1000;
            dig -= d1 * 1000;
            putit |= d1 > 0;
            if putit {
                string.push(d1 as u8 + b'0');
            }
            d1 = dig / 100;
            dig -= d1 * 100;
            putit |= d1 > 0;
            if putit {
                string.push(d1 as u8 + b'0');
            }
            d1 = dig / 10;
            dig -= d1 * 10;
            putit |= d1 > 0;
            if putit {
                string.push(d1 as u8 + b'0');
            }
            string.push(dig as u8 + b'0');
            d += 1;
        }
    }
    // If requested, output a decimal point and all the digits that follow it.
    // We initially put out a multiple of DEC_DIGITS digits, then truncate if needed.
    if dscale > 0 {
        string.push(b'.');
        let mut i = 0;
        while i < dscale {
            let mut dig = if d >= 0 && d < ndigits {
                digits[d as usize]
            } else {
                0
            };
            assert!(dec_digits == 4);
            let mut d1 = dig / 1000;
            dig -= d1 * 1000;
            string.push(d1 as u8 + b'0');
            d1 = dig / 100;
            dig -= d1 * 100;
            string.push(d1 as u8 + b'0');
            d1 = dig / 10;
            dig -= d1 * 10;
            string.push(d1 as u8 + b'0');
            string.push(dig as u8 + b'0');
            d += 1;
            i += dec_digits;
        }
        while *string.last().expect("!last") == b'0' {
            string.pop();
        }
    }
    match String::from_utf8(string) {
        Ok(val) => Ok(val),
        Err(_) => Err(NumericParseError::InvalidDigit),
    }
}

#[cfg(test)]
mod test {
    use super::*;
    #[test]
    fn numeric() {
        assert_eq!(
            &numeric_string_from_sql(&[
                0, 5, 0, 0, 0, 0, 0, 16, 0, 33, 28, 32, 36, 253, 27, 57, 27, 58
            ])
            .unwrap(),
            "33.720094696969697"
        );
        assert_eq!(
            &numeric_string_from_sql(&[0, 1, 255, 251, 64, 0, 0, 19, 0, 30]).unwrap(),
            "-0.0000000000000000003"
        );
    }
}