use std::collections::HashMap;
use std::fs;
use std::io;
use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("keystore I/O error: {0}")]
Io(#[from] io::Error),
#[error("keystore JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("key not found: {0}")]
NotFound(String),
#[error("key already exists: {0}")]
AlreadyExists(String),
#[error("unknown algorithm: {0}")]
UnknownAlgorithm(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredKey {
pub alg: String,
#[serde(with = "base64url_bytes")]
pub pub_key: Vec<u8>,
#[serde(with = "base64url_bytes")]
pub prv_key: Vec<u8>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tag: Option<String>,
}
pub trait KeyStore {
fn store(&mut self, tmb: &str, key: StoredKey) -> Result<(), Error>;
fn get(&self, tmb: &str) -> Result<&StoredKey, Error>;
fn list(&self) -> Vec<&str>;
fn save(&self) -> Result<(), Error>;
}
pub struct JsonKeyStore {
path: PathBuf,
keys: HashMap<String, StoredKey>,
}
impl JsonKeyStore {
pub fn open(path: impl AsRef<Path>) -> Result<Self, Error> {
let path = path.as_ref().to_path_buf();
let keys = if path.exists() {
let content = fs::read_to_string(&path)?;
serde_json::from_str(&content)?
} else {
HashMap::new()
};
Ok(Self { path, keys })
}
pub fn path(&self) -> &Path {
&self.path
}
}
impl KeyStore for JsonKeyStore {
fn store(&mut self, tmb: &str, key: StoredKey) -> Result<(), Error> {
if self.keys.contains_key(tmb) {
return Err(Error::AlreadyExists(tmb.to_string()));
}
self.keys.insert(tmb.to_string(), key);
Ok(())
}
fn get(&self, tmb: &str) -> Result<&StoredKey, Error> {
self.keys
.get(tmb)
.ok_or_else(|| Error::NotFound(tmb.to_string()))
}
fn list(&self) -> Vec<&str> {
self.keys.keys().map(String::as_str).collect()
}
fn save(&self) -> Result<(), Error> {
let content = serde_json::to_string_pretty(&self.keys)?;
fs::write(&self.path, content)?;
Ok(())
}
}
mod base64url_bytes {
use base64ct::{Base64UrlUnpadded, Encoding};
use serde::{Deserialize, Deserializer, Serializer, de};
pub fn serialize<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let encoded = Base64UrlUnpadded::encode_string(bytes);
serializer.serialize_str(&encoded)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Base64UrlUnpadded::decode_vec(&s).map_err(de::Error::custom)
}
}