1use futures::stream::{self, Stream, StreamExt};
8use serde::Serialize;
9use serde::de::DeserializeOwned;
10use std::pin::Pin;
11use std::sync::Arc;
12use tokio::sync::mpsc;
13use tokio_stream::wrappers::ReceiverStream;
14
15use super::bidirectional::BidirChannel;
16use super::context::PlexusContext;
17use super::types::{PlexusStreamItem, StreamMetadata};
18
19pub type PlexusStream = Pin<Box<dyn Stream<Item = PlexusStreamItem> + Send>>;
21
22pub fn wrap_stream<T: Serialize + Send + 'static>(
37 stream: impl Stream<Item = T> + Send + 'static,
38 content_type: &'static str,
39 provenance: Vec<String>,
40) -> PlexusStream {
41 let plexus_hash = PlexusContext::hash();
42 let metadata = StreamMetadata::new(provenance.clone(), plexus_hash.clone());
43 let done_metadata = StreamMetadata::new(provenance, plexus_hash);
44
45 let data_stream = stream.map(move |item| PlexusStreamItem::Data {
46 metadata: metadata.clone(),
47 content_type: content_type.to_string(),
48 content: serde_json::to_value(item).expect("serialization failed"),
49 });
50
51 let done_stream = stream::once(async move { PlexusStreamItem::Done {
52 metadata: done_metadata,
53 }});
54
55 Box::pin(data_stream.chain(done_stream))
56}
57
58
59pub fn create_bidir_stream<Req, Resp>(
89 content_type: &'static str,
90 provenance: Vec<String>,
91) -> (
92 Arc<BidirChannel<Req, Resp>>,
93 impl FnOnce(Pin<Box<dyn Stream<Item = PlexusStreamItem> + Send>>) -> PlexusStream,
94)
95where
96 Req: Serialize + DeserializeOwned + Send + Sync + 'static,
97 Resp: Serialize + DeserializeOwned + Send + Sync + 'static,
98{
99 let plexus_hash = PlexusContext::hash();
100
101 let (bidir_tx, bidir_rx) = mpsc::channel::<PlexusStreamItem>(32);
103
104 let bidir_channel = Arc::new(BidirChannel::<Req, Resp>::new(
108 bidir_tx,
109 true, provenance.clone(),
111 plexus_hash.clone(),
112 ));
113
114 let wrap_fn = move |user_stream: Pin<Box<dyn Stream<Item = PlexusStreamItem> + Send>>| -> PlexusStream {
116 let bidir_stream = ReceiverStream::new(bidir_rx);
117
118 let merged = stream::select(user_stream, bidir_stream);
121
122 Box::pin(merged)
123 };
124
125 (bidir_channel, wrap_fn)
126}
127
128pub fn wrap_stream_with_bidir<T, Req, Resp>(
153 stream: impl Stream<Item = T> + Send + 'static,
154 content_type: &'static str,
155 provenance: Vec<String>,
156) -> (Arc<BidirChannel<Req, Resp>>, PlexusStream)
157where
158 T: Serialize + Send + 'static,
159 Req: Serialize + DeserializeOwned + Send + Sync + 'static,
160 Resp: Serialize + DeserializeOwned + Send + Sync + 'static,
161{
162 let (ctx, wrap_fn) = create_bidir_stream::<Req, Resp>(content_type, provenance.clone());
163
164 let wrapped_user_stream = wrap_stream(stream, content_type, provenance);
166
167 let merged = wrap_fn(wrapped_user_stream);
169
170 (ctx, merged)
171}
172
173pub fn error_stream(
177 message: String,
178 provenance: Vec<String>,
179 recoverable: bool,
180) -> PlexusStream {
181 let metadata = StreamMetadata::new(provenance, PlexusContext::hash());
182
183 Box::pin(stream::once(async move {
184 PlexusStreamItem::Error {
185 metadata,
186 message,
187 code: None,
188 recoverable,
189 }
190 }))
191}
192
193pub fn error_stream_with_code(
197 message: String,
198 code: String,
199 provenance: Vec<String>,
200 recoverable: bool,
201) -> PlexusStream {
202 let metadata = StreamMetadata::new(provenance, PlexusContext::hash());
203
204 Box::pin(stream::once(async move {
205 PlexusStreamItem::Error {
206 metadata,
207 message,
208 code: Some(code),
209 recoverable,
210 }
211 }))
212}
213
214pub fn done_stream(provenance: Vec<String>) -> PlexusStream {
218 let metadata = StreamMetadata::new(provenance, PlexusContext::hash());
219
220 Box::pin(stream::once(async move {
221 PlexusStreamItem::Done { metadata }
222 }))
223}
224
225pub fn progress_stream(
229 message: String,
230 percentage: Option<f32>,
231 provenance: Vec<String>,
232) -> PlexusStream {
233 let metadata = StreamMetadata::new(provenance, PlexusContext::hash());
234
235 Box::pin(stream::once(async move {
236 PlexusStreamItem::Progress {
237 metadata,
238 message,
239 percentage,
240 }
241 }))
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use futures::StreamExt;
248 use serde::{Deserialize, Serialize};
249
250 #[derive(Debug, Clone, Serialize, Deserialize)]
251 struct TestEvent {
252 value: i32,
253 }
254
255 #[tokio::test]
256 async fn test_wrap_stream() {
257 let events = vec![TestEvent { value: 1 }, TestEvent { value: 2 }];
258 let input_stream = stream::iter(events);
259
260 let wrapped = wrap_stream(input_stream, "test.event", vec!["test".into()]);
261 let items: Vec<_> = wrapped.collect().await;
262
263 assert_eq!(items.len(), 3);
265
266 match &items[0] {
268 PlexusStreamItem::Data {
269 content_type,
270 content,
271 metadata,
272 } => {
273 assert_eq!(content_type, "test.event");
274 assert_eq!(content["value"], 1);
275 assert_eq!(metadata.provenance, vec!["test"]);
276 }
277 _ => panic!("Expected Data item"),
278 }
279
280 assert!(matches!(items[2], PlexusStreamItem::Done { .. }));
282 }
283
284
285 #[tokio::test]
286 async fn test_error_stream() {
287 let stream = error_stream("Something failed".into(), vec!["test".into()], false);
288 let items: Vec<_> = stream.collect().await;
289
290 assert_eq!(items.len(), 1);
291 match &items[0] {
292 PlexusStreamItem::Error {
293 message,
294 recoverable,
295 code,
296 ..
297 } => {
298 assert_eq!(message, "Something failed");
299 assert!(!recoverable);
300 assert!(code.is_none());
301 }
302 _ => panic!("Expected Error item"),
303 }
304 }
305
306 #[tokio::test]
307 async fn test_error_stream_with_code() {
308 let stream = error_stream_with_code(
309 "Not found".into(),
310 "NOT_FOUND".into(),
311 vec!["test".into()],
312 true,
313 );
314 let items: Vec<_> = stream.collect().await;
315
316 assert_eq!(items.len(), 1);
317 match &items[0] {
318 PlexusStreamItem::Error {
319 message,
320 code,
321 recoverable,
322 ..
323 } => {
324 assert_eq!(message, "Not found");
325 assert_eq!(code.as_deref(), Some("NOT_FOUND"));
326 assert!(recoverable);
327 }
328 _ => panic!("Expected Error item"),
329 }
330 }
331}