1use alloc::{format, sync::Arc, vec::Vec};
4
5use rand_core::CryptoRngCore;
6use tokio::{sync::mpsc, task::JoinHandle};
7use tracing::{debug, trace};
8
9use super::{
10 message::Message,
11 session::{CanFinalize, ProcessedArtifact, ProcessedMessage, RoundOutcome, Session, SessionId, SessionParameters},
12 transcript::SessionReport,
13 LocalError,
14};
15use crate::protocol::Protocol;
16
17#[derive(Debug)]
19pub struct MessageOut<SP: SessionParameters> {
20 pub session_id: SessionId,
24 pub from: SP::Verifier,
28 pub to: SP::Verifier,
30 pub message: Message<SP::Verifier>,
34}
35
36#[derive(Debug)]
38pub struct MessageIn<SP: SessionParameters> {
39 pub from: SP::Verifier,
43 pub message: Message<SP::Verifier>,
45}
46
47pub async fn run_session<P, SP>(
50 rng: &mut impl CryptoRngCore,
51 tx: &mpsc::Sender<MessageOut<SP>>,
52 rx: &mut mpsc::Receiver<MessageIn<SP>>,
53 session: Session<P, SP>,
54) -> Result<SessionReport<P, SP>, LocalError>
55where
56 P: Protocol<SP::Verifier>,
57 SP: SessionParameters,
58{
59 let mut session = session;
60 let mut cached_messages = Vec::new();
63
64 let my_id = format!("{:?}", session.verifier());
65
66 loop {
78 debug!("{my_id}: *** starting round {:?} ***", session.round_id());
79
80 let mut accum = session.make_accumulator();
83
84 let destinations = session.message_destinations();
89 for destination in destinations.iter() {
90 let (message, artifact) = session.make_message(rng, destination)?;
95 debug!("{my_id}: Sending a message to {destination:?}",);
96 tx.send(MessageOut {
97 session_id: session.session_id().clone(),
98 from: session.verifier().clone(),
99 to: destination.clone(),
100 message,
101 })
102 .await
103 .map_err(|err| {
104 LocalError::new(format!(
105 "Failed to send a message from {:?} to {:?}: {err}",
106 session.verifier(),
107 destination
108 ))
109 })?;
110
111 session.add_artifact(&mut accum, artifact)?;
113 }
114
115 for preprocessed in cached_messages {
116 debug!("{my_id}: Applying a cached message");
118 let processed = session.process_message(preprocessed);
119
120 session.add_processed_message(&mut accum, processed)?;
122 }
123
124 loop {
125 match session.can_finalize(&accum) {
126 CanFinalize::Yes => break,
127 CanFinalize::NotYet => {}
128 CanFinalize::Never => {
132 tracing::warn!("{my_id}: This session cannot ever be finalized. Terminating.");
133 return session.terminate_due_to_errors(accum);
134 }
135 }
136
137 debug!("{my_id}: Waiting for a message");
138 let message_in = rx
139 .recv()
140 .await
141 .ok_or_else(|| LocalError::new("Failed to receive a message"))?;
142
143 match session
145 .preprocess_message(&mut accum, &message_in.from, message_in.message)?
146 .ok()
147 {
148 Some(preprocessed) => {
149 debug!("{my_id}: Applying a message from {:?}", message_in.from);
151 let processed = session.process_message(preprocessed);
152 session.add_processed_message(&mut accum, processed)?;
154 }
155 None => {
156 trace!("{my_id} Pre-processing complete. Current state: {accum:?}")
157 }
158 }
159 }
160
161 debug!("{my_id}: Finalizing the round");
162
163 match session.finalize_round(rng, accum)? {
164 RoundOutcome::Finished(report) => break Ok(report),
165 RoundOutcome::AnotherRound {
166 session: new_session,
167 cached_messages: new_cached_messages,
168 } => {
169 session = new_session;
170 cached_messages = new_cached_messages;
171 }
172 }
173 }
174}
175
176pub async fn par_run_session<P, SP>(
184 rng: &mut (impl 'static + Clone + CryptoRngCore + Send),
185 tx: &mpsc::Sender<MessageOut<SP>>,
186 rx: &mut mpsc::Receiver<MessageIn<SP>>,
187 session: Session<P, SP>,
188) -> Result<SessionReport<P, SP>, LocalError>
189where
190 P: Protocol<SP::Verifier>,
191 SP: SessionParameters,
192 <SP as SessionParameters>::Signer: Send + Sync,
193 <P as Protocol<SP::Verifier>>::ProtocolError: Send + Sync,
194{
195 let mut session = Arc::new(session);
196 let mut cached_messages = Vec::new();
199
200 let my_id = format!("{:?}", session.verifier());
201
202 loop {
214 debug!("{my_id}: *** starting round {:?} ***", session.round_id());
215
216 let (processed_tx, mut processed_rx) = mpsc::channel::<ProcessedMessage<P, SP>>(100);
217 let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<(MessageOut<SP>, ProcessedArtifact<SP>)>(100);
218
219 let mut accum = session.make_accumulator();
222
223 let destinations = session.message_destinations();
228 let mut message_creation_tasks = Vec::new();
229 for destination in destinations {
230 let rng = rng.clone();
231 let session = session.clone();
232 let my_id = my_id.clone();
233 let outgoing_tx = outgoing_tx.clone();
234 let destination = destination.clone();
235 let message_creation = tokio::task::spawn_blocking(move || {
236 let mut rng = rng;
237 let (message, artifact) = session.make_message(&mut rng, &destination)?;
238 debug!("{my_id}: Sending a message to {destination:?}",);
239 let message_out = MessageOut {
240 session_id: session.session_id().clone(),
241 from: session.verifier().clone(),
242 to: destination.clone(),
243 message,
244 };
245 outgoing_tx.blocking_send((message_out, artifact)).map_err(|err| {
246 LocalError::new(format!(
247 "Failed to send a created message from {:?} to {:?}: {err}",
248 session.verifier(),
249 destination
250 ))
251 })
252 });
253 message_creation_tasks.push(message_creation);
254 }
255
256 let mut message_processing_tasks = Vec::new();
257 for preprocessed in cached_messages {
258 let session = session.clone();
259 let processed_tx = processed_tx.clone();
260 let my_id = my_id.clone();
261 let message_processing: JoinHandle<Result<(), LocalError>> = tokio::task::spawn_blocking(move || {
262 debug!("{my_id}: Applying a cached message");
263 let processed = session.process_message(preprocessed);
264 processed_tx
265 .blocking_send(processed)
266 .map_err(|_err| LocalError::new("Failed to send a processed message"))
267 });
268 message_processing_tasks.push(message_processing);
269 }
270
271 let can_finalize = loop {
272 match session.can_finalize(&accum) {
273 CanFinalize::Yes => break true,
274 CanFinalize::NotYet => {}
275 CanFinalize::Never => break false,
279 }
280
281 tokio::select! {
282 processed = processed_rx.recv() => {
283 if let Some(processed) = processed {
284 session.add_processed_message(&mut accum, processed)?;
285 }
286 }
287 outgoing = outgoing_rx.recv() => {
288 if let Some((message_out, artifact)) = outgoing {
289 let from = message_out.from.clone();
290 let to = message_out.to.clone();
291 tx.send(message_out)
292 .await
293 .map_err(|err| {
294 LocalError::new(format!(
295 "Failed to send a message from {from:?} to {to:?}: {err}",
296 ))
297 })?;
298
299 session.add_artifact(&mut accum, artifact)?;
300 }
301 }
302 message_in = rx.recv() => {
303 if let Some(message_in) = message_in {
304 match session
305 .preprocess_message(&mut accum, &message_in.from, message_in.message)?
306 .ok()
307 {
308 Some(preprocessed) => {
309 let session = session.clone();
310 let processed_tx = processed_tx.clone();
311 let my_id = my_id.clone();
312 let message_processing = tokio::task::spawn_blocking(move || {
313 debug!("{my_id}: Applying a message from {:?}", message_in.from);
314 let processed = session.process_message(preprocessed);
315 processed_tx.blocking_send(processed).map_err(|_err| {
316 LocalError::new("Failed to send a processed message")
317 })
318 });
319 message_processing_tasks.push(message_processing);
320 }
321 None => {
322 trace!("{my_id} Pre-processing complete. Current state: {accum:?}")
323 }
324 }
325 }
326 }
327 }
328 };
329
330 debug!("{my_id}: Finalizing the round {}", session.round_id());
331
332 for message_creation_task in message_creation_tasks {
335 message_creation_task
336 .await
337 .map_err(|_err| LocalError::new("Failed to join a message creation task"))??;
338 }
339
340 for message_processing_task in message_processing_tasks {
341 message_processing_task
342 .await
343 .map_err(|_err| LocalError::new("Failed to join a message processing task"))??;
344 }
345
346 drop(outgoing_tx);
348 drop(processed_tx);
349
350 while let Some((message_out, artifact)) = outgoing_rx.recv().await {
352 let from = message_out.from.clone();
353 let to = message_out.to.clone();
354 tx.send(message_out)
355 .await
356 .map_err(|err| LocalError::new(format!("Failed to send a message from {from:?} to {to:?}: {err}",)))?;
357
358 session.add_artifact(&mut accum, artifact)?;
359 }
360
361 debug!("{my_id}: Sent out all remaining messages");
362
363 let session_inner = Arc::into_inner(session)
364 .ok_or_else(|| LocalError::new("There are still references to the session left"))?;
365
366 if !can_finalize {
367 return session_inner.terminate_due_to_errors(accum);
368 }
369
370 match session_inner.finalize_round(rng, accum)? {
371 RoundOutcome::Finished(report) => return Ok(report),
372 RoundOutcome::AnotherRound {
373 session: new_session,
374 cached_messages: new_cached_messages,
375 } => {
376 session = Arc::new(new_session);
377 cached_messages = new_cached_messages;
378 }
379 }
380 }
381}