use super::{Encoding, Result, get_content_from_file};
use notify::{EventHandler, RecursiveMode, Watcher};
use serde::{Deserialize, Serialize, de::Visitor};
use std::{fmt, fs, path::PathBuf, sync::Arc};
use zeroize::Zeroize;
#[derive(Clone)]
pub struct SecretWatcher {
path: PathBuf,
key: Option<String>,
#[cfg(feature = "notify")]
content: tokio::sync::watch::Receiver<Result<String>>,
#[allow(unused)]
watcher: Arc<notify::INotifyWatcher>,
encoding: Option<Encoding>,
}
impl SecretWatcher {
pub fn read(&self) -> Result<String> {
self.content.borrow().clone()
}
#[cfg(all(feature = "notify-watch", feature = "tokio-notify"))]
pub fn stream(&self) -> impl futures::Stream<Item = Result<String>> + 'static {
#[cfg(feature = "tokio-notify")]
{
let content = self.content.clone();
tokio_stream::wrappers::WatchStream::new(content)
}
}
}
impl fmt::Debug for SecretWatcher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self { path, key, .. } = self;
f.debug_struct("SecretWatcher")
.field("path", path)
.field("key", key)
.finish()
}
}
impl fmt::Display for SecretWatcher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self { path, key, .. } = self;
write!(f, r#"type: "file_watcher", path: {path:?}, key: "{key:?}"#)
}
}
impl PartialEq for SecretWatcher {
fn eq(&self, other: &Self) -> bool {
self.path == other.path
&& self.key == other.key
&& self.encoding == other.encoding
&& std::ptr::eq(&self.content, &other.content)
}
}
#[cfg(feature = "json-schema")]
impl ::schemars::JsonSchema for SecretWatcher {
fn schema_name() -> std::borrow::Cow<'static, str> {
"SecretWatcher".into()
}
fn json_schema(generator: &mut ::schemars::SchemaGenerator) -> ::schemars::Schema {
use ::schemars::json_schema;
json_schema!({
"examples": [
"/path/to/file",
{
"path":"/path/to/file",
"encoding": "base64"
}
],
"anyOf": [
{
"type": "string"
},
{
"type": "object",
"required": ["path"],
"properties": {
"key": {
"type": "string"
},
"path": {
"type": "string",
},
"encoding": Encoding::json_schema(generator),
}
}
]
})
}
}
impl Serialize for SecretWatcher {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeMap;
let Self {
path,
key,
encoding,
..
} = self;
if key.is_none() && encoding.is_none() {
serializer.serialize_str(path.to_str().unwrap())
} else {
let mut map = serializer.serialize_map(None)?;
map.serialize_entry("path", path)?;
if let Some(key) = key {
map.serialize_entry("key", key)?;
}
if let Some(encoding) = encoding {
map.serialize_entry("encoding", encoding)?;
}
map.end()
}
}
}
struct SecretWatcherVisitor;
impl<'de> Visitor<'de> for SecretWatcherVisitor {
type Value = (String, PathBuf, Option<String>, Option<Encoding>);
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("struct SecretWatcher")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
let path = PathBuf::from(v);
let content = fs::read_to_string(&path)
.map_err(|err| E::custom(format!("cannot read file at '{path:?}': {err}")))?;
Ok((content, path, None, None))
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'de>,
{
let mut key_name = None;
let mut path_name = None;
let mut encoding_name = None;
while let Ok(Some(key)) = map.next_key::<String>() {
match key.as_str() {
"key" => {
let key_value = map.next_value::<String>()?;
key_name = Some(key_value);
}
"path" => {
let path_value = map.next_value::<PathBuf>()?;
path_name = Some(path_value);
}
"encoding" => {
let encoding_value = map.next_value::<Encoding>()?;
encoding_name = Some(encoding_value);
}
_ => {}
}
}
match (key_name, path_name, encoding_name) {
(key, Some(path), encoding) => Ok((
get_content_from_file(&path, key.as_deref(), encoding)
.map_err(|err| <A::Error as serde::de::Error>::custom(err.to_string()))?,
path,
key,
encoding,
)),
_ => Err(<A::Error as serde::de::Error>::missing_field("path")),
}
}
}
impl<'de> Deserialize<'de> for SecretWatcher {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let (init, path, key, encoding) = deserializer.deserialize_any(SecretWatcherVisitor)?;
#[cfg(feature = "notify")]
let (sender, recv) = tokio::sync::watch::channel(Ok(init));
struct Reader {
path: PathBuf,
key: Option<String>,
encoding: Option<Encoding>,
#[cfg(feature = "notify")]
sender: tokio::sync::watch::Sender<Result<String>>,
}
impl EventHandler for Reader {
fn handle_event(&mut self, event: notify::Result<notify::Event>) {
if let Ok(notify::Event {
kind: notify::EventKind::Modify(_),
..
}) = event
{
let next =
get_content_from_file(&self.path, self.key.as_deref(), self.encoding);
fn swap(dest: &mut Result<String>, src: Result<String>) {
let prev = std::mem::replace(dest, src);
if let Ok(mut prev) = prev {
prev.zeroize();
}
}
self.sender.send_if_modified(|curr| {
let is_same = match (&curr, &next) {
(Ok(curr), Ok(next)) => curr == next,
_ => false,
};
swap(curr, next);
!is_same
});
}
}
}
let reader = Reader {
path: path.clone(),
key: key.clone(),
encoding,
sender,
};
let target_path = reader.path.clone();
let watch_path = if target_path.is_symlink() {
#[cfg(feature = "k8s")]
{
target_path
.parent()
.ok_or_else(|| {
<D::Error as serde::de::Error>::custom(format!(
"cannot determine parent directory of '{target_path:?}'"
))
})?
.to_path_buf()
}
#[cfg(not(feature = "k8s"))]
{
target_path
.canonicalize()
.map_err(|e| {
<D::Error as serde::de::Error>::custom(format!(
"failed to canonicalize path '{target_path:?}': {e}"
))
})?
.parent()
.ok_or_else(|| {
<D::Error as serde::de::Error>::custom(
"cannot determine parent directory of canonicalized path",
)
})?
.to_path_buf()
}
} else {
target_path.clone()
};
let mut watcher = notify::recommended_watcher(reader).map_err(|err| {
<D::Error as serde::de::Error>::custom(format!(
"cannot create file watcher for '{target_path:?}': {err}"
))
})?;
watcher
.watch(&watch_path, RecursiveMode::NonRecursive)
.map_err(|err| {
<D::Error as serde::de::Error>::custom(format!(
"cannot watch file at '{watch_path:?}': {err}"
))
})?;
Ok(SecretWatcher {
content: recv,
watcher: Arc::new(watcher),
path: target_path,
key,
encoding,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use assert_fs::{
TempDir,
fixture::ChildPath,
prelude::{FileWriteStr, PathChild},
};
use backon::RetryableWithContext;
use futures::{FutureExt, future::BoxFuture};
use rstest::{fixture, rstest};
#[fixture]
fn empty_tmp_dir() -> TempDir {
TempDir::new().unwrap()
}
#[fixture]
fn temp_dir(empty_tmp_dir: TempDir) -> (ChildPath, TempDir) {
let file = empty_tmp_dir.child("secret");
(file, empty_tmp_dir)
}
#[allow(clippy::type_complexity)]
fn check_secret(
(s, expected): (SecretWatcher, Result<String>),
) -> BoxFuture<'static, ((SecretWatcher, Result<String>), Result<(), ()>)> {
if s.read() == expected {
async { ((s, expected), Ok(())) }.boxed()
} else {
async {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
((s, expected), Err(()))
}
.boxed()
}
}
#[rstest]
#[tokio::test]
async fn secrets_in_file_with_keys(temp_dir: (ChildPath, TempDir)) {
use backon::ConstantBuilder;
let (file, _) = temp_dir;
file.write_str("").unwrap();
let secret: SecretWatcher = serde_json::from_str(&format!(
r#""{}""#,
String::from_utf8_lossy(file.to_string_lossy().as_bytes())
))
.expect("input to be deserialized");
let ((secret, _), res) = check_secret
.retry(ConstantBuilder::default())
.context((secret, Ok("".to_string())))
.await;
assert!(res.is_ok());
file.write_str("hello").unwrap();
let (_, res) = check_secret
.retry(ConstantBuilder::default())
.context((secret, Ok("hello".to_string())))
.await;
assert!(res.is_ok());
}
#[cfg(feature = "tokio-notify")]
#[rstest]
#[tokio::test]
async fn secret_notify(temp_dir: (ChildPath, TempDir)) {
use futures::{StreamExt, TryStreamExt, stream};
let (file, _) = temp_dir;
file.write_str("MY_KEY=").unwrap();
let secret: SecretWatcher = serde_json::from_str(&format!(
r#"{{"path":"{}","key":"MY_KEY"}}"#,
String::from_utf8_lossy(file.to_string_lossy().as_bytes())
))
.expect("input to be deserialized");
let (tx, rx) = tokio::sync::mpsc::channel(4);
let task = tokio::spawn(
secret
.stream()
.map(Ok::<_, tokio::sync::mpsc::error::SendError<_>>)
.try_fold(tx, |tx, next| async {
tx.send(next).await?;
Ok(tx)
}),
);
stream::iter(1..=3)
.fold(file, |file, next| async move {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
file.write_str(&format!("MY_KEY={next}")).unwrap();
file
})
.await;
let messages = tokio_stream::wrappers::ReceiverStream::new(rx)
.take(4)
.try_collect::<Vec<_>>()
.await;
assert_eq!(
messages,
Ok(vec![
"".to_string(),
"1".to_string(),
"2".to_string(),
"3".to_string(),
])
);
drop(task);
drop(secret);
}
#[cfg(all(unix, feature = "tokio-notify", feature = "k8s"))]
#[rstest]
#[tokio::test]
async fn k8s_secret_notify(empty_tmp_dir: TempDir) {
use futures::TryStreamExt;
use std::os::unix::fs as unix_fs;
let data_dir_1 = empty_tmp_dir.child("..2024_01_15_12_30_00");
fs::create_dir(data_dir_1.path()).unwrap();
let actual_file_1 = data_dir_1.child("secret.txt");
actual_file_1.write_str("SECRET1").unwrap();
let data_link = empty_tmp_dir.child("..data");
unix_fs::symlink("..2024_01_15_12_30_00", data_link.path()).unwrap();
let secret_link = empty_tmp_dir.child("secret.txt");
unix_fs::symlink("..data/secret.txt", secret_link.path()).unwrap();
let secret: SecretWatcher = serde_json::from_str(&format!(
r#"{{"path":"{}"}}"#,
String::from_utf8_lossy(secret_link.to_string_lossy().as_bytes())
))
.unwrap();
let stream = secret.stream();
let task = tokio::spawn(async move {
stream
.try_fold(vec![], |mut acc, s| {
acc.push(s);
futures::future::ready(Ok(acc))
})
.await
});
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let data_dir_2 = empty_tmp_dir.child("..2024_01_15_12_31_00");
fs::create_dir(data_dir_2.path()).unwrap();
let actual_file_2 = data_dir_2.child("secret.txt");
actual_file_2.write_str("SECRET2").unwrap();
let data_tmp = empty_tmp_dir.child("..data_tmp");
unix_fs::symlink("..2024_01_15_12_31_00", data_tmp.path()).unwrap();
fs::rename(data_tmp.path(), data_link.path()).unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
drop(secret);
let results = task.await.unwrap().unwrap();
assert_eq!(results, vec!["SECRET1".to_string(), "SECRET2".to_string()]);
}
#[cfg(all(unix, feature = "tokio-notify", not(feature = "k8s")))]
#[rstest]
#[tokio::test]
async fn external_symlink_secret_notify(empty_tmp_dir: TempDir) {
use futures::TryStreamExt;
use std::os::unix::fs as unix_fs;
let actual_secrets_dir = empty_tmp_dir.child("actual_secrets");
fs::create_dir(actual_secrets_dir.path()).unwrap();
let actual_file = actual_secrets_dir.child("key.txt");
actual_file.write_str("INITIAL_SECRET").unwrap();
let config_dir = empty_tmp_dir.child("config");
fs::create_dir(config_dir.path()).unwrap();
let symlink_file = config_dir.child("secret.txt");
unix_fs::symlink(actual_file.path(), symlink_file.path()).unwrap();
let secret: SecretWatcher = serde_json::from_str(&format!(
r#"{{"path":"{}"}}"#,
String::from_utf8_lossy(symlink_file.to_string_lossy().as_bytes())
))
.unwrap();
let stream = secret.stream();
let task = tokio::spawn(async move {
stream
.try_fold(vec![], |mut acc, s| {
acc.push(s);
futures::future::ready(Ok(acc))
})
.await
});
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
actual_file.write_str("UPDATED_SECRET").unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
drop(secret);
let results = task.await.unwrap().unwrap();
assert_eq!(
results,
vec!["INITIAL_SECRET".to_string(), "UPDATED_SECRET".to_string()]
);
}
}