manul/session/
tokio.rs

1//! High-level API for executing sessions in `tokio` tasks.
2
3use alloc::{format, sync::Arc, vec::Vec};
4
5use rand_core::CryptoRngCore;
6use tokio::{sync::mpsc, task::JoinHandle};
7use tracing::{debug, trace};
8
9use super::{
10    message::Message,
11    session::{CanFinalize, ProcessedArtifact, ProcessedMessage, RoundOutcome, Session, SessionId, SessionParameters},
12    transcript::SessionReport,
13    LocalError,
14};
15use crate::protocol::Protocol;
16
17/// The outgoing message from a local session.
18#[derive(Debug)]
19pub struct MessageOut<SP: SessionParameters> {
20    /// The session ID that created the message.
21    ///
22    /// Useful when there are several sessions running on a node, pushing messages into the same channel.
23    pub session_id: SessionId,
24    /// The verifying key of the party that created the message.
25    ///
26    /// Useful when there are several sessions running on a node, pushing messages into the same channel.
27    pub from: SP::Verifier,
28    /// The verifying key of the party the message is intended for.
29    pub to: SP::Verifier,
30    /// The message to be sent.
31    ///
32    /// Note that the caller is responsible for encrypting the message and attaching authentication info.
33    pub message: Message<SP::Verifier>,
34}
35
36/// The incoming message from a remote session.
37#[derive(Debug)]
38pub struct MessageIn<SP: SessionParameters> {
39    /// The verifying key of the party the message originated from.
40    ///
41    /// It is assumed that the message's authentication info has been checked at this point.
42    pub from: SP::Verifier,
43    /// The incoming message.
44    pub message: Message<SP::Verifier>,
45}
46
47/// Executes the session waiting for the messages from the `rx` channel
48/// and pushing outgoing messages into the `tx` channel.
49pub async fn run_session<P, SP>(
50    rng: &mut impl CryptoRngCore,
51    tx: &mpsc::Sender<MessageOut<SP>>,
52    rx: &mut mpsc::Receiver<MessageIn<SP>>,
53    session: Session<P, SP>,
54) -> Result<SessionReport<P, SP>, LocalError>
55where
56    P: Protocol<SP::Verifier>,
57    SP: SessionParameters,
58{
59    let mut session = session;
60    // Some rounds can finalize early and put off sending messages to the next round. Such messages
61    // will be stored here and applied after the messages for this round are sent.
62    let mut cached_messages = Vec::new();
63
64    let my_id = format!("{:?}", session.verifier());
65
66    // Each iteration of the loop progresses the session as follows:
67    //  - Send out messages as dictated by the session "destinations".
68    //  - Apply any cached messages.
69    //  - Enter a nested loop:
70    //      - Try to finalize the session; if we're done, exit the inner loop.
71    //      - Wait until we get an incoming message.
72    //      - Process the message we received and continue the loop.
73    //  - When all messages have been sent and received as specified by the protocol, finalize the
74    //    round.
75    //  - If the protocol outcome is a new round, go to the top of the loop and start over with a
76    //    new session.
77    loop {
78        debug!("{my_id}: *** starting round {:?} ***", session.round_id());
79
80        // This is kept in the main task since it's mutable,
81        // and we don't want to bother with synchronization.
82        let mut accum = session.make_accumulator();
83
84        // Note: generating/sending messages and verifying newly received messages
85        // can be done in parallel, with the results being assembled into `accum`
86        // sequentially in the host task.
87
88        let destinations = session.message_destinations();
89        for destination in destinations.iter() {
90            // In production usage, this will happen in a spawned task
91            // (since it can take some time to create a message),
92            // and the artifact will be sent back to the host task
93            // to be added to the accumulator.
94            let (message, artifact) = session.make_message(rng, destination)?;
95            debug!("{my_id}: Sending a message to {destination:?}",);
96            tx.send(MessageOut {
97                session_id: session.session_id().clone(),
98                from: session.verifier().clone(),
99                to: destination.clone(),
100                message,
101            })
102            .await
103            .map_err(|err| {
104                LocalError::new(format!(
105                    "Failed to send a message from {:?} to {:?}: {err}",
106                    session.verifier(),
107                    destination
108                ))
109            })?;
110
111            // This would happen in a host task
112            session.add_artifact(&mut accum, artifact)?;
113        }
114
115        for preprocessed in cached_messages {
116            // In production usage, this would happen in a spawned task and relayed back to the main task.
117            debug!("{my_id}: Applying a cached message");
118            let processed = session.process_message(preprocessed);
119
120            // This would happen in a host task.
121            session.add_processed_message(&mut accum, processed)?;
122        }
123
124        loop {
125            match session.can_finalize(&accum) {
126                CanFinalize::Yes => break,
127                CanFinalize::NotYet => {}
128                // Due to already registered invalid messages from nodes,
129                // even if the remaining nodes send correct messages, it won't be enough.
130                // Terminating.
131                CanFinalize::Never => {
132                    tracing::warn!("{my_id}: This session cannot ever be finalized. Terminating.");
133                    return session.terminate_due_to_errors(accum);
134                }
135            }
136
137            debug!("{my_id}: Waiting for a message");
138            let message_in = rx
139                .recv()
140                .await
141                .ok_or_else(|| LocalError::new("Failed to receive a message"))?;
142
143            // Perform quick checks before proceeding with the verification.
144            match session
145                .preprocess_message(&mut accum, &message_in.from, message_in.message)?
146                .ok()
147            {
148                Some(preprocessed) => {
149                    // In production usage, this would happen in a separate task.
150                    debug!("{my_id}: Applying a message from {:?}", message_in.from);
151                    let processed = session.process_message(preprocessed);
152                    // In production usage, this would be a host task.
153                    session.add_processed_message(&mut accum, processed)?;
154                }
155                None => {
156                    trace!("{my_id} Pre-processing complete. Current state: {accum:?}")
157                }
158            }
159        }
160
161        debug!("{my_id}: Finalizing the round");
162
163        match session.finalize_round(rng, accum)? {
164            RoundOutcome::Finished(report) => break Ok(report),
165            RoundOutcome::AnotherRound {
166                session: new_session,
167                cached_messages: new_cached_messages,
168            } => {
169                session = new_session;
170                cached_messages = new_cached_messages;
171            }
172        }
173    }
174}
175
176/// Executes the session waiting for the messages from the `rx` channel
177/// and pushing outgoing messages into the `tx` channel.
178/// The messages are processed in parallel.
179///
180/// This function should be used if message creation and verification takes a significant amount of time,
181/// to offset the parallelizing overhead.
182/// Use [`tokio::run_async`](`crate::dev::tokio::run_async`) to benchmark your specific protocol.
183pub async fn par_run_session<P, SP>(
184    rng: &mut (impl 'static + Clone + CryptoRngCore + Send),
185    tx: &mpsc::Sender<MessageOut<SP>>,
186    rx: &mut mpsc::Receiver<MessageIn<SP>>,
187    session: Session<P, SP>,
188) -> Result<SessionReport<P, SP>, LocalError>
189where
190    P: Protocol<SP::Verifier>,
191    SP: SessionParameters,
192    <SP as SessionParameters>::Signer: Send + Sync,
193    <P as Protocol<SP::Verifier>>::ProtocolError: Send + Sync,
194{
195    let mut session = Arc::new(session);
196    // Some rounds can finalize early and put off sending messages to the next round. Such messages
197    // will be stored here and applied after the messages for this round are sent.
198    let mut cached_messages = Vec::new();
199
200    let my_id = format!("{:?}", session.verifier());
201
202    // Each iteration of the loop progresses the session as follows:
203    //  - Send out messages as dictated by the session "destinations".
204    //  - Apply any cached messages.
205    //  - Enter a nested loop:
206    //      - Try to finalize the session; if we're done, exit the inner loop.
207    //      - Wait until we get an incoming message.
208    //      - Process the message we received and continue the loop.
209    //  - When all messages have been sent and received as specified by the protocol, finalize the
210    //    round.
211    //  - If the protocol outcome is a new round, go to the top of the loop and start over with a
212    //    new session.
213    loop {
214        debug!("{my_id}: *** starting round {:?} ***", session.round_id());
215
216        let (processed_tx, mut processed_rx) = mpsc::channel::<ProcessedMessage<P, SP>>(100);
217        let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<(MessageOut<SP>, ProcessedArtifact<SP>)>(100);
218
219        // This is kept in the main task since it's mutable,
220        // and we don't want to bother with synchronization.
221        let mut accum = session.make_accumulator();
222
223        // Note: generating/sending messages and verifying newly received messages
224        // can be done in parallel, with the results being assembled into `accum`
225        // sequentially in the host task.
226
227        let destinations = session.message_destinations();
228        let mut message_creation_tasks = Vec::new();
229        for destination in destinations {
230            let rng = rng.clone();
231            let session = session.clone();
232            let my_id = my_id.clone();
233            let outgoing_tx = outgoing_tx.clone();
234            let destination = destination.clone();
235            let message_creation = tokio::task::spawn_blocking(move || {
236                let mut rng = rng;
237                let (message, artifact) = session.make_message(&mut rng, &destination)?;
238                debug!("{my_id}: Sending a message to {destination:?}",);
239                let message_out = MessageOut {
240                    session_id: session.session_id().clone(),
241                    from: session.verifier().clone(),
242                    to: destination.clone(),
243                    message,
244                };
245                outgoing_tx.blocking_send((message_out, artifact)).map_err(|err| {
246                    LocalError::new(format!(
247                        "Failed to send a created message from {:?} to {:?}: {err}",
248                        session.verifier(),
249                        destination
250                    ))
251                })
252            });
253            message_creation_tasks.push(message_creation);
254        }
255
256        let mut message_processing_tasks = Vec::new();
257        for preprocessed in cached_messages {
258            let session = session.clone();
259            let processed_tx = processed_tx.clone();
260            let my_id = my_id.clone();
261            let message_processing: JoinHandle<Result<(), LocalError>> = tokio::task::spawn_blocking(move || {
262                debug!("{my_id}: Applying a cached message");
263                let processed = session.process_message(preprocessed);
264                processed_tx
265                    .blocking_send(processed)
266                    .map_err(|_err| LocalError::new("Failed to send a processed message"))
267            });
268            message_processing_tasks.push(message_processing);
269        }
270
271        let can_finalize = loop {
272            match session.can_finalize(&accum) {
273                CanFinalize::Yes => break true,
274                CanFinalize::NotYet => {}
275                // Due to already registered invalid messages from nodes,
276                // even if the remaining nodes send correct messages, it won't be enough.
277                // Terminating.
278                CanFinalize::Never => break false,
279            }
280
281            tokio::select! {
282                processed = processed_rx.recv() => {
283                    if let Some(processed) = processed {
284                        session.add_processed_message(&mut accum, processed)?;
285                    }
286                }
287                outgoing = outgoing_rx.recv() => {
288                    if let Some((message_out, artifact)) = outgoing {
289                        let from = message_out.from.clone();
290                        let to = message_out.to.clone();
291                        tx.send(message_out)
292                        .await
293                        .map_err(|err| {
294                            LocalError::new(format!(
295                                "Failed to send a message from {from:?} to {to:?}: {err}",
296                            ))
297                        })?;
298
299                        session.add_artifact(&mut accum, artifact)?;
300                    }
301                }
302                message_in = rx.recv() => {
303                    if let Some(message_in) = message_in {
304                        match session
305                            .preprocess_message(&mut accum, &message_in.from, message_in.message)?
306                            .ok()
307                        {
308                            Some(preprocessed) => {
309                                let session = session.clone();
310                                let processed_tx = processed_tx.clone();
311                                let my_id = my_id.clone();
312                                let message_processing = tokio::task::spawn_blocking(move || {
313                                    debug!("{my_id}: Applying a message from {:?}", message_in.from);
314                                    let processed = session.process_message(preprocessed);
315                                    processed_tx.blocking_send(processed).map_err(|_err| {
316                                        LocalError::new("Failed to send a processed message")
317                                    })
318                                });
319                                message_processing_tasks.push(message_processing);
320                            }
321                            None => {
322                                trace!("{my_id} Pre-processing complete. Current state: {accum:?}")
323                            }
324                        }
325                    }
326                }
327            }
328        };
329
330        debug!("{my_id}: Finalizing the round {}", session.round_id());
331
332        // Join all the handles created in this iteration.
333
334        for message_creation_task in message_creation_tasks {
335            message_creation_task
336                .await
337                .map_err(|_err| LocalError::new("Failed to join a message creation task"))??;
338        }
339
340        for message_processing_task in message_processing_tasks {
341            message_processing_task
342                .await
343                .map_err(|_err| LocalError::new("Failed to join a message processing task"))??;
344        }
345
346        // Drop our copies of `Sender`s to let the channels close.
347        drop(outgoing_tx);
348        drop(processed_tx);
349
350        // Send all the remaining messages
351        while let Some((message_out, artifact)) = outgoing_rx.recv().await {
352            let from = message_out.from.clone();
353            let to = message_out.to.clone();
354            tx.send(message_out)
355                .await
356                .map_err(|err| LocalError::new(format!("Failed to send a message from {from:?} to {to:?}: {err}",)))?;
357
358            session.add_artifact(&mut accum, artifact)?;
359        }
360
361        debug!("{my_id}: Sent out all remaining messages");
362
363        let session_inner = Arc::into_inner(session)
364            .ok_or_else(|| LocalError::new("There are still references to the session left"))?;
365
366        if !can_finalize {
367            return session_inner.terminate_due_to_errors(accum);
368        }
369
370        match session_inner.finalize_round(rng, accum)? {
371            RoundOutcome::Finished(report) => return Ok(report),
372            RoundOutcome::AnotherRound {
373                session: new_session,
374                cached_messages: new_cached_messages,
375            } => {
376                session = Arc::new(new_session);
377                cached_messages = new_cached_messages;
378            }
379        }
380    }
381}