use crate::Transformer;
use crate::channels::{ChannelItem, TypeErasedReceiver, TypeErasedSender};
use crate::execution::ExecutionError;
use crate::http_server::nodes::path_router_transformer::{
PathBasedRouterTransformer, PathRouterConfig,
};
use crate::http_server::types::HttpServerRequest;
use crate::message::Message;
use crate::traits::{NodeKind, NodeTrait};
use async_stream::stream;
use futures::StreamExt;
use std::any::Any;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::{debug, warn};
#[derive(Debug, Clone, Default)]
pub struct HttpPathRouterConfig {
pub router_config: PathRouterConfig,
}
#[derive(Debug)]
pub struct HttpPathRouterNode {
name: String,
config: HttpPathRouterConfig,
transformer: Arc<Mutex<PathBasedRouterTransformer>>,
input_port_names: Vec<String>,
output_port_names: Vec<String>,
}
impl HttpPathRouterNode {
pub fn new(name: String, config: HttpPathRouterConfig) -> Self {
let transformer = PathBasedRouterTransformer::new(config.router_config.clone());
Self {
name,
config,
transformer: Arc::new(Mutex::new(transformer)),
input_port_names: vec!["in".to_string()],
output_port_names: vec![
"out_0".to_string(),
"out_1".to_string(),
"out_2".to_string(),
"out_3".to_string(),
"out_4".to_string(),
],
}
}
pub fn with_default_config(name: String) -> Self {
debug!(node = %name, "HttpPathRouterNode::with_default_config()");
Self::new(name, HttpPathRouterConfig::default())
}
pub fn add_route(&mut self, pattern: String, port: usize) {
debug!(
node = %self.name,
pattern = %pattern,
port = port,
"HttpPathRouterNode::add_route()"
);
let mut transformer = self.transformer.try_lock().unwrap();
transformer.add_route(pattern, port);
}
pub fn set_default_port(&mut self, port: Option<usize>) {
debug!(
node = %self.name,
default_port = ?port,
"HttpPathRouterNode::set_default_port()"
);
let mut transformer = self.transformer.try_lock().unwrap();
transformer.set_default_port(port);
}
#[must_use]
pub fn config(&self) -> &HttpPathRouterConfig {
&self.config
}
pub fn config_mut(&mut self) -> &mut HttpPathRouterConfig {
&mut self.config
}
}
impl Clone for HttpPathRouterNode {
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
config: self.config.clone(),
transformer: Arc::clone(&self.transformer),
input_port_names: self.input_port_names.clone(),
output_port_names: self.output_port_names.clone(),
}
}
}
impl NodeTrait for HttpPathRouterNode {
const INPUT_PORTS: &'static [&'static str] = &["in"];
fn name(&self) -> &str {
&self.name
}
fn node_kind(&self) -> NodeKind {
NodeKind::Transformer
}
fn input_port_names(&self) -> Vec<String> {
self.input_port_names.clone()
}
fn output_port_names(&self) -> Vec<String> {
self.output_port_names.clone()
}
fn has_input_port(&self, port_name: &str) -> bool {
self.input_port_names.iter().any(|name| name == port_name)
}
fn has_output_port(&self, port_name: &str) -> bool {
self.output_port_names.iter().any(|name| name == port_name)
}
fn spawn_execution_task(
&self,
input_channels: HashMap<String, TypeErasedReceiver>,
output_channels: HashMap<String, TypeErasedSender>,
pause_signal: Arc<tokio::sync::RwLock<bool>>,
_use_shared_memory: bool,
_arc_pool: Option<Arc<crate::zero_copy::ArcPool<bytes::Bytes>>>,
) -> Option<tokio::task::JoinHandle<Result<(), ExecutionError>>> {
let node_name = self.name.clone();
debug!(
node = %node_name,
input_channels = input_channels.len(),
output_channels = output_channels.len(),
"HttpPathRouterNode::spawn_execution_task()"
);
let transformer = Arc::clone(&self.transformer);
let handle = tokio::spawn(async move {
let receivers: Vec<(String, TypeErasedReceiver)> = input_channels.into_iter().collect();
let input_stream: Pin<Box<dyn futures::Stream<Item = Message<HttpServerRequest>> + Send>> =
if receivers.is_empty() {
Box::pin(futures::stream::empty())
} else if receivers.len() == 1 {
let (_port_name, mut receiver) = receivers.into_iter().next().unwrap();
let node_name_clone = node_name.clone();
Box::pin(stream! {
while let Some(channel_item) = receiver.recv().await {
tracing::info!(
node = %node_name_clone,
"Path router received channel item"
);
match channel_item {
ChannelItem::Arc(arc) => {
let channel_item = ChannelItem::Arc(arc.clone());
match channel_item.downcast_message_arc::<HttpServerRequest>() {
Ok(msg_arc) => {
let message = (*msg_arc).clone();
let request = message.payload();
tracing::info!(
node = %node_name_clone,
request_id = %request.request_id,
path = %request.path,
"Path router received HTTP request"
);
yield message;
}
Err(_) => {
warn!(
node = %node_name_clone,
"Failed to downcast Arc to Message<HttpServerRequest>, skipping"
);
continue;
}
}
}
ChannelItem::SharedMemory(_) => {
warn!(
node = %node_name_clone,
"SharedMemory items not yet supported in HttpPathRouterNode"
);
continue;
}
}
}
})
} else {
let node_name_clone = node_name.clone();
let streams: Vec<_> = receivers
.into_iter()
.map(move |(_port_name, mut receiver)| {
let node_name_inner = node_name_clone.clone();
Box::pin(stream! {
while let Some(channel_item) = receiver.recv().await {
tracing::info!(
node = %node_name_inner,
"Path router received channel item (multi-input)"
);
match channel_item {
ChannelItem::Arc(arc) => {
let channel_item = ChannelItem::Arc(arc.clone());
match channel_item.downcast_message_arc::<HttpServerRequest>() {
Ok(msg_arc) => {
let message = (*msg_arc).clone();
let request = message.payload();
tracing::info!(
node = %node_name_inner,
request_id = %request.request_id,
path = %request.path,
"Path router received HTTP request (multi-input)"
);
yield message;
}
Err(_) => {
warn!(
node = %node_name_inner,
"Failed to downcast Arc to Message<HttpServerRequest>, skipping"
);
continue;
}
}
}
ChannelItem::SharedMemory(_) => {
warn!(
node = %node_name_inner,
"SharedMemory items not yet supported in HttpPathRouterNode"
);
continue;
}
}
}
})
as Pin<Box<dyn futures::Stream<Item = Message<HttpServerRequest>> + Send>>
})
.collect();
Box::pin(futures::stream::select_all(streams))
};
let mut transformer_guard = transformer.lock().await;
let output_stream = transformer_guard.transform(input_stream).await;
drop(transformer_guard);
let mut output_stream = std::pin::pin!(output_stream);
loop {
let item_result = tokio::time::timeout(
tokio::time::Duration::from_millis(100),
output_stream.next(),
)
.await;
let output_tuple = match item_result {
Ok(Some(tuple)) => tuple,
Ok(None) => {
break;
}
Err(_) => {
let paused = *pause_signal.read().await;
if paused {
return Ok(());
}
continue;
}
};
let pause_check_result =
tokio::time::timeout(tokio::time::Duration::from_millis(100), async {
while *pause_signal.read().await {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
})
.await;
if pause_check_result.is_err() && *pause_signal.read().await {
return Ok(());
}
let ports = [
("out_0", &output_tuple.0),
("out_1", &output_tuple.1),
("out_2", &output_tuple.2),
("out_3", &output_tuple.3),
("out_4", &output_tuple.4),
];
for (port_name, opt_message) in ports.iter() {
if let Some(message) = opt_message {
let request = message.payload();
tracing::info!(
node = %node_name,
request_id = %request.request_id,
path = %request.path,
output_port = %port_name,
"Routing HTTP request"
);
if let Some(sender) = output_channels.get(*port_name) {
let message_arc = Arc::new(message.clone());
let arc_any: Arc<dyn Any + Send + Sync> = unsafe {
Arc::from_raw(Arc::into_raw(message_arc) as *const (dyn Any + Send + Sync))
};
if sender.send(ChannelItem::Arc(arc_any)).await.is_err() {
let paused = *pause_signal.read().await;
if paused {
return Ok(());
}
warn!(
node = %node_name,
port = %port_name,
"Output channel receiver dropped"
);
}
}
}
}
}
Ok(())
});
Some(handle)
}
}