Skip to main content

arraylake/
lib.rs

1use std::{
2    collections::{HashMap, HashSet},
3    num::{NonZeroU16, NonZeroUsize},
4    sync::Arc,
5};
6
7use anyhow::{Context, bail};
8use async_trait::async_trait;
9use chrono::{DateTime, TimeDelta, Utc};
10use icechunk::{
11    Storage,
12    config::{
13        GcsBearerCredential, GcsCredentials, GcsCredentialsFetcher, S3Credentials,
14        S3CredentialsFetcher, S3Options, S3StaticCredentials,
15    },
16    format::SnapshotId,
17    new_s3_storage,
18    ops::gc::{Action, ExpireResult, GCConfig, GCSummary},
19    refs::Ref,
20    storage::{new_gcs_storage, new_r2_storage},
21};
22use serde::{Deserialize, Serialize};
23use tracing::{debug, info, instrument};
24use url::Url;
25
26#[derive(Debug, Deserialize, PartialEq, Eq)]
27#[serde(untagged)]
28pub enum ALBucketCredentials {
29    S3 {
30        aws_access_key_id: String,
31        aws_secret_access_key: String,
32        aws_session_token: Option<String>,
33        expiration: Option<DateTime<Utc>>,
34    },
35    GCS {
36        access_token: String,
37        principal: String,
38        expiration: Option<DateTime<Utc>>,
39    },
40}
41
42#[derive(Debug, Deserialize, PartialEq, Eq)]
43#[serde(untagged)]
44pub enum ALBucketExtraConfigValue {
45    B(bool),
46    S(String),
47}
48
49#[derive(Debug, Deserialize, PartialEq, Eq)]
50pub enum ALBucketPlatform {
51    #[serde(rename = "s3")]
52    S3,
53    #[serde(rename = "s3-compatible")]
54    S3Compatible,
55    #[serde(rename = "minio")]
56    Minio,
57    #[serde(rename = "gs")]
58    GS,
59}
60
61#[derive(Debug, Deserialize, PartialEq, Eq)]
62#[serde(tag = "method")]
63pub enum ALBucketAuthConfig {
64    #[serde(rename = "aws_customer_managed_role")]
65    #[serde(alias = "customer_managed_role")]
66    AWSCustomerManagedRole {
67        external_customer_id: String,
68        external_role_name: String,
69        shared_secret: String, // This is a secret string, not used by the client
70    },
71
72    #[serde(rename = "gcp_customer_managed_role")]
73    GCPCustomerManagedRole { target_service_account: String },
74
75    #[serde(rename = "r2_customer_managed_role")]
76    R2CustomerManagedRole {
77        external_accound_id: String,
78        account_api_token: String, // This is a secret string, not used by the client
79        parent_access_key_id: String, // This is a secret string, not used by the client
80        duration_seconds: i32,     // 3600 seconds by default
81    },
82
83    #[serde(rename = "hmac")]
84    Hmac { access_key_id: String, secret_access_key: String },
85
86    #[serde(rename = "anonymous")]
87    Anonymous,
88}
89
90#[derive(Debug, Deserialize, PartialEq, Eq)]
91pub struct ALBucketInfo {
92    pub nickname: String,
93    pub platform: ALBucketPlatform,
94    pub name: String,
95    pub prefix: String,
96    pub is_default: bool,
97    pub extra_config: HashMap<String, ALBucketExtraConfigValue>,
98    pub auth_config: ALBucketAuthConfig,
99}
100
101#[derive(Debug, Deserialize, PartialEq)]
102pub struct OptimizationConfig {
103    pub gc_config: Option<ALGCConfig>,
104    pub expiration_config: Option<ALExpirationConfig>,
105}
106
107#[derive(Debug, Deserialize, PartialEq)]
108pub struct ALRepoInfo {
109    pub name: String,
110    pub org: String,
111    pub bucket: Option<ALBucketInfo>,
112    pub prefix: String,
113    pub kind: String,
114    pub optimization_config: Option<OptimizationConfig>,
115}
116
117#[derive(Debug, Deserialize, PartialEq)]
118pub struct ALGCConfig {
119    pub extra_gc_roots: HashSet<String>,
120    pub dangling_chunks: Option<f64>,
121    pub dangling_manifests: Option<f64>,
122    pub dangling_attributes: Option<f64>,
123    pub dangling_transaction_logs: Option<f64>,
124    pub dangling_snapshots: Option<f64>,
125    pub enabled: bool,
126}
127
128impl TryFrom<ALGCConfig> for GCConfig {
129    type Error = anyhow::Error;
130
131    fn try_from(value: ALGCConfig) -> Result<Self, Self::Error> {
132        fn to_action(now: DateTime<Utc>, seconds: Option<f64>) -> Action {
133            seconds
134                .map(|sec| {
135                    Action::DeleteIfCreatedBefore(
136                        now - TimeDelta::seconds(sec.round() as i64),
137                    )
138                })
139                .unwrap_or(Action::Keep)
140        }
141
142        let now = Utc::now();
143        let config = GCConfig::new(
144            value
145                .extra_gc_roots
146                .into_iter()
147                .map(|s| {
148                    SnapshotId::try_from(s.as_str()).map_err(|error| {
149                        anyhow::anyhow!(
150                            "Cannot parse a SnapshotId from <{}>: {}",
151                            s,
152                            error
153                        )
154                    })
155                })
156                .collect::<Result<_, _>>()?,
157            to_action(now, value.dangling_chunks),
158            to_action(now, value.dangling_manifests),
159            to_action(now, value.dangling_attributes),
160            to_action(now, value.dangling_transaction_logs),
161            to_action(now, value.dangling_snapshots),
162            NonZeroU16::new(1_000).unwrap(),
163            NonZeroUsize::new(5 * 1_024 * 1_024 * 1_024).unwrap(), // 5GB
164            NonZeroU16::new(500).unwrap(),
165            false,
166        );
167        Ok(config)
168    }
169}
170
171#[derive(Debug, Deserialize, PartialEq)]
172pub struct ALExpirationConfig {
173    pub expire_versions_older_than: f64,
174    pub expire_every: Option<f64>,
175    pub enabled: bool,
176}
177
178#[derive(Debug)]
179pub struct ALClient {
180    http: reqwest::Client,
181    uri: Url,
182    token: String,
183}
184
185// We implement default to satisfy the serialization requirements or the credential refreshers
186impl Default for ALClient {
187    fn default() -> Self {
188        Self {
189            http: Default::default(),
190            uri: Url::parse("https://api.earthmover.io").unwrap(),
191            token: Default::default(),
192        }
193    }
194}
195
196impl ALClient {
197    pub fn new(service_uri: Option<String>, token: String) -> anyhow::Result<Self> {
198        let uri = service_uri.unwrap_or("https://api.earthmover.io".to_string());
199        let uri = Url::parse(uri.as_str())?;
200        Ok(ALClient { token, uri, http: reqwest::Client::new() })
201    }
202
203    #[instrument(skip(self))]
204    pub async fn get_credentials(
205        &self,
206        org: &str,
207        repo: &str,
208    ) -> anyhow::Result<ALBucketCredentials> {
209        info!(repo_name = repo, org_name = org, "Getting credentials");
210        let path = format!("repos/{org}/{repo}/bucket-credentials");
211        let url = self.make_url(path.as_str())?;
212        let res = self
213            .http
214            .get(url)
215            .bearer_auth(self.token.clone())
216            .send()
217            .await
218            .context("Fetching credentials")?
219            .json::<ALBucketCredentials>()
220            .await
221            .context("Parsing credentials")?;
222        Ok(res)
223    }
224
225    #[instrument(skip(self))]
226    pub async fn get_repo_info(
227        &self,
228        org: &str,
229        repo: &str,
230    ) -> anyhow::Result<ALRepoInfo> {
231        info!(repo_name = repo, org_name = org, "Getting repository info");
232        let path = format!("repos/{org}/{repo}");
233        let url = self.make_url(path.as_str())?;
234        let res = self
235            .http
236            .get(url)
237            .bearer_auth(self.token.clone())
238            .send()
239            .await
240            .context("Fetching repo info")?
241            .json::<ALRepoInfo>()
242            .await
243            .context("Parsing repo info")?;
244        Ok(res)
245    }
246
247    #[instrument(skip(self))]
248    pub async fn get_storage_for_repo(
249        self: &Arc<Self>,
250        repo: &ALRepoInfo,
251    ) -> anyhow::Result<Arc<dyn Storage>> {
252        info!(?repo, "Getting storage for repo");
253        let Some(bucket) = &repo.bucket else { bail!("Bucket info is missing") };
254
255        if !can_we_get_bucket_credentials(bucket) {
256            bail!("We don't have credentials to the bucket")
257        }
258        let credential_refresher = ALCredentialsFetcher {
259            al: Arc::clone(self),
260            repo: repo.name.clone(),
261            org: repo.org.clone(),
262        };
263        match bucket.platform {
264            ALBucketPlatform::S3 => get_s3_storage(repo, credential_refresher),
265            ALBucketPlatform::S3Compatible => {
266                get_s3_compatible_storage(repo, credential_refresher)
267            }
268            ALBucketPlatform::Minio => {
269                get_s3_compatible_storage(repo, credential_refresher)
270            }
271            ALBucketPlatform::GS => get_gcs_storage(repo, credential_refresher).await,
272        }
273    }
274
275    #[instrument(skip(self))]
276    pub async fn store_gc_results(
277        &self,
278        org: &str,
279        repo: &str,
280        job_run_id: Option<&str>,
281        result: GCSummary,
282    ) -> anyhow::Result<()> {
283        info!(?repo, "Storing gc result");
284        let path = format!("repos/icechunk/{org}/{repo}/gc_results");
285        let url = self.make_url(path.as_str())?;
286        let mut body: HashMap<String, serde_json::Value> = [
287            ("bytes_deleted".to_string(), serde_json::json!(result.bytes_deleted)),
288            ("chunks_deleted".to_string(), serde_json::json!(result.chunks_deleted)),
289            (
290                "manifests_deleted".to_string(),
291                serde_json::json!(result.manifests_deleted),
292            ),
293            (
294                "snapshots_deleted".to_string(),
295                serde_json::json!(result.snapshots_deleted),
296            ),
297            (
298                "attributes_deleted".to_string(),
299                serde_json::json!(result.attributes_deleted),
300            ),
301            (
302                "transaction_logs_deleted".to_string(),
303                serde_json::json!(result.transaction_logs_deleted),
304            ),
305        ]
306        .into_iter()
307        .collect();
308
309        if let Some(id) = job_run_id {
310            body.insert("job_run_id".to_string(), serde_json::json!(id));
311        }
312
313        self.http
314            .post(url)
315            .bearer_auth(self.token.clone())
316            .json(&body)
317            .send()
318            .await
319            .context("Uploading gc results")?
320            .error_for_status()?;
321
322        Ok(())
323    }
324
325    #[instrument(skip(self))]
326    pub async fn store_expiration_results(
327        &self,
328        org: &str,
329        repo: &str,
330        job_run_id: Option<&str>,
331        result: ExpireResult,
332    ) -> anyhow::Result<()> {
333        info!(?repo, "Storing expiration result");
334        let path = format!("repos/icechunk/{org}/{repo}/expiration_results");
335        let url = self.make_url(path.as_str())?;
336        let mut body: HashMap<String, serde_json::Value> = [
337            (
338                "released_snapshots".to_string(),
339                serde_json::json!(
340                    result
341                        .released_snapshots
342                        .into_iter()
343                        .map(|r| r.to_string())
344                        .collect::<HashSet<_>>()
345                ),
346            ),
347            (
348                "edited_snapshots".to_string(),
349                serde_json::json!(
350                    result
351                        .edited_snapshots
352                        .into_iter()
353                        .map(|r| r.to_string())
354                        .collect::<HashSet<_>>()
355                ),
356            ),
357            (
358                "deleted_tags".to_string(),
359                serde_json::json!(
360                    result
361                        .deleted_refs
362                        .iter()
363                        .filter_map(|r| match r {
364                            Ref::Tag(name) => Some(name.clone()),
365                            Ref::Branch(_) => None,
366                        })
367                        .collect::<HashSet<_>>()
368                ),
369            ),
370            (
371                "deleted_branches".to_string(),
372                serde_json::json!(
373                    result
374                        .deleted_refs
375                        .iter()
376                        .filter_map(|r| match r {
377                            Ref::Tag(_) => None,
378                            Ref::Branch(name) => Some(name.clone()),
379                        })
380                        .collect::<HashSet<_>>()
381                ),
382            ),
383        ]
384        .into_iter()
385        .collect();
386
387        if let Some(id) = job_run_id {
388            body.insert("job_run_id".to_string(), serde_json::json!(id));
389        }
390
391        self.http
392            .post(url)
393            .bearer_auth(self.token.clone())
394            .json(&body)
395            .send()
396            .await
397            .context("Uploading expiration results")?
398            .error_for_status()?;
399
400        Ok(())
401    }
402
403    fn make_url(&self, path: &str) -> anyhow::Result<String> {
404        Ok(self.uri.join(path)?.as_str().to_string())
405    }
406}
407
408fn use_delegated_credentials(bucket: &ALBucketInfo) -> bool {
409    matches!(
410        bucket.auth_config,
411        ALBucketAuthConfig::AWSCustomerManagedRole { .. }
412            | ALBucketAuthConfig::GCPCustomerManagedRole { .. }
413            | ALBucketAuthConfig::R2CustomerManagedRole { .. }
414    ) && matches!(
415        bucket.platform,
416        ALBucketPlatform::GS | ALBucketPlatform::S3 | ALBucketPlatform::S3Compatible
417    )
418}
419
420fn use_hmac_credentials(bucket: &ALBucketInfo) -> bool {
421    matches!(bucket.auth_config, ALBucketAuthConfig::Hmac { .. })
422}
423
424fn can_we_get_bucket_credentials(bucket: &ALBucketInfo) -> bool {
425    use_delegated_credentials(bucket) || use_hmac_credentials(bucket)
426}
427
428#[derive(Debug, Serialize, Deserialize)]
429struct ALCredentialsFetcher {
430    #[serde(skip)]
431    al: Arc<ALClient>,
432    repo: String,
433    org: String,
434}
435
436#[async_trait]
437#[typetag::serde]
438impl GcsCredentialsFetcher for ALCredentialsFetcher {
439    async fn get(&self) -> Result<GcsBearerCredential, String> {
440        debug!("Refreshing GCS credentials");
441        match self.al.get_credentials(self.org.as_str(), self.repo.as_str()).await {
442            Ok(ALBucketCredentials::GCS { access_token, expiration, .. }) => {
443                Ok(GcsBearerCredential {
444                    bearer: access_token,
445                    expires_after: expiration,
446                })
447            }
448            Ok(_) => Err("Invalid credential type for GCS bucket".to_string()),
449            Err(_) => Err("Cannot get credentials for GCS bucket".to_string()),
450        }
451    }
452}
453
454#[async_trait]
455#[typetag::serde]
456impl S3CredentialsFetcher for ALCredentialsFetcher {
457    async fn get(&self) -> Result<S3StaticCredentials, String> {
458        debug!("Refreshing S3 credentials");
459        match self.al.get_credentials(self.org.as_str(), self.repo.as_str()).await {
460            Ok(ALBucketCredentials::S3 {
461                aws_access_key_id,
462                aws_secret_access_key,
463                aws_session_token,
464                expiration,
465            }) => Ok(S3StaticCredentials {
466                access_key_id: aws_access_key_id,
467                secret_access_key: aws_secret_access_key,
468                session_token: aws_session_token,
469                expires_after: expiration,
470            }),
471            Ok(_) => Err("Invalid credential type for S3 bucket".to_string()),
472            Err(_) => Err("Cannot get credentials for S3 bucket".to_string()),
473        }
474    }
475}
476
477#[instrument(skip(fetcher))]
478async fn get_gcs_storage(
479    repo: &ALRepoInfo,
480    fetcher: ALCredentialsFetcher,
481) -> anyhow::Result<Arc<dyn Storage>> {
482    use ALBucketAuthConfig::*;
483
484    let Some(bucket) = &repo.bucket else { bail!("Bucket info is missing") };
485
486    let credentials = match &bucket.auth_config {
487        GCPCustomerManagedRole { .. } => GcsCredentials::Refreshable(Arc::new(fetcher)),
488        Anonymous => GcsCredentials::Anonymous,
489        _ => bail!("Cannot create credentials for the bucket"),
490    };
491    new_gcs_storage(
492        bucket.name.clone(),
493        Some(repo.prefix.clone()), // already includes the bucket prefix
494        Some(credentials),
495        None, // FIXME: should we pass something here?
496    )
497    .await
498    .context("Creating GCS storage")
499}
500
501#[instrument(skip(fetcher))]
502fn get_s3_storage(
503    repo: &ALRepoInfo,
504    fetcher: ALCredentialsFetcher,
505) -> anyhow::Result<Arc<dyn Storage>> {
506    use ALBucketAuthConfig::*;
507
508    let Some(bucket) = &repo.bucket else { bail!("Bucket info is missing") };
509
510    let credentials = match &bucket.auth_config {
511        AWSCustomerManagedRole { .. } => S3Credentials::Refreshable(Arc::new(fetcher)),
512        Hmac { access_key_id, secret_access_key } => {
513            S3Credentials::Static(S3StaticCredentials {
514                access_key_id: access_key_id.clone(),
515                secret_access_key: secret_access_key.clone(),
516                session_token: None,
517                expires_after: None,
518            })
519        }
520        Anonymous => S3Credentials::Anonymous,
521        _ => bail!("Cannot create credentials for the bucket"),
522    };
523
524    let region = match bucket.extra_config.get("region_name") {
525        Some(ALBucketExtraConfigValue::S(region)) => Some(region.clone()),
526        _ => None,
527    };
528    let options = S3Options {
529        region,
530        endpoint_url: None,
531        anonymous: false,
532        allow_http: false,
533        force_path_style: false,
534        network_stream_timeout_seconds: None,
535        requester_pays: false,
536    };
537
538    new_s3_storage(
539        options,
540        bucket.name.clone(),
541        Some(repo.prefix.clone()), // already includes the bucket prefix
542        Some(credentials),
543    )
544    .context("Creating S3 storage")
545}
546
547#[instrument(skip(fetcher))]
548fn get_s3_compatible_storage(
549    repo: &ALRepoInfo,
550    fetcher: ALCredentialsFetcher,
551) -> anyhow::Result<Arc<dyn Storage>> {
552    use ALBucketAuthConfig::*;
553    let Some(bucket) = repo.bucket.as_ref() else {
554        bail!("Bucket not found");
555    };
556    let credentials = match &bucket.auth_config {
557        AWSCustomerManagedRole { .. } => S3Credentials::Refreshable(Arc::new(fetcher)),
558        R2CustomerManagedRole { .. } => S3Credentials::Refreshable(Arc::new(fetcher)),
559        Hmac { access_key_id, secret_access_key } => {
560            S3Credentials::Static(S3StaticCredentials {
561                access_key_id: access_key_id.clone(),
562                secret_access_key: secret_access_key.clone(),
563                session_token: None,
564                expires_after: None,
565            })
566        }
567        Anonymous => S3Credentials::Anonymous,
568        _ => bail!("Cannot create credentials for the bucket"),
569    };
570    // FIXME: add tigris, minio and R2
571    let region = match bucket.extra_config.get("region_name") {
572        Some(ALBucketExtraConfigValue::S(region)) => Some(region.clone()),
573        _ => None,
574    };
575    let endpoint_url = match bucket.extra_config.get("endpoint_url") {
576        Some(ALBucketExtraConfigValue::S(ep)) => Some(ep.clone()),
577        _ => None,
578    };
579    let allow_http = match bucket.extra_config.get("use_ssl") {
580        Some(ALBucketExtraConfigValue::B(b)) => *b,
581        _ => false,
582    };
583    let force_path_style = allow_http; // this is what the Python AL client does
584
585    // If its an r2 bucket, use r2 storage
586    let options = S3Options {
587        region,
588        endpoint_url: endpoint_url.clone(),
589        anonymous: false,
590        allow_http,
591        force_path_style,
592        network_stream_timeout_seconds: None,
593        requester_pays: false,
594    };
595
596    if let Some(endpoint_url) = endpoint_url
597        && endpoint_url.contains("r2.cloudflarestorage.com")
598    {
599        new_r2_storage(
600            options,
601            Some(bucket.name.clone()),
602            Some(repo.prefix.clone()),
603            None,
604            Some(credentials),
605        )
606        .context("Creating R2 storage")
607    } else {
608        new_s3_storage(
609            options,
610            bucket.name.clone(),
611            Some(repo.prefix.clone()), // already includes the bucket prefix
612            Some(credentials),
613        )
614        .context("Creating S3 storage")
615    }
616}
617
618#[cfg(test)]
619#[allow(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
620mod tests {
621    use chrono::{TimeZone, Utc};
622
623    use super::{
624        ALBucketAuthConfig, ALBucketCredentials, ALBucketExtraConfigValue, ALBucketInfo,
625        ALBucketPlatform, ALRepoInfo,
626    };
627
628    #[test]
629    fn test_parsing_repo_info() -> Result<(), Box<dyn std::error::Error>> {
630        let json = r#"
631            {
632            "_id": "id",
633            "org": "earthmover-demos",
634            "name": "repo",
635            "created": "2025-04-23T16:31:30+00:00",
636            "updated": "2025-04-23T16:31:30+00:00",
637            "description": null,
638            "metadata": null,
639            "created_by": "abcdef",
640            "visibility": "PRIVATE",
641            "bucket": {
642                "id": "some-bucket-id",
643                "nickname": "nickname",
644                "platform": "s3",
645                "name": "name",
646                "created": "2024-04-17T23:00:43+00:00",
647                "updated": "2025-04-08T13:28:28+00:00",
648                "created_by": "somebody@example.com",
649                "prefix": "bucket_prefix",
650                "extra_config": {
651                    "region_name": "us-east-1",
652                    "anonymous": false
653                },
654                "is_default": true,
655                "auth_config": {
656                    "method": "aws_customer_managed_role",
657                    "external_customer_id": "1234",
658                    "external_role_name": "role",
659                    "shared_secret": "secret"
660                }
661            },
662            "status": {
663                "mode": "online",
664                "message": "new repo creation",
665                "initiated_by": {
666                    "principal_id": "principal",
667                    "system_id": null
668                },
669                "estimated_end_time": null
670            },
671            "kind": "icechunk",
672            "prefix": "repo_prefix",
673            "gc_config": null,
674            "expiration_config": null
675            }
676        "#;
677
678        let info: ALRepoInfo = serde_json::from_str(json)?;
679        let extra_config = [
680            (
681                "region_name".to_string(),
682                ALBucketExtraConfigValue::S("us-east-1".to_string()),
683            ),
684            ("anonymous".to_string(), ALBucketExtraConfigValue::B(false)),
685        ]
686        .into_iter()
687        .collect();
688        let auth_config = ALBucketAuthConfig::AWSCustomerManagedRole {
689            external_customer_id: "1234".to_string(),
690            external_role_name: "role".to_string(),
691            shared_secret: "secret".to_string(),
692        };
693        assert_eq!(
694            info,
695            ALRepoInfo {
696                org: "earthmover-demos".to_string(),
697                name: "repo".to_string(),
698                bucket: Some(ALBucketInfo {
699                    nickname: "nickname".to_string(),
700                    platform: ALBucketPlatform::S3,
701                    name: "name".to_string(),
702                    prefix: "bucket_prefix".to_string(),
703                    is_default: true,
704                    extra_config,
705                    auth_config,
706                }),
707                prefix: "repo_prefix".to_string(),
708                optimization_config: None,
709                kind: "icechunk".to_string(),
710            }
711        );
712
713        Ok(())
714    }
715
716    #[test]
717    fn test_parsing_credentials() -> Result<(), Box<dyn std::error::Error>> {
718        let json = r#"
719            {
720                "aws_access_key_id": "access",
721                "aws_secret_access_key": "secret",
722                "aws_session_token": "token",
723                "expiration": "2025-05-13T19:11:50Z"
724            }
725        "#;
726
727        let expiration = Utc.with_ymd_and_hms(2025, 5, 13, 19, 11, 50).single().unwrap();
728        let creds: ALBucketCredentials = serde_json::from_str(json)?;
729        assert_eq!(
730            creds,
731            ALBucketCredentials::S3 {
732                aws_access_key_id: "access".to_string(),
733                aws_secret_access_key: "secret".to_string(),
734                aws_session_token: Some("token".to_string()),
735                expiration: Some(expiration),
736            }
737        );
738        Ok(())
739    }
740}