Skip to main content

bcp_decoder/
streaming.rs

1use std::sync::Arc;
2
3use bcp_types::block::{Block, BlockContent};
4use bcp_types::block_type::BlockType;
5use bcp_types::content_store::ContentStore;
6use bcp_types::summary::Summary;
7use bcp_wire::block_frame::{BlockFlags, BlockFrame};
8use bcp_wire::header::{HEADER_SIZE, BcpHeader};
9use bcp_wire::varint::decode_varint;
10use tokio::io::{AsyncRead, AsyncReadExt};
11
12use crate::decompression::{self, MAX_BLOCK_DECOMPRESSED_SIZE, MAX_PAYLOAD_DECOMPRESSED_SIZE};
13use crate::error::DecodeError;
14
15/// Events emitted by the streaming decoder.
16///
17/// The stream yields a `Header` event first (once the 8-byte file
18/// header has been read and validated), then a sequence of `Block`
19/// events for each decoded block, terminating when the END sentinel
20/// is encountered.
21///
22/// ```text
23///   Header(BcpHeader)
24///   Block(Block)
25///   Block(Block)
26///   Block(Block)
27///   ... (stream ends at END sentinel)
28/// ```
29#[derive(Clone, Debug)]
30pub enum DecoderEvent {
31    /// The file header has been parsed and validated.
32    Header(BcpHeader),
33
34    /// A block has been fully decoded.
35    Block(Block),
36}
37
38/// Asynchronous streaming decoder — yields blocks one at a time
39/// without buffering the entire payload.
40///
41/// This is the primary API for large payloads or network streams.
42/// The decoder reads the header first, then yields blocks as they
43/// are fully received. Backpressure is handled naturally: the stream
44/// only reads the next block when the caller awaits the next item.
45///
46/// Unlike the synchronous [`BcpDecoder`](crate::BcpDecoder) which
47/// requires the entire payload in memory, `StreamingDecoder` reads
48/// incrementally from any `AsyncRead` source (files, TCP sockets,
49/// HTTP response bodies, etc.).
50///
51/// # Important: whole-payload compression disables streaming
52///
53/// **When the header's `COMPRESSED` flag is set (whole-payload
54/// compression), the streaming decoder falls back to buffering the
55/// entire payload before yielding any blocks.** This is unavoidable:
56/// zstd requires the full compressed input to decompress. The API
57/// surface remains the same (you still call `next()` in a loop), but
58/// the memory and latency characteristics become identical to
59/// [`BcpDecoder::decode`](crate::BcpDecoder::decode).
60///
61/// True incremental streaming is only achieved with **uncompressed**
62/// or **per-block compressed** payloads. If streaming is critical for
63/// your use case, prefer per-block compression
64/// ([`BcpEncoder::compress_blocks`]) over whole-payload compression
65/// ([`BcpEncoder::compress_payload`]).
66///
67/// # Content store
68///
69/// To decode payloads with `IS_REFERENCE` blocks, provide a content
70/// store via [`with_content_store`](Self::with_content_store).
71///
72/// # Example
73///
74/// ```rust,no_run
75/// use bcp_decoder::StreamingDecoder;
76/// use tokio::io::AsyncRead;
77///
78/// async fn decode_from_reader(reader: impl AsyncRead + Unpin) {
79///     let mut stream = StreamingDecoder::new(reader);
80///     while let Some(event) = stream.next().await.transpose().unwrap() {
81///         // Process each DecoderEvent...
82///     }
83/// }
84/// ```
85pub struct StreamingDecoder<R> {
86    reader: R,
87    state: StreamState,
88    /// Internal read buffer. Block bodies are read into this buffer
89    /// before being parsed. The buffer is reused across blocks to
90    /// avoid repeated allocations.
91    buf: Vec<u8>,
92    /// When whole-payload compression is detected, the entire stream
93    /// is read and decompressed into this buffer. Subsequent block
94    /// reads consume from here instead of the original reader.
95    decompressed_payload: Option<Vec<u8>>,
96    /// Read cursor into `decompressed_payload`.
97    decompressed_cursor: usize,
98    /// Optional content store for resolving `IS_REFERENCE` blocks.
99    content_store: Option<Arc<dyn ContentStore>>,
100}
101
102/// Internal state machine for the streaming decoder.
103///
104/// The decoder progresses through three states:
105///
106/// ```text
107///   ReadHeader → ReadBlocks → Done
108/// ```
109///
110/// `ReadHeader` is the initial state. After the header is read, the
111/// decoder transitions to `ReadBlocks` and stays there until the END
112/// sentinel is encountered, at which point it transitions to `Done`.
113#[derive(Clone, Copy, Debug, PartialEq, Eq)]
114enum StreamState {
115    ReadHeader,
116    ReadBlocks,
117    Done,
118}
119
120impl<R: AsyncRead + Unpin> StreamingDecoder<R> {
121    /// Create a new streaming decoder over the given async reader.
122    ///
123    /// The decoder starts in `ReadHeader` state and will read the
124    /// 8-byte file header on the first call to [`next`](Self::next).
125    #[must_use]
126    pub fn new(reader: R) -> Self {
127        Self {
128            reader,
129            state: StreamState::ReadHeader,
130            buf: Vec::with_capacity(4096),
131            decompressed_payload: None,
132            decompressed_cursor: 0,
133            content_store: None,
134        }
135    }
136
137    /// Attach a content store for resolving `IS_REFERENCE` blocks.
138    ///
139    /// When a block has the `IS_REFERENCE` flag set, its 32-byte body
140    /// is looked up in this store to retrieve the original content.
141    #[must_use]
142    pub fn with_content_store(mut self, store: Arc<dyn ContentStore>) -> Self {
143        self.content_store = Some(store);
144        self
145    }
146
147    /// Read the next event from the stream.
148    ///
149    /// Returns `Ok(Some(event))` for each decoded event, `Ok(None)`
150    /// when the stream is exhausted (END sentinel reached), or `Err`
151    /// on any decode error.
152    ///
153    /// The first call always yields `DecoderEvent::Header`. Subsequent
154    /// calls yield `DecoderEvent::Block` until the END sentinel.
155    pub async fn next(&mut self) -> Option<Result<DecoderEvent, DecodeError>> {
156        match self.state {
157            StreamState::ReadHeader => Some(self.read_header().await),
158            StreamState::ReadBlocks => self.read_next_block().await,
159            StreamState::Done => None,
160        }
161    }
162
163    /// Read and validate the 8-byte file header.
164    ///
165    /// If the header's `COMPRESSED` flag is set, the decoder reads
166    /// all remaining bytes from the stream, decompresses them with
167    /// zstd, and stores the result internally. Subsequent block reads
168    /// consume from the decompressed buffer.
169    async fn read_header(&mut self) -> Result<DecoderEvent, DecodeError> {
170        let mut header_buf = [0u8; HEADER_SIZE];
171        self.reader.read_exact(&mut header_buf).await.map_err(|_| {
172            DecodeError::InvalidHeader(bcp_wire::WireError::UnexpectedEof { offset: 0 })
173        })?;
174
175        let header = BcpHeader::read_from(&header_buf).map_err(DecodeError::InvalidHeader)?;
176
177        // Whole-payload decompression: buffer everything, decompress.
178        if header.flags.is_compressed() {
179            let mut compressed = Vec::new();
180            self.reader
181                .read_to_end(&mut compressed)
182                .await
183                .map_err(DecodeError::Io)?;
184            let decompressed =
185                decompression::decompress(&compressed, MAX_PAYLOAD_DECOMPRESSED_SIZE)?;
186            self.decompressed_payload = Some(decompressed);
187            self.decompressed_cursor = 0;
188        }
189
190        self.state = StreamState::ReadBlocks;
191        Ok(DecoderEvent::Header(header))
192    }
193
194    /// Read the next block frame from the stream.
195    ///
196    /// If a decompressed payload buffer exists (whole-payload mode),
197    /// reads from that buffer. Otherwise reads from the async reader.
198    ///
199    /// Per-block decompression and reference resolution are applied
200    /// transparently.
201    ///
202    /// Returns `None` when the END sentinel is encountered, transitioning
203    /// the state to `Done`.
204    async fn read_next_block(&mut self) -> Option<Result<DecoderEvent, DecodeError>> {
205        // If we have a decompressed payload buffer, parse from it
206        // using BlockFrame::read_from (synchronous path).
207        if let Some(ref payload) = self.decompressed_payload {
208            if self.decompressed_cursor >= payload.len() {
209                self.state = StreamState::Done;
210                return Some(Err(DecodeError::MissingEndSentinel));
211            }
212
213            let remaining = &payload[self.decompressed_cursor..];
214            match BlockFrame::read_from(remaining) {
215                Ok(Some((frame, consumed))) => {
216                    self.decompressed_cursor += consumed;
217                    Some(self.decode_frame(&frame))
218                }
219                Ok(None) => {
220                    // END sentinel — compute its size and advance cursor.
221                    // END = varint(0xFF) + flags(1 byte) + varint(0x00)
222                    match end_sentinel_size(remaining) {
223                        Ok(size) => self.decompressed_cursor += size,
224                        Err(e) => return Some(Err(e)),
225                    }
226                    self.state = StreamState::Done;
227                    None
228                }
229                Err(e) => Some(Err(DecodeError::from(e))),
230            }
231        } else {
232            self.read_next_block_from_reader().await
233        }
234    }
235
236    /// Read the next block frame from the async reader (non-buffered path).
237    async fn read_next_block_from_reader(&mut self) -> Option<Result<DecoderEvent, DecodeError>> {
238        // Read block_type varint
239        let block_type_raw = match self.read_varint().await {
240            Ok(v) => v,
241            Err(e) => return Some(Err(e)),
242        };
243
244        #[allow(clippy::cast_possible_truncation)]
245        let block_type_byte = block_type_raw as u8;
246
247        // Check for END sentinel
248        if block_type_byte == 0xFF {
249            match self.read_end_frame_tail().await {
250                Ok(()) => {}
251                Err(e) => return Some(Err(e)),
252            }
253            self.state = StreamState::Done;
254            return None;
255        }
256
257        // Read flags (single byte)
258        let mut flags_byte = [0u8; 1];
259        if let Err(e) = self.reader.read_exact(&mut flags_byte).await {
260            return Some(Err(DecodeError::Io(e)));
261        }
262        let flags = BlockFlags::from_raw(flags_byte[0]);
263
264        // Read content_len varint
265        #[allow(clippy::cast_possible_truncation)]
266        let content_len = match self.read_varint().await {
267            Ok(v) => v as usize,
268            Err(e) => return Some(Err(e)),
269        };
270
271        // Read body bytes
272        self.buf.clear();
273        self.buf.resize(content_len, 0);
274        if let Err(e) = self.reader.read_exact(&mut self.buf[..content_len]).await {
275            return Some(Err(DecodeError::Io(e)));
276        }
277
278        let frame = bcp_wire::block_frame::BlockFrame {
279            block_type: block_type_byte,
280            flags,
281            body: self.buf[..content_len].to_vec(),
282        };
283
284        Some(self.decode_frame(&frame))
285    }
286
287    /// Decode a `BlockFrame` into a `DecoderEvent::Block`.
288    ///
289    /// Handles reference resolution, decompression, summary extraction,
290    /// and body deserialization.
291    fn decode_frame(
292        &self,
293        frame: &bcp_wire::block_frame::BlockFrame,
294    ) -> Result<DecoderEvent, DecodeError> {
295        let block_type = BlockType::from_wire_id(frame.block_type);
296
297        // Stage 1: Resolve content-addressed references.
298        let resolved_body = if frame.flags.is_reference() {
299            let store = self
300                .content_store
301                .as_ref()
302                .ok_or(DecodeError::MissingContentStore)?;
303            if frame.body.len() != 32 {
304                return Err(DecodeError::Wire(bcp_wire::WireError::UnexpectedEof {
305                    offset: frame.body.len(),
306                }));
307            }
308            let hash: [u8; 32] = frame.body[..32].try_into().expect("length already checked");
309            store
310                .get(&hash)
311                .ok_or(DecodeError::UnresolvedReference { hash })?
312        } else {
313            frame.body.clone()
314        };
315
316        // Stage 2: Per-block decompression.
317        let decompressed_body = if frame.flags.is_compressed() {
318            decompression::decompress(&resolved_body, MAX_BLOCK_DECOMPRESSED_SIZE)?
319        } else {
320            resolved_body
321        };
322
323        // Stage 3 & 4: Summary extraction + TLV decode.
324        let mut body = decompressed_body.as_slice();
325        let mut summary = None;
326
327        if frame.flags.has_summary() {
328            match Summary::decode(body) {
329                Ok((sum, consumed)) => {
330                    summary = Some(sum);
331                    body = &body[consumed..];
332                }
333                Err(e) => return Err(e.into()),
334            }
335        }
336
337        let content = BlockContent::decode_body(&block_type, body)?;
338
339        Ok(DecoderEvent::Block(Block {
340            block_type,
341            flags: frame.flags,
342            summary,
343            content,
344        }))
345    }
346
347    /// Read the trailing flags + `content_len` bytes of an END frame.
348    ///
349    /// The END sentinel has: `flags`=0x00 (1 byte) + `content_len`=0x00 (1 byte).
350    /// We read and discard these to fully consume the END frame.
351    async fn read_end_frame_tail(&mut self) -> Result<(), DecodeError> {
352        // flags byte
353        let mut byte = [0u8; 1];
354        self.reader
355            .read_exact(&mut byte)
356            .await
357            .map_err(DecodeError::Io)?;
358
359        // content_len varint (should be 0)
360        let _content_len = self.read_varint().await?;
361        Ok(())
362    }
363
364    /// Read a single varint from the async reader.
365    ///
366    /// Varints are read byte-by-byte: each byte's MSB indicates whether
367    /// more bytes follow (1 = more, 0 = last byte). Maximum 10 bytes
368    /// for a 64-bit value.
369    async fn read_varint(&mut self) -> Result<u64, DecodeError> {
370        let mut varint_buf = [0u8; 10];
371        let mut len = 0;
372
373        loop {
374            let mut byte = [0u8; 1];
375            self.reader
376                .read_exact(&mut byte)
377                .await
378                .map_err(DecodeError::Io)?;
379            varint_buf[len] = byte[0];
380            len += 1;
381
382            // MSB clear means this is the last byte
383            if byte[0] & 0x80 == 0 {
384                break;
385            }
386
387            if len >= 10 {
388                return Err(DecodeError::Wire(bcp_wire::WireError::VarintTooLong));
389            }
390        }
391
392        let (value, _) = decode_varint(&varint_buf[..len])?;
393        Ok(value)
394    }
395}
396
397/// Calculate the byte size of the END sentinel from a buffer slice.
398///
399/// Used by the streaming decoder when parsing from a decompressed
400/// payload buffer. The END sentinel is: `varint(0xFF)` + `flags(0x00)`
401/// + `varint(0x00)`.
402fn end_sentinel_size(buf: &[u8]) -> Result<usize, DecodeError> {
403    let (_, type_len) = decode_varint(buf)?;
404    let mut size = type_len;
405    // flags byte
406    size += 1;
407    // content_len varint
408    let rest = buf
409        .get(size..)
410        .ok_or(DecodeError::Wire(bcp_wire::WireError::UnexpectedEof {
411            offset: size,
412        }))?;
413    let (_, len_size) = decode_varint(rest)?;
414    size += len_size;
415    Ok(size)
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421    use bcp_encoder::BcpEncoder;
422    use bcp_types::enums::{Lang, Priority, Role, Status};
423
424    /// Helper: encode a payload and decode it via the streaming decoder,
425    /// collecting all events into a Vec.
426    async fn stream_roundtrip(encoder: &BcpEncoder) -> Vec<DecoderEvent> {
427        let payload = encoder.encode().unwrap();
428        let cursor = std::io::Cursor::new(payload);
429        let reader = tokio::io::BufReader::new(cursor);
430
431        let mut decoder = StreamingDecoder::new(reader);
432        let mut events = Vec::new();
433
434        while let Some(result) = decoder.next().await {
435            events.push(result.unwrap());
436        }
437
438        events
439    }
440
441    #[tokio::test]
442    async fn streaming_produces_header_then_blocks() {
443        let mut enc = BcpEncoder::new();
444        enc.add_code(Lang::Rust, "main.rs", b"fn main() {}")
445            .add_conversation(Role::User, b"hello");
446        let events = stream_roundtrip(&enc).await;
447
448        assert_eq!(events.len(), 3); // Header + 2 blocks
449
450        assert!(matches!(&events[0], DecoderEvent::Header(h) if h.version_major == 1));
451        assert!(matches!(&events[1], DecoderEvent::Block(b) if b.block_type == BlockType::Code));
452        assert!(
453            matches!(&events[2], DecoderEvent::Block(b) if b.block_type == BlockType::Conversation)
454        );
455    }
456
457    #[tokio::test]
458    async fn streaming_matches_sync_decoder() {
459        let mut encoder = BcpEncoder::new();
460        encoder
461            .add_code(Lang::Rust, "lib.rs", b"pub fn x() {}")
462            .with_summary("Function x.").unwrap()
463            .with_priority(Priority::High).unwrap()
464            .add_conversation(Role::User, b"What does x do?")
465            .add_tool_result("docs", Status::Ok, b"x is a placeholder.");
466
467        let payload = encoder.encode().unwrap();
468
469        // Sync decode
470        let sync_decoded = crate::BcpDecoder::decode(&payload).unwrap();
471
472        // Streaming decode
473        let events = stream_roundtrip(&encoder).await;
474
475        // Extract blocks from events (skip the Header event)
476        let stream_blocks: Vec<_> = events
477            .into_iter()
478            .filter_map(|e| match e {
479                DecoderEvent::Block(b) => Some(b),
480                _ => None,
481            })
482            .collect();
483
484        // Same number of blocks
485        assert_eq!(sync_decoded.blocks.len(), stream_blocks.len());
486
487        // Same block types in same order
488        for (sync_block, stream_block) in sync_decoded.blocks.iter().zip(stream_blocks.iter()) {
489            assert_eq!(sync_block.block_type, stream_block.block_type);
490            assert_eq!(sync_block.flags, stream_block.flags);
491            assert_eq!(sync_block.summary, stream_block.summary);
492        }
493    }
494
495    #[tokio::test]
496    async fn streaming_handles_summary_blocks() {
497        let mut enc = BcpEncoder::new();
498        enc.add_code(Lang::Python, "app.py", b"print('hi')")
499            .with_summary("Prints a greeting.").unwrap();
500        let events = stream_roundtrip(&enc).await;
501
502        let block = match &events[1] {
503            DecoderEvent::Block(b) => b,
504            other => panic!("expected Block, got {other:?}"),
505        };
506
507        assert!(block.flags.has_summary());
508        assert_eq!(block.summary.as_ref().unwrap().text, "Prints a greeting.");
509    }
510
511    #[tokio::test]
512    async fn streaming_empty_body_blocks() {
513        let mut enc = BcpEncoder::new();
514        enc.add_extension("ns", "t", b"");
515        let events = stream_roundtrip(&enc).await;
516
517        assert_eq!(events.len(), 2); // Header + Extension
518    }
519
520    #[tokio::test]
521    async fn streaming_terminates_at_end_sentinel() {
522        let mut enc = BcpEncoder::new();
523        enc.add_conversation(Role::User, b"hi");
524        let events = stream_roundtrip(&enc).await;
525
526        // After all events, decoder should return None
527        assert_eq!(events.len(), 2); // Header + 1 block
528    }
529
530    // ── Per-block compression streaming tests ───────────────────────────
531
532    #[tokio::test]
533    async fn streaming_per_block_compression_roundtrip() {
534        let big_content = "fn main() { println!(\"hello world\"); }\n".repeat(50);
535        let mut enc = BcpEncoder::new();
536        enc.add_code(Lang::Rust, "main.rs", big_content.as_bytes())
537            .with_compression().unwrap();
538        let events = stream_roundtrip(&enc).await;
539
540        assert_eq!(events.len(), 2); // Header + 1 block
541        let block = match &events[1] {
542            DecoderEvent::Block(b) => b,
543            other => panic!("expected Block, got {other:?}"),
544        };
545
546        match &block.content {
547            BlockContent::Code(code) => {
548                assert_eq!(code.content, big_content.as_bytes());
549            }
550            other => panic!("expected Code, got {other:?}"),
551        }
552    }
553
554    // ── Whole-payload compression streaming tests ───────────────────────
555
556    #[tokio::test]
557    async fn streaming_whole_payload_compression_roundtrip() {
558        let big_content = "use std::io;\n".repeat(100);
559        let mut enc = BcpEncoder::new();
560        enc.add_code(Lang::Rust, "a.rs", big_content.as_bytes())
561            .add_code(Lang::Rust, "b.rs", big_content.as_bytes());
562        enc.compress_payload();
563        let events = stream_roundtrip(&enc).await;
564
565        // Header + 2 blocks
566        assert_eq!(events.len(), 3);
567
568        // Verify header has COMPRESSED flag
569        match &events[0] {
570            DecoderEvent::Header(h) => assert!(h.flags.is_compressed()),
571            other => panic!("expected Header, got {other:?}"),
572        }
573
574        // Both blocks should decompress correctly
575        for event in &events[1..] {
576            match event {
577                DecoderEvent::Block(block) => match &block.content {
578                    BlockContent::Code(code) => {
579                        assert_eq!(code.content, big_content.as_bytes());
580                    }
581                    other => panic!("expected Code, got {other:?}"),
582                },
583                other => panic!("expected Block, got {other:?}"),
584            }
585        }
586    }
587
588    // ── Content store streaming tests ───────────────────────────────────
589
590    #[tokio::test]
591    async fn streaming_content_addressing_roundtrip() {
592        let store = Arc::new(bcp_encoder::MemoryContentStore::new());
593        let mut enc = BcpEncoder::new();
594        enc.set_content_store(store.clone())
595            .add_code(Lang::Rust, "main.rs", b"fn main() {}")
596            .with_content_addressing().unwrap();
597
598        let payload = enc.encode().unwrap();
599        let cursor = std::io::Cursor::new(payload);
600        let reader = tokio::io::BufReader::new(cursor);
601
602        let mut decoder = StreamingDecoder::new(reader).with_content_store(store);
603        let mut events = Vec::new();
604        while let Some(result) = decoder.next().await {
605            events.push(result.unwrap());
606        }
607
608        assert_eq!(events.len(), 2); // Header + 1 block
609        match &events[1] {
610            DecoderEvent::Block(block) => match &block.content {
611                BlockContent::Code(code) => {
612                    assert_eq!(code.content, b"fn main() {}");
613                }
614                other => panic!("expected Code, got {other:?}"),
615            },
616            other => panic!("expected Block, got {other:?}"),
617        }
618    }
619
620    #[tokio::test]
621    async fn streaming_matches_sync_compressed() {
622        let big_content = "pub fn hello() -> &'static str { \"world\" }\n".repeat(100);
623        let mut encoder = BcpEncoder::new();
624        encoder
625            .add_code(Lang::Rust, "lib.rs", big_content.as_bytes())
626            .with_summary("Hello function.").unwrap()
627            .add_conversation(Role::User, b"explain");
628        encoder.compress_payload();
629
630        let payload = encoder.encode().unwrap();
631
632        // Sync decode
633        let sync_decoded = crate::BcpDecoder::decode(&payload).unwrap();
634
635        // Streaming decode
636        let events = stream_roundtrip(&encoder).await;
637        let stream_blocks: Vec<_> = events
638            .into_iter()
639            .filter_map(|e| match e {
640                DecoderEvent::Block(b) => Some(b),
641                _ => None,
642            })
643            .collect();
644
645        assert_eq!(sync_decoded.blocks.len(), stream_blocks.len());
646        for (sync_block, stream_block) in sync_decoded.blocks.iter().zip(stream_blocks.iter()) {
647            assert_eq!(sync_block.block_type, stream_block.block_type);
648            assert_eq!(sync_block.summary, stream_block.summary);
649        }
650    }
651}