1#![expect(clippy::panic, clippy::unwrap_used)]
16
17use std::{
18 collections::{BTreeMap, BTreeSet},
19 marker::PhantomData,
20};
21
22use crate::protocol::{ProtocolState, Role, RoleT};
23
24pub struct ProtoSpec<State, Message, R> {
25 transitions: BTreeMap<State, PerState<State, Message>>,
26 _phantom: PhantomData<R>,
27}
28
29#[derive(Debug, PartialEq, Eq, Clone)]
30struct PerState<State, Message> {
31 agency: Role,
32 transitions: BTreeMap<Message, (Role, State, bool)>,
33}
34
35impl<State, Message> PerState<State, Message> {
36 fn initiator() -> Self {
37 Self { agency: Role::Initiator, transitions: Default::default() }
38 }
39
40 fn responder() -> Self {
41 Self { agency: Role::Responder, transitions: Default::default() }
42 }
43
44 fn role(role: Role) -> Self {
45 Self { agency: role, transitions: Default::default() }
46 }
47
48 fn insert(&mut self, msg: Message, role: Role, to: State) -> Option<(Role, State)>
49 where
50 Message: std::fmt::Debug + Ord,
51 State: std::fmt::Debug,
52 {
53 assert_eq!(self.agency, role, "inserting {msg:?}@{role:?} to {to:?}");
54 self.transitions.insert(msg, (role, to, false)).map(|(r, t, _)| (r, t))
55 }
56
57 fn insert_sim_open(&mut self, msg: Message, role: Role, to: State) -> Option<(Role, State)>
58 where
59 Message: std::fmt::Debug + Ord,
60 State: std::fmt::Debug,
61 {
62 assert_eq!(self.agency, role, "inserting {msg:?}@{role:?} to {to:?}");
63 self.transitions.insert(msg, (role, to, true)).map(|(r, t, _)| (r, t))
64 }
65}
66
67impl<State, Message, R> Default for ProtoSpec<State, Message, R> {
68 fn default() -> Self {
69 Self { transitions: Default::default(), _phantom: PhantomData }
70 }
71}
72
73impl<State, Message, R> ProtoSpec<State, Message, R>
74where
75 State: Ord + std::fmt::Debug + Clone + ProtocolState<R, WireMsg = Message>,
76 Message: Ord + std::fmt::Debug + Clone,
77 R: RoleT,
78{
79 pub fn init(&mut self, from: State, msg: Message, to: State) {
81 let present = self.transitions.entry(from.clone()).or_insert_with(PerState::initiator).insert(
82 msg.clone(),
83 Role::Initiator,
84 to.clone(),
85 );
86 if let Some(present) = present {
87 panic!("transition {:?} -> {:?} -> {:?} already defined when inserting {:?}", from, msg, present, to);
88 }
89 }
90
91 pub fn resp(&mut self, from: State, msg: Message, to: State) {
93 let present = self.transitions.entry(from.clone()).or_insert_with(PerState::responder).insert(
94 msg.clone(),
95 Role::Responder,
96 to.clone(),
97 );
98 if let Some(present) = present {
99 panic!("transition {:?} -> {:?} -> {:?} already defined when inserting {:?}", from, msg, present, to);
100 }
101 }
102
103 pub fn sim_open(&mut self, from: State, msg: Message, to: State) {
104 let present = self.transitions.entry(from.clone()).or_insert_with(PerState::responder).insert_sim_open(
105 msg.clone(),
106 Role::Responder,
107 to.clone(),
108 );
109 if let Some(present) = present {
110 panic!("transition {:?} -> {:?} -> {:?} already defined when inserting {:?}", from, msg, present, to);
111 }
112 }
113
114 #[expect(clippy::expect_used)]
119 pub fn check(&self, initial: State, local_msg: impl Fn(&Message) -> Option<State::Action>) {
120 let role = const { R::ROLE.unwrap() };
121
122 let states = self.transitions.keys().collect::<Vec<_>>();
123 let messages = self.transitions.values().flat_map(|m| m.transitions.keys()).collect::<BTreeSet<_>>();
124
125 let (out, init) = initial.init().unwrap();
126 match role {
127 Role::Initiator => {
128 if let Some(_send) = out.send.as_ref() {
129 assert_ne!(initial, init, "initialization with send must transition to a different state");
130 } else {
131 assert_eq!(initial, init, "initialization without send must remain in the same state");
132 }
133 }
134 Role::Responder => {
135 assert!(out.send.is_none());
136 assert_eq!(initial, init, "initialization without send must remain in the same state");
137 }
138 }
139 assert_eq!(
140 out.want_next,
141 self.transitions.get(&init).expect("init() transitions to non-existent state").agency == role.opposite(),
142 "initialization must want_next for responder and not for initiator (unless sending from init()) (got {out:?})"
143 );
144
145 for state in states {
146 for &message in &messages {
147 let to = self.transitions.get(state).and_then(|m| m.transitions.get(message));
148 if state == &initial && Some(message) == out.send.as_ref() {
149 assert_eq!(Some(&init), to.map(|(_, s, _)| s));
150 continue;
151 }
152 let (must_be_local, is_sim_open) = to.map(|(r, _, s)| (*r == role, *s)).unwrap_or((false, false));
153
154 let outcome = if must_be_local {
155 assert_eq!(
156 None,
157 state.network(message.clone()).ok(),
158 "state {state:?} allows network message {message:?} while local node has agency"
159 );
160 let Some(local_msg) = local_msg(message) else {
161 if is_sim_open {
162 continue;
163 }
164 panic!("local message {message:?} not declared for {state:?} in check() arguments");
165 };
166 state.local(local_msg).ok()
167 } else {
168 assert_eq!(
169 None,
170 local_msg(message).and_then(|action| state.local(action).ok()),
171 "state {state:?} allows local message {message:?} while the peer may have agency"
172 );
173 state.network(message.clone()).ok().map(|(outcome, next)| (outcome.without_result(), next))
174 };
175
176 let ((r, to, _), (send, next)) = match (to, outcome) {
177 (None, None) => continue,
178 (None, Some(_)) => panic!("extraneous transition {:?} -> {:?}", state, message),
179 (Some(_), None) => panic!("missing transition {:?} -> {:?} for {:?}", state, message, to),
180 (Some(to), Some(outcome)) => (to, outcome),
181 };
182 let is_local = must_be_local;
184
185 if is_local {
186 assert_eq!(*r, role, "sending {message:?} not allowed for {role:?}");
187 assert_eq!(send.send.as_ref(), Some(message), "sending message in state {state:?}");
188 assert_eq!(&next, to, "final state mismatch for {state:?} -> {message:?}");
189 } else {
190 assert_eq!(*r, role.opposite(), "expecting {message:?} not allowed for {role:?}");
191 if let Some(send) = send.send.as_ref() {
192 let to2 = self.transitions.get(to).and_then(|m| m.transitions.get(send));
193 if let Some((r2, to2, _)) = to2 {
194 assert_eq!(*r2, role, "sending {send:?} not allowed for {role:?}");
195 assert_eq!(to2, &next, "final state mismatch for {to:?} -> {send:?}");
196 } else {
197 panic!("extraneous transition {:?} -> {:?}", to, send);
198 }
199 } else {
200 assert_eq!(&next, to, "final state mismatch for {state:?} -> {message:?}");
201 }
202 }
203
204 if let Some(s) = self.transitions.get(&next) {
207 if s.agency == role.opposite() {
208 assert!(
209 send.want_next,
210 "transition into state with remote agency requires want_next: {state:?} -> {message:?} -> {to:?} (got {send:?})"
211 );
212 } else {
213 assert!(
214 !send.want_next,
215 "transition into state with local agency should not want_next: {state:?} -> {message:?} -> {to:?} (got {send:?})"
216 );
217 }
218 }
219 }
220 }
221 }
222
223 pub fn assert_refines<S2, R2>(&self, other: &ProtoSpec<S2, Message, R2>, surjection: impl Fn(&State) -> S2)
228 where
229 S2: Ord + std::fmt::Debug + Clone + ProtocolState<R2, WireMsg = Message>,
230 R2: RoleT,
231 {
232 let mut simplified = BTreeMap::<S2, PerState<S2, Message>>::new();
233
234 for (from, per_state) in &self.transitions {
235 let from = surjection(from);
236 for (message, (role, to, _)) in per_state.transitions.iter() {
237 let to = surjection(to);
238 let existing_target = simplified.entry(from.clone()).or_insert_with(|| PerState::role(*role)).insert(
239 message.clone(),
240 *role,
241 to.clone(),
242 );
243 if let Some((existing_role, existing_target)) = existing_target.as_ref()
244 && (existing_target != &to || existing_role != role)
245 {
246 panic!(
247 "transition {:?} -> {:?} -> {:?} already defined with different target state when inserting {:?}",
248 from, message, existing_target, to
249 );
250 }
251 }
252 }
253
254 assert_eq!(simplified, other.transitions);
255 }
256}