use std::fs;
use std::io;
use std::path::Path;
use std::str;
use rustc_serialize::Decodable;
use {
ByteString, Result, Decoded,
Error, LocatableError, ParseError,
};
use self::State::*;
const BUF_SIZE: usize = 1024 * 128;
#[derive(Clone, Copy)]
pub enum RecordTerminator {
CRLF,
Any(u8),
}
impl RecordTerminator {
#[inline]
fn is_crlf(&self) -> bool {
match *self {
RecordTerminator::CRLF => true,
RecordTerminator::Any(_) => false,
}
}
}
impl PartialEq<u8> for RecordTerminator {
#[inline]
fn eq(&self, &other: &u8) -> bool {
match *self {
RecordTerminator::CRLF => other == b'\r' || other == b'\n',
RecordTerminator::Any(b) => other == b
}
}
}
pub struct Reader<R> {
rdr: R,
buf: Vec<u8>,
bufi: usize,
fieldbuf: Vec<u8>,
state: State,
eof: bool,
first_row: Vec<ByteString>,
first_row_done: bool,
irecord: u64,
ifield: u64,
byte_offset: u64,
delimiter: u8,
quote: u8,
escape: Option<u8>,
double_quote: bool,
record_term: RecordTerminator,
flexible: bool,
#[doc(hidden)]
pub has_headers: bool,
has_seeked: bool,
}
impl<R: io::Read> Reader<R> {
pub fn from_reader(rdr: R) -> Reader<R> {
Reader {
rdr: rdr,
buf: vec![0; BUF_SIZE],
bufi: BUF_SIZE,
fieldbuf: Vec::with_capacity(1024),
state: StartRecord,
eof: false,
first_row: vec![],
first_row_done: false,
irecord: 1,
ifield: 1,
byte_offset: 0,
delimiter: b',',
quote: b'"',
escape: None,
double_quote: true,
record_term: RecordTerminator::CRLF,
flexible: false,
has_headers: true,
has_seeked: false,
}
}
}
impl Reader<fs::File> {
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Reader<fs::File>> {
Ok(Reader::from_reader(try!(fs::File::open(path))))
}
}
impl Reader<io::Cursor<Vec<u8>>> {
pub fn from_string<'a, S>(s: S) -> Reader<io::Cursor<Vec<u8>>>
where S: Into<String> {
Reader::from_bytes(s.into().into_bytes())
}
pub fn from_bytes<'a, V>(bytes: V) -> Reader<io::Cursor<Vec<u8>>>
where V: Into<Vec<u8>> {
Reader::from_reader(io::Cursor::new(bytes.into()))
}
}
impl<R: io::Read> Reader<R> {
pub fn decode<'a, D: Decodable>(&'a mut self) -> DecodedRecords<'a, R, D> {
DecodedRecords {
p: self.byte_records(),
_phantom: ::std::marker::PhantomData,
}
}
pub fn records<'a>(&'a mut self) -> StringRecords<'a, R> {
StringRecords { p: self.byte_records() }
}
pub fn headers(&mut self) -> Result<Vec<String>> {
byte_record_to_utf8(try!(self.byte_headers()))
}
}
impl<R: io::Read> Reader<R> {
pub fn delimiter(mut self, delimiter: u8) -> Reader<R> {
self.delimiter = delimiter;
self
}
pub fn has_headers(mut self, yes: bool) -> Reader<R> {
self.has_headers = yes;
self
}
pub fn flexible(mut self, yes: bool) -> Reader<R> {
self.flexible = yes;
self
}
pub fn record_terminator(mut self, term: RecordTerminator) -> Reader<R> {
self.record_term = term;
self
}
pub fn quote(mut self, quote: u8) -> Reader<R> {
self.quote = quote;
self
}
pub fn escape(mut self, escape: Option<u8>) -> Reader<R> {
self.escape = escape;
self
}
pub fn double_quote(mut self, yes: bool) -> Reader<R> {
self.double_quote = yes;
self
}
pub fn ascii(self) -> Reader<R> {
self.delimiter(b'\x1f')
.record_terminator(RecordTerminator::Any(b'\x1e'))
}
}
#[derive(Debug)]
pub enum NextField<'a, T: ?Sized + 'a> {
Data(&'a T),
Error(Error),
EndOfRecord,
EndOfCsv,
}
impl<'a, T: ?Sized + ::std::fmt::Debug> NextField<'a, T> {
pub fn into_iter_result(self) -> Option<Result<&'a T>> {
match self {
NextField::EndOfRecord | NextField::EndOfCsv => None,
NextField::Error(err) => Some(Err(err)),
NextField::Data(field) => Some(Ok(field)),
}
}
pub fn is_end(&self) -> bool {
if let NextField::EndOfCsv = *self { true } else { false }
}
pub fn unwrap(self) -> &'a T {
match self {
NextField::Data(field) => field,
v => panic!("Cannot unwrap '{:?}'", v),
}
}
}
impl<R: io::Read> Reader<R> {
pub fn byte_headers(&mut self) -> Result<Vec<ByteString>> {
if !self.first_row.is_empty() {
Ok(self.first_row.clone())
} else {
let mut headers = vec![];
loop {
let field = match self.next_bytes() {
NextField::EndOfRecord | NextField::EndOfCsv => break,
NextField::Error(err) => return Err(err),
NextField::Data(field) => field,
};
headers.push(field.to_vec());
}
assert!(headers.len() > 0 || self.done());
Ok(headers)
}
}
pub fn byte_records<'a>(&'a mut self) -> ByteRecords<'a, R> {
let first = self.has_seeked;
ByteRecords { p: self, first: first, errored: false }
}
pub fn done(&self) -> bool {
self.eof
}
pub fn next_bytes(&mut self) -> NextField<[u8]> {
unsafe { self.fieldbuf.set_len(0); }
loop {
if let Err(err) = self.fill_buf() {
return NextField::Error(Error::Io(err));
}
if self.buf.len() == 0 {
self.eof = true;
if let StartRecord = self.state {
return self.next_eoc();
} else if let EndRecord = self.state {
self.state = StartRecord;
return self.next_eor();
} else {
self.state = EndRecord;
return self.next_data();
}
}
while self.bufi < self.buf.len() {
let c = self.buf[self.bufi];
match self.state {
StartRecord => {
if self.is_record_term(c) {
self.bump();
} else {
self.state = StartField;
}
}
EndRecord => {
if self.record_term.is_crlf() && c == b'\n' {
self.bump();
}
self.state = StartRecord;
return self.next_eor();
}
StartField => {
self.bump();
if c == self.quote {
self.state = InQuotedField;
} else if c == self.delimiter {
return self.next_data();
} else if self.is_record_term(c) {
self.state = EndRecord;
return self.next_data();
} else {
self.add(c);
self.state = InField;
}
}
InField => {
self.bump();
if c == self.delimiter {
self.state = StartField;
return self.next_data();
} else if self.is_record_term(c) {
self.state = EndRecord;
return self.next_data();
} else {
self.add(c);
}
}
InQuotedField => {
self.bump();
if c == self.quote {
self.state = InDoubleEscapedQuote;
} else if self.escape == Some(c) {
self.state = InEscapedQuote;
} else {
self.add(c);
}
}
InEscapedQuote => {
self.bump();
self.add(c);
self.state = InQuotedField;
}
InDoubleEscapedQuote => {
self.bump();
if self.double_quote && c == self.quote {
self.add(c);
self.state = InQuotedField;
} else if c == self.delimiter {
self.state = StartField;
return self.next_data();
} else if self.is_record_term(c) {
self.state = EndRecord;
return self.next_data();
} else {
self.add(c);
self.state = InField; }
}
}
}
}
}
pub fn next_str(&mut self) -> NextField<str> {
let (record, field) = (self.irecord, self.ifield);
match self.next_bytes() {
NextField::EndOfRecord => NextField::EndOfRecord,
NextField::EndOfCsv => NextField::EndOfCsv,
NextField::Error(err) => NextField::Error(err),
NextField::Data(bytes) => {
match str::from_utf8(bytes) {
Ok(s) => NextField::Data(s),
Err(_) => NextField::Error(Error::Parse(LocatableError {
record: record,
field: field,
err: ParseError::InvalidUtf8,
})),
}
}
}
}
#[doc(hidden)]
pub unsafe fn byte_fields<'a>(&'a mut self) -> UnsafeByteFields<'a, R> {
UnsafeByteFields { rdr: self }
}
pub fn byte_offset(&self) -> u64 {
self.byte_offset
}
#[inline]
fn next_data(&mut self) -> NextField<[u8]> {
if !self.first_row_done {
self.first_row.push(self.fieldbuf.to_vec());
}
self.ifield += 1;
NextField::Data(&self.fieldbuf)
}
#[inline]
fn next_eor(&mut self) -> NextField<[u8]> {
if !self.flexible
&& self.first_row_done
&& self.ifield != self.first_row.len() as u64 {
return self.parse_error(ParseError::UnequalLengths {
expected: self.first_row.len() as u64,
got: self.ifield as u64,
});
}
self.irecord += 1;
self.ifield = 0;
self.first_row_done = true;
NextField::EndOfRecord
}
#[inline]
fn next_eoc(&self) -> NextField<[u8]> {
NextField::EndOfCsv
}
#[inline]
fn fill_buf(&mut self) -> io::Result<()> {
if self.bufi == self.buf.len() {
unsafe { let cap = self.buf.capacity(); self.buf.set_len(cap); }
let n = try!(self.rdr.read(&mut self.buf));
unsafe { self.buf.set_len(n); }
self.bufi = 0;
}
Ok(())
}
#[inline]
fn bump(&mut self) {
self.bufi += 1;
self.byte_offset += 1;
}
#[inline]
fn add(&mut self, c: u8) {
self.fieldbuf.push(c);
}
#[inline]
fn is_record_term(&self, c: u8) -> bool {
self.record_term == c
}
fn parse_error(&self, err: ParseError) -> NextField<[u8]> {
NextField::Error(Error::Parse(LocatableError {
record: self.irecord,
field: self.ifield,
err: err,
}))
}
}
#[derive(Debug)]
enum State {
StartRecord,
EndRecord,
StartField,
InField,
InQuotedField,
InEscapedQuote,
InDoubleEscapedQuote,
}
impl<R: io::Read + io::Seek> Reader<R> {
pub fn seek(&mut self, pos: u64) -> Result<()> {
self.has_seeked = true;
self.state = StartRecord;
if pos == self.byte_offset() {
return Ok(())
}
self.bufi = self.buf.len(); self.eof = false;
self.byte_offset = pos;
try!(self.rdr.seek(io::SeekFrom::Start(pos)));
Ok(())
}
}
#[doc(hidden)]
pub struct UnsafeByteFields<'a, R: 'a> {
rdr: &'a mut Reader<R>,
}
#[doc(hidden)]
impl<'a, R> Iterator for UnsafeByteFields<'a, R> where R: io::Read {
type Item = Result<&'a [u8]>;
fn next(&mut self) -> Option<Result<&'a [u8]>> {
unsafe {
::std::mem::transmute(self.rdr.next_bytes().into_iter_result())
}
}
}
pub struct DecodedRecords<'a, R: 'a, D> {
p: ByteRecords<'a, R>,
_phantom: ::std::marker::PhantomData<D>,
}
impl<'a, R, D> Iterator for DecodedRecords<'a, R, D>
where R: io::Read, D: Decodable {
type Item = Result<D>;
fn next(&mut self) -> Option<Result<D>> {
self.p.next().map(|res| {
res.and_then(|byte_record| {
Decodable::decode(&mut Decoded::new(byte_record))
})
})
}
}
pub struct StringRecords<'a, R: 'a> {
p: ByteRecords<'a, R>,
}
impl<'a, R> Iterator for StringRecords<'a, R> where R: io::Read {
type Item = Result<Vec<String>>;
fn next(&mut self) -> Option<Result<Vec<String>>> {
self.p.next().map(|res| {
res.and_then(|byte_record| {
byte_record_to_utf8(byte_record)
})
})
}
}
pub struct ByteRecords<'a, R: 'a> {
p: &'a mut Reader<R>,
first: bool,
errored: bool,
}
impl<'a, R> Iterator for ByteRecords<'a, R> where R: io::Read {
type Item = Result<Vec<ByteString>>;
fn next(&mut self) -> Option<Result<Vec<ByteString>>> {
if !self.first {
self.first = true;
let headers = self.p.byte_headers();
if headers.as_ref().map(|r| r.is_empty()).unwrap_or(false) {
assert!(self.p.done());
return None;
}
if !self.p.has_headers {
return Some(headers);
}
}
if self.p.done() || self.errored {
return None;
}
let mut record = Vec::with_capacity(self.p.first_row.len());
loop {
match self.p.next_bytes() {
NextField::EndOfRecord | NextField::EndOfCsv => {
if record.len() == 0 {
return None
}
break
}
NextField::Error(err) => {
self.errored = true;
return Some(Err(err));
}
NextField::Data(field) => record.push(field.to_vec()),
}
}
Some(Ok(record))
}
}
fn byte_record_to_utf8(record: Vec<ByteString>) -> Result<Vec<String>> {
for bytes in record.iter() {
if let Err(err) = ::std::str::from_utf8(&**bytes) {
return Err(Error::Decode(format!(
"Could not decode the following bytes as UTF-8 \
because {}: {:?}", err.to_string(), bytes)));
}
}
Ok(unsafe { ::std::mem::transmute(record) })
}