Skip to main content

secrets_provider_aws_sm/
lib.rs

1use anyhow::{Context, Result};
2use aws_config::BehaviorVersion;
3use aws_sdk_kms::{Client as KmsClient, primitives::Blob as KmsBlob};
4use aws_sdk_secretsmanager::Client as SecretsManagerClient;
5use aws_sdk_secretsmanager::error::{ProvideErrorMetadata, SdkError};
6use aws_sdk_secretsmanager::operation::list_secret_version_ids::ListSecretVersionIdsError;
7use aws_sdk_secretsmanager::types::{Filter, FilterNameStringType, Tag};
8use aws_types::region::Region;
9use greentic_secrets_core::rt;
10use greentic_secrets_spec::{
11    KeyProvider, Scope, SecretListItem, SecretRecord, SecretUri, SecretVersion, SecretsBackend,
12    SecretsError, SecretsResult, VersionedSecret,
13};
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::env;
17
18const DEFAULT_PREFIX: &str = "gtsec";
19const DEFAULT_STAGE: &str = "AWSCURRENT";
20const PREFIX_ENV: &str = "GREENTIC_AWS_SECRET_PREFIX";
21const STAGE_ENV: &str = "GREENTIC_AWS_VERSION_STAGE";
22const KMS_KEY_ENV: &str = "GREENTIC_AWS_KMS_KEY_ID";
23const REGION_ENV: &str = "GREENTIC_AWS_REGION";
24const TEAM_PLACEHOLDER: &str = "_";
25const SM_ENDPOINT_ENV: &str = "GREENTIC_AWS_SM_ENDPOINT";
26const KMS_ENDPOINT_ENV: &str = "GREENTIC_AWS_KMS_ENDPOINT";
27
28/// Components returned for integration with the broker/core wiring.
29pub struct BackendComponents {
30    pub backend: Box<dyn SecretsBackend>,
31    pub key_provider: Box<dyn KeyProvider>,
32}
33
34/// Build the AWS Secrets Manager backend and corresponding KMS key provider.
35pub async fn build_backend() -> Result<BackendComponents> {
36    let (config, shared_config) = AwsProviderConfig::load_from_env().await?;
37
38    let secrets_client = {
39        let mut builder = aws_sdk_secretsmanager::config::Builder::from(&shared_config);
40        if let Some(endpoint) = config.secrets_endpoint.as_deref() {
41            builder = builder.endpoint_url(endpoint);
42        }
43        SecretsManagerClient::from_conf(builder.build())
44    };
45
46    let kms_client = {
47        let mut builder = aws_sdk_kms::config::Builder::from(&shared_config);
48        if let Some(endpoint) = config.kms_endpoint.as_deref() {
49            builder = builder.endpoint_url(endpoint);
50        }
51        KmsClient::from_conf(builder.build())
52    };
53
54    let backend = AwsSecretsBackend::new(secrets_client, config.clone());
55    let key_provider = AwsKmsKeyProvider::new(kms_client, config.clone());
56
57    Ok(BackendComponents {
58        backend: Box::new(backend),
59        key_provider: Box::new(key_provider),
60    })
61}
62
63#[derive(Clone)]
64struct AwsProviderConfig {
65    secret_prefix: String,
66    version_stage: String,
67    kms_key_id: String,
68    secrets_endpoint: Option<String>,
69    kms_endpoint: Option<String>,
70    resource_tags: Vec<aws_sdk_secretsmanager::types::Tag>,
71}
72
73impl AwsProviderConfig {
74    async fn load_from_env() -> Result<(Self, aws_types::SdkConfig)> {
75        let prefix = env::var(PREFIX_ENV).unwrap_or_else(|_| DEFAULT_PREFIX.to_string());
76        let stage = env::var(STAGE_ENV).unwrap_or_else(|_| DEFAULT_STAGE.to_string());
77        let kms_key_id = env::var(KMS_KEY_ENV)
78            .context("GREENTIC_AWS_KMS_KEY_ID must be set for the AWS provider")?;
79        let mut loader = aws_config::defaults(BehaviorVersion::latest());
80        if let Ok(region) = env::var(REGION_ENV) {
81            loader = loader.region(Region::new(region));
82        }
83
84        let shared_config = loader.load().await;
85        let secrets_endpoint = env::var(SM_ENDPOINT_ENV)
86            .ok()
87            .filter(|s| !s.trim().is_empty());
88        let kms_endpoint = env::var(KMS_ENDPOINT_ENV)
89            .ok()
90            .filter(|s| !s.trim().is_empty());
91        let resource_tags = build_resource_tags();
92
93        Ok((
94            Self {
95                secret_prefix: prefix,
96                version_stage: stage,
97                kms_key_id,
98                secrets_endpoint,
99                kms_endpoint,
100                resource_tags,
101            },
102            shared_config,
103        ))
104    }
105
106    fn secret_name(&self, uri: &SecretUri) -> String {
107        format!(
108            "{}/{}/{}/{}/{}/{}",
109            self.secret_prefix,
110            uri.scope().env(),
111            uri.scope().tenant(),
112            uri.scope().team().unwrap_or(TEAM_PLACEHOLDER),
113            uri.category(),
114            uri.name()
115        )
116    }
117
118    fn scope_prefix(&self, scope: &Scope) -> String {
119        format!(
120            "{prefix}/{env}/{tenant}/",
121            prefix = self.secret_prefix,
122            env = scope.env(),
123            tenant = scope.tenant()
124        )
125    }
126}
127
128fn build_resource_tags() -> Vec<Tag> {
129    let mut tags = Vec::new();
130    tags.push(Tag::builder().key("greentic-ci").value("true").build());
131    if let Ok(repo) = env::var("GITHUB_REPOSITORY") {
132        tags.push(Tag::builder().key("greentic-repo").value(repo).build());
133    }
134    if let Ok(run_id) = env::var("GITHUB_RUN_ID") {
135        let attempt = env::var("GITHUB_RUN_ATTEMPT").unwrap_or_default();
136        let combined = if attempt.is_empty() {
137            run_id
138        } else {
139            format!("{run_id}/{attempt}")
140        };
141        tags.push(Tag::builder().key("greentic-run").value(combined).build());
142    }
143    tags
144}
145
146#[derive(Clone)]
147pub struct AwsSecretsBackend {
148    client: SecretsManagerClient,
149    config: AwsProviderConfig,
150}
151
152async fn fetch_secret_version_inner(
153    client: SecretsManagerClient,
154    secret_id: String,
155    version_id: Option<String>,
156) -> SecretsResult<Option<StoredSecret>> {
157    let mut request = client.get_secret_value().secret_id(secret_id);
158    if let Some(version) = version_id {
159        request = request.version_id(version);
160    }
161    match request.send().await {
162        Ok(output) => deserialize_secret_payload(output.secret_string(), output.secret_binary()),
163        Err(err) => {
164            if is_not_found(&err) {
165                Ok(None)
166            } else {
167                Err(storage_error("get_secret_value", err))
168            }
169        }
170    }
171}
172
173impl AwsSecretsBackend {
174    fn new(client: SecretsManagerClient, config: AwsProviderConfig) -> Self {
175        Self { client, config }
176    }
177
178    fn fetch_latest_version(&self, secret_id: &str) -> SecretsResult<Option<StoredSecret>> {
179        let client = self.client.clone();
180        let secret_id = secret_id.to_owned();
181        rt::sync_await(async move { fetch_secret_version_inner(client, secret_id, None).await })
182    }
183
184    fn load_all_versions(&self, uri: &SecretUri) -> SecretsResult<Vec<StoredSecret>> {
185        let client = self.client.clone();
186        let secret_id = self.config.secret_name(uri);
187        rt::sync_await(async move {
188            let mut collected = Vec::new();
189            let mut token: Option<String> = None;
190
191            loop {
192                let mut request = client
193                    .list_secret_version_ids()
194                    .secret_id(secret_id.clone())
195                    .include_deprecated(true);
196
197                if let Some(ref next) = token {
198                    request = request.next_token(next);
199                }
200
201                let response = match request.send().await {
202                    Ok(resp) => resp,
203                    Err(err) => {
204                        if is_not_found(&err) {
205                            return Ok(Vec::new());
206                        }
207                        if list_versions_unsupported(&err) {
208                            let latest =
209                                fetch_secret_version_inner(client.clone(), secret_id.clone(), None)
210                                    .await?;
211                            return Ok(latest.into_iter().collect());
212                        }
213                        return Err(storage_error("list_secret_version_ids", err));
214                    }
215                };
216
217                for entry in response.versions() {
218                    if let Some(version_id) = entry.version_id()
219                        && let Some(stored) = fetch_secret_version_inner(
220                            client.clone(),
221                            secret_id.clone(),
222                            Some(version_id.to_string()),
223                        )
224                        .await?
225                    {
226                        collected.push(stored);
227                    }
228                }
229
230                if let Some(next) = response.next_token() {
231                    token = Some(next.to_string());
232                } else {
233                    break;
234                }
235            }
236
237            collected.sort_by_key(|item| item.version);
238            Ok(collected)
239        })
240    }
241
242    fn ensure_secret_created(
243        &self,
244        secret_id: &str,
245        payload: &str,
246        record: Option<&SecretRecord>,
247    ) -> SecretsResult<bool> {
248        let client = self.client.clone();
249        let secret_id = secret_id.to_owned();
250        let payload = payload.to_owned();
251        let description = record.and_then(|rec| rec.meta.description.clone());
252        let config = self.config.clone();
253        rt::sync_await(async move {
254            let mut request = client
255                .create_secret()
256                .name(secret_id.clone())
257                .secret_string(payload.clone());
258            if !config.resource_tags.is_empty() {
259                request = request.set_tags(Some(config.resource_tags.clone()));
260            }
261            if let Some(desc) = description.as_ref()
262                && !desc.is_empty()
263            {
264                request = request.description(desc.clone());
265            }
266
267            match request.send().await {
268                Ok(_) => Ok(true),
269                Err(err) => {
270                    if let SdkError::ServiceError(context) = &err
271                        && context.err().is_resource_exists_exception()
272                    {
273                        return Ok(false);
274                    }
275                    Err(storage_error("create_secret", err))
276                }
277            }
278        })
279    }
280
281    fn write_new_version(&self, secret_id: &str, payload: &str) -> SecretsResult<()> {
282        let client = self.client.clone();
283        let secret_id = secret_id.to_owned();
284        let payload = payload.to_owned();
285        let version_stage = self.config.version_stage.clone();
286        rt::sync_await(async move {
287            match client
288                .put_secret_value()
289                .secret_id(secret_id)
290                .secret_string(payload)
291                .set_version_stages(Some(vec![version_stage]))
292                .send()
293                .await
294            {
295                Ok(_) => Ok(()),
296                Err(err) => Err(storage_error("put_secret_value", err)),
297            }
298        })
299    }
300
301    fn list_scope(
302        &self,
303        scope: &Scope,
304        category_prefix: Option<&str>,
305        name_prefix: Option<&str>,
306    ) -> SecretsResult<Vec<SecretListItem>> {
307        let prefix = self.config.scope_prefix(scope);
308        let client = self.client.clone();
309        let secret_prefix = self.config.secret_prefix.clone();
310        let scope_env = scope.env().to_string();
311        let scope_tenant = scope.tenant().to_string();
312        let scope_team = scope.team().map(|s| s.to_string());
313        let category_prefix = category_prefix.map(|s| s.to_string());
314        let name_prefix = name_prefix.map(|s| s.to_string());
315        rt::sync_await(async move {
316            let mut items = Vec::new();
317            let mut token: Option<String> = None;
318
319            loop {
320                let mut request = client.list_secrets();
321                let filter = Filter::builder()
322                    .key(FilterNameStringType::Name)
323                    .values(prefix.clone())
324                    .build();
325                request = request.filters(filter);
326                if let Some(ref next) = token {
327                    request = request.next_token(next);
328                }
329
330                let response = match request.send().await {
331                    Ok(resp) => resp,
332                    Err(err) => return Err(storage_error("list_secrets", err)),
333                };
334
335                for entry in response.secret_list() {
336                    let name = match entry.name() {
337                        Some(value) => value.to_string(),
338                        None => continue,
339                    };
340                    if !name.starts_with(&prefix) {
341                        continue;
342                    }
343                    let uri = match parse_secret_name(&secret_prefix, &name) {
344                        Some(uri) => uri,
345                        None => continue,
346                    };
347                    if uri.scope().env() != scope_env {
348                        continue;
349                    }
350                    if uri.scope().tenant() != scope_tenant {
351                        continue;
352                    }
353                    if let Some(ref team) = scope_team
354                        && uri.scope().team() != Some(team.as_str())
355                    {
356                        continue;
357                    }
358                    if let Some(ref cat_prefix) = category_prefix
359                        && !uri.category().starts_with(cat_prefix)
360                    {
361                        continue;
362                    }
363                    if let Some(ref name_prefix) = name_prefix
364                        && !uri.name().starts_with(name_prefix)
365                    {
366                        continue;
367                    }
368
369                    if let Some(stored) =
370                        fetch_secret_version_inner(client.clone(), name.clone(), None).await?
371                    {
372                        if stored.deleted {
373                            continue;
374                        }
375                        if let Some(record) = stored.record {
376                            let latest = Some(stored.version.to_string());
377                            items.push(SecretListItem::from_meta(&record.meta, latest));
378                        }
379                    }
380                }
381
382                if let Some(next) = response.next_token() {
383                    token = Some(next.to_string());
384                } else {
385                    break;
386                }
387            }
388
389            Ok(items)
390        })
391    }
392}
393
394impl SecretsBackend for AwsSecretsBackend {
395    fn put(&self, record: SecretRecord) -> SecretsResult<SecretVersion> {
396        let secret_id = self.config.secret_name(&record.meta.uri);
397
398        let versions = self.load_all_versions(&record.meta.uri)?;
399        let next_version = versions
400            .iter()
401            .map(|stored| stored.version)
402            .max()
403            .unwrap_or(0)
404            .saturating_add(1);
405
406        let stored = StoredSecret::live(next_version, record.clone());
407        let payload = serde_json::to_string(&stored)
408            .map_err(|err| SecretsError::Storage(format!("serialize secret payload: {err}")))?;
409
410        if versions.is_empty() {
411            let created = self.ensure_secret_created(&secret_id, &payload, Some(&record))?;
412            if !created {
413                self.write_new_version(&secret_id, &payload)?;
414            }
415        } else {
416            self.write_new_version(&secret_id, &payload)?;
417        }
418
419        Ok(SecretVersion {
420            version: next_version,
421            deleted: false,
422        })
423    }
424
425    fn get(&self, uri: &SecretUri, version: Option<u64>) -> SecretsResult<Option<VersionedSecret>> {
426        let secret_id = self.config.secret_name(uri);
427
428        if let Some(version) = version {
429            let versions = self.load_all_versions(uri)?;
430            return Ok(versions
431                .into_iter()
432                .find(|stored| stored.version == version)
433                .map(|stored| stored.into_versioned()));
434        }
435
436        match self.fetch_latest_version(&secret_id)? {
437            Some(stored) if !stored.deleted => Ok(Some(stored.into_versioned())),
438            _ => Ok(None),
439        }
440    }
441
442    fn list(
443        &self,
444        scope: &Scope,
445        category_prefix: Option<&str>,
446        name_prefix: Option<&str>,
447    ) -> SecretsResult<Vec<SecretListItem>> {
448        self.list_scope(scope, category_prefix, name_prefix)
449    }
450
451    fn delete(&self, uri: &SecretUri) -> SecretsResult<SecretVersion> {
452        let secret_id = self.config.secret_name(uri);
453        let versions = self.load_all_versions(uri)?;
454        if versions.is_empty() {
455            return Err(SecretsError::NotFound {
456                entity: uri.to_string(),
457            });
458        }
459
460        let next_version = versions
461            .iter()
462            .map(|stored| stored.version)
463            .max()
464            .unwrap_or(0)
465            .saturating_add(1);
466
467        let stored = StoredSecret::tombstone(next_version);
468        let payload = serde_json::to_string(&stored)
469            .map_err(|err| SecretsError::Storage(format!("serialize tombstone payload: {err}")))?;
470
471        self.write_new_version(&secret_id, &payload)?;
472
473        Ok(SecretVersion {
474            version: next_version,
475            deleted: true,
476        })
477    }
478
479    fn versions(&self, uri: &SecretUri) -> SecretsResult<Vec<SecretVersion>> {
480        Ok(self
481            .load_all_versions(uri)?
482            .into_iter()
483            .map(|stored| SecretVersion {
484                version: stored.version,
485                deleted: stored.deleted,
486            })
487            .collect())
488    }
489
490    fn exists(&self, uri: &SecretUri) -> SecretsResult<bool> {
491        Ok(self.get(uri, None)?.is_some())
492    }
493}
494
495#[derive(Clone)]
496pub struct AwsKmsKeyProvider {
497    client: KmsClient,
498    key_id: String,
499}
500
501impl AwsKmsKeyProvider {
502    fn new(client: KmsClient, config: AwsProviderConfig) -> Self {
503        Self {
504            client,
505            key_id: config.kms_key_id,
506        }
507    }
508
509    fn context(scope: &Scope) -> HashMap<String, String> {
510        let mut ctx = HashMap::new();
511        ctx.insert("env".into(), scope.env().to_string());
512        ctx.insert("tenant".into(), scope.tenant().to_string());
513        if let Some(team) = scope.team() {
514            ctx.insert("team".into(), team.to_string());
515        }
516        ctx
517    }
518}
519
520impl KeyProvider for AwsKmsKeyProvider {
521    fn wrap_dek(&self, scope: &Scope, dek: &[u8]) -> SecretsResult<Vec<u8>> {
522        let context = Self::context(scope);
523        let client = self.client.clone();
524        let key_id = self.key_id.clone();
525        let dek = dek.to_vec();
526        rt::sync_await(async move {
527            match client
528                .encrypt()
529                .key_id(&key_id)
530                .set_encryption_context(Some(context))
531                .plaintext(KmsBlob::new(dek))
532                .send()
533                .await
534            {
535                Ok(output) => output
536                    .ciphertext_blob()
537                    .map(|blob| blob.as_ref().to_vec())
538                    .ok_or_else(|| {
539                        SecretsError::Backend("kms encrypt returned no ciphertext".into())
540                    }),
541                Err(err) => Err(SecretsError::Backend(format!("kms encrypt: {err}"))),
542            }
543        })
544    }
545
546    fn unwrap_dek(&self, scope: &Scope, wrapped: &[u8]) -> SecretsResult<Vec<u8>> {
547        let context = Self::context(scope);
548        let client = self.client.clone();
549        let key_id = self.key_id.clone();
550        let wrapped = wrapped.to_vec();
551        rt::sync_await(async move {
552            match client
553                .decrypt()
554                .key_id(&key_id)
555                .set_encryption_context(Some(context))
556                .ciphertext_blob(KmsBlob::new(wrapped))
557                .send()
558                .await
559            {
560                Ok(output) => output
561                    .plaintext()
562                    .map(|blob| blob.as_ref().to_vec())
563                    .ok_or_else(|| {
564                        SecretsError::Backend("kms decrypt returned no plaintext".into())
565                    }),
566                Err(err) => Err(SecretsError::Backend(format!("kms decrypt: {err}"))),
567            }
568        })
569    }
570}
571
572#[derive(Debug, Clone, Serialize, Deserialize)]
573struct StoredSecret {
574    version: u64,
575    deleted: bool,
576    #[serde(skip_serializing_if = "Option::is_none")]
577    record: Option<SecretRecord>,
578}
579
580impl StoredSecret {
581    fn live(version: u64, record: SecretRecord) -> Self {
582        Self {
583            version,
584            deleted: false,
585            record: Some(record),
586        }
587    }
588
589    fn tombstone(version: u64) -> Self {
590        Self {
591            version,
592            deleted: true,
593            record: None,
594        }
595    }
596
597    fn into_versioned(self) -> VersionedSecret {
598        VersionedSecret {
599            version: self.version,
600            deleted: self.deleted,
601            record: self.record,
602        }
603    }
604}
605
606fn parse_secret_name(prefix: &str, name: &str) -> Option<SecretUri> {
607    let mut segments = name.split('/');
608    let prefix_segment = segments.next()?;
609    if prefix_segment != prefix {
610        return None;
611    }
612    let env = segments.next()?;
613    let tenant = segments.next()?;
614    let team_segment = segments.next()?;
615    let category = segments.next()?;
616    let name_segment = segments.next()?;
617    if segments.next().is_some() {
618        return None;
619    }
620
621    let team = if team_segment == TEAM_PLACEHOLDER {
622        None
623    } else {
624        Some(team_segment.to_string())
625    };
626
627    let scope = Scope::new(env.to_string(), tenant.to_string(), team).ok()?;
628    SecretUri::new(scope, category, name_segment).ok()
629}
630
631fn deserialize_secret_payload(
632    secret_string: Option<&str>,
633    secret_binary: Option<&aws_smithy_types::Blob>,
634) -> SecretsResult<Option<StoredSecret>> {
635    if let Some(value) = secret_string {
636        if value.trim().is_empty() {
637            return Ok(None);
638        }
639        return serde_json::from_str::<StoredSecret>(value)
640            .map(Some)
641            .map_err(|err| SecretsError::Storage(format!("decode secret payload: {err}")));
642    }
643
644    if let Some(blob) = secret_binary {
645        let bytes = blob.as_ref();
646        if bytes.is_empty() {
647            return Ok(None);
648        }
649        return serde_json::from_slice::<StoredSecret>(bytes)
650            .map(Some)
651            .map_err(|err| SecretsError::Storage(format!("decode secret payload: {err}")));
652    }
653
654    Ok(None)
655}
656
657fn is_not_found<T>(err: &SdkError<T>) -> bool
658where
659    T: aws_smithy_types::error::metadata::ProvideErrorMetadata + Send + Sync + std::fmt::Debug,
660{
661    if let SdkError::ServiceError(context) = err {
662        return context.err().code() == Some("ResourceNotFoundException");
663    }
664    false
665}
666
667fn storage_error<T>(operation: &str, err: SdkError<T>) -> SecretsError
668where
669    T: std::fmt::Display,
670{
671    SecretsError::Storage(format!("{operation} failed: {err}"))
672}
673
674fn list_versions_unsupported(err: &SdkError<ListSecretVersionIdsError>) -> bool {
675    match err {
676        SdkError::DispatchFailure(_) => true,
677        SdkError::ServiceError(ctx) => matches!(
678            ctx.err().code(),
679            Some("NotImplementedException") | Some("UnknownOperationException")
680        ),
681        _ => false,
682    }
683}
684
685#[cfg(test)]
686mod tests {
687    use super::*;
688    use rustls_native_certs::load_native_certs;
689    use serial_test::serial;
690    use std::env;
691
692    fn set_env(key: &str, value: &str) {
693        unsafe { env::set_var(key, value) };
694    }
695
696    fn clear_env(key: &str) {
697        unsafe { env::remove_var(key) };
698    }
699
700    fn setup_env() {
701        set_env(
702            "GREENTIC_AWS_KMS_KEY_ID",
703            "arn:aws:kms:us-east-1:000000000000:key/test",
704        );
705        set_env("GREENTIC_AWS_SECRET_PREFIX", "unit");
706        set_env("GREENTIC_AWS_VERSION_STAGE", "AWSCURRENT");
707        set_env("GREENTIC_AWS_REGION", "us-east-1");
708        set_env("AWS_ALLOW_HTTP", "1");
709        set_env("AWS_ACCESS_KEY_ID", "test");
710        set_env("AWS_SECRET_ACCESS_KEY", "test");
711        set_env("AWS_SESSION_TOKEN", "test");
712        set_env(SM_ENDPOINT_ENV, "http://127.0.0.1:9");
713        set_env(KMS_ENDPOINT_ENV, "http://127.0.0.1:9");
714        clear_env("AWS_PROFILE");
715    }
716
717    fn native_roots_available() -> bool {
718        let certs = load_native_certs();
719        if certs.certs.is_empty() {
720            eprintln!("native root certs unavailable: {:?}", certs.errors);
721            return false;
722        }
723        true
724    }
725
726    #[tokio::test(flavor = "multi_thread")]
727    #[serial]
728    async fn aws_provider_ok_under_tokio() {
729        if !native_roots_available() {
730            eprintln!("skipping aws_provider_ok_under_tokio: no native root certs");
731            return;
732        }
733
734        setup_env();
735        let BackendComponents { backend, .. } = build_backend()
736            .await
737            .expect("aws backend builds with env config");
738
739        let scope = Scope::new("dev", "tenant", None).expect("scope");
740        let result = backend.list(&scope, None, None);
741        assert!(
742            result.is_err(),
743            "list should attempt network and bubble up errors without panicking"
744        );
745    }
746}