1use super::{QueryResult, Session};
10
11use crate::client::{connections::CmdResponse, Error, Result};
12use crate::{at_least_one_correct_elder, elder_count};
13use sn_interface::messaging::{
14 data::{CmdError, DataQuery, QueryResponse},
15 DstLocation, MsgId, MsgKind, ServiceAuth, WireMsg,
16};
17use sn_interface::network_knowledge::prefix_map::NetworkPrefixMap;
18use sn_interface::types::{Peer, PeerLinks, PublicKey, SendToOneError};
19
20use backoff::{backoff::Backoff, ExponentialBackoff};
21use bytes::Bytes;
22use dashmap::DashMap;
23use futures::future::join_all;
24use qp2p::{Close, Config as QuicP2pConfig, ConnectionError, Endpoint, SendError};
25use rand::{rngs::OsRng, seq::SliceRandom};
26use std::{net::SocketAddr, path::PathBuf, sync::Arc, time::Duration};
27use tokio::{
28 sync::mpsc::{channel, Sender},
29 sync::RwLock,
30 task::JoinHandle,
31};
32use tracing::{debug, error, trace, warn};
33use xor_name::XorName;
34
35pub(crate) const NUM_OF_ELDERS_SUBSET_FOR_QUERIES: usize = 3;
37
38pub(crate) const NODES_TO_CONTACT_PER_STARTUP_BATCH: usize = 3;
40
41const INITIAL_WAIT: u64 = 1;
43
44const CLIENT_SEND_RETRIES: usize = 1;
46
47impl Session {
48 #[instrument(skip(err_sender), level = "debug")]
50 pub(crate) fn new(
51 client_pk: PublicKey,
52 genesis_key: bls::PublicKey,
53 qp2p_config: QuicP2pConfig,
54 err_sender: Sender<CmdError>,
55 local_addr: SocketAddr,
56 cmd_ack_wait: Duration,
57 prefix_map: NetworkPrefixMap,
58 ) -> Result<Session> {
59 let endpoint = Endpoint::new_client(local_addr, qp2p_config)?;
60 let peer_links = PeerLinks::new(endpoint.clone());
61
62 let session = Session {
63 pending_queries: Arc::new(DashMap::default()),
64 incoming_err_sender: Arc::new(err_sender),
65 pending_cmds: Arc::new(DashMap::default()),
66 endpoint,
67 network: Arc::new(prefix_map),
68 genesis_key,
69 initial_connection_check_msg_id: Arc::new(RwLock::new(None)),
70 cmd_ack_wait,
71 peer_links,
72 };
73
74 Ok(session)
75 }
76
77 #[instrument(skip(self, auth, payload), level = "debug", name = "session send cmd")]
78 pub(crate) async fn send_cmd(
79 &self,
80 dst_address: XorName,
81 auth: ServiceAuth,
82 payload: Bytes,
83 ) -> Result<()> {
84 let endpoint = self.endpoint.clone();
85 let (section_pk, elders) = self.get_cmd_elders(dst_address).await?;
88
89 let msg_id = MsgId::new();
90
91 debug!(
92 "Sending cmd w/id {:?}, from {}, to {} Elders w/ dst: {:?}",
93 msg_id,
94 endpoint.public_addr(),
95 elders.len(),
96 dst_address
97 );
98
99 let dst_location = DstLocation::Section {
100 name: dst_address,
101 section_pk,
102 };
103
104 let msg_kind = MsgKind::ServiceMsg(auth);
105 let wire_msg = WireMsg::new_msg(msg_id, payload, msg_kind, dst_location)?;
106
107 let elders_len = elders.len();
108 let (sender, mut receiver) = channel::<CmdResponse>(elders_len);
110 let _ = self.pending_cmds.insert(msg_id, sender);
111 trace!("Inserted channel for cmd {:?}", msg_id);
112
113 let res = send_msg(self.clone(), elders, wire_msg, msg_id).await;
114
115 if res.as_ref().is_err() {
116 return res;
118 }
119
120 let expected_acks = std::cmp::max(1, elders_len * 2 / 3);
121 let mut received_ack = 0;
125 let mut received_err = 0;
126 let mut attempts = 0;
127 let interval = Duration::from_millis(1000);
128 let expected_cmd_ack_wait_attempts =
129 std::cmp::max(10, self.cmd_ack_wait.as_millis() / interval.as_millis());
130 loop {
131 match receiver.try_recv() {
132 Ok((src, None)) => {
133 received_ack += 1;
134 trace!(
135 "received CmdAck of {:?} from {:?}, so far {:?} / {:?}",
136 msg_id,
137 src,
138 received_ack,
139 expected_acks
140 );
141 if received_ack >= expected_acks {
142 let _ = self.pending_cmds.remove(&msg_id);
143 break;
144 }
145 }
146 Ok((src, Some(error))) => {
147 received_err += 1;
148 error!(
149 "received error response {:?} of cmd {:?} from {:?}, so far {:?} vs. {:?}",
150 error, msg_id, src, received_ack, received_err
151 );
152 if received_err >= expected_acks {
153 error!("Received majority of error response for cmd {:?}", msg_id);
154 let _ = self.pending_cmds.remove(&msg_id);
155 return Err(Error::from((error, msg_id)));
156 }
157 }
158 Err(_err) => {
159 }
161 }
162 attempts += 1;
163 if attempts >= expected_cmd_ack_wait_attempts {
164 warn!(
165 "Terminated with insufficient CmdAcks for {:?}, {:?} / {:?} acks received",
166 msg_id, received_ack, expected_acks
167 );
168 break;
169 }
170 trace!(
171 "current ack waiting loop count {:?}/{:?}",
172 attempts,
173 expected_cmd_ack_wait_attempts
174 );
175 tokio::time::sleep(interval).await;
176 }
177
178 trace!("Wait for any cmd response/reaction (AE msgs eg), is over)");
179 res
180 }
181
182 #[instrument(skip_all, level = "debug")]
183 pub(crate) async fn send_query(
185 &self,
186 query: DataQuery,
187 auth: ServiceAuth,
188 payload: Bytes,
189 ) -> Result<QueryResult> {
190 let endpoint = self.endpoint.clone();
191
192 let chunk_addr = if let DataQuery::GetChunk(address) = query {
193 Some(address)
194 } else {
195 None
196 };
197
198 let dst = query.dst_name();
199
200 let (section_pk, elders) = self.get_query_elders(dst).await?;
201 let elders_len = elders.len();
202 let msg_id = MsgId::new();
203
204 debug!(
205 "Sending query message {:?}, msg_id: {:?}, from {}, to the {} Elders closest to data name: {:?}",
206 query,
207 msg_id,
208 endpoint.public_addr(),
209 elders_len,
210 elders
211 );
212
213 let (sender, mut receiver) = channel::<QueryResponse>(7);
214
215 if let Ok(op_id) = query.operation_id() {
216 trace!("Inserting channel for op_id {:?}", (msg_id, op_id));
218 if let Some(mut entry) = self.pending_queries.get_mut(&op_id) {
219 let senders_vec = entry.value_mut();
220 senders_vec.push((msg_id, sender))
221 } else {
222 let _nonexistant_entry = self.pending_queries.insert(op_id, vec![(msg_id, sender)]);
223 }
224
225 trace!("Inserted channel for {:?}", op_id);
226 } else {
227 warn!("No op_id found for query");
228 }
229
230 let dst_location = DstLocation::Section {
231 name: dst,
232 section_pk,
233 };
234 let msg_kind = MsgKind::ServiceMsg(auth);
235 let wire_msg = WireMsg::new_msg(msg_id, payload, msg_kind, dst_location)?;
236
237 send_msg(self.clone(), elders, wire_msg, msg_id).await?;
238
239 let mut discarded_responses: usize = 0;
250
251 let response = loop {
252 let mut error_response = None;
253 match (receiver.recv().await, chunk_addr) {
254 (Some(QueryResponse::GetChunk(Ok(chunk))), Some(chunk_addr)) => {
255 debug!("Chunk QueryResponse received is: {:#?}", chunk);
258
259 if chunk_addr.name() == chunk.name() {
260 trace!("Valid Chunk received for {:?}", msg_id);
261 break Some(QueryResponse::GetChunk(Ok(chunk)));
262 } else {
263 warn!("We received an invalid Chunk response from one of the nodes");
266 discarded_responses += 1;
267 }
268 }
269 (response @ Some(QueryResponse::GetChunk(Err(_))), Some(_))
273 | (response @ Some(QueryResponse::GetRegister((Err(_), _))), None)
274 | (response @ Some(QueryResponse::GetRegisterPolicy((Err(_), _))), None)
275 | (response @ Some(QueryResponse::GetRegisterOwner((Err(_), _))), None)
276 | (response @ Some(QueryResponse::GetRegisterUserPermissions((Err(_), _))), None) =>
277 {
278 debug!("QueryResponse error received (but may be overridden by a non-error response from another elder): {:#?}", &response);
279 error_response = response;
280 discarded_responses += 1;
281 }
282 (Some(response), _) => {
283 debug!("QueryResponse received is: {:#?}", response);
284 break Some(response);
285 }
286 (None, _) => {
287 debug!("QueryResponse channel closed.");
288 break None;
289 }
290 }
291 if discarded_responses == elders_len {
292 break error_response;
293 }
294 };
295
296 debug!(
297 "Response obtained for query w/id {:?}: {:?}",
298 msg_id, response
299 );
300
301 if let Some(query) = &response {
302 if let Ok(query_op_id) = query.operation_id() {
303 trace!("Removing channel for {:?}", (msg_id, &query_op_id));
305 if let Some(mut entry) = self.pending_queries.get_mut(&query_op_id) {
307 let listeners_for_op = entry.value_mut();
308 if let Some(index) = listeners_for_op
309 .iter()
310 .position(|(id, _sender)| *id == msg_id)
311 {
312 let _old_listener = listeners_for_op.swap_remove(index);
313 }
314 } else {
315 warn!("No listeners found for our op_id: {:?}", query_op_id)
316 }
317 }
318 }
319
320 match response {
321 Some(response) => {
322 let operation_id = response
323 .operation_id()
324 .map_err(|_| Error::UnknownOperationId)?;
325 Ok(QueryResult {
326 response,
327 operation_id,
328 })
329 }
330 None => Err(Error::NoResponse),
331 }
332 }
333
334 #[instrument(skip_all, level = "debug")]
335 pub(crate) async fn make_contact_with_nodes(
336 &self,
337 nodes: Vec<Peer>,
338 dst_address: XorName,
339 auth: ServiceAuth,
340 payload: Bytes,
341 ) -> Result<(), Error> {
342 let endpoint = self.endpoint.clone();
343 let (elders_or_adults, section_pk) =
346 if let Some(sap) = self.network.closest_or_opposite(&dst_address, None) {
347 let mut nodes: Vec<_> = sap.elders_vec();
348
349 nodes.shuffle(&mut OsRng);
350
351 (nodes, sap.section_key())
352 } else {
353 (nodes, self.genesis_key)
355 };
356
357 let msg_id = MsgId::new();
358
359 debug!(
360 "Making initial contact with nodes. Our PublicAddr: {:?}. Using {:?} to {} nodes: {:?}",
361 endpoint.public_addr(),
362 msg_id,
363 elders_or_adults.len(),
364 elders_or_adults
365 );
366
367 let dst_location = DstLocation::Section {
369 name: dst_address,
370 section_pk,
371 };
372 let msg_kind = MsgKind::ServiceMsg(auth);
373 let wire_msg = WireMsg::new_msg(msg_id, payload, msg_kind, dst_location)?;
374
375 let initial_contacts = elders_or_adults
383 .clone()
384 .into_iter()
385 .take(NODES_TO_CONTACT_PER_STARTUP_BATCH)
386 .collect();
387
388 send_msg(self.clone(), initial_contacts, wire_msg.clone(), msg_id).await?;
389
390 *self.initial_connection_check_msg_id.write().await = Some(msg_id);
391
392 let mut knowledge_checks = 0;
393 let mut outgoing_msg_rounds = 1;
394 let mut last_start_pos = 0;
395 let mut tried_every_contact = false;
396
397 let mut backoff = ExponentialBackoff {
398 initial_interval: Duration::from_millis(1500),
399 max_interval: Duration::from_secs(5),
400 max_elapsed_time: Some(Duration::from_secs(60)),
401 ..Default::default()
402 };
403
404 backoff.reset();
406
407 tokio::time::sleep(Duration::from_secs(INITIAL_WAIT)).await;
409
410 if section_pk == self.genesis_key {
412 info!("On client startup, awaiting some network knowledge");
413
414 let mut known_sap = self.network.closest_or_opposite(&dst_address, None);
415
416 let mut insufficient_sap_peers = false;
417
418 if let Some(sap) = known_sap.clone() {
419 if sap.elders_vec().len() < elder_count() {
420 insufficient_sap_peers = true;
421 }
422 }
423
424 while known_sap.is_none() || insufficient_sap_peers {
426 if tried_every_contact {
427 return Err(Error::NetworkContact);
428 }
429
430 let stats = self.network.known_sections_count();
431 debug!("Client still has not received a complete section's AE-Retry message... Current sections known: {:?}. Do we have insufficient peers: {:?}", stats, insufficient_sap_peers);
432 knowledge_checks += 1;
433
434 if knowledge_checks > 2 {
437 let mut start_pos = outgoing_msg_rounds * NODES_TO_CONTACT_PER_STARTUP_BATCH;
438 outgoing_msg_rounds += 1;
439
440 if start_pos > elders_or_adults.len() {
442 start_pos = last_start_pos;
443 }
444
445 last_start_pos = start_pos;
446
447 let next_batch_end = start_pos + NODES_TO_CONTACT_PER_STARTUP_BATCH;
448
449 let next_contacts = if next_batch_end > elders_or_adults.len() {
451 let next = elders_or_adults[start_pos..].to_vec();
453 tried_every_contact = true;
455
456 next
457 } else {
458 elders_or_adults[start_pos..start_pos + NODES_TO_CONTACT_PER_STARTUP_BATCH]
459 .to_vec()
460 };
461
462 trace!("Sending out another batch of initial contact msgs to new nodes");
463 send_msg(self.clone(), next_contacts, wire_msg.clone(), msg_id).await?;
464
465 let next_wait = backoff.next_backoff();
466 trace!(
467 "Awaiting a duration of {:?} before trying new nodes",
468 next_wait
469 );
470
471 if let Some(wait) = next_wait {
473 tokio::time::sleep(wait).await;
474 }
475
476 known_sap = self.network.closest_or_opposite(&dst_address, None);
477
478 debug!("Known sap: {known_sap:?}");
479 insufficient_sap_peers = false;
480 if let Some(sap) = known_sap.clone() {
481 if sap.elders_vec().len() < elder_count() {
482 debug!("Known elders: {:?}", sap.elders_vec().len());
483 insufficient_sap_peers = true;
484 }
485 }
486 }
487 }
488
489 let stats = self.network.known_sections_count();
490 debug!("Client has received updated network knowledge. Current sections known: {:?}. Sap for our startup-query: {:?}", stats, known_sap);
491 }
492
493 Ok(())
494 }
495
496 async fn get_query_elders(&self, dst: XorName) -> Result<(bls::PublicKey, Vec<Peer>)> {
497 let sap = self.network.closest_or_opposite(&dst, None);
499 let (section_pk, mut elders) = if let Some(sap) = &sap {
500 (sap.section_key(), sap.elders_vec())
501 } else {
502 return Err(Error::NoNetworkKnowledge);
503 };
504
505 elders.shuffle(&mut OsRng);
506
507 let elders: Vec<_> = elders
509 .into_iter()
510 .take(NUM_OF_ELDERS_SUBSET_FOR_QUERIES)
511 .collect();
512
513 let elders_len = elders.len();
514 if elders_len < NUM_OF_ELDERS_SUBSET_FOR_QUERIES && elders_len > 1 {
515 return Err(Error::InsufficientElderConnections {
516 connections: elders_len,
517 required: NUM_OF_ELDERS_SUBSET_FOR_QUERIES,
518 });
519 }
520
521 Ok((section_pk, elders))
522 }
523
524 async fn get_cmd_elders(&self, dst_address: XorName) -> Result<(bls::PublicKey, Vec<Peer>)> {
525 let (mut elders, section_pk) =
527 if let Some(sap) = self.network.closest_or_opposite(&dst_address, None) {
528 let sap_elders = sap.elders_vec();
529
530 trace!("SAP elders found {:?}", sap_elders);
531
532 (sap_elders, sap.section_key())
533 } else {
534 return Err(Error::NoNetworkKnowledge);
535 };
536
537 let targets_count = at_least_one_correct_elder(); if elders.len() < targets_count {
541 error!("Insufficient knowledge to send to {:?}", dst_address);
542 return Err(Error::InsufficientElderKnowledge {
543 connections: elders.len(),
544 required: targets_count,
545 section_pk,
546 });
547 }
548
549 elders.shuffle(&mut OsRng);
550
551 let elders = elders.into_iter().take(targets_count).collect();
553
554 Ok((section_pk, elders))
555 }
556}
557
558#[instrument(skip_all, level = "trace")]
559pub(super) async fn send_msg(
560 session: Session,
561 nodes: Vec<Peer>,
562 wire_msg: WireMsg,
563 msg_id: MsgId,
564) -> Result<()> {
565 let priority = wire_msg.clone().into_msg()?.priority();
566 let msg_bytes = wire_msg.serialize()?;
567
568 let mut last_error = None;
569 drop(wire_msg);
570
571 let mut tasks = Vec::default();
573
574 let successes = Arc::new(RwLock::new(0));
575
576 for peer in nodes.clone() {
577 let session = session.clone();
578 let msg_bytes_clone = msg_bytes.clone();
579 let peer_name = peer.name();
580
581 let task_handle: JoinHandle<(XorName, Result<()>)> = tokio::spawn(async move {
582 let link = session.peer_links.get_or_create(&peer).await;
583
584 let listen = |conn, incoming_msgs| {
585 Session::spawn_msg_listener_thread(session.clone(), peer, conn, incoming_msgs);
586 };
587
588 let mut retries = 0;
589
590 let send_and_retry = || async {
591 match link
592 .send_with(msg_bytes_clone.clone(), priority, None, listen)
593 .await
594 {
595 Ok(()) => Ok(()),
596 Err(error) => match error {
597 SendToOneError::Connection(err) => Err(Error::QuicP2pConnection(err)),
598 SendToOneError::Send(err) => Err(Error::QuicP2pSend(err)),
599 },
600 }
601 };
602 let mut result = send_and_retry().await;
603
604 while result.is_err() && retries < CLIENT_SEND_RETRIES {
605 warn!(
606 "Attempting to send msg again {msg_id:?}, attempt #{:?}",
607 retries.clone()
608 );
609 retries += 1;
610 result = send_and_retry().await;
611 }
612
613 (peer_name, result)
614 });
615
616 tasks.push(task_handle);
617 }
618
619 let results = join_all(tasks).await;
621
622 for r in results {
623 match r {
624 Ok((peer_name, send_result)) => match send_result {
625 Err(Error::QuicP2pSend(SendError::ConnectionLost(ConnectionError::Closed(
626 Close::Application { reason, error_code },
627 )))) => {
628 warn!(
629 "Connection was closed by node {}, reason: {:?}",
630 peer_name,
631 String::from_utf8(reason.to_vec())
632 );
633 last_error = Some(Error::QuicP2pSend(SendError::ConnectionLost(
634 ConnectionError::Closed(Close::Application { reason, error_code }),
635 )));
636 }
637 Err(Error::QuicP2pSend(SendError::ConnectionLost(error))) => {
638 warn!("Connection to {} was lost: {:?}", peer_name, error);
639 last_error = Some(Error::QuicP2pSend(SendError::ConnectionLost(error)));
640 }
641 Err(error) => {
642 warn!(
643 "Issue during {:?} send to {}: {:?}",
644 msg_id, peer_name, error
645 );
646 last_error = Some(error);
647 }
648 Ok(_) => *successes.write().await += 1,
649 },
650 Err(join_error) => {
651 warn!("Tokio join error as we send: {:?}", join_error)
652 }
653 }
654 }
655
656 let failures = nodes.len() - *successes.read().await;
657
658 if failures > 0 {
659 trace!(
660 "Sending the message ({:?}) from {} to {}/{} of the nodes failed: {:?}",
661 msg_id,
662 session.endpoint.public_addr(),
663 failures,
664 nodes.len(),
665 nodes,
666 );
667 }
668
669 let successful_sends = *successes.read().await;
670 if failures > successful_sends {
671 warn!("More errors when sending a message than successes");
672 if let Some(error) = last_error {
673 warn!("The relevant error is: {error}");
674 return Err(error);
675 }
676 }
677
678 Ok(())
679}
680
681#[instrument(skip_all, level = "trace")]
682pub(crate) async fn create_safe_dir() -> Result<PathBuf, Error> {
683 let mut root_dir = dirs_next::home_dir().ok_or(Error::CouldNotReadHomeDir)?;
684 root_dir.push(".safe");
685
686 tokio::fs::create_dir_all(root_dir.clone())
688 .await
689 .map_err(|_| Error::CouldNotCreateSafeDir)?;
690
691 Ok(root_dir)
692}