use crate;
use crate::execution::{ExecutionError, GraphExecutor};
use crate::http_server::nodes::consumer_node::HttpServerConsumerNode;
use crate::http_server::nodes::producer_node::HttpServerProducerNode;
use axum::Router;
use axum::body::Body;
use axum::extract::Request;
use axum::http::StatusCode;
use axum::response::Response;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::time::timeout;
use tracing::{debug, error, info, trace, warn};
#[derive(Debug, Clone)]
pub struct GraphServerConfig {
pub request_timeout: Duration,
pub request_channel_buffer: usize,
}
impl Default for GraphServerConfig {
fn default() -> Self {
Self {
request_timeout: Duration::from_secs(30),
request_channel_buffer: 100,
}
}
}
pub struct GraphServer {
executor: Arc<tokio::sync::RwLock<GraphExecutor>>,
request_sender: mpsc::Sender<Request>,
producer_node: Arc<tokio::sync::RwLock<Option<Arc<HttpServerProducerNode>>>>,
consumer_node: Arc<tokio::sync::RwLock<Option<Arc<HttpServerConsumerNode>>>>,
config: GraphServerConfig,
}
impl GraphServer {
pub async fn from_graph_with_node_names(
graph: Graph,
producer_node_name: String,
consumer_node_name: String,
config: GraphServerConfig,
) -> Result<Self, ExecutionError> {
trace!(
producer_node = %producer_node_name,
consumer_node = %consumer_node_name,
"GraphServer::from_graph_with_node_names"
);
let producer_node_trait = graph.get_node(&producer_node_name).ok_or_else(|| {
ExecutionError::Other(format!(
"Producer node '{}' not found in graph",
producer_node_name
))
})?;
let consumer_node_trait = graph.get_node(&consumer_node_name).ok_or_else(|| {
ExecutionError::Other(format!(
"Consumer node '{}' not found in graph",
consumer_node_name
))
})?;
if producer_node_trait.node_kind() != crate::traits::NodeKind::Producer {
return Err(ExecutionError::Other(format!(
"Node '{}' is not a Producer (found {:?})",
producer_node_name,
producer_node_trait.node_kind()
)));
}
if consumer_node_trait.node_kind() != crate::traits::NodeKind::Consumer {
return Err(ExecutionError::Other(format!(
"Node '{}' is not a Consumer (found {:?})",
consumer_node_name,
consumer_node_trait.node_kind()
)));
}
let producer_node: Arc<HttpServerProducerNode> = {
let any_node = producer_node_trait as &dyn std::any::Any;
match any_node.downcast_ref::<HttpServerProducerNode>() {
Some(node) => Arc::new(node.clone()),
None => {
return Err(ExecutionError::Other(format!(
"Node '{}' is not an HttpServerProducerNode (found: {})",
producer_node_name,
std::any::type_name_of_val(producer_node_trait)
)));
}
}
};
let consumer_node: Arc<HttpServerConsumerNode> = {
let any_node = consumer_node_trait as &dyn std::any::Any;
match any_node.downcast_ref::<HttpServerConsumerNode>() {
Some(node) => Arc::new(node.clone()),
None => {
return Err(ExecutionError::Other(format!(
"Node '{}' is not an HttpServerConsumerNode (found: {})",
consumer_node_name,
std::any::type_name_of_val(consumer_node_trait)
)));
}
}
};
debug!(
buffer_size = config.request_channel_buffer,
"Creating request channel"
);
let (request_sender, request_receiver) = mpsc::channel(config.request_channel_buffer);
debug!(
producer_node = %producer_node_name,
"Setting request receiver on producer node"
);
producer_node.set_request_receiver(request_receiver).await;
debug!("Creating GraphExecutor");
let executor = GraphExecutor::new(graph);
let server = Self {
executor: Arc::new(tokio::sync::RwLock::new(executor)),
request_sender,
producer_node: Arc::new(tokio::sync::RwLock::new(Some(producer_node))),
consumer_node: Arc::new(tokio::sync::RwLock::new(Some(consumer_node))),
config,
};
Ok(server)
}
pub async fn from_graph(
_graph: Graph,
_config: GraphServerConfig,
) -> Result<Self, ExecutionError> {
Err(ExecutionError::Other(
"GraphServer::from_graph not yet implemented - use from_graph_with_node_names or pass nodes directly".to_string()
))
}
pub async fn start(&self) -> Result<(), ExecutionError> {
trace!("GraphServer::start()");
info!("Starting graph executor...");
let mut executor = self.executor.write().await;
let result = executor.start().await;
match &result {
Ok(()) => info!("Graph executor started successfully"),
Err(e) => error!(error = %e, "Failed to start graph executor"),
}
result
}
pub async fn stop(&self) -> Result<(), ExecutionError> {
trace!("GraphServer::stop()");
let mut executor = self.executor.write().await;
executor.stop().await
}
pub async fn handle(&self, request: Request) -> Response<Body> {
trace!("GraphServer::handle()");
let request_id = uuid::Uuid::new_v4().to_string();
debug!(request_id = %request_id, "Generated request ID");
let (response_sender, mut response_receiver) = mpsc::channel(1);
debug!(request_id = %request_id, "Registering request with consumer node");
if let Some(consumer) = self.consumer_node.read().await.as_ref() {
consumer
.register_request(request_id.clone(), response_sender)
.await;
} else {
error!("HttpServerConsumerNode not found - cannot handle request");
return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Server configuration error"))
.unwrap();
}
debug!(request_id = %request_id, "Injecting request into graph");
let mut request_with_id = request;
request_with_id
.extensions_mut()
.insert(crate::http_server::types::RequestIdExtension(
request_id.clone(),
));
if let Err(e) = self.request_sender.send(request_with_id).await {
error!(error = %e, "Failed to inject request into graph");
return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Failed to process request"))
.unwrap();
}
match timeout(self.config.request_timeout, response_receiver.recv()).await {
Ok(Some(response)) => response,
Ok(None) => {
warn!(request_id = %request_id, "Response channel closed before response received");
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Request processing failed"))
.unwrap()
}
Err(_) => {
warn!(request_id = %request_id, "Request timed out");
if let Some(consumer) = self.consumer_node.read().await.as_ref() {
consumer.unregister_request(&request_id).await;
}
Response::builder()
.status(StatusCode::REQUEST_TIMEOUT)
.body(Body::from("Request timeout"))
.unwrap()
}
}
}
pub fn handler(
&self,
) -> impl Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response<Body>> + Send>>
+ Clone {
let server = Arc::new(self.clone());
move |request: Request| {
let server = Arc::clone(&server);
Box::pin(async move { server.handle(request).await })
}
}
pub async fn serve(&self, addr: std::net::SocketAddr) -> Result<(), ExecutionError> {
trace!(addr = %addr, "GraphServer::serve()");
info!("Starting graph executor before serving...");
self.start().await?;
info!("Graph executor started, starting HTTP server...");
let server = Arc::new(self.clone());
let handler = move |request: Request| {
let server = server.clone();
async move { server.handle(request).await }
};
let router = Router::new().route("/", axum::routing::any(handler));
let listener = tokio::net::TcpListener::bind(addr)
.await
.map_err(|e| ExecutionError::Other(format!("Failed to bind to address {}: {}", addr, e)))?;
axum::serve(listener, router)
.await
.map_err(|e| ExecutionError::Other(format!("Server error: {}", e)))?;
Ok(())
}
}
impl Clone for GraphServer {
fn clone(&self) -> Self {
trace!("GraphServer::clone()");
Self {
executor: Arc::clone(&self.executor),
request_sender: self.request_sender.clone(),
producer_node: Arc::clone(&self.producer_node),
consumer_node: Arc::clone(&self.consumer_node),
config: self.config.clone(),
}
}
}