use std::collections::HashMap;
use std::sync::Arc;
use parking_lot::Mutex;
use reflow_tracing_protocol::client::TracingIntegration;
use serde_json::Value;
use crate::message::{EncodableValue, Message};
use crate::{ActorBehavior, ActorConfig, ActorContext, ActorLoad, ActorState, Port};
pub struct ActorProcess {
node_id: String,
behavior: ActorBehavior,
inport_names: Vec<String>,
await_all_inports: bool,
required_inports: Vec<String>,
inport_rx: flume::Receiver<HashMap<String, Message>>,
outports: Port,
state: Arc<Mutex<dyn ActorState>>,
load: Arc<ActorLoad>,
config: ActorConfig,
tracing: Option<TracingIntegration>,
}
impl ActorProcess {
#[allow(clippy::too_many_arguments)]
pub fn new(
node_id: String,
behavior: ActorBehavior,
inport_names: Vec<String>,
await_all_inports: bool,
required_inports: Vec<String>,
inport_rx: flume::Receiver<HashMap<String, Message>>,
outports: Port,
state: Arc<Mutex<dyn ActorState>>,
load: Arc<ActorLoad>,
config: ActorConfig,
tracing: Option<TracingIntegration>,
) -> Self {
Self {
node_id,
behavior,
inport_names,
await_all_inports,
required_inports,
inport_rx,
outports,
state,
load,
config,
tracing,
}
}
pub async fn run(self) {
use futures::StreamExt;
let mut accumulated: HashMap<String, Message> = HashMap::new();
let inports_count = self.inport_names.len();
let actor_id = self.config.get_node_id();
let total_connections: usize = self.config.inport_connection_counts.values().sum();
let mut tick_message_count: usize = 0;
let mut port_counts: HashMap<String, usize> = HashMap::new();
loop {
let packet = match self.inport_rx.clone().stream().next().await {
Some(p) => p,
None => {
eprintln!("[INPORT CLOSED] {}", self.node_id);
break;
}
};
self.load.inc();
let payload = if self.await_all_inports {
merge_accumulate(&mut accumulated, packet);
tick_message_count += 1;
let needed = if total_connections > 0 {
total_connections
} else {
inports_count
};
if tick_message_count < needed {
continue;
}
tick_message_count = 0;
std::mem::take(&mut accumulated)
} else if !self.required_inports.is_empty() {
merge_accumulate(&mut accumulated, packet.clone());
tick_message_count += 1;
for port in packet.keys() {
*port_counts.entry(port.clone()).or_insert(0) += 1;
}
let has_all_required = self.required_inports.iter().all(|req| {
let needed = self
.config
.inport_connection_counts
.get(req)
.copied()
.unwrap_or(1);
let received = port_counts.get(req).copied().unwrap_or(0);
received >= needed
});
if !has_all_required {
continue;
}
tick_message_count = 0;
let payload = accumulated.clone();
for req in &self.required_inports {
accumulated.remove(req);
port_counts.remove(req);
}
payload
} else {
packet
};
let context = ActorContext::new(
payload,
self.outports.clone(),
self.state.clone(),
self.config.clone(),
self.load.clone(),
);
match (self.behavior)(context).await {
Ok(result) => {
if !result.is_empty() {
let _ = self.outports.0.send_async(result).await;
}
self.load.reset();
if let Some(ref tracing) = self.tracing {
let _ = tracing.trace_actor_completed(actor_id).await;
}
}
Err(e) => {
self.load.reset();
eprintln!("[{}] behavior error: {:?}", self.node_id, e);
if let Some(ref tracing) = self.tracing {
let _ = tracing.trace_actor_failed(actor_id, e.to_string()).await;
}
}
}
}
}
#[cfg(not(target_arch = "wasm32"))]
pub fn into_future(
self,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>> {
Box::pin(self.run())
}
#[cfg(target_arch = "wasm32")]
pub fn into_future(
self,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + 'static>> {
Box::pin(self.run())
}
}
fn merge_accumulate(accumulated: &mut HashMap<String, Message>, packet: HashMap<String, Message>) {
for (port, msg) in packet {
match accumulated.get(&port) {
Some(Message::Object(existing_obj)) => {
if let Message::Object(new_obj) = &msg {
let mut merged: Value = existing_obj.as_ref().clone().into();
let new_v: Value = new_obj.as_ref().clone().into();
if let (Some(m), Some(n)) = (merged.as_object_mut(), new_v.as_object()) {
for (k, v) in n {
m.insert(k.clone(), v.clone());
}
}
accumulated.insert(
port,
Message::Object(Arc::new(EncodableValue::from(merged))),
);
} else {
accumulated.insert(port, msg);
}
}
_ => {
accumulated.insert(port, msg);
}
}
}
}