use futures::stream::{self, Stream, StreamExt};
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use super::bidirectional::BidirChannel;
use super::context::PlexusContext;
use super::credential_envelope::{
assemble_envelope_content, serialize_with_credential_capture, CookieProjector,
};
use super::types::{PlexusStreamItem, StreamMetadata};
pub type PlexusStream = Pin<Box<dyn Stream<Item = PlexusStreamItem> + Send>>;
pub fn wrap_stream<T: Serialize + Send + 'static>(
stream: impl Stream<Item = T> + Send + 'static,
content_type: &'static str,
provenance: Vec<String>,
) -> PlexusStream {
let plexus_hash = PlexusContext::hash();
let metadata = StreamMetadata::new(provenance.clone(), plexus_hash.clone());
let done_metadata = StreamMetadata::new(provenance, plexus_hash);
let projector = CookieProjector::None;
let data_stream = stream.map(move |item| {
let (payload, captured) = serialize_with_credential_capture(&item);
let (content, _hints) =
assemble_envelope_content(payload, captured, &projector);
PlexusStreamItem::Data {
metadata: metadata.clone(),
content_type: content_type.to_string(),
content,
}
});
let done_stream = stream::once(async move { PlexusStreamItem::Done {
metadata: done_metadata,
}});
Box::pin(data_stream.chain(done_stream))
}
pub fn create_bidir_stream<Req, Resp>(
_content_type: &'static str,
provenance: Vec<String>,
) -> (
Arc<BidirChannel<Req, Resp>>,
impl FnOnce(Pin<Box<dyn Stream<Item = PlexusStreamItem> + Send>>) -> PlexusStream,
)
where
Req: Serialize + DeserializeOwned + Send + Sync + 'static,
Resp: Serialize + DeserializeOwned + Send + Sync + 'static,
{
let plexus_hash = PlexusContext::hash();
let (bidir_tx, bidir_rx) = mpsc::channel::<PlexusStreamItem>(32);
let bidir_channel = Arc::new(BidirChannel::<Req, Resp>::new(
bidir_tx,
true, provenance.clone(),
plexus_hash.clone(),
));
let wrap_fn = move |user_stream: Pin<Box<dyn Stream<Item = PlexusStreamItem> + Send>>| -> PlexusStream {
let bidir_stream = ReceiverStream::new(bidir_rx);
let merged = stream::select(user_stream, bidir_stream);
Box::pin(merged)
};
(bidir_channel, wrap_fn)
}
pub fn wrap_stream_with_bidir<T, Req, Resp>(
stream: impl Stream<Item = T> + Send + 'static,
content_type: &'static str,
provenance: Vec<String>,
) -> (Arc<BidirChannel<Req, Resp>>, PlexusStream)
where
T: Serialize + Send + 'static,
Req: Serialize + DeserializeOwned + Send + Sync + 'static,
Resp: Serialize + DeserializeOwned + Send + Sync + 'static,
{
let (ctx, wrap_fn) = create_bidir_stream::<Req, Resp>(content_type, provenance.clone());
let wrapped_user_stream = wrap_stream(stream, content_type, provenance);
let merged = wrap_fn(wrapped_user_stream);
(ctx, merged)
}
pub fn error_stream(
message: String,
provenance: Vec<String>,
recoverable: bool,
) -> PlexusStream {
let metadata = StreamMetadata::new(provenance, PlexusContext::hash());
Box::pin(stream::once(async move {
PlexusStreamItem::Error {
metadata,
message,
code: None,
recoverable,
}
}))
}
pub fn error_stream_with_code(
message: String,
code: String,
provenance: Vec<String>,
recoverable: bool,
) -> PlexusStream {
let metadata = StreamMetadata::new(provenance, PlexusContext::hash());
Box::pin(stream::once(async move {
PlexusStreamItem::Error {
metadata,
message,
code: Some(code),
recoverable,
}
}))
}
pub fn done_stream(provenance: Vec<String>) -> PlexusStream {
let metadata = StreamMetadata::new(provenance, PlexusContext::hash());
Box::pin(stream::once(async move {
PlexusStreamItem::Done { metadata }
}))
}
pub fn progress_stream(
message: String,
percentage: Option<f32>,
provenance: Vec<String>,
) -> PlexusStream {
let metadata = StreamMetadata::new(provenance, PlexusContext::hash());
Box::pin(stream::once(async move {
PlexusStreamItem::Progress {
metadata,
message,
percentage,
}
}))
}
#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
struct TestEvent {
value: i32,
}
#[tokio::test]
async fn test_wrap_stream() {
let events = vec![TestEvent { value: 1 }, TestEvent { value: 2 }];
let input_stream = stream::iter(events);
let wrapped = wrap_stream(input_stream, "test.event", vec!["test".into()]);
let items: Vec<_> = wrapped.collect().await;
assert_eq!(items.len(), 3);
match &items[0] {
PlexusStreamItem::Data {
content_type,
content,
metadata,
} => {
assert_eq!(content_type, "test.event");
assert_eq!(content["value"], 1);
assert_eq!(metadata.provenance, vec!["test"]);
}
_ => panic!("Expected Data item"),
}
assert!(matches!(items[2], PlexusStreamItem::Done { .. }));
}
#[tokio::test]
async fn wrap_stream_credential_free_payload_is_wire_identical() {
let events = vec![TestEvent { value: 7 }];
let input_stream = stream::iter(events);
let wrapped = wrap_stream(input_stream, "test.event", vec!["t".into()]);
let items: Vec<_> = wrapped.collect().await;
assert_eq!(items.len(), 2); match &items[0] {
PlexusStreamItem::Data { content, .. } => {
let obj = content.as_object().expect("object");
assert_eq!(obj.get("value").unwrap(), &serde_json::json!(7));
assert!(
!obj.contains_key("_credentials"),
"_credentials key MUST NOT appear on non-credential payloads"
);
assert_eq!(obj.len(), 1, "no extra fields");
}
_ => panic!("Expected Data item"),
}
}
#[tokio::test]
async fn wrap_stream_credential_bearing_payload_emits_sentinel_in_body() {
use plexus_auth_core::{
AttachmentSite, Credential, CredentialIssuer, CredentialKind, CredentialMetadata,
CredentialMinter, CredentialScheme, HeaderName, MethodPath, Origin, Scope,
};
let _ = (
CredentialMinter::issuer, Credential::<String>::metadata,
CredentialIssuer::new(
Origin::new("ws://test"),
MethodPath::try_new("auth.login").unwrap(),
),
CredentialMetadata::new(
CredentialKind::Bearer,
AttachmentSite::Header {
name: HeaderName::try_new("authorization").unwrap(),
},
Some(CredentialScheme::new("Bearer ")),
Vec::<Scope>::new(),
None,
None,
None,
CredentialIssuer::new(
Origin::new("ws://test"),
MethodPath::try_new("auth.login").unwrap(),
),
),
);
#[derive(Serialize)]
struct LoginPayload {
user_id: String,
session: serde_json::Value,
}
let payload = LoginPayload {
user_id: "alice".into(),
session: serde_json::json!({ "$credential": "cred_0" }),
};
let input_stream = stream::iter(vec![payload]);
let wrapped = wrap_stream(input_stream, "auth.login.result", vec!["auth".into()]);
let items: Vec<_> = wrapped.collect().await;
let content = match &items[0] {
PlexusStreamItem::Data { content, .. } => content,
_ => panic!("Expected Data item"),
};
assert_eq!(
content.get("session").unwrap(),
&serde_json::json!({ "$credential": "cred_0" })
);
let obj = content.as_object().unwrap();
assert!(!obj.contains_key("_credentials"),
"sidecar absent until plexus-auth-core exposes DispatchCaptureGuard::install");
}
#[tokio::test]
async fn test_error_stream() {
let stream = error_stream("Something failed".into(), vec!["test".into()], false);
let items: Vec<_> = stream.collect().await;
assert_eq!(items.len(), 1);
match &items[0] {
PlexusStreamItem::Error {
message,
recoverable,
code,
..
} => {
assert_eq!(message, "Something failed");
assert!(!recoverable);
assert!(code.is_none());
}
_ => panic!("Expected Error item"),
}
}
#[tokio::test]
async fn test_error_stream_with_code() {
let stream = error_stream_with_code(
"Not found".into(),
"NOT_FOUND".into(),
vec!["test".into()],
true,
);
let items: Vec<_> = stream.collect().await;
assert_eq!(items.len(), 1);
match &items[0] {
PlexusStreamItem::Error {
message,
code,
recoverable,
..
} => {
assert_eq!(message, "Not found");
assert_eq!(code.as_deref(), Some("NOT_FOUND"));
assert!(recoverable);
}
_ => panic!("Expected Error item"),
}
}
}