1use std::{
2 net::SocketAddr,
3 path::PathBuf,
4 sync::Arc,
5 time::{Duration, SystemTime},
6};
7
8use axum_server::tls_rustls::RustlsConfig;
9use bytesize::ByteSize;
10use slatedb::object_store;
11use tokio::time::Instant;
12use tower_http::{
13 cors::CorsLayer,
14 trace::{DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse, TraceLayer},
15};
16use tracing::info;
17
18use crate::{backend::Backend, handlers, init};
19
20#[derive(clap::Args, Debug, Clone)]
21pub struct TlsConfig {
22 #[arg(long, conflicts_with_all = ["tls_cert", "tls_key"])]
24 pub tls_self: bool,
25
26 #[arg(long, requires = "tls_key")]
29 pub tls_cert: Option<PathBuf>,
30
31 #[arg(long, requires = "tls_cert")]
34 pub tls_key: Option<PathBuf>,
35}
36
37#[derive(clap::Args, Debug, Clone)]
38pub struct LiteArgs {
39 #[arg(long)]
43 pub bucket: Option<String>,
44
45 #[arg(long, value_name = "DIR", conflicts_with = "bucket")]
49 pub local_root: Option<PathBuf>,
50
51 #[arg(long, default_value = "")]
53 pub path: String,
54
55 #[command(flatten)]
57 pub tls: TlsConfig,
58
59 #[arg(long)]
61 pub port: Option<u16>,
62
63 #[arg(long)]
70 pub no_cors: bool,
71
72 #[arg(long, env = "S2LITE_INIT_FILE")]
77 pub init_file: Option<PathBuf>,
78}
79
80#[derive(Debug, Clone)]
81enum StoreType {
82 S3Bucket(String),
83 LocalFileSystem(PathBuf),
84 InMemory,
85}
86
87impl StoreType {
88 fn default_flush_interval(&self) -> Duration {
89 Duration::from_millis(match self {
90 StoreType::S3Bucket(_) => 50,
91 StoreType::LocalFileSystem(_) | StoreType::InMemory => 5,
92 })
93 }
94}
95
96pub async fn run(args: LiteArgs) -> eyre::Result<()> {
97 info!(?args);
98
99 let addr = {
100 let port = args.port.unwrap_or_else(|| {
101 if args.tls.tls_self || args.tls.tls_cert.is_some() {
102 443
103 } else {
104 80
105 }
106 });
107 format!("0.0.0.0:{port}")
108 };
109
110 let store_type = if let Some(bucket) = args.bucket {
111 StoreType::S3Bucket(bucket)
112 } else if let Some(local_root) = args.local_root {
113 StoreType::LocalFileSystem(local_root)
114 } else {
115 StoreType::InMemory
116 };
117
118 let object_store = init_object_store(&store_type).await?;
119
120 let db_settings = slatedb::Settings::from_env_with_default(
121 "SL8_",
122 slatedb::Settings {
123 flush_interval: Some(store_type.default_flush_interval()),
124 ..Default::default()
125 },
126 )?;
127
128 let manifest_poll_interval = db_settings.manifest_poll_interval;
129
130 let append_inflight_max = if std::env::var("S2LITE_PIPELINE")
131 .is_ok_and(|v| v.eq_ignore_ascii_case("true") || v == "1")
132 {
133 info!("pipelining enabled on append sessions up to 25MiB");
134 ByteSize::mib(25)
135 } else {
136 info!("pipelining disabled");
137 ByteSize::b(1)
138 };
139
140 let db = slatedb::Db::builder(args.path, object_store)
141 .with_settings(db_settings)
142 .build()
143 .await?;
144
145 info!(
146 ?manifest_poll_interval,
147 "sleeping to ensure prior instance fenced out"
148 );
149
150 tokio::time::sleep(manifest_poll_interval).await;
151
152 let backend = Backend::new(db, append_inflight_max);
153 crate::backend::bgtasks::spawn(&backend);
154
155 if let Some(init_file) = &args.init_file {
156 let spec = init::load(init_file)?;
157 init::apply(&backend, spec).await?;
158 }
159
160 let mut app = handlers::router().with_state(backend).layer(
161 TraceLayer::new_for_http()
162 .make_span_with(DefaultMakeSpan::new().level(tracing::Level::INFO))
163 .on_request(DefaultOnRequest::new().level(tracing::Level::DEBUG))
164 .on_response(DefaultOnResponse::new().level(tracing::Level::INFO)),
165 );
166
167 if !args.no_cors {
168 app = app.layer(CorsLayer::very_permissive());
169 }
170
171 let server_handle = axum_server::Handle::new();
172 tokio::spawn(shutdown_signal(server_handle.clone()));
173 match (
174 args.tls.tls_self,
175 args.tls.tls_cert.clone(),
176 args.tls.tls_key.clone(),
177 ) {
178 (false, Some(cert_path), Some(key_path)) => {
179 info!(
180 addr,
181 ?cert_path,
182 "starting https server with provided certificate"
183 );
184 let rustls_config = RustlsConfig::from_pem_file(cert_path, key_path).await?;
185 axum_server::bind_rustls(addr.parse()?, rustls_config)
186 .handle(server_handle)
187 .serve(app.into_make_service())
188 .await?;
189 }
190 (true, None, None) => {
191 info!(
192 addr,
193 "starting https server with self-signed certificate, clients will need to use --insecure"
194 );
195 let rcgen::CertifiedKey { cert, signing_key } = rcgen::generate_simple_self_signed([
196 "localhost".to_string(),
197 "127.0.0.1".to_string(),
198 "::1".to_string(),
199 ])?;
200 let rustls_config = RustlsConfig::from_pem(
201 cert.pem().into_bytes(),
202 signing_key.serialize_pem().into_bytes(),
203 )
204 .await?;
205 axum_server::bind_rustls(addr.parse()?, rustls_config)
206 .handle(server_handle)
207 .serve(app.into_make_service())
208 .await?;
209 }
210 (false, None, None) => {
211 info!(addr, "starting plain http server");
212 axum_server::bind(addr.parse()?)
213 .handle(server_handle)
214 .serve(app.into_make_service())
215 .await?;
216 }
217 _ => {
218 return Err(eyre::eyre!("Invalid TLS configuration"));
220 }
221 }
222
223 Ok(())
224}
225
226async fn init_object_store(
227 store_type: &StoreType,
228) -> eyre::Result<Arc<dyn object_store::ObjectStore>> {
229 Ok(match store_type {
230 StoreType::S3Bucket(bucket) => {
231 info!(bucket, "using s3 object store");
232 let mut builder =
233 object_store::aws::AmazonS3Builder::from_env().with_bucket_name(bucket);
234 match (
235 std::env::var_os("AWS_ENDPOINT_URL_S3").and_then(|s| s.into_string().ok()),
236 std::env::var_os("AWS_ACCESS_KEY_ID").and_then(|s| s.into_string().ok()),
237 std::env::var_os("AWS_SECRET_ACCESS_KEY").and_then(|s| s.into_string().ok()),
238 ) {
239 (endpoint, Some(key_id), Some(secret_key)) => {
240 info!(key_id, "using static credentials from env vars");
241 if let Some(endpoint) = endpoint {
242 builder = builder.with_endpoint(endpoint);
243 }
244 builder = builder.with_credentials(Arc::new(
245 object_store::StaticCredentialProvider::new(
246 object_store::aws::AwsCredential {
247 key_id,
248 secret_key,
249 token: None,
250 },
251 ),
252 ));
253 }
254 _ => {
255 let aws_config =
256 aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
257 if let Some(region) = aws_config.region() {
258 info!(region = region.as_ref());
259 builder = builder.with_region(region.to_string());
260 }
261 if let Some(credentials_provider) = aws_config.credentials_provider() {
262 info!("using aws-config credentials provider");
263 builder = builder.with_credentials(Arc::new(S3CredentialProvider {
264 aws: credentials_provider.clone(),
265 cache: tokio::sync::Mutex::new(None),
266 }));
267 }
268 }
269 }
270 Arc::new(builder.build()?) as Arc<dyn object_store::ObjectStore>
271 }
272 StoreType::LocalFileSystem(local_root) => {
273 std::fs::create_dir_all(local_root)?;
274 info!(
275 root = %local_root.display(),
276 "using local filesystem object store"
277 );
278 Arc::new(object_store::local::LocalFileSystem::new_with_prefix(
279 local_root,
280 )?)
281 }
282 StoreType::InMemory => {
283 info!("using in-memory object store");
284 Arc::new(object_store::memory::InMemory::new())
285 }
286 })
287}
288
289async fn shutdown_signal(handle: axum_server::Handle<SocketAddr>) {
290 let ctrl_c = async {
291 tokio::signal::ctrl_c().await.expect("ctrl-c");
292 };
293
294 #[cfg(unix)]
295 let term = async {
296 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
297 .expect("SIGTERM")
298 .recv()
299 .await;
300 };
301
302 #[cfg(not(unix))]
303 let term = std::future::pending::<()>();
304
305 tokio::select! {
306 _ = ctrl_c => {
307 info!("received Ctrl+C, starting graceful shutdown");
308 },
309 _ = term => {
310 info!("received SIGTERM, starting graceful shutdown");
311 },
312 }
313
314 handle.graceful_shutdown(Some(Duration::from_secs(10)));
315}
316
317#[derive(Debug)]
318struct CachedCredential {
319 credential: Arc<object_store::aws::AwsCredential>,
320 expiry: Option<SystemTime>,
321}
322
323impl CachedCredential {
324 fn is_valid(&self) -> bool {
325 self.expiry
326 .is_none_or(|exp| exp > SystemTime::now() + Duration::from_secs(60))
327 }
328}
329
330#[derive(Debug)]
331struct S3CredentialProvider {
332 aws: aws_credential_types::provider::SharedCredentialsProvider,
333 cache: tokio::sync::Mutex<Option<CachedCredential>>,
334}
335
336#[async_trait::async_trait]
337impl object_store::CredentialProvider for S3CredentialProvider {
338 type Credential = object_store::aws::AwsCredential;
339
340 async fn get_credential(&self) -> object_store::Result<Arc<object_store::aws::AwsCredential>> {
341 let mut cached = self.cache.lock().await;
342 if let Some(cached) = cached.as_ref().filter(|c| c.is_valid()) {
343 return Ok(cached.credential.clone());
344 }
345
346 use aws_credential_types::provider::ProvideCredentials as _;
347
348 let start = Instant::now();
349 let creds =
350 self.aws
351 .provide_credentials()
352 .await
353 .map_err(|e| object_store::Error::Generic {
354 store: "S3",
355 source: Box::new(e),
356 })?;
357 info!(
358 key_id = creds.access_key_id(),
359 expiry_s = creds
360 .expiry()
361 .and_then(|t| t.duration_since(SystemTime::now()).ok())
362 .map(|d| d.as_secs()),
363 elapsed_ms = start.elapsed().as_millis(),
364 "fetched credentials"
365 );
366 let credential = Arc::new(object_store::aws::AwsCredential {
367 key_id: creds.access_key_id().to_owned(),
368 secret_key: creds.secret_access_key().to_owned(),
369 token: creds.session_token().map(|s| s.to_owned()),
370 });
371 *cached = Some(CachedCredential {
372 credential: credential.clone(),
373 expiry: creds.expiry(),
374 });
375 Ok(credential)
376 }
377}