1use std::sync::Arc;
2
3use bcp_types::block::{Block, BlockContent};
4use bcp_types::block_type::BlockType;
5use bcp_types::content_store::ContentStore;
6use bcp_types::summary::Summary;
7use bcp_wire::block_frame::{BlockFlags, BlockFrame};
8use bcp_wire::header::{HEADER_SIZE, BcpHeader};
9use bcp_wire::varint::decode_varint;
10use tokio::io::{AsyncRead, AsyncReadExt};
11
12use crate::decompression::{self, MAX_BLOCK_DECOMPRESSED_SIZE, MAX_PAYLOAD_DECOMPRESSED_SIZE};
13use crate::error::DecodeError;
14
15#[derive(Clone, Debug)]
30pub enum DecoderEvent {
31 Header(BcpHeader),
33
34 Block(Block),
36}
37
38pub struct StreamingDecoder<R> {
86 reader: R,
87 state: StreamState,
88 buf: Vec<u8>,
92 decompressed_payload: Option<Vec<u8>>,
96 decompressed_cursor: usize,
98 content_store: Option<Arc<dyn ContentStore>>,
100}
101
102#[derive(Clone, Copy, Debug, PartialEq, Eq)]
114enum StreamState {
115 ReadHeader,
116 ReadBlocks,
117 Done,
118}
119
120impl<R: AsyncRead + Unpin> StreamingDecoder<R> {
121 #[must_use]
126 pub fn new(reader: R) -> Self {
127 Self {
128 reader,
129 state: StreamState::ReadHeader,
130 buf: Vec::with_capacity(4096),
131 decompressed_payload: None,
132 decompressed_cursor: 0,
133 content_store: None,
134 }
135 }
136
137 #[must_use]
142 pub fn with_content_store(mut self, store: Arc<dyn ContentStore>) -> Self {
143 self.content_store = Some(store);
144 self
145 }
146
147 pub async fn next(&mut self) -> Option<Result<DecoderEvent, DecodeError>> {
156 match self.state {
157 StreamState::ReadHeader => Some(self.read_header().await),
158 StreamState::ReadBlocks => self.read_next_block().await,
159 StreamState::Done => None,
160 }
161 }
162
163 async fn read_header(&mut self) -> Result<DecoderEvent, DecodeError> {
170 let mut header_buf = [0u8; HEADER_SIZE];
171 self.reader.read_exact(&mut header_buf).await.map_err(|_| {
172 DecodeError::InvalidHeader(bcp_wire::WireError::UnexpectedEof { offset: 0 })
173 })?;
174
175 let header = BcpHeader::read_from(&header_buf).map_err(DecodeError::InvalidHeader)?;
176
177 if header.flags.is_compressed() {
179 let mut compressed = Vec::new();
180 self.reader
181 .read_to_end(&mut compressed)
182 .await
183 .map_err(DecodeError::Io)?;
184 let decompressed =
185 decompression::decompress(&compressed, MAX_PAYLOAD_DECOMPRESSED_SIZE)?;
186 self.decompressed_payload = Some(decompressed);
187 self.decompressed_cursor = 0;
188 }
189
190 self.state = StreamState::ReadBlocks;
191 Ok(DecoderEvent::Header(header))
192 }
193
194 async fn read_next_block(&mut self) -> Option<Result<DecoderEvent, DecodeError>> {
205 if let Some(ref payload) = self.decompressed_payload {
208 if self.decompressed_cursor >= payload.len() {
209 self.state = StreamState::Done;
210 return Some(Err(DecodeError::MissingEndSentinel));
211 }
212
213 let remaining = &payload[self.decompressed_cursor..];
214 match BlockFrame::read_from(remaining) {
215 Ok(Some((frame, consumed))) => {
216 self.decompressed_cursor += consumed;
217 Some(self.decode_frame(&frame))
218 }
219 Ok(None) => {
220 match end_sentinel_size(remaining) {
223 Ok(size) => self.decompressed_cursor += size,
224 Err(e) => return Some(Err(e)),
225 }
226 self.state = StreamState::Done;
227 None
228 }
229 Err(e) => Some(Err(DecodeError::from(e))),
230 }
231 } else {
232 self.read_next_block_from_reader().await
233 }
234 }
235
236 async fn read_next_block_from_reader(&mut self) -> Option<Result<DecoderEvent, DecodeError>> {
238 let block_type_raw = match self.read_varint().await {
240 Ok(v) => v,
241 Err(e) => return Some(Err(e)),
242 };
243
244 #[allow(clippy::cast_possible_truncation)]
245 let block_type_byte = block_type_raw as u8;
246
247 if block_type_byte == 0xFF {
249 match self.read_end_frame_tail().await {
250 Ok(()) => {}
251 Err(e) => return Some(Err(e)),
252 }
253 self.state = StreamState::Done;
254 return None;
255 }
256
257 let mut flags_byte = [0u8; 1];
259 if let Err(e) = self.reader.read_exact(&mut flags_byte).await {
260 return Some(Err(DecodeError::Io(e)));
261 }
262 let flags = BlockFlags::from_raw(flags_byte[0]);
263
264 #[allow(clippy::cast_possible_truncation)]
266 let content_len = match self.read_varint().await {
267 Ok(v) => v as usize,
268 Err(e) => return Some(Err(e)),
269 };
270
271 self.buf.clear();
273 self.buf.resize(content_len, 0);
274 if let Err(e) = self.reader.read_exact(&mut self.buf[..content_len]).await {
275 return Some(Err(DecodeError::Io(e)));
276 }
277
278 let frame = bcp_wire::block_frame::BlockFrame {
279 block_type: block_type_byte,
280 flags,
281 body: self.buf[..content_len].to_vec(),
282 };
283
284 Some(self.decode_frame(&frame))
285 }
286
287 fn decode_frame(
292 &self,
293 frame: &bcp_wire::block_frame::BlockFrame,
294 ) -> Result<DecoderEvent, DecodeError> {
295 let block_type = BlockType::from_wire_id(frame.block_type);
296
297 let resolved_body = if frame.flags.is_reference() {
299 let store = self
300 .content_store
301 .as_ref()
302 .ok_or(DecodeError::MissingContentStore)?;
303 if frame.body.len() != 32 {
304 return Err(DecodeError::Wire(bcp_wire::WireError::UnexpectedEof {
305 offset: frame.body.len(),
306 }));
307 }
308 let hash: [u8; 32] = frame.body[..32].try_into().expect("length already checked");
309 store
310 .get(&hash)
311 .ok_or(DecodeError::UnresolvedReference { hash })?
312 } else {
313 frame.body.clone()
314 };
315
316 let decompressed_body = if frame.flags.is_compressed() {
318 decompression::decompress(&resolved_body, MAX_BLOCK_DECOMPRESSED_SIZE)?
319 } else {
320 resolved_body
321 };
322
323 let mut body = decompressed_body.as_slice();
325 let mut summary = None;
326
327 if frame.flags.has_summary() {
328 match Summary::decode(body) {
329 Ok((sum, consumed)) => {
330 summary = Some(sum);
331 body = &body[consumed..];
332 }
333 Err(e) => return Err(e.into()),
334 }
335 }
336
337 let content = BlockContent::decode_body(&block_type, body)?;
338
339 Ok(DecoderEvent::Block(Block {
340 block_type,
341 flags: frame.flags,
342 summary,
343 content,
344 }))
345 }
346
347 async fn read_end_frame_tail(&mut self) -> Result<(), DecodeError> {
352 let mut byte = [0u8; 1];
354 self.reader
355 .read_exact(&mut byte)
356 .await
357 .map_err(DecodeError::Io)?;
358
359 let _content_len = self.read_varint().await?;
361 Ok(())
362 }
363
364 async fn read_varint(&mut self) -> Result<u64, DecodeError> {
370 let mut varint_buf = [0u8; 10];
371 let mut len = 0;
372
373 loop {
374 let mut byte = [0u8; 1];
375 self.reader
376 .read_exact(&mut byte)
377 .await
378 .map_err(DecodeError::Io)?;
379 varint_buf[len] = byte[0];
380 len += 1;
381
382 if byte[0] & 0x80 == 0 {
384 break;
385 }
386
387 if len >= 10 {
388 return Err(DecodeError::Wire(bcp_wire::WireError::VarintTooLong));
389 }
390 }
391
392 let (value, _) = decode_varint(&varint_buf[..len])?;
393 Ok(value)
394 }
395}
396
397fn end_sentinel_size(buf: &[u8]) -> Result<usize, DecodeError> {
403 let (_, type_len) = decode_varint(buf)?;
404 let mut size = type_len;
405 size += 1;
407 let rest = buf
409 .get(size..)
410 .ok_or(DecodeError::Wire(bcp_wire::WireError::UnexpectedEof {
411 offset: size,
412 }))?;
413 let (_, len_size) = decode_varint(rest)?;
414 size += len_size;
415 Ok(size)
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421 use bcp_encoder::BcpEncoder;
422 use bcp_types::enums::{Lang, Priority, Role, Status};
423
424 async fn stream_roundtrip(encoder: &BcpEncoder) -> Vec<DecoderEvent> {
427 let payload = encoder.encode().unwrap();
428 let cursor = std::io::Cursor::new(payload);
429 let reader = tokio::io::BufReader::new(cursor);
430
431 let mut decoder = StreamingDecoder::new(reader);
432 let mut events = Vec::new();
433
434 while let Some(result) = decoder.next().await {
435 events.push(result.unwrap());
436 }
437
438 events
439 }
440
441 #[tokio::test]
442 async fn streaming_produces_header_then_blocks() {
443 let mut enc = BcpEncoder::new();
444 enc.add_code(Lang::Rust, "main.rs", b"fn main() {}")
445 .add_conversation(Role::User, b"hello");
446 let events = stream_roundtrip(&enc).await;
447
448 assert_eq!(events.len(), 3); assert!(matches!(&events[0], DecoderEvent::Header(h) if h.version_major == 1));
451 assert!(matches!(&events[1], DecoderEvent::Block(b) if b.block_type == BlockType::Code));
452 assert!(
453 matches!(&events[2], DecoderEvent::Block(b) if b.block_type == BlockType::Conversation)
454 );
455 }
456
457 #[tokio::test]
458 async fn streaming_matches_sync_decoder() {
459 let mut encoder = BcpEncoder::new();
460 encoder
461 .add_code(Lang::Rust, "lib.rs", b"pub fn x() {}")
462 .with_summary("Function x.").unwrap()
463 .with_priority(Priority::High).unwrap()
464 .add_conversation(Role::User, b"What does x do?")
465 .add_tool_result("docs", Status::Ok, b"x is a placeholder.");
466
467 let payload = encoder.encode().unwrap();
468
469 let sync_decoded = crate::BcpDecoder::decode(&payload).unwrap();
471
472 let events = stream_roundtrip(&encoder).await;
474
475 let stream_blocks: Vec<_> = events
477 .into_iter()
478 .filter_map(|e| match e {
479 DecoderEvent::Block(b) => Some(b),
480 _ => None,
481 })
482 .collect();
483
484 assert_eq!(sync_decoded.blocks.len(), stream_blocks.len());
486
487 for (sync_block, stream_block) in sync_decoded.blocks.iter().zip(stream_blocks.iter()) {
489 assert_eq!(sync_block.block_type, stream_block.block_type);
490 assert_eq!(sync_block.flags, stream_block.flags);
491 assert_eq!(sync_block.summary, stream_block.summary);
492 }
493 }
494
495 #[tokio::test]
496 async fn streaming_handles_summary_blocks() {
497 let mut enc = BcpEncoder::new();
498 enc.add_code(Lang::Python, "app.py", b"print('hi')")
499 .with_summary("Prints a greeting.").unwrap();
500 let events = stream_roundtrip(&enc).await;
501
502 let block = match &events[1] {
503 DecoderEvent::Block(b) => b,
504 other => panic!("expected Block, got {other:?}"),
505 };
506
507 assert!(block.flags.has_summary());
508 assert_eq!(block.summary.as_ref().unwrap().text, "Prints a greeting.");
509 }
510
511 #[tokio::test]
512 async fn streaming_empty_body_blocks() {
513 let mut enc = BcpEncoder::new();
514 enc.add_extension("ns", "t", b"");
515 let events = stream_roundtrip(&enc).await;
516
517 assert_eq!(events.len(), 2); }
519
520 #[tokio::test]
521 async fn streaming_terminates_at_end_sentinel() {
522 let mut enc = BcpEncoder::new();
523 enc.add_conversation(Role::User, b"hi");
524 let events = stream_roundtrip(&enc).await;
525
526 assert_eq!(events.len(), 2); }
529
530 #[tokio::test]
533 async fn streaming_per_block_compression_roundtrip() {
534 let big_content = "fn main() { println!(\"hello world\"); }\n".repeat(50);
535 let mut enc = BcpEncoder::new();
536 enc.add_code(Lang::Rust, "main.rs", big_content.as_bytes())
537 .with_compression().unwrap();
538 let events = stream_roundtrip(&enc).await;
539
540 assert_eq!(events.len(), 2); let block = match &events[1] {
542 DecoderEvent::Block(b) => b,
543 other => panic!("expected Block, got {other:?}"),
544 };
545
546 match &block.content {
547 BlockContent::Code(code) => {
548 assert_eq!(code.content, big_content.as_bytes());
549 }
550 other => panic!("expected Code, got {other:?}"),
551 }
552 }
553
554 #[tokio::test]
557 async fn streaming_whole_payload_compression_roundtrip() {
558 let big_content = "use std::io;\n".repeat(100);
559 let mut enc = BcpEncoder::new();
560 enc.add_code(Lang::Rust, "a.rs", big_content.as_bytes())
561 .add_code(Lang::Rust, "b.rs", big_content.as_bytes());
562 enc.compress_payload();
563 let events = stream_roundtrip(&enc).await;
564
565 assert_eq!(events.len(), 3);
567
568 match &events[0] {
570 DecoderEvent::Header(h) => assert!(h.flags.is_compressed()),
571 other => panic!("expected Header, got {other:?}"),
572 }
573
574 for event in &events[1..] {
576 match event {
577 DecoderEvent::Block(block) => match &block.content {
578 BlockContent::Code(code) => {
579 assert_eq!(code.content, big_content.as_bytes());
580 }
581 other => panic!("expected Code, got {other:?}"),
582 },
583 other => panic!("expected Block, got {other:?}"),
584 }
585 }
586 }
587
588 #[tokio::test]
591 async fn streaming_content_addressing_roundtrip() {
592 let store = Arc::new(bcp_encoder::MemoryContentStore::new());
593 let mut enc = BcpEncoder::new();
594 enc.set_content_store(store.clone())
595 .add_code(Lang::Rust, "main.rs", b"fn main() {}")
596 .with_content_addressing().unwrap();
597
598 let payload = enc.encode().unwrap();
599 let cursor = std::io::Cursor::new(payload);
600 let reader = tokio::io::BufReader::new(cursor);
601
602 let mut decoder = StreamingDecoder::new(reader).with_content_store(store);
603 let mut events = Vec::new();
604 while let Some(result) = decoder.next().await {
605 events.push(result.unwrap());
606 }
607
608 assert_eq!(events.len(), 2); match &events[1] {
610 DecoderEvent::Block(block) => match &block.content {
611 BlockContent::Code(code) => {
612 assert_eq!(code.content, b"fn main() {}");
613 }
614 other => panic!("expected Code, got {other:?}"),
615 },
616 other => panic!("expected Block, got {other:?}"),
617 }
618 }
619
620 #[tokio::test]
621 async fn streaming_matches_sync_compressed() {
622 let big_content = "pub fn hello() -> &'static str { \"world\" }\n".repeat(100);
623 let mut encoder = BcpEncoder::new();
624 encoder
625 .add_code(Lang::Rust, "lib.rs", big_content.as_bytes())
626 .with_summary("Hello function.").unwrap()
627 .add_conversation(Role::User, b"explain");
628 encoder.compress_payload();
629
630 let payload = encoder.encode().unwrap();
631
632 let sync_decoded = crate::BcpDecoder::decode(&payload).unwrap();
634
635 let events = stream_roundtrip(&encoder).await;
637 let stream_blocks: Vec<_> = events
638 .into_iter()
639 .filter_map(|e| match e {
640 DecoderEvent::Block(b) => Some(b),
641 _ => None,
642 })
643 .collect();
644
645 assert_eq!(sync_decoded.blocks.len(), stream_blocks.len());
646 for (sync_block, stream_block) in sync_decoded.blocks.iter().zip(stream_blocks.iter()) {
647 assert_eq!(sync_block.block_type, stream_block.block_type);
648 assert_eq!(sync_block.summary, stream_block.summary);
649 }
650 }
651}