1use std::{
2 net::SocketAddr,
3 path::PathBuf,
4 sync::Arc,
5 time::{Duration, SystemTime},
6};
7
8use axum_server::tls_rustls::RustlsConfig;
9use bytesize::ByteSize;
10use slatedb::object_store;
11use tokio::time::Instant;
12use tower_http::trace::{DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse, TraceLayer};
13use tracing::info;
14
15use crate::{backend::Backend, handlers};
16
17#[derive(clap::Args, Debug, Clone)]
18pub struct TlsConfig {
19 #[arg(long, conflicts_with_all = ["tls_cert", "tls_key"])]
21 pub tls_self: bool,
22
23 #[arg(long, requires = "tls_key")]
26 pub tls_cert: Option<PathBuf>,
27
28 #[arg(long, requires = "tls_cert")]
31 pub tls_key: Option<PathBuf>,
32}
33
34#[derive(clap::Args, Debug, Clone)]
35pub struct LiteArgs {
36 #[arg(long)]
40 pub bucket: Option<String>,
41
42 #[arg(long, value_name = "DIR", conflicts_with = "bucket")]
46 pub local_root: Option<PathBuf>,
47
48 #[arg(long, default_value = "")]
50 pub path: String,
51
52 #[command(flatten)]
54 pub tls: TlsConfig,
55
56 #[arg(long)]
58 pub port: Option<u16>,
59}
60
61#[derive(Debug, Clone)]
62enum StoreType {
63 S3Bucket(String),
64 LocalFileSystem(PathBuf),
65 InMemory,
66}
67
68impl StoreType {
69 fn default_flush_interval(&self) -> Duration {
70 Duration::from_millis(match self {
71 StoreType::S3Bucket(_) => 50,
72 StoreType::LocalFileSystem(_) | StoreType::InMemory => 5,
73 })
74 }
75}
76
77pub async fn run(args: LiteArgs) -> eyre::Result<()> {
78 info!(?args);
79
80 let addr = {
81 let port = args.port.unwrap_or_else(|| {
82 if args.tls.tls_self || args.tls.tls_cert.is_some() {
83 443
84 } else {
85 80
86 }
87 });
88 format!("0.0.0.0:{port}")
89 };
90
91 let store_type = if let Some(bucket) = args.bucket {
92 StoreType::S3Bucket(bucket)
93 } else if let Some(local_root) = args.local_root {
94 StoreType::LocalFileSystem(local_root)
95 } else {
96 StoreType::InMemory
97 };
98
99 let object_store = init_object_store(&store_type).await?;
100
101 let db_settings = slatedb::Settings::from_env_with_default(
102 "SL8_",
103 slatedb::Settings {
104 flush_interval: Some(store_type.default_flush_interval()),
105 ..Default::default()
106 },
107 )?;
108
109 let manifest_poll_interval = db_settings.manifest_poll_interval;
110
111 let append_inflight_max = if std::env::var("S2LITE_PIPELINE")
112 .is_ok_and(|v| v.eq_ignore_ascii_case("true") || v == "1")
113 {
114 info!("pipelining enabled on append sessions up to 25MiB");
115 ByteSize::mib(25)
116 } else {
117 info!("pipelining disabled");
118 ByteSize::b(1)
119 };
120
121 let db = slatedb::Db::builder(args.path, object_store)
122 .with_settings(db_settings)
123 .build()
124 .await?;
125
126 info!(
127 ?manifest_poll_interval,
128 "sleeping to ensure prior instance fenced out"
129 );
130
131 tokio::time::sleep(manifest_poll_interval).await;
132
133 let app = handlers::router()
134 .with_state(Backend::new(db, append_inflight_max))
135 .layer(
136 TraceLayer::new_for_http()
137 .make_span_with(DefaultMakeSpan::new().level(tracing::Level::INFO))
138 .on_request(DefaultOnRequest::new().level(tracing::Level::DEBUG))
139 .on_response(DefaultOnResponse::new().level(tracing::Level::INFO)),
140 );
141
142 let server_handle = axum_server::Handle::new();
143 tokio::spawn(shutdown_signal(server_handle.clone()));
144 match (
145 args.tls.tls_self,
146 args.tls.tls_cert.clone(),
147 args.tls.tls_key.clone(),
148 ) {
149 (false, Some(cert_path), Some(key_path)) => {
150 info!(
151 addr,
152 ?cert_path,
153 "starting https server with provided certificate"
154 );
155 let rustls_config = RustlsConfig::from_pem_file(cert_path, key_path).await?;
156 axum_server::bind_rustls(addr.parse()?, rustls_config)
157 .handle(server_handle)
158 .serve(app.into_make_service())
159 .await?;
160 }
161 (true, None, None) => {
162 info!(
163 addr,
164 "starting https server with self-signed certificate, clients will need to use --insecure"
165 );
166 let rcgen::CertifiedKey { cert, signing_key } = rcgen::generate_simple_self_signed([
167 "localhost".to_string(),
168 "127.0.0.1".to_string(),
169 "::1".to_string(),
170 ])?;
171 let rustls_config = RustlsConfig::from_pem(
172 cert.pem().into_bytes(),
173 signing_key.serialize_pem().into_bytes(),
174 )
175 .await?;
176 axum_server::bind_rustls(addr.parse()?, rustls_config)
177 .handle(server_handle)
178 .serve(app.into_make_service())
179 .await?;
180 }
181 (false, None, None) => {
182 info!(addr, "starting plain http server");
183 axum_server::bind(addr.parse()?)
184 .handle(server_handle)
185 .serve(app.into_make_service())
186 .await?;
187 }
188 _ => {
189 return Err(eyre::eyre!("Invalid TLS configuration"));
191 }
192 }
193
194 Ok(())
195}
196
197async fn init_object_store(
198 store_type: &StoreType,
199) -> eyre::Result<Arc<dyn object_store::ObjectStore>> {
200 Ok(match store_type {
201 StoreType::S3Bucket(bucket) => {
202 info!(bucket, "using s3 object store");
203 let mut builder =
204 object_store::aws::AmazonS3Builder::from_env().with_bucket_name(bucket);
205 match (
206 std::env::var_os("AWS_ENDPOINT_URL_S3").and_then(|s| s.into_string().ok()),
207 std::env::var_os("AWS_ACCESS_KEY_ID").and_then(|s| s.into_string().ok()),
208 std::env::var_os("AWS_SECRET_ACCESS_KEY").and_then(|s| s.into_string().ok()),
209 ) {
210 (endpoint, Some(key_id), Some(secret_key)) => {
211 info!(key_id, "using static credentials from env vars");
212 if let Some(endpoint) = endpoint {
213 builder = builder.with_endpoint(endpoint);
214 }
215 builder = builder.with_credentials(Arc::new(
216 object_store::StaticCredentialProvider::new(
217 object_store::aws::AwsCredential {
218 key_id,
219 secret_key,
220 token: None,
221 },
222 ),
223 ));
224 }
225 _ => {
226 let aws_config =
227 aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
228 if let Some(region) = aws_config.region() {
229 info!(region = region.as_ref());
230 builder = builder.with_region(region.to_string());
231 }
232 if let Some(credentials_provider) = aws_config.credentials_provider() {
233 info!("using aws-config credentials provider");
234 builder = builder.with_credentials(Arc::new(S3CredentialProvider {
235 aws: credentials_provider.clone(),
236 cache: tokio::sync::Mutex::new(None),
237 }));
238 }
239 }
240 }
241 Arc::new(builder.build()?) as Arc<dyn object_store::ObjectStore>
242 }
243 StoreType::LocalFileSystem(local_root) => {
244 std::fs::create_dir_all(local_root)?;
245 info!(
246 root = %local_root.display(),
247 "using local filesystem object store"
248 );
249 Arc::new(object_store::local::LocalFileSystem::new_with_prefix(
250 local_root,
251 )?)
252 }
253 StoreType::InMemory => {
254 info!("using in-memory object store");
255 Arc::new(object_store::memory::InMemory::new())
256 }
257 })
258}
259
260async fn shutdown_signal(handle: axum_server::Handle<SocketAddr>) {
261 let ctrl_c = async {
262 tokio::signal::ctrl_c().await.expect("ctrl-c");
263 };
264
265 #[cfg(unix)]
266 let term = async {
267 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
268 .expect("SIGTERM")
269 .recv()
270 .await;
271 };
272
273 #[cfg(not(unix))]
274 let term = std::future::pending::<()>();
275
276 tokio::select! {
277 _ = ctrl_c => {
278 info!("received Ctrl+C, starting graceful shutdown");
279 },
280 _ = term => {
281 info!("received SIGTERM, starting graceful shutdown");
282 },
283 }
284
285 handle.graceful_shutdown(Some(Duration::from_secs(10)));
286}
287
288#[derive(Debug)]
289struct CachedCredential {
290 credential: Arc<object_store::aws::AwsCredential>,
291 expiry: Option<SystemTime>,
292}
293
294impl CachedCredential {
295 fn is_valid(&self) -> bool {
296 self.expiry
297 .is_none_or(|exp| exp > SystemTime::now() + Duration::from_secs(60))
298 }
299}
300
301#[derive(Debug)]
302struct S3CredentialProvider {
303 aws: aws_credential_types::provider::SharedCredentialsProvider,
304 cache: tokio::sync::Mutex<Option<CachedCredential>>,
305}
306
307#[async_trait::async_trait]
308impl object_store::CredentialProvider for S3CredentialProvider {
309 type Credential = object_store::aws::AwsCredential;
310
311 async fn get_credential(&self) -> object_store::Result<Arc<object_store::aws::AwsCredential>> {
312 let mut cached = self.cache.lock().await;
313 if let Some(cached) = cached.as_ref().filter(|c| c.is_valid()) {
314 return Ok(cached.credential.clone());
315 }
316
317 use aws_credential_types::provider::ProvideCredentials as _;
318
319 let start = Instant::now();
320 let creds =
321 self.aws
322 .provide_credentials()
323 .await
324 .map_err(|e| object_store::Error::Generic {
325 store: "S3",
326 source: Box::new(e),
327 })?;
328 info!(
329 key_id = creds.access_key_id(),
330 expiry_s = creds
331 .expiry()
332 .and_then(|t| t.duration_since(SystemTime::now()).ok())
333 .map(|d| d.as_secs()),
334 elapsed_ms = start.elapsed().as_millis(),
335 "fetched credentials"
336 );
337 let credential = Arc::new(object_store::aws::AwsCredential {
338 key_id: creds.access_key_id().to_owned(),
339 secret_key: creds.secret_access_key().to_owned(),
340 token: creds.session_token().map(|s| s.to_owned()),
341 });
342 *cached = Some(CachedCredential {
343 credential: credential.clone(),
344 expiry: creds.expiry(),
345 });
346 Ok(credential)
347 }
348}