use std::collections::HashMap;
use aes_gcm::aead::{Aead, KeyInit, OsRng, Payload};
use aes_gcm::{AeadCore, Aes256Gcm, Key, Nonce};
use base64::engine::general_purpose::STANDARD as B64;
use base64::Engine as _;
use serde_json::Value;
use thiserror::Error;
use crate::client::{ConfigClient, ConfigClientError};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Classification {
Public,
Secret,
Skip,
}
pub type Classifier = Box<dyn Fn(&str, &Value) -> Classification + Send + Sync>;
pub struct BuildBundleOptions {
pub base_url: String,
pub auth_url: Option<String>,
pub client_id: Option<String>,
pub api_key: String,
pub org_id: String,
pub environment: Option<String>,
pub classify: Option<Classifier>,
}
#[derive(Debug)]
pub struct BuildBundleResult {
pub key_b64: String,
pub blob: Vec<u8>,
pub size: u64,
pub key_count: usize,
pub skipped_count: usize,
}
#[derive(Debug, Error)]
pub enum BuildError {
#[error("failed to fetch config values: {0}")]
Fetch(#[from] ConfigClientError),
#[error("config fetch transport error: {0}")]
Request(#[from] reqwest::Error),
#[error("failed to serialize config values to JSON: {0}")]
Serialize(#[from] serde_json::Error),
#[error("aes-gcm encryption failed: {0}")]
Encrypt(String),
}
pub async fn build_bundle(options: BuildBundleOptions) -> Result<BuildBundleResult, BuildError> {
let BuildBundleOptions {
base_url,
auth_url,
client_id,
api_key,
org_id,
environment,
classify,
} = options;
let resolved_client_id = client_id.unwrap_or_else(|| api_key.clone());
if let Some(url) = &auth_url {
std::env::set_var("SMOOAI_CONFIG_AUTH_URL", url);
}
let mut client = match &environment {
Some(env) => ConfigClient::with_environment(&base_url, &resolved_client_id, &api_key, &org_id, env),
None => ConfigClient::new(&base_url, &resolved_client_id, &api_key, &org_id),
};
let all = client.get_all_values(environment.as_deref()).await?;
let mut public_map: HashMap<String, Value> = HashMap::new();
let mut secret_map: HashMap<String, Value> = HashMap::new();
let mut skipped_count: usize = 0;
for (key, value) in all {
let section = match classify {
Some(ref f) => f(&key, &value),
None => Classification::Public,
};
match section {
Classification::Public => {
public_map.insert(key, value);
}
Classification::Secret => {
secret_map.insert(key, value);
}
Classification::Skip => {
skipped_count += 1;
}
}
}
let key_count = public_map.len() + secret_map.len();
let partitioned = serde_json::json!({
"public": public_map,
"secret": secret_map,
});
let plaintext = serde_json::to_vec(&partitioned)?;
let key_bytes: [u8; 32] = {
let k = Aes256Gcm::generate_key(&mut OsRng);
k.into()
};
let nonce_bytes: [u8; 12] = {
let n = Aes256Gcm::generate_nonce(&mut OsRng);
n.into()
};
let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&key_bytes));
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext_and_tag = cipher
.encrypt(
nonce,
Payload {
msg: &plaintext,
aad: &[],
},
)
.map_err(|e| BuildError::Encrypt(e.to_string()))?;
let mut blob = Vec::with_capacity(nonce_bytes.len() + ciphertext_and_tag.len());
blob.extend_from_slice(&nonce_bytes);
blob.extend_from_slice(&ciphertext_and_tag);
let size = blob.len() as u64;
let key_b64 = B64.encode(key_bytes);
Ok(BuildBundleResult {
key_b64,
blob,
size,
key_count,
skipped_count,
})
}
#[cfg(test)]
mod tests {
use super::*;
use wiremock::matchers::{header, method, path_regex, query_param};
use wiremock::{Mock, MockServer, ResponseTemplate};
#[tokio::test]
async fn build_bundle_encrypts_and_reports_counts() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path_regex(r"^/token$"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "stub-jwt",
"expires_in": 3600
})))
.mount(&mock_server)
.await;
Mock::given(method("GET"))
.and(path_regex(r"/organizations/.+/config/values"))
.and(query_param("environment", "production"))
.and(header("Authorization", "Bearer stub-jwt"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"values": {
"apiUrl": "https://api.example.com",
"tavilyApiKey": "tvly-abc",
"newFlow": true,
}
})))
.mount(&mock_server)
.await;
let classify: Classifier = Box::new(|key, _v| match key {
"tavilyApiKey" => Classification::Secret,
"newFlow" => Classification::Skip,
_ => Classification::Public,
});
let result = build_bundle(BuildBundleOptions {
base_url: mock_server.uri(),
auth_url: Some(mock_server.uri()),
client_id: Some("test-api-key".to_string()),
api_key: "test-api-key".to_string(),
org_id: "test-org".to_string(),
environment: Some("production".to_string()),
classify: Some(classify),
})
.await
.unwrap();
assert_eq!(result.key_count, 2); assert_eq!(result.skipped_count, 1); assert!(result.blob.len() > 12 + 16); assert_eq!(result.size, result.blob.len() as u64);
let key = B64.decode(&result.key_b64).unwrap();
assert_eq!(key.len(), 32);
}
#[tokio::test]
async fn build_bundle_default_classifier_makes_everything_public() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path_regex(r"^/token$"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "stub-jwt",
"expires_in": 3600
})))
.mount(&mock_server)
.await;
Mock::given(method("GET"))
.and(path_regex(r"/organizations/.+/config/values"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"values": {"FOO": "bar", "BAZ": 42}
})))
.mount(&mock_server)
.await;
let result = build_bundle(BuildBundleOptions {
base_url: mock_server.uri(),
auth_url: Some(mock_server.uri()),
client_id: Some("k".to_string()),
api_key: "k".to_string(),
org_id: "o".to_string(),
environment: Some("test".to_string()),
classify: None,
})
.await
.unwrap();
assert_eq!(result.key_count, 2);
assert_eq!(result.skipped_count, 0);
}
}