use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::node::NodeDef;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ResultValue {
Single(String),
Multiple(Vec<String>),
Bool(bool),
}
impl std::fmt::Display for ResultValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ResultValue::Single(s) => write!(f, "{s}"),
ResultValue::Bool(b) => write!(f, "{b}"),
ResultValue::Multiple(vs) => write!(f, "{}", vs.join(" ")),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionState {
pub session_id: String,
pub created_at: u64,
pub roots: Vec<NodeDef>,
pub stack: Vec<usize>,
#[serde(default)]
pub results: HashMap<String, ResultValue>,
#[serde(default)]
pub checksum: String,
}
impl SessionState {
pub fn new(session_id: String) -> Self {
let created_at = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
SessionState {
session_id,
created_at,
roots: Vec::new(),
stack: Vec::new(),
results: HashMap::new(),
checksum: String::new(),
}
}
pub fn to_msgpack_bytes(&self) -> anyhow::Result<Vec<u8>> {
let mut copy = self.clone();
copy.checksum = String::new();
let payload = rmp_serde::to_vec_named(©)?;
copy.checksum = fnv1a_hex(&payload);
Ok(rmp_serde::to_vec_named(©)?)
}
pub fn from_msgpack_bytes(bytes: &[u8]) -> anyhow::Result<Self> {
let mut state: SessionState = rmp_serde::from_slice(bytes)?;
let stored_checksum = std::mem::take(&mut state.checksum);
let payload = rmp_serde::to_vec_named(&state)?;
let computed = fnv1a_hex(&payload);
if computed != stored_checksum {
anyhow::bail!(crate::error::RoptError::SessionIntegrityFailure);
}
state.checksum = stored_checksum;
Ok(state)
}
pub fn current_node_mut(&mut self) -> Option<&mut NodeDef> {
if self.stack.is_empty() {
return None;
}
let mut node: &mut NodeDef = self.roots.get_mut(self.stack[0])?;
for &idx in &self.stack[1..] {
node = node.children.get_mut(idx)?;
}
Some(node)
}
pub fn current_node(&self) -> Option<&NodeDef> {
if self.stack.is_empty() {
return None;
}
let mut node: &NodeDef = self.roots.get(self.stack[0])?;
for &idx in &self.stack[1..] {
node = node.children.get(idx)?;
}
Some(node)
}
pub fn name_exists_in_current_scope(&self, name: &str) -> bool {
let siblings: &[NodeDef] = if self.stack.is_empty() {
&self.roots
} else {
match self.current_node() {
Some(n) => &n.children,
None => return false,
}
};
siblings.iter().any(|s| s.name.as_deref() == Some(name))
}
pub fn add_child(&mut self, child: NodeDef) -> usize {
if self.stack.is_empty() {
self.roots.push(child);
self.roots.len() - 1
} else {
let parent = self
.current_node_mut()
.expect("stack points to a valid node");
parent.children.push(child);
parent.children.len() - 1
}
}
pub fn push_node(&mut self, child: NodeDef) {
let idx = self.add_child(child);
self.stack.push(idx);
}
pub fn pop_node(&mut self) -> anyhow::Result<()> {
if self.stack.is_empty() {
anyhow::bail!(crate::error::RoptError::StackUnderflow);
}
self.stack.pop();
Ok(())
}
pub fn depth(&self) -> usize {
self.stack.len()
}
}
fn fnv1a_hex(data: &[u8]) -> String {
const FNV_OFFSET: u64 = 14695981039346656037;
const FNV_PRIME: u64 = 1099511628211;
let mut hash: u64 = FNV_OFFSET;
for byte in data {
hash ^= *byte as u64;
hash = hash.wrapping_mul(FNV_PRIME);
}
format!("{hash:016x}")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trip_preserves_state() {
let mut state = SessionState::new("test-session".into());
state
.results
.insert("key".into(), ResultValue::Single("value".into()));
let bytes = state.to_msgpack_bytes().unwrap();
let restored = SessionState::from_msgpack_bytes(&bytes).unwrap();
assert_eq!(restored.session_id, "test-session");
assert!(restored.results.contains_key("key"));
}
#[test]
fn tampered_file_fails_checksum() {
let state = SessionState::new("tamper-test".into());
let mut bytes = state.to_msgpack_bytes().unwrap();
let mid = bytes.len() / 2;
bytes[mid] ^= 0xFF;
assert!(SessionState::from_msgpack_bytes(&bytes).is_err());
}
}