use alloc::{format, string::ToString, vec::Vec};
use serde::{Deserialize, Serialize};
use crate::core::error::{OxiRouterError, Result};
use crate::core::query_log::QueryLog;
use crate::core::router::RouterConfig;
use crate::core::source::DataSource;
pub(crate) const STATE_MAGIC: [u8; 4] = *b"OXIR";
pub(crate) const STATE_VERSION_V1: u32 = 1;
pub(crate) const STATE_VERSION: u32 = 2;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouterState {
pub(crate) version: u32,
pub(crate) sources: Vec<DataSource>,
pub(crate) model_bytes: Option<Vec<u8>>,
pub(crate) rl_bytes: Option<Vec<u8>>,
pub(crate) query_log: QueryLog,
#[serde(default)]
pub(crate) config: Option<RouterConfig>,
}
impl RouterState {
#[must_use]
pub fn config(&self) -> Option<&RouterConfig> {
self.config.as_ref()
}
pub fn to_bytes(&self) -> Result<Vec<u8>> {
let body =
serde_json::to_vec(self).map_err(|e| OxiRouterError::ModelError(e.to_string()))?;
let mut out = Vec::with_capacity(8 + body.len());
out.extend_from_slice(&STATE_MAGIC);
out.extend_from_slice(&STATE_VERSION.to_le_bytes());
out.extend_from_slice(&body);
Ok(out)
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.len() < 8 {
return Err(OxiRouterError::IncompatibleModel {
reason: "state snapshot too short".into(),
});
}
if bytes[0..4] != STATE_MAGIC {
return Err(OxiRouterError::IncompatibleModel {
reason: "invalid magic bytes (expected OXIR)".into(),
});
}
let ver = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
if ver != STATE_VERSION_V1 && ver != STATE_VERSION {
return Err(OxiRouterError::IncompatibleModel {
reason: format!(
"state version mismatch: expected {STATE_VERSION_V1} or {STATE_VERSION}, found {ver}"
),
});
}
serde_json::from_slice(&bytes[8..]).map_err(|e| OxiRouterError::ModelError(e.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(all(not(feature = "std"), feature = "alloc"))]
use alloc::vec;
fn make_state() -> RouterState {
RouterState {
version: STATE_VERSION,
sources: Vec::new(),
model_bytes: None,
rl_bytes: None,
query_log: QueryLog::new(),
config: None,
}
}
#[test]
fn round_trip_empty() {
let state = make_state();
let bytes = state.to_bytes().expect("encode failed");
assert_eq!(&bytes[0..4], b"OXIR");
let ver = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
assert_eq!(ver, STATE_VERSION);
let restored = RouterState::from_bytes(&bytes).expect("decode failed");
assert_eq!(restored.version, STATE_VERSION);
assert!(restored.sources.is_empty());
assert!(restored.model_bytes.is_none());
assert!(restored.rl_bytes.is_none());
}
#[test]
fn rejects_short_bytes() {
let err = RouterState::from_bytes(&[0u8; 4]).unwrap_err();
assert!(
matches!(err, OxiRouterError::IncompatibleModel { .. }),
"expected IncompatibleModel, got {err:?}"
);
}
#[test]
fn rejects_bad_magic() {
let mut bytes = [0u8; 16];
bytes[0..4].copy_from_slice(b"BAAD");
let err = RouterState::from_bytes(&bytes).unwrap_err();
assert!(
matches!(err, OxiRouterError::IncompatibleModel { .. }),
"expected IncompatibleModel, got {err:?}"
);
}
#[test]
fn rejects_wrong_version() {
let state = make_state();
let mut bytes = state.to_bytes().expect("encode failed");
bytes[4..8].copy_from_slice(&99u32.to_le_bytes());
let err = RouterState::from_bytes(&bytes).unwrap_err();
assert!(
matches!(err, OxiRouterError::IncompatibleModel { .. }),
"expected IncompatibleModel, got {err:?}"
);
}
#[test]
fn round_trip_with_data() {
use crate::core::source::DataSource;
let mut state = make_state();
state
.sources
.push(DataSource::new("test-src", "https://example.org/sparql"));
state.model_bytes = Some(vec![1, 2, 3, 4]);
state.rl_bytes = Some(vec![5, 6, 7, 8]);
let bytes = state.to_bytes().expect("encode failed");
let restored = RouterState::from_bytes(&bytes).expect("decode failed");
assert_eq!(restored.sources.len(), 1);
assert_eq!(restored.sources[0].id, "test-src");
assert_eq!(restored.model_bytes, Some(vec![1, 2, 3, 4]));
assert_eq!(restored.rl_bytes, Some(vec![5, 6, 7, 8]));
}
}