1use std::collections::BTreeMap;
4
5use pureflow_types::PortId;
6
7use crate::{PortPacket, Result};
8
9#[derive(Debug, Clone, Default, PartialEq)]
11#[cfg_attr(not(feature = "arrow"), derive(Eq))]
12pub struct BatchInputs {
13 packets_by_port: BTreeMap<PortId, Vec<PortPacket>>,
14}
15
16impl BatchInputs {
17 #[must_use]
19 pub const fn new() -> Self {
20 Self {
21 packets_by_port: BTreeMap::new(),
22 }
23 }
24
25 #[must_use]
27 pub fn from_packets(packets_by_port: impl Into<BTreeMap<PortId, Vec<PortPacket>>>) -> Self {
28 Self {
29 packets_by_port: packets_by_port.into(),
30 }
31 }
32
33 pub fn push(&mut self, port_id: PortId, packet: PortPacket) {
35 self.packets_by_port
36 .entry(port_id)
37 .or_default()
38 .push(packet);
39 }
40
41 #[must_use]
43 pub fn packets(&self, port_id: &PortId) -> &[PortPacket] {
44 self.packets_by_port.get(port_id).map_or(&[], Vec::as_slice)
45 }
46
47 #[must_use]
49 pub const fn packets_by_port(&self) -> &BTreeMap<PortId, Vec<PortPacket>> {
50 &self.packets_by_port
51 }
52
53 #[must_use]
55 pub fn into_packets_by_port(self) -> BTreeMap<PortId, Vec<PortPacket>> {
56 self.packets_by_port
57 }
58}
59
60#[derive(Debug, Clone, Default, PartialEq)]
62#[cfg_attr(not(feature = "arrow"), derive(Eq))]
63pub struct BatchOutputs {
64 packets_by_port: BTreeMap<PortId, Vec<PortPacket>>,
65}
66
67impl BatchOutputs {
68 #[must_use]
70 pub const fn new() -> Self {
71 Self {
72 packets_by_port: BTreeMap::new(),
73 }
74 }
75
76 #[must_use]
78 pub fn from_packets(packets_by_port: impl Into<BTreeMap<PortId, Vec<PortPacket>>>) -> Self {
79 Self {
80 packets_by_port: packets_by_port.into(),
81 }
82 }
83
84 pub fn push(&mut self, port_id: PortId, packet: PortPacket) {
86 self.packets_by_port
87 .entry(port_id)
88 .or_default()
89 .push(packet);
90 }
91
92 #[must_use]
94 pub fn packets(&self, port_id: &PortId) -> &[PortPacket] {
95 self.packets_by_port.get(port_id).map_or(&[], Vec::as_slice)
96 }
97
98 #[must_use]
100 pub const fn packets_by_port(&self) -> &BTreeMap<PortId, Vec<PortPacket>> {
101 &self.packets_by_port
102 }
103
104 #[must_use]
106 pub fn into_packets_by_port(self) -> BTreeMap<PortId, Vec<PortPacket>> {
107 self.packets_by_port
108 }
109}
110
111pub trait BatchExecutor: Send + Sync {
124 fn invoke(&self, inputs: BatchInputs) -> Result<BatchOutputs>;
130}
131
132pub struct WasmModule {
134 executor: Box<dyn BatchExecutor>,
135}
136
137impl WasmModule {
138 #[must_use]
140 pub const fn new(executor: Box<dyn BatchExecutor>) -> Self {
141 Self { executor }
142 }
143
144 pub fn invoke(&self, inputs: BatchInputs) -> Result<BatchOutputs> {
150 self.executor.invoke(inputs)
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 use crate::{
159 PacketPayload,
160 context::ExecutionMetadata,
161 message::{MessageEndpoint, MessageMetadata, MessageRoute},
162 };
163 use pureflow_types::{ExecutionId, MessageId, NodeId, WorkflowId};
164
165 fn execution_id(value: &str) -> ExecutionId {
166 ExecutionId::new(value).expect("valid execution id")
167 }
168
169 fn message_id(value: &str) -> MessageId {
170 MessageId::new(value).expect("valid message id")
171 }
172
173 fn node_id(value: &str) -> NodeId {
174 NodeId::new(value).expect("valid node id")
175 }
176
177 fn port_id(value: &str) -> PortId {
178 PortId::new(value).expect("valid port id")
179 }
180
181 fn workflow_id(value: &str) -> WorkflowId {
182 WorkflowId::new(value).expect("valid workflow id")
183 }
184
185 fn packet(value: &'static [u8]) -> PortPacket {
186 let source: MessageEndpoint = MessageEndpoint::new(node_id("source"), port_id("out"));
187 let target: MessageEndpoint = MessageEndpoint::new(node_id("wasm"), port_id("in"));
188 let route: MessageRoute = MessageRoute::new(Some(source), target);
189 let execution: ExecutionMetadata = ExecutionMetadata::first_attempt(execution_id("run-1"));
190 let metadata: MessageMetadata =
191 MessageMetadata::new(message_id("msg-1"), workflow_id("flow"), execution, route);
192
193 PortPacket::new(metadata, PacketPayload::from(value))
194 }
195
196 struct EchoBatchExecutor;
197
198 impl BatchExecutor for EchoBatchExecutor {
199 fn invoke(&self, inputs: BatchInputs) -> Result<BatchOutputs> {
200 let mut outputs: BatchOutputs = BatchOutputs::new();
201 for packet in inputs.packets(&port_id("in")) {
202 outputs.push(port_id("out"), packet.clone());
203 }
204 Ok(outputs)
205 }
206 }
207
208 #[test]
209 fn batch_inputs_preserve_port_order_and_packet_order() {
210 let mut inputs: BatchInputs = BatchInputs::new();
211 inputs.push(port_id("right"), packet(b"second"));
212 inputs.push(port_id("left"), packet(b"first"));
213 inputs.push(port_id("right"), packet(b"third"));
214
215 assert_eq!(
216 inputs
217 .packets_by_port()
218 .keys()
219 .map(PortId::as_str)
220 .collect::<Vec<_>>(),
221 vec!["left", "right"]
222 );
223 assert_eq!(inputs.packets(&port_id("right")).len(), 2);
224 }
225
226 #[test]
227 fn wasm_module_invokes_opaque_batch_executor() {
228 let module: WasmModule = WasmModule::new(Box::new(EchoBatchExecutor));
229 let mut inputs: BatchInputs = BatchInputs::new();
230 inputs.push(port_id("in"), packet(b"payload"));
231
232 let outputs: BatchOutputs = module.invoke(inputs).expect("batch should run");
233
234 assert_eq!(outputs.packets(&port_id("out")).len(), 1);
235 assert_eq!(
236 outputs.packets(&port_id("out"))[0]
237 .payload()
238 .as_bytes()
239 .expect("payload should contain bytes")
240 .as_ref(),
241 b"payload"
242 );
243 }
244
245 #[test]
246 fn batch_executor_accepts_empty_inputs() {
247 let module: WasmModule = WasmModule::new(Box::new(EchoBatchExecutor));
248
249 let outputs: BatchOutputs = module.invoke(BatchInputs::new()).expect("batch should run");
250
251 assert!(outputs.packets_by_port().is_empty());
252 }
253}