use bytes::Bytes;
use pureflow_types::{MessageId, NodeId, PortId, WorkflowId};
use serde_json::Value;
use crate::context::ExecutionMetadata;
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(not(feature = "arrow"), derive(Eq))]
pub enum PacketPayload {
Bytes(Bytes),
Control(Value),
#[cfg(feature = "arrow")]
Arrow(arrow_array::RecordBatch),
}
impl PacketPayload {
#[must_use]
pub fn bytes(value: impl Into<Bytes>) -> Self {
Self::Bytes(value.into())
}
#[must_use]
pub fn control(value: impl Into<Value>) -> Self {
Self::Control(value.into())
}
#[must_use]
pub const fn as_bytes(&self) -> Option<&Bytes> {
match self {
Self::Bytes(bytes) => Some(bytes),
Self::Control(_) => None,
#[cfg(feature = "arrow")]
Self::Arrow(_) => None,
}
}
#[must_use]
pub const fn as_control(&self) -> Option<&Value> {
match self {
Self::Bytes(_) => None,
Self::Control(value) => Some(value),
#[cfg(feature = "arrow")]
Self::Arrow(_) => None,
}
}
#[cfg(feature = "arrow")]
#[must_use]
pub const fn as_arrow(&self) -> Option<&arrow_array::RecordBatch> {
match self {
Self::Bytes(_) | Self::Control(_) => None,
Self::Arrow(batch) => Some(batch),
}
}
}
impl From<Bytes> for PacketPayload {
fn from(value: Bytes) -> Self {
Self::Bytes(value)
}
}
impl From<Vec<u8>> for PacketPayload {
fn from(value: Vec<u8>) -> Self {
Self::Bytes(Bytes::from(value))
}
}
impl From<&'static [u8]> for PacketPayload {
fn from(value: &'static [u8]) -> Self {
Self::Bytes(Bytes::from_static(value))
}
}
impl From<Value> for PacketPayload {
fn from(value: Value) -> Self {
Self::Control(value)
}
}
#[cfg(feature = "arrow")]
impl From<arrow_array::RecordBatch> for PacketPayload {
fn from(value: arrow_array::RecordBatch) -> Self {
Self::Arrow(value)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MessageEndpoint {
node_id: NodeId,
port_id: PortId,
}
impl MessageEndpoint {
#[must_use]
pub const fn new(node_id: NodeId, port_id: PortId) -> Self {
Self { node_id, port_id }
}
#[must_use]
pub const fn node_id(&self) -> &NodeId {
&self.node_id
}
#[must_use]
pub const fn port_id(&self) -> &PortId {
&self.port_id
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MessageRoute {
source: Option<MessageEndpoint>,
target: MessageEndpoint,
}
impl MessageRoute {
#[must_use]
pub const fn new(source: Option<MessageEndpoint>, target: MessageEndpoint) -> Self {
Self { source, target }
}
#[must_use]
pub const fn source(&self) -> Option<&MessageEndpoint> {
self.source.as_ref()
}
#[must_use]
pub const fn target(&self) -> &MessageEndpoint {
&self.target
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MessageMetadata {
message_id: MessageId,
workflow_id: WorkflowId,
execution: ExecutionMetadata,
route: MessageRoute,
}
impl MessageMetadata {
#[must_use]
pub const fn new(
message_id: MessageId,
workflow_id: WorkflowId,
execution: ExecutionMetadata,
route: MessageRoute,
) -> Self {
Self {
message_id,
workflow_id,
execution,
route,
}
}
#[must_use]
pub const fn message_id(&self) -> &MessageId {
&self.message_id
}
#[must_use]
pub const fn workflow_id(&self) -> &WorkflowId {
&self.workflow_id
}
#[must_use]
pub const fn execution(&self) -> &ExecutionMetadata {
&self.execution
}
#[must_use]
pub const fn route(&self) -> &MessageRoute {
&self.route
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MessageEnvelope<P> {
metadata: MessageMetadata,
payload: P,
}
impl<P> MessageEnvelope<P> {
#[must_use]
pub const fn new(metadata: MessageMetadata, payload: P) -> Self {
Self { metadata, payload }
}
#[must_use]
pub const fn metadata(&self) -> &MessageMetadata {
&self.metadata
}
#[must_use]
pub const fn payload(&self) -> &P {
&self.payload
}
#[must_use]
pub fn into_payload(self) -> P {
self.payload
}
#[must_use]
pub fn map_payload<Q>(self, f: impl FnOnce(P) -> Q) -> MessageEnvelope<Q> {
MessageEnvelope {
metadata: self.metadata,
payload: f(self.payload),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use pureflow_types::ExecutionId;
use serde_json::json;
fn execution_id(value: &str) -> ExecutionId {
ExecutionId::new(value).expect("valid execution id")
}
fn message_id(value: &str) -> MessageId {
MessageId::new(value).expect("valid message id")
}
fn node_id(value: &str) -> NodeId {
NodeId::new(value).expect("valid node id")
}
fn port_id(value: &str) -> PortId {
PortId::new(value).expect("valid port id")
}
fn workflow_id(value: &str) -> WorkflowId {
WorkflowId::new(value).expect("valid workflow id")
}
fn execution() -> ExecutionMetadata {
ExecutionMetadata::first_attempt(execution_id("run-1"))
}
#[test]
fn message_envelope_keeps_payload_separate_from_metadata() {
let target: MessageEndpoint = MessageEndpoint::new(node_id("consumer"), port_id("in"));
let route: MessageRoute = MessageRoute::new(None, target);
let metadata: MessageMetadata =
MessageMetadata::new(message_id("msg-1"), workflow_id("flow"), execution(), route);
let envelope: MessageEnvelope<&str> = MessageEnvelope::new(metadata, "payload");
let mapped: MessageEnvelope<usize> = envelope.map_payload(str::len);
assert_eq!(mapped.payload(), &7);
assert_eq!(mapped.metadata().message_id().as_str(), "msg-1");
assert_eq!(
mapped.metadata().route().target().node_id().as_str(),
"consumer"
);
}
#[test]
fn packet_payload_bytes_clone_and_slice_without_copying_user_data() {
let payload: PacketPayload = PacketPayload::bytes(Bytes::from_static(b"abcdef"));
let cloned: PacketPayload = payload.clone();
let sliced: Bytes = cloned
.as_bytes()
.expect("payload should contain bytes")
.slice(1..4);
assert_eq!(
payload
.as_bytes()
.expect("payload should contain bytes")
.as_ref(),
b"abcdef"
);
assert!(payload.as_control().is_none());
assert_eq!(sliced.as_ref(), b"bcd");
}
#[test]
fn packet_payload_control_carries_structured_values() {
let payload: PacketPayload = PacketPayload::control(json!({
"command": "flush",
"priority": 3,
}));
let control: &Value = payload
.as_control()
.expect("payload should contain control data");
assert_eq!(control["command"], "flush");
assert_eq!(control["priority"], 3);
assert!(payload.as_bytes().is_none());
}
#[cfg(feature = "arrow")]
#[test]
fn packet_payload_arrow_carries_record_batches() {
use std::sync::Arc;
use arrow_array::{Int32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
let schema = Arc::new(Schema::new(vec![Field::new(
"value",
DataType::Int32,
false,
)]));
let values = Arc::new(Int32Array::from(vec![1, 2, 3]));
let batch: RecordBatch =
RecordBatch::try_new(schema, vec![values]).expect("record batch should be valid");
let payload: PacketPayload = PacketPayload::from(batch.clone());
assert_eq!(payload.as_arrow(), Some(&batch));
assert!(payload.as_bytes().is_none());
assert!(payload.as_control().is_none());
}
}