use crate::channels::{ChannelItem, TypeErasedReceiver, TypeErasedSender};
use crate::execution::ExecutionError;
use crate::http_server::types::{HttpServerRequest, HttpServerResponse};
use crate::message::Message;
use crate::traits::{NodeKind, NodeTrait};
use axum::body::Body;
use axum::response::Response;
use futures::StreamExt;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::{RwLock, mpsc};
use tracing::{debug, warn};
#[derive(Debug, Clone)]
pub struct HttpServerConsumerConfig {
pub request_timeout: std::time::Duration,
pub default_status: axum::http::StatusCode,
}
impl Default for HttpServerConsumerConfig {
fn default() -> Self {
Self {
request_timeout: std::time::Duration::from_secs(30),
default_status: axum::http::StatusCode::OK,
}
}
}
#[derive(Debug)]
pub struct HttpServerConsumerNode {
name: String,
config: HttpServerConsumerConfig,
response_senders: Arc<RwLock<HashMap<String, mpsc::Sender<Response<Body>>>>>,
input_port_names: Vec<String>,
}
impl HttpServerConsumerNode {
pub fn new(name: String, config: HttpServerConsumerConfig) -> Self {
Self {
name,
config,
response_senders: Arc::new(RwLock::new(HashMap::new())),
input_port_names: vec!["in".to_string()],
}
}
pub fn with_default_config(name: String) -> Self {
debug!(node = %name, "HttpServerConsumerNode::with_default_config()");
Self::new(name, HttpServerConsumerConfig::default())
}
pub(crate) async fn register_request(
&self,
request_id: String,
sender: mpsc::Sender<Response<Body>>,
) {
debug!(
node = %self.name,
request_id = %request_id,
"HttpServerConsumerNode::register_request()"
);
let mut senders = self.response_senders.write().await;
senders.insert(request_id, sender);
}
pub(crate) async fn unregister_request(&self, request_id: &str) {
debug!(
node = %self.name,
request_id = %request_id,
"HttpServerConsumerNode::unregister_request()"
);
let mut senders = self.response_senders.write().await;
senders.remove(request_id);
}
#[must_use]
pub fn config(&self) -> &HttpServerConsumerConfig {
&self.config
}
pub fn config_mut(&mut self) -> &mut HttpServerConsumerConfig {
&mut self.config
}
fn http_response_to_axum(response: &HttpServerResponse) -> Response<Body> {
let mut axum_response = Response::builder()
.status(response.status)
.body(Body::from(response.body.clone()))
.unwrap_or_else(|_| {
Response::builder()
.status(axum::http::StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap()
});
for (key, value) in response.headers.iter() {
if let Ok(header_name) = axum::http::HeaderName::from_bytes(key.as_str().as_bytes()) {
axum_response
.headers_mut()
.insert(header_name, value.clone());
}
}
axum_response
}
}
impl Clone for HttpServerConsumerNode {
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
config: self.config.clone(),
response_senders: Arc::clone(&self.response_senders),
input_port_names: self.input_port_names.clone(),
}
}
}
impl NodeTrait for HttpServerConsumerNode {
const INPUT_PORTS: &'static [&'static str] = &["in"];
fn name(&self) -> &str {
&self.name
}
fn node_kind(&self) -> NodeKind {
NodeKind::Consumer
}
fn input_port_names(&self) -> Vec<String> {
self.input_port_names.clone()
}
fn output_port_names(&self) -> Vec<String> {
vec![] }
fn has_input_port(&self, port_name: &str) -> bool {
self.input_port_names.iter().any(|name| name == port_name)
}
fn has_output_port(&self, _port_name: &str) -> bool {
false }
fn spawn_execution_task(
&self,
input_channels: std::collections::HashMap<String, TypeErasedReceiver>,
_output_channels: std::collections::HashMap<String, TypeErasedSender>,
pause_signal: std::sync::Arc<tokio::sync::RwLock<bool>>,
_use_shared_memory: bool,
_arc_pool: Option<std::sync::Arc<crate::zero_copy::ArcPool<bytes::Bytes>>>,
) -> Option<tokio::task::JoinHandle<Result<(), ExecutionError>>> {
let node_name = self.name.clone();
debug!(
node = %node_name,
input_channels = input_channels.len(),
"HttpServerConsumerNode::spawn_execution_task()"
);
let response_senders = Arc::clone(&self.response_senders);
let _default_status = self.config.default_status;
let handle = tokio::spawn(async move {
let receivers: Vec<(String, TypeErasedReceiver)> = input_channels.into_iter().collect();
let mut merged_stream: Pin<
Box<dyn futures::Stream<Item = Message<HttpServerResponse>> + Send>,
> = if receivers.is_empty() {
Box::pin(futures::stream::empty())
} else if receivers.len() == 1 {
let (_port_name, mut receiver) = receivers.into_iter().next().unwrap();
let node_name_clone = node_name.clone();
Box::pin(async_stream::stream! {
while let Some(channel_item) = receiver.recv().await {
debug!(
node = %node_name_clone,
"HttpServerConsumerNode: Received channel item"
);
match channel_item {
ChannelItem::Arc(arc) => {
let channel_item = ChannelItem::Arc(arc.clone());
match channel_item.downcast_message_arc::<HttpServerResponse>() {
Ok(msg_arc) => {
let message = (*msg_arc).clone();
debug!(
node = %node_name_clone,
request_id = %message.id(),
"HttpServerConsumerNode: Received HTTP response"
);
yield message;
}
Err(_) => {
let channel_item = ChannelItem::Arc(arc.clone());
match channel_item.downcast_message_arc::<HttpServerRequest>() {
Ok(msg_arc) => {
let request_message = (*msg_arc).clone();
let request = request_message.payload();
debug!(
node = %node_name_clone,
request_id = %request.request_id,
"HttpServerConsumerNode: Received HTTP request, converting to response"
);
let response_text = format!("Hello from StreamWeave! You requested: {}", request.path);
let response = HttpServerResponse::text_with_request_id(
axum::http::StatusCode::OK,
&response_text,
request.request_id.clone(),
);
let response_message = Message::with_metadata(
response,
request_message.id().clone(),
request_message.metadata().clone(),
);
yield response_message;
}
Err(_) => {
warn!(
node = %node_name_clone,
"Failed to downcast Arc to Message<HttpServerResponse> or Message<HttpServerRequest>, skipping"
);
continue;
}
}
}
}
}
ChannelItem::SharedMemory(_) => {
warn!(
node = %node_name_clone,
"SharedMemory items not yet supported in HttpServerConsumerNode"
);
continue;
}
}
}
})
} else {
let node_name_clone = node_name.clone();
let streams: Vec<_> = receivers
.into_iter()
.map(move |(_port_name, mut receiver)| {
let node_name_inner = node_name_clone.clone();
Box::pin(async_stream::stream! {
while let Some(channel_item) = receiver.recv().await {
match channel_item {
ChannelItem::Arc(arc) => {
let channel_item = ChannelItem::Arc(arc.clone());
match channel_item.downcast_message_arc::<HttpServerResponse>() {
Ok(msg_arc) => {
let message = (*msg_arc).clone();
yield message;
}
Err(_) => {
let channel_item = ChannelItem::Arc(arc.clone());
match channel_item.downcast_message_arc::<HttpServerRequest>() {
Ok(msg_arc) => {
let request_message = (*msg_arc).clone();
let request = request_message.payload();
let response_text = format!("Hello from StreamWeave! You requested: {}", request.path);
let response = HttpServerResponse::text_with_request_id(
axum::http::StatusCode::OK,
&response_text,
request.request_id.clone(),
);
let response_message = Message::with_metadata(
response,
request_message.id().clone(),
request_message.metadata().clone(),
);
yield response_message;
}
Err(_) => {
warn!(
node = %node_name_inner,
"Failed to downcast Arc to Message<HttpServerResponse> or Message<HttpServerRequest>, skipping"
);
continue;
}
}
}
}
}
ChannelItem::SharedMemory(_) => {
warn!(
node = %node_name_inner,
"SharedMemory items not yet supported in HttpServerConsumerNode"
);
continue;
}
}
}
})
as Pin<Box<dyn futures::Stream<Item = Message<HttpServerResponse>> + Send>>
})
.collect();
Box::pin(futures::stream::select_all(streams))
};
loop {
let message_result = tokio::time::timeout(
tokio::time::Duration::from_millis(100),
merged_stream.next(),
)
.await;
let message = match message_result {
Ok(Some(msg)) => msg,
Ok(None) => {
break;
}
Err(_) => {
let paused = *pause_signal.read().await;
if paused {
return Ok(());
}
continue;
}
};
let pause_check_result =
tokio::time::timeout(tokio::time::Duration::from_millis(100), async {
while *pause_signal.read().await {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
})
.await;
if pause_check_result.is_err() && *pause_signal.read().await {
return Ok(());
}
let response = message.payload();
let request_id = response.request_id.clone();
tracing::info!(
node = %node_name,
request_id = %request_id,
status = %response.status,
"Sending HTTP response"
);
let sender = {
let senders = response_senders.read().await;
senders.get(&request_id).cloned()
};
match sender {
Some(sender) => {
let axum_response = Self::http_response_to_axum(response);
if let Err(e) = sender.send(axum_response).await {
warn!(
node = %node_name,
request_id = %request_id,
error = %e,
"Failed to send response to client"
);
} else {
tracing::debug!(
node = %node_name,
request_id = %request_id,
"Response sent to client"
);
}
let mut senders = response_senders.write().await;
senders.remove(&request_id);
}
None => {
warn!(
node = %node_name,
request_id = %request_id,
"No sender found for request ID - request may have timed out"
);
}
}
}
Ok(())
});
Some(handle)
}
}