use std::path::Path;
pub struct Config {
pub key_state: crate::NodeState,
pub client_name: Option<String>,
pub control_server_url: url::Url,
pub requested_hostname: Option<String>,
pub requested_tags: Vec<String>,
}
pub async fn load_key_file(
p: impl AsRef<Path>,
bad_format: BadFormatBehavior,
) -> Result<crate::NodeState, crate::Error> {
let p = p.as_ref();
tracing::trace!(key_file = %p.display(), "loading key file");
let key_file = load_or_init::<KeyFile>(&p, Default::default, bad_format).await?;
Ok(key_file.key_state)
}
#[derive(serde::Serialize, serde::Deserialize, Default)]
struct KeyFile {
key_state: crate::NodeState,
}
impl From<&Config> for ts_control::Config {
fn from(value: &Config) -> ts_control::Config {
ts_control::Config {
client_name: value.client_name.clone(),
hostname: value.requested_hostname.clone(),
server_url: value.control_server_url.clone(),
}
}
}
impl Default for Config {
fn default() -> Self {
Self {
key_state: Default::default(),
client_name: None,
control_server_url: ts_control::DEFAULT_CONTROL_SERVER.clone(),
requested_hostname: None,
requested_tags: vec![],
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]
pub enum BadFormatBehavior {
#[default]
Error,
Overwrite,
}
#[tracing::instrument(skip_all, fields(?bad_format_behavior, path = %path.as_ref().display()))]
async fn load_or_init<KeyState>(
path: impl AsRef<Path>,
default: impl FnOnce() -> KeyState,
bad_format_behavior: BadFormatBehavior,
) -> Result<KeyState, crate::Error>
where
KeyState: serde::Serialize + serde::de::DeserializeOwned,
{
let path = path.as_ref();
tokio::fs::create_dir_all(path.parent().unwrap())
.await
.map_err(|e| {
tracing::error!(error = %e, "creating parent dirs for key file");
crate::Error::InternalFailure
})?;
match tokio::fs::read(path).await {
Ok(contents) => match serde_json::from_slice::<KeyState>(&contents) {
Ok(state) => {
return Ok(state);
}
Err(e) => match bad_format_behavior {
BadFormatBehavior::Error => {
tracing::error!(error = %e, "parsing key file");
return Err(crate::Error::InternalFailure);
}
BadFormatBehavior::Overwrite => {
tracing::warn!(
error = %e,
config_file_contents_len = contents.len(),
"failed loading version from key file, overwriting",
);
}
},
},
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
Err(e) => {
tracing::error!(error = %e, path = %path.display(), "reading key file");
return Err(crate::Error::InternalFailure);
}
}
let value = default();
tokio::fs::write(
path,
serde_json::to_vec(&value).map_err(|e| {
tracing::error!(error = %e, "serializing key state");
crate::Error::InternalFailure
})?,
)
.await
.map_err(|e| {
tracing::error!(error = %e, "saving key state");
crate::Error::InternalFailure
})?;
Ok(value)
}