use std::path::Path;
use serde::Serializer;
use ts_keys::PersistState;
use crate::keys::NodeState;
const CONTROL_URL_VAR: &str = "TS_CONTROL_URL";
const HOSTNAME_VAR: &str = "TS_HOSTNAME";
const AUTHKEY_VAR: &str = "TS_AUTH_KEY";
pub struct Config {
pub key_state: PersistState,
pub client_name: Option<String>,
pub control_server_url: url::Url,
pub requested_hostname: Option<String>,
pub requested_tags: Vec<String>,
}
impl Config {
pub async fn default_with_key_file(p: impl AsRef<Path>) -> Result<Self, crate::Error> {
Ok(Config {
key_state: load_key_file(p, Default::default()).await?,
..Default::default()
})
}
pub fn default_from_env() -> Config {
let mut config = Config::default();
if let Ok(u) = std::env::var(CONTROL_URL_VAR) {
match u.parse() {
Ok(u) => config.control_server_url = u,
Err(e) => {
tracing::error!(error = %e, "parsing {CONTROL_URL_VAR} (fall back to default value)");
}
}
};
config.requested_hostname = std::env::var(HOSTNAME_VAR).ok();
config
}
}
pub fn auth_key_from_env() -> Option<String> {
std::env::var(AUTHKEY_VAR).ok()
}
pub async fn load_key_file(
p: impl AsRef<Path>,
bad_format: BadFormatBehavior,
) -> Result<PersistState, crate::Error> {
let p = p.as_ref();
tracing::trace!(key_file = %p.display(), "loading key file");
let key_file = load_or_init::<KeyFile>(
&p,
Default::default,
|x| match x {
#[allow(deprecated)]
KeyFile::Old(old) => Some(KeyFile::New(KeyFileNew {
key_state: PersistState::from(&old.key_state),
})),
_ => None,
},
bad_format,
)
.await?;
Ok(key_file.key_state())
}
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum KeyFile {
#[deprecated]
Old(KeyFileOld),
New(KeyFileNew),
}
impl KeyFile {
#[allow(deprecated)]
pub fn key_state(&self) -> PersistState {
match self {
Self::Old(old) => (&old.key_state).into(),
Self::New(new) => new.key_state.clone(),
}
}
}
impl Default for KeyFile {
fn default() -> Self {
KeyFile::New(KeyFileNew::default())
}
}
impl serde::Serialize for KeyFile {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
KeyFileNew {
key_state: self.key_state(),
}
.serialize(serializer)
}
}
#[derive(serde::Deserialize, serde::Serialize, Default)]
struct KeyFileNew {
key_state: PersistState,
}
#[derive(serde::Deserialize)]
struct KeyFileOld {
key_state: NodeState,
}
impl From<&Config> for ts_control::Config {
fn from(value: &Config) -> ts_control::Config {
ts_control::Config {
client_name: value.client_name.clone(),
hostname: value.requested_hostname.clone(),
server_url: value.control_server_url.clone(),
tags: value.requested_tags.clone(),
}
}
}
impl Default for Config {
fn default() -> Self {
Self {
key_state: Default::default(),
client_name: None,
control_server_url: ts_control::DEFAULT_CONTROL_SERVER.clone(),
requested_hostname: None,
requested_tags: vec![],
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]
pub enum BadFormatBehavior {
#[default]
Error,
Overwrite,
}
#[tracing::instrument(skip_all, fields(?bad_format_behavior, path = %path.as_ref().display()))]
async fn load_or_init<KeyState>(
path: impl AsRef<Path>,
default: impl FnOnce() -> KeyState,
migrate: impl FnOnce(&KeyState) -> Option<KeyState>,
bad_format_behavior: BadFormatBehavior,
) -> Result<KeyState, crate::Error>
where
KeyState: serde::Serialize + serde::de::DeserializeOwned,
{
let path = path.as_ref();
tokio::fs::create_dir_all(path.parent().unwrap())
.await
.map_err(|e| {
tracing::error!(error = %e, "creating parent dirs for key file");
crate::Error::KeyFileWrite
})?;
match tokio::fs::read(path).await {
Ok(contents) => match serde_json::from_slice::<KeyState>(&contents) {
Ok(state) => {
if let Some(migrated) = migrate(&state) {
match try_write(path, &migrated).await {
Ok(_) => {
tracing::info!("migrated key file to new disco-less format");
return Ok(migrated);
}
Err(e) => {
tracing::error!(error = %e, "unable to migrate key file");
}
}
}
return Ok(state);
}
Err(e) => match bad_format_behavior {
BadFormatBehavior::Error => {
tracing::error!(error = %e, "parsing key file");
return Err(crate::Error::KeyFileRead);
}
BadFormatBehavior::Overwrite => {
tracing::warn!(
error = %e,
config_file_contents_len = contents.len(),
"failed loading version from key file, overwriting",
);
}
},
},
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
Err(e) => {
tracing::error!(error = %e, path = %path.display(), "reading key file");
return Err(crate::Error::KeyFileRead);
}
}
let value = default();
try_write(path, &value).await?;
Ok(value)
}
async fn try_write(
path: impl AsRef<Path>,
value: &impl serde::Serialize,
) -> Result<(), crate::Error> {
tokio::fs::write(
path,
serde_json::to_vec(value).map_err(|e| {
tracing::error!(error = %e, "serializing key state");
crate::Error::KeyFileWrite
})?,
)
.await
.map_err(|e| {
tracing::error!(error = %e, "saving key state");
crate::Error::KeyFileWrite
})?;
Ok(())
}