#![doc(html_root_url = "https://docs.rs/postgres_large_object/0.7")]
extern crate postgres;
use postgres::{GenericConnection, Result};
use postgres::transaction::Transaction;
use postgres::types::Oid;
use std::cmp;
use std::fmt;
use std::i32;
use std::io::{self, Write};
pub trait LargeObjectExt {
fn create_large_object(&self) -> Result<Oid>;
fn delete_large_object(&self, oid: Oid) -> Result<()>;
}
impl<T: GenericConnection> LargeObjectExt for T {
fn create_large_object(&self) -> Result<Oid> {
let stmt = self.prepare_cached("SELECT pg_catalog.lo_create(0)")?;
let r = stmt.query(&[]).map(|r| r.iter().next().unwrap().get(0));
r
}
fn delete_large_object(&self, oid: Oid) -> Result<()> {
let stmt = self.prepare_cached("SELECT pg_catalog.lo_unlink($1)")?;
stmt.execute(&[&oid]).map(|_| ())
}
}
#[derive(Debug)]
pub enum Mode {
Read,
Write,
ReadWrite,
}
impl Mode {
fn to_i32(&self) -> i32 {
match *self {
Mode::Read => 0x00040000,
Mode::Write => 0x00020000,
Mode::ReadWrite => 0x00040000 | 0x00020000,
}
}
}
pub trait LargeObjectTransactionExt {
fn open_large_object<'a>(&'a self, oid: Oid, mode: Mode) -> Result<LargeObject<'a>>;
}
impl<'conn> LargeObjectTransactionExt for Transaction<'conn> {
fn open_large_object<'a>(&'a self, oid: Oid, mode: Mode) -> Result<LargeObject<'a>> {
let version = self.connection().parameter("server_version").unwrap();
let (major, minor) = parse_version(&version);
let has_64 = major > 9 || (major == 9 && minor >= 3);
let stmt = self.prepare_cached("SELECT pg_catalog.lo_open($1, $2)")?;
let fd = stmt.query(&[&oid, &mode.to_i32()])?
.iter()
.next()
.unwrap()
.get(0);
Ok(LargeObject {
trans: self,
fd: fd,
has_64: has_64,
finished: false,
})
}
}
pub struct LargeObject<'a> {
trans: &'a Transaction<'a>,
fd: i32,
has_64: bool,
finished: bool,
}
impl<'a> fmt::Debug for LargeObject<'a> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("LargeObject")
.field("fd", &self.fd)
.field("transaction", &self.trans)
.finish()
}
}
impl<'a> Drop for LargeObject<'a> {
fn drop(&mut self) {
let _ = self.finish_inner();
}
}
impl<'a> LargeObject<'a> {
pub fn fd(&self) -> i32 {
self.fd
}
pub fn truncate(&mut self, len: i64) -> Result<()> {
if self.has_64 {
let stmt = self.trans
.prepare_cached("SELECT pg_catalog.lo_truncate64($1, $2)")?;
stmt.execute(&[&self.fd, &len]).map(|_| ())
} else {
let len = if len <= i32::max_value() as i64 {
len as i32
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"The database does not support objects larger \
than 2GB",
).into());
};
let stmt = self.trans
.prepare_cached("SELECT pg_catalog.lo_truncate($1, $2)")?;
stmt.execute(&[&self.fd, &len]).map(|_| ())
}
}
fn finish_inner(&mut self) -> Result<()> {
if self.finished {
return Ok(());
}
self.finished = true;
let stmt = self.trans.prepare_cached("SELECT pg_catalog.lo_close($1)")?;
stmt.execute(&[&self.fd]).map(|_| ())
}
pub fn finish(mut self) -> Result<()> {
self.finish_inner()
}
}
impl<'a> io::Read for LargeObject<'a> {
fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
let stmt = self.trans
.prepare_cached("SELECT pg_catalog.loread($1, $2)")?;
let cap = cmp::min(buf.len(), i32::MAX as usize) as i32;
let rows = stmt.query(&[&self.fd, &cap])?;
buf.write(rows.get(0).get_bytes(0).unwrap())
}
}
impl<'a> io::Write for LargeObject<'a> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let stmt = self.trans
.prepare_cached("SELECT pg_catalog.lowrite($1, $2)")?;
let cap = cmp::min(buf.len(), i32::MAX as usize);
stmt.execute(&[&self.fd, &&buf[..cap]])?;
Ok(cap)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl<'a> io::Seek for LargeObject<'a> {
fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
let (kind, pos) = match pos {
io::SeekFrom::Start(pos) => {
let pos = if pos <= i64::max_value as u64 {
pos as i64
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"cannot seek more than 2^63 bytes",
));
};
(0, pos)
}
io::SeekFrom::Current(pos) => (1, pos),
io::SeekFrom::End(pos) => (2, pos),
};
if self.has_64 {
let stmt = self.trans
.prepare_cached("SELECT pg_catalog.lo_lseek64($1, $2, $3)")?;
let rows = stmt.query(&[&self.fd, &pos, &kind])?;
let pos: i64 = rows.iter().next().unwrap().get(0);
Ok(pos as u64)
} else {
let pos = if pos <= i32::max_value() as i64 {
pos as i32
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"cannot seek more than 2^31 bytes",
));
};
let stmt = self.trans
.prepare_cached("SELECT pg_catalog.lo_lseek($1, $2, $3)")?;
let rows = stmt.query(&[&self.fd, &pos, &kind])?;
let pos: i32 = rows.iter().next().unwrap().get(0);
Ok(pos as u64)
}
}
}
fn parse_version(version: &str) -> (i32, i32) {
let version = version.split(' ').next().unwrap();
let mut version = version.split('.');
let major: i32 = version.next().unwrap().parse().unwrap();
let minor: i32 = version.next().unwrap().parse().unwrap();
(major, minor)
}
#[cfg(test)]
mod test {
use postgres::{Connection, TlsMode};
use postgres::error::UNDEFINED_OBJECT;
use {parse_version, LargeObjectExt, LargeObjectTransactionExt, Mode};
#[test]
fn test_create_delete() {
let conn = Connection::connect("postgres://postgres@localhost", TlsMode::None).unwrap();
let oid = conn.create_large_object().unwrap();
conn.delete_large_object(oid).unwrap();
}
#[test]
fn test_delete_bogus() {
let conn = Connection::connect("postgres://postgres@localhost", TlsMode::None).unwrap();
match conn.delete_large_object(0) {
Ok(()) => panic!("unexpected success"),
Err(ref e) if e.code() == Some(&UNDEFINED_OBJECT) => {}
Err(e) => panic!("unexpected error: {:?}", e),
}
}
#[test]
fn test_open_bogus() {
let conn = Connection::connect("postgres://postgres@localhost", TlsMode::None).unwrap();
let trans = conn.transaction().unwrap();
match trans.open_large_object(0, Mode::Read) {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.code() == Some(&UNDEFINED_OBJECT) => {}
Err(e) => panic!("unexpected error: {:?}", e),
};
}
#[test]
fn test_open_finish() {
let conn = Connection::connect("postgres://postgres@localhost", TlsMode::None).unwrap();
let trans = conn.transaction().unwrap();
let oid = trans.create_large_object().unwrap();
let lo = trans.open_large_object(oid, Mode::Read).unwrap();
lo.finish().unwrap();
}
#[test]
fn test_write_read() {
use std::io::{Read, Write};
let conn = Connection::connect("postgres://postgres@localhost", TlsMode::None).unwrap();
let trans = conn.transaction().unwrap();
let oid = trans.create_large_object().unwrap();
let mut lo = trans.open_large_object(oid, Mode::Write).unwrap();
lo.write_all(b"hello world!!!").unwrap();
let mut lo = trans.open_large_object(oid, Mode::Read).unwrap();
let mut out = vec![];
lo.read_to_end(&mut out).unwrap();
assert_eq!(out, b"hello world!!!");
}
#[test]
fn test_seek_tell() {
use std::io::{Read, Seek, SeekFrom, Write};
let conn = Connection::connect("postgres://postgres@localhost", TlsMode::None).unwrap();
let trans = conn.transaction().unwrap();
let oid = trans.create_large_object().unwrap();
let mut lo = trans.open_large_object(oid, Mode::Write).unwrap();
lo.write_all(b"hello world!!!").unwrap();
assert_eq!(14, lo.seek(SeekFrom::Current(0)).unwrap());
assert_eq!(1, lo.seek(SeekFrom::Start(1)).unwrap());
let mut buf = [0];
assert_eq!(1, lo.read(&mut buf).unwrap());
assert_eq!(b'e', buf[0]);
assert_eq!(2, lo.seek(SeekFrom::Current(0)).unwrap());
assert_eq!(10, lo.seek(SeekFrom::End(-4)).unwrap());
assert_eq!(1, lo.read(&mut buf).unwrap());
assert_eq!(b'd', buf[0]);
assert_eq!(8, lo.seek(SeekFrom::Current(-3)).unwrap());
assert_eq!(1, lo.read(&mut buf).unwrap());
assert_eq!(b'r', buf[0]);
}
#[test]
fn test_write_with_read_fd() {
use std::io::Write;
let conn = Connection::connect("postgres://postgres@localhost", TlsMode::None).unwrap();
let trans = conn.transaction().unwrap();
let oid = trans.create_large_object().unwrap();
let mut lo = trans.open_large_object(oid, Mode::Read).unwrap();
assert!(lo.write_all(b"hello world!!!").is_err());
}
#[test]
fn test_truncate() {
use std::io::{Read, Seek, SeekFrom, Write};
let conn = Connection::connect("postgres://postgres@localhost", TlsMode::None).unwrap();
let trans = conn.transaction().unwrap();
let oid = trans.create_large_object().unwrap();
let mut lo = trans.open_large_object(oid, Mode::Write).unwrap();
lo.write_all(b"hello world!!!").unwrap();
lo.truncate(5).unwrap();
lo.seek(SeekFrom::Start(0)).unwrap();
let mut buf = vec![];
lo.read_to_end(&mut buf).unwrap();
assert_eq!(buf, b"hello");
lo.truncate(10).unwrap();
lo.seek(SeekFrom::Start(0)).unwrap();
buf.clear();
lo.read_to_end(&mut buf).unwrap();
assert_eq!(buf, b"hello\0\0\0\0\0");
}
#[test]
fn test_parse_version() {
let version = parse_version("10.3 (Debian 10.3-1.pgdg90+1)");
assert_eq!(version, (10, 3));
let version = parse_version("9.5");
assert_eq!(version, (9, 5));
}
}