#![recursion_limit = "512"]
use std::error::Error;
use std::io::IsTerminal;
use std::sync::Arc;
use aws_credential_types::provider::ProvideCredentials;
use clap::{Parser, ValueEnum};
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto::Builder as ConnBuilder;
use s3s::S3;
use s3s::auth::SimpleAuth;
use s3s::host::SingleDomain;
use s3s::service::S3ServiceBuilder;
use s4_codec::cpu_zstd::CpuZstd;
use s4_codec::dispatcher::{AlwaysDispatcher, SamplingDispatcher};
use s4_codec::passthrough::Passthrough;
use s4_codec::{CodecDispatcher, CodecKind, CodecRegistry};
use s4_server::S4Service;
use s4_server::routing::{HealthRouter, ReadyCheck};
use tokio::net::TcpListener;
use tracing::info;
#[derive(Debug, Clone, Copy, ValueEnum)]
enum CodecChoice {
Passthrough,
CpuZstd,
#[cfg(feature = "nvcomp-gpu")]
NvcompZstd,
#[cfg(feature = "nvcomp-gpu")]
NvcompBitcomp,
}
impl CodecChoice {
fn as_kind(self) -> CodecKind {
match self {
Self::Passthrough => CodecKind::Passthrough,
Self::CpuZstd => CodecKind::CpuZstd,
#[cfg(feature = "nvcomp-gpu")]
Self::NvcompZstd => CodecKind::NvcompZstd,
#[cfg(feature = "nvcomp-gpu")]
Self::NvcompBitcomp => CodecKind::NvcompBitcomp,
}
}
}
#[derive(Debug, Clone, Copy, ValueEnum)]
enum DispatcherChoice {
Always,
Sampling,
}
#[derive(Debug, Clone, Copy, ValueEnum)]
enum LogFormat {
Pretty,
Json,
}
#[derive(Debug, Parser)]
#[command(
name = "s4",
version,
about = "S4 — Squished S3 (GPU 透過圧縮 S3 互換ゲートウェイ)"
)]
struct Opt {
#[clap(long, default_value = "127.0.0.1")]
host: String,
#[clap(long, default_value = "8014")]
port: u16,
#[clap(long)]
domain: Option<String>,
#[clap(long)]
endpoint_url: String,
#[clap(long, value_enum, default_value = "cpu-zstd")]
codec: CodecChoice,
#[clap(long, default_value_t = CpuZstd::DEFAULT_LEVEL)]
zstd_level: i32,
#[clap(long, value_enum, default_value = "sampling")]
dispatcher: DispatcherChoice,
#[clap(long, value_enum, default_value = "pretty")]
log_format: LogFormat,
#[clap(long)]
otlp_endpoint: Option<String>,
#[clap(long, default_value = "s4")]
service_name: String,
}
fn setup_tracing(
format: LogFormat,
otlp_endpoint: Option<&str>,
service_name: &str,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
use tracing_subscriber::EnvFilter;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
let otel_layer = if let Some(endpoint) = otlp_endpoint {
use opentelemetry::trace::TracerProvider;
use opentelemetry_otlp::WithExportConfig;
let exporter = opentelemetry_otlp::SpanExporter::builder()
.with_tonic()
.with_endpoint(endpoint)
.build()?;
let provider = opentelemetry_sdk::trace::SdkTracerProvider::builder()
.with_resource(
opentelemetry_sdk::Resource::builder()
.with_service_name(service_name.to_owned())
.build(),
)
.with_batch_exporter(exporter)
.build();
let tracer = provider.tracer(service_name.to_owned());
opentelemetry::global::set_tracer_provider(provider);
Some(tracing_opentelemetry::layer().with_tracer(tracer))
} else {
None
};
use tracing_subscriber::Layer;
match (format, otel_layer) {
(LogFormat::Pretty, Some(otel)) => {
let fmt_layer = tracing_subscriber::fmt::layer()
.with_ansi(std::io::stdout().is_terminal())
.with_filter(env_filter);
tracing_subscriber::registry()
.with(otel)
.with(fmt_layer)
.init();
}
(LogFormat::Pretty, None) => {
let fmt_layer = tracing_subscriber::fmt::layer()
.with_ansi(std::io::stdout().is_terminal())
.with_filter(env_filter);
tracing_subscriber::registry().with(fmt_layer).init();
}
(LogFormat::Json, Some(otel)) => {
let fmt_layer = tracing_subscriber::fmt::layer()
.json()
.with_current_span(true)
.with_span_list(false)
.with_filter(env_filter);
tracing_subscriber::registry()
.with(otel)
.with(fmt_layer)
.init();
}
(LogFormat::Json, None) => {
let fmt_layer = tracing_subscriber::fmt::layer()
.json()
.with_current_span(true)
.with_span_list(false)
.with_filter(env_filter);
tracing_subscriber::registry().with(fmt_layer).init();
}
}
Ok(())
}
fn build_registry(default: CodecKind, zstd_level: i32) -> Arc<CodecRegistry> {
let reg = CodecRegistry::new(default)
.with(Arc::new(Passthrough))
.with(Arc::new(CpuZstd::new(zstd_level)));
#[cfg(feature = "nvcomp-gpu")]
let reg = {
use s4_codec::nvcomp::{NvcompBitcompCodec, NvcompZstdCodec, is_gpu_available};
if is_gpu_available() {
let mut r = reg;
match NvcompZstdCodec::new() {
Ok(c) => r = r.with(Arc::new(c)),
Err(e) => tracing::warn!("nvcomp-zstd init failed: {e}"),
}
match NvcompBitcompCodec::default_general() {
Ok(c) => r = r.with(Arc::new(c)),
Err(e) => tracing::warn!("nvcomp-bitcomp init failed: {e}"),
}
r
} else {
tracing::warn!(
"nvcomp-gpu feature is enabled but no CUDA-capable GPU detected at runtime"
);
reg
}
};
Arc::new(reg)
}
fn build_dispatcher(choice: DispatcherChoice, default: CodecKind) -> Arc<dyn CodecDispatcher> {
match choice {
DispatcherChoice::Always => Arc::new(AlwaysDispatcher(default)),
DispatcherChoice::Sampling => Arc::new(SamplingDispatcher::new(default)),
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
let opt = Opt::parse();
setup_tracing(
opt.log_format,
opt.otlp_endpoint.as_deref(),
&opt.service_name,
)?;
let sdk_conf = aws_config::from_env()
.endpoint_url(&opt.endpoint_url)
.load()
.await;
let client = aws_sdk_s3::Client::from_conf(
aws_sdk_s3::config::Builder::from(&sdk_conf)
.force_path_style(true)
.build(),
);
let ready_client = client.clone();
let proxy = s3s_aws::Proxy::from(client);
let default_kind = opt.codec.as_kind();
let registry = build_registry(default_kind, opt.zstd_level);
let dispatcher = build_dispatcher(opt.dispatcher, default_kind);
info!(
codec = ?opt.codec,
dispatcher = ?opt.dispatcher,
registered = ?registry.kinds().collect::<Vec<_>>(),
"S4 codec registry built"
);
let s4 = S4Service::new(proxy, registry, dispatcher);
run_server(s4, &sdk_conf, &opt, ready_client).await
}
fn build_ready_check(client: aws_sdk_s3::Client) -> ReadyCheck {
Arc::new(move || {
let c = client.clone();
Box::pin(async move {
match c.list_buckets().send().await {
Ok(_) => Ok(()),
Err(e) => {
let dbg = format!("{e:?}");
if dbg.contains("AccessDenied")
|| dbg.contains("InvalidAccessKeyId")
|| dbg.contains("SignatureDoesNotMatch")
{
Ok(())
} else {
Err(format!("backend list_buckets failed: {e}"))
}
}
}
})
})
}
async fn run_server<S>(
s4: S,
sdk_conf: &aws_config::SdkConfig,
opt: &Opt,
ready_client: aws_sdk_s3::Client,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>>
where
S: S3 + Send + Sync + 'static,
{
let service = {
let mut b = S3ServiceBuilder::new(s4);
if let Some(cred_provider) = sdk_conf.credentials_provider() {
let cred = cred_provider.provide_credentials().await?;
b.set_auth(SimpleAuth::from_single(
cred.access_key_id(),
cred.secret_access_key(),
));
}
if let Some(domain) = &opt.domain {
b.set_host(SingleDomain::new(domain)?);
}
b.build()
};
let ready_check = build_ready_check(ready_client);
let metrics_handle = s4_server::metrics::install();
let routed_service = HealthRouter::new(service, Some(ready_check)).with_metrics(metrics_handle);
let listener = TcpListener::bind((opt.host.as_str(), opt.port)).await?;
let http_server = ConnBuilder::new(TokioExecutor::new());
let graceful = hyper_util::server::graceful::GracefulShutdown::new();
let mut ctrl_c = std::pin::pin!(tokio::signal::ctrl_c());
info!(
host = %opt.host,
port = opt.port,
endpoint_url = %opt.endpoint_url,
"S4 listening (paths /health and /ready served alongside S3 traffic)"
);
loop {
let (socket, _) = tokio::select! {
res = listener.accept() => match res {
Ok(conn) => conn,
Err(err) => {
tracing::error!("accept error: {err}");
continue;
}
},
_ = ctrl_c.as_mut() => break,
};
let conn = http_server.serve_connection(TokioIo::new(socket), routed_service.clone());
let conn = graceful.watch(conn.into_owned());
tokio::spawn(async move {
let _ = conn.await;
});
}
tokio::select! {
() = graceful.shutdown() => tracing::debug!("graceful shutdown complete"),
() = tokio::time::sleep(std::time::Duration::from_secs(10)) =>
tracing::warn!("graceful shutdown timeout, aborting"),
}
info!("S4 stopped");
Ok(())
}