use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::errors::*;
use crate::appstate::{AppStateManager, AppStateType, AppStateSyncKey, AppStateSyncKeyRequest, AppStateSyncPatch};
use crate::session::MultiDeviceSession;
use crate::node_protocol::{AppMessage, AppEvent};
use crate::node_wire::{Node, NodeContent};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AppStateSyncKind {
Full,
Patch,
Initial,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppStateSyncRequest {
pub r#type: AppStateSyncKind,
pub app_state_types: Vec<AppStateType>,
pub version: u64,
}
impl AppStateSyncRequest {
pub fn new(kind: AppStateSyncKind, types: Vec<AppStateType>) -> Self {
AppStateSyncRequest {
r#type: kind,
app_state_types: types,
version: 0,
}
}
pub fn to_node(&self) -> Node {
let mut attributes = HashMap::new();
match self.r#type {
AppStateSyncKind::Full => {
attributes.insert("type".into(), NodeContent::Token("full"));
},
AppStateSyncKind::Patch => {
attributes.insert("type".into(), NodeContent::Token("patch"));
},
AppStateSyncKind::Initial => {
attributes.insert("type".into(), NodeContent::Token("initial"));
},
}
let types_str: Vec<String> = self.app_state_types
.iter()
.map(|t| t.into_string().to_string())
.collect();
let content = NodeContent::List(types_str.iter()
.map(|s| {
let mut attr = HashMap::new();
attr.insert("name".into(), NodeContent::String(s.as_str().into()));
Node::new("collection", attr, NodeContent::None)
})
.collect());
Node::new("appstate", attributes, content)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppStateSyncResponse {
pub patches: Vec<AppStateSyncPatch>,
pub snapshot: Option<Vec<u8>>,
pub has_more: bool,
}
impl AppStateSyncResponse {
pub fn from_node(node: &Node) -> Result<Self> {
let mut patches = Vec::new();
let mut has_more = false;
if let NodeContent::List(children) = &node.content {
for child in children {
if child.desc() == "patch" {
if let Some(version_str) = child.get_attribute("version").ok() {
let version = version_str.as_str().parse::<u64>()
.map_err(|_| "Invalid version in patch")?;
patches.push(AppStateSyncPatch {
version,
snapshot_mac: vec![], patch_mac: vec![], patches: vec![], });
}
} else if child.desc() == "sync" {
has_more = true;
}
}
}
Ok(AppStateSyncResponse {
patches,
snapshot: None, has_more,
})
}
}
pub struct AppStateSyncManager {
pub app_state_manager: AppStateManager,
pub pending_requests: HashMap<String, AppStateSyncRequest>,
pub session: MultiDeviceSession,
}
impl AppStateSyncManager {
pub fn new(session: MultiDeviceSession) -> Self {
AppStateSyncManager {
app_state_manager: AppStateManager::new(),
pending_requests: HashMap::new(),
session,
}
}
pub fn request_sync(&mut self, types: Vec<AppStateType>) -> Result<AppMessage> {
let request = AppStateSyncRequest::new(AppStateSyncKind::Initial, types);
let request_id = uuid::Uuid::new_v4().to_string();
self.pending_requests.insert(request_id.clone(), request);
let mut attributes = HashMap::new();
attributes.insert("id".into(), NodeContent::String(request_id.into()));
attributes.insert("type".into(), NodeContent::Token("critical_block"));
let node = Node::new("query", attributes, NodeContent::List(vec![]));
Ok(AppMessage::Query(crate::node_protocol::Query::MessagesBefore {
jid: crate::Jid { id: "status@broadcast".to_string(), is_group: false }, id: request_id,
count: 1,
}))
}
pub fn handle_sync_response(&mut self, response: &AppStateSyncResponse) -> Result<()> {
for patch in &response.patches {
if self.app_state_manager.verify_patch_mac(patch) {
self.app_state_manager.update_version();
} else {
return Err("Invalid patch MAC".into());
}
}
Ok(())
}
pub fn get_app_state_key(&self, key_id: &[u8]) -> Option<&AppStateSyncKey> {
self.app_state_manager.keys.get(key_id)
}
pub fn register_app_state_key(&mut self, key: AppStateSyncKey) {
self.app_state_manager.keys.insert(key.key_id.clone(), key);
}
pub fn get_current_version(&self, app_state_type: &str) -> Option<u64> {
self.app_state_manager.current_state.get(app_state_type).copied()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::session::MultiDeviceSession;
#[test]
fn test_app_state_sync_request() {
let types = vec![AppStateType::Regular, AppStateType::CriticalBlock];
let request = AppStateSyncRequest::new(AppStateSyncKind::Initial, types);
let node = request.to_node();
assert_eq!(node.desc(), "appstate");
}
#[test]
fn test_app_state_sync_manager() {
let session = MultiDeviceSession::default();
let mut manager = AppStateSyncManager::new(session);
let types = vec![AppStateType::Regular];
let result = manager.request_sync(types);
assert!(result.is_ok());
let response = AppStateSyncResponse {
patches: vec![],
snapshot: None,
has_more: false,
};
assert!(manager.handle_sync_response(&response).is_ok());
}
}