use crate::protocol::error::Error;
use crate::protocol::headers::{self, names};
use crate::protocol::json_mode;
use crate::protocol::problem::{ProblemResponse, Result, request_instance};
use crate::protocol::stream_name::StreamName;
use crate::router::StreamBasePath;
use crate::storage::{CreateStreamResult, CreateWithDataResult, Storage, StreamConfig};
use axum::{
Extension,
body::Body,
extract::{OriginalUri, State},
http::{HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response},
};
use chrono::Utc;
use std::sync::Arc;
pub async fn create_stream<S: Storage>(
State(storage): State<Arc<S>>,
StreamName(name): StreamName,
original_uri: OriginalUri,
Extension(StreamBasePath(stream_base_path)): Extension<StreamBasePath>,
headers: HeaderMap,
body: Body,
) -> Result<Response> {
let instance = request_instance(&original_uri);
let result = async {
let body_bytes =
axum::body::to_bytes(body, usize::MAX)
.await
.map_err(|e| Error::InvalidHeader {
header: "Content-Length".to_string(),
reason: format!("Failed to read body: {e}"),
})?;
let content_type = headers.get("content-type").and_then(|v| v.to_str().ok());
if let Some(ct) = content_type
&& ct.trim().is_empty()
{
return Err(ProblemResponse::from(Error::InvalidHeader {
header: "Content-Type".to_string(),
reason: "empty value".to_string(),
}));
}
let normalized_ct = content_type.map_or_else(
|| "application/octet-stream".to_string(),
headers::normalize_content_type,
);
let ttl_seconds =
if let Some(ttl_value) = headers.get(names::STREAM_TTL).and_then(|v| v.to_str().ok()) {
Some(headers::parse_ttl(ttl_value)?)
} else {
None
};
let expires_at = if let Some(expires_value) = headers
.get(names::STREAM_EXPIRES_AT)
.and_then(|v| v.to_str().ok())
{
Some(headers::parse_expires_at(expires_value)?)
} else {
None
};
if ttl_seconds.is_some() && expires_at.is_some() {
return Err(ProblemResponse::from(Error::ConflictingExpiration));
}
let created_closed = headers
.get(names::STREAM_CLOSED)
.and_then(|v| v.to_str().ok())
.is_some_and(headers::parse_bool);
let mut config = StreamConfig::new(normalized_ct.clone());
if let Some(ttl) = ttl_seconds {
let expires_at =
Utc::now() + chrono::Duration::seconds(i64::try_from(ttl).unwrap_or(i64::MAX));
config = config.with_expires_at(expires_at);
config = config.with_ttl(ttl);
} else if let Some(expires) = expires_at {
config = config.with_expires_at(expires);
}
if created_closed {
config = config.with_created_closed(true);
}
let messages = if body_bytes.is_empty() {
vec![]
} else if json_mode::is_json_content_type(&normalized_ct) {
json_mode::process_append(&body_bytes)?
} else {
vec![body_bytes]
};
let CreateWithDataResult {
status: create_status,
next_offset,
closed,
} = storage.create_stream_with_data(&name, config, messages, created_closed)?;
let status = if matches!(create_status, CreateStreamResult::Created) {
StatusCode::CREATED
} else {
StatusCode::OK
};
let location = build_location_url(&headers, &stream_base_path, &name);
let mut response_headers = HeaderMap::new();
response_headers.insert("content-type", normalized_ct.parse().unwrap());
response_headers.insert(
names::STREAM_NEXT_OFFSET,
HeaderValue::from_bytes(next_offset.as_str().as_bytes()).unwrap(),
);
response_headers.insert("location", location.parse().unwrap());
if closed {
response_headers.insert(names::STREAM_CLOSED, "true".parse().unwrap());
}
Ok((status, response_headers).into_response())
}
.await;
result.map_err(|problem| problem.with_instance(instance))
}
fn build_location_url(headers: &HeaderMap, stream_base_path: &str, name: &str) -> String {
let scheme = headers
.get("x-forwarded-proto")
.and_then(|v| v.to_str().ok())
.unwrap_or("http");
let host = headers
.get("x-forwarded-host")
.and_then(|v| v.to_str().ok())
.or_else(|| headers.get("host").and_then(|v| v.to_str().ok()))
.unwrap_or("localhost");
if stream_base_path == "/" {
format!("{scheme}://{host}/{name}")
} else {
format!("{scheme}://{host}{stream_base_path}/{name}")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_location_prefers_x_forwarded_host() {
let mut headers = HeaderMap::new();
headers.insert("x-forwarded-proto", "https".parse().unwrap());
headers.insert("x-forwarded-host", "proxy.example.com".parse().unwrap());
headers.insert("host", "internal.local".parse().unwrap());
let location = build_location_url(&headers, "/v1/stream", "orders");
assert_eq!(location, "https://proxy.example.com/v1/stream/orders");
}
#[test]
fn test_build_location_falls_back_to_host_and_http() {
let mut headers = HeaderMap::new();
headers.insert("host", "localhost:4437".parse().unwrap());
let location = build_location_url(&headers, "/v1/stream", "orders");
assert_eq!(location, "http://localhost:4437/v1/stream/orders");
}
#[test]
fn test_build_location_supports_custom_base_path() {
let mut headers = HeaderMap::new();
headers.insert("host", "localhost:4437".parse().unwrap());
let location = build_location_url(&headers, "/streams", "orders");
assert_eq!(location, "http://localhost:4437/streams/orders");
}
#[test]
fn test_build_location_supports_root_base_path() {
let mut headers = HeaderMap::new();
headers.insert("host", "localhost:4437".parse().unwrap());
let location = build_location_url(&headers, "/", "orders");
assert_eq!(location, "http://localhost:4437/orders");
}
}