pdk_classy/hl/
body_stream.rs

1// Copyright (c) 2025, Salesforce, Inc.,
2// All rights reserved.
3// For full license text, see the LICENSE.txt file
4
5use std::pin::pin;
6use std::task::Context;
7use std::{cell::RefCell, marker::PhantomData, pin::Pin, ptr, rc::Rc, task::Poll};
8
9use futures::{AsyncRead, Stream, StreamExt};
10use std::future::Future;
11
12use super::{
13    dynamic_exchange::{DynamicExchange, ExchangeEvent},
14    entity::Chunk,
15};
16use crate::event::{BodyEvent, Event};
17
18pub struct BodyStreamExchange<B: Event> {
19    #[cfg(feature = "experimental")]
20    host: Option<Rc<dyn crate::host::Host>>,
21    contains_body: bool,
22    #[allow(unused)]
23    exchange: Rc<RefCell<DynamicExchange>>,
24    stream: RefCell<Option<Pin<Box<dyn Stream<Item = Chunk>>>>>,
25    _event: PhantomData<B>,
26}
27
28impl<B: ExchangeEvent + BodyEvent + 'static> BodyStreamExchange<B> {
29    #[allow(clippy::await_holding_refcell_ref)]
30    pub(super) async fn new(exchange: Rc<RefCell<DynamicExchange>>, contains_body: bool) -> Self {
31        if !contains_body {
32            return Self {
33                #[cfg(feature = "experimental")]
34                host: None,
35                contains_body,
36                exchange,
37                stream: RefCell::new(None),
38                _event: PhantomData,
39            };
40        };
41
42        let mut ref_mut = exchange.as_ref().borrow_mut();
43        let ex = ref_mut.wait_for_event_buffering::<B>(false).await;
44        let stream = ex.map(|e| {
45            Box::pin(
46                e.static_event_data_stream()
47                    .map(|e| Chunk::new(e.read_body(0, e.chunk_size()))),
48            ) as Pin<Box<dyn Stream<Item = Chunk>>>
49        });
50        #[cfg(feature = "experimental")]
51        let host = ex.map(|e| e.host.clone());
52        drop(ref_mut);
53        Self {
54            #[cfg(feature = "experimental")]
55            host,
56            contains_body,
57            exchange,
58            stream: RefCell::new(stream),
59            _event: PhantomData,
60        }
61    }
62
63    pub(super) fn contains_body(&self) -> bool {
64        self.contains_body
65    }
66
67    pub(super) fn stream(&self) -> BodyStream {
68        let mut stream = self.stream.borrow_mut();
69        BodyStream {
70            inner: stream.take(),
71            _lifetime: PhantomData,
72        }
73    }
74
75    #[cfg(feature = "experimental")]
76    pub(super) fn write_chunk(&self, chunk: Chunk) -> Result<(), super::BodyError> {
77        let Some(host) = &self.host else {
78            return Err(super::BodyError::BodyNotSent);
79        };
80
81        if chunk.bytes().len() >= crate::hl::body::MAX_BODY_SIZE {
82            return Err(super::BodyError::ExceededBodySize(chunk.bytes().len()));
83        }
84
85        B::write_body(host.as_ref(), 0, usize::MAX, chunk.bytes());
86
87        Ok(())
88    }
89}
90
91pub struct BodyStream<'b> {
92    inner: Option<Pin<Box<dyn Stream<Item = Chunk>>>>,
93    _lifetime: PhantomData<&'b ()>,
94}
95
96impl BodyStream<'_> {
97    pub async fn next(&mut self) -> Option<Chunk> {
98        StreamExt::next(self).await
99    }
100
101    pub async fn collect(&mut self) -> Chunk {
102        let mut bytes = Vec::<u8>::new();
103        while let Some(chunk) = self.next().await {
104            bytes.append(&mut chunk.into_bytes());
105        }
106        Chunk::new(bytes)
107    }
108}
109
110impl<'b> Stream for BodyStream<'b>
111where
112    Self: 'b,
113{
114    type Item = Chunk;
115
116    fn poll_next(
117        mut self: Pin<&mut Self>,
118        cx: &mut std::task::Context<'_>,
119    ) -> Poll<Option<Self::Item>> {
120        let Some(inner) = self.inner.as_mut() else {
121            return Poll::Ready(None);
122        };
123        inner.as_mut().poll_next(cx)
124    }
125}
126
127pub struct BodyStreamAsyncReader<'a> {
128    stream: RefCell<BodyStream<'a>>,
129    chunk: RefCell<Option<Chunk>>,
130    last: RefCell<usize>,
131}
132
133impl<'a> BodyStreamAsyncReader<'a> {
134    pub fn new(stream: BodyStream<'a>) -> Self {
135        Self {
136            stream: RefCell::new(stream),
137            chunk: RefCell::new(None),
138            last: RefCell::new(0),
139        }
140    }
141}
142
143impl AsyncRead for BodyStreamAsyncReader<'_> {
144    fn poll_read(
145        self: Pin<&mut Self>,
146        cx: &mut Context<'_>,
147        buf: &mut [u8],
148    ) -> Poll<std::io::Result<usize>> {
149        let mut last = *self.last.borrow();
150        let mut chunk_len = self
151            .chunk
152            .borrow()
153            .as_ref()
154            .map(|c| c.bytes().len())
155            .unwrap_or_default();
156
157        if last == chunk_len {
158            match pin!(self.stream.borrow_mut().next()).poll(cx) {
159                Poll::Ready(Some(chunk)) => {
160                    last = 0;
161                    chunk_len = chunk.bytes().len();
162                    *self.chunk.borrow_mut() = Some(chunk);
163                }
164                Poll::Ready(None) => return Poll::Ready(Ok(0)),
165                Poll::Pending => return Poll::Pending,
166            }
167        };
168
169        let mut read = 0;
170        if let Some(chunk) = self.chunk.borrow().as_ref() {
171            let new_last = match chunk_len - last <= buf.len() {
172                true => chunk_len,
173                false => last + buf.len(),
174            };
175
176            read = new_last - last;
177
178            unsafe {
179                ptr::copy_nonoverlapping(
180                    chunk.bytes()[last..new_last].as_ptr(),
181                    buf.as_mut_ptr(),
182                    read,
183                );
184            }
185
186            *self.last.borrow_mut() = new_last;
187        }
188        Poll::Ready(Ok(read))
189    }
190}