use std::io::{self,Write,BufWriter,Seek,SeekFrom};
use std::fs::File;
use std::path::Path;
use std::marker::PhantomData;
use std::convert::Infallible as Never;
use byteorder::{WriteBytesExt, LittleEndian};
use crate::serialize::{AutoSerialize, Serialize, TypeWrite};
use crate::header::{self, DType, VersionProps, HeaderSizeType, HeaderEncoding};
use crate::read::Order;
const FILLER_FOR_UNKNOWN_SIZE: &'static [u8] = &[b'*'; 19];
struct DataFromBuilder<T: ?Sized> {
order: Order,
dtype: DType,
shape: Option<Vec<u64>>,
_marker: PhantomData<fn(&T)>, }
pub use write_options::{WriteOptions, WriterBuilder};
pub mod write_options {
use super::*;
#[derive(Debug)]
pub struct WriteOptions<T: ?Sized> {
order: Order,
_marker: PhantomData<fn(&T)>, }
impl<T: ?Sized> WriteOptions<T> {
pub fn new() -> Self { WriteOptions {
order: Order::C,
_marker: PhantomData,
}}
}
impl WriteOptions<Never> {
pub fn new_header_only() -> Self { WriteOptions {
order: Order::C,
_marker: PhantomData,
}}
}
impl<T: ?Sized> Default for WriteOptions<T> {
fn default() -> Self { Self::new() }
}
impl<T: ?Sized> Clone for WriteOptions<T> {
fn clone(&self) -> Self { WriteOptions { order: self.order.clone(), _marker: self._marker }}
}
pub trait WriterBuilder<T: ?Sized>: Sized {
fn default_dtype(self) -> WithDType<Self> where T: AutoSerialize { self.dtype(T::default_dtype()) }
fn dtype(self, dtype: DType) -> WithDType<Self> { WithDType { inner: self, dtype } }
fn shape(self, shape: &[u64]) -> WithShape<Self> { WithShape { inner: self, shape: shape.to_vec() } }
fn writer<W>(self, writer: W) -> WithWriter<W, Self>
where
Self: MissingWriter,
{ WithWriter { inner: self, writer } }
fn order(self, order: Order) -> Self;
#[doc(hidden)] fn __get_order(&self) -> Order;
fn begin_nd(self) -> io::Result<NpyWriter<T, <Self as HasWriter>::Writer>>
where
T: Serialize,
Self: HasDType + HasWriter + HasShape,
<Self as HasWriter>::Writer: Write,
{
NpyWriter::_begin(DataFromBuilder {
dtype: self.__get_dtype(),
order: self.__get_order(),
shape: Some(self.__get_shape()),
_marker: PhantomData,
}, MaybeSeek::Isnt(self.__into_writer()))
}
fn begin_1d(self) -> io::Result<NpyWriter<T, <Self as HasWriter>::Writer>>
where
T: Serialize,
Self: HasDType + HasWriter,
<Self as HasWriter>::Writer: Write + Seek,
{
NpyWriter::_begin(DataFromBuilder {
dtype: self.__get_dtype(),
order: self.__get_order(),
shape: None,
_marker: PhantomData,
}, MaybeSeek::new_seek(self.__into_writer()))
}
fn write_header_only(mut self) -> io::Result<<Self as HasWriter>::Writer>
where
Self: HasDType + HasWriter + HasShape,
<Self as HasWriter>::Writer: Write,
{
let dtype = self.__get_dtype();
let order = self.__get_order();
let shape = self.__get_shape();
write_header(self.__writer_mut(), &dtype, order, Some(shape.as_slice()))?;
Ok(self.__into_writer())
}
}
#[derive(Debug, Clone)]
pub struct WithWriter<W, Builder> {
pub(super) writer: W,
pub(super) inner: Builder,
}
#[derive(Debug, Clone)]
pub struct WithDType<Builder> {
pub(super) dtype: DType,
pub(super) inner: Builder,
}
#[derive(Debug, Clone)]
pub struct WithShape<Builder> {
pub(super) shape: Vec<u64>,
pub(super) inner: Builder,
}
pub trait HasDType {
#[doc(hidden)] fn __get_dtype(&self) -> DType;
}
pub trait HasShape {
#[doc(hidden)] fn __get_shape(&self) -> Vec<u64>;
}
pub trait HasWriter {
type Writer;
#[doc(hidden)]
fn __into_writer(self) -> Self::Writer;
#[doc(hidden)]
fn __writer_mut(&mut self) -> &mut Self::Writer;
}
pub trait MissingWriter {}
impl<T: ?Sized> WriterBuilder<T> for WriteOptions<T> {
fn order(mut self, order: Order) -> Self { self.order = order; self }
fn __get_order(&self) -> Order { self.order }
}
impl<W, T: ?Sized, B: WriterBuilder<T>> WriterBuilder<T> for WithWriter<W, B> {
fn order(mut self, order: Order) -> Self { self.inner = self.inner.order(order); self }
fn __get_order(&self) -> Order { self.inner.__get_order() }
}
impl<T: ?Sized, B: WriterBuilder<T>> WriterBuilder<T> for WithDType<B> {
fn order(mut self, order: Order) -> Self { self.inner = self.inner.order(order); self }
fn __get_order(&self) -> Order { self.inner.__get_order() }
}
impl<T: ?Sized, B: WriterBuilder<T>> WriterBuilder<T> for WithShape<B> {
fn order(mut self, order: Order) -> Self { self.inner = self.inner.order(order); self }
fn __get_order(&self) -> Order { self.inner.__get_order() }
}
impl<B> HasDType for WithDType<B> {
fn __get_dtype(&self) -> DType { self.dtype.clone() }
}
impl<B> HasShape for WithShape<B> {
fn __get_shape(&self) -> Vec<u64> { self.shape.to_vec() }
}
impl<W, B> HasWriter for WithWriter<W, B> {
type Writer = W;
fn __into_writer(self) -> Self::Writer { self.writer }
fn __writer_mut(&mut self) -> &mut Self::Writer { &mut self.writer }
}
impl<T: ?Sized> MissingWriter for WriteOptions<T> {}
macro_rules! forward_typestate_impls {
( $(
( $inner:tt $impl_generics:tt $Self:tt ): ( $($Trait:ident)* )
)* ) => {
$($( forward_typestate_impls!(@single $inner $impl_generics $Self [$Trait]); )*)*
};
(@single [$inner:ident] [$($impl_generics:tt)*] [$Self:ty] [HasDType]) => {
impl<$($impl_generics)*> HasDType for $Self where $inner: HasDType {
fn __get_dtype(&self) -> DType { self.inner.__get_dtype() }
}
};
(@single [$inner:ident] [$($impl_generics:tt)*] [$Self:ty] [HasShape]) => {
impl<$($impl_generics)*> HasShape for $Self where $inner: HasShape {
fn __get_shape(&self) -> Vec<u64> { self.inner.__get_shape() }
}
};
(@single [$inner:ident] [$($impl_generics:tt)*] [$Self:ty] [HasWriter]) => {
impl<$($impl_generics)*> HasWriter for $Self where $inner: HasWriter {
type Writer = $inner::Writer;
fn __into_writer(self) -> Self::Writer { self.inner.__into_writer() }
fn __writer_mut(&mut self) -> &mut Self::Writer { self.inner.__writer_mut() }
}
};
(@single [$inner:ident] [$($impl_generics:tt)*] [$Self:ty] [MissingWriter]) => {
impl<$($impl_generics)*> MissingWriter for $Self where $inner: MissingWriter { }
};
}
forward_typestate_impls!{
([B] [B] [WithShape<B>]): ( HasDType HasWriter MissingWriter)
([B] [B] [WithDType<B>]): (HasShape HasWriter MissingWriter)
([B] [W, B] [WithWriter<W, B>]): (HasShape HasDType )
}
}
pub struct NpyWriter<Row: Serialize + ?Sized, W: Write> {
start_pos: Option<u64>,
shape_info: ShapeInfo,
num_items: u64,
fw: MaybeSeek<W>,
writer: <Row as Serialize>::TypeWriter,
version_props: VersionProps,
}
enum ShapeInfo {
Automatic { offset_in_header_text: u64 },
Known { expected_num_items: u64 },
}
#[deprecated(since = "0.5.0", note = "Doesn't carry its weight. Use to_file_1d instead, or replicate the original behavior with Builder::new().default_dtype().begin_1d(std::io::BufWriter::new(std::fs::File::create(path)?))")]
pub type OutFile<Row> = NpyWriter<Row, BufWriter<File>>;
#[allow(deprecated)]
impl<Row: AutoSerialize> OutFile<Row> {
#[deprecated(since = "0.5.0", note = "Doesn't carry its weight. Use to_file_1d instead, or replicate the original behavior with Builder::new().default_dtype().begin_1d(std::io::BufWriter::new(std::fs::File::create(path)?))")]
pub fn open<P: AsRef<Path>>(path: P) -> io::Result<Self> {
WriteOptions::new()
.default_dtype()
.writer(BufWriter::new(File::create(path)?))
.begin_1d()
}
}
#[allow(deprecated)]
impl<Row: Serialize> OutFile<Row> {
#[deprecated(since = "0.5.0", note = "use .finish() instead")]
pub fn close(self) -> io::Result<()> {
self.finish()
}
}
impl<Row: Serialize + ?Sized , W: Write> NpyWriter<Row, W> {
fn _begin(builder: DataFromBuilder<Row>, mut fw: MaybeSeek<W>) -> io::Result<Self> {
let DataFromBuilder { dtype, order, shape, _marker } = builder;
let start_pos = match fw {
MaybeSeek::Is(ref mut fw) => Some(fw.seek(SeekFrom::Current(0))?),
MaybeSeek::Isnt(_) => None,
};
let (shape_info, version_props) = write_header(&mut fw, &dtype, order, shape.as_deref())?;
let writer = match Row::writer(&dtype) {
Ok(writer) => writer,
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())),
};
Ok(NpyWriter {
start_pos,
shape_info,
num_items: 0,
fw,
writer,
version_props,
})
}
pub fn push(&mut self, row: &Row) -> io::Result<()> {
self.num_items += 1;
self.writer.write_one(&mut self.fw, row)
}
pub fn extend(&mut self, rows: impl IntoIterator<Item=Row>) -> io::Result<()> where Row: Sized {
rows.into_iter().try_for_each(|row| self.push(&row))
}
fn finish_(&mut self) -> io::Result<()> {
match self.shape_info {
ShapeInfo::Known { expected_num_items } => {
if expected_num_items != self.num_items {
return Err(io::Error::new(io::ErrorKind::InvalidData, {
format!("shape has {} item(s), but {} item(s) were written!", expected_num_items, self.num_items)
}));
}
},
ShapeInfo::Automatic { offset_in_header_text } => {
let shape_pos = self.start_pos.unwrap() + self.version_props.bytes_before_text() as u64 + offset_in_header_text;
let end_pos = self.fw.seek(SeekFrom::Current(0))?;
self.fw.seek(SeekFrom::Start(shape_pos))?;
let length = format!("{}", self.num_items);
self.fw.write_all(length.as_bytes())?;
self.fw.write_all(&b",), }"[..])?;
self.fw.write_all(&::std::iter::repeat(b' ').take(FILLER_FOR_UNKNOWN_SIZE.len() - length.len()).collect::<Vec<_>>())?;
self.fw.seek(SeekFrom::Start(end_pos))?;
},
}
self.fw.flush()?;
Ok(())
}
pub fn finish(mut self) -> io::Result<()> {
self.finish_()
}
}
fn write_header<W: Write>(
fw: &mut W,
dtype: &DType,
order: Order,
shape: Option<&[u64]>,
) -> io::Result<(ShapeInfo, VersionProps)> {
if let DType::Array(..) = dtype {
panic!("the outermost dtype cannot be an array (got: {:?})", dtype);
}
let (dict_text, shape_info) = create_dict(dtype, order, shape);
let (header_text, version, version_props) = determine_required_version_and_pad_header(dict_text);
fw.write_all(&[0x93u8])?;
fw.write_all(b"NUMPY")?;
fw.write_all(&[version.0, version.1])?;
assert_eq!((header_text.len() + version_props.bytes_before_text()) % 16, 0);
match version_props.header_size_type {
HeaderSizeType::U16 => {
assert!(header_text.len() <= u16::MAX as usize);
fw.write_u16::<LittleEndian>(header_text.len() as u16)?;
},
HeaderSizeType::U32 => {
assert!(header_text.len() <= u32::MAX as usize);
fw.write_u32::<LittleEndian>(header_text.len() as u32)?;
},
}
fw.write_all(&header_text)?;
Ok((shape_info, version_props))
}
fn create_dict(dtype: &DType, order: Order, shape: Option<&[u64]>) -> (Vec<u8>, ShapeInfo) {
let mut header: Vec<u8> = vec![];
header.extend(&b"{'descr': "[..]);
header.extend(dtype.descr().as_bytes());
header.extend(&b", 'fortran_order': "[..]);
match order {
Order::C => header.extend(&b"False"[..]),
Order::Fortran => header.extend(&b"True"[..]),
}
header.extend(&b", 'shape': ("[..]);
let shape_info = match shape {
Some(shape) => {
for x in shape {
write!(header, "{}, ", x).unwrap();
}
header.extend(&b"), }"[..]);
ShapeInfo::Known { expected_num_items: shape.iter().product() }
},
None => {
let shape_offset = header.len() as u64;
header.extend(FILLER_FOR_UNKNOWN_SIZE);
header.extend(&b",), }"[..]);
ShapeInfo::Automatic { offset_in_header_text: shape_offset }
},
};
(header, shape_info)
}
impl<Row: Serialize + ?Sized, W: Write> Drop for NpyWriter<Row, W> {
fn drop(&mut self) {
let _ = self.finish_(); }
}
fn determine_required_version_and_pad_header(mut header_utf8: Vec<u8>) -> (Vec<u8>, (u8, u8), VersionProps) {
use HeaderSizeType::*;
use HeaderEncoding::*;
const SAFE_U16_CUTOFF: usize = 0xffff_fc00;
let required_props = VersionProps {
header_size_type: if header_utf8.len() >= SAFE_U16_CUTOFF { U32 } else { U16 },
encoding: if header_utf8.iter().any(|b| !b.is_ascii()) { Utf8 } else { Ascii },
};
let version = header::get_minimal_version(required_props);
let actual_props = header::get_version_props(version).expect("generated internally so must be valid");
const ALIGN_TO: usize = 64;
let bytes_before_text = actual_props.bytes_before_text();
header_utf8.extend(&::std::iter::repeat(b' ').take(ALIGN_TO - 1 - ((header_utf8.len() + bytes_before_text) % ALIGN_TO)).collect::<Vec<_>>());
header_utf8.push(b'\n');
assert_eq!((header_utf8.len() + bytes_before_text) % ALIGN_TO, 0);
(header_utf8, version, actual_props)
}
#[deprecated(since = "0.5.0", note = "renamed to to_file_1d")]
pub fn to_file<S, T, P>(filename: P, data: T) -> std::io::Result<()>
where
P: AsRef<Path>,
S: AutoSerialize,
T: IntoIterator<Item=S>,
{
to_file_1d(filename, data)
}
pub fn to_file_1d<S, T, P>(filename: P, data: T) -> std::io::Result<()>
where
P: AsRef<Path>,
S: AutoSerialize,
T: IntoIterator<Item=S>,
{
#![allow(deprecated)]
let mut of = OutFile::open(filename)?;
for row in data {
of.push(&row)?;
}
of.close()
}
use maybe_seek::MaybeSeek;
mod maybe_seek {
use super::*;
pub(crate) enum MaybeSeek<W> {
Is(Box<dyn WriteSeek<W>>),
Isnt(W),
}
pub(crate) trait WriteSeek<W>: Write + Seek + sealed::Sealed<W> {}
mod sealed {
use super::*;
pub(crate) trait Sealed<W> {}
impl<W: Write + Seek> Sealed<W> for W {}
}
impl<W: Write + Seek> WriteSeek<W> for W {}
impl<W: Write> Write for MaybeSeek<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
MaybeSeek::Is(w) => (*w).write(buf),
MaybeSeek::Isnt(w) => w.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
MaybeSeek::Is(w) => (*w).flush(),
MaybeSeek::Isnt(w) => w.flush(),
}
}
}
impl<W> Seek for MaybeSeek<W> {
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
match self {
MaybeSeek::Is(w) => (*w).seek(pos),
MaybeSeek::Isnt(_) => unreachable!("(BUG!) .seek() called on MaybeSeek::Isnt!"),
}
}
}
impl<W: WriteSeek<W>> MaybeSeek<W> {
pub fn new_seek(w: W) -> Self {
let inner = unsafe {
std::mem::transmute::<
Box<dyn WriteSeek<W> + '_>,
Box<dyn WriteSeek<W> + 'static>,
>(Box::new(w))
};
MaybeSeek::Is(inner)
}
}
}
#[cfg(test)]
pub(crate) fn to_bytes_1d<T: AutoSerialize>(data: &[T]) -> io::Result<Vec<u8>> {
let mut cursor = io::Cursor::new(vec![]);
to_writer_1d(&mut cursor, data)?;
Ok(cursor.into_inner())
}
#[cfg(test)]
pub(crate) fn to_writer_1d<W: io::Write + io::Seek, T: AutoSerialize>(writer: W, data: &[T]) -> io::Result<()> {
to_writer_1d_with_seeking(writer, data)
}
#[cfg(test)]
pub(crate) fn to_writer_nd<W: io::Write, T: AutoSerialize>(writer: W, data: &[T], shape: &[u64]) -> io::Result<()> {
let mut writer = WriteOptions::new().default_dtype().writer(writer).shape(shape).begin_nd()?;
writer.extend(data)?;
writer.finish()
}
#[cfg(test)]
pub(crate) fn to_writer_1d_with_seeking<W: io::Write + io::Seek, T: AutoSerialize>(writer: W, data: &[T]) -> io::Result<()> {
let mut writer = WriteOptions::new().default_dtype().writer(writer).begin_1d()?;
writer.extend(data)?;
writer.finish()
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::{self, Cursor};
use crate::NpyFile;
use crate::header::Field;
fn bytestring_contains(haystack: &[u8], needle: &[u8]) -> bool {
if needle.is_empty() {
return true;
}
haystack.windows(needle.len()).any(move |w| w == needle)
}
#[test]
fn write_1d_simple() -> io::Result<()> {
let raw_buffer = to_bytes_1d(&[1.0, 3.0, 5.0])?;
let reader = NpyFile::new(&raw_buffer[..])?;
assert_eq!(reader.into_vec::<f64>()?, vec![1.0, 3.0, 5.0]);
Ok(())
}
#[test]
fn write_1d_in_the_middle() -> io::Result<()> {
let mut cursor = Cursor::new(vec![]);
let prefix = b"lorem ipsum dolor sit amet.";
let suffix = b"and they lived happily ever after.";
cursor.write_all(prefix)?;
to_writer_1d_with_seeking(&mut cursor, &[1.0, 3.0, 5.0])?;
cursor.write_all(suffix)?;
let raw_buffer = cursor.into_inner();
assert!(raw_buffer.starts_with(prefix));
assert!(raw_buffer.ends_with(suffix));
let written_bytes = &raw_buffer[prefix.len()..raw_buffer.len() - suffix.len()];
let reader = NpyFile::new(&written_bytes[..])?;
assert_eq!(reader.into_vec::<f64>()?, vec![1.0, 3.0, 5.0]);
Ok(())
}
#[test]
fn implicit_finish() -> io::Result<()> {
let mut cursor = Cursor::new(vec![]);
let mut writer = WriteOptions::new().default_dtype().writer(&mut cursor).begin_1d()?;
writer.extend(vec![1.0, 3.0, 5.0, 7.0])?;
drop(writer);
let raw_buffer = cursor.into_inner();
println!("{:?}", raw_buffer);
assert!(bytestring_contains(&raw_buffer, b"'shape': (4,"));
Ok(())
}
#[test]
fn write_nd_simple() -> io::Result<()> {
let mut buffer = vec![];
to_writer_nd(&mut buffer, &[00, 01, 02, 10, 11, 12], &[2, 3])?;
let reader = NpyFile::new(&buffer[..])?;
assert_eq!(reader.shape(), &[2, 3][..]);
assert_eq!(reader.into_vec::<i32>()?, vec![00, 01, 02, 10, 11, 12]);
Ok(())
}
#[test]
fn write_nd_wrong_len() -> io::Result<()> {
let try_writing = |elems: &[i32]| -> io::Result<()> {
let mut buf = vec![];
let mut writer = WriteOptions::new().default_dtype().writer(&mut buf).shape(&[2, 3]).begin_nd()?;
for &x in elems {
writer.push(&x)?;
}
writer.finish()?;
Ok(())
};
assert!(try_writing(&[00, 01, 02, 10, 11]).is_err());
assert!(try_writing(&[00, 01, 02, 10, 11, 12]).is_ok());
assert!(try_writing(&[00, 01, 02, 10, 11, 12, 20]).is_err());
Ok(())
}
#[test]
fn write_header_only_positions() -> io::Result<()> {
let mut cursor = Cursor::new(vec![]);
let dtype = DType::new_scalar("|O".parse().unwrap());
cursor.write(b"ABCD")?;
let header_start_pos = cursor.position();
let returned_writer = WriteOptions::new_header_only()
.shape(&[2, 3])
.dtype(dtype.clone())
.writer(&mut cursor)
.write_header_only()?;
returned_writer.write(b"dcba")?;
cursor.set_position(header_start_pos);
let npy = NpyFile::new(&mut cursor)?;
assert_eq!(npy.dtype(), dtype);
assert_eq!(npy.shape(), &[2, 3]);
assert_eq!(npy.len(), 6);
assert!(npy.uses_pickled_array());
let mut trailing_bytes = vec![];
std::io::Read::read_to_end(&mut cursor, &mut trailing_bytes)?;
assert_eq!(&trailing_bytes[..], b"dcba");
Ok(())
}
#[test]
fn write_header_only_funny_type() -> io::Result<()> {
let mut cursor = Cursor::new(vec![]);
let dtype = DType::Record(vec![
Field {
name: "parent".to_string(),
dtype: DType::Record(vec![
Field {
name: "child".to_string(),
dtype: DType::Plain("|O".parse().unwrap()),
},
]),
}
]);
WriteOptions::new_header_only()
.shape(&[2, 3])
.dtype(dtype.clone())
.writer(&mut cursor)
.write_header_only()?;
cursor.set_position(0);
let npy = NpyFile::new(&mut cursor)?;
assert_eq!(npy.dtype(), dtype);
assert_eq!(npy.shape(), &[2, 3]);
assert_eq!(npy.len(), 6);
assert!(npy.uses_pickled_array());
Ok(())
}
}