use crate::channels::{ChannelItem, TypeErasedReceiver, TypeErasedSender};
use crate::execution::ExecutionError;
use crate::http_server::nodes::producer::HttpRequestProducerConfig;
use crate::http_server::types::HttpServerRequest;
use crate::traits::{NodeKind, NodeTrait};
use axum::extract::Request;
use std::any::Any;
use std::sync::Arc;
use tokio::sync::{Mutex, mpsc};
use tracing::{debug, error, info, warn};
#[derive(Debug, Clone, Default)]
pub struct HttpServerProducerConfig {
pub http_config: HttpRequestProducerConfig,
}
#[derive(Debug)]
pub struct HttpServerProducerNode {
name: String,
config: HttpServerProducerConfig,
request_receiver: Arc<Mutex<Option<mpsc::Receiver<Request>>>>,
output_port_names: Vec<String>,
}
impl HttpServerProducerNode {
pub fn new(name: String, config: HttpServerProducerConfig) -> Self {
debug!(node = %name, "HttpServerProducerNode::new()");
Self {
name,
config,
request_receiver: Arc::new(Mutex::new(None)),
output_port_names: vec!["out".to_string()],
}
}
pub fn with_default_config(name: String) -> Self {
debug!(node = %name, "HttpServerProducerNode::with_default_config()");
Self::new(name, HttpServerProducerConfig::default())
}
pub(crate) async fn set_request_receiver(&self, receiver: mpsc::Receiver<Request>) {
debug!(node = %self.name, "HttpServerProducerNode::set_request_receiver()");
let mut guard = self.request_receiver.lock().await;
*guard = Some(receiver);
}
#[must_use]
pub fn config(&self) -> &HttpServerProducerConfig {
&self.config
}
pub fn config_mut(&mut self) -> &mut HttpServerProducerConfig {
&mut self.config
}
}
impl Clone for HttpServerProducerNode {
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
config: self.config.clone(),
request_receiver: Arc::clone(&self.request_receiver),
output_port_names: self.output_port_names.clone(),
}
}
}
impl NodeTrait for HttpServerProducerNode {
const INPUT_PORTS: &'static [&'static str] = &[];
fn name(&self) -> &str {
&self.name
}
fn node_kind(&self) -> NodeKind {
NodeKind::Producer
}
fn input_port_names(&self) -> Vec<String> {
vec![] }
fn output_port_names(&self) -> Vec<String> {
self.output_port_names.clone()
}
fn has_input_port(&self, _port_name: &str) -> bool {
false }
fn has_output_port(&self, port_name: &str) -> bool {
self.output_port_names.iter().any(|name| name == port_name)
}
fn spawn_execution_task(
&self,
_input_channels: std::collections::HashMap<String, TypeErasedReceiver>,
output_channels: std::collections::HashMap<String, TypeErasedSender>,
pause_signal: std::sync::Arc<tokio::sync::RwLock<bool>>,
_use_shared_memory: bool,
_arc_pool: Option<std::sync::Arc<crate::zero_copy::ArcPool<bytes::Bytes>>>,
) -> Option<tokio::task::JoinHandle<Result<(), ExecutionError>>> {
let node_name = self.name.clone();
debug!(
node = %node_name,
output_channels = output_channels.len(),
"HttpServerProducerNode::spawn_execution_task()"
);
let request_receiver = Arc::clone(&self.request_receiver);
let handle = tokio::spawn(async move {
let mut receiver_guard = request_receiver.lock().await;
let mut receiver = match receiver_guard.take() {
Some(rx) => rx,
None => {
error!(
node = %node_name,
"Request receiver not set - GraphServer must call set_request_receiver()"
);
return Err(ExecutionError::NodeExecutionFailed {
node: node_name.clone(),
reason: "Request receiver not set".to_string(),
message_id: None,
});
}
};
loop {
let request_result =
tokio::time::timeout(tokio::time::Duration::from_millis(100), receiver.recv()).await;
let axum_request = match request_result {
Ok(Some(req)) => req,
Ok(None) => {
info!(
node = %node_name,
"Request channel closed, terminating producer"
);
break;
}
Err(_) => {
let paused = *pause_signal.read().await;
if paused {
return Ok(());
}
continue;
}
};
let pause_check_result =
tokio::time::timeout(tokio::time::Duration::from_millis(100), async {
while *pause_signal.read().await {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
})
.await;
if pause_check_result.is_err() && *pause_signal.read().await {
return Ok(());
}
let http_request = HttpServerRequest::from_axum_request(axum_request).await;
tracing::info!(
node = %node_name,
request_id = %http_request.request_id,
method = ?http_request.method,
path = %http_request.path,
"Received HTTP request"
);
let request_id_clone = http_request.request_id.clone();
let message_id = crate::message::MessageId::new_custom(http_request.request_id.clone());
let message = crate::message::Message::new(http_request, message_id);
let message_arc = Arc::new(message);
let is_fan_out = output_channels.len() > 1;
let mut send_succeeded = 0;
let mut send_failed_count = 0;
if output_channels.is_empty() {
warn!(
node = %node_name,
request_id = %request_id_clone,
"No output channels available - message will not be sent"
);
continue;
}
tracing::info!(
node = %node_name,
request_id = %request_id_clone,
output_channels = output_channels.len(),
"Sending message to output channels"
);
if is_fan_out {
for (port_name, sender) in &output_channels {
let arc_any: Arc<dyn Any + Send + Sync> = unsafe {
Arc::from_raw(Arc::into_raw(message_arc.clone()) as *const (dyn Any + Send + Sync))
};
match sender.send(ChannelItem::Arc(arc_any)).await {
Ok(()) => {
send_succeeded += 1;
tracing::info!(
node = %node_name,
request_id = %request_id_clone,
port = %port_name,
"Message sent successfully (fan-out)"
);
}
Err(e) => {
tracing::error!(
node = %node_name,
request_id = %request_id_clone,
port = %port_name,
error = %e,
"Failed to send message"
);
let paused = *pause_signal.read().await;
if paused {
return Ok(());
}
send_failed_count += 1;
warn!(
node = %node_name,
port = %port_name,
"Output channel receiver dropped (may be normal in fan-out scenarios)"
);
}
}
}
} else {
let (port_name, sender) = output_channels.iter().next().unwrap();
let arc_any: Arc<dyn Any + Send + Sync> =
unsafe { Arc::from_raw(Arc::into_raw(message_arc) as *const (dyn Any + Send + Sync)) };
match sender.send(ChannelItem::Arc(arc_any)).await {
Ok(()) => {
send_succeeded += 1;
tracing::info!(
node = %node_name,
request_id = %request_id_clone,
port = %port_name,
"Message sent successfully (single output)"
);
}
Err(e) => {
tracing::error!(
node = %node_name,
request_id = %request_id_clone,
port = %port_name,
error = %e,
"Failed to send message"
);
let paused = *pause_signal.read().await;
if paused {
return Ok(());
}
send_failed_count += 1;
warn!(
node = %node_name,
port = port_name,
"Output channel receiver dropped"
);
}
}
}
if send_failed_count > 0 && send_succeeded == 0 {
return Ok(());
}
}
Ok(())
});
Some(handle)
}
}