use std::{
collections::HashMap,
default::Default,
fs::File,
io::{Read, Seek, SeekFrom, Write},
sync::Mutex,
time::SystemTime,
};
use crate::{
modhex::ModHex,
otp::{self, DecryptedOtp, DecryptedPrivateData, Otp, PrivateId, PublicId},
store::{OtpStore, StoreError},
};
pub struct Store {
file_path: String,
entries: Vec<Data>,
lookup: HashMap<PublicId, usize>,
write_lock: Mutex<()>,
}
pub struct Data {
pub key: [u8; 16],
pub previous: DecryptedPrivateData,
}
#[derive(Debug)]
pub enum Error {
IoError(std::io::Error),
Parsing,
}
const LINE_LENGTH_WITH_NEWLINE: usize = 90 + 1;
impl std::error::Error for Error {}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::IoError(error) => write!(f, "{error}"),
Error::Parsing => write!(f, "Error parsing file"),
}
}
}
impl Default for Data {
fn default() -> Self {
Self {
key: [0; 16],
previous: DecryptedPrivateData {
id: PrivateId { raw_bytes: [0; 6] },
usage_counter: 0,
session_counter: 0,
timestamp: 0,
random: [0; 2],
},
}
}
}
impl Store {
pub fn create(file_path: &str) -> Result<Self, Error> {
let mut f = match File::open(file_path) {
Ok(file) => file,
Err(ref e) if e.kind() == std::io::ErrorKind::NotFound => {
let store = Store {
file_path: String::from(file_path),
lookup: HashMap::new(),
entries: Vec::new(),
write_lock: Mutex::new(()),
};
store.store_all().map_err(Error::IoError)?;
return Ok(store);
}
Err(e) => return Err(Error::IoError(e)),
};
let mut buffer = String::new();
f.read_to_string(&mut buffer).map_err(Error::IoError)?;
let mut lookup = HashMap::new();
let mut entries = Vec::new();
let rewrite_whole_file = buffer.contains("\r\n");
for (line_index, line) in buffer.trim().lines().skip(1).enumerate() {
let line = line.trim();
if line.is_empty() {
continue;
}
let Some((public_id_str, data)) = line.split_once('=') else {
return Err(Error::Parsing);
};
let mut public_id = PublicId { raw_bytes: [0; 6] };
public_id.raw_bytes.copy_from_slice(
ModHex::try_from(public_id_str.trim())
.map_err(|_| Error::Parsing)?
.raw_bytes(),
);
let mut usage_counter: Option<u16> = None;
let mut session_counter: Option<u8> = None;
let mut private_id: Option<[u8; 6]> = None;
let mut key: Option<[u8; 16]> = None;
for part in data.split_whitespace() {
if let Some(value) = part.strip_prefix("u_cnt:") {
usage_counter = Some(value.parse().map_err(|_| Error::Parsing)?);
} else if let Some(value) = part.strip_prefix("s_cnt:") {
session_counter = Some(value.parse().map_err(|_| Error::Parsing)?);
} else if let Some(value) = part.strip_prefix("pid:") {
private_id = Some(
decode_hex(value)
.map_err(|_| Error::Parsing)?
.try_into()
.map_err(|_| Error::Parsing)?,
);
} else if let Some(value) = part.strip_prefix("key:") {
key = Some(
decode_hex(value)
.map_err(|_| Error::Parsing)?
.try_into()
.map_err(|_| Error::Parsing)?,
);
}
}
let usage_counter = usage_counter.ok_or(Error::Parsing)?;
let session_counter = session_counter.ok_or(Error::Parsing)?;
let private_id = private_id.ok_or(Error::Parsing)?;
let key = key.ok_or(Error::Parsing)?;
let previous = DecryptedPrivateData {
id: crate::otp::PrivateId {
raw_bytes: private_id,
},
usage_counter,
session_counter,
timestamp: 0,
random: [0; 2],
};
entries.push(Data { key, previous });
lookup.insert(public_id, line_index);
}
let store = Store {
file_path: String::from(file_path),
lookup,
entries,
write_lock: Mutex::new(()),
};
if rewrite_whole_file {
store.store_all().map_err(Error::IoError)?;
}
Ok(store)
}
pub fn get(&self, id: PublicId) -> Option<&Data> {
self.lookup.get(&id).map(|&idx| &self.entries[idx])
}
pub fn validate(&mut self, otp_str: &str) -> Result<(), StoreError> {
let otp = match Otp::from_modhex(otp_str) {
Ok(otp) => otp,
Err(err) => {
return Err(StoreError::Otp(err));
}
};
let Some(previous) = self.get(otp.id) else {
return Err(StoreError::UnknownPublicId);
};
let decrypted_otp = match otp.decrypt(&previous.key) {
Ok(otp) => otp,
Err(err) => {
return Err(StoreError::Otp(err));
}
};
match decrypted_otp.validate(&otp::DecryptedOtp {
id: decrypted_otp.id,
private: previous.previous.clone(),
}) {
Ok(()) => self
.update(&decrypted_otp, &previous.key.clone())
.map_err(|e| StoreError::Other(Box::new(e))),
Err(err) => Err(StoreError::Validation(err)),
}
}
pub fn provision_new_otp(&mut self, otp: &DecryptedOtp, key: &[u8]) -> Result<(), Error> {
self.update(otp, key)
}
fn update(&mut self, otp: &DecryptedOtp, key: &[u8]) -> Result<(), Error> {
let mut is_new = false;
if let Some(existing_index) = self.lookup.get(&otp.id) {
self.entries[*existing_index].previous = otp.private.clone();
} else {
let mut new_data = Data {
previous: otp.private.clone(),
..Default::default()
};
new_data.key.copy_from_slice(key);
self.entries.push(new_data);
self.lookup.insert(otp.id, self.entries.len() - 1);
is_new = true;
}
self.store(otp.id, is_new).map_err(Error::IoError)
}
fn write_otp<W: std::io::Write>(
writer: &mut W,
public_id: PublicId,
entry: &Data,
write_provision_data: bool,
) -> Result<(), std::io::Error> {
write!(
writer,
"{} = u_cnt:{:05}",
ModHex::from(&public_id.raw_bytes[..]),
entry.previous.usage_counter
)?;
write!(writer, " s_cnt:{:03}", entry.previous.session_counter)?;
if write_provision_data {
write!(writer, " pid:")?;
for b in &entry.previous.id.raw_bytes {
write!(writer, "{b:02x}")?;
}
write!(writer, " key:")?;
for b in entry.key {
write!(writer, "{b:02x}")?;
}
}
Ok(())
}
fn store(&mut self, public_id: PublicId, is_new: bool) -> Result<(), std::io::Error> {
let _lock = self.write_lock.lock().unwrap();
if !std::path::Path::new(&self.file_path).exists() {
return self.store_all();
}
let mut file = std::fs::OpenOptions::new()
.write(true)
.open(&self.file_path)?;
let line_index = *self.lookup.get(&public_id).expect("Id must be present");
let entry = &self.entries[line_index];
if is_new {
file.seek(SeekFrom::End(0))?;
writeln!(file)?;
} else {
let position = (line_index + 1) * LINE_LENGTH_WITH_NEWLINE;
file.seek(SeekFrom::Start(position as u64))?;
}
Self::write_otp(&mut file, public_id, entry, is_new)
}
fn store_all(&self) -> Result<(), std::io::Error> {
let _lock = self.write_lock.lock().unwrap();
let mut file = File::create(&self.file_path)?;
let mut first = true;
write!(
file,
"# Do not modify this file while the program is running this will result in a broken file #"
)?;
for (&public_id, &index) in &self.lookup {
if !first {
writeln!(file)?;
}
first = false;
Self::write_otp(&mut file, public_id, &self.entries[index], true)
.map_err(|_| std::io::Error::other("Format error"))?;
}
Ok(())
}
}
pub(crate) fn decode_hex(s: &str) -> Result<Vec<u8>, std::num::ParseIntError> {
(0..s.len())
.step_by(2)
.map(|i| u8::from_str_radix(&s[i..i + 2], 16))
.collect()
}
impl OtpStore for Store {
fn validate(&mut self, otp_str: &str, _now: SystemTime) -> Result<(), StoreError> {
self.validate(otp_str)
}
}