use rmpv::Value as MpValue;
use crate::Error;
#[derive(Debug)]
pub struct FuseResult {
pub payload: Vec<u8>,
pub shards_merged: usize,
}
pub fn fuse_payloads(payloads: Vec<Vec<u8>>) -> Result<FuseResult, Error> {
if payloads.is_empty() {
return Ok(FuseResult {
payload: encode_empty_array(),
shards_merged: 0,
});
}
if payloads.len() == 1 {
let single = payloads.into_iter().next().expect("len==1");
let shards_merged = 1;
return Ok(FuseResult {
payload: single,
shards_merged,
});
}
let mut all_rows: Vec<MpValue> = Vec::new();
let mut non_empty = 0usize;
for payload in &payloads {
if payload.is_empty() {
continue;
}
let rows = decode_msgpack_array(payload)?;
if !rows.is_empty() {
non_empty += 1;
all_rows.extend(rows);
}
}
let merged = encode_msgpack_array(&all_rows).map_err(|e| Error::Serialization {
format: "msgpack".into(),
detail: format!("fuser: encode failed: {e}"),
})?;
Ok(FuseResult {
payload: merged,
shards_merged: non_empty,
})
}
fn decode_msgpack_array(bytes: &[u8]) -> Result<Vec<MpValue>, Error> {
if bytes.is_empty() {
return Ok(Vec::new());
}
let mut cursor = std::io::Cursor::new(bytes);
let value: MpValue =
rmpv::decode::read_value(&mut cursor).map_err(|e| Error::Serialization {
format: "msgpack".into(),
detail: format!("fuser: decode failed: {e}"),
})?;
match value {
MpValue::Array(rows) => Ok(rows),
other => Ok(vec![other]),
}
}
fn encode_msgpack_array(rows: &[MpValue]) -> Result<Vec<u8>, rmpv::encode::Error> {
let v = MpValue::Array(rows.to_vec());
let mut buf = Vec::new();
rmpv::encode::write_value(&mut buf, &v)?;
Ok(buf)
}
fn encode_empty_array() -> Vec<u8> {
vec![0x90]
}
pub fn push_up_commutative_aggregate(
payloads: Vec<Vec<u8>>,
agg_type: &str,
) -> Option<Result<Vec<u8>, Error>> {
match agg_type.to_uppercase().as_str() {
"SUM" | "COUNT" => {}
_ => return None,
}
Some(fuse_payloads(payloads).map(|r| r.payload))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fuse_empty_produces_empty_array() {
let r = fuse_payloads(vec![]).unwrap();
assert_eq!(r.payload, vec![0x90]);
assert_eq!(r.shards_merged, 0);
}
#[test]
fn fuse_single_passthrough() {
let data = vec![0x91, 0x01]; let r = fuse_payloads(vec![data.clone()]).unwrap();
assert_eq!(r.payload, data);
assert_eq!(r.shards_merged, 1);
}
#[test]
fn fuse_two_arrays() {
let p1 = encode_row_array(&[1i64]).unwrap();
let p2 = encode_row_array(&[2i64]).unwrap();
let r = fuse_payloads(vec![p1, p2]).unwrap();
let rows = decode_msgpack_array(&r.payload).unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(r.shards_merged, 2);
}
#[test]
fn fuse_skips_empty_payloads() {
let p1 = vec![];
let p2 = encode_row_array(&[99i64]).unwrap();
let r = fuse_payloads(vec![p1, p2]).unwrap();
let rows = decode_msgpack_array(&r.payload).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(r.shards_merged, 1);
}
#[test]
fn push_up_sum_is_commutative() {
let p1 = encode_row_array(&[1i64]).unwrap();
let p2 = encode_row_array(&[2i64]).unwrap();
let result = push_up_commutative_aggregate(vec![p1, p2], "SUM");
assert!(result.is_some());
assert!(result.unwrap().is_ok());
}
#[test]
fn push_up_avg_is_not_commutative() {
let p1 = encode_row_array(&[1i64]).unwrap();
let result = push_up_commutative_aggregate(vec![p1], "AVG");
assert!(result.is_none());
}
fn encode_row_array(values: &[i64]) -> Result<Vec<u8>, rmpv::encode::Error> {
let rows: Vec<MpValue> = values.iter().map(|&v| MpValue::Integer(v.into())).collect();
encode_msgpack_array(&rows)
}
}