use std::str;
use bytes::{Buf, BufMut, BytesMut};
use crate::error::{PgWireError, PgWireResult};
pub(crate) fn get_cstring(buf: &mut BytesMut) -> Option<String> {
let mut i = 0;
while i < buf.remaining() && buf[i] != b'\0' {
i += 1;
}
if i == buf.remaining() {
return None;
}
let string_buf = buf.split_to(i + 1);
if i == 0 {
None
} else {
Some(String::from_utf8_lossy(&string_buf[..i]).into_owned())
}
}
pub(crate) fn put_cstring(buf: &mut BytesMut, input: &str) {
buf.put_slice(input.as_bytes());
buf.put_u8(b'\0');
}
pub(crate) fn put_option_cstring(buf: &mut BytesMut, input: &Option<String>) {
if let Some(input) = input {
put_cstring(buf, input);
} else {
buf.put_u8(b'\0');
}
}
pub(crate) fn get_length(buf: &BytesMut, offset: usize) -> Option<usize> {
if buf.remaining() >= 4 + offset {
Some((&buf[offset..4 + offset]).get_i32() as usize)
} else {
None
}
}
pub(crate) fn decode_packet<T, F>(
buf: &mut BytesMut,
offset: usize,
max_size: usize,
decode_fn: F,
) -> PgWireResult<Option<T>>
where
F: Fn(&mut BytesMut, usize) -> PgWireResult<T>,
{
if let Some(msg_len) = get_length(buf, offset) {
if msg_len > max_size {
return Err(PgWireError::MessageTooLarge(max_size, msg_len));
}
if buf.remaining() >= msg_len + offset {
buf.advance(offset + 4);
return decode_fn(buf, msg_len).map(|r| Some(r));
}
}
Ok(None)
}
pub(crate) fn option_string_len(s: &Option<String>) -> usize {
1 + s.as_ref().map(|s| s.len()).unwrap_or(0)
}
#[cfg(test)]
mod test {
use super::get_cstring;
use bytes::{BufMut, BytesMut};
#[test]
fn get_cstring_valid() {
let mut buf = BytesMut::new();
buf.put(&b"a cstring\0"[..]);
buf.put(&b"\0"[..]);
assert_eq!(Some("a cstring".into()), get_cstring(&mut buf));
assert_eq!(None, get_cstring(&mut buf));
}
#[test]
fn get_cstring_empty() {
let mut buf = BytesMut::new();
assert_eq!(None, get_cstring(&mut buf));
}
#[test]
fn get_cstring_without_null() {
let mut buf = BytesMut::new();
buf.put(&b"a cstring"[..]);
assert_eq!(None, get_cstring(&mut buf));
}
}