1use std::{
2 net::SocketAddr,
3 path::PathBuf,
4 sync::Arc,
5 time::{Duration, SystemTime},
6};
7
8use axum_server::tls_rustls::RustlsConfig;
9use bytesize::ByteSize;
10use http::header::AUTHORIZATION;
11use s2_common::encryption::S2_ENCRYPTION_KEY_HEADER;
12use slatedb::object_store;
13use tokio::time::Instant;
14use tower_http::{
15 cors::CorsLayer,
16 sensitive_headers::SetSensitiveRequestHeadersLayer,
17 trace::{DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse, TraceLayer},
18};
19use tracing::info;
20
21use crate::{backend::Backend, handlers, init};
22
23#[derive(clap::Args, Debug, Clone)]
24pub struct TlsConfig {
25 #[arg(long, conflicts_with_all = ["tls_cert", "tls_key"])]
27 pub tls_self: bool,
28
29 #[arg(long, requires = "tls_key")]
32 pub tls_cert: Option<PathBuf>,
33
34 #[arg(long, requires = "tls_cert")]
37 pub tls_key: Option<PathBuf>,
38}
39
40#[derive(clap::Args, Debug, Clone)]
41pub struct LiteArgs {
42 #[arg(long)]
46 pub bucket: Option<String>,
47
48 #[arg(long, value_name = "DIR", conflicts_with = "bucket")]
52 pub local_root: Option<PathBuf>,
53
54 #[arg(long, default_value = "")]
56 pub path: String,
57
58 #[command(flatten)]
60 pub tls: TlsConfig,
61
62 #[arg(long)]
64 pub port: Option<u16>,
65
66 #[arg(long)]
73 pub no_cors: bool,
74
75 #[arg(long, env = "S2LITE_INIT_FILE")]
81 pub init_file: Option<PathBuf>,
82
83 #[arg(long, default_value = "128MiB")]
85 pub append_inflight_bytes: ByteSize,
86}
87
88#[derive(Debug, Clone)]
89enum StoreType {
90 S3Bucket(String),
91 LocalFileSystem(PathBuf),
92 InMemory,
93}
94
95impl StoreType {
96 fn default_flush_interval(&self) -> Duration {
97 Duration::from_millis(match self {
98 StoreType::S3Bucket(_) => 50,
99 StoreType::LocalFileSystem(_) | StoreType::InMemory => 5,
100 })
101 }
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq)]
105enum ServerProtocol {
106 Http,
107 Https { self_signed: bool },
108}
109
110impl ServerProtocol {
111 fn from_args(args: &LiteArgs) -> Self {
112 if args.tls.tls_self {
113 Self::Https { self_signed: true }
114 } else if args.tls.tls_cert.is_some() {
115 Self::Https { self_signed: false }
116 } else {
117 Self::Http
118 }
119 }
120
121 fn scheme(self) -> &'static str {
122 match self {
123 Self::Http => "http",
124 Self::Https { .. } => "https",
125 }
126 }
127
128 fn default_port(self) -> u16 {
129 match self {
130 Self::Http => 80,
131 Self::Https { .. } => 443,
132 }
133 }
134
135 fn requires_ssl_no_verify(self) -> bool {
136 matches!(self, Self::Https { self_signed: true })
137 }
138}
139
140fn cli_endpoint(protocol: ServerProtocol, port: u16) -> String {
141 format!("{}://localhost:{port}", protocol.scheme())
142}
143
144fn cli_env_hint(protocol: ServerProtocol, port: u16) -> String {
145 let endpoint = cli_endpoint(protocol, port);
146 let mut lines = vec![
147 "copy/paste into a new terminal to point the S2 CLI at this server:".to_string(),
148 format!("export S2_ACCOUNT_ENDPOINT={endpoint}"),
149 format!("export S2_BASIN_ENDPOINT={endpoint}"),
150 "export S2_ACCESS_TOKEN=ignored".to_string(),
151 ];
152
153 if protocol.requires_ssl_no_verify() {
154 lines.push("export S2_SSL_NO_VERIFY=1".to_string());
155 }
156
157 lines.join("\n")
158}
159
160pub async fn run(args: LiteArgs) -> eyre::Result<()> {
161 info!(?args);
162
163 let protocol = ServerProtocol::from_args(&args);
164 let port = args.port.unwrap_or_else(|| protocol.default_port());
165 let addr = format!("0.0.0.0:{port}");
166 let cli_hint = cli_env_hint(protocol, port);
167
168 let store_type = if let Some(bucket) = args.bucket {
169 StoreType::S3Bucket(bucket)
170 } else if let Some(local_root) = args.local_root {
171 StoreType::LocalFileSystem(local_root)
172 } else {
173 StoreType::InMemory
174 };
175
176 let object_store = init_object_store(&store_type).await?;
177
178 let db_settings = slatedb::Settings::from_env_with_default(
179 "SL8_",
180 slatedb::Settings {
181 flush_interval: Some(store_type.default_flush_interval()),
182 ..Default::default()
183 },
184 )?;
185
186 let manifest_poll_interval = db_settings.manifest_poll_interval;
187
188 let db = slatedb::Db::builder(args.path, object_store)
189 .with_settings(db_settings)
190 .build()
191 .await?;
192
193 info!(
194 ?manifest_poll_interval,
195 "sleeping to ensure prior instance fenced out"
196 );
197
198 tokio::time::sleep(manifest_poll_interval).await;
199
200 info!(%args.append_inflight_bytes, "starting backend");
201 let backend = Backend::new(db, args.append_inflight_bytes);
202 crate::backend::bgtasks::spawn(&backend);
203
204 if let Some(init_file) = &args.init_file {
205 let spec = init::load(init_file)?;
206 init::apply(&backend, spec).await?;
207 }
208
209 let mut app = handlers::router()
210 .with_state(backend)
211 .layer(
212 TraceLayer::new_for_http()
213 .make_span_with(DefaultMakeSpan::new().level(tracing::Level::INFO))
214 .on_request(DefaultOnRequest::new().level(tracing::Level::DEBUG))
215 .on_response(DefaultOnResponse::new().level(tracing::Level::INFO)),
216 )
217 .layer(SetSensitiveRequestHeadersLayer::new([
218 AUTHORIZATION,
219 S2_ENCRYPTION_KEY_HEADER.clone(),
220 ]));
221
222 if !args.no_cors {
223 app = app.layer(CorsLayer::very_permissive());
224 }
225
226 let server_handle = axum_server::Handle::new();
227 tokio::spawn(shutdown_signal(server_handle.clone()));
228 match (
229 args.tls.tls_self,
230 args.tls.tls_cert.clone(),
231 args.tls.tls_key.clone(),
232 ) {
233 (false, Some(cert_path), Some(key_path)) => {
234 info!(
235 addr,
236 ?cert_path,
237 "starting https server with provided certificate"
238 );
239 let rustls_config = RustlsConfig::from_pem_file(cert_path, key_path).await?;
240 info!("{}", cli_hint);
241 axum_server::bind_rustls(addr.parse()?, rustls_config)
242 .handle(server_handle)
243 .serve(app.into_make_service())
244 .await?;
245 }
246 (true, None, None) => {
247 info!(
248 addr,
249 "starting https server with self-signed certificate, clients will need to use --insecure"
250 );
251 let rcgen::CertifiedKey { cert, signing_key } = rcgen::generate_simple_self_signed([
252 "localhost".to_string(),
253 "127.0.0.1".to_string(),
254 "::1".to_string(),
255 ])?;
256 let rustls_config = RustlsConfig::from_pem(
257 cert.pem().into_bytes(),
258 signing_key.serialize_pem().into_bytes(),
259 )
260 .await?;
261 info!("{}", cli_hint);
262 axum_server::bind_rustls(addr.parse()?, rustls_config)
263 .handle(server_handle)
264 .serve(app.into_make_service())
265 .await?;
266 }
267 (false, None, None) => {
268 info!(addr, "starting plain http server");
269 info!("{}", cli_hint);
270 axum_server::bind(addr.parse()?)
271 .handle(server_handle)
272 .serve(app.into_make_service())
273 .await?;
274 }
275 _ => {
276 return Err(eyre::eyre!("Invalid TLS configuration"));
278 }
279 }
280
281 Ok(())
282}
283
284async fn init_object_store(
285 store_type: &StoreType,
286) -> eyre::Result<Arc<dyn object_store::ObjectStore>> {
287 Ok(match store_type {
288 StoreType::S3Bucket(bucket) => {
289 info!(bucket, "using s3 object store");
290 let mut builder =
291 object_store::aws::AmazonS3Builder::from_env().with_bucket_name(bucket);
292
293 if let Some(endpoint) =
294 std::env::var_os("AWS_ENDPOINT_URL_S3").and_then(|s| s.into_string().ok())
295 {
296 if endpoint.starts_with("http://") {
297 builder = builder.with_allow_http(true);
298 }
299 builder = builder.with_endpoint(endpoint);
300 }
301
302 match (
303 std::env::var_os("AWS_ACCESS_KEY_ID").and_then(|s| s.into_string().ok()),
304 std::env::var_os("AWS_SECRET_ACCESS_KEY").and_then(|s| s.into_string().ok()),
305 ) {
306 (Some(key_id), Some(secret_key)) => {
307 info!(key_id, "using static credentials from env vars");
308
309 let token =
310 std::env::var_os("AWS_SESSION_TOKEN").and_then(|s| s.into_string().ok());
311 builder = builder.with_credentials(Arc::new(
312 object_store::StaticCredentialProvider::new(
313 object_store::aws::AwsCredential {
314 key_id,
315 secret_key,
316 token,
317 },
318 ),
319 ));
320 }
321 _ => {
322 let aws_config =
323 aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
324 if let Some(region) = aws_config.region() {
325 info!(region = region.as_ref());
326 builder = builder.with_region(region.to_string());
327 }
328 if let Some(credentials_provider) = aws_config.credentials_provider() {
329 info!("using aws-config credentials provider");
330 builder = builder.with_credentials(Arc::new(S3CredentialProvider {
331 aws: credentials_provider.clone(),
332 cache: tokio::sync::Mutex::new(None),
333 }));
334 }
335 }
336 }
337 Arc::new(builder.build()?) as Arc<dyn object_store::ObjectStore>
338 }
339 StoreType::LocalFileSystem(local_root) => {
340 std::fs::create_dir_all(local_root)?;
341 info!(
342 root = %local_root.display(),
343 "using local filesystem object store"
344 );
345 Arc::new(object_store::local::LocalFileSystem::new_with_prefix(
346 local_root,
347 )?)
348 }
349 StoreType::InMemory => {
350 info!("using in-memory object store");
351 Arc::new(object_store::memory::InMemory::new())
352 }
353 })
354}
355
356async fn shutdown_signal(handle: axum_server::Handle<SocketAddr>) {
357 let ctrl_c = async {
358 tokio::signal::ctrl_c().await.expect("ctrl-c");
359 };
360
361 #[cfg(unix)]
362 let term = async {
363 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
364 .expect("SIGTERM")
365 .recv()
366 .await;
367 };
368
369 #[cfg(not(unix))]
370 let term = std::future::pending::<()>();
371
372 tokio::select! {
373 _ = ctrl_c => {
374 info!("received Ctrl+C, starting graceful shutdown");
375 },
376 _ = term => {
377 info!("received SIGTERM, starting graceful shutdown");
378 },
379 }
380
381 handle.graceful_shutdown(Some(Duration::from_secs(10)));
382}
383
384#[derive(Debug)]
385struct CachedCredential {
386 credential: Arc<object_store::aws::AwsCredential>,
387 expiry: Option<SystemTime>,
388}
389
390impl CachedCredential {
391 fn is_valid(&self) -> bool {
392 self.expiry
393 .is_none_or(|exp| exp > SystemTime::now() + Duration::from_secs(60))
394 }
395}
396
397#[derive(Debug)]
398struct S3CredentialProvider {
399 aws: aws_credential_types::provider::SharedCredentialsProvider,
400 cache: tokio::sync::Mutex<Option<CachedCredential>>,
401}
402
403#[async_trait::async_trait]
404impl object_store::CredentialProvider for S3CredentialProvider {
405 type Credential = object_store::aws::AwsCredential;
406
407 async fn get_credential(&self) -> object_store::Result<Arc<object_store::aws::AwsCredential>> {
408 let mut cached = self.cache.lock().await;
409 if let Some(cached) = cached.as_ref().filter(|c| c.is_valid()) {
410 return Ok(cached.credential.clone());
411 }
412
413 use aws_credential_types::provider::ProvideCredentials as _;
414
415 let start = Instant::now();
416 let creds =
417 self.aws
418 .provide_credentials()
419 .await
420 .map_err(|e| object_store::Error::Generic {
421 store: "S3",
422 source: Box::new(e),
423 })?;
424 info!(
425 key_id = creds.access_key_id(),
426 expiry_s = creds
427 .expiry()
428 .and_then(|t| t.duration_since(SystemTime::now()).ok())
429 .map(|d| d.as_secs()),
430 elapsed_ms = start.elapsed().as_millis(),
431 "fetched credentials"
432 );
433 let credential = Arc::new(object_store::aws::AwsCredential {
434 key_id: creds.access_key_id().to_owned(),
435 secret_key: creds.secret_access_key().to_owned(),
436 token: creds.session_token().map(|s| s.to_owned()),
437 });
438 *cached = Some(CachedCredential {
439 credential: credential.clone(),
440 expiry: creds.expiry(),
441 });
442 Ok(credential)
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use super::{ServerProtocol, cli_endpoint, cli_env_hint};
449
450 #[test]
451 fn cli_endpoint_uses_localhost_with_explicit_port() {
452 assert_eq!(
453 cli_endpoint(ServerProtocol::Http, 80),
454 "http://localhost:80"
455 );
456 assert_eq!(
457 cli_endpoint(ServerProtocol::Https { self_signed: false }, 443),
458 "https://localhost:443"
459 );
460 }
461
462 #[test]
463 fn cli_env_hint_includes_exports_for_http() {
464 assert_eq!(
465 cli_env_hint(ServerProtocol::Http, 8080),
466 concat!(
467 "copy/paste into a new terminal to point the S2 CLI at this server:\n",
468 "export S2_ACCOUNT_ENDPOINT=http://localhost:8080\n",
469 "export S2_BASIN_ENDPOINT=http://localhost:8080\n",
470 "export S2_ACCESS_TOKEN=ignored",
471 )
472 );
473 }
474
475 #[test]
476 fn cli_env_hint_includes_ssl_no_verify_for_self_signed_tls() {
477 assert_eq!(
478 cli_env_hint(ServerProtocol::Https { self_signed: true }, 8443),
479 concat!(
480 "copy/paste into a new terminal to point the S2 CLI at this server:\n",
481 "export S2_ACCOUNT_ENDPOINT=https://localhost:8443\n",
482 "export S2_BASIN_ENDPOINT=https://localhost:8443\n",
483 "export S2_ACCESS_TOKEN=ignored\n",
484 "export S2_SSL_NO_VERIFY=1",
485 )
486 );
487 }
488}