use axum::{
Json, Router,
extract::{DefaultBodyLimit, Path as AxumPath, Query, State},
http::{
HeaderValue, Method, StatusCode,
header::{self, AUTHORIZATION, CONTENT_TYPE},
},
middleware,
response::{IntoResponse, Response},
routing::{get, post},
};
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use std::{sync::Arc, time::Instant};
use tower_http::{
cors::{AllowOrigin, CorsLayer},
trace::TraceLayer,
};
use crate::{
application::{
commands::*,
handlers::{
CommandHandlerGat, QueryHandlerGat,
command_handlers::SessionCommandHandler,
query_handlers::{SessionQueryHandler, StreamQueryHandler, SystemQueryHandler},
},
queries::*,
},
domain::{
aggregates::stream_session::{SessionConfig, SessionHealth},
ports::{
DictionaryStore, EventPublisherGat, NoopDictionaryStore, StreamRepositoryGat,
StreamStoreGat,
},
value_objects::{Priority, SessionId, StreamId},
},
infrastructure::http::middleware::{RateLimitMiddleware, security_middleware},
};
#[derive(Debug, Clone)]
pub struct HttpServerConfig {
pub allowed_origins: Vec<String>,
}
impl Default for HttpServerConfig {
fn default() -> Self {
Self {
allowed_origins: vec!["http://localhost:3000".to_string()],
}
}
}
fn build_cors_layer(config: &HttpServerConfig) -> Result<CorsLayer, PjsError> {
let base = CorsLayer::new()
.allow_methods([Method::GET, Method::POST])
.allow_headers([CONTENT_TYPE, AUTHORIZATION])
.max_age(std::time::Duration::from_secs(3600));
let has_wildcard = config.allowed_origins.iter().any(|o| o == "*");
let has_explicit = config.allowed_origins.iter().any(|o| o != "*");
let layer = match (
config.allowed_origins.is_empty(),
has_wildcard,
has_explicit,
) {
(true, _, _) => base.allow_origin(AllowOrigin::list(std::iter::empty::<HeaderValue>())),
(_, true, true) => {
return Err(PjsError::HttpError(
"CORS: wildcard '*' cannot be combined with explicit origins".into(),
));
}
(_, true, false) => base.allow_origin(tower_http::cors::Any),
(_, false, _) => {
let origins: Vec<HeaderValue> = config
.allowed_origins
.iter()
.map(|o| {
o.parse::<HeaderValue>()
.map_err(|e| PjsError::HttpError(format!("invalid CORS origin {o:?}: {e}")))
})
.collect::<Result<_, _>>()?;
base.allow_origin(AllowOrigin::list(origins))
}
};
Ok(layer)
}
pub struct PjsAppState<R, P, S>
where
R: StreamRepositoryGat + Send + Sync + 'static,
P: EventPublisherGat + Send + Sync + 'static,
S: StreamStoreGat + Send + Sync + 'static,
{
command_handler: Arc<SessionCommandHandler<R, P>>,
session_query_handler: Arc<SessionQueryHandler<R>>,
stream_query_handler: Arc<StreamQueryHandler<R, S>>,
system_handler: Arc<SystemQueryHandler<R>>,
pub(crate) dictionary_store: Arc<dyn DictionaryStore>,
}
impl<R, P, S> Clone for PjsAppState<R, P, S>
where
R: StreamRepositoryGat + Send + Sync + 'static,
P: EventPublisherGat + Send + Sync + 'static,
S: StreamStoreGat + Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self {
command_handler: self.command_handler.clone(),
session_query_handler: self.session_query_handler.clone(),
stream_query_handler: self.stream_query_handler.clone(),
system_handler: self.system_handler.clone(),
dictionary_store: self.dictionary_store.clone(),
}
}
}
impl<R, P, S> PjsAppState<R, P, S>
where
R: StreamRepositoryGat + Send + Sync + 'static,
P: EventPublisherGat + Send + Sync + 'static,
S: StreamStoreGat + Send + Sync + 'static,
{
pub fn new(repository: Arc<R>, event_publisher: Arc<P>, stream_store: Arc<S>) -> Self {
Self::with_dictionary_store(
repository,
event_publisher,
stream_store,
Arc::new(NoopDictionaryStore),
)
}
pub fn with_dictionary_store(
repository: Arc<R>,
event_publisher: Arc<P>,
stream_store: Arc<S>,
dictionary_store: Arc<dyn DictionaryStore>,
) -> Self {
let started_at = Instant::now();
Self {
command_handler: Arc::new(SessionCommandHandler::new(
repository.clone(),
event_publisher,
)),
session_query_handler: Arc::new(SessionQueryHandler::new(repository.clone())),
stream_query_handler: Arc::new(StreamQueryHandler::new(
repository.clone(),
stream_store,
)),
system_handler: Arc::new(SystemQueryHandler::with_start_time(repository, started_at)),
dictionary_store,
}
}
}
#[derive(Debug, Deserialize)]
pub struct CreateSessionRequest {
pub max_concurrent_streams: Option<usize>,
pub timeout_seconds: Option<u64>,
pub client_info: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct CreateSessionResponse {
pub session_id: String,
pub expires_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Deserialize)]
pub struct StartStreamRequest {
pub data: JsonValue,
pub priority_threshold: Option<u8>,
pub max_frames: Option<usize>,
}
#[derive(Debug, Deserialize)]
pub struct StreamParams {
pub session_id: String,
pub priority: Option<u8>,
pub format: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct SessionHealthResponse {
pub is_healthy: bool,
pub active_streams: usize,
pub failed_streams: usize,
pub is_expired: bool,
pub uptime_seconds: i64,
}
impl From<SessionHealth> for SessionHealthResponse {
fn from(health: SessionHealth) -> Self {
Self {
is_healthy: health.is_healthy,
active_streams: health.active_streams,
failed_streams: health.failed_streams,
is_expired: health.is_expired,
uptime_seconds: health.uptime_seconds,
}
}
}
pub fn create_pjs_router<R, P, S>() -> Router<PjsAppState<R, P, S>>
where
R: StreamRepositoryGat + Send + Sync + 'static,
P: EventPublisherGat + Send + Sync + 'static,
S: StreamStoreGat + Send + Sync + 'static,
{
create_pjs_router_with_config::<R, P, S>(&HttpServerConfig::default())
.expect("default HttpServerConfig must always produce a valid CORS layer")
}
pub fn create_pjs_router_with_config<R, P, S>(
config: &HttpServerConfig,
) -> Result<Router<PjsAppState<R, P, S>>, PjsError>
where
R: StreamRepositoryGat + Send + Sync + 'static,
P: EventPublisherGat + Send + Sync + 'static,
S: StreamStoreGat + Send + Sync + 'static,
{
let all_routes = public_routes::<R, P, S>().merge(protected_routes::<R, P, S>());
apply_common_layers(all_routes, config)
}
pub fn create_pjs_router_with_rate_limit<R, P, S>(
rate_limit_middleware: RateLimitMiddleware,
) -> Router<PjsAppState<R, P, S>>
where
R: StreamRepositoryGat + Send + Sync + 'static,
P: EventPublisherGat + Send + Sync + 'static,
S: StreamStoreGat + Send + Sync + 'static,
{
create_pjs_router_with_rate_limit_and_config::<R, P, S>(
&HttpServerConfig::default(),
rate_limit_middleware,
)
.expect("default HttpServerConfig must always produce a valid CORS layer")
}
pub fn create_pjs_router_with_rate_limit_and_config<R, P, S>(
config: &HttpServerConfig,
rate_limit_middleware: RateLimitMiddleware,
) -> Result<Router<PjsAppState<R, P, S>>, PjsError>
where
R: StreamRepositoryGat + Send + Sync + 'static,
P: EventPublisherGat + Send + Sync + 'static,
S: StreamStoreGat + Send + Sync + 'static,
{
let all_routes = public_routes::<R, P, S>()
.merge(protected_routes::<R, P, S>())
.layer(rate_limit_middleware);
apply_common_layers(all_routes, config)
}
#[cfg(feature = "http-server")]
pub fn create_pjs_router_with_auth<R, P, S>(
config: &HttpServerConfig,
auth: crate::infrastructure::http::auth::ApiKeyAuthLayer,
) -> Result<Router<PjsAppState<R, P, S>>, PjsError>
where
R: StreamRepositoryGat + Send + Sync + 'static,
P: EventPublisherGat + Send + Sync + 'static,
S: StreamStoreGat + Send + Sync + 'static,
{
let protected = protected_routes::<R, P, S>().layer(auth);
let merged = public_routes::<R, P, S>().merge(protected);
apply_common_layers(merged, config)
}
#[cfg(feature = "http-server")]
pub fn create_pjs_router_with_rate_limit_and_auth<R, P, S>(
config: &HttpServerConfig,
rate_limit: RateLimitMiddleware,
auth: crate::infrastructure::http::auth::ApiKeyAuthLayer,
) -> Result<Router<PjsAppState<R, P, S>>, PjsError>
where
R: StreamRepositoryGat + Send + Sync + 'static,
P: EventPublisherGat + Send + Sync + 'static,
S: StreamStoreGat + Send + Sync + 'static,
{
let protected = protected_routes::<R, P, S>().layer(auth);
let merged = public_routes::<R, P, S>()
.merge(protected)
.layer(rate_limit);
apply_common_layers(merged, config)
}
fn public_routes<R, P, S>() -> Router<PjsAppState<R, P, S>>
where
R: StreamRepositoryGat + Send + Sync + 'static,
P: EventPublisherGat + Send + Sync + 'static,
S: StreamStoreGat + Send + Sync + 'static,
{
let router = Router::new().route("/pjs/health", get(system_health));
#[cfg(feature = "metrics")]
let router = router.route(
"/metrics",
get(crate::infrastructure::http::metrics::metrics_handler),
);
router
}
fn protected_routes<R, P, S>() -> Router<PjsAppState<R, P, S>>
where
R: StreamRepositoryGat + Send + Sync + 'static,
P: EventPublisherGat + Send + Sync + 'static,
S: StreamStoreGat + Send + Sync + 'static,
{
let router = Router::new()
.route("/pjs/sessions", post(create_session::<R, P, S>))
.route("/pjs/sessions/{session_id}", get(get_session::<R, P, S>))
.route(
"/pjs/sessions/{session_id}/health",
get(session_health::<R, P, S>),
)
.route(
"/pjs/sessions/{session_id}/stats",
get(get_session_stats::<R, P, S>),
)
.route(
"/pjs/sessions/{session_id}/streams",
post(create_stream::<R, P, S>),
)
.route(
"/pjs/sessions/{session_id}/streams/{stream_id}/start",
post(start_stream::<R, P, S>),
)
.route(
"/pjs/sessions/{session_id}/streams/{stream_id}",
get(get_stream::<R, P, S>),
)
.route(
"/pjs/sessions/{session_id}/streams/{stream_id}/frames",
get(get_stream_frames::<R, P, S>),
)
.route("/pjs/sessions/search", get(search_sessions::<R, P, S>))
.route("/pjs/sessions", get(list_sessions::<R, P, S>))
.route("/pjs/stats", get(get_system_stats::<R, P, S>));
#[cfg(all(feature = "compression", not(target_arch = "wasm32")))]
let router = router.route(
"/pjs/sessions/{session_id}/dictionary",
get(crate::infrastructure::http::dictionary::get_session_dictionary::<R, P, S>),
);
router
}
fn apply_common_layers<R, P, S>(
router: Router<PjsAppState<R, P, S>>,
config: &HttpServerConfig,
) -> Result<Router<PjsAppState<R, P, S>>, PjsError>
where
R: StreamRepositoryGat + Send + Sync + 'static,
P: EventPublisherGat + Send + Sync + 'static,
S: StreamStoreGat + Send + Sync + 'static,
{
let cors = build_cors_layer(config)?;
Ok(router
.layer(middleware::from_fn(security_middleware))
.layer(DefaultBodyLimit::max(10 * 1024 * 1024))
.layer(cors)
.layer(TraceLayer::new_for_http()))
}
async fn create_session<R, P, S>(
State(state): State<PjsAppState<R, P, S>>,
headers: axum::http::HeaderMap,
Json(request): Json<CreateSessionRequest>,
) -> Result<Json<CreateSessionResponse>, PjsError>
where
R: StreamRepositoryGat + Send + Sync + 'static,
P: EventPublisherGat + Send + Sync + 'static,
S: StreamStoreGat + Send + Sync + 'static,
{
let config = SessionConfig {
max_concurrent_streams: request.max_concurrent_streams.unwrap_or(10),
session_timeout_seconds: request.timeout_seconds.unwrap_or(3600),
default_stream_config: Default::default(),
enable_compression: true,
metadata: Default::default(),
};
let user_agent = headers
.get(header::USER_AGENT)
.and_then(|h| h.to_str().ok())
.map(String::from);
let command = CreateSessionCommand {
config,
client_info: request.client_info,
user_agent,
ip_address: None,
};
let session_id: SessionId = CommandHandlerGat::handle(&*state.command_handler, command)
.await
.map_err(PjsError::Application)?;
let expires_at = chrono::Utc::now()
+ chrono::Duration::seconds(request.timeout_seconds.unwrap_or(3600) as i64);
Ok(Json(CreateSessionResponse {
session_id: session_id.to_string(),
expires_at,
}))
}
async fn get_session<R, P, S>(
State(state): State<PjsAppState<R, P, S>>,
AxumPath(session_id): AxumPath<String>,
) -> Result<Json<SessionResponse>, PjsError>
where
R: StreamRepositoryGat + Send + Sync + 'static,
P: EventPublisherGat + Send + Sync + 'static,
S: StreamStoreGat + Send + Sync + 'static,
{
let session_id =
SessionId::from_string(&session_id).map_err(|_| PjsError::InvalidSessionId(session_id))?;
let query = GetSessionQuery {
session_id: session_id.into(),
};
let response = <SessionQueryHandler<R> as QueryHandlerGat<GetSessionQuery>>::handle(
&*state.session_query_handler,
query,
)
.await
.map_err(PjsError::Application)?;
Ok(Json(response))
}
async fn session_health<R, P, S>(
State(state): State<PjsAppState<R, P, S>>,
AxumPath(session_id): AxumPath<String>,
) -> Result<Json<SessionHealthResponse>, PjsError>
where
R: StreamRepositoryGat + Send + Sync + 'static,
P: EventPublisherGat + Send + Sync + 'static,
S: StreamStoreGat + Send + Sync + 'static,
{
let session_id =
SessionId::from_string(&session_id).map_err(|_| PjsError::InvalidSessionId(session_id))?;
let query = GetSessionHealthQuery {
session_id: session_id.into(),
};
let response = <SessionQueryHandler<R> as QueryHandlerGat<GetSessionHealthQuery>>::handle(
&*state.session_query_handler,
query,
)
.await
.map_err(PjsError::Application)?;
Ok(Json(SessionHealthResponse::from(response.health)))
}
async fn create_stream<R, P, S>(
State(state): State<PjsAppState<R, P, S>>,
AxumPath(session_id): AxumPath<String>,
Json(request): Json<StartStreamRequest>,
) -> Result<Json<serde_json::Value>, PjsError>
where
R: StreamRepositoryGat + Send + Sync + 'static,
P: EventPublisherGat + Send + Sync + 'static,
S: StreamStoreGat + Send + Sync + 'static,
{
let session_id =
SessionId::from_string(&session_id).map_err(|_| PjsError::InvalidSessionId(session_id))?;
let command = CreateStreamCommand {
session_id: session_id.into(),
source_data: request.data,
config: None,
};
let stream_id: StreamId = CommandHandlerGat::handle(&*state.command_handler, command)
.await
.map_err(PjsError::Application)?;
Ok(Json(serde_json::json!({
"stream_id": stream_id.to_string(),
"status": "created"
})))
}
async fn start_stream<R, P, S>(
State(state): State<PjsAppState<R, P, S>>,
AxumPath((session_id, stream_id)): AxumPath<(String, String)>,
) -> Result<Json<serde_json::Value>, PjsError>
where
R: StreamRepositoryGat + Send + Sync + 'static,
P: EventPublisherGat + Send + Sync + 'static,
S: StreamStoreGat + Send + Sync + 'static,
{
let session_id = SessionId::from_string(&session_id)
.map_err(|_| PjsError::InvalidSessionId(session_id.clone()))?;
let stream_id =
StreamId::from_string(&stream_id).map_err(|_| PjsError::InvalidStreamId(stream_id))?;
let command = StartStreamCommand {
session_id: session_id.into(),
stream_id: stream_id.into(),
};
<SessionCommandHandler<R, P> as CommandHandlerGat<StartStreamCommand>>::handle(
&*state.command_handler,
command,
)
.await
.map_err(PjsError::Application)?;
Ok(Json(serde_json::json!({
"stream_id": stream_id.to_string(),
"status": "started"
})))
}
async fn get_stream<R, P, S>(
State(state): State<PjsAppState<R, P, S>>,
AxumPath((session_id, stream_id)): AxumPath<(String, String)>,
) -> Result<Json<StreamResponse>, PjsError>
where
R: StreamRepositoryGat + Send + Sync + 'static,
P: EventPublisherGat + Send + Sync + 'static,
S: StreamStoreGat + Send + Sync + 'static,
{
let session_id = SessionId::from_string(&session_id)
.map_err(|_| PjsError::InvalidSessionId(session_id.clone()))?;
let stream_id =
StreamId::from_string(&stream_id).map_err(|_| PjsError::InvalidStreamId(stream_id))?;
let query = GetStreamQuery {
session_id: session_id.into(),
stream_id: stream_id.into(),
};
let response = <StreamQueryHandler<R, S> as QueryHandlerGat<GetStreamQuery>>::handle(
&*state.stream_query_handler,
query,
)
.await
.map_err(PjsError::Application)?;
Ok(Json(response))
}
async fn list_sessions<R, P, S>(
State(state): State<PjsAppState<R, P, S>>,
Query(params): Query<PaginationParams>,
) -> Result<Json<SessionsResponse>, PjsError>
where
R: StreamRepositoryGat + Send + Sync + 'static,
P: EventPublisherGat + Send + Sync + 'static,
S: StreamStoreGat + Send + Sync + 'static,
{
let query = GetActiveSessionsQuery {
limit: params.limit,
offset: params.offset,
};
let response = <SessionQueryHandler<R> as QueryHandlerGat<GetActiveSessionsQuery>>::handle(
&*state.session_query_handler,
query,
)
.await
.map_err(PjsError::Application)?;
Ok(Json(response))
}
async fn search_sessions<R, P, S>(
State(state): State<PjsAppState<R, P, S>>,
Query(params): Query<SearchSessionsParams>,
) -> Result<Json<SessionsResponse>, PjsError>
where
R: StreamRepositoryGat + Send + Sync + 'static,
P: EventPublisherGat + Send + Sync + 'static,
S: StreamStoreGat + Send + Sync + 'static,
{
let sort_by = params.sort_by.as_deref().and_then(|s| match s {
"created_at" => Some(SessionSortField::CreatedAt),
"updated_at" => Some(SessionSortField::UpdatedAt),
"stream_count" => Some(SessionSortField::StreamCount),
"total_bytes" => Some(SessionSortField::TotalBytes),
_ => None,
});
let sort_order = params.sort_order.as_deref().and_then(|s| match s {
"ascending" | "asc" => Some(SortOrder::Ascending),
"descending" | "desc" => Some(SortOrder::Descending),
_ => None,
});
let query = SearchSessionsQuery {
filters: SessionFilters {
state: params.state,
created_after: None,
created_before: None,
client_info: None,
has_active_streams: None,
},
sort_by,
sort_order,
limit: params.limit,
offset: params.offset,
};
let response = <SessionQueryHandler<R> as QueryHandlerGat<SearchSessionsQuery>>::handle(
&*state.session_query_handler,
query,
)
.await
.map_err(PjsError::Application)?;
Ok(Json(response))
}
#[derive(Debug, Deserialize)]
pub struct PaginationParams {
pub limit: Option<usize>,
pub offset: Option<usize>,
}
#[derive(Debug, Deserialize)]
pub struct SearchSessionsParams {
pub state: Option<String>,
pub sort_by: Option<String>,
pub sort_order: Option<String>,
pub limit: Option<usize>,
pub offset: Option<usize>,
}
async fn system_health() -> Json<serde_json::Value> {
Json(serde_json::json!({
"status": "healthy",
"version": env!("CARGO_PKG_VERSION"),
"features": ["pjs_streaming", "axum_integration", "gat_handlers"]
}))
}
async fn get_system_stats<R, P, S>(
State(state): State<PjsAppState<R, P, S>>,
) -> Result<Json<SystemStatsResponse>, PjsError>
where
R: StreamRepositoryGat + Send + Sync + 'static,
P: EventPublisherGat + Send + Sync + 'static,
S: StreamStoreGat + Send + Sync + 'static,
{
let query = GetSystemStatsQuery {
include_historical: false,
};
let response = <SystemQueryHandler<R> as QueryHandlerGat<GetSystemStatsQuery>>::handle(
&*state.system_handler,
query,
)
.await
.map_err(PjsError::Application)?;
Ok(Json(response))
}
#[derive(Debug, Deserialize)]
pub struct FrameQueryParams {
pub since_sequence: Option<u64>,
pub priority: Option<u8>,
pub limit: Option<usize>,
}
async fn get_stream_frames<R, P, S>(
State(state): State<PjsAppState<R, P, S>>,
AxumPath((session_id, stream_id)): AxumPath<(String, String)>,
Query(params): Query<FrameQueryParams>,
) -> Result<Json<FramesResponse>, PjsError>
where
R: StreamRepositoryGat + Send + Sync + 'static,
P: EventPublisherGat + Send + Sync + 'static,
S: StreamStoreGat + Send + Sync + 'static,
{
let session_id = SessionId::from_string(&session_id)
.map_err(|_| PjsError::InvalidSessionId(session_id.clone()))?;
let stream_id =
StreamId::from_string(&stream_id).map_err(|_| PjsError::InvalidStreamId(stream_id))?;
let priority_filter = params
.priority
.map(|p| Priority::new(p).map(Into::into))
.transpose()
.map_err(|e: crate::domain::DomainError| PjsError::InvalidPriority(e.to_string()))?;
let query = GetStreamFramesQuery {
session_id: session_id.into(),
stream_id: stream_id.into(),
since_sequence: params.since_sequence,
priority_filter,
limit: params.limit,
};
let response = <StreamQueryHandler<R, S> as QueryHandlerGat<GetStreamFramesQuery>>::handle(
&*state.stream_query_handler,
query,
)
.await
.map_err(PjsError::Application)?;
Ok(Json(response))
}
async fn get_session_stats<R, P, S>(
State(state): State<PjsAppState<R, P, S>>,
AxumPath(session_id): AxumPath<String>,
) -> Result<Json<SessionStatsResponse>, PjsError>
where
R: StreamRepositoryGat + Send + Sync + 'static,
P: EventPublisherGat + Send + Sync + 'static,
S: StreamStoreGat + Send + Sync + 'static,
{
let session_id =
SessionId::from_string(&session_id).map_err(|_| PjsError::InvalidSessionId(session_id))?;
let query = GetSessionStatsQuery {
session_id: session_id.into(),
};
let response = <SessionQueryHandler<R> as QueryHandlerGat<GetSessionStatsQuery>>::handle(
&*state.session_query_handler,
query,
)
.await
.map_err(PjsError::Application)?;
Ok(Json(response))
}
#[derive(Debug, thiserror::Error)]
pub enum PjsError {
#[error("Application error: {0}")]
Application(#[from] crate::application::ApplicationError),
#[error("Invalid session ID: {0}")]
InvalidSessionId(String),
#[error("Invalid stream ID: {0}")]
InvalidStreamId(String),
#[error("Invalid priority: {0}")]
InvalidPriority(String),
#[error("HTTP error: {0}")]
HttpError(String),
}
impl IntoResponse for PjsError {
fn into_response(self) -> Response {
let (status, error_message) = match &self {
PjsError::Application(app_err) => {
use crate::application::ApplicationError;
let status = match app_err {
ApplicationError::NotFound(_) => StatusCode::NOT_FOUND,
ApplicationError::Validation(_) => StatusCode::BAD_REQUEST,
ApplicationError::Authorization(_) => StatusCode::UNAUTHORIZED,
ApplicationError::Concurrency(_) | ApplicationError::Conflict(_) => {
StatusCode::CONFLICT
}
ApplicationError::Domain(_) | ApplicationError::Logic(_) => {
StatusCode::INTERNAL_SERVER_ERROR
}
};
(status, self.to_string())
}
PjsError::InvalidSessionId(_) => (StatusCode::BAD_REQUEST, self.to_string()),
PjsError::InvalidStreamId(_) => (StatusCode::BAD_REQUEST, self.to_string()),
PjsError::InvalidPriority(_) => (StatusCode::BAD_REQUEST, self.to_string()),
PjsError::HttpError(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
};
let body = Json(serde_json::json!({
"error": error_message
}));
(status, body).into_response()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cors_empty_origins_denies_all() {
let config = HttpServerConfig {
allowed_origins: vec![],
};
let result = build_cors_layer(&config);
assert!(
result.is_ok(),
"empty origins should return Ok (deny-all layer)"
);
}
#[test]
fn cors_wildcard_only_is_ok() {
let config = HttpServerConfig {
allowed_origins: vec!["*".to_string()],
};
let result = build_cors_layer(&config);
assert!(result.is_ok(), "wildcard-only should return Ok");
}
#[test]
fn cors_mixed_wildcard_and_explicit_is_err() {
let config = HttpServerConfig {
allowed_origins: vec!["*".to_string(), "http://example.com".to_string()],
};
let result = build_cors_layer(&config);
assert!(
result.is_err(),
"mixing wildcard with explicit origins must fail"
);
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("wildcard"),
"error message should mention wildcard: {msg}"
);
}
#[test]
fn cors_valid_single_origin_is_ok() {
let config = HttpServerConfig {
allowed_origins: vec!["http://example.com".to_string()],
};
assert!(build_cors_layer(&config).is_ok());
}
#[test]
fn cors_valid_multiple_origins_is_ok() {
let config = HttpServerConfig {
allowed_origins: vec![
"https://app.example.com".to_string(),
"https://admin.example.com".to_string(),
],
};
assert!(build_cors_layer(&config).is_ok());
}
#[test]
fn cors_invalid_origin_string_is_err() {
let config = HttpServerConfig {
allowed_origins: vec!["not a\nvalid header".to_string()],
};
let result = build_cors_layer(&config);
assert!(result.is_err(), "invalid origin string must return Err");
}
#[test]
fn default_config_is_valid() {
assert!(
build_cors_layer(&HttpServerConfig::default()).is_ok(),
"default HttpServerConfig must produce a valid CORS layer"
);
}
use crate::domain::{
aggregates::StreamSession,
entities::Stream,
events::DomainEvent,
ports::{
EventPublisherGat, Pagination, PriorityDistribution, SessionHealthSnapshot,
SessionQueryCriteria, SessionQueryResult, StreamFilter, StreamRepositoryGat,
StreamStatistics, StreamStatus, StreamStoreGat,
},
value_objects::{SessionId, StreamId},
};
use chrono::Utc;
use std::collections::HashMap;
struct MockRepository {
sessions: parking_lot::Mutex<HashMap<SessionId, StreamSession>>,
}
impl MockRepository {
fn new() -> Self {
Self {
sessions: parking_lot::Mutex::new(HashMap::new()),
}
}
}
impl StreamRepositoryGat for MockRepository {
type FindSessionFuture<'a>
= impl std::future::Future<Output = crate::domain::DomainResult<Option<StreamSession>>>
+ Send
+ 'a
where
Self: 'a;
type SaveSessionFuture<'a>
= impl std::future::Future<Output = crate::domain::DomainResult<()>> + Send + 'a
where
Self: 'a;
type RemoveSessionFuture<'a>
= impl std::future::Future<Output = crate::domain::DomainResult<()>> + Send + 'a
where
Self: 'a;
type FindActiveSessionsFuture<'a>
= impl std::future::Future<Output = crate::domain::DomainResult<Vec<StreamSession>>>
+ Send
+ 'a
where
Self: 'a;
type FindSessionsByCriteriaFuture<'a>
= impl std::future::Future<Output = crate::domain::DomainResult<SessionQueryResult>>
+ Send
+ 'a
where
Self: 'a;
type GetSessionHealthFuture<'a>
= impl std::future::Future<Output = crate::domain::DomainResult<SessionHealthSnapshot>>
+ Send
+ 'a
where
Self: 'a;
type SessionExistsFuture<'a>
= impl std::future::Future<Output = crate::domain::DomainResult<bool>> + Send + 'a
where
Self: 'a;
fn find_session(&self, session_id: SessionId) -> Self::FindSessionFuture<'_> {
async move { Ok(self.sessions.lock().get(&session_id).cloned()) }
}
fn save_session(&self, session: StreamSession) -> Self::SaveSessionFuture<'_> {
async move {
self.sessions.lock().insert(session.id(), session);
Ok(())
}
}
fn remove_session(&self, session_id: SessionId) -> Self::RemoveSessionFuture<'_> {
async move {
self.sessions.lock().remove(&session_id);
Ok(())
}
}
fn find_active_sessions(&self) -> Self::FindActiveSessionsFuture<'_> {
async move { Ok(self.sessions.lock().values().cloned().collect()) }
}
fn find_sessions_by_criteria(
&self,
_criteria: SessionQueryCriteria,
pagination: Pagination,
) -> Self::FindSessionsByCriteriaFuture<'_> {
async move {
let sessions: Vec<_> = self.sessions.lock().values().cloned().collect();
let total_count = sessions.len();
let paginated: Vec<_> = sessions
.into_iter()
.skip(pagination.offset)
.take(pagination.limit)
.collect();
let has_more = pagination.offset + paginated.len() < total_count;
Ok(SessionQueryResult {
sessions: paginated,
total_count,
has_more,
query_duration_ms: 0,
scan_limit_reached: false,
})
}
}
fn get_session_health(&self, session_id: SessionId) -> Self::GetSessionHealthFuture<'_> {
async move {
Ok(SessionHealthSnapshot {
session_id,
is_healthy: true,
active_streams: 0,
total_frames: 0,
last_activity: Utc::now(),
error_rate: 0.0,
metrics: HashMap::new(),
})
}
}
fn session_exists(&self, session_id: SessionId) -> Self::SessionExistsFuture<'_> {
async move { Ok(self.sessions.lock().contains_key(&session_id)) }
}
}
struct MockEventPublisher;
impl EventPublisherGat for MockEventPublisher {
type PublishFuture<'a>
= impl std::future::Future<Output = crate::domain::DomainResult<()>> + Send + 'a
where
Self: 'a;
type PublishBatchFuture<'a>
= impl std::future::Future<Output = crate::domain::DomainResult<()>> + Send + 'a
where
Self: 'a;
fn publish(&self, _event: DomainEvent) -> Self::PublishFuture<'_> {
async move { Ok(()) }
}
fn publish_batch(&self, _events: Vec<DomainEvent>) -> Self::PublishBatchFuture<'_> {
async move { Ok(()) }
}
}
struct MockStreamStore;
impl StreamStoreGat for MockStreamStore {
type StoreStreamFuture<'a>
= impl std::future::Future<Output = crate::domain::DomainResult<()>> + Send + 'a
where
Self: 'a;
type GetStreamFuture<'a>
= impl std::future::Future<Output = crate::domain::DomainResult<Option<Stream>>>
+ Send
+ 'a
where
Self: 'a;
type DeleteStreamFuture<'a>
= impl std::future::Future<Output = crate::domain::DomainResult<()>> + Send + 'a
where
Self: 'a;
type ListStreamsForSessionFuture<'a>
=
impl std::future::Future<Output = crate::domain::DomainResult<Vec<Stream>>> + Send + 'a
where
Self: 'a;
type FindStreamsBySessionFuture<'a>
=
impl std::future::Future<Output = crate::domain::DomainResult<Vec<Stream>>> + Send + 'a
where
Self: 'a;
type UpdateStreamStatusFuture<'a>
= impl std::future::Future<Output = crate::domain::DomainResult<()>> + Send + 'a
where
Self: 'a;
type GetStreamStatisticsFuture<'a>
= impl std::future::Future<Output = crate::domain::DomainResult<StreamStatistics>>
+ Send
+ 'a
where
Self: 'a;
fn store_stream(&self, _stream: Stream) -> Self::StoreStreamFuture<'_> {
async move { Ok(()) }
}
fn get_stream(&self, _stream_id: StreamId) -> Self::GetStreamFuture<'_> {
async move { Ok(None) }
}
fn delete_stream(&self, _stream_id: StreamId) -> Self::DeleteStreamFuture<'_> {
async move { Ok(()) }
}
fn list_streams_for_session(
&self,
_session_id: SessionId,
) -> Self::ListStreamsForSessionFuture<'_> {
async move { Ok(vec![]) }
}
fn find_streams_by_session(
&self,
_session_id: SessionId,
_filter: StreamFilter,
) -> Self::FindStreamsBySessionFuture<'_> {
async move { Ok(vec![]) }
}
fn update_stream_status(
&self,
_stream_id: StreamId,
_status: StreamStatus,
) -> Self::UpdateStreamStatusFuture<'_> {
async move { Ok(()) }
}
fn get_stream_statistics(
&self,
_stream_id: StreamId,
) -> Self::GetStreamStatisticsFuture<'_> {
async move {
Ok(StreamStatistics {
total_frames: 0,
total_bytes: 0,
priority_distribution: PriorityDistribution::default(),
avg_frame_size: 0.0,
creation_time: Utc::now(),
completion_time: None,
processing_duration: None,
})
}
}
}
#[tokio::test]
async fn test_system_health() {
let response = system_health().await;
let health_data: serde_json::Value = response.0;
assert_eq!(health_data["status"], "healthy");
assert!(!health_data["features"].as_array().unwrap().is_empty());
}
#[tokio::test]
async fn test_app_state_creation() {
let repository = Arc::new(MockRepository::new());
let event_publisher = Arc::new(MockEventPublisher);
let stream_store = Arc::new(MockStreamStore);
let _state = PjsAppState::new(repository, event_publisher, stream_store);
}
#[tokio::test]
async fn test_get_system_stats_returns_real_uptime() {
use crate::application::handlers::QueryHandlerGat;
use crate::application::handlers::query_handlers::SystemQueryHandler;
use crate::application::queries::GetSystemStatsQuery;
use std::time::{Duration, Instant};
let repository = Arc::new(MockRepository::new());
let started_at = Instant::now() - Duration::from_secs(5);
let handler = SystemQueryHandler::with_start_time(repository, started_at);
let query = GetSystemStatsQuery {
include_historical: false,
};
let result = QueryHandlerGat::handle(&handler, query).await.unwrap();
assert!(
result.uptime_seconds >= 5,
"uptime_seconds should be at least 5, got {}",
result.uptime_seconds
);
assert_ne!(
result.uptime_seconds, 3600,
"uptime_seconds must not be the hard-coded placeholder 3600"
);
}
#[cfg(feature = "metrics")]
#[tokio::test]
async fn test_metrics_endpoint_returns_prometheus_format() {
use crate::infrastructure::http::metrics::install_global_recorder;
let handle = install_global_recorder().expect("recorder install should succeed");
let rendered = handle.render();
assert!(
!rendered.contains("{\"error\""),
"rendered metrics should not be a JSON error: {rendered}"
);
let handle2 = install_global_recorder().expect("second call must not fail");
assert_eq!(
handle.render(),
handle2.render(),
"both handles must render the same metrics"
);
}
#[cfg(feature = "metrics")]
#[test]
fn test_metrics_router_has_metrics_route() {
let _router =
create_pjs_router_with_config::<MockRepository, MockEventPublisher, MockStreamStore>(
&HttpServerConfig::default(),
)
.expect("router should build successfully with metrics feature");
}
#[tokio::test]
async fn search_sessions_route_returns_ok() {
use axum::http::Request;
use tower::ServiceExt;
let repository = Arc::new(MockRepository::new());
let event_publisher = Arc::new(MockEventPublisher);
let stream_store = Arc::new(MockStreamStore);
let state = PjsAppState::new(repository, event_publisher, stream_store);
let router =
create_pjs_router_with_config::<MockRepository, MockEventPublisher, MockStreamStore>(
&HttpServerConfig::default(),
)
.expect("router should build")
.with_state(state);
let req = Request::builder()
.uri("/pjs/sessions/search")
.body(axum::body::Body::empty())
.unwrap();
let resp = router.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
}