1use std::{
2 collections::BTreeMap,
3 io::Write,
4 process::{Command, Stdio},
5 sync::Arc,
6};
7
8use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
9use serde::{Deserialize, Serialize};
10
11use crate::{
12 error::SdkError,
13 local::{LocalSymmetricKey, ManagedSymmetricKeyReference},
14 models::KeyTransportMode,
15};
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct ManagedSymmetricKeyProviderCapabilities {
19 supported_transport_modes: Vec<KeyTransportMode>,
20}
21
22impl ManagedSymmetricKeyProviderCapabilities {
23 pub fn new(supported_transport_modes: Vec<KeyTransportMode>) -> Self {
24 let mut supported_transport_modes = supported_transport_modes;
25 supported_transport_modes.dedup();
26 Self {
27 supported_transport_modes,
28 }
29 }
30
31 pub fn supports(&self, mode: KeyTransportMode) -> bool {
32 self.supported_transport_modes.contains(&mode)
33 }
34
35 pub fn supported_transport_modes(&self) -> &[KeyTransportMode] {
36 &self.supported_transport_modes
37 }
38}
39
40pub trait ManagedSymmetricKeyProvider: Send + Sync {
41 fn provider_name(&self) -> &str;
42
43 fn capabilities(&self) -> &ManagedSymmetricKeyProviderCapabilities;
44
45 fn resolve_key(
46 &self,
47 key_reference: &ManagedSymmetricKeyReference,
48 ) -> Result<LocalSymmetricKey, SdkError>;
49}
50
51#[derive(Clone, Default)]
52pub struct ManagedSymmetricKeyProviderRegistry {
53 providers: BTreeMap<String, Arc<dyn ManagedSymmetricKeyProvider>>,
54}
55
56impl ManagedSymmetricKeyProviderRegistry {
57 pub fn new() -> Self {
58 Self::default()
59 }
60
61 pub fn register<P>(&mut self, provider: P)
62 where
63 P: ManagedSymmetricKeyProvider + 'static,
64 {
65 self.register_arc(Arc::new(provider));
66 }
67
68 pub fn register_arc(&mut self, provider: Arc<dyn ManagedSymmetricKeyProvider>) {
69 self.providers
70 .insert(provider.provider_name().to_string(), provider);
71 }
72
73 pub fn resolve(
74 &self,
75 provider_name: Option<&str>,
76 ) -> Result<Arc<dyn ManagedSymmetricKeyProvider>, SdkError> {
77 match provider_name {
78 Some(provider_name) => self.providers.get(provider_name).cloned().ok_or_else(|| {
79 SdkError::InvalidInput(format!(
80 "managed symmetric key provider {provider_name:?} is not registered"
81 ))
82 }),
83 None => match self.providers.len() {
84 0 => Err(SdkError::InvalidInput(
85 "managed key execution requires a registered symmetric key provider, but none were configured"
86 .to_string(),
87 )),
88 1 => self
89 .providers
90 .values()
91 .next()
92 .cloned()
93 .ok_or_else(|| {
94 SdkError::InvalidInput(
95 "managed symmetric key provider registry was unexpectedly empty"
96 .to_string(),
97 )
98 }),
99 _ => Err(SdkError::InvalidInput(
100 "multiple managed symmetric key providers are registered; specify provider_name in the key source"
101 .to_string(),
102 )),
103 },
104 }
105 }
106}
107
108#[derive(Clone)]
109pub struct InMemoryManagedSymmetricKeyProvider {
110 name: String,
111 capabilities: ManagedSymmetricKeyProviderCapabilities,
112 keys: BTreeMap<String, LocalSymmetricKey>,
113}
114
115impl InMemoryManagedSymmetricKeyProvider {
116 pub fn new(name: impl Into<String>, keys: BTreeMap<String, LocalSymmetricKey>) -> Self {
117 Self {
118 name: name.into(),
119 capabilities: ManagedSymmetricKeyProviderCapabilities::new(vec![
120 KeyTransportMode::WrappedKeyReference,
121 KeyTransportMode::AuthorizedKeyRelease,
122 ]),
123 keys,
124 }
125 }
126
127 pub fn with_supported_transport_modes<I>(mut self, modes: I) -> Self
128 where
129 I: IntoIterator<Item = KeyTransportMode>,
130 {
131 self.capabilities =
132 ManagedSymmetricKeyProviderCapabilities::new(modes.into_iter().collect());
133 self
134 }
135}
136
137impl ManagedSymmetricKeyProvider for InMemoryManagedSymmetricKeyProvider {
138 fn provider_name(&self) -> &str {
139 &self.name
140 }
141
142 fn capabilities(&self) -> &ManagedSymmetricKeyProviderCapabilities {
143 &self.capabilities
144 }
145
146 fn resolve_key(
147 &self,
148 key_reference: &ManagedSymmetricKeyReference,
149 ) -> Result<LocalSymmetricKey, SdkError> {
150 self.keys
151 .get(key_reference.key_reference())
152 .cloned()
153 .ok_or_else(|| {
154 SdkError::InvalidInput(format!(
155 "managed symmetric key reference {:?} is not available from provider {:?}",
156 key_reference.key_reference(),
157 self.name
158 ))
159 })
160 }
161}
162
163#[derive(Clone)]
164pub struct CommandManagedSymmetricKeyProvider {
165 name: String,
166 capabilities: ManagedSymmetricKeyProviderCapabilities,
167 command: String,
168 args: Vec<String>,
169 env: BTreeMap<String, String>,
170}
171
172#[derive(Debug, Serialize)]
173#[serde(rename_all = "snake_case")]
174struct CommandManagedSymmetricKeyProviderRequest<'a> {
175 provider_name: &'a str,
176 key_reference: &'a str,
177 requested_provider_name: Option<&'a str>,
178}
179
180#[derive(Debug, Deserialize)]
181#[serde(rename_all = "snake_case")]
182struct CommandManagedSymmetricKeyProviderResponse {
183 key_b64: String,
184}
185
186impl CommandManagedSymmetricKeyProvider {
187 pub fn new(name: impl Into<String>, command: impl Into<String>) -> Self {
188 Self {
189 name: name.into(),
190 capabilities: ManagedSymmetricKeyProviderCapabilities::new(vec![
191 KeyTransportMode::WrappedKeyReference,
192 KeyTransportMode::AuthorizedKeyRelease,
193 ]),
194 command: command.into(),
195 args: Vec::new(),
196 env: BTreeMap::new(),
197 }
198 }
199
200 pub fn with_args<I, S>(mut self, args: I) -> Self
201 where
202 I: IntoIterator<Item = S>,
203 S: Into<String>,
204 {
205 self.args = args.into_iter().map(Into::into).collect();
206 self
207 }
208
209 pub fn with_envs<I, K, V>(mut self, envs: I) -> Self
210 where
211 I: IntoIterator<Item = (K, V)>,
212 K: Into<String>,
213 V: Into<String>,
214 {
215 self.env = envs
216 .into_iter()
217 .map(|(key, value)| (key.into(), value.into()))
218 .collect();
219 self
220 }
221
222 pub fn with_supported_transport_modes<I>(mut self, modes: I) -> Self
223 where
224 I: IntoIterator<Item = KeyTransportMode>,
225 {
226 self.capabilities =
227 ManagedSymmetricKeyProviderCapabilities::new(modes.into_iter().collect());
228 self
229 }
230}
231
232impl ManagedSymmetricKeyProvider for CommandManagedSymmetricKeyProvider {
233 fn provider_name(&self) -> &str {
234 &self.name
235 }
236
237 fn capabilities(&self) -> &ManagedSymmetricKeyProviderCapabilities {
238 &self.capabilities
239 }
240
241 fn resolve_key(
242 &self,
243 key_reference: &ManagedSymmetricKeyReference,
244 ) -> Result<LocalSymmetricKey, SdkError> {
245 let mut child = Command::new(&self.command)
246 .args(&self.args)
247 .envs(&self.env)
248 .stdin(Stdio::piped())
249 .stdout(Stdio::piped())
250 .stderr(Stdio::piped())
251 .spawn()
252 .map_err(|error| {
253 SdkError::Connection(format!(
254 "failed to launch managed symmetric key provider command {:?}: {error}",
255 self.command
256 ))
257 })?;
258
259 let request = serde_json::to_vec(&CommandManagedSymmetricKeyProviderRequest {
260 provider_name: &self.name,
261 key_reference: key_reference.key_reference(),
262 requested_provider_name: key_reference.provider_name(),
263 })
264 .map_err(|error| {
265 SdkError::Serialization(format!(
266 "failed to serialize managed symmetric key provider command request: {error}"
267 ))
268 })?;
269
270 if let Some(mut stdin) = child.stdin.take() {
271 stdin.write_all(&request).map_err(|error| {
272 SdkError::Connection(format!(
273 "failed to send key reference to managed symmetric key provider command {:?}: {error}",
274 self.command
275 ))
276 })?;
277 }
278
279 let output = child.wait_with_output().map_err(|error| {
280 SdkError::Connection(format!(
281 "failed waiting for managed symmetric key provider command {:?}: {error}",
282 self.command
283 ))
284 })?;
285
286 if !output.status.success() {
287 let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
288 return Err(SdkError::Connection(format!(
289 "managed symmetric key provider command {:?} failed{}",
290 self.command,
291 if stderr.is_empty() {
292 String::new()
293 } else {
294 format!(": {stderr}")
295 }
296 )));
297 }
298
299 let response: CommandManagedSymmetricKeyProviderResponse =
300 serde_json::from_slice(&output.stdout).map_err(|error| {
301 SdkError::Serialization(format!(
302 "failed to decode managed symmetric key provider command output: {error}"
303 ))
304 })?;
305
306 let decoded_key = BASE64_STANDARD.decode(&response.key_b64).map_err(|error| {
307 SdkError::InvalidInput(format!(
308 "managed symmetric key provider command returned invalid base64 key material: {error}"
309 ))
310 })?;
311 let key: [u8; 32] = decoded_key.try_into().map_err(|_| {
312 SdkError::InvalidInput(
313 "managed symmetric key provider command must return exactly 32 bytes of key material"
314 .to_string(),
315 )
316 })?;
317
318 Ok(LocalSymmetricKey::from(key))
319 }
320}