use std::collections::HashMap;
use std::io::Write;
use std::rc::Rc;
use failure::{err_msg, Error};
use rand::random;
use serde::Serialize;
use serde_json;
use encode::{encode, encode_to_vec};
use schema::Schema;
use ser::Serializer;
use types::{ToAvro, Value};
use Codec;
const SYNC_SIZE: usize = 16;
const SYNC_INTERVAL: usize = 1000 * SYNC_SIZE;
const AVRO_OBJECT_HEADER: &'static [u8] = &[b'O', b'b', b'j', 1u8];
pub struct Writer<'a, W> {
schema: &'a Schema,
serializer: Serializer,
writer: W,
buffer: Vec<u8>,
num_values: usize,
codec: Codec,
marker: Vec<u8>,
has_header: bool,
}
impl<'a, W: Write> Writer<'a, W> {
pub fn new(schema: &'a Schema, writer: W) -> Writer<'a, W> {
Self::with_codec(schema, writer, Codec::Null)
}
pub fn with_codec(schema: &'a Schema, writer: W, codec: Codec) -> Writer<'a, W> {
let mut marker = Vec::with_capacity(16);
for _ in 0..16 {
marker.push(random::<u8>());
}
Writer {
schema,
serializer: Serializer::new(),
writer,
buffer: Vec::with_capacity(SYNC_INTERVAL),
num_values: 0,
codec,
marker,
has_header: false,
}
}
pub fn schema(&self) -> &'a Schema {
self.schema
}
pub fn append<T: ToAvro>(&mut self, value: T) -> Result<usize, Error> {
let n = if !self.has_header {
let header = self.header()?;
let n = self.append_bytes(header.as_ref())?;
self.has_header = true;
n
} else {
0
};
write_avro_datum(self.schema, value, &mut self.buffer)?;
self.num_values += 1;
if self.buffer.len() >= SYNC_INTERVAL {
return self.flush().map(|b| b + n)
}
Ok(n)
}
pub fn append_ser<S: Serialize>(&mut self, value: S) -> Result<usize, Error> {
let avro_value = value.serialize(&mut self.serializer)?;
self.append(avro_value)
}
pub fn extend<I, T: ToAvro>(&mut self, values: I) -> Result<usize, Error>
where
I: Iterator<Item = T>,
{
let mut num_bytes = 0;
for value in values {
num_bytes += self.append(value)?;
}
num_bytes += self.flush()?;
Ok(num_bytes)
}
pub fn extend_ser<I, T: Serialize>(&mut self, values: I) -> Result<usize, Error>
where
I: Iterator<Item = T>,
{
let mut num_bytes = 0;
for value in values {
num_bytes += self.append_ser(value)?;
}
num_bytes += self.flush()?;
Ok(num_bytes)
}
pub fn flush(&mut self) -> Result<usize, Error> {
if self.num_values == 0 {
return Ok(0)
}
self.codec.compress(&mut self.buffer)?;
let num_values = self.num_values;
let stream_len = self.buffer.len();
let num_bytes = self.append_raw(num_values.avro(), &Schema::Long)?
+ self.append_raw(stream_len.avro(), &Schema::Long)?
+ self.writer.write(self.buffer.as_ref())?
+ self.append_marker()?;
self.buffer.clear();
self.num_values = 0;
Ok(num_bytes)
}
pub fn into_inner(self) -> W {
self.writer
}
fn append_marker(&mut self) -> Result<usize, Error> {
Ok(self.writer.write(&self.marker)?)
}
fn append_raw(&mut self, value: Value, schema: &Schema) -> Result<usize, Error> {
self.append_bytes(encode_to_vec(value, schema).as_ref())
}
fn append_bytes(&mut self, bytes: &[u8]) -> Result<usize, Error> {
Ok(self.writer.write(bytes)?)
}
fn header(&self) -> Result<Vec<u8>, Error> {
let schema_bytes = serde_json::to_string(self.schema)?.into_bytes();
let mut metadata = HashMap::with_capacity(2);
metadata.insert("avro.schema", Value::Bytes(schema_bytes));
metadata.insert("avro.codec", self.codec.avro());
let mut header = Vec::new();
header.extend_from_slice(AVRO_OBJECT_HEADER);
encode(
metadata.avro(),
&Schema::Map(Rc::new(Schema::Bytes)),
&mut header,
);
header.extend_from_slice(&self.marker);
Ok(header)
}
}
fn write_avro_datum<T: ToAvro>(
schema: &Schema,
value: T,
buffer: &mut Vec<u8>,
) -> Result<(), Error> {
let avro = value.avro();
if !avro.validate(schema) {
return Err(err_msg("value does not match schema"))
}
Ok(encode(avro, schema, buffer))
}
pub fn to_avro_datum<T: ToAvro>(schema: &Schema, value: T) -> Result<Vec<u8>, Error> {
let mut buffer = Vec::new();
write_avro_datum(schema, value, &mut buffer)?;
Ok(buffer)
}
#[cfg(test)]
mod tests {
use super::*;
use types::Record;
use util::zig_i64;
static SCHEMA: &'static str = r#"
{
"type": "record",
"name": "test",
"fields": [
{"name": "a", "type": "long", "default": 42},
{"name": "b", "type": "string"}
]
}
"#;
#[test]
fn test_to_avro_datum() {
let schema = Schema::parse_str(SCHEMA).unwrap();
let mut record = Record::new(&schema).unwrap();
record.put("a", 27i64);
record.put("b", "foo");
let mut expected = Vec::new();
zig_i64(27, &mut expected);
zig_i64(3, &mut expected);
expected.extend(vec![b'f', b'o', b'o'].into_iter());
assert_eq!(to_avro_datum(&schema, record).unwrap(), expected);
}
#[test]
fn test_writer_append() {
let schema = Schema::parse_str(SCHEMA).unwrap();
let mut writer = Writer::new(&schema, Vec::new());
let mut record = Record::new(&schema).unwrap();
record.put("a", 27i64);
record.put("b", "foo");
let n1 = writer.append(record.clone()).unwrap();
let n2 = writer.append(record.clone()).unwrap();
let n3 = writer.flush().unwrap();
let result = writer.into_inner();
assert_eq!(n1 + n2 + n3, result.len());
let mut header = Vec::new();
header.extend(vec![b'O', b'b', b'j', b'\x01']);
let mut data = Vec::new();
zig_i64(27, &mut data);
zig_i64(3, &mut data);
data.extend(vec![b'f', b'o', b'o'].into_iter());
let data_copy = data.clone();
data.extend(data_copy);
assert_eq!(
result
.iter()
.cloned()
.take(header.len())
.collect::<Vec<u8>>(),
header
);
assert_eq!(
result
.iter()
.cloned()
.rev()
.skip(16)
.take(data.len())
.collect::<Vec<u8>>()
.into_iter()
.rev()
.collect::<Vec<u8>>(),
data
);
}
#[test]
fn test_writer_extend() {
let schema = Schema::parse_str(SCHEMA).unwrap();
let mut writer = Writer::new(&schema, Vec::new());
let mut record = Record::new(&schema).unwrap();
record.put("a", 27i64);
record.put("b", "foo");
let record_copy = record.clone();
let records = vec![record, record_copy];
let n1 = writer.extend(records.into_iter()).unwrap();
let n2 = writer.flush().unwrap();
let result = writer.into_inner();
assert_eq!(n1 + n2, result.len());
let mut header = Vec::new();
header.extend(vec![b'O', b'b', b'j', b'\x01']);
let mut data = Vec::new();
zig_i64(27, &mut data);
zig_i64(3, &mut data);
data.extend(vec![b'f', b'o', b'o'].into_iter());
let data_copy = data.clone();
data.extend(data_copy);
assert_eq!(
result
.iter()
.cloned()
.take(header.len())
.collect::<Vec<u8>>(),
header
);
assert_eq!(
result
.iter()
.cloned()
.rev()
.skip(16)
.take(data.len())
.collect::<Vec<u8>>()
.into_iter()
.rev()
.collect::<Vec<u8>>(),
data
);
}
#[derive(Debug, Clone, Deserialize, Serialize)]
struct TestSerdeSerialize {
a: i64,
b: String,
}
#[test]
fn test_writer_append_ser() {
let schema = Schema::parse_str(SCHEMA).unwrap();
let mut writer = Writer::new(&schema, Vec::new());
let record = TestSerdeSerialize {
a: 27,
b: "foo".to_owned(),
};
let n1 = writer.append_ser(record).unwrap();
let n2 = writer.flush().unwrap();
let result = writer.into_inner();
assert_eq!(n1 + n2, result.len());
let mut header = Vec::new();
header.extend(vec![b'O', b'b', b'j', b'\x01']);
let mut data = Vec::new();
zig_i64(27, &mut data);
zig_i64(3, &mut data);
data.extend(vec![b'f', b'o', b'o'].into_iter());
assert_eq!(
result
.iter()
.cloned()
.take(header.len())
.collect::<Vec<u8>>(),
header
);
assert_eq!(
result
.iter()
.cloned()
.rev()
.skip(16)
.take(data.len())
.collect::<Vec<u8>>()
.into_iter()
.rev()
.collect::<Vec<u8>>(),
data
);
}
#[test]
fn test_writer_extend_ser() {
let schema = Schema::parse_str(SCHEMA).unwrap();
let mut writer = Writer::new(&schema, Vec::new());
let record = TestSerdeSerialize {
a: 27,
b: "foo".to_owned(),
};
let record_copy = record.clone();
let records = vec![record, record_copy];
let n1 = writer.extend_ser(records.into_iter()).unwrap();
let n2 = writer.flush().unwrap();
let result = writer.into_inner();
assert_eq!(n1 + n2, result.len());
let mut header = Vec::new();
header.extend(vec![b'O', b'b', b'j', b'\x01']);
let mut data = Vec::new();
zig_i64(27, &mut data);
zig_i64(3, &mut data);
data.extend(vec![b'f', b'o', b'o'].into_iter());
let data_copy = data.clone();
data.extend(data_copy);
assert_eq!(
result
.iter()
.cloned()
.take(header.len())
.collect::<Vec<u8>>(),
header
);
assert_eq!(
result
.iter()
.cloned()
.rev()
.skip(16)
.take(data.len())
.collect::<Vec<u8>>()
.into_iter()
.rev()
.collect::<Vec<u8>>(),
data
);
}
#[test]
fn test_writer_with_codec() {
let schema = Schema::parse_str(SCHEMA).unwrap();
let mut writer = Writer::with_codec(&schema, Vec::new(), Codec::Deflate);
let mut record = Record::new(&schema).unwrap();
record.put("a", 27i64);
record.put("b", "foo");
let n1 = writer.append(record.clone()).unwrap();
let n2 = writer.append(record.clone()).unwrap();
let n3 = writer.flush().unwrap();
let result = writer.into_inner();
assert_eq!(n1 + n2 + n3, result.len());
let mut header = Vec::new();
header.extend(vec![b'O', b'b', b'j', b'\x01']);
let mut data = Vec::new();
zig_i64(27, &mut data);
zig_i64(3, &mut data);
data.extend(vec![b'f', b'o', b'o'].into_iter());
let data_copy = data.clone();
data.extend(data_copy);
Codec::Deflate.compress(&mut data).unwrap();
assert_eq!(
result
.iter()
.cloned()
.take(header.len())
.collect::<Vec<u8>>(),
header
);
assert_eq!(
result
.iter()
.cloned()
.rev()
.skip(16)
.take(data.len())
.collect::<Vec<u8>>()
.into_iter()
.rev()
.collect::<Vec<u8>>(),
data
);
}
}