use chacha20poly1305::aead::{Aead, AeadCore, KeyInit, OsRng};
use chacha20poly1305::{ChaCha20Poly1305, Nonce};
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use std::time::{SystemTime, UNIX_EPOCH};
use crate::error::{Error, ErrorCode};
use crate::types::mrtr::InputResponses;
const NONCE_LEN: usize = 12;
fn derive_key(secret: &[u8]) -> [u8; 32] {
use sha2::Digest;
let mut h = Sha256::new();
h.update(b"neva:mrtr:requestState:v1");
h.update(secret);
h.finalize().into()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct StatePayload {
pub answers: InputResponses,
#[serde(default)]
pub requested: Vec<String>,
#[serde(default)]
pub memos: std::collections::HashMap<String, serde_json::Value>,
#[serde(default)]
pub effects: std::collections::HashSet<String>,
pub exp: u64,
pub req: String,
pub principal: Option<String>,
}
pub(crate) struct StateCodec<'a> {
key: &'a [u8],
}
impl<'a> StateCodec<'a> {
pub(crate) fn new(key: &'a [u8]) -> Self {
Self { key }
}
fn cipher(&self) -> Result<ChaCha20Poly1305, Error> {
ChaCha20Poly1305::new_from_slice(&derive_key(self.key))
.map_err(|_| Error::new(ErrorCode::InternalError, "bad state key"))
}
pub(crate) fn encode(&self, payload: &StatePayload) -> Result<String, Error> {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD as B64};
let json = serde_json::to_vec(payload).map_err(Error::from)?;
let cipher = self.cipher()?;
let nonce = ChaCha20Poly1305::generate_nonce(&mut OsRng);
let sealed = cipher
.encrypt(&nonce, json.as_slice())
.map_err(|_| Error::new(ErrorCode::InternalError, "requestState encryption failed"))?;
Ok(format!("{}.{}", B64.encode(nonce), B64.encode(sealed)))
}
pub(crate) fn decode(&self, blob: &str) -> Result<StatePayload, Error> {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD as B64};
let (n_b64, c_b64) = blob
.split_once('.')
.ok_or_else(|| Error::new(ErrorCode::InvalidParams, "malformed requestState"))?;
let nonce = B64
.decode(n_b64)
.map_err(|_| Error::new(ErrorCode::InvalidParams, "bad requestState nonce"))?;
let sealed = B64
.decode(c_b64)
.map_err(|_| Error::new(ErrorCode::InvalidParams, "bad requestState payload"))?;
let nonce: [u8; NONCE_LEN] = nonce
.try_into()
.map_err(|_| Error::new(ErrorCode::InvalidParams, "bad requestState nonce"))?;
let json = self
.cipher()?
.decrypt(&Nonce::from(nonce), sealed.as_slice())
.map_err(|_| {
Error::new(
ErrorCode::InvalidParams,
"requestState integrity check failed",
)
})?;
serde_json::from_slice(&json).map_err(Error::from)
}
}
pub(crate) fn now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
pub(crate) fn request_binding(method: &str, salient_params: &serde_json::Value) -> String {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD as B64};
use sha2::Digest;
let bytes = serde_json::to_vec(&canonicalize(salient_params)).unwrap_or_default();
let digest = Sha256::digest(&bytes);
format!("{method}:{}", B64.encode(digest))
}
pub(crate) fn input_responses_digest(responses: &InputResponses) -> String {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD as B64};
use sha2::Digest;
let value = serde_json::to_value(responses).unwrap_or_default();
let bytes = serde_json::to_vec(&canonicalize(&value)).unwrap_or_default();
B64.encode(Sha256::digest(&bytes))
}
fn canonicalize(value: &serde_json::Value) -> serde_json::Value {
match value {
serde_json::Value::Object(map) => {
let mut keys: Vec<&String> = map.keys().collect();
keys.sort_unstable();
let mut out = serde_json::Map::with_capacity(keys.len());
for key in keys {
out.insert(key.clone(), canonicalize(&map[key]));
}
serde_json::Value::Object(out)
}
serde_json::Value::Array(items) => {
serde_json::Value::Array(items.iter().map(canonicalize).collect())
}
other => other.clone(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn payload() -> StatePayload {
StatePayload {
answers: HashMap::new(),
requested: Vec::new(),
memos: HashMap::new(),
effects: std::collections::HashSet::new(),
exp: now_secs() + 300,
req: request_binding("tools/call", &serde_json::json!({"name":"t"})),
principal: Some("alice".into()),
}
}
#[test]
fn memos_and_effects_roundtrip() {
let codec = StateCodec::new(b"secret-key");
let mut p = payload();
p.memos
.insert("quote".into(), serde_json::json!({"price": 42}));
p.effects.insert("charge".into());
let blob = codec.encode(&p).unwrap();
let got = codec.decode(&blob).unwrap();
assert_eq!(
got.memos.get("quote"),
Some(&serde_json::json!({"price": 42}))
);
assert!(got.effects.contains("charge"));
}
#[test]
fn payload_without_memos_or_effects_decodes_with_defaults() {
let json = serde_json::json!({
"answers": {},
"exp": now_secs() + 300,
"req": request_binding("tools/call", &serde_json::json!({"name":"t"})),
"principal": serde_json::Value::Null,
});
let codec = StateCodec::new(b"secret-key");
let blob = {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD as B64};
let bytes = serde_json::to_vec(&json).unwrap();
let cipher = codec.cipher().unwrap();
let nonce = ChaCha20Poly1305::generate_nonce(&mut OsRng);
let sealed = cipher.encrypt(&nonce, bytes.as_slice()).unwrap();
format!("{}.{}", B64.encode(nonce), B64.encode(sealed))
};
let got = codec.decode(&blob).unwrap();
assert!(got.memos.is_empty());
assert!(got.effects.is_empty());
assert!(got.requested.is_empty());
}
#[test]
fn memo_values_are_not_readable_from_the_wire_blob() {
let codec = StateCodec::new(b"secret-key");
let mut p = payload();
p.memos
.insert("token".into(), serde_json::json!("super-secret-value"));
let blob = codec.encode(&p).unwrap();
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD as B64};
for segment in blob.split('.') {
let bytes = B64.decode(segment).unwrap_or_default();
assert!(
!bytes
.windows(b"super-secret-value".len())
.any(|w| w == b"super-secret-value"),
"memo value leaked in plaintext within the blob"
);
}
assert_eq!(
codec.decode(&blob).unwrap().memos.get("token"),
Some(&serde_json::json!("super-secret-value"))
);
}
#[test]
fn encode_decode_roundtrips() {
let codec = StateCodec::new(b"secret-key");
let p = payload();
let blob = codec.encode(&p).unwrap();
let got = codec.decode(&blob).unwrap();
assert_eq!(got.exp, p.exp);
assert_eq!(got.req, p.req);
assert_eq!(got.principal, p.principal);
assert!(got.answers.is_empty());
}
#[test]
fn tampered_blob_is_rejected() {
let codec = StateCodec::new(b"secret-key");
let mut blob = codec.encode(&payload()).unwrap();
blob.push('x'); assert!(codec.decode(&blob).is_err());
}
#[test]
fn wrong_key_is_rejected() {
let blob = StateCodec::new(b"key-a").encode(&payload()).unwrap();
assert!(StateCodec::new(b"key-b").decode(&blob).is_err());
}
#[test]
fn request_binding_is_stable_and_distinct() {
let a = request_binding("tools/call", &serde_json::json!({"name":"t"}));
let b = request_binding("tools/call", &serde_json::json!({"name":"t"}));
let c = request_binding("tools/call", &serde_json::json!({"name":"u"}));
assert_eq!(a, b);
assert_ne!(a, c);
}
#[test]
fn request_binding_is_independent_of_object_key_order() {
let mut first = serde_json::Map::new();
first.insert("name".into(), serde_json::json!("t"));
first.insert(
"args".into(),
serde_json::json!([{"a": 1, "b": 2}, {"c": 3}]),
);
let mut second = serde_json::Map::new();
second.insert(
"args".into(),
serde_json::json!([{"b": 2, "a": 1}, {"c": 3}]),
);
second.insert("name".into(), serde_json::json!("t"));
assert_eq!(
request_binding("tools/call", &serde_json::Value::Object(first)),
request_binding("tools/call", &serde_json::Value::Object(second)),
);
}
fn answer(content: serde_json::Value) -> crate::types::elicitation::ElicitResult {
crate::types::elicitation::ElicitResult {
action: crate::types::elicitation::ElicitationAction::Accept,
content: Some(content),
meta: None,
}
}
#[test]
fn input_responses_digest_distinguishes_distinct_answers() {
let mut a = InputResponses::new();
a.insert("k".into(), answer(serde_json::json!({"v": 1})));
let mut b = InputResponses::new();
b.insert("k".into(), answer(serde_json::json!({"v": 2})));
assert_ne!(input_responses_digest(&a), input_responses_digest(&b));
let mut a2 = InputResponses::new();
a2.insert("k".into(), answer(serde_json::json!({"v": 1})));
assert_eq!(input_responses_digest(&a), input_responses_digest(&a2));
}
#[test]
fn input_responses_digest_is_independent_of_key_order() {
let mut first = InputResponses::new();
first.insert("a".into(), answer(serde_json::json!({"x": 1, "y": 2})));
first.insert("b".into(), answer(serde_json::json!(null)));
let mut second = InputResponses::new();
second.insert("b".into(), answer(serde_json::json!(null)));
second.insert("a".into(), answer(serde_json::json!({"y": 2, "x": 1})));
assert_eq!(
input_responses_digest(&first),
input_responses_digest(&second)
);
}
}