Skip to main content

amaru_protocols/protocol/
check.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![expect(clippy::panic, clippy::unwrap_used)]
16
17use std::{
18    collections::{BTreeMap, BTreeSet},
19    marker::PhantomData,
20};
21
22use crate::protocol::{ProtocolState, Role, RoleT};
23
24pub struct ProtoSpec<State, Message, R> {
25    transitions: BTreeMap<State, PerState<State, Message>>,
26    _phantom: PhantomData<R>,
27}
28
29#[derive(Debug, PartialEq, Eq, Clone)]
30struct PerState<State, Message> {
31    agency: Role,
32    transitions: BTreeMap<Message, (Role, State, bool)>,
33}
34
35impl<State, Message> PerState<State, Message> {
36    fn initiator() -> Self {
37        Self { agency: Role::Initiator, transitions: Default::default() }
38    }
39
40    fn responder() -> Self {
41        Self { agency: Role::Responder, transitions: Default::default() }
42    }
43
44    fn role(role: Role) -> Self {
45        Self { agency: role, transitions: Default::default() }
46    }
47
48    fn insert(&mut self, msg: Message, role: Role, to: State) -> Option<(Role, State)>
49    where
50        Message: std::fmt::Debug + Ord,
51        State: std::fmt::Debug,
52    {
53        assert_eq!(self.agency, role, "inserting {msg:?}@{role:?} to {to:?}");
54        self.transitions.insert(msg, (role, to, false)).map(|(r, t, _)| (r, t))
55    }
56
57    fn insert_sim_open(&mut self, msg: Message, role: Role, to: State) -> Option<(Role, State)>
58    where
59        Message: std::fmt::Debug + Ord,
60        State: std::fmt::Debug,
61    {
62        assert_eq!(self.agency, role, "inserting {msg:?}@{role:?} to {to:?}");
63        self.transitions.insert(msg, (role, to, true)).map(|(r, t, _)| (r, t))
64    }
65}
66
67impl<State, Message, R> Default for ProtoSpec<State, Message, R> {
68    fn default() -> Self {
69        Self { transitions: Default::default(), _phantom: PhantomData }
70    }
71}
72
73impl<State, Message, R> ProtoSpec<State, Message, R>
74where
75    State: Ord + std::fmt::Debug + Clone + ProtocolState<R, WireMsg = Message>,
76    Message: Ord + std::fmt::Debug + Clone,
77    R: RoleT,
78{
79    /// Add a transition that can be executed by the initiator.
80    pub fn init(&mut self, from: State, msg: Message, to: State) {
81        let present = self.transitions.entry(from.clone()).or_insert_with(PerState::initiator).insert(
82            msg.clone(),
83            Role::Initiator,
84            to.clone(),
85        );
86        if let Some(present) = present {
87            panic!("transition {:?} -> {:?} -> {:?} already defined when inserting {:?}", from, msg, present, to);
88        }
89    }
90
91    /// Add a transition that can be executed by the responder.
92    pub fn resp(&mut self, from: State, msg: Message, to: State) {
93        let present = self.transitions.entry(from.clone()).or_insert_with(PerState::responder).insert(
94            msg.clone(),
95            Role::Responder,
96            to.clone(),
97        );
98        if let Some(present) = present {
99            panic!("transition {:?} -> {:?} -> {:?} already defined when inserting {:?}", from, msg, present, to);
100        }
101    }
102
103    pub fn sim_open(&mut self, from: State, msg: Message, to: State) {
104        let present = self.transitions.entry(from.clone()).or_insert_with(PerState::responder).insert_sim_open(
105            msg.clone(),
106            Role::Responder,
107            to.clone(),
108        );
109        if let Some(present) = present {
110            panic!("transition {:?} -> {:?} -> {:?} already defined when inserting {:?}", from, msg, present, to);
111        }
112    }
113
114    /// Check that the protocol implementation follows the spec.
115    ///
116    /// The `local_msg` function turns the network message under test
117    /// into a local action so that the protocol can be tested.
118    #[expect(clippy::expect_used)]
119    pub fn check(&self, initial: State, local_msg: impl Fn(&Message) -> Option<State::Action>) {
120        let role = const { R::ROLE.unwrap() };
121
122        let states = self.transitions.keys().collect::<Vec<_>>();
123        let messages = self.transitions.values().flat_map(|m| m.transitions.keys()).collect::<BTreeSet<_>>();
124
125        let (out, init) = initial.init().unwrap();
126        match role {
127            Role::Initiator => {
128                if let Some(_send) = out.send.as_ref() {
129                    assert_ne!(initial, init, "initialization with send must transition to a different state");
130                } else {
131                    assert_eq!(initial, init, "initialization without send must remain in the same state");
132                }
133            }
134            Role::Responder => {
135                assert!(out.send.is_none());
136                assert_eq!(initial, init, "initialization without send must remain in the same state");
137            }
138        }
139        assert_eq!(
140            out.want_next,
141            self.transitions.get(&init).expect("init() transitions to non-existent state").agency == role.opposite(),
142            "initialization must want_next for responder and not for initiator (unless sending from init()) (got {out:?})"
143        );
144
145        for state in states {
146            for &message in &messages {
147                let to = self.transitions.get(state).and_then(|m| m.transitions.get(message));
148                if state == &initial && Some(message) == out.send.as_ref() {
149                    assert_eq!(Some(&init), to.map(|(_, s, _)| s));
150                    continue;
151                }
152                let (must_be_local, is_sim_open) = to.map(|(r, _, s)| (*r == role, *s)).unwrap_or((false, false));
153
154                let outcome = if must_be_local {
155                    assert_eq!(
156                        None,
157                        state.network(message.clone()).ok(),
158                        "state {state:?} allows network message {message:?} while local node has agency"
159                    );
160                    let Some(local_msg) = local_msg(message) else {
161                        if is_sim_open {
162                            continue;
163                        }
164                        panic!("local message {message:?} not declared for {state:?} in check() arguments");
165                    };
166                    state.local(local_msg).ok()
167                } else {
168                    assert_eq!(
169                        None,
170                        local_msg(message).and_then(|action| state.local(action).ok()),
171                        "state {state:?} allows local message {message:?} while the peer may have agency"
172                    );
173                    state.network(message.clone()).ok().map(|(outcome, next)| (outcome.without_result(), next))
174                };
175
176                let ((r, to, _), (send, next)) = match (to, outcome) {
177                    (None, None) => continue,
178                    (None, Some(_)) => panic!("extraneous transition {:?} -> {:?}", state, message),
179                    (Some(_), None) => panic!("missing transition {:?} -> {:?} for {:?}", state, message, to),
180                    (Some(to), Some(outcome)) => (to, outcome),
181                };
182                // we only get here if `to` was `Some`, meaning that must_be_local == is_local
183                let is_local = must_be_local;
184
185                if is_local {
186                    assert_eq!(*r, role, "sending {message:?} not allowed for {role:?}");
187                    assert_eq!(send.send.as_ref(), Some(message), "sending message in state {state:?}");
188                    assert_eq!(&next, to, "final state mismatch for {state:?} -> {message:?}");
189                } else {
190                    assert_eq!(*r, role.opposite(), "expecting {message:?} not allowed for {role:?}");
191                    if let Some(send) = send.send.as_ref() {
192                        let to2 = self.transitions.get(to).and_then(|m| m.transitions.get(send));
193                        if let Some((r2, to2, _)) = to2 {
194                            assert_eq!(*r2, role, "sending {send:?} not allowed for {role:?}");
195                            assert_eq!(to2, &next, "final state mismatch for {to:?} -> {send:?}");
196                        } else {
197                            panic!("extraneous transition {:?} -> {:?}", to, send);
198                        }
199                    } else {
200                        assert_eq!(&next, to, "final state mismatch for {state:?} -> {message:?}");
201                    }
202                }
203
204                // check that want-next is called when transitioning into a state with remote agency
205                // (note that transition into final state will yield None for the get())
206                if let Some(s) = self.transitions.get(&next) {
207                    if s.agency == role.opposite() {
208                        assert!(
209                            send.want_next,
210                            "transition into state with remote agency requires want_next: {state:?} -> {message:?} -> {to:?} (got {send:?})"
211                        );
212                    } else {
213                        assert!(
214                            !send.want_next,
215                            "transition into state with local agency should not want_next: {state:?} -> {message:?} -> {to:?} (got {send:?})"
216                        );
217                    }
218                }
219            }
220        }
221    }
222
223    /// Assert that this protocol refines the other protocol.
224    ///
225    /// This means that this protocol has more states than the other
226    /// protocol, thus the state projection must be a surjection.
227    pub fn assert_refines<S2, R2>(&self, other: &ProtoSpec<S2, Message, R2>, surjection: impl Fn(&State) -> S2)
228    where
229        S2: Ord + std::fmt::Debug + Clone + ProtocolState<R2, WireMsg = Message>,
230        R2: RoleT,
231    {
232        let mut simplified = BTreeMap::<S2, PerState<S2, Message>>::new();
233
234        for (from, per_state) in &self.transitions {
235            let from = surjection(from);
236            for (message, (role, to, _)) in per_state.transitions.iter() {
237                let to = surjection(to);
238                let existing_target = simplified.entry(from.clone()).or_insert_with(|| PerState::role(*role)).insert(
239                    message.clone(),
240                    *role,
241                    to.clone(),
242                );
243                if let Some((existing_role, existing_target)) = existing_target.as_ref()
244                    && (existing_target != &to || existing_role != role)
245                {
246                    panic!(
247                        "transition {:?} -> {:?} -> {:?} already defined with different target state when inserting {:?}",
248                        from, message, existing_target, to
249                    );
250                }
251            }
252        }
253
254        assert_eq!(simplified, other.transitions);
255    }
256}