use crate::{
actor::ActorContext,
connector::{ConnectionPoint, Connector, InitialPacket},
network::{Network, NetworkConfig},
};
use reflow_actor_macro::actor;
use std::{collections::HashMap, sync::Arc, time::Duration};
use serde_json::Value;
use crate::{
actor::{Actor, ActorBehavior, MemoryState, Port},
message::Message,
};
#[actor(
SumActor,
inports::<100>(A, B),
outports::<100>(Out),
await_all_inports
)]
async fn sum_actor(context: ActorContext) -> Result<HashMap<String, Message>, anyhow::Error> {
let payload = context.get_payload();
println!("SumActor received payload: {:?}", payload);
let _a = payload.get("A").expect("expected to get data from port A");
let _b = payload.get("B").expect("expected to get data from port B");
let a = match _a {
Message::Integer(value) => *value,
_ => 0,
};
let b = match _b {
Message::Integer(value) => *value,
_ => 0,
};
Ok([("Out".to_owned(), Message::Integer(a + b))].into())
}
#[actor(
SquareActor,
inports::<100>(In),
outports::<50>(Out)
)]
async fn square_actor(context: ActorContext) -> Result<HashMap<String, Message>, anyhow::Error> {
let payload = context.get_payload();
let message = payload.get("In").unwrap();
let input = match message {
Message::Integer(value) => *value,
_ => 0,
};
Ok([("Out".to_owned(), Message::Integer(input * input))].into())
}
#[actor(
AssertEqActor,
state(MemoryState),
inports::<100>(A, B),
outports(Out),
await_all_inports
)]
async fn _assert_eq(context: ActorContext) -> Result<HashMap<String, Message>, anyhow::Error> {
let payload = context.get_payload();
let data_a = payload.get("A").expect("expected to get data from port A");
let data_b = payload.get("B").expect("expected to get data from port B");
let a = match data_a {
Message::Integer(value) => *value,
_ => 0,
};
let b = match data_b {
Message::Integer(value) => *value,
_ => 0,
};
assert_eq!(a, b);
println!("====================================");
println!("|| [ASSERT_LOG] {} == {} ||", a, b);
println!("====================================");
Ok([].into())
}
#[tokio::test]
async fn test_network() -> Result<(), anyhow::Error> {
let mut network = Network::new(NetworkConfig::default());
let sum_id = "Sum";
let square_id = "Square";
let asser_eq_id = "AssertEq";
network.register_actor("sum_process", SumActor::new())?;
network.register_actor("square_process", SquareActor::new())?;
network.register_actor("assert_eq_process", AssertEqActor::new())?;
network.add_node(sum_id, "sum_process", None)?;
network.add_node(square_id, "square_process", None)?;
network.add_node(asser_eq_id, "assert_eq_process", None)?;
network.add_connection(Connector {
from: ConnectionPoint {
actor: sum_id.to_owned(),
port: "Out".to_owned(),
..Default::default()
},
to: ConnectionPoint {
actor: square_id.to_owned(),
port: "In".to_owned(),
..Default::default()
},
});
network.add_connection(Connector {
from: ConnectionPoint {
actor: square_id.to_owned(),
port: "Out".to_owned(),
..Default::default()
},
to: ConnectionPoint {
actor: asser_eq_id.to_owned(),
port: "A".to_owned(),
..Default::default()
},
});
network.add_initial(InitialPacket {
to: ConnectionPoint {
actor: sum_id.to_owned(),
port: "A".to_owned(),
initial_data: Some(Message::Integer(2)),
},
});
network.add_initial(InitialPacket {
to: ConnectionPoint {
actor: sum_id.to_owned(),
port: "B".to_owned(),
initial_data: Some(Message::Integer(3)),
},
});
network.add_initial(InitialPacket {
to: ConnectionPoint {
actor: asser_eq_id.to_owned(),
port: "B".to_owned(),
initial_data: Some(Message::Integer(25)),
},
});
network.start()?;
tokio::time::sleep(Duration::from_secs(2)).await;
network.shutdown();
Ok(())
}
#[actor(
TransformActor,
inports::<100>(Input),
outports::<50>(Output),
state(MemoryState)
)]
async fn transform_actor(context: ActorContext) -> Result<HashMap<String, Message>, anyhow::Error> {
let payload = context.get_payload();
let state = context.get_state();
let input = payload.get("Input").expect("expected Input data");
let count = {
let mut count = 0;
let mut state = state.lock();
if let Some(state) = state.as_mut_any().downcast_mut::<MemoryState>() {
count = state
.get("count")
.unwrap_or(&Value::Number(0.into()))
.as_i64()
.unwrap_or(0);
state.insert("count", Value::Number((count + 1).into()));
}
count
};
let result = match input {
Message::Integer(n) => Message::Integer(n + count),
Message::String(s) => Message::string(format!("{}{}", s, count)),
_ => Message::any(Value::Null.into()),
};
Ok([("Output".to_owned(), result)].into())
}
#[actor(
FilterActor,
inports::<100>(In),
outports::<50>(Passed, Failed)
)]
async fn filter_actor(
context: ActorContext,
) -> Result<HashMap<std::string::String, Message>, anyhow::Error> {
let payload = context.get_payload();
let input = payload.get("In").expect("expected input");
match input {
Message::Integer(n) if *n > 0 => Ok([("Passed".to_owned(), input.clone())].into()),
_ => Ok([("Failed".to_owned(), input.clone())].into()),
}
}
#[actor(
AggregatorActor,
inports::<100>(Value),
outports::<50>(Sum, Count),
state(MemoryState)
)]
async fn aggregator_actor(
context: ActorContext,
) -> Result<HashMap<String, Message>, anyhow::Error> {
let payload = context.get_payload();
let state = context.get_state();
let value = payload.get("Value").expect("expected Value");
let mut result = HashMap::new();
let mut sum = 0;
let mut count = 0;
let mut state = state.lock();
if let Some(state) = state.as_mut_any().downcast_mut::<MemoryState>() {
sum = state
.get("sum")
.unwrap_or(&Value::Number(0.into()))
.as_i64()
.unwrap_or(0);
count = state
.get("count")
.unwrap_or(&Value::Number(0.into()))
.as_i64()
.unwrap_or(0);
if let Message::Integer(n) = value {
sum += n;
count += 1;
}
state.insert("sum", Value::Number(sum.into()));
state.insert("count", Value::Number(count.into()));
}
result.insert("Sum".to_owned(), Message::Integer(sum));
result.insert("Count".to_owned(), Message::Integer(count));
println!("AggregatorActor: Sum = {}, Count = {}", sum, count);
Ok(result)
}
#[tokio::test]
async fn test_complex_network() -> Result<(), anyhow::Error> {
let mut network = Network::new(NetworkConfig::default());
network.register_actor("transform", TransformActor::new())?;
network.register_actor("filter", FilterActor::new())?;
network.register_actor("aggregator", AggregatorActor::new())?;
network.add_node("transform1", "transform", None)?;
network.add_node("filter1", "filter", None)?;
network.add_node("aggregator1", "aggregator", None)?;
network.add_connection(Connector {
from: ConnectionPoint {
actor: "transform1".to_owned(),
port: "Output".to_owned(),
..Default::default()
},
to: ConnectionPoint {
actor: "filter1".to_owned(),
port: "In".to_owned(),
..Default::default()
},
});
network.add_connection(Connector {
from: ConnectionPoint {
actor: "filter1".to_owned(),
port: "Passed".to_owned(),
..Default::default()
},
to: ConnectionPoint {
actor: "aggregator1".to_owned(),
port: "Value".to_owned(),
..Default::default()
},
});
for i in 1..=5 {
network.add_initial(InitialPacket {
to: ConnectionPoint {
actor: "transform1".to_owned(),
port: "Input".to_owned(),
initial_data: Some(Message::Integer(i)),
},
});
}
network.start()?;
tokio::time::sleep(Duration::from_secs(2)).await;
network.shutdown();
Ok(())
}
#[actor(
ForwardActor,
inports::<100>(In),
outports::<100>(Out)
)]
async fn forward_actor(context: ActorContext) -> Result<HashMap<String, Message>, anyhow::Error> {
let payload = context.get_payload();
let input = payload.get("In").expect("expected input on In port");
Ok([("Out".to_owned(), input.clone())].into())
}
#[tokio::test]
async fn test_fanout_broadcast() -> Result<(), anyhow::Error> {
let mut network = Network::new(NetworkConfig::default());
network.register_actor("transform", TransformActor::new())?;
network.register_actor("forward_a", ForwardActor::new())?;
network.register_actor("forward_b", ForwardActor::new())?;
network.add_node("source", "transform", None)?;
network.add_node("sink_a", "forward_a", None)?;
network.add_node("sink_b", "forward_b", None)?;
network.add_connection(Connector {
from: ConnectionPoint {
actor: "source".to_owned(),
port: "Output".to_owned(),
..Default::default()
},
to: ConnectionPoint {
actor: "sink_a".to_owned(),
port: "In".to_owned(),
..Default::default()
},
});
network.add_connection(Connector {
from: ConnectionPoint {
actor: "source".to_owned(),
port: "Output".to_owned(),
..Default::default()
},
to: ConnectionPoint {
actor: "sink_b".to_owned(),
port: "In".to_owned(),
..Default::default()
},
});
for i in 1..=3 {
network.add_initial(InitialPacket {
to: ConnectionPoint {
actor: "source".to_owned(),
port: "Input".to_owned(),
initial_data: Some(Message::Integer(i)),
},
});
}
network.start()?;
tokio::time::sleep(Duration::from_secs(2)).await;
let forward_a_outport = network.actors.get("forward_a").unwrap().get_outports().1;
let forward_b_outport = network.actors.get("forward_b").unwrap().get_outports().1;
let mut a_messages = Vec::new();
while let Ok(pkt) = forward_a_outport.try_recv() {
if let Some(msg) = pkt.get("Out") {
a_messages.push(msg.clone());
}
}
let mut b_messages = Vec::new();
while let Ok(pkt) = forward_b_outport.try_recv() {
if let Some(msg) = pkt.get("Out") {
b_messages.push(msg.clone());
}
}
println!(
"forward_a received {} messages: {:?}",
a_messages.len(),
a_messages
);
println!(
"forward_b received {} messages: {:?}",
b_messages.len(),
b_messages
);
assert_eq!(
a_messages.len(),
3,
"forward_a should have received all 3 messages (got {})",
a_messages.len()
);
assert_eq!(
b_messages.len(),
3,
"forward_b should have received all 3 messages (got {})",
b_messages.len()
);
network.shutdown();
Ok(())
}
#[actor(
StreamProducerActor,
inports::<10>(Trigger),
outports::<10>(Out)
)]
async fn stream_producer_actor(
context: ActorContext,
) -> Result<HashMap<String, Message>, anyhow::Error> {
use reflow_actor::stream::StreamFrame;
let (tx, handle) = context.create_stream(
"Out",
Some("application/octet-stream".into()),
Some(300),
Some(8),
);
tx.send_async(StreamFrame::Begin {
content_type: Some("application/octet-stream".into()),
size_hint: Some(300),
metadata: None,
})
.await
.unwrap();
for chunk in [
b"chunk-1".to_vec(),
b"chunk-2".to_vec(),
b"chunk-3".to_vec(),
] {
tx.send_async(StreamFrame::Data(Arc::new(chunk)))
.await
.unwrap();
}
tx.send_async(StreamFrame::End).await.unwrap();
Ok([("Out".to_owned(), Message::stream_handle(handle))].into())
}
#[actor(
StreamConsumerActor,
inports::<10>(In),
outports::<10>(ByteCount)
)]
async fn stream_consumer_actor(
context: ActorContext,
) -> Result<HashMap<String, Message>, anyhow::Error> {
use reflow_actor::stream::StreamFrame;
let rx = context
.take_stream_receiver("In")
.expect("expected stream receiver on In port");
let mut total_bytes: usize = 0;
let mut got_begin = false;
let mut got_end = false;
loop {
match rx.recv_async().await {
Ok(StreamFrame::Begin { .. }) => got_begin = true,
Ok(StreamFrame::Data(data)) => total_bytes += data.len(),
Ok(StreamFrame::End) => {
got_end = true;
break;
}
Ok(StreamFrame::Error(e)) => return Err(anyhow::anyhow!("Stream error: {}", e)),
Err(_) => break,
}
}
assert!(got_begin, "consumer should have received Begin frame");
assert!(got_end, "consumer should have received End frame");
Ok([("ByteCount".to_owned(), Message::Integer(total_bytes as i64))].into())
}
#[tokio::test]
async fn test_network_actor_streaming() -> Result<(), anyhow::Error> {
let mut network = Network::new(NetworkConfig::default());
network.register_actor("stream_producer", StreamProducerActor::new())?;
network.register_actor("stream_consumer", StreamConsumerActor::new())?;
network.add_node("producer", "stream_producer", None)?;
network.add_node("consumer", "stream_consumer", None)?;
network.add_connection(Connector {
from: ConnectionPoint {
actor: "producer".to_owned(),
port: "Out".to_owned(),
..Default::default()
},
to: ConnectionPoint {
actor: "consumer".to_owned(),
port: "In".to_owned(),
..Default::default()
},
});
network.add_initial(InitialPacket {
to: ConnectionPoint {
actor: "producer".to_owned(),
port: "Trigger".to_owned(),
initial_data: Some(Message::Flow),
},
});
network.start()?;
tokio::time::sleep(Duration::from_secs(2)).await;
let consumer_outport = network
.actors
.get("stream_consumer")
.unwrap()
.get_outports()
.1;
let mut byte_counts = Vec::new();
while let Ok(pkt) = consumer_outport.try_recv() {
if let Some(Message::Integer(n)) = pkt.get("ByteCount") {
byte_counts.push(*n);
}
}
assert_eq!(
byte_counts,
vec![21],
"consumer should have received 21 total bytes from the stream"
);
network.shutdown();
Ok(())
}
#[actor(
StreamConsumerGracefulActor,
inports::<10>(In),
outports::<10>(ByteCount)
)]
async fn stream_consumer_graceful_actor(
context: ActorContext,
) -> Result<HashMap<String, Message>, anyhow::Error> {
use reflow_actor::stream::StreamFrame;
match context.take_stream_receiver("In") {
Some(rx) => {
let mut total_bytes: usize = 0;
loop {
match rx.recv_async().await {
Ok(StreamFrame::Data(data)) => total_bytes += data.len(),
Ok(StreamFrame::End) | Err(_) => break,
_ => {}
}
}
Ok([("ByteCount".to_owned(), Message::Integer(total_bytes as i64))].into())
}
None => {
Ok([("ByteCount".to_owned(), Message::Integer(-1))].into())
}
}
}
#[tokio::test]
async fn test_network_stream_fanout_single_consumer() -> Result<(), anyhow::Error> {
let mut network = Network::new(NetworkConfig::default());
network.register_actor("stream_producer", StreamProducerActor::new())?;
network.register_actor("consumer_a", StreamConsumerGracefulActor::new())?;
network.register_actor("consumer_b", StreamConsumerGracefulActor::new())?;
network.add_node("producer", "stream_producer", None)?;
network.add_node("sink_a", "consumer_a", None)?;
network.add_node("sink_b", "consumer_b", None)?;
network.add_connection(Connector {
from: ConnectionPoint {
actor: "producer".to_owned(),
port: "Out".to_owned(),
..Default::default()
},
to: ConnectionPoint {
actor: "sink_a".to_owned(),
port: "In".to_owned(),
..Default::default()
},
});
network.add_connection(Connector {
from: ConnectionPoint {
actor: "producer".to_owned(),
port: "Out".to_owned(),
..Default::default()
},
to: ConnectionPoint {
actor: "sink_b".to_owned(),
port: "In".to_owned(),
..Default::default()
},
});
network.add_initial(InitialPacket {
to: ConnectionPoint {
actor: "producer".to_owned(),
port: "Trigger".to_owned(),
initial_data: Some(Message::Flow),
},
});
network.start()?;
tokio::time::sleep(Duration::from_secs(2)).await;
let mut results: Vec<i64> = Vec::new();
for name in ["consumer_a", "consumer_b"] {
let outport = network.actors.get(name).unwrap().get_outports().1;
while let Ok(pkt) = outport.try_recv() {
if let Some(Message::Integer(n)) = pkt.get("ByteCount") {
results.push(*n);
}
}
}
results.sort();
assert_eq!(
results,
vec![-1, 21],
"One consumer should get 21 bytes, the other -1 (single-consumer stream)"
);
network.shutdown();
Ok(())
}