use super::expression::Expression;
use super::ids::{ExportId, ImportId};
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
#[derive(Debug, Clone, PartialEq)]
pub enum Message {
Push(Expression),
Pull(ImportId),
Resolve(ExportId, Expression),
Reject(ExportId, Expression),
Release(ImportId, u32),
Abort(Expression),
}
impl Message {
pub fn from_json(value: &JsonValue) -> Result<Self, MessageError> {
let arr = value.as_array().ok_or(MessageError::NotAnArray)?;
if arr.is_empty() {
return Err(MessageError::EmptyMessage);
}
let msg_type = arr[0].as_str().ok_or(MessageError::InvalidMessageType)?;
match msg_type {
"push" => {
if arr.len() != 2 {
return Err(MessageError::InvalidPush);
}
let expr = Expression::from_json(&arr[1])?;
Ok(Message::Push(expr))
}
"pull" => {
if arr.len() != 2 {
return Err(MessageError::InvalidPull);
}
let import_id = arr[1].as_i64().ok_or(MessageError::InvalidImportId)?;
Ok(Message::Pull(ImportId(import_id)))
}
"resolve" => {
if arr.len() != 3 {
return Err(MessageError::InvalidResolve);
}
let export_id = arr[1].as_i64().ok_or(MessageError::InvalidExportId)?;
let expr = Expression::from_json(&arr[2])?;
Ok(Message::Resolve(ExportId(export_id), expr))
}
"reject" => {
if arr.len() != 3 {
return Err(MessageError::InvalidReject);
}
let export_id = arr[1].as_i64().ok_or(MessageError::InvalidExportId)?;
let expr = Expression::from_json(&arr[2])?;
Ok(Message::Reject(ExportId(export_id), expr))
}
"release" => {
if arr.len() != 3 {
return Err(MessageError::InvalidRelease);
}
let import_id = arr[1].as_i64().ok_or(MessageError::InvalidImportId)?;
let refcount = arr[2].as_u64().ok_or(MessageError::InvalidRefcount)? as u32;
Ok(Message::Release(ImportId(import_id), refcount))
}
"abort" => {
if arr.len() != 2 {
return Err(MessageError::InvalidAbort);
}
let expr = Expression::from_json(&arr[1])?;
Ok(Message::Abort(expr))
}
_ => Err(MessageError::UnknownMessageType(msg_type.to_string())),
}
}
pub fn to_json(&self) -> JsonValue {
match self {
Message::Push(expr) => {
serde_json::json!(["push", expr.to_json()])
}
Message::Pull(import_id) => {
serde_json::json!(["pull", import_id.0])
}
Message::Resolve(export_id, expr) => {
serde_json::json!(["resolve", export_id.0, expr.to_json()])
}
Message::Reject(export_id, expr) => {
serde_json::json!(["reject", export_id.0, expr.to_json()])
}
Message::Release(import_id, refcount) => {
serde_json::json!(["release", import_id.0, refcount])
}
Message::Abort(expr) => {
serde_json::json!(["abort", expr.to_json()])
}
}
}
}
impl Serialize for Message {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.to_json().serialize(serializer)
}
}
impl<'de> Deserialize<'de> for Message {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = JsonValue::deserialize(deserializer)?;
Message::from_json(&value).map_err(|e| serde::de::Error::custom(e.to_string()))
}
}
#[derive(Debug, thiserror::Error)]
pub enum MessageError {
#[error("Message must be a JSON array")]
NotAnArray,
#[error("Message array cannot be empty")]
EmptyMessage,
#[error("Message type must be a string")]
InvalidMessageType,
#[error("Invalid push message format")]
InvalidPush,
#[error("Invalid pull message format")]
InvalidPull,
#[error("Invalid resolve message format")]
InvalidResolve,
#[error("Invalid reject message format")]
InvalidReject,
#[error("Invalid release message format")]
InvalidRelease,
#[error("Invalid abort message format")]
InvalidAbort,
#[error("Invalid import ID")]
InvalidImportId,
#[error("Invalid export ID")]
InvalidExportId,
#[error("Invalid refcount")]
InvalidRefcount,
#[error("Unknown message type: {0}")]
UnknownMessageType(String),
#[error("Expression error: {0}")]
ExpressionError(#[from] super::expression::ExpressionError),
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_push_message() {
let json = json!(["push", "hello"]);
let msg = Message::from_json(&json).unwrap();
match &msg {
Message::Push(expr) => {
assert_eq!(expr, &Expression::String("hello".to_string()));
}
_ => panic!("Expected Push message"),
}
assert_eq!(msg.to_json(), json);
}
#[test]
fn test_pull_message() {
let json = json!(["pull", 42]);
let msg = Message::from_json(&json).unwrap();
match msg {
Message::Pull(id) => {
assert_eq!(id, ImportId(42));
}
_ => panic!("Expected Pull message"),
}
assert_eq!(msg.to_json(), json);
}
#[test]
fn test_resolve_message() {
let json = json!(["resolve", -1, "result"]);
let msg = Message::from_json(&json).unwrap();
match &msg {
Message::Resolve(id, expr) => {
assert_eq!(id, &ExportId(-1));
assert_eq!(expr, &Expression::String("result".to_string()));
}
_ => panic!("Expected Resolve message"),
}
assert_eq!(msg.to_json(), json);
}
#[test]
fn test_serialization_roundtrip() {
let original = Message::Push(Expression::Number(serde_json::Number::from(42)));
let json = serde_json::to_value(&original).unwrap();
let deserialized: Message = serde_json::from_value(json.clone()).unwrap();
assert_eq!(original, deserialized);
assert_eq!(json, json!(["push", 42]));
}
}