#![recursion_limit = "512"]
use std::error::Error;
use std::io::IsTerminal;
use std::sync::Arc;
use aws_credential_types::provider::ProvideCredentials;
use clap::{Parser, ValueEnum};
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto::Builder as ConnBuilder;
use s3s::S3;
use s3s::auth::SimpleAuth;
use s3s::host::SingleDomain;
use s3s::service::S3ServiceBuilder;
use s4_codec::cpu_zstd::CpuZstd;
use s4_codec::dispatcher::{AlwaysDispatcher, SamplingDispatcher};
use s4_codec::passthrough::Passthrough;
use s4_codec::{CodecDispatcher, CodecKind, CodecRegistry};
use s4_server::S4Service;
use s4_server::routing::{HealthRouter, ReadyCheck};
use tokio::net::TcpListener;
use tracing::info;
#[derive(Debug, Clone, Copy, ValueEnum)]
enum CodecChoice {
Passthrough,
CpuZstd,
#[cfg(feature = "nvcomp-gpu")]
NvcompZstd,
#[cfg(feature = "nvcomp-gpu")]
NvcompBitcomp,
#[cfg(feature = "nvcomp-gpu")]
NvcompGdeflate,
}
impl CodecChoice {
fn as_kind(self) -> CodecKind {
match self {
Self::Passthrough => CodecKind::Passthrough,
Self::CpuZstd => CodecKind::CpuZstd,
#[cfg(feature = "nvcomp-gpu")]
Self::NvcompZstd => CodecKind::NvcompZstd,
#[cfg(feature = "nvcomp-gpu")]
Self::NvcompBitcomp => CodecKind::NvcompBitcomp,
#[cfg(feature = "nvcomp-gpu")]
Self::NvcompGdeflate => CodecKind::NvcompGDeflate,
}
}
}
#[derive(Debug, Clone, Copy, ValueEnum)]
enum DispatcherChoice {
Always,
Sampling,
}
#[derive(Debug, Clone, Copy, ValueEnum)]
enum LogFormat {
Pretty,
Json,
}
#[derive(Debug, Parser)]
#[command(
name = "s4",
version,
about = "S4 — Squished S3 (GPU 透過圧縮 S3 互換ゲートウェイ)"
)]
struct Opt {
#[clap(long, default_value = "127.0.0.1")]
host: String,
#[clap(long, default_value = "8014")]
port: u16,
#[clap(long)]
domain: Option<String>,
#[clap(long)]
endpoint_url: String,
#[clap(long, value_enum, default_value = "cpu-zstd")]
codec: CodecChoice,
#[clap(long, default_value_t = CpuZstd::DEFAULT_LEVEL)]
zstd_level: i32,
#[clap(long, value_enum, default_value = "sampling")]
dispatcher: DispatcherChoice,
#[clap(long, value_enum, default_value = "pretty")]
log_format: LogFormat,
#[clap(long)]
otlp_endpoint: Option<String>,
#[clap(long, default_value = "s4")]
service_name: String,
#[clap(long, requires = "tls_key")]
tls_cert: Option<std::path::PathBuf>,
#[clap(long, requires = "tls_cert")]
tls_key: Option<std::path::PathBuf>,
#[clap(long, conflicts_with_all = ["tls_cert", "tls_key"])]
acme: Option<String>,
#[clap(long, requires = "acme")]
acme_contact: Option<String>,
#[clap(long, requires = "acme")]
acme_cache_dir: Option<std::path::PathBuf>,
#[clap(long, requires = "acme")]
acme_staging: bool,
#[clap(long)]
policy: Option<std::path::PathBuf>,
}
fn setup_tracing(
format: LogFormat,
otlp_endpoint: Option<&str>,
service_name: &str,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
use tracing_subscriber::EnvFilter;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
let otel_layer = if let Some(endpoint) = otlp_endpoint {
use opentelemetry::trace::TracerProvider;
use opentelemetry_otlp::WithExportConfig;
let exporter = opentelemetry_otlp::SpanExporter::builder()
.with_tonic()
.with_endpoint(endpoint)
.build()?;
let provider = opentelemetry_sdk::trace::SdkTracerProvider::builder()
.with_resource(
opentelemetry_sdk::Resource::builder()
.with_service_name(service_name.to_owned())
.build(),
)
.with_batch_exporter(exporter)
.build();
let tracer = provider.tracer(service_name.to_owned());
opentelemetry::global::set_tracer_provider(provider);
Some(tracing_opentelemetry::layer().with_tracer(tracer))
} else {
None
};
use tracing_subscriber::Layer;
match (format, otel_layer) {
(LogFormat::Pretty, Some(otel)) => {
let fmt_layer = tracing_subscriber::fmt::layer()
.with_ansi(std::io::stdout().is_terminal())
.with_filter(env_filter);
tracing_subscriber::registry()
.with(otel)
.with(fmt_layer)
.init();
}
(LogFormat::Pretty, None) => {
let fmt_layer = tracing_subscriber::fmt::layer()
.with_ansi(std::io::stdout().is_terminal())
.with_filter(env_filter);
tracing_subscriber::registry().with(fmt_layer).init();
}
(LogFormat::Json, Some(otel)) => {
let fmt_layer = tracing_subscriber::fmt::layer()
.json()
.with_current_span(true)
.with_span_list(false)
.with_filter(env_filter);
tracing_subscriber::registry()
.with(otel)
.with(fmt_layer)
.init();
}
(LogFormat::Json, None) => {
let fmt_layer = tracing_subscriber::fmt::layer()
.json()
.with_current_span(true)
.with_span_list(false)
.with_filter(env_filter);
tracing_subscriber::registry().with(fmt_layer).init();
}
}
Ok(())
}
fn build_registry(default: CodecKind, zstd_level: i32) -> Arc<CodecRegistry> {
let reg = CodecRegistry::new(default)
.with(Arc::new(Passthrough))
.with(Arc::new(CpuZstd::new(zstd_level)));
#[cfg(feature = "nvcomp-gpu")]
let reg = {
use s4_codec::nvcomp::{
NvcompBitcompCodec, NvcompGDeflateCodec, NvcompZstdCodec, is_gpu_available,
};
if is_gpu_available() {
let mut r = reg;
match NvcompZstdCodec::new() {
Ok(c) => r = r.with(Arc::new(c)),
Err(e) => tracing::warn!("nvcomp-zstd init failed: {e}"),
}
match NvcompBitcompCodec::default_general() {
Ok(c) => r = r.with(Arc::new(c)),
Err(e) => tracing::warn!("nvcomp-bitcomp init failed: {e}"),
}
match NvcompGDeflateCodec::new() {
Ok(c) => r = r.with(Arc::new(c)),
Err(e) => tracing::warn!("nvcomp-gdeflate init failed: {e}"),
}
r
} else {
tracing::warn!(
"nvcomp-gpu feature is enabled but no CUDA-capable GPU detected at runtime"
);
reg
}
};
Arc::new(reg)
}
fn build_dispatcher(choice: DispatcherChoice, default: CodecKind) -> Arc<dyn CodecDispatcher> {
match choice {
DispatcherChoice::Always => Arc::new(AlwaysDispatcher(default)),
DispatcherChoice::Sampling => Arc::new(SamplingDispatcher::new(default)),
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
let opt = Opt::parse();
setup_tracing(
opt.log_format,
opt.otlp_endpoint.as_deref(),
&opt.service_name,
)?;
let sdk_conf = aws_config::from_env()
.endpoint_url(&opt.endpoint_url)
.load()
.await;
let client = aws_sdk_s3::Client::from_conf(
aws_sdk_s3::config::Builder::from(&sdk_conf)
.force_path_style(true)
.build(),
);
let ready_client = client.clone();
let proxy = s3s_aws::Proxy::from(client);
let default_kind = opt.codec.as_kind();
let registry = build_registry(default_kind, opt.zstd_level);
let dispatcher = build_dispatcher(opt.dispatcher, default_kind);
info!(
codec = ?opt.codec,
dispatcher = ?opt.dispatcher,
registered = ?registry.kinds().collect::<Vec<_>>(),
"S4 codec registry built"
);
let mut s4 = S4Service::new(proxy, registry, dispatcher);
let listener_secure = opt.tls_cert.is_some() || opt.acme.is_some();
s4 = s4.with_secure_transport(listener_secure);
if let Some(ref policy_path) = opt.policy {
let policy = s4_server::policy::Policy::from_path(policy_path)
.map_err(|e| format!("--policy {}: {e}", policy_path.display()))?;
info!(path = %policy_path.display(), "S4 bucket policy loaded");
s4 = s4.with_policy(std::sync::Arc::new(policy));
}
run_server(s4, &sdk_conf, &opt, ready_client).await
}
fn build_ready_check(client: aws_sdk_s3::Client) -> ReadyCheck {
Arc::new(move || {
let c = client.clone();
Box::pin(async move {
match c.list_buckets().send().await {
Ok(_) => Ok(()),
Err(e) => {
let dbg = format!("{e:?}");
if dbg.contains("AccessDenied")
|| dbg.contains("InvalidAccessKeyId")
|| dbg.contains("SignatureDoesNotMatch")
{
Ok(())
} else {
Err(format!("backend list_buckets failed: {e}"))
}
}
}
})
})
}
async fn run_server<S>(
s4: S,
sdk_conf: &aws_config::SdkConfig,
opt: &Opt,
ready_client: aws_sdk_s3::Client,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>>
where
S: S3 + Send + Sync + 'static,
{
let service = {
let mut b = S3ServiceBuilder::new(s4);
if let Some(cred_provider) = sdk_conf.credentials_provider() {
let cred = cred_provider.provide_credentials().await?;
b.set_auth(SimpleAuth::from_single(
cred.access_key_id(),
cred.secret_access_key(),
));
}
if let Some(domain) = &opt.domain {
b.set_host(SingleDomain::new(domain)?);
}
b.build()
};
let ready_check = build_ready_check(ready_client);
let metrics_handle = s4_server::metrics::install();
let routed_service = HealthRouter::new(service, Some(ready_check)).with_metrics(metrics_handle);
let listener = TcpListener::bind((opt.host.as_str(), opt.port)).await?;
let http_server = ConnBuilder::new(TokioExecutor::new());
let graceful = hyper_util::server::graceful::GracefulShutdown::new();
let mut ctrl_c = std::pin::pin!(tokio::signal::ctrl_c());
let tls_state: Option<Arc<s4_server::tls::TlsState>> = match (&opt.tls_cert, &opt.tls_key) {
(Some(cert), Some(key)) => {
s4_server::tls::install_default_crypto_provider();
let state = Arc::new(s4_server::tls::TlsState::load(cert, key)?);
let reload_state = Arc::clone(&state);
tokio::spawn(async move {
use tokio::signal::unix::{SignalKind, signal};
let mut hup = match signal(SignalKind::hangup()) {
Ok(s) => s,
Err(e) => {
tracing::warn!("could not install SIGHUP handler: {e}");
return;
}
};
while hup.recv().await.is_some() {
match reload_state.reload() {
Ok(()) => {
tracing::info!("S4 TLS cert hot-reload succeeded");
s4_server::metrics::record_tls_cert_reload(true);
}
Err(e) => {
tracing::warn!(
"S4 TLS cert hot-reload failed (keeping previous config): {e}"
);
s4_server::metrics::record_tls_cert_reload(false);
}
}
}
});
Some(state)
}
_ => None,
};
let acme_acceptors: Option<Arc<s4_server::acme::AcmeAcceptors>> = match &opt.acme {
Some(domains_csv) => {
s4_server::tls::install_default_crypto_provider();
let domains: Vec<String> = domains_csv
.split(',')
.map(|s| s.trim().to_string())
.collect();
let cache_dir = opt.acme_cache_dir.clone().unwrap_or_else(|| {
let home = std::env::var("HOME").unwrap_or_else(|_| ".".into());
std::path::PathBuf::from(home).join(".s4/acme")
});
info!(
domains = ?domains,
staging = opt.acme_staging,
cache_dir = %cache_dir.display(),
"S4 ACME acceptor bootstrapping"
);
Some(Arc::new(s4_server::acme::bootstrap(
s4_server::acme::AcmeOptions {
domains,
contact: opt.acme_contact.clone(),
cache_dir,
staging: opt.acme_staging,
},
)))
}
None => None,
};
let scheme = if tls_state.is_some() || acme_acceptors.is_some() {
"https"
} else {
"http"
};
info!(
host = %opt.host,
port = opt.port,
scheme,
endpoint_url = %opt.endpoint_url,
"S4 listening (paths /health and /ready served alongside S3 traffic)"
);
loop {
let (socket, _) = tokio::select! {
res = listener.accept() => match res {
Ok(conn) => conn,
Err(err) => {
tracing::error!("accept error: {err}");
continue;
}
},
_ = ctrl_c.as_mut() => break,
};
let svc = routed_service.clone();
let server = http_server.clone();
let watch_handle = graceful.watcher();
if let Some(acceptors) = acme_acceptors.as_ref() {
let acceptors = Arc::clone(acceptors);
tokio::spawn(async move {
match s4_server::acme::accept_one(socket, &acceptors).await {
Ok(Some(tls_stream)) => {
let conn = server.serve_connection(TokioIo::new(tls_stream), svc);
let conn = watch_handle.watch(conn.into_owned());
let _ = conn.await;
}
Ok(None) => {
}
Err(err) => {
tracing::warn!("acme handshake failed: {err}");
}
}
});
} else if let Some(state) = tls_state.as_ref() {
let acceptor = state.acceptor();
tokio::spawn(async move {
let tls_stream = match acceptor.accept(socket).await {
Ok(s) => s,
Err(err) => {
tracing::warn!("tls handshake failed: {err}");
return;
}
};
let conn = server.serve_connection(TokioIo::new(tls_stream), svc);
let conn = watch_handle.watch(conn.into_owned());
let _ = conn.await;
});
} else {
let conn = server.serve_connection(TokioIo::new(socket), svc);
let conn = watch_handle.watch(conn.into_owned());
tokio::spawn(async move {
let _ = conn.await;
});
}
}
tokio::select! {
() = graceful.shutdown() => tracing::debug!("graceful shutdown complete"),
() = tokio::time::sleep(std::time::Duration::from_secs(10)) =>
tracing::warn!("graceful shutdown timeout, aborting"),
}
info!("S4 stopped");
Ok(())
}