use crate::error::{ComponentInfo, ErrorAction, ErrorContext, ErrorStrategy, StreamError};
use crate::{Consumer, ConsumerConfig};
use crate::{HttpServerResponse, Input};
use async_stream::stream;
use async_trait::async_trait;
use axum::body::Bytes;
use axum::{body::Body, response::Response};
use futures::Stream;
use futures::StreamExt;
use std::pin::Pin;
use tracing::{error, warn};
#[derive(Debug, Clone)]
pub struct HttpResponseConsumerConfig {
pub stream_response: bool,
pub max_items: Option<usize>,
pub merge_responses: bool,
pub default_status: axum::http::StatusCode,
}
impl Default for HttpResponseConsumerConfig {
fn default() -> Self {
Self {
stream_response: false,
max_items: None,
merge_responses: false,
default_status: axum::http::StatusCode::OK,
}
}
}
impl HttpResponseConsumerConfig {
#[must_use]
pub fn with_stream_response(mut self, stream: bool) -> Self {
self.stream_response = stream;
self
}
#[must_use]
pub fn with_max_items(mut self, max: Option<usize>) -> Self {
self.max_items = max;
self
}
#[must_use]
pub fn with_merge_responses(mut self, merge: bool) -> Self {
self.merge_responses = merge;
self
}
#[must_use]
pub fn with_default_status(mut self, status: axum::http::StatusCode) -> Self {
self.default_status = status;
self
}
}
#[derive(Debug)]
pub struct HttpResponseConsumer {
pub config: ConsumerConfig<HttpServerResponse>,
pub http_config: HttpResponseConsumerConfig,
pub responses: Vec<HttpServerResponse>,
pub finished: bool,
}
impl HttpResponseConsumer {
#[must_use]
pub fn new() -> Self {
Self {
config: ConsumerConfig::default(),
http_config: HttpResponseConsumerConfig::default(),
responses: Vec::new(),
finished: false,
}
}
#[must_use]
pub fn with_config(http_config: HttpResponseConsumerConfig) -> Self {
Self {
config: ConsumerConfig::default(),
http_config,
responses: Vec::new(),
finished: false,
}
}
#[must_use]
pub fn with_error_strategy(mut self, strategy: ErrorStrategy<HttpServerResponse>) -> Self {
self.config.error_strategy = strategy;
self
}
#[must_use]
pub fn with_name(mut self, name: String) -> Self {
self.config.name = name;
self
}
pub async fn get_response(&mut self) -> Response<Body> {
if self.responses.is_empty() {
HttpServerResponse::new(
self.http_config.default_status,
Vec::new(),
crate::http_server::types::ContentType::Text,
)
.to_axum_response()
} else if self.responses.len() == 1 {
self.responses.remove(0).to_axum_response()
} else if self.http_config.merge_responses {
self.merge_responses().to_axum_response()
} else {
self.responses.remove(0).to_axum_response()
}
}
pub async fn create_streaming_response(
&self,
stream: impl futures::Stream<Item = HttpServerResponse> + Send + 'static,
) -> Response<Body> {
use std::pin::Pin;
use std::sync::{Arc, Mutex};
let component_name = self.config.name.clone();
let error_strategy = self.config.error_strategy.clone();
let default_status = self.http_config.default_status;
let first_response_meta: Arc<Mutex<Option<(axum::http::StatusCode, axum::http::HeaderMap)>>> =
Arc::new(Mutex::new(None));
let first_response_meta_clone = first_response_meta.clone();
let pinned_stream = Box::pin(stream);
let body_stream: Pin<Box<dyn futures::Stream<Item = Result<Bytes, std::io::Error>> + Send>> =
Box::pin(stream! {
let mut stream = pinned_stream;
let mut first_response_handled = false;
while let Some(response) = stream.next().await {
if !first_response_handled {
let content_type = response.content_type.clone();
let mut headers = response.headers.clone();
if !headers.contains_key("content-type") {
let content_type_value = axum::http::HeaderValue::from_str(content_type.as_str())
.unwrap_or_else(|_| axum::http::HeaderValue::from_static("application/octet-stream"));
headers.insert("content-type", content_type_value);
}
headers.insert(
"transfer-encoding",
axum::http::HeaderValue::from_static("chunked"),
);
*first_response_meta_clone.lock().unwrap() = Some((response.status, headers));
first_response_handled = true;
}
if response.status.as_u16() >= 400 {
let error = StreamError::new(
Box::new(std::io::Error::other(format!(
"HTTP error status: {}",
response.status
))),
ErrorContext {
timestamp: chrono::Utc::now(),
item: Some(response.clone()),
component_name: component_name.clone(),
component_type: std::any::type_name::<HttpResponseConsumer>().to_string(),
},
ComponentInfo {
name: component_name.clone(),
type_name: std::any::type_name::<HttpResponseConsumer>().to_string(),
},
);
match error_strategy {
ErrorStrategy::Stop => {
error!(
component = %component_name,
error = %error,
"Stopping stream due to HTTP error status"
);
break;
}
ErrorStrategy::Skip => {
warn!(
component = %component_name,
error = %error,
"Skipping response with error status"
);
continue;
}
ErrorStrategy::Retry(_) => {
warn!(
component = %component_name,
error = %error,
"Cannot retry HTTP response streaming"
);
continue;
}
ErrorStrategy::Custom(_) => {
warn!(
component = %component_name,
"Custom error handler not applicable for streaming"
);
continue;
}
}
}
if !response.body.is_empty() {
yield Ok(Bytes::from(response.body));
}
}
});
let body = Body::from_stream(body_stream);
let (status, headers) = if let Some(meta) = first_response_meta.lock().unwrap().take() {
meta
} else {
let mut default_headers = axum::http::HeaderMap::new();
default_headers.insert(
"content-type",
axum::http::HeaderValue::from_static("text/plain"),
);
(default_status, default_headers)
};
let mut response = Response::builder().status(status).body(body).unwrap();
*response.headers_mut() = headers;
response
}
fn merge_responses(&mut self) -> HttpServerResponse {
if self.responses.is_empty() {
return HttpServerResponse::new(
self.http_config.default_status,
Vec::new(),
crate::http_server::types::ContentType::Text,
);
}
let first = self.responses.remove(0);
let mut merged_body = first.body.clone();
let status = first.status;
let content_type = first.content_type;
let mut headers = first.headers;
for response in &self.responses {
merged_body.extend_from_slice(&response.body);
for (key, value) in response.headers.iter() {
if !headers.contains_key(key) {
headers.insert(key.clone(), value.clone());
}
}
}
HttpServerResponse {
request_id: first.request_id.clone(),
status,
headers,
body: merged_body,
content_type,
}
}
#[must_use]
pub fn http_config(&self) -> &HttpResponseConsumerConfig {
&self.http_config
}
#[must_use]
pub fn responses(&self) -> &[HttpServerResponse] {
&self.responses
}
}
impl Default for HttpResponseConsumer {
fn default() -> Self {
Self::new()
}
}
impl Clone for HttpResponseConsumer {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
http_config: self.http_config.clone(),
responses: self.responses.clone(),
finished: self.finished,
}
}
}
impl Input for HttpResponseConsumer {
type Input = HttpServerResponse;
type InputStream = Pin<Box<dyn Stream<Item = Self::Input> + Send>>;
}
#[async_trait]
impl Consumer for HttpResponseConsumer {
type InputPorts = (HttpServerResponse,);
async fn consume(&mut self, mut stream: Self::InputStream) {
let component_name = self.config.name.clone();
let max_items = self.http_config.max_items;
let mut count = 0;
while let Some(response) = stream.next().await {
count += 1;
if let Some(max) = max_items
&& count > max
{
warn!(
component = %component_name,
"Maximum items limit ({}) reached, stopping consumption",
max
);
break;
}
if response.status.as_u16() >= 400 {
let error = StreamError::new(
Box::new(std::io::Error::other(format!(
"HTTP error status: {}",
response.status
))),
ErrorContext {
timestamp: chrono::Utc::now(),
item: Some(response.clone()),
component_name: component_name.clone(),
component_type: std::any::type_name::<HttpResponseConsumer>().to_string(),
},
ComponentInfo {
name: component_name.clone(),
type_name: std::any::type_name::<HttpResponseConsumer>().to_string(),
},
);
match self.handle_error(&error) {
ErrorAction::Stop => {
error!(
component = %component_name,
error = %error,
"Stopping due to HTTP error status"
);
break;
}
ErrorAction::Skip => {
warn!(
component = %component_name,
error = %error,
"Skipping response with error status"
);
continue;
}
ErrorAction::Retry => {
warn!(
component = %component_name,
error = %error,
"Cannot retry HTTP response consumption"
);
continue;
}
}
}
self.responses.push(response);
}
self.finished = true;
}
fn set_config_impl(&mut self, config: ConsumerConfig<HttpServerResponse>) {
self.config = config;
}
fn get_config_impl(&self) -> &ConsumerConfig<HttpServerResponse> {
&self.config
}
fn get_config_mut_impl(&mut self) -> &mut ConsumerConfig<HttpServerResponse> {
&mut self.config
}
}