Skip to main content

sdk_rust/
providers.rs

1use std::{
2    collections::BTreeMap,
3    io::Write,
4    process::{Command, Stdio},
5    sync::Arc,
6};
7
8use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
9use serde::{Deserialize, Serialize};
10
11use crate::{
12    error::SdkError,
13    local::{LocalSymmetricKey, ManagedSymmetricKeyReference},
14    models::KeyTransportMode,
15};
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct ManagedSymmetricKeyProviderCapabilities {
19    supported_transport_modes: Vec<KeyTransportMode>,
20}
21
22impl ManagedSymmetricKeyProviderCapabilities {
23    pub fn new(supported_transport_modes: Vec<KeyTransportMode>) -> Self {
24        let mut supported_transport_modes = supported_transport_modes;
25        supported_transport_modes.dedup();
26        Self {
27            supported_transport_modes,
28        }
29    }
30
31    pub fn supports(&self, mode: KeyTransportMode) -> bool {
32        self.supported_transport_modes.contains(&mode)
33    }
34
35    pub fn supported_transport_modes(&self) -> &[KeyTransportMode] {
36        &self.supported_transport_modes
37    }
38}
39
40pub trait ManagedSymmetricKeyProvider: Send + Sync {
41    fn provider_name(&self) -> &str;
42
43    fn capabilities(&self) -> &ManagedSymmetricKeyProviderCapabilities;
44
45    fn resolve_key(
46        &self,
47        key_reference: &ManagedSymmetricKeyReference,
48    ) -> Result<LocalSymmetricKey, SdkError>;
49}
50
51#[derive(Clone, Default)]
52pub struct ManagedSymmetricKeyProviderRegistry {
53    providers: BTreeMap<String, Arc<dyn ManagedSymmetricKeyProvider>>,
54}
55
56impl ManagedSymmetricKeyProviderRegistry {
57    pub fn new() -> Self {
58        Self::default()
59    }
60
61    pub fn register<P>(&mut self, provider: P)
62    where
63        P: ManagedSymmetricKeyProvider + 'static,
64    {
65        self.register_arc(Arc::new(provider));
66    }
67
68    pub fn register_arc(&mut self, provider: Arc<dyn ManagedSymmetricKeyProvider>) {
69        self.providers
70            .insert(provider.provider_name().to_string(), provider);
71    }
72
73    pub fn resolve(
74        &self,
75        provider_name: Option<&str>,
76    ) -> Result<Arc<dyn ManagedSymmetricKeyProvider>, SdkError> {
77        match provider_name {
78            Some(provider_name) => self.providers.get(provider_name).cloned().ok_or_else(|| {
79                SdkError::InvalidInput(format!(
80                    "managed symmetric key provider {provider_name:?} is not registered"
81                ))
82            }),
83            None => match self.providers.len() {
84                0 => Err(SdkError::InvalidInput(
85                    "managed key execution requires a registered symmetric key provider, but none were configured"
86                        .to_string(),
87                )),
88                1 => self
89                    .providers
90                    .values()
91                    .next()
92                    .cloned()
93                    .ok_or_else(|| {
94                        SdkError::InvalidInput(
95                            "managed symmetric key provider registry was unexpectedly empty"
96                                .to_string(),
97                        )
98                    }),
99                _ => Err(SdkError::InvalidInput(
100                    "multiple managed symmetric key providers are registered; specify provider_name in the key source"
101                        .to_string(),
102                )),
103            },
104        }
105    }
106}
107
108#[derive(Clone)]
109pub struct InMemoryManagedSymmetricKeyProvider {
110    name: String,
111    capabilities: ManagedSymmetricKeyProviderCapabilities,
112    keys: BTreeMap<String, LocalSymmetricKey>,
113}
114
115impl InMemoryManagedSymmetricKeyProvider {
116    pub fn new(name: impl Into<String>, keys: BTreeMap<String, LocalSymmetricKey>) -> Self {
117        Self {
118            name: name.into(),
119            capabilities: ManagedSymmetricKeyProviderCapabilities::new(vec![
120                KeyTransportMode::WrappedKeyReference,
121                KeyTransportMode::AuthorizedKeyRelease,
122            ]),
123            keys,
124        }
125    }
126
127    pub fn with_supported_transport_modes<I>(mut self, modes: I) -> Self
128    where
129        I: IntoIterator<Item = KeyTransportMode>,
130    {
131        self.capabilities =
132            ManagedSymmetricKeyProviderCapabilities::new(modes.into_iter().collect());
133        self
134    }
135}
136
137impl ManagedSymmetricKeyProvider for InMemoryManagedSymmetricKeyProvider {
138    fn provider_name(&self) -> &str {
139        &self.name
140    }
141
142    fn capabilities(&self) -> &ManagedSymmetricKeyProviderCapabilities {
143        &self.capabilities
144    }
145
146    fn resolve_key(
147        &self,
148        key_reference: &ManagedSymmetricKeyReference,
149    ) -> Result<LocalSymmetricKey, SdkError> {
150        self.keys
151            .get(key_reference.key_reference())
152            .cloned()
153            .ok_or_else(|| {
154                SdkError::InvalidInput(format!(
155                    "managed symmetric key reference {:?} is not available from provider {:?}",
156                    key_reference.key_reference(),
157                    self.name
158                ))
159            })
160    }
161}
162
163#[derive(Clone)]
164pub struct CommandManagedSymmetricKeyProvider {
165    name: String,
166    capabilities: ManagedSymmetricKeyProviderCapabilities,
167    command: String,
168    args: Vec<String>,
169    env: BTreeMap<String, String>,
170}
171
172#[derive(Debug, Serialize)]
173#[serde(rename_all = "snake_case")]
174struct CommandManagedSymmetricKeyProviderRequest<'a> {
175    provider_name: &'a str,
176    key_reference: &'a str,
177    requested_provider_name: Option<&'a str>,
178}
179
180#[derive(Debug, Deserialize)]
181#[serde(rename_all = "snake_case")]
182struct CommandManagedSymmetricKeyProviderResponse {
183    key_b64: String,
184}
185
186impl CommandManagedSymmetricKeyProvider {
187    pub fn new(name: impl Into<String>, command: impl Into<String>) -> Self {
188        Self {
189            name: name.into(),
190            capabilities: ManagedSymmetricKeyProviderCapabilities::new(vec![
191                KeyTransportMode::WrappedKeyReference,
192                KeyTransportMode::AuthorizedKeyRelease,
193            ]),
194            command: command.into(),
195            args: Vec::new(),
196            env: BTreeMap::new(),
197        }
198    }
199
200    pub fn with_args<I, S>(mut self, args: I) -> Self
201    where
202        I: IntoIterator<Item = S>,
203        S: Into<String>,
204    {
205        self.args = args.into_iter().map(Into::into).collect();
206        self
207    }
208
209    pub fn with_envs<I, K, V>(mut self, envs: I) -> Self
210    where
211        I: IntoIterator<Item = (K, V)>,
212        K: Into<String>,
213        V: Into<String>,
214    {
215        self.env = envs
216            .into_iter()
217            .map(|(key, value)| (key.into(), value.into()))
218            .collect();
219        self
220    }
221
222    pub fn with_supported_transport_modes<I>(mut self, modes: I) -> Self
223    where
224        I: IntoIterator<Item = KeyTransportMode>,
225    {
226        self.capabilities =
227            ManagedSymmetricKeyProviderCapabilities::new(modes.into_iter().collect());
228        self
229    }
230}
231
232impl ManagedSymmetricKeyProvider for CommandManagedSymmetricKeyProvider {
233    fn provider_name(&self) -> &str {
234        &self.name
235    }
236
237    fn capabilities(&self) -> &ManagedSymmetricKeyProviderCapabilities {
238        &self.capabilities
239    }
240
241    fn resolve_key(
242        &self,
243        key_reference: &ManagedSymmetricKeyReference,
244    ) -> Result<LocalSymmetricKey, SdkError> {
245        let mut child = Command::new(&self.command)
246            .args(&self.args)
247            .envs(&self.env)
248            .stdin(Stdio::piped())
249            .stdout(Stdio::piped())
250            .stderr(Stdio::piped())
251            .spawn()
252            .map_err(|error| {
253                SdkError::Connection(format!(
254                    "failed to launch managed symmetric key provider command {:?}: {error}",
255                    self.command
256                ))
257            })?;
258
259        let request = serde_json::to_vec(&CommandManagedSymmetricKeyProviderRequest {
260            provider_name: &self.name,
261            key_reference: key_reference.key_reference(),
262            requested_provider_name: key_reference.provider_name(),
263        })
264        .map_err(|error| {
265            SdkError::Serialization(format!(
266                "failed to serialize managed symmetric key provider command request: {error}"
267            ))
268        })?;
269
270        if let Some(mut stdin) = child.stdin.take() {
271            stdin.write_all(&request).map_err(|error| {
272                SdkError::Connection(format!(
273                    "failed to send key reference to managed symmetric key provider command {:?}: {error}",
274                    self.command
275                ))
276            })?;
277        }
278
279        let output = child.wait_with_output().map_err(|error| {
280            SdkError::Connection(format!(
281                "failed waiting for managed symmetric key provider command {:?}: {error}",
282                self.command
283            ))
284        })?;
285
286        if !output.status.success() {
287            let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
288            return Err(SdkError::Connection(format!(
289                "managed symmetric key provider command {:?} failed{}",
290                self.command,
291                if stderr.is_empty() {
292                    String::new()
293                } else {
294                    format!(": {stderr}")
295                }
296            )));
297        }
298
299        let response: CommandManagedSymmetricKeyProviderResponse =
300            serde_json::from_slice(&output.stdout).map_err(|error| {
301                SdkError::Serialization(format!(
302                    "failed to decode managed symmetric key provider command output: {error}"
303                ))
304            })?;
305
306        let decoded_key = BASE64_STANDARD.decode(&response.key_b64).map_err(|error| {
307            SdkError::InvalidInput(format!(
308                "managed symmetric key provider command returned invalid base64 key material: {error}"
309            ))
310        })?;
311        let key: [u8; 32] = decoded_key.try_into().map_err(|_| {
312            SdkError::InvalidInput(
313                "managed symmetric key provider command must return exactly 32 bytes of key material"
314                    .to_string(),
315            )
316        })?;
317
318        Ok(LocalSymmetricKey::from(key))
319    }
320}