#![warn(missing_docs)]
pub mod actions;
#[cfg(feature = "server")]
pub mod capture;
#[cfg(feature = "server")]
pub mod config;
#[doc(hidden)]
pub mod constants;
pub mod error;
#[cfg(not(target_arch = "wasm32"))]
#[doc(hidden)]
pub mod event_stream;
pub mod health;
#[doc(hidden)]
pub mod metrics;
#[cfg(feature = "mirror")]
#[doc(hidden)]
pub mod mirror;
#[cfg(not(target_arch = "wasm32"))]
#[doc(hidden)]
pub mod persistence;
#[cfg(any(
not(target_arch = "wasm32"),
all(target_arch = "wasm32", feature = "wasm"),
all(target_os = "wasi", target_env = "p2", feature = "wasi"),
))]
#[doc(hidden)]
pub mod retention;
#[doc(hidden)]
pub mod runtime;
#[doc(hidden)]
pub mod sequence;
pub mod server;
#[doc(hidden)]
pub mod shard_iterator;
pub mod store;
pub mod types;
#[doc(hidden)]
pub mod util;
#[doc(hidden)]
pub mod validation;
use axum::Router;
#[cfg(all(feature = "server", not(target_arch = "wasm32")))]
use axum::body::Body;
use axum::middleware;
use axum::routing::{any, get};
#[cfg(all(feature = "server", not(target_arch = "wasm32")))]
use hyper::body::Incoming;
#[cfg(all(feature = "server", not(target_arch = "wasm32")))]
use hyper_util::rt::{TokioExecutor, TokioIo};
#[cfg(all(feature = "server", not(target_arch = "wasm32")))]
use hyper_util::server::conn::auto::Builder as AutoBuilder;
#[cfg(all(feature = "server", not(target_arch = "wasm32")))]
use hyper_util::server::graceful::GracefulShutdown;
#[cfg(all(feature = "server", not(target_arch = "wasm32")))]
use hyper_util::service::TowerToHyperService;
#[cfg(all(feature = "server", not(target_arch = "wasm32")))]
use std::time::Duration;
use store::Store;
#[cfg(any(
not(target_arch = "wasm32"),
all(target_os = "wasi", target_env = "p2", feature = "wasi"),
))]
use store::StoreOptions;
#[cfg(all(feature = "server", not(target_arch = "wasm32")))]
use tokio::task::JoinSet;
#[cfg(all(feature = "server", not(target_arch = "wasm32")))]
use tower::ServiceExt as _;
#[cfg(feature = "access-log")]
use tower_http::trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer};
#[cfg(all(feature = "server", not(target_arch = "wasm32")))]
const PLAIN_HTTP_GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
#[cfg(any(
not(target_arch = "wasm32"),
all(target_os = "wasi", target_env = "p2", feature = "wasi"),
))]
pub fn create_app(options: StoreOptions) -> (Router, Store) {
let store = Store::new(options.clone());
let app = create_router(store.clone());
spawn_retention_reaper(&store, &options);
(app, store)
}
#[cfg(all(feature = "server", not(target_arch = "wasm32")))]
pub fn create_app_with_capture(
options: StoreOptions,
capture: Option<capture::CaptureWriter>,
) -> (Router, Store) {
let store = Store::with_capture(options.clone(), capture);
let app = create_router(store.clone());
spawn_retention_reaper(&store, &options);
(app, store)
}
pub fn create_router(store: Store) -> Router {
let app = Router::new()
.route("/_health", get(health::health))
.route("/_health/live", get(health::live))
.route("/_health/ready", get(health::ready))
.route("/metrics", get(health::metrics))
.fallback(any(server::handler))
.with_state(store)
.layer(middleware::from_fn(server::kinesis_413_middleware));
#[cfg(feature = "access-log")]
let app = app.layer(
TraceLayer::new_for_http()
.make_span_with(DefaultMakeSpan::new().level(tracing::Level::INFO))
.on_response(DefaultOnResponse::new().level(tracing::Level::INFO)),
);
app
}
#[cfg(all(feature = "server", not(target_arch = "wasm32")))]
pub async fn serve_plain_http(
listener: tokio::net::TcpListener,
app: Router,
shutdown: impl std::future::Future<Output = ()> + Send + 'static,
) -> std::io::Result<()> {
let mut shutdown = std::pin::pin!(shutdown);
let graceful = GracefulShutdown::new();
let mut connections = JoinSet::new();
loop {
tokio::select! {
biased;
_ = &mut shutdown => break,
accept = listener.accept() => {
let (stream, _addr) = accept?;
let io = TokioIo::new(stream);
let service = TowerToHyperService::new(
app.clone()
.into_service::<Body>()
.map_request(|req: axum::http::Request<Incoming>| req.map(Body::new)),
);
let builder = AutoBuilder::new(TokioExecutor::new());
let connection = graceful.watch(
builder
.serve_connection_with_upgrades(io, service)
.into_owned(),
);
connections.spawn(async move {
if let Err(err) = connection.await {
tracing::debug!("plain connection closed with error: {err}");
}
});
}
Some(result) = connections.join_next(), if !connections.is_empty() => {
if let Err(err) = result {
tracing::debug!("plain connection task failed: {err}");
}
}
}
}
drop(listener);
let active_connections = graceful.count();
if active_connections > 0 {
tracing::info!(
connections = active_connections,
timeout_secs = PLAIN_HTTP_GRACEFUL_SHUTDOWN_TIMEOUT.as_secs(),
"draining plain HTTP connections"
);
}
if tokio::time::timeout(PLAIN_HTTP_GRACEFUL_SHUTDOWN_TIMEOUT, graceful.shutdown())
.await
.is_err()
{
tracing::warn!(
connections = connections.len(),
timeout_secs = PLAIN_HTTP_GRACEFUL_SHUTDOWN_TIMEOUT.as_secs(),
"timed out draining plain HTTP connections; aborting remaining tasks"
);
connections.abort_all();
}
while let Some(result) = connections.join_next().await {
if let Err(err) = result {
tracing::debug!("plain connection task failed during shutdown: {err}");
}
}
Ok(())
}
#[cfg(any(
not(target_arch = "wasm32"),
all(target_os = "wasi", target_env = "p2", feature = "wasi"),
))]
fn spawn_retention_reaper(store: &Store, options: &StoreOptions) {
if options.retention_check_interval_secs > 0 {
let reaper_store = store.clone();
runtime::spawn_background(retention::run_reaper(
reaper_store,
options.retention_check_interval_secs,
));
}
}
#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests {
use super::*;
use crate::actions::{Operation, dispatch};
use crate::types::StreamStatus;
use serde_json::json;
use std::time::Duration;
#[tokio::test]
async fn create_app_preserves_background_transitions() {
let (_app, store) = create_app(StoreOptions {
create_stream_ms: 1,
..StoreOptions::default()
});
dispatch(
&store,
Operation::CreateStream,
json!({
"StreamName": "native-no-default-features",
"ShardCount": 1,
}),
)
.await
.unwrap();
let stream = store
.get_stream("native-no-default-features")
.await
.unwrap();
assert_eq!(stream.stream_status, StreamStatus::Creating);
for _ in 0..50 {
let stream = store
.get_stream("native-no-default-features")
.await
.unwrap();
if stream.stream_status == StreamStatus::Active {
return;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
let stream = store
.get_stream("native-no-default-features")
.await
.unwrap();
assert_eq!(stream.stream_status, StreamStatus::Active);
}
}