use std::collections::HashSet;
use crate::util::bencode::{self, BtValue};
use crate::config::config_message::{
ConfigData, ConfigValue, MutableConfigMessage, SignCallable, VerifyCallable,
};
use crate::config::encrypt::{config_decrypt, config_encrypt, pad_message, ENCRYPT_DATA_OVERHEAD};
use crate::config::namespaces::Namespace;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConfigState {
Clean = 0,
Dirty = 1,
Waiting = 2,
}
pub trait ConfigType: Default {
fn namespace() -> Namespace;
fn encryption_domain() -> &'static str;
fn config_lags() -> i64 {
5
}
fn accepts_protobuf() -> bool {
false
}
fn is_readonly() -> bool {
false
}
fn load_from_data(&mut self, data: &ConfigData);
fn store_to_data(&self, data: &mut ConfigData);
}
#[derive(Debug, Clone)]
pub struct PushData {
pub seqno: i64,
pub messages: Vec<Vec<u8>>,
pub obsolete_hashes: Vec<String>,
}
pub struct ConfigBase<T: ConfigType> {
pub config_type: T,
message: MutableConfigMessage,
keys: Vec<[u8; 32]>,
state: ConfigState,
needs_dump: bool,
curr_hashes: HashSet<String>,
old_hashes: HashSet<String>,
verifier: Option<VerifyCallable>,
signer: Option<SignCallable>,
}
impl<T: ConfigType> ConfigBase<T> {
pub fn new(secret_key: &[u8], dump: Option<&[u8]>) -> Result<Self, String> {
let enc_key = derive_enc_key(secret_key)?;
let mut base = ConfigBase {
config_type: T::default(),
message: MutableConfigMessage::new_empty(),
keys: vec![enc_key],
state: ConfigState::Clean,
needs_dump: false,
curr_hashes: HashSet::new(),
old_hashes: HashSet::new(),
verifier: None,
signer: None,
};
if let Some(dump_data) = dump {
base.load_dump(dump_data)?;
}
let data = base.message.data().clone();
base.config_type.load_from_data(&data);
Ok(base)
}
pub fn new_group(
ed25519_pubkey: &[u8; 32],
ed25519_secretkey: Option<&[u8]>,
dump: Option<&[u8]>,
) -> Result<Self, String> {
let mut base = ConfigBase {
config_type: T::default(),
message: MutableConfigMessage::new_empty(),
keys: Vec::new(),
state: ConfigState::Clean,
needs_dump: false,
curr_hashes: HashSet::new(),
old_hashes: HashSet::new(),
verifier: None,
signer: None,
};
let pubkey = *ed25519_pubkey;
base.verifier = Some(std::sync::Arc::new(move |data: &[u8], sig: &[u8]| {
if sig.len() != 64 {
return false;
}
crate::crypto::ed25519::verify(sig, &pubkey, data).unwrap_or(false)
}));
if let Some(sk) = ed25519_secretkey {
let sk_owned: Vec<u8> = sk.to_vec();
base.signer = Some(std::sync::Arc::new(move |data: &[u8]| {
crate::crypto::ed25519::sign(&sk_owned, data)
.expect("ed25519 signing failed")
.to_vec()
}));
}
if let Some(dump_data) = dump {
base.load_dump(dump_data)?;
}
let data = base.message.data().clone();
base.config_type.load_from_data(&data);
Ok(base)
}
pub fn add_key(&mut self, key: [u8; 32], high_priority: bool) {
if high_priority {
self.keys.insert(0, key);
} else {
self.keys.push(key);
}
}
pub fn clear_keys(&mut self) {
self.keys.clear();
}
pub fn key_count(&self) -> usize {
self.keys.len()
}
pub fn needs_push(&self) -> bool {
if T::is_readonly() {
return false;
}
self.state == ConfigState::Dirty
}
pub fn needs_dump(&self) -> bool {
self.needs_dump
}
pub fn state(&self) -> ConfigState {
self.state
}
pub fn is_dirty(&self) -> bool {
self.state == ConfigState::Dirty
}
pub fn seqno(&self) -> i64 {
self.message.seqno()
}
pub fn current_hashes(&self) -> Vec<String> {
self.curr_hashes.iter().cloned().collect()
}
pub fn old_hashes(&self) -> Vec<String> {
self.old_hashes.iter().cloned().collect()
}
pub fn data(&self) -> &ConfigData {
self.message.data()
}
pub fn dirty_data(&mut self) -> &mut ConfigData {
self.mark_dirty();
self.message.data_mut()
}
fn mark_dirty(&mut self) {
if self.state != ConfigState::Dirty {
self.set_state(ConfigState::Dirty);
}
}
fn set_state(&mut self, s: ConfigState) {
if s == ConfigState::Dirty && self.state == ConfigState::Clean && !self.curr_hashes.is_empty()
{
for h in self.curr_hashes.drain() {
self.old_hashes.insert(h);
}
}
self.state = s;
self.needs_dump = true;
}
fn sync_to_message(&mut self) {
let mut data = self.message.data().clone();
self.config_type.store_to_data(&mut data);
*self.message.data_mut() = data;
}
pub fn push(&mut self) -> PushData {
if T::is_readonly() {
return PushData {
seqno: 0,
messages: vec![],
obsolete_hashes: vec![],
};
}
self.sync_to_message();
let seqno = self.message.seqno();
let serialized = self.message.serialize();
let mut padded = serialized;
pad_message(&mut padded, ENCRYPT_DATA_OVERHEAD);
let encrypted = if let Some(key) = self.keys.first() {
config_encrypt(&padded, key, T::encryption_domain())
} else {
padded
};
let obsolete: Vec<String> = self.old_hashes.drain().collect();
if self.state == ConfigState::Dirty {
self.set_state(ConfigState::Waiting);
}
PushData {
seqno,
messages: vec![encrypted],
obsolete_hashes: obsolete,
}
}
pub fn confirm_pushed(&mut self, seqno: i64, msg_hash: &str) {
if seqno == self.message.seqno() && self.state == ConfigState::Waiting {
self.curr_hashes.clear();
self.curr_hashes.insert(msg_hash.to_string());
self.state = ConfigState::Clean;
self.needs_dump = true;
}
}
pub fn merge(&mut self, messages: &[(&str, &[u8])]) -> Result<Vec<String>, String> {
if self.keys.is_empty() {
return Err("Cannot merge configs without any decryption keys".into());
}
let mut good_hashes = Vec::new();
let mut decrypted_messages: Vec<Vec<u8>> = Vec::new();
let mut decrypted_hashes: Vec<String> = Vec::new();
for (hash, encrypted) in messages {
let mut decrypted = None;
for key in &self.keys {
match config_decrypt(encrypted, key, T::encryption_domain()) {
Ok(plain) => {
decrypted = Some(plain);
break;
}
Err(_) => continue,
}
}
if let Some(mut plain) = decrypted {
if let Some(pos) = plain.iter().position(|&b| b != 0)
&& pos > 0 {
plain = plain[pos..].to_vec();
}
if plain.is_empty() {
continue;
}
if plain[0] == b'd' {
good_hashes.push(hash.to_string());
decrypted_hashes.push(hash.to_string());
decrypted_messages.push(plain);
}
}
}
if decrypted_messages.is_empty() {
return Ok(good_hashes);
}
let mine = self.message.serialize();
let old_seqno = self.message.seqno();
let mut all_configs: Vec<&[u8]> = Vec::new();
all_configs.push(&mine);
for msg in &decrypted_messages {
all_configs.push(msg);
}
let verifier_for_merge = self.verifier.clone();
let signer_for_merge = self.signer.clone();
match MutableConfigMessage::from_multiple(
&all_configs,
verifier_for_merge,
signer_for_merge,
T::config_lags(),
Some(&|_i, _e| {
}),
) {
Ok(new_msg) => {
let merged = new_msg.merged();
let new_seqno = new_msg.seqno();
for hash in &decrypted_hashes {
self.old_hashes.insert(hash.clone());
}
self.message = new_msg;
if new_seqno != old_seqno {
if merged {
self.set_state(ConfigState::Dirty);
} else {
self.state = ConfigState::Clean;
self.needs_dump = true;
}
} else {
self.needs_dump = true;
}
let data = self.message.data().clone();
self.config_type.load_from_data(&data);
}
Err(e) => {
return Err(format!("Merge failed: {}", e));
}
}
Ok(good_hashes)
}
pub fn dump(&mut self) -> Vec<u8> {
self.sync_to_message();
self.needs_dump = false;
let serialized = self.message.serialize();
let mut dump_dict = std::collections::BTreeMap::new();
dump_dict.insert(b"!".to_vec(), BtValue::String(serialized));
if !self.curr_hashes.is_empty() {
let hash_list: Vec<BtValue> = self
.curr_hashes
.iter()
.map(|h| BtValue::String(h.as_bytes().to_vec()))
.collect();
dump_dict.insert(b"+".to_vec(), BtValue::List(hash_list));
}
if !self.old_hashes.is_empty() {
let hash_list: Vec<BtValue> = self
.old_hashes
.iter()
.map(|h| BtValue::String(h.as_bytes().to_vec()))
.collect();
dump_dict.insert(b"-".to_vec(), BtValue::List(hash_list));
}
bencode::encode(&BtValue::Dict(dump_dict))
}
fn load_dump(&mut self, dump_data: &[u8]) -> Result<(), String> {
let top = bencode::decode(dump_data).map_err(|e| format!("Invalid dump: {}", e))?;
let dict = match &top {
BtValue::Dict(d) => d,
_ => return Err("Dump must be a bencode dict".into()),
};
if let Some(BtValue::String(body)) = dict.get(b"!".as_ref()) {
let verifier = self.verifier.clone();
let signer = self.signer.clone();
match MutableConfigMessage::from_bytes(body, verifier, signer, T::config_lags()) {
Ok(msg) => {
self.message = msg;
}
Err(e) => {
return Err(format!("Failed to parse dump body: {}", e));
}
}
}
if let Some(BtValue::List(hashes)) = dict.get(b"+".as_ref()) {
for h in hashes {
if let BtValue::String(s) = h
&& let Ok(hash_str) = String::from_utf8(s.clone()) {
self.curr_hashes.insert(hash_str);
}
}
}
if let Some(BtValue::List(hashes)) = dict.get(b"-".as_ref()) {
for h in hashes {
if let BtValue::String(s) = h
&& let Ok(hash_str) = String::from_utf8(s.clone()) {
self.old_hashes.insert(hash_str);
}
}
}
Ok(())
}
pub fn get(&self) -> &T {
&self.config_type
}
pub fn get_mut(&mut self) -> &mut T {
self.mark_dirty();
&mut self.config_type
}
}
pub mod field_helpers {
use super::*;
pub fn get_string(data: &ConfigData, key: &[u8]) -> Option<String> {
match data.get(key) {
Some(ConfigValue::String(s)) => {
if s.is_empty() {
None
} else {
String::from_utf8(s.clone()).ok()
}
}
_ => None,
}
}
pub fn get_bytes(data: &ConfigData, key: &[u8]) -> Option<Vec<u8>> {
match data.get(key) {
Some(ConfigValue::String(s)) if !s.is_empty() => Some(s.clone()),
_ => None,
}
}
pub fn get_int(data: &ConfigData, key: &[u8]) -> Option<i64> {
match data.get(key) {
Some(ConfigValue::Integer(n)) => Some(*n),
_ => None,
}
}
pub fn get_int_or_zero(data: &ConfigData, key: &[u8]) -> i64 {
get_int(data, key).unwrap_or(0)
}
pub fn get_dict<'a>(data: &'a ConfigData, key: &[u8]) -> Option<&'a ConfigData> {
match data.get(key) {
Some(ConfigValue::Dict(d)) => Some(d),
_ => None,
}
}
pub fn get_or_create_dict<'a>(data: &'a mut ConfigData, key: &[u8]) -> &'a mut ConfigData {
data.entry(key.to_vec())
.or_insert_with(|| ConfigValue::Dict(ConfigData::new()));
match data.get_mut(key) {
Some(ConfigValue::Dict(d)) => d,
_ => unreachable!(),
}
}
pub fn set_nonempty_str(data: &mut ConfigData, key: &[u8], val: &str) {
if val.is_empty() {
data.remove(key);
} else {
data.insert(key.to_vec(), ConfigValue::String(val.as_bytes().to_vec()));
}
}
pub fn set_str_always(data: &mut ConfigData, key: &[u8], val: &str) {
data.insert(key.to_vec(), ConfigValue::String(val.as_bytes().to_vec()));
}
pub fn set_nonempty_bytes(data: &mut ConfigData, key: &[u8], val: &[u8]) {
if val.is_empty() {
data.remove(key);
} else {
data.insert(key.to_vec(), ConfigValue::String(val.to_vec()));
}
}
pub fn set_nonzero_int(data: &mut ConfigData, key: &[u8], val: i64) {
if val == 0 {
data.remove(key);
} else {
data.insert(key.to_vec(), ConfigValue::Integer(val));
}
}
pub fn set_positive_int(data: &mut ConfigData, key: &[u8], val: i64) {
if val <= 0 {
data.remove(key);
} else {
data.insert(key.to_vec(), ConfigValue::Integer(val));
}
}
pub fn set_flag(data: &mut ConfigData, key: &[u8], val: bool) {
if val {
data.insert(key.to_vec(), ConfigValue::Integer(1));
} else {
data.remove(key);
}
}
pub fn set_pair_if(
data: &mut ConfigData,
condition: bool,
key1: &[u8],
val1: &[u8],
key2: &[u8],
val2: &[u8],
) {
if condition {
data.insert(key1.to_vec(), ConfigValue::String(val1.to_vec()));
data.insert(key2.to_vec(), ConfigValue::String(val2.to_vec()));
} else {
data.remove(key1);
data.remove(key2);
}
}
}
fn derive_enc_key(secret_key: &[u8]) -> Result<[u8; 32], String> {
let seed = if secret_key.len() == 64 {
&secret_key[..32]
} else if secret_key.len() == 32 {
secret_key
} else {
return Err(format!(
"Invalid secret key length: expected 32 or 64, got {}",
secret_key.len()
));
};
let hash = blake2b_simd::Params::new()
.hash_length(32)
.key(b"SessionConfig")
.hash(seed);
let mut key = [0u8; 32];
key.copy_from_slice(hash.as_bytes());
Ok(key)
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Default)]
struct TestConfig {
name: Option<String>,
}
impl ConfigType for TestConfig {
fn namespace() -> Namespace {
Namespace::UserProfile
}
fn encryption_domain() -> &'static str {
"UserProfile"
}
fn load_from_data(&mut self, data: &ConfigData) {
self.name = field_helpers::get_string(data, b"n");
}
fn store_to_data(&self, data: &mut ConfigData) {
if let Some(ref name) = self.name {
field_helpers::set_nonempty_str(data, b"n", name);
}
}
}
#[test]
fn test_new_config_is_clean() {
let seed = [0u8; 32];
let base: ConfigBase<TestConfig> = ConfigBase::new(&seed, None).unwrap();
assert_eq!(base.state(), ConfigState::Clean);
assert!(!base.needs_push());
assert!(!base.needs_dump());
}
#[test]
fn test_dirty_after_mutation() {
let seed = [0u8; 32];
let mut base: ConfigBase<TestConfig> = ConfigBase::new(&seed, None).unwrap();
base.get_mut().name = Some("Test".to_string());
assert_eq!(base.state(), ConfigState::Dirty);
assert!(base.needs_push());
assert!(base.needs_dump());
}
#[test]
fn test_push_transitions_to_waiting() {
let seed = [0u8; 32];
let mut base: ConfigBase<TestConfig> = ConfigBase::new(&seed, None).unwrap();
base.get_mut().name = Some("Test".to_string());
let push_data = base.push();
assert_eq!(base.state(), ConfigState::Waiting);
assert!(push_data.seqno > 0 || push_data.seqno == 0);
assert!(!push_data.messages.is_empty());
}
#[test]
fn test_confirm_transitions_to_clean() {
let seed = [0u8; 32];
let mut base: ConfigBase<TestConfig> = ConfigBase::new(&seed, None).unwrap();
base.get_mut().name = Some("Test".to_string());
let push_data = base.push();
base.confirm_pushed(push_data.seqno, "abc123");
assert_eq!(base.state(), ConfigState::Clean);
assert_eq!(base.current_hashes(), vec!["abc123"]);
}
#[test]
fn test_dump_and_load_roundtrip() {
let seed = [0u8; 32];
let mut base: ConfigBase<TestConfig> = ConfigBase::new(&seed, None).unwrap();
base.get_mut().name = Some("Test User".to_string());
let push_data = base.push();
base.confirm_pushed(push_data.seqno, "hash1");
let dump = base.dump();
assert!(!dump.is_empty());
let base2: ConfigBase<TestConfig> = ConfigBase::new(&seed, Some(&dump)).unwrap();
assert_eq!(base2.get().name.as_deref(), Some("Test User"));
}
#[test]
fn test_derive_enc_key_deterministic() {
let seed1 = [1u8; 32];
let key1 = derive_enc_key(&seed1).unwrap();
let key2 = derive_enc_key(&seed1).unwrap();
assert_eq!(key1, key2);
let seed2 = [2u8; 32];
let key3 = derive_enc_key(&seed2).unwrap();
assert_ne!(key1, key3);
}
#[test]
fn test_derive_enc_key_from_64_byte_key() {
let seed = [1u8; 32];
let mut full_key = [0u8; 64];
full_key[..32].copy_from_slice(&seed);
full_key[32..].copy_from_slice(&[2u8; 32]);
let key_from_seed = derive_enc_key(&seed).unwrap();
let key_from_full = derive_enc_key(&full_key).unwrap();
assert_eq!(key_from_seed, key_from_full);
}
#[test]
fn test_field_helpers() {
let mut data = ConfigData::new();
field_helpers::set_nonempty_str(&mut data, b"n", "Alice");
assert_eq!(field_helpers::get_string(&data, b"n"), Some("Alice".into()));
field_helpers::set_nonempty_str(&mut data, b"n", "");
assert_eq!(field_helpers::get_string(&data, b"n"), None);
field_helpers::set_nonzero_int(&mut data, b"x", 42);
assert_eq!(field_helpers::get_int(&data, b"x"), Some(42));
field_helpers::set_nonzero_int(&mut data, b"x", 0);
assert_eq!(field_helpers::get_int(&data, b"x"), None);
field_helpers::set_flag(&mut data, b"f", true);
assert_eq!(field_helpers::get_int(&data, b"f"), Some(1));
field_helpers::set_flag(&mut data, b"f", false);
assert_eq!(field_helpers::get_int(&data, b"f"), None);
}
}