#![doc(html_root_url="https://docs.rs/postgres-binary-copy/0.5.0")]
#![warn(missing_docs)]
extern crate byteorder;
extern crate postgres;
extern crate streaming_iterator;
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use postgres::types::{Type, ToSql, IsNull};
use postgres::stmt::{CopyInfo, ReadWithInfo, WriteWithInfo};
use std::cmp;
use std::fmt;
use std::io::prelude::*;
use std::io::{self, Cursor};
use std::mem;
use streaming_iterator::StreamingIterator;
const HEADER_MAGIC: &'static [u8] = b"PGCOPY\n\xff\r\n\0";
#[derive(Debug, Copy, Clone)]
enum ReadState {
Header,
Body(usize),
Footer,
}
pub struct BinaryCopyReader<'a, I> {
types: &'a [Type],
state: ReadState,
it: I,
buf: Cursor<Vec<u8>>,
}
impl<'a, I> fmt::Debug for BinaryCopyReader<'a, I>
where
I: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("BinaryCopyReader")
.field("types", &self.types)
.field("state", &self.state)
.field("it", &self.it)
.finish()
}
}
impl<'a, I> BinaryCopyReader<'a, I>
where
I: StreamingIterator<Item = ToSql>,
{
pub fn new(types: &'a [Type], it: I) -> BinaryCopyReader<'a, I> {
let mut buf = vec![];
let _ = buf.write(HEADER_MAGIC);
let _ = buf.write_i32::<BigEndian>(0);
let _ = buf.write_i32::<BigEndian>(0);
BinaryCopyReader {
types: types,
state: ReadState::Header,
it: it,
buf: Cursor::new(buf),
}
}
fn fill_buf(&mut self, _: &CopyInfo) -> io::Result<()> {
enum Op<'a> {
Value(usize, &'a ToSql),
Footer,
Nothing,
}
let op = match (self.state, self.it.next()) {
(ReadState::Header, Some(value)) => {
self.state = ReadState::Body(0);
Op::Value(0, value)
}
(ReadState::Body(old_idx), Some(value)) => {
let idx = (old_idx + 1) % self.types.len();
self.state = ReadState::Body(idx);
Op::Value(idx, value)
}
(ReadState::Header, None) |
(ReadState::Body(_), None) => {
self.state = ReadState::Footer;
Op::Footer
}
(ReadState::Footer, _) => Op::Nothing,
};
self.buf.set_position(0);
self.buf.get_mut().clear();
match op {
Op::Value(idx, value) => {
if idx == 0 {
let len = self.types.len();
let len = if len > i16::max_value() as usize {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"value too large to transmit",
));
} else {
len as i16
};
let _ = self.buf.write_i16::<BigEndian>(len);
}
let len_pos = self.buf.position();
let _ = self.buf.write_i32::<BigEndian>(0); let len = match value.to_sql_checked(&self.types[idx], self.buf.get_mut()) {
Ok(IsNull::Yes) => -1,
Ok(IsNull::No) => {
let len = self.buf.get_ref().len() as u64 - 4 - len_pos;
if len > i32::max_value() as u64 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"value too large to transmit",
));
} else {
len as i32
}
}
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
};
self.buf.set_position(len_pos);
let _ = self.buf.write_i32::<BigEndian>(len);
}
Op::Footer => {
let _ = self.buf.write_i16::<BigEndian>(-1);
}
Op::Nothing => {}
}
self.buf.set_position(0);
Ok(())
}
}
impl<'a, I> ReadWithInfo for BinaryCopyReader<'a, I>
where
I: StreamingIterator<Item = ToSql>,
{
fn read_with_info(&mut self, buf: &mut [u8], info: &CopyInfo) -> io::Result<usize> {
if self.buf.position() == self.buf.get_ref().len() as u64 {
self.fill_buf(info)?;
}
self.buf.read(buf)
}
}
pub trait WriteValue {
fn write_value(&mut self, r: &[u8], info: &CopyInfo) -> io::Result<()>;
fn write_null_value(&mut self, info: &CopyInfo) -> io::Result<()>;
}
impl<F> WriteValue for F
where
F: FnMut(Option<&[u8]>, &CopyInfo) -> io::Result<()>,
{
fn write_value(&mut self, r: &[u8], info: &CopyInfo) -> io::Result<()> {
self(Some(r), info)
}
fn write_null_value(&mut self, info: &CopyInfo) -> io::Result<()> {
self(None, info)
}
}
#[derive(Debug)]
enum WriteState {
AtHeader,
AtTuple,
AtFieldSize(usize),
AtField { size: usize, remaining: usize },
Done,
}
pub struct BinaryCopyWriter<W> {
state: WriteState,
has_oids: bool,
value_writer: W,
buf: Vec<u8>,
}
impl<W> fmt::Debug for BinaryCopyWriter<W>
where
W: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("BinaryCopyWriter")
.field("state", &self.state)
.field("has_oids", &self.has_oids)
.field("value_writer", &self.value_writer)
.field("buf", &self.buf.len())
.finish()
}
}
impl<W> BinaryCopyWriter<W>
where
W: WriteValue,
{
pub fn new(value_writer: W) -> BinaryCopyWriter<W> {
BinaryCopyWriter {
state: WriteState::AtHeader,
has_oids: false,
value_writer: value_writer,
buf: Vec::new(),
}
}
fn read_to(&mut self, buf: &[u8], size: usize) -> io::Result<(bool, usize)> {
let to_read = cmp::min(size - self.buf.len(), buf.len());
let nread = self.buf.write(&buf[..to_read])?;
Ok((nread == to_read, nread))
}
fn read_header(&mut self, buf: &[u8]) -> io::Result<usize> {
let header_size = HEADER_MAGIC.len() + mem::size_of::<i32>() * 2;
let (done, nread) = self.read_to(buf, header_size)?;
if !done {
return Ok(nread);
}
if !self.buf.starts_with(HEADER_MAGIC) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid header",
));
}
let flags = (&mut &self.buf[HEADER_MAGIC.len()..])
.read_i32::<BigEndian>()?;
self.has_oids = (flags & 1 << 16) != 0;
if (flags & !0 << 17) != 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"critical file format issue",
));
}
self.buf.clear();
self.state = WriteState::AtTuple;
Ok(nread)
}
fn read_tuple(&mut self, buf: &[u8]) -> io::Result<usize> {
let (done, nread) = self.read_to(buf, mem::size_of::<i16>())?;
if !done {
return Ok(nread);
}
let mut tuple_size = (&mut &self.buf[..]).read_i16::<BigEndian>()?;
self.buf.clear();
if tuple_size == -1 {
self.state = WriteState::Done;
Ok(nread)
} else {
if self.has_oids {
tuple_size += 1;
}
self.state = WriteState::AtFieldSize(tuple_size as usize);
Ok(nread)
}
}
fn read_field_size(
&mut self,
buf: &[u8],
info: &CopyInfo,
remaining: usize,
) -> io::Result<usize> {
let (done, nread) = self.read_to(buf, mem::size_of::<i32>())?;
if !done {
return Ok(nread);
}
let field_size = (&mut &self.buf[..]).read_i32::<BigEndian>()?;
self.buf.clear();
if field_size == -1 {
self.value_writer.write_null_value(info)?;
self.advance_field_state(remaining);
} else {
self.state = WriteState::AtField {
size: field_size as usize,
remaining: remaining,
};
}
Ok(nread)
}
fn advance_field_state(&mut self, remaining: usize) {
self.state = if remaining == 1 {
WriteState::AtTuple
} else {
WriteState::AtFieldSize(remaining - 1)
};
}
fn read_field(
&mut self,
buf: &[u8],
info: &CopyInfo,
size: usize,
remaining: usize,
) -> io::Result<usize> {
let (done, nread) = self.read_to(buf, size)?;
if !done {
return Ok(nread);
}
self.value_writer.write_value(&self.buf, info)?;
self.buf.clear();
self.advance_field_state(remaining);
Ok(nread)
}
}
impl<W> WriteWithInfo for BinaryCopyWriter<W>
where
W: WriteValue,
{
fn write_with_info(&mut self, buf: &[u8], info: &CopyInfo) -> io::Result<usize> {
match self.state {
WriteState::AtHeader => self.read_header(buf),
WriteState::AtTuple => self.read_tuple(buf),
WriteState::AtFieldSize(remaining) => self.read_field_size(buf, info, remaining),
WriteState::AtField { size, remaining } => self.read_field(buf, info, size, remaining),
WriteState::Done => {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"unexpected input after EOF",
))
}
}
}
}
#[cfg(test)]
mod test {
use super::*;
use postgres::{Connection, TlsMode};
use postgres::types::{FromSql, ToSql, INT4, VARCHAR, BYTEA, OID};
use postgres::stmt::CopyInfo;
use streaming_iterator::{convert, StreamingIterator};
#[test]
fn write_basic() {
let conn = Connection::connect("postgres://postgres:password@localhost", TlsMode::None)
.unwrap();
conn.execute(
"CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY, bar VARCHAR)",
&[],
).unwrap();
let stmt = conn.prepare("COPY foo (id, bar) FROM STDIN BINARY")
.unwrap();
let types = &[INT4, VARCHAR];
let values: Vec<Box<ToSql>> = vec![
Box::new(1i32),
Box::new("foobar"),
Box::new(2i32),
Box::new(None::<String>),
];
let values = convert(values.into_iter()).map_ref(|v| &**v);
let mut reader = BinaryCopyReader::new(types, values);
stmt.copy_in(&[], &mut reader).unwrap();
let stmt = conn.prepare("SELECT id, bar FROM foo ORDER BY id").unwrap();
assert_eq!(
vec![(1i32, Some("foobar".to_string())), (2i32, None)],
stmt.query(&[])
.unwrap()
.into_iter()
.map(|r| (r.get(0), r.get(1)))
.collect::<Vec<(i32, Option<String>)>>()
);
}
#[test]
fn write_many_rows() {
let conn = Connection::connect("postgres://postgres:password@localhost", TlsMode::None)
.unwrap();
conn.execute(
"CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY, bar VARCHAR)",
&[],
).unwrap();
let stmt = conn.prepare("COPY foo (id, bar) FROM STDIN BINARY")
.unwrap();
let types = &[INT4, VARCHAR];
let mut values: Vec<Box<ToSql>> = vec![];
for i in 0..10_000i32 {
values.push(Box::new(i));
values.push(Box::new(format!("the value for {}", i)));
}
let values = convert(values.into_iter()).map_ref(|v| &**v);
let mut reader = BinaryCopyReader::new(types, values);
stmt.copy_in(&[], &mut reader).unwrap();
let stmt = conn.prepare("SELECT id, bar FROM foo ORDER BY id").unwrap();
let result = stmt.query(&[]).unwrap();
assert_eq!(10000, result.len());
for (i, row) in result.into_iter().enumerate() {
assert_eq!(i as i32, row.get(0));
assert_eq!(format!("the value for {}", i), row.get::<_, String>(1));
}
}
#[test]
fn write_big_rows() {
let conn = Connection::connect("postgres://postgres:password@localhost", TlsMode::None)
.unwrap();
conn.execute(
"CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY, bar BYTEA)",
&[],
).unwrap();
let stmt = conn.prepare("COPY foo (id, bar) FROM STDIN BINARY")
.unwrap();
let types = &[INT4, BYTEA];
let mut values: Vec<Box<ToSql>> = vec![];
for i in 0..2i32 {
values.push(Box::new(i));
values.push(Box::new(vec![i as u8; 128 * 1024]));
}
let values = convert(values.into_iter()).map_ref(|v| &**v);
let mut reader = BinaryCopyReader::new(types, values);
stmt.copy_in(&[], &mut reader).unwrap();
let stmt = conn.prepare("SELECT id, bar FROM foo ORDER BY id").unwrap();
let result = stmt.query(&[]).unwrap();
assert_eq!(2, result.len());
for (i, row) in result.into_iter().enumerate() {
assert_eq!(i as i32, row.get(0));
assert_eq!(vec![i as u8; 128 * 1024], row.get::<_, Vec<u8>>(1));
}
}
#[test]
fn read_basic() {
let conn = Connection::connect("postgres://postgres:password@localhost", TlsMode::None)
.unwrap();
conn.execute(
"CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY, bar INT)",
&[],
).unwrap();
conn.execute("INSERT INTO foo (bar) VALUES (1), (2), (NULL), (4)", &[])
.unwrap();
let mut out = vec![];
{
let writer = |r: Option<&[u8]>, _: &CopyInfo| {
match r {
Some(r) => out.push(Option::<i32>::from_sql(&INT4, r).unwrap()),
None => out.push(Option::<i32>::from_sql_null(&INT4).unwrap()),
}
Ok(())
};
let mut writer = BinaryCopyWriter::new(writer);
let stmt = conn.prepare("COPY (SELECT bar FROM foo ORDER BY id) TO STDOUT BINARY")
.unwrap();
stmt.copy_out(&[], &mut writer).unwrap();
}
assert_eq!(out, [Some(1), Some(2), None, Some(4)]);
}
#[test]
fn read_many_rows() {
let conn = Connection::connect("postgres://postgres:password@localhost", TlsMode::None)
.unwrap();
conn.execute("CREATE TEMPORARY TABLE foo (id INT)", &[])
.unwrap();
let mut expected = vec![];
let stmt = conn.prepare("INSERT INTO foo (id) VALUES ($1)").unwrap();
for i in 0..10_000i32 {
stmt.execute(&[&i]).unwrap();
expected.push(i);
}
let mut out = vec![];
{
let writer = |r: Option<&[u8]>, _: &CopyInfo| {
out.push(i32::from_sql(&INT4, r.unwrap()).unwrap());
Ok(())
};
let mut writer = BinaryCopyWriter::new(writer);
let stmt = conn.prepare("COPY (SELECT id FROM foo ORDER BY id) TO STDOUT BINARY")
.unwrap();
stmt.copy_out(&[], &mut writer).unwrap();
}
assert_eq!(out, expected);
}
#[test]
fn read_big_rows() {
let conn = Connection::connect("postgres://postgres:password@localhost", TlsMode::None)
.unwrap();
conn.execute(
"CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY, bar BYTEA)",
&[],
).unwrap();
let mut expected = vec![];
let stmt = conn.prepare("INSERT INTO foo (id, bar) VALUES ($1, $2)")
.unwrap();
for i in 0..2i32 {
let value = vec![i as u8; 128 * 1024];
stmt.execute(&[&i, &value]).unwrap();
expected.push(value);
}
let mut out = vec![];
{
let writer = |r: Option<&[u8]>, _: &CopyInfo| {
out.push(Vec::<u8>::from_sql(&BYTEA, r.unwrap()).unwrap());
Ok(())
};
let mut writer = BinaryCopyWriter::new(writer);
let stmt = conn.prepare(
"COPY (SELECT bar FROM foo ORDER BY id) TO STDOUT (FORMAT \
binary)",
).unwrap();
stmt.copy_out(&[], &mut writer).unwrap();
}
assert_eq!(out, expected);
}
#[test]
fn read_with_oids() {
let conn = Connection::connect("postgres://postgres:password@localhost", TlsMode::None)
.unwrap();
conn.execute("CREATE TEMPORARY TABLE foo (id INT) WITH OIDS", &[])
.unwrap();
conn.execute("INSERT INTO foo (id) VALUES (1), (2), (3), (4)", &[])
.unwrap();
let mut oids = vec![];
let mut out = vec![];
{
let writer = |r: Option<&[u8]>, _: &CopyInfo| {
if oids.len() > out.len() {
out.push(i32::from_sql(&BYTEA, r.unwrap()).unwrap());
} else {
oids.push(u32::from_sql(&OID, r.unwrap()).unwrap());
}
Ok(())
};
let mut writer = BinaryCopyWriter::new(writer);
let stmt = conn.prepare("COPY foo (id) TO STDOUT (FORMAT binary, OIDS)")
.unwrap();
stmt.copy_out(&[], &mut writer).unwrap();
}
assert_eq!(oids.len(), out.len());
assert_eq!(out, [1, 2, 3, 4]);
}
}