1use crate::protocol::{
11 bytes_to_hash, create_fragment_response, create_request, create_response, encode_request,
12 encode_response, hash_to_key, is_fragmented, parse_message, DataMessage as ProtoMessage,
13 DataResponse, FRAGMENT_SIZE,
14};
15use crate::types::{
16 should_forward, ForwardRequest, ForwardTx, PeerHTLConfig, PeerId, PeerState, SignalingMessage,
17 DATA_CHANNEL_LABEL, MAX_HTL,
18};
19use bytes::Bytes;
20use hashtree_core::{Hash, Store};
21use lru::LruCache;
22use std::collections::HashMap;
23use std::num::NonZeroUsize;
24use std::sync::Arc;
25use thiserror::Error;
26use tokio::sync::{mpsc, oneshot, RwLock};
27use webrtc::api::interceptor_registry::register_default_interceptors;
28use webrtc::api::media_engine::MediaEngine;
29use webrtc::api::APIBuilder;
30use webrtc::data_channel::data_channel_init::RTCDataChannelInit;
31use webrtc::data_channel::data_channel_message::DataChannelMessage;
32use webrtc::data_channel::RTCDataChannel;
33use webrtc::ice_transport::ice_candidate::{RTCIceCandidate, RTCIceCandidateInit};
34use webrtc::ice_transport::ice_server::RTCIceServer;
35use webrtc::interceptor::registry::Registry;
36use webrtc::peer_connection::configuration::RTCConfiguration;
37use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
38use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
39use webrtc::peer_connection::RTCPeerConnection;
40
41#[derive(Debug, Error)]
42pub enum PeerError {
43 #[error("WebRTC error: {0}")]
44 WebRTC(#[from] webrtc::Error),
45 #[error("JSON error: {0}")]
46 Json(#[from] serde_json::Error),
47 #[error("Channel closed")]
48 ChannelClosed,
49 #[error("Request timeout")]
50 Timeout,
51 #[error("Peer not ready")]
52 NotReady,
53 #[error("Data not found")]
54 NotFound,
55}
56
57const THEIR_REQUESTS_SIZE: usize = 200;
59
60#[allow(dead_code)]
62const FRAGMENT_STALL_TIMEOUT_MS: u64 = 5000;
63#[allow(dead_code)]
64const FRAGMENT_TOTAL_TIMEOUT_MS: u64 = 120000;
65
66struct PendingRequest {
69 #[allow(dead_code)] hash: Hash,
71 response_tx: oneshot::Sender<Option<Vec<u8>>>,
72}
73
74#[derive(Debug, Clone)]
77struct TheirRequest {
78 hash: Hash,
80 #[allow(dead_code)]
82 requested_at: std::time::Instant,
83}
84
85struct PendingReassembly {
87 #[allow(dead_code)] hash: Hash,
89 fragments: HashMap<u32, Vec<u8>>,
90 total_expected: u32,
91 received_bytes: usize,
92 #[allow(dead_code)]
94 first_fragment_at: std::time::Instant,
95 last_fragment_at: std::time::Instant,
96}
97
98pub type ForwardRequestCallback = Arc<
102 dyn Fn(Hash, String, u8) -> futures::future::BoxFuture<'static, Option<Vec<u8>>> + Send + Sync,
103>;
104
105async fn forward_via_channel(
107 forward_tx: &ForwardTx,
108 hash: Hash,
109 exclude_peer_id: String,
110 htl: u8,
111) -> Option<Vec<u8>> {
112 let (response_tx, response_rx) = oneshot::channel();
113 let req = ForwardRequest {
114 hash,
115 exclude_peer_id,
116 htl,
117 response: response_tx,
118 };
119
120 if forward_tx.send(req).await.is_err() {
121 return None;
122 }
123
124 response_rx.await.ok().flatten()
125}
126
127pub struct Peer<S: Store> {
136 pub remote_id: PeerId,
138 state: Arc<RwLock<PeerState>>,
140 connection: Arc<RTCPeerConnection>,
142 data_channel: Arc<RwLock<Option<Arc<RTCDataChannel>>>>,
144 pending_candidates: Arc<RwLock<Vec<RTCIceCandidateInit>>>,
146 pending_requests: Arc<RwLock<HashMap<String, PendingRequest>>>,
149 their_requests: Arc<RwLock<LruCache<String, TheirRequest>>>,
153 pending_reassemblies: Arc<RwLock<HashMap<String, PendingReassembly>>>,
155 signaling_tx: mpsc::Sender<SignalingMessage>,
157 local_store: Arc<S>,
159 local_peer_id: String,
161 debug: bool,
163 htl_config: PeerHTLConfig,
165 forward_tx: Option<ForwardTx>,
167 on_forward_request: Option<ForwardRequestCallback>,
169}
170
171impl<S: Store + 'static> Peer<S> {
172 pub async fn new(
174 remote_id: PeerId,
175 local_peer_id: String,
176 signaling_tx: mpsc::Sender<SignalingMessage>,
177 local_store: Arc<S>,
178 debug: bool,
179 ) -> Result<Self, PeerError> {
180 Self::with_forward_channel(
181 remote_id,
182 local_peer_id,
183 signaling_tx,
184 local_store,
185 debug,
186 None,
187 )
188 .await
189 }
190
191 pub async fn with_forward_channel(
193 remote_id: PeerId,
194 local_peer_id: String,
195 signaling_tx: mpsc::Sender<SignalingMessage>,
196 local_store: Arc<S>,
197 debug: bool,
198 forward_tx: Option<ForwardTx>,
199 ) -> Result<Self, PeerError> {
200 let mut media_engine = MediaEngine::default();
202 media_engine.register_default_codecs()?;
203
204 let mut registry = Registry::new();
205 registry = register_default_interceptors(registry, &mut media_engine)?;
206
207 let api = APIBuilder::new()
208 .with_media_engine(media_engine)
209 .with_interceptor_registry(registry)
210 .build();
211
212 let config = RTCConfiguration {
214 ice_servers: vec![RTCIceServer {
215 urls: vec![
216 "stun:stun.iris.to:3478".to_string(),
217 "stun:stun.l.google.com:19302".to_string(),
218 "stun:stun.cloudflare.com:3478".to_string(),
219 ],
220 ..Default::default()
221 }],
222 ..Default::default()
223 };
224
225 let connection = Arc::new(api.new_peer_connection(config).await?);
226
227 let peer = Self {
228 remote_id,
229 state: Arc::new(RwLock::new(PeerState::New)),
230 connection,
231 data_channel: Arc::new(RwLock::new(None)),
232 pending_candidates: Arc::new(RwLock::new(Vec::new())),
233 pending_requests: Arc::new(RwLock::new(HashMap::new())),
234 their_requests: Arc::new(RwLock::new(LruCache::new(
235 NonZeroUsize::new(THEIR_REQUESTS_SIZE).unwrap(),
236 ))),
237 pending_reassemblies: Arc::new(RwLock::new(HashMap::new())),
238 signaling_tx,
239 local_store,
240 local_peer_id,
241 debug,
242 htl_config: PeerHTLConfig::random(),
243 forward_tx,
244 on_forward_request: None,
245 };
246
247 peer.setup_handlers().await?;
248
249 Ok(peer)
250 }
251
252 async fn setup_handlers(&self) -> Result<(), PeerError> {
254 let state = self.state.clone();
255 let data_channel = self.data_channel.clone();
256 let pending_requests = self.pending_requests.clone();
257 let their_requests = self.their_requests.clone();
258 let pending_reassemblies = self.pending_reassemblies.clone();
259 let local_store = self.local_store.clone();
260 let debug = self.debug;
261 let htl_config = self.htl_config;
262 let forward_tx = self.forward_tx.clone();
263 let on_forward_request = self.on_forward_request.clone();
264 let peer_id_str = self.remote_id.to_peer_string();
265
266 let state_clone = state.clone();
268 self.connection.on_peer_connection_state_change(Box::new(
269 move |s: RTCPeerConnectionState| {
270 let state = state_clone.clone();
271 Box::pin(async move {
272 if debug {
273 println!("[Peer] Connection state changed: {:?}", s);
274 }
275 let mut state = state.write().await;
276 match s {
277 RTCPeerConnectionState::Connected => {
278 *state = PeerState::Connected;
279 if debug {
280 println!("[Peer] Connection established");
281 }
282 }
283 RTCPeerConnectionState::Disconnected
284 | RTCPeerConnectionState::Failed
285 | RTCPeerConnectionState::Closed => {
286 *state = PeerState::Disconnected;
287 if debug {
288 println!("[Peer] Connection closed: {:?}", s);
289 }
290 }
291 _ => {}
292 }
293 })
294 },
295 ));
296
297 let data_channel_clone = data_channel.clone();
299 let pending_requests_clone = pending_requests.clone();
300 let their_requests_clone = their_requests.clone();
301 let pending_reassemblies_clone = pending_reassemblies.clone();
302 let local_store_clone = local_store.clone();
303 let state_clone = state.clone();
304 let forward_tx_clone = forward_tx.clone();
305 let on_forward_clone = on_forward_request.clone();
306 let peer_id_clone = peer_id_str.clone();
307 self.connection.on_data_channel(Box::new(move |dc| {
308 let data_channel = data_channel_clone.clone();
309 let pending_requests = pending_requests_clone.clone();
310 let their_requests = their_requests_clone.clone();
311 let pending_reassemblies = pending_reassemblies_clone.clone();
312 let local_store = local_store_clone.clone();
313 let state = state_clone.clone();
314 let forward_tx = forward_tx_clone.clone();
315 let on_forward = on_forward_clone.clone();
316 let peer_id = peer_id_clone.clone();
317
318 Box::pin(async move {
319 if dc.label() == DATA_CHANNEL_LABEL {
320 Self::setup_data_channel_handlers(
321 dc.clone(),
322 pending_requests,
323 their_requests,
324 pending_reassemblies,
325 local_store,
326 debug,
327 htl_config,
328 forward_tx,
329 on_forward,
330 peer_id,
331 )
332 .await;
333 *data_channel.write().await = Some(dc);
334 *state.write().await = PeerState::Ready;
335 if debug {
336 println!("[Peer] Data channel opened (incoming)");
337 }
338 }
339 })
340 }));
341
342 let signaling_tx = self.signaling_tx.clone();
344 let local_peer_id = self.local_peer_id.clone();
345 let remote_id = self.remote_id.to_peer_string();
346 self.connection
347 .on_ice_candidate(Box::new(move |candidate: Option<RTCIceCandidate>| {
348 let signaling_tx = signaling_tx.clone();
349 let local_peer_id = local_peer_id.clone();
350 let remote_id = remote_id.clone();
351
352 Box::pin(async move {
353 if let Some(candidate) = candidate {
354 let json = candidate.to_json().unwrap();
355 let msg = SignalingMessage::Candidate {
356 peer_id: local_peer_id,
357 target_peer_id: remote_id,
358 candidate: json.candidate,
359 sdp_m_line_index: json.sdp_mline_index,
360 sdp_mid: json.sdp_mid,
361 };
362 let _ = signaling_tx.send(msg).await;
363 }
364 })
365 }));
366
367 let debug_clone = debug;
369 self.connection
370 .on_ice_connection_state_change(Box::new(move |s| {
371 if debug_clone {
372 println!("[Peer] ICE connection state: {:?}", s);
373 }
374 Box::pin(async {})
375 }));
376
377 let debug_clone2 = debug;
379 self.connection
380 .on_ice_gathering_state_change(Box::new(move |s| {
381 if debug_clone2 {
382 println!("[Peer] ICE gathering state: {:?}", s);
383 }
384 Box::pin(async {})
385 }));
386
387 Ok(())
388 }
389
390 #[allow(clippy::too_many_arguments)]
393 async fn setup_data_channel_handlers(
394 dc: Arc<RTCDataChannel>,
395 pending_requests: Arc<RwLock<HashMap<String, PendingRequest>>>,
396 their_requests: Arc<RwLock<LruCache<String, TheirRequest>>>,
397 pending_reassemblies: Arc<RwLock<HashMap<String, PendingReassembly>>>,
398 local_store: Arc<S>,
399 debug: bool,
400 htl_config: PeerHTLConfig,
401 forward_tx: Option<ForwardTx>,
402 on_forward_request: Option<ForwardRequestCallback>,
403 peer_id: String,
404 ) {
405 let pending_requests_clone = pending_requests.clone();
406 let their_requests_clone = their_requests.clone();
407 let pending_reassemblies_clone = pending_reassemblies.clone();
408 let local_store_clone = local_store.clone();
409 let dc_clone = dc.clone();
410 let forward_tx_clone = forward_tx.clone();
411 let on_forward_clone = on_forward_request.clone();
412 let peer_id_clone = peer_id.clone();
413
414 dc.on_message(Box::new(move |msg: DataChannelMessage| {
415 let pending_requests = pending_requests_clone.clone();
416 let their_requests = their_requests_clone.clone();
417 let pending_reassemblies = pending_reassemblies_clone.clone();
418 let local_store = local_store_clone.clone();
419 let dc = dc_clone.clone();
420 let forward_tx = forward_tx_clone.clone();
421 let on_forward = on_forward_clone.clone();
422 let peer_id = peer_id_clone.clone();
423
424 Box::pin(async move {
425 let data = msg.data.to_vec();
426 if data.is_empty() {
427 return;
428 }
429
430 let parsed = match parse_message(&data) {
432 Some(m) => m,
433 None => {
434 if debug {
435 println!("[Peer] Failed to parse message");
436 }
437 return;
438 }
439 };
440
441 match parsed {
442 ProtoMessage::Request(req) => {
443 let htl = req.htl.unwrap_or(MAX_HTL);
444 let hash_key = hash_to_key(&req.h);
445
446 if debug {
447 println!(
448 "[Peer] Request: hash={}..., htl={}",
449 &hash_key[..16.min(hash_key.len())],
450 htl
451 );
452 }
453
454 let hash_bytes = match bytes_to_hash(&req.h) {
456 Some(h) => h,
457 None => return,
458 };
459
460 let local_result = local_store.get(&hash_bytes).await;
462
463 if let Ok(Some(payload)) = local_result {
464 Self::send_response(&dc, &hash_bytes, payload, debug).await;
466 return;
467 }
468
469 let can_forward = forward_tx.is_some() || on_forward.is_some();
471 if can_forward && should_forward(htl) {
472 {
474 let mut their_reqs = their_requests.write().await;
475 their_reqs.put(
476 hash_key.clone(),
477 TheirRequest {
478 hash: hash_bytes,
479 requested_at: std::time::Instant::now(),
480 },
481 );
482 }
483
484 let forward_htl = htl_config.decrement(htl);
486
487 if debug {
488 println!(
489 "[Peer] Forwarding request htl={}->{}, hash={}...",
490 htl,
491 forward_htl,
492 &hash_key[..16.min(hash_key.len())]
493 );
494 }
495
496 let forward_result = if let Some(ref tx) = forward_tx {
498 forward_via_channel(tx, hash_bytes, peer_id.clone(), forward_htl)
499 .await
500 } else if let Some(ref forward_cb) = on_forward {
501 forward_cb(hash_bytes, peer_id.clone(), forward_htl).await
502 } else {
503 None
504 };
505
506 if let Some(payload) = forward_result {
507 their_requests.write().await.pop(&hash_key);
509 Self::send_response(&dc, &hash_bytes, payload, debug).await;
510
511 if debug {
512 println!(
513 "[Peer] Forward success for hash={}...",
514 &hash_key[..16.min(hash_key.len())]
515 );
516 }
517 return;
518 }
519 }
520
521 {
524 let mut their_reqs = their_requests.write().await;
525 their_reqs.put(
526 hash_key,
527 TheirRequest {
528 hash: hash_bytes,
529 requested_at: std::time::Instant::now(),
530 },
531 );
532 }
533 }
534 ProtoMessage::Response(res) => {
535 let hash_key = hash_to_key(&res.h);
536
537 let final_data = if is_fragmented(&res) {
539 Self::handle_fragment_response(&res, &pending_reassemblies, debug).await
541 } else {
542 Some(res.d)
544 };
545
546 let final_data = match final_data {
547 Some(d) => d,
548 None => return, };
550
551 if debug {
552 println!(
553 "[Peer] Response: hash={}..., size={}",
554 &hash_key[..16.min(hash_key.len())],
555 final_data.len()
556 );
557 }
558
559 let mut requests = pending_requests.write().await;
561 if let Some(request) = requests.remove(&hash_key) {
562 let computed_hash = hashtree_core::sha256(&final_data);
564 if computed_hash.to_vec() == res.h {
565 let _ = request.response_tx.send(Some(final_data));
566 } else {
567 if debug {
568 println!("[Peer] Hash mismatch for response");
569 }
570 let _ = request.response_tx.send(None);
571 }
572 }
573 }
574 ProtoMessage::QuoteRequest(_) | ProtoMessage::QuoteResponse(_) => {
575 if debug {
576 println!("[Peer] Ignoring quote message on legacy peer path");
577 }
578 }
579 }
580 })
581 }));
582 }
583
584 async fn send_response(dc: &Arc<RTCDataChannel>, hash: &Hash, data: Vec<u8>, debug: bool) {
586 if data.len() <= FRAGMENT_SIZE {
587 let res = create_response(hash, data);
589 let encoded = encode_response(&res);
590 let _ = dc.send(&Bytes::from(encoded)).await;
591 } else {
592 let total_fragments = data.len().div_ceil(FRAGMENT_SIZE) as u32;
594 for i in 0..total_fragments {
595 let start = (i as usize) * FRAGMENT_SIZE;
596 let end = std::cmp::min(start + FRAGMENT_SIZE, data.len());
597 let fragment = data[start..end].to_vec();
598
599 let res = create_fragment_response(hash, fragment, i, total_fragments);
600 let encoded = encode_response(&res);
601 let _ = dc.send(&Bytes::from(encoded)).await;
602
603 if debug && i == 0 {
604 println!("[Peer] Sending {} fragments for hash", total_fragments);
605 }
606 }
607 }
608 }
609
610 async fn handle_fragment_response(
612 res: &DataResponse,
613 pending_reassemblies: &Arc<RwLock<HashMap<String, PendingReassembly>>>,
614 debug: bool,
615 ) -> Option<Vec<u8>> {
616 let hash_key = hash_to_key(&res.h);
617 let now = std::time::Instant::now();
618 let index = res.i.unwrap();
619 let total = res.n.unwrap();
620
621 let mut reassemblies = pending_reassemblies.write().await;
622
623 let pending = reassemblies.entry(hash_key.clone()).or_insert_with(|| {
624 let hash = bytes_to_hash(&res.h).unwrap_or([0u8; 32]);
625 PendingReassembly {
626 hash,
627 fragments: HashMap::new(),
628 total_expected: total,
629 received_bytes: 0,
630 first_fragment_at: now,
631 last_fragment_at: now,
632 }
633 });
634
635 if !pending.fragments.contains_key(&index) {
637 pending.received_bytes += res.d.len();
638 pending.fragments.insert(index, res.d.clone());
639 pending.last_fragment_at = now;
640 }
641
642 if pending.fragments.len() == pending.total_expected as usize {
644 let total = pending.total_expected;
645 let mut assembled = Vec::with_capacity(pending.received_bytes);
646 for i in 0..total {
647 if let Some(fragment) = pending.fragments.get(&i) {
648 assembled.extend_from_slice(fragment);
649 }
650 }
651 reassemblies.remove(&hash_key);
652
653 if debug {
654 println!(
655 "[Peer] Reassembled {} fragments, {} bytes",
656 total,
657 assembled.len()
658 );
659 }
660
661 Some(assembled)
662 } else {
663 None }
665 }
666
667 pub async fn connect(&self) -> Result<(), PeerError> {
669 *self.state.write().await = PeerState::Connecting;
670
671 let dc_init = RTCDataChannelInit {
674 ordered: Some(false),
675 ..Default::default()
676 };
677 let dc = self
678 .connection
679 .create_data_channel(DATA_CHANNEL_LABEL, Some(dc_init))
680 .await?;
681
682 Self::setup_data_channel_handlers(
683 dc.clone(),
684 self.pending_requests.clone(),
685 self.their_requests.clone(),
686 self.pending_reassemblies.clone(),
687 self.local_store.clone(),
688 self.debug,
689 self.htl_config,
690 self.forward_tx.clone(),
691 self.on_forward_request.clone(),
692 self.remote_id.to_peer_string(),
693 )
694 .await;
695
696 let data_channel = self.data_channel.clone();
697 let state = self.state.clone();
698 let debug = self.debug;
699 dc.on_open(Box::new(move || {
700 let _data_channel = data_channel.clone();
701 let state = state.clone();
702
703 Box::pin(async move {
704 *state.write().await = PeerState::Ready;
705 if debug {
706 println!("[Peer] Data channel opened (outgoing)");
707 }
708 })
709 }));
710
711 *self.data_channel.write().await = Some(dc);
712
713 let offer = self.connection.create_offer(None).await?;
715 self.connection.set_local_description(offer.clone()).await?;
716
717 let msg = SignalingMessage::Offer {
718 peer_id: self.local_peer_id.clone(),
719 target_peer_id: self.remote_id.to_peer_string(),
720 sdp: offer.sdp,
721 };
722 self.signaling_tx
723 .send(msg)
724 .await
725 .map_err(|_| PeerError::ChannelClosed)?;
726
727 Ok(())
728 }
729
730 pub async fn handle_signaling(&self, msg: SignalingMessage) -> Result<(), PeerError> {
732 match msg {
733 SignalingMessage::Offer { sdp, .. } => {
734 if self.debug {
735 println!("[Peer] Received offer, setting remote description");
736 }
737 let offer = RTCSessionDescription::offer(sdp)?;
738 self.connection.set_remote_description(offer).await?;
739
740 let candidates = self
742 .pending_candidates
743 .write()
744 .await
745 .drain(..)
746 .collect::<Vec<_>>();
747 if self.debug && !candidates.is_empty() {
748 println!(
749 "[Peer] Adding {} pending candidates after offer",
750 candidates.len()
751 );
752 }
753 for candidate in candidates {
754 self.connection.add_ice_candidate(candidate).await?;
755 }
756
757 let answer = self.connection.create_answer(None).await?;
759 self.connection
760 .set_local_description(answer.clone())
761 .await?;
762
763 let msg = SignalingMessage::Answer {
764 peer_id: self.local_peer_id.clone(),
765 target_peer_id: self.remote_id.to_peer_string(),
766 sdp: answer.sdp,
767 };
768 self.signaling_tx
769 .send(msg)
770 .await
771 .map_err(|_| PeerError::ChannelClosed)?;
772
773 *self.state.write().await = PeerState::Connecting;
774 }
775 SignalingMessage::Answer { sdp, .. } => {
776 if self.debug {
777 println!("[Peer] Received answer, setting remote description");
778 }
779 let answer = RTCSessionDescription::answer(sdp)?;
780 self.connection.set_remote_description(answer).await?;
781
782 let candidates = self
784 .pending_candidates
785 .write()
786 .await
787 .drain(..)
788 .collect::<Vec<_>>();
789 if self.debug && !candidates.is_empty() {
790 println!(
791 "[Peer] Adding {} pending candidates after answer",
792 candidates.len()
793 );
794 }
795 for candidate in candidates {
796 self.connection.add_ice_candidate(candidate).await?;
797 }
798 }
799 SignalingMessage::Candidate {
800 candidate,
801 sdp_m_line_index,
802 sdp_mid,
803 ..
804 } => {
805 let init = RTCIceCandidateInit {
806 candidate: candidate.clone(),
807 sdp_mid,
808 sdp_mline_index: sdp_m_line_index,
809 ..Default::default()
810 };
811
812 if self.connection.remote_description().await.is_some() {
814 if self.debug {
815 println!(
816 "[Peer] Adding ICE candidate: {}...",
817 &candidate[..candidate.len().min(50)]
818 );
819 }
820 self.connection.add_ice_candidate(init).await?;
821 } else {
822 if self.debug {
823 println!("[Peer] Queueing ICE candidate (no remote description yet)");
824 }
825 self.pending_candidates.write().await.push(init);
826 }
827 }
828 SignalingMessage::Candidates { candidates, .. } => {
829 for c in candidates {
830 let init = RTCIceCandidateInit {
831 candidate: c.candidate,
832 sdp_mid: c.sdp_mid,
833 sdp_mline_index: c.sdp_m_line_index,
834 ..Default::default()
835 };
836
837 if self.connection.remote_description().await.is_some() {
838 self.connection.add_ice_candidate(init).await?;
839 } else {
840 self.pending_candidates.write().await.push(init);
841 }
842 }
843 }
844 _ => {}
845 }
846
847 Ok(())
848 }
849
850 pub async fn request(&self, hash: &Hash) -> Result<Option<Vec<u8>>, PeerError> {
852 self.request_with_htl(hash, MAX_HTL).await
853 }
854
855 pub async fn request_with_htl(
858 &self,
859 hash: &Hash,
860 htl: u8,
861 ) -> Result<Option<Vec<u8>>, PeerError> {
862 let state = *self.state.read().await;
863 if state != PeerState::Ready {
864 return Err(PeerError::NotReady);
865 }
866
867 let dc = self.data_channel.read().await;
868 let dc = dc.as_ref().ok_or(PeerError::NotReady)?;
869
870 let hash_key = hash_to_key(hash);
871
872 {
874 let requests = self.pending_requests.read().await;
875 if requests.contains_key(&hash_key) {
876 drop(requests);
878 }
880 }
881
882 let (tx, rx) = oneshot::channel();
884 self.pending_requests.write().await.insert(
885 hash_key.clone(),
886 PendingRequest {
887 hash: *hash,
888 response_tx: tx,
889 },
890 );
891
892 let send_htl = self.htl_config.decrement(htl);
895 let req = create_request(hash, send_htl);
896 let encoded = encode_request(&req);
897 dc.send(&Bytes::from(encoded)).await?;
898
899 if self.debug {
900 println!(
901 "[Peer] Sent request: htl={}, hash={}...",
902 send_htl,
903 &hash_key[..16.min(hash_key.len())]
904 );
905 }
906
907 match tokio::time::timeout(std::time::Duration::from_secs(10), rx).await {
909 Ok(Ok(data)) => Ok(data),
910 Ok(Err(_)) => Err(PeerError::ChannelClosed),
911 Err(_) => {
912 self.pending_requests.write().await.remove(&hash_key);
914 Err(PeerError::Timeout)
915 }
916 }
917 }
918
919 pub async fn send_response_for_hash(
923 &self,
924 hash: &Hash,
925 data: Option<&[u8]>,
926 ) -> Result<(), PeerError> {
927 let dc = self.data_channel.read().await;
928 let dc = dc.as_ref().ok_or(PeerError::NotReady)?;
929
930 if let Some(payload) = data {
931 Self::send_response(dc, hash, payload.to_vec(), self.debug).await;
933 }
934 Ok(())
938 }
939
940 pub async fn state(&self) -> PeerState {
942 *self.state.read().await
943 }
944
945 pub async fn close(&self) -> Result<(), PeerError> {
947 self.connection.close().await?;
948 *self.state.write().await = PeerState::Disconnected;
949 Ok(())
950 }
951
952 pub fn set_on_forward_request<F>(&mut self, callback: F)
956 where
957 F: Fn(Hash, String, u8) -> futures::future::BoxFuture<'static, Option<Vec<u8>>>
958 + Send
959 + Sync
960 + 'static,
961 {
962 self.on_forward_request = Some(Arc::new(callback));
963 }
964
965 pub fn htl_config(&self) -> PeerHTLConfig {
967 self.htl_config
968 }
969
970 pub async fn send_data(&self, hash_hex: &str, data: &[u8]) -> Result<bool, PeerError> {
973 let their_req = {
974 let mut requests = self.their_requests.write().await;
975 requests.pop(hash_hex)
976 };
977
978 let Some(their_req) = their_req else {
979 return Ok(false);
980 };
981
982 let dc = self.data_channel.read().await;
983 let dc = dc.as_ref().ok_or(PeerError::NotReady)?;
984
985 Self::send_response(dc, &their_req.hash, data.to_vec(), self.debug).await;
987
988 if self.debug {
989 println!(
990 "[Peer] Sent data for hash: {}...",
991 &hash_hex[..16.min(hash_hex.len())]
992 );
993 }
994
995 Ok(true)
996 }
997
998 pub async fn has_requested(&self, hash_hex: &str) -> bool {
1000 self.their_requests.read().await.peek(hash_hex).is_some()
1001 }
1002
1003 pub async fn their_request_count(&self) -> usize {
1005 self.their_requests.read().await.len()
1006 }
1007
1008 pub async fn our_request_count(&self) -> usize {
1010 self.pending_requests.read().await.len()
1011 }
1012}
1013
1014#[cfg(test)]
1015mod tests {
1016 use super::*;
1017 use std::collections::HashMap;
1018
1019 #[tokio::test]
1020 async fn test_fragment_reassembly_completes_and_clears_pending() {
1021 let pending = Arc::new(RwLock::new(HashMap::new()));
1022 let hash = vec![0x11u8; 32];
1023
1024 let first = DataResponse {
1025 h: hash.clone(),
1026 d: b"world".to_vec(),
1027 i: Some(1),
1028 n: Some(2),
1029 };
1030 let second = DataResponse {
1031 h: hash.clone(),
1032 d: b"hello ".to_vec(),
1033 i: Some(0),
1034 n: Some(2),
1035 };
1036
1037 let incomplete =
1038 Peer::<hashtree_core::MemoryStore>::handle_fragment_response(&first, &pending, false)
1039 .await;
1040 assert!(incomplete.is_none());
1041 assert_eq!(pending.read().await.len(), 1);
1042
1043 let completed =
1044 Peer::<hashtree_core::MemoryStore>::handle_fragment_response(&second, &pending, false)
1045 .await;
1046 assert_eq!(completed, Some(b"hello world".to_vec()));
1047 assert_eq!(pending.read().await.len(), 0);
1048 }
1049
1050 #[tokio::test]
1051 async fn test_fragment_reassembly_ignores_duplicate_fragment() {
1052 let pending = Arc::new(RwLock::new(HashMap::new()));
1053 let hash = vec![0x22u8; 32];
1054
1055 let frag0 = DataResponse {
1056 h: hash.clone(),
1057 d: b"abc".to_vec(),
1058 i: Some(0),
1059 n: Some(2),
1060 };
1061 let frag1 = DataResponse {
1062 h: hash,
1063 d: b"def".to_vec(),
1064 i: Some(1),
1065 n: Some(2),
1066 };
1067
1068 let r1 =
1070 Peer::<hashtree_core::MemoryStore>::handle_fragment_response(&frag0, &pending, false)
1071 .await;
1072 assert!(r1.is_none());
1073
1074 let r2 =
1076 Peer::<hashtree_core::MemoryStore>::handle_fragment_response(&frag0, &pending, false)
1077 .await;
1078 assert!(r2.is_none());
1079
1080 let r3 =
1081 Peer::<hashtree_core::MemoryStore>::handle_fragment_response(&frag1, &pending, false)
1082 .await;
1083 assert_eq!(r3, Some(b"abcdef".to_vec()));
1084 assert_eq!(pending.read().await.len(), 0);
1085 }
1086}