use crate::Output;
use crate::error::{ComponentInfo, ErrorContext, ErrorStrategy, StreamError};
use crate::http_server::types::HttpServerRequest;
use crate::{Producer, ProducerConfig};
use async_stream::stream;
use async_trait::async_trait;
use axum::extract::Request;
use chrono;
use futures::Stream;
use futures::StreamExt;
#[allow(unused_imports)] use http_body_util::BodyExt;
use std::pin::Pin;
use tracing::{error, warn};
#[derive(Debug, Clone)]
pub struct HttpRequestProducerConfig {
pub extract_body: bool,
pub max_body_size: Option<usize>,
pub parse_json: bool,
pub extract_query_params: bool,
pub extract_path_params: bool,
pub stream_body: bool,
pub chunk_size: usize,
}
impl Default for HttpRequestProducerConfig {
fn default() -> Self {
Self {
extract_body: true,
max_body_size: Some(10 * 1024 * 1024), parse_json: true,
extract_query_params: true,
extract_path_params: true,
stream_body: false, chunk_size: 64 * 1024, }
}
}
impl HttpRequestProducerConfig {
#[must_use]
pub fn with_extract_body(mut self, extract: bool) -> Self {
self.extract_body = extract;
self
}
#[must_use]
pub fn with_max_body_size(mut self, size: Option<usize>) -> Self {
self.max_body_size = size;
self
}
#[must_use]
pub fn with_parse_json(mut self, parse: bool) -> Self {
self.parse_json = parse;
self
}
#[must_use]
pub fn with_extract_query_params(mut self, extract: bool) -> Self {
self.extract_query_params = extract;
self
}
#[must_use]
pub fn with_extract_path_params(mut self, extract: bool) -> Self {
self.extract_path_params = extract;
self
}
#[must_use]
pub fn with_stream_body(mut self, stream: bool) -> Self {
self.stream_body = stream;
self
}
#[must_use]
pub fn with_chunk_size(mut self, size: usize) -> Self {
self.chunk_size = size;
self
}
}
#[derive(Debug)]
pub struct HttpRequestProducer {
pub config: ProducerConfig<HttpServerRequest>,
pub http_config: HttpRequestProducerConfig,
pub request: Option<HttpServerRequest>,
pub body_stream: Option<axum::body::Body>,
}
impl HttpRequestProducer {
pub async fn from_axum_request(
axum_request: Request,
http_config: HttpRequestProducerConfig,
) -> Self {
let body_size = axum_request
.headers()
.get("content-length")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<usize>().ok());
let (parts, body) = axum_request.into_parts();
let request_without_body = Request::from_parts(parts, axum::body::Body::empty());
let mut request = HttpServerRequest::from_axum_request(request_without_body).await;
let body_stream = if http_config.extract_body {
if let Some(size) = body_size
&& let Some(max_size) = http_config.max_body_size
&& size > max_size
{
warn!(
"Request body size {} exceeds maximum {} bytes, body will not be extracted",
size, max_size
);
return Self {
config: ProducerConfig::default(),
http_config,
request: Some(request),
body_stream: None,
};
}
if http_config.stream_body {
request.body = None;
Some(body)
} else {
let body_result = axum::body::to_bytes(body, usize::MAX).await;
match body_result {
Ok(body_bytes) => {
request.body = Some(body_bytes.to_vec());
if http_config.parse_json
&& request.is_content_type(crate::http_server::types::ContentType::Json)
{
}
}
Err(e) => {
warn!(
error = %e,
"Failed to extract request body"
);
}
}
None
}
} else {
None
};
Self {
config: ProducerConfig::default(),
http_config,
request: Some(request),
body_stream,
}
}
#[must_use]
pub fn with_error_strategy(mut self, strategy: ErrorStrategy<HttpServerRequest>) -> Self {
self.config.error_strategy = strategy;
self
}
#[must_use]
pub fn with_name(mut self, name: String) -> Self {
self.config.name = Some(name);
self
}
pub fn set_path_params(&mut self, params: std::collections::HashMap<String, String>) {
if let Some(ref mut request) = self.request {
request.path_params = params;
}
}
#[must_use]
pub fn http_config(&self) -> &HttpRequestProducerConfig {
&self.http_config
}
}
impl Clone for HttpRequestProducer {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
http_config: self.http_config.clone(),
request: self.request.clone(),
body_stream: None,
}
}
}
impl Output for HttpRequestProducer {
type Output = HttpServerRequest;
type OutputStream = Pin<Box<dyn Stream<Item = Self::Output> + Send>>;
}
#[async_trait]
impl Producer for HttpRequestProducer {
type OutputPorts = (HttpServerRequest,);
fn produce(&mut self) -> Self::OutputStream {
let component_name = self
.config
.name
.clone()
.unwrap_or_else(|| "http_request_producer".to_string());
let error_strategy = self.config.error_strategy.clone();
let request = self.request.take();
let body_stream = self.body_stream.take();
let _chunk_size = self.http_config.chunk_size; let stream_body = self.http_config.stream_body;
Box::pin(stream! {
match request {
Some(req) => {
yield req.clone();
if stream_body
&& let Some(body) = body_stream {
let mut body_stream = body.into_data_stream();
let mut total_bytes = 0u64;
while let Some(chunk_result) = body_stream.next().await {
match chunk_result {
Ok(chunk) => {
total_bytes += chunk.len() as u64;
let mut chunk_headers = req.headers.clone();
chunk_headers.insert(
axum::http::HeaderName::from_static("x-streamweave-chunk-offset"),
axum::http::HeaderValue::from_str(&total_bytes.to_string())
.unwrap_or_else(|_| axum::http::HeaderValue::from_static("0")),
);
chunk_headers.insert(
axum::http::HeaderName::from_static("x-streamweave-chunk-size"),
axum::http::HeaderValue::from_str(&chunk.len().to_string())
.unwrap_or_else(|_| axum::http::HeaderValue::from_static("0")),
);
let chunk_request = HttpServerRequest {
request_id: req.request_id.clone(),
method: req.method,
uri: req.uri.clone(),
path: req.path.clone(),
headers: chunk_headers,
query_params: req.query_params.clone(),
path_params: req.path_params.clone(),
body: Some(chunk.to_vec()),
content_type: req.content_type.clone(),
remote_addr: req.remote_addr.clone(),
};
yield chunk_request;
}
Err(e) => {
warn!(
component = %component_name,
error = %e,
total_bytes = total_bytes,
"Error reading body chunk during streaming"
);
match error_strategy {
ErrorStrategy::Stop => {
error!(
component = %component_name,
error = %e,
"Stopping stream due to body read error"
);
break;
}
ErrorStrategy::Skip => {
continue;
}
ErrorStrategy::Retry(_) => {
warn!(
component = %component_name,
"Cannot retry body chunk read, skipping"
);
continue;
}
ErrorStrategy::Custom(_) => {
warn!(
component = %component_name,
"Custom error handler not applicable for body chunk read"
);
continue;
}
}
}
}
}
}
}
None => {
let error: StreamError<HttpServerRequest> = StreamError::new(
Box::new(std::io::Error::other("No request available")),
ErrorContext {
timestamp: chrono::Utc::now(),
item: None,
component_name: component_name.clone(),
component_type: std::any::type_name::<HttpRequestProducer>().to_string(),
},
ComponentInfo {
name: component_name.clone(),
type_name: std::any::type_name::<HttpRequestProducer>().to_string(),
},
);
match error_strategy {
ErrorStrategy::Stop => {
error!(
component = %component_name,
error = %error,
"Stopping due to missing request"
);
}
ErrorStrategy::Skip => {
warn!(
component = %component_name,
error = %error,
"Skipping missing request"
);
}
ErrorStrategy::Retry(_) => {
warn!(
component = %component_name,
error = %error,
"Cannot retry missing request"
);
}
ErrorStrategy::Custom(_) => {
error!(
component = %component_name,
error = %error,
"Custom error handler not applicable for missing request"
);
}
}
}
}
})
}
fn set_config_impl(&mut self, config: ProducerConfig<HttpServerRequest>) {
self.config = config;
}
fn get_config_impl(&self) -> &ProducerConfig<HttpServerRequest> {
&self.config
}
fn get_config_mut_impl(&mut self) -> &mut ProducerConfig<HttpServerRequest> {
&mut self.config
}
}