1use anyhow::Result;
4use bytes::Bytes;
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::{mpsc, oneshot, Mutex};
8use tracing::{debug, error, info, warn};
9use webrtc::api::interceptor_registry::register_default_interceptors;
10use webrtc::api::media_engine::MediaEngine;
11use webrtc::api::setting_engine::SettingEngine;
12use webrtc::api::APIBuilder;
13use webrtc::data_channel::data_channel_init::RTCDataChannelInit;
14use webrtc::data_channel::data_channel_message::DataChannelMessage;
15use webrtc::data_channel::RTCDataChannel;
16use webrtc::ice_transport::ice_candidate::RTCIceCandidate;
17use webrtc::ice_transport::ice_server::RTCIceServer;
18use webrtc::interceptor::registry::Registry;
19use webrtc::peer_connection::configuration::RTCConfiguration;
20use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
21use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
22use webrtc::peer_connection::RTCPeerConnection;
23
24use super::types::{DataMessage, DataRequest, DataResponse, PeerDirection, PeerId, PeerStateEvent, SignalingMessage, encode_message, encode_request, encode_response, parse_message, hash_to_hex};
25
26pub trait ContentStore: Send + Sync + 'static {
28 fn get(&self, hash_hex: &str) -> Result<Option<Vec<u8>>>;
30}
31
32pub struct PendingRequest {
34 pub hash: Vec<u8>,
35 pub response_tx: oneshot::Sender<Option<Vec<u8>>>,
36}
37
38pub struct Peer {
40 pub peer_id: PeerId,
41 pub direction: PeerDirection,
42 pub created_at: std::time::Instant,
43 pub connected_at: Option<std::time::Instant>,
44
45 pc: Arc<RTCPeerConnection>,
46 pub data_channel: Arc<Mutex<Option<Arc<RTCDataChannel>>>>,
48 signaling_tx: mpsc::Sender<SignalingMessage>,
49 my_peer_id: PeerId,
50
51 store: Option<Arc<dyn ContentStore>>,
53
54 pub pending_requests: Arc<Mutex<HashMap<String, PendingRequest>>>,
56
57 #[allow(dead_code)]
59 message_tx: mpsc::Sender<(DataMessage, Option<Vec<u8>>)>,
60 #[allow(dead_code)]
61 message_rx: Option<mpsc::Receiver<(DataMessage, Option<Vec<u8>>)>>,
62
63 state_event_tx: Option<mpsc::Sender<PeerStateEvent>>,
65}
66
67impl Peer {
68 pub async fn new(
70 peer_id: PeerId,
71 direction: PeerDirection,
72 my_peer_id: PeerId,
73 signaling_tx: mpsc::Sender<SignalingMessage>,
74 stun_servers: Vec<String>,
75 ) -> Result<Self> {
76 Self::new_with_store_and_events(peer_id, direction, my_peer_id, signaling_tx, stun_servers, None, None).await
77 }
78
79 pub async fn new_with_store(
81 peer_id: PeerId,
82 direction: PeerDirection,
83 my_peer_id: PeerId,
84 signaling_tx: mpsc::Sender<SignalingMessage>,
85 stun_servers: Vec<String>,
86 store: Option<Arc<dyn ContentStore>>,
87 ) -> Result<Self> {
88 Self::new_with_store_and_events(peer_id, direction, my_peer_id, signaling_tx, stun_servers, store, None).await
89 }
90
91 pub async fn new_with_store_and_events(
93 peer_id: PeerId,
94 direction: PeerDirection,
95 my_peer_id: PeerId,
96 signaling_tx: mpsc::Sender<SignalingMessage>,
97 stun_servers: Vec<String>,
98 store: Option<Arc<dyn ContentStore>>,
99 state_event_tx: Option<mpsc::Sender<PeerStateEvent>>,
100 ) -> Result<Self> {
101 let mut m = MediaEngine::default();
103 m.register_default_codecs()?;
104
105 let mut registry = Registry::new();
106 registry = register_default_interceptors(registry, &mut m)?;
107
108 let setting_engine = SettingEngine::default();
111 let api = APIBuilder::new()
114 .with_media_engine(m)
115 .with_interceptor_registry(registry)
116 .with_setting_engine(setting_engine)
117 .build();
118
119 let ice_servers: Vec<RTCIceServer> = stun_servers
121 .iter()
122 .map(|url| RTCIceServer {
123 urls: vec![url.clone()],
124 ..Default::default()
125 })
126 .collect();
127
128 let config = RTCConfiguration {
129 ice_servers,
130 ..Default::default()
131 };
132
133 let pc = Arc::new(api.new_peer_connection(config).await?);
134 let (message_tx, message_rx) = mpsc::channel(100);
135 Ok(Self {
136 peer_id,
137 direction,
138 created_at: std::time::Instant::now(),
139 connected_at: None,
140 pc,
141 data_channel: Arc::new(Mutex::new(None)),
142 signaling_tx,
143 my_peer_id,
144 store,
145 pending_requests: Arc::new(Mutex::new(HashMap::new())),
146 message_tx,
147 message_rx: Some(message_rx),
148 state_event_tx,
149 })
150 }
151
152 pub fn set_store(&mut self, store: Arc<dyn ContentStore>) {
154 self.store = Some(store);
155 }
156
157 pub fn state(&self) -> RTCPeerConnectionState {
159 self.pc.connection_state()
160 }
161
162 pub fn signaling_state(&self) -> webrtc::peer_connection::signaling_state::RTCSignalingState {
164 self.pc.signaling_state()
165 }
166
167 pub fn is_connected(&self) -> bool {
169 self.pc.connection_state() == RTCPeerConnectionState::Connected
170 }
171
172 pub async fn setup_handlers(&mut self) -> Result<()> {
174 let peer_id = self.peer_id.clone();
175 let signaling_tx = self.signaling_tx.clone();
176 let my_peer_id_str = self.my_peer_id.to_string();
177 let recipient = self.peer_id.to_string();
178
179 self.pc
181 .on_ice_candidate(Box::new(move |candidate: Option<RTCIceCandidate>| {
182 let signaling_tx = signaling_tx.clone();
183 let my_peer_id_str = my_peer_id_str.clone();
184 let recipient = recipient.clone();
185
186 Box::pin(async move {
187 if let Some(c) = candidate {
188 if let Some(init) = c.to_json().ok() {
189 info!("ICE candidate generated: {}", &init.candidate[..init.candidate.len().min(60)]);
190 let msg = SignalingMessage::candidate(
191 serde_json::to_value(&init).unwrap_or_default(),
192 &recipient,
193 &my_peer_id_str,
194 );
195 if let Err(e) = signaling_tx.send(msg).await {
196 error!("Failed to send ICE candidate: {}", e);
197 }
198 }
199 }
200 })
201 }));
202
203 let peer_id_log = peer_id.clone();
205 let state_event_tx = self.state_event_tx.clone();
206 self.pc
207 .on_peer_connection_state_change(Box::new(move |state: RTCPeerConnectionState| {
208 let peer_id = peer_id_log.clone();
209 let state_event_tx = state_event_tx.clone();
210 Box::pin(async move {
211 info!("Peer {} connection state: {:?}", peer_id.short(), state);
212
213 if let Some(tx) = state_event_tx {
215 let event = match state {
216 RTCPeerConnectionState::Connected => Some(PeerStateEvent::Connected(peer_id)),
217 RTCPeerConnectionState::Failed => Some(PeerStateEvent::Failed(peer_id)),
218 RTCPeerConnectionState::Disconnected | RTCPeerConnectionState::Closed => {
219 Some(PeerStateEvent::Disconnected(peer_id))
220 }
221 _ => None,
222 };
223 if let Some(event) = event {
224 if let Err(e) = tx.send(event).await {
225 error!("Failed to send peer state event: {}", e);
226 }
227 }
228 }
229 })
230 }));
231
232 Ok(())
233 }
234
235 pub async fn connect(&mut self) -> Result<serde_json::Value> {
237 println!("[Peer {}] Creating data channel...", self.peer_id.short());
238 let dc_init = RTCDataChannelInit {
241 ordered: Some(false),
242 ..Default::default()
243 };
244 let dc = self.pc.create_data_channel("hashtree", Some(dc_init)).await?;
245 println!("[Peer {}] Data channel created, setting up handlers...", self.peer_id.short());
246 self.setup_data_channel(dc.clone()).await?;
247 println!("[Peer {}] Handlers set up, storing data channel...", self.peer_id.short());
248 {
249 let mut dc_guard = self.data_channel.lock().await;
250 *dc_guard = Some(dc);
251 }
252 println!("[Peer {}] Data channel stored", self.peer_id.short());
253
254 let offer = self.pc.create_offer(None).await?;
257 let mut gathering_complete = self.pc.gathering_complete_promise().await;
258 self.pc.set_local_description(offer).await?;
259
260 let _ = tokio::time::timeout(
262 std::time::Duration::from_secs(10),
263 gathering_complete.recv()
264 ).await;
265
266 let local_desc = self.pc.local_description().await
268 .ok_or_else(|| anyhow::anyhow!("No local description after gathering"))?;
269
270 debug!("Offer created, SDP len: {}, ice_gathering: {:?}",
271 local_desc.sdp.len(), self.pc.ice_gathering_state());
272
273 let offer_json = serde_json::json!({
275 "type": local_desc.sdp_type.to_string().to_lowercase(),
276 "sdp": local_desc.sdp
277 });
278
279 Ok(offer_json)
280 }
281
282 pub async fn handle_offer(&mut self, offer: serde_json::Value) -> Result<serde_json::Value> {
284 let sdp = offer
285 .get("sdp")
286 .and_then(|s| s.as_str())
287 .ok_or_else(|| anyhow::anyhow!("Missing SDP in offer"))?;
288
289 let peer_id = self.peer_id.clone();
292 let message_tx = self.message_tx.clone();
293 let pending_requests = self.pending_requests.clone();
294 let store = self.store.clone();
295 let data_channel_holder = self.data_channel.clone();
296
297 self.pc
298 .on_data_channel(Box::new(move |dc: Arc<RTCDataChannel>| {
299 let peer_id = peer_id.clone();
300 let message_tx = message_tx.clone();
301 let pending_requests = pending_requests.clone();
302 let store = store.clone();
303 let data_channel_holder = data_channel_holder.clone();
304
305 Box::pin(async move {
307 info!("Peer {} received data channel: {}", peer_id.short(), dc.label());
308
309 {
311 let mut dc_guard = data_channel_holder.lock().await;
312 *dc_guard = Some(dc.clone());
313 }
314
315 Self::setup_dc_handlers(
317 dc.clone(),
318 peer_id,
319 message_tx,
320 pending_requests,
321 store,
322 )
323 .await;
324 })
325 }));
326
327 let offer_desc = RTCSessionDescription::offer(sdp.to_string())?;
329 self.pc.set_remote_description(offer_desc).await?;
330
331 let answer = self.pc.create_answer(None).await?;
334 let mut gathering_complete = self.pc.gathering_complete_promise().await;
335 self.pc.set_local_description(answer).await?;
336
337 let _ = tokio::time::timeout(
339 std::time::Duration::from_secs(10),
340 gathering_complete.recv()
341 ).await;
342
343 let local_desc = self.pc.local_description().await
345 .ok_or_else(|| anyhow::anyhow!("No local description after gathering"))?;
346
347 debug!("Answer created, SDP len: {}, ice_gathering: {:?}",
348 local_desc.sdp.len(), self.pc.ice_gathering_state());
349
350 let answer_json = serde_json::json!({
351 "type": local_desc.sdp_type.to_string().to_lowercase(),
352 "sdp": local_desc.sdp
353 });
354
355 Ok(answer_json)
356 }
357
358 pub async fn handle_answer(&mut self, answer: serde_json::Value) -> Result<()> {
360 let sdp = answer
361 .get("sdp")
362 .and_then(|s| s.as_str())
363 .ok_or_else(|| anyhow::anyhow!("Missing SDP in answer"))?;
364
365 let answer_desc = RTCSessionDescription::answer(sdp.to_string())?;
366 self.pc.set_remote_description(answer_desc).await?;
367
368 Ok(())
369 }
370
371 pub async fn handle_candidate(&mut self, candidate: serde_json::Value) -> Result<()> {
373 let candidate_str = candidate
374 .get("candidate")
375 .and_then(|c| c.as_str())
376 .unwrap_or("");
377
378 let sdp_mid = candidate
379 .get("sdpMid")
380 .and_then(|m| m.as_str())
381 .map(|s| s.to_string());
382
383 let sdp_mline_index = candidate
384 .get("sdpMLineIndex")
385 .and_then(|i| i.as_u64())
386 .map(|i| i as u16);
387
388 if !candidate_str.is_empty() {
389 use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
390 let init = RTCIceCandidateInit {
391 candidate: candidate_str.to_string(),
392 sdp_mid,
393 sdp_mline_index,
394 username_fragment: candidate
395 .get("usernameFragment")
396 .and_then(|u| u.as_str())
397 .map(|s| s.to_string()),
398 };
399 self.pc.add_ice_candidate(init).await?;
400 }
401
402 Ok(())
403 }
404
405 async fn setup_data_channel(&mut self, dc: Arc<RTCDataChannel>) -> Result<()> {
407 let peer_id = self.peer_id.clone();
408 let message_tx = self.message_tx.clone();
409 let pending_requests = self.pending_requests.clone();
410 let store = self.store.clone();
411
412 Self::setup_dc_handlers(dc, peer_id, message_tx, pending_requests, store).await;
413 Ok(())
414 }
415
416 async fn setup_dc_handlers(
418 dc: Arc<RTCDataChannel>,
419 peer_id: PeerId,
420 message_tx: mpsc::Sender<(DataMessage, Option<Vec<u8>>)>,
421 pending_requests: Arc<Mutex<HashMap<String, PendingRequest>>>,
422 store: Option<Arc<dyn ContentStore>>,
423 ) {
424 let label = dc.label().to_string();
425 let peer_short = peer_id.short();
426
427 let _pending_binary: Arc<Mutex<Option<u32>>> = Arc::new(Mutex::new(None));
429
430 let _dc_for_open = dc.clone();
431 let peer_short_open = peer_short.clone();
432 let label_clone = label.clone();
433 dc.on_open(Box::new(move || {
434 let peer_short_open = peer_short_open.clone();
435 let label_clone = label_clone.clone();
436 Box::pin(async move {
438 info!("[Peer {}] Data channel '{}' open", peer_short_open, label_clone);
439 })
440 }));
441
442 let dc_for_msg = dc.clone();
443 let peer_short_msg = peer_short.clone();
444 let _pending_binary_clone = _pending_binary.clone();
445 let store_clone = store.clone();
446
447 dc.on_message(Box::new(move |msg: DataChannelMessage| {
448 let dc = dc_for_msg.clone();
449 let peer_short = peer_short_msg.clone();
450 let pending_requests = pending_requests.clone();
451 let _pending_binary = _pending_binary_clone.clone();
452 let _message_tx = message_tx.clone();
453 let store = store_clone.clone();
454 let msg_data = msg.data.clone();
455
456 Box::pin(async move {
458 debug!("[Peer {}] Received {} bytes on data channel", peer_short, msg_data.len());
460 match parse_message(&msg_data) {
461 Ok(data_msg) => match data_msg {
462 DataMessage::Request(req) => {
463 let hash_hex = hash_to_hex(&req.h);
464 let hash_short = &hash_hex[..8.min(hash_hex.len())];
465 info!(
466 "[Peer {}] Received request for {}",
467 peer_short, hash_short
468 );
469
470 let data = if let Some(ref store) = store {
472 match store.get(&hash_hex) {
473 Ok(Some(data)) => {
474 info!("[Peer {}] Found {} in store ({} bytes)", peer_short, hash_short, data.len());
475 Some(data)
476 },
477 Ok(None) => {
478 info!("[Peer {}] Hash {} not in store", peer_short, hash_short);
479 None
480 },
481 Err(e) => {
482 warn!("[Peer {}] Store error: {}", peer_short, e);
483 None
484 }
485 }
486 } else {
487 warn!("[Peer {}] No store configured - cannot serve requests", peer_short);
488 None
489 };
490
491 if let Some(data) = data {
493 let data_len = data.len();
494 let response = DataResponse {
495 h: req.h,
496 d: data,
497 };
498 if let Ok(wire) = encode_response(&response) {
499 if let Err(e) = dc.send(&Bytes::from(wire)).await {
500 error!(
501 "[Peer {}] Failed to send response: {}",
502 peer_short, e
503 );
504 } else {
505 info!(
506 "[Peer {}] Sent response for {} ({} bytes)",
507 peer_short, hash_short, data_len
508 );
509 }
510 }
511 } else {
512 info!("[Peer {}] Content not found for {}", peer_short, hash_short);
513 }
514 }
515 DataMessage::Response(res) => {
516 let hash_hex = hash_to_hex(&res.h);
517 let hash_short = &hash_hex[..8.min(hash_hex.len())];
518 debug!(
519 "[Peer {}] Received response for {} ({} bytes)",
520 peer_short, hash_short, res.d.len()
521 );
522
523 let mut pending = pending_requests.lock().await;
525 if let Some(req) = pending.remove(&hash_hex) {
526 let _ = req.response_tx.send(Some(res.d));
527 }
528 }
529 },
530 Err(e) => {
531 warn!("[Peer {}] Failed to parse message: {:?}", peer_short, e);
532 let hex_dump: String = msg_data.iter().take(50).map(|b| format!("{:02x}", b)).collect();
534 warn!("[Peer {}] Message hex: {}", peer_short, hex_dump);
535 }
536 }
537 })
538 }));
539 }
540
541 pub fn has_data_channel(&self) -> bool {
543 self.data_channel
545 .try_lock()
546 .map(|guard| guard.is_some())
547 .unwrap_or(false)
548 }
549
550 pub async fn request(&self, hash_hex: &str) -> Result<Option<Vec<u8>>> {
552 let dc_guard = self.data_channel.lock().await;
553 let dc = dc_guard
554 .as_ref()
555 .ok_or_else(|| anyhow::anyhow!("No data channel"))?
556 .clone();
557 drop(dc_guard); let hash = hex::decode(hash_hex)
561 .map_err(|e| anyhow::anyhow!("Invalid hex hash: {}", e))?;
562
563 let (tx, rx) = oneshot::channel();
565
566 {
568 let mut pending = self.pending_requests.lock().await;
569 pending.insert(
570 hash_hex.to_string(),
571 PendingRequest {
572 hash: hash.clone(),
573 response_tx: tx,
574 },
575 );
576 }
577
578 let req = DataRequest {
580 h: hash,
581 htl: crate::webrtc::types::MAX_HTL,
582 };
583 let wire = encode_request(&req)?;
584 dc.send(&Bytes::from(wire)).await?;
585
586 debug!(
587 "[Peer {}] Sent request for {}",
588 self.peer_id.short(),
589 &hash_hex[..8.min(hash_hex.len())]
590 );
591
592 match tokio::time::timeout(std::time::Duration::from_secs(10), rx).await {
594 Ok(Ok(data)) => Ok(data),
595 Ok(Err(_)) => {
596 Ok(None)
598 }
599 Err(_) => {
600 let mut pending = self.pending_requests.lock().await;
602 pending.remove(hash_hex);
603 Ok(None)
604 }
605 }
606 }
607
608 pub async fn send_message(&self, msg: &DataMessage) -> Result<()> {
610 let dc_guard = self.data_channel.lock().await;
611 if let Some(ref dc) = *dc_guard {
612 let wire = encode_message(msg)?;
613 dc.send(&Bytes::from(wire)).await?;
614 }
615 Ok(())
616 }
617
618 pub async fn close(&self) -> Result<()> {
620 {
621 let dc_guard = self.data_channel.lock().await;
622 if let Some(ref dc) = *dc_guard {
623 dc.close().await?;
624 }
625 }
626 self.pc.close().await?;
627 Ok(())
628 }
629}