use serde::{Deserialize, Serialize};
use super::envelope::ProtocolEnvelope;
use crate::effects::{ChoreographyError, LabelId};
use crate::identifiers::RoleName;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BlockedOn<L: LabelId> {
Send {
to: RoleName,
message_type: String,
},
Recv {
from: RoleName,
expected_types: Vec<String>,
},
Choice {
branches: Vec<L>,
},
Offer {
from: RoleName,
branches: Vec<L>,
},
Complete,
Failed(String),
}
impl<L: LabelId> BlockedOn<L> {
#[must_use]
pub fn is_terminal(&self) -> bool {
matches!(self, BlockedOn::Complete | BlockedOn::Failed(_))
}
#[must_use]
pub fn is_send(&self) -> bool {
matches!(self, BlockedOn::Send { .. })
}
#[must_use]
pub fn is_recv(&self) -> bool {
matches!(self, BlockedOn::Recv { .. })
}
#[must_use]
pub fn is_choice(&self) -> bool {
matches!(self, BlockedOn::Choice { .. } | BlockedOn::Offer { .. })
}
}
#[derive(Debug, Clone)]
pub enum StepInput<L: LabelId> {
SendMessage(ProtocolEnvelope),
RecvMessage(ProtocolEnvelope),
MakeChoice(L),
ReceiveOffer(L),
Timeout,
Error(String),
}
impl<L: LabelId> StepInput<L> {
pub fn send(envelope: ProtocolEnvelope) -> Self {
Self::SendMessage(envelope)
}
pub fn recv(envelope: ProtocolEnvelope) -> Self {
Self::RecvMessage(envelope)
}
pub fn choice(branch: L) -> Self {
Self::MakeChoice(branch)
}
pub fn offer(branch: L) -> Self {
Self::ReceiveOffer(branch)
}
}
#[derive(Debug, Clone)]
pub enum StepOutput<L: LabelId> {
Sent(ProtocolEnvelope),
Received {
envelope: ProtocolEnvelope,
response: Option<ProtocolEnvelope>,
},
ChoiceMade(L),
OfferReceived(L),
Completed,
NoProgress,
}
impl<L: LabelId> StepOutput<L> {
#[must_use]
pub fn is_completed(&self) -> bool {
matches!(self, StepOutput::Completed)
}
#[must_use]
pub fn made_progress(&self) -> bool {
!matches!(self, StepOutput::NoProgress)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Checkpoint {
pub protocol: String,
pub role: RoleName,
pub state_id: String,
pub state_data: Vec<u8>,
pub sequence: u64,
pub metadata: std::collections::BTreeMap<String, String>,
}
impl Checkpoint {
pub fn new(protocol: impl Into<String>, role: RoleName, state_id: impl Into<String>) -> Self {
Self {
protocol: protocol.into(),
role,
state_id: state_id.into(),
state_data: Vec::new(),
sequence: 0,
metadata: std::collections::BTreeMap::new(),
}
}
#[must_use]
pub fn with_data(mut self, data: Vec<u8>) -> Self {
self.state_data = data;
self
}
#[must_use]
pub fn with_sequence(mut self, seq: u64) -> Self {
self.sequence = seq;
self
}
#[must_use]
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn to_bytes(&self) -> Result<Vec<u8>, CheckpointError> {
bincode::serialize(self).map_err(|e| CheckpointError::Serialization(e.to_string()))
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, CheckpointError> {
bincode::deserialize(bytes).map_err(|e| CheckpointError::Deserialization(e.to_string()))
}
}
#[derive(Debug, thiserror::Error)]
pub enum CheckpointError {
#[error("Checkpoint serialization error: {0}")]
Serialization(String),
#[error("Checkpoint deserialization error: {0}")]
Deserialization(String),
#[error("Incompatible checkpoint: {0}")]
Incompatible(String),
}
pub trait ProtocolStateMachine: Send {
type Label: LabelId;
fn protocol_name(&self) -> &str;
fn role(&self) -> &RoleName;
fn blocked_on(&self) -> BlockedOn<Self::Label>;
fn step(
&mut self,
input: StepInput<Self::Label>,
) -> Result<StepOutput<Self::Label>, ChoreographyError>;
fn checkpoint(&self) -> Result<Checkpoint, CheckpointError>;
fn restore(&mut self, checkpoint: &Checkpoint) -> Result<(), CheckpointError>;
fn sequence(&self) -> u64;
fn is_complete(&self) -> bool {
self.blocked_on().is_terminal()
}
}
#[derive(Debug)]
pub struct LinearStateMachine<L: LabelId> {
protocol: String,
role: RoleName,
states: Vec<BlockedOn<L>>,
current_state: usize,
sequence: u64,
}
impl<L: LabelId> LinearStateMachine<L> {
pub fn new(protocol: impl Into<String>, role: RoleName, states: Vec<BlockedOn<L>>) -> Self {
Self {
protocol: protocol.into(),
role,
states,
current_state: 0,
sequence: 0,
}
}
fn advance(&mut self) {
if self.current_state < self.states.len() {
self.current_state += 1;
self.sequence += 1;
}
}
}
impl<L: LabelId> ProtocolStateMachine for LinearStateMachine<L> {
type Label = L;
fn protocol_name(&self) -> &str {
&self.protocol
}
fn role(&self) -> &RoleName {
&self.role
}
fn blocked_on(&self) -> BlockedOn<Self::Label> {
self.states
.get(self.current_state)
.cloned()
.unwrap_or(BlockedOn::Complete)
}
fn step(
&mut self,
input: StepInput<Self::Label>,
) -> Result<StepOutput<Self::Label>, ChoreographyError> {
let current = self.blocked_on();
match (¤t, &input) {
(BlockedOn::Send { .. }, StepInput::SendMessage(env)) => {
self.advance();
Ok(StepOutput::Sent(env.clone()))
}
(BlockedOn::Recv { .. }, StepInput::RecvMessage(env)) => {
self.advance();
Ok(StepOutput::Received {
envelope: env.clone(),
response: None,
})
}
(BlockedOn::Choice { branches }, StepInput::MakeChoice(branch)) => {
if branches.contains(branch) {
self.advance();
Ok(StepOutput::ChoiceMade(*branch))
} else {
Err(ChoreographyError::InvalidChoice {
expected: branches
.iter()
.map(|label| label.as_str().to_string())
.collect(),
actual: branch.as_str().to_string(),
})
}
}
(BlockedOn::Offer { branches, .. }, StepInput::ReceiveOffer(branch)) => {
if branches.contains(branch) {
self.advance();
Ok(StepOutput::OfferReceived(*branch))
} else {
Err(ChoreographyError::InvalidChoice {
expected: branches
.iter()
.map(|label| label.as_str().to_string())
.collect(),
actual: branch.as_str().to_string(),
})
}
}
(BlockedOn::Complete, _) => Ok(StepOutput::Completed),
(BlockedOn::Failed(msg), _) => Err(ChoreographyError::ExecutionError(msg.clone())),
_ => Ok(StepOutput::NoProgress),
}
}
fn checkpoint(&self) -> Result<Checkpoint, CheckpointError> {
let state_data = bincode::serialize(&self.current_state)
.map_err(|e| CheckpointError::Serialization(e.to_string()))?;
Ok(Checkpoint::new(
&self.protocol,
self.role.clone(),
format!("state_{}", self.current_state),
)
.with_data(state_data)
.with_sequence(self.sequence))
}
fn restore(&mut self, checkpoint: &Checkpoint) -> Result<(), CheckpointError> {
if checkpoint.protocol != self.protocol {
return Err(CheckpointError::Incompatible(format!(
"Protocol mismatch: expected {}, got {}",
self.protocol, checkpoint.protocol
)));
}
self.current_state = bincode::deserialize(&checkpoint.state_data)
.map_err(|e| CheckpointError::Deserialization(e.to_string()))?;
self.sequence = checkpoint.sequence;
Ok(())
}
fn sequence(&self) -> u64 {
self.sequence
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
enum TestLabel {
Accept,
Reject,
Other,
}
impl LabelId for TestLabel {
fn as_str(&self) -> &'static str {
match self {
TestLabel::Accept => "Accept",
TestLabel::Reject => "Reject",
TestLabel::Other => "Other",
}
}
fn from_str(label: &str) -> Option<Self> {
match label {
"Accept" => Some(TestLabel::Accept),
"Reject" => Some(TestLabel::Reject),
"Other" => Some(TestLabel::Other),
_ => None,
}
}
}
#[test]
fn test_blocked_on_terminal() {
assert!(BlockedOn::<TestLabel>::Complete.is_terminal());
assert!(BlockedOn::<TestLabel>::Failed("error".to_string()).is_terminal());
assert!(!BlockedOn::<TestLabel>::Send {
to: RoleName::from_static("Server"),
message_type: "Request".to_string(),
}
.is_terminal());
}
#[test]
fn test_linear_state_machine() {
let states = vec![
BlockedOn::Send {
to: RoleName::from_static("Server"),
message_type: "Request".to_string(),
},
BlockedOn::Recv {
from: RoleName::from_static("Server"),
expected_types: vec!["Response".to_string()],
},
];
let mut sm = LinearStateMachine::<TestLabel>::new(
"TestProto",
RoleName::from_static("Client"),
states,
);
assert!(sm.blocked_on().is_send());
let send_env = super::super::envelope::ProtocolEnvelope::builder()
.protocol("TestProto")
.sender(RoleName::from_static("Client"))
.recipient(RoleName::from_static("Server"))
.message_type("Request")
.payload(vec![])
.build()
.unwrap();
let result = sm.step(StepInput::send(send_env.clone()));
assert!(result.is_ok());
assert!(matches!(result.unwrap(), StepOutput::Sent(_)));
assert!(sm.blocked_on().is_recv());
let recv_env = super::super::envelope::ProtocolEnvelope::builder()
.protocol("TestProto")
.sender(RoleName::from_static("Server"))
.recipient(RoleName::from_static("Client"))
.message_type("Response")
.payload(vec![])
.build()
.unwrap();
let result = sm.step(StepInput::recv(recv_env));
assert!(result.is_ok());
assert!(sm.blocked_on().is_terminal());
}
#[test]
fn test_checkpoint_roundtrip() {
let states = vec![BlockedOn::Send {
to: RoleName::from_static("Server"),
message_type: "Msg".to_string(),
}];
let sm =
LinearStateMachine::<TestLabel>::new("Proto", RoleName::from_static("Client"), states);
let checkpoint = sm.checkpoint().unwrap();
let bytes = checkpoint.to_bytes().unwrap();
let restored = Checkpoint::from_bytes(&bytes).unwrap();
assert_eq!(checkpoint.protocol, restored.protocol);
assert_eq!(checkpoint.sequence, restored.sequence);
}
#[test]
fn test_choice_validation() {
let states = vec![BlockedOn::Choice {
branches: vec![TestLabel::Accept, TestLabel::Reject],
}];
let mut sm =
LinearStateMachine::<TestLabel>::new("Proto", RoleName::from_static("Client"), states);
let result = sm.step(StepInput::choice(TestLabel::Other));
assert!(result.is_err());
let result = sm.step(StepInput::choice(TestLabel::Accept));
assert!(result.is_ok());
}
}