use std::sync::Arc;
use bcp_types::block::{Block, BlockContent};
use bcp_types::block_type::BlockType;
use bcp_types::content_store::ContentStore;
use bcp_types::summary::Summary;
use bcp_wire::block_frame::{BlockFlags, BlockFrame};
use bcp_wire::header::{HEADER_SIZE, BcpHeader};
use bcp_wire::varint::decode_varint;
use tokio::io::{AsyncRead, AsyncReadExt};
use crate::decompression::{self, MAX_BLOCK_DECOMPRESSED_SIZE, MAX_PAYLOAD_DECOMPRESSED_SIZE};
use crate::error::DecodeError;
#[derive(Clone, Debug)]
pub enum DecoderEvent {
Header(BcpHeader),
Block(Block),
}
pub struct StreamingDecoder<R> {
reader: R,
state: StreamState,
buf: Vec<u8>,
decompressed_payload: Option<Vec<u8>>,
decompressed_cursor: usize,
content_store: Option<Arc<dyn ContentStore>>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum StreamState {
ReadHeader,
ReadBlocks,
Done,
}
impl<R: AsyncRead + Unpin> StreamingDecoder<R> {
#[must_use]
pub fn new(reader: R) -> Self {
Self {
reader,
state: StreamState::ReadHeader,
buf: Vec::with_capacity(4096),
decompressed_payload: None,
decompressed_cursor: 0,
content_store: None,
}
}
#[must_use]
pub fn with_content_store(mut self, store: Arc<dyn ContentStore>) -> Self {
self.content_store = Some(store);
self
}
pub async fn next(&mut self) -> Option<Result<DecoderEvent, DecodeError>> {
match self.state {
StreamState::ReadHeader => Some(self.read_header().await),
StreamState::ReadBlocks => self.read_next_block().await,
StreamState::Done => None,
}
}
async fn read_header(&mut self) -> Result<DecoderEvent, DecodeError> {
let mut header_buf = [0u8; HEADER_SIZE];
self.reader.read_exact(&mut header_buf).await.map_err(|_| {
DecodeError::InvalidHeader(bcp_wire::WireError::UnexpectedEof { offset: 0 })
})?;
let header = BcpHeader::read_from(&header_buf).map_err(DecodeError::InvalidHeader)?;
if header.flags.is_compressed() {
let mut compressed = Vec::new();
self.reader
.read_to_end(&mut compressed)
.await
.map_err(DecodeError::Io)?;
let decompressed =
decompression::decompress(&compressed, MAX_PAYLOAD_DECOMPRESSED_SIZE)?;
self.decompressed_payload = Some(decompressed);
self.decompressed_cursor = 0;
}
self.state = StreamState::ReadBlocks;
Ok(DecoderEvent::Header(header))
}
async fn read_next_block(&mut self) -> Option<Result<DecoderEvent, DecodeError>> {
if let Some(ref payload) = self.decompressed_payload {
if self.decompressed_cursor >= payload.len() {
self.state = StreamState::Done;
return Some(Err(DecodeError::MissingEndSentinel));
}
let remaining = &payload[self.decompressed_cursor..];
match BlockFrame::read_from(remaining) {
Ok(Some((frame, consumed))) => {
self.decompressed_cursor += consumed;
Some(self.decode_frame(&frame))
}
Ok(None) => {
match end_sentinel_size(remaining) {
Ok(size) => self.decompressed_cursor += size,
Err(e) => return Some(Err(e)),
}
self.state = StreamState::Done;
None
}
Err(e) => Some(Err(DecodeError::from(e))),
}
} else {
self.read_next_block_from_reader().await
}
}
async fn read_next_block_from_reader(&mut self) -> Option<Result<DecoderEvent, DecodeError>> {
let block_type_raw = match self.read_varint().await {
Ok(v) => v,
Err(e) => return Some(Err(e)),
};
#[allow(clippy::cast_possible_truncation)]
let block_type_byte = block_type_raw as u8;
if block_type_byte == 0xFF {
match self.read_end_frame_tail().await {
Ok(()) => {}
Err(e) => return Some(Err(e)),
}
self.state = StreamState::Done;
return None;
}
let mut flags_byte = [0u8; 1];
if let Err(e) = self.reader.read_exact(&mut flags_byte).await {
return Some(Err(DecodeError::Io(e)));
}
let flags = BlockFlags::from_raw(flags_byte[0]);
#[allow(clippy::cast_possible_truncation)]
let content_len = match self.read_varint().await {
Ok(v) => v as usize,
Err(e) => return Some(Err(e)),
};
self.buf.clear();
self.buf.resize(content_len, 0);
if let Err(e) = self.reader.read_exact(&mut self.buf[..content_len]).await {
return Some(Err(DecodeError::Io(e)));
}
let frame = bcp_wire::block_frame::BlockFrame {
block_type: block_type_byte,
flags,
body: self.buf[..content_len].to_vec(),
};
Some(self.decode_frame(&frame))
}
fn decode_frame(
&self,
frame: &bcp_wire::block_frame::BlockFrame,
) -> Result<DecoderEvent, DecodeError> {
let block_type = BlockType::from_wire_id(frame.block_type);
let resolved_body = if frame.flags.is_reference() {
let store = self
.content_store
.as_ref()
.ok_or(DecodeError::MissingContentStore)?;
if frame.body.len() != 32 {
return Err(DecodeError::Wire(bcp_wire::WireError::UnexpectedEof {
offset: frame.body.len(),
}));
}
let hash: [u8; 32] = frame.body[..32].try_into().expect("length already checked");
store
.get(&hash)
.ok_or(DecodeError::UnresolvedReference { hash })?
} else {
frame.body.clone()
};
let decompressed_body = if frame.flags.is_compressed() {
decompression::decompress(&resolved_body, MAX_BLOCK_DECOMPRESSED_SIZE)?
} else {
resolved_body
};
let mut body = decompressed_body.as_slice();
let mut summary = None;
if frame.flags.has_summary() {
match Summary::decode(body) {
Ok((sum, consumed)) => {
summary = Some(sum);
body = &body[consumed..];
}
Err(e) => return Err(e.into()),
}
}
let content = BlockContent::decode_body(&block_type, body)?;
Ok(DecoderEvent::Block(Block {
block_type,
flags: frame.flags,
summary,
content,
}))
}
async fn read_end_frame_tail(&mut self) -> Result<(), DecodeError> {
let mut byte = [0u8; 1];
self.reader
.read_exact(&mut byte)
.await
.map_err(DecodeError::Io)?;
let _content_len = self.read_varint().await?;
Ok(())
}
async fn read_varint(&mut self) -> Result<u64, DecodeError> {
let mut varint_buf = [0u8; 10];
let mut len = 0;
loop {
let mut byte = [0u8; 1];
self.reader
.read_exact(&mut byte)
.await
.map_err(DecodeError::Io)?;
varint_buf[len] = byte[0];
len += 1;
if byte[0] & 0x80 == 0 {
break;
}
if len >= 10 {
return Err(DecodeError::Wire(bcp_wire::WireError::VarintTooLong));
}
}
let (value, _) = decode_varint(&varint_buf[..len])?;
Ok(value)
}
}
fn end_sentinel_size(buf: &[u8]) -> Result<usize, DecodeError> {
let (_, type_len) = decode_varint(buf)?;
let mut size = type_len;
size += 1;
let rest = buf
.get(size..)
.ok_or(DecodeError::Wire(bcp_wire::WireError::UnexpectedEof {
offset: size,
}))?;
let (_, len_size) = decode_varint(rest)?;
size += len_size;
Ok(size)
}
#[cfg(test)]
mod tests {
use super::*;
use bcp_encoder::BcpEncoder;
use bcp_types::enums::{Lang, Priority, Role, Status};
async fn stream_roundtrip(encoder: &BcpEncoder) -> Vec<DecoderEvent> {
let payload = encoder.encode().unwrap();
let cursor = std::io::Cursor::new(payload);
let reader = tokio::io::BufReader::new(cursor);
let mut decoder = StreamingDecoder::new(reader);
let mut events = Vec::new();
while let Some(result) = decoder.next().await {
events.push(result.unwrap());
}
events
}
#[tokio::test]
async fn streaming_produces_header_then_blocks() {
let mut enc = BcpEncoder::new();
enc.add_code(Lang::Rust, "main.rs", b"fn main() {}")
.add_conversation(Role::User, b"hello");
let events = stream_roundtrip(&enc).await;
assert_eq!(events.len(), 3);
assert!(matches!(&events[0], DecoderEvent::Header(h) if h.version_major == 1));
assert!(matches!(&events[1], DecoderEvent::Block(b) if b.block_type == BlockType::Code));
assert!(
matches!(&events[2], DecoderEvent::Block(b) if b.block_type == BlockType::Conversation)
);
}
#[tokio::test]
async fn streaming_matches_sync_decoder() {
let mut encoder = BcpEncoder::new();
encoder
.add_code(Lang::Rust, "lib.rs", b"pub fn x() {}")
.with_summary("Function x.").unwrap()
.with_priority(Priority::High).unwrap()
.add_conversation(Role::User, b"What does x do?")
.add_tool_result("docs", Status::Ok, b"x is a placeholder.");
let payload = encoder.encode().unwrap();
let sync_decoded = crate::BcpDecoder::decode(&payload).unwrap();
let events = stream_roundtrip(&encoder).await;
let stream_blocks: Vec<_> = events
.into_iter()
.filter_map(|e| match e {
DecoderEvent::Block(b) => Some(b),
_ => None,
})
.collect();
assert_eq!(sync_decoded.blocks.len(), stream_blocks.len());
for (sync_block, stream_block) in sync_decoded.blocks.iter().zip(stream_blocks.iter()) {
assert_eq!(sync_block.block_type, stream_block.block_type);
assert_eq!(sync_block.flags, stream_block.flags);
assert_eq!(sync_block.summary, stream_block.summary);
}
}
#[tokio::test]
async fn streaming_handles_summary_blocks() {
let mut enc = BcpEncoder::new();
enc.add_code(Lang::Python, "app.py", b"print('hi')")
.with_summary("Prints a greeting.").unwrap();
let events = stream_roundtrip(&enc).await;
let block = match &events[1] {
DecoderEvent::Block(b) => b,
other => panic!("expected Block, got {other:?}"),
};
assert!(block.flags.has_summary());
assert_eq!(block.summary.as_ref().unwrap().text, "Prints a greeting.");
}
#[tokio::test]
async fn streaming_empty_body_blocks() {
let mut enc = BcpEncoder::new();
enc.add_extension("ns", "t", b"");
let events = stream_roundtrip(&enc).await;
assert_eq!(events.len(), 2); }
#[tokio::test]
async fn streaming_terminates_at_end_sentinel() {
let mut enc = BcpEncoder::new();
enc.add_conversation(Role::User, b"hi");
let events = stream_roundtrip(&enc).await;
assert_eq!(events.len(), 2); }
#[tokio::test]
async fn streaming_per_block_compression_roundtrip() {
let big_content = "fn main() { println!(\"hello world\"); }\n".repeat(50);
let mut enc = BcpEncoder::new();
enc.add_code(Lang::Rust, "main.rs", big_content.as_bytes())
.with_compression().unwrap();
let events = stream_roundtrip(&enc).await;
assert_eq!(events.len(), 2); let block = match &events[1] {
DecoderEvent::Block(b) => b,
other => panic!("expected Block, got {other:?}"),
};
match &block.content {
BlockContent::Code(code) => {
assert_eq!(code.content, big_content.as_bytes());
}
other => panic!("expected Code, got {other:?}"),
}
}
#[tokio::test]
async fn streaming_whole_payload_compression_roundtrip() {
let big_content = "use std::io;\n".repeat(100);
let mut enc = BcpEncoder::new();
enc.add_code(Lang::Rust, "a.rs", big_content.as_bytes())
.add_code(Lang::Rust, "b.rs", big_content.as_bytes());
enc.compress_payload();
let events = stream_roundtrip(&enc).await;
assert_eq!(events.len(), 3);
match &events[0] {
DecoderEvent::Header(h) => assert!(h.flags.is_compressed()),
other => panic!("expected Header, got {other:?}"),
}
for event in &events[1..] {
match event {
DecoderEvent::Block(block) => match &block.content {
BlockContent::Code(code) => {
assert_eq!(code.content, big_content.as_bytes());
}
other => panic!("expected Code, got {other:?}"),
},
other => panic!("expected Block, got {other:?}"),
}
}
}
#[tokio::test]
async fn streaming_content_addressing_roundtrip() {
let store = Arc::new(bcp_encoder::MemoryContentStore::new());
let mut enc = BcpEncoder::new();
enc.set_content_store(store.clone())
.add_code(Lang::Rust, "main.rs", b"fn main() {}")
.with_content_addressing().unwrap();
let payload = enc.encode().unwrap();
let cursor = std::io::Cursor::new(payload);
let reader = tokio::io::BufReader::new(cursor);
let mut decoder = StreamingDecoder::new(reader).with_content_store(store);
let mut events = Vec::new();
while let Some(result) = decoder.next().await {
events.push(result.unwrap());
}
assert_eq!(events.len(), 2); match &events[1] {
DecoderEvent::Block(block) => match &block.content {
BlockContent::Code(code) => {
assert_eq!(code.content, b"fn main() {}");
}
other => panic!("expected Code, got {other:?}"),
},
other => panic!("expected Block, got {other:?}"),
}
}
#[tokio::test]
async fn streaming_matches_sync_compressed() {
let big_content = "pub fn hello() -> &'static str { \"world\" }\n".repeat(100);
let mut encoder = BcpEncoder::new();
encoder
.add_code(Lang::Rust, "lib.rs", big_content.as_bytes())
.with_summary("Hello function.").unwrap()
.add_conversation(Role::User, b"explain");
encoder.compress_payload();
let payload = encoder.encode().unwrap();
let sync_decoded = crate::BcpDecoder::decode(&payload).unwrap();
let events = stream_roundtrip(&encoder).await;
let stream_blocks: Vec<_> = events
.into_iter()
.filter_map(|e| match e {
DecoderEvent::Block(b) => Some(b),
_ => None,
})
.collect();
assert_eq!(sync_decoded.blocks.len(), stream_blocks.len());
for (sync_block, stream_block) in sync_decoded.blocks.iter().zip(stream_blocks.iter()) {
assert_eq!(sync_block.block_type, stream_block.block_type);
assert_eq!(sync_block.summary, stream_block.summary);
}
}
}