1use crate::conn_pool::ConnPool;
2use crate::discovery::Discovery;
3use crate::error::{Result, WebTorrentError};
4use crate::nat::NatTraversal;
5use crate::throttling::ThrottleGroup;
6use crate::torrent::Torrent;
7use bytes::Bytes;
8use rand::Rng;
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::sync::{mpsc, RwLock};
12use tokio::time::{Duration, Instant};
13
14#[derive(Debug, Clone)]
16pub struct WebTorrentOptions {
17 pub peer_id: Option<[u8; 20]>,
18 pub node_id: Option<[u8; 20]>,
19 pub torrent_port: u16,
20 pub dht_port: u16,
21 pub max_conns: usize,
22 pub utp: bool,
23 pub nat_upnp: bool,
24 pub nat_pmp: bool,
25 pub lsd: bool,
26 pub ut_pex: bool,
27 pub seed_outgoing_connections: bool,
28 pub download_limit: Option<u64>, pub upload_limit: Option<u64>, pub blocklist: Option<String>,
31 pub tracker: Option<TrackerConfig>,
32 pub web_seeds: bool,
33}
34
35impl Default for WebTorrentOptions {
36 fn default() -> Self {
37 Self {
38 peer_id: None,
39 node_id: None,
40 torrent_port: 0,
41 dht_port: 0,
42 max_conns: 55,
43 utp: true,
44 nat_upnp: true,
45 nat_pmp: true,
46 lsd: true,
47 ut_pex: true,
48 seed_outgoing_connections: true,
49 download_limit: None,
50 upload_limit: None,
51 blocklist: None,
52 tracker: None,
53 web_seeds: true,
54 }
55 }
56}
57
58pub struct TrackerConfig {
59 pub announce: Vec<String>,
60 #[cfg_attr(not(test), allow(dead_code))]
61 pub get_announce_opts: Option<Box<dyn Fn() -> HashMap<String, String> + Send + Sync>>,
62}
63
64impl std::fmt::Debug for TrackerConfig {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("TrackerConfig")
67 .field("announce", &self.announce)
68 .field("get_announce_opts", &"<function>")
69 .finish()
70 }
71}
72
73impl Clone for TrackerConfig {
74 fn clone(&self) -> Self {
75 Self {
76 announce: self.announce.clone(),
77 get_announce_opts: None, }
79 }
80}
81
82pub struct WebTorrent {
84 pub(crate) peer_id: [u8; 20],
85 pub(crate) node_id: [u8; 20],
86 pub(crate) options: WebTorrentOptions,
87 pub(crate) torrents: Arc<RwLock<Vec<Arc<Torrent>>>>,
88 pub(crate) conn_pool: Arc<RwLock<Option<Arc<ConnPool>>>>,
89 pub(crate) nat_traversal: Option<Arc<NatTraversal>>,
90 pub(crate) dht: Option<Arc<Discovery>>,
91 pub(crate) destroyed: Arc<RwLock<bool>>,
92 pub(crate) listening: Arc<RwLock<bool>>,
93 pub(crate) ready: Arc<RwLock<bool>>,
94 pub(crate) torrent_port: Arc<RwLock<u16>>,
95 pub(crate) dht_port: Arc<RwLock<u16>>,
96 pub(crate) event_tx: mpsc::UnboundedSender<ClientEvent>,
97 pub(crate) event_rx: Arc<RwLock<mpsc::UnboundedReceiver<ClientEvent>>>,
98 download_speed_tracker: Arc<SpeedTracker>,
100 upload_speed_tracker: Arc<SpeedTracker>,
101 download_throttle: Arc<ThrottleGroup>,
103 upload_throttle: Arc<ThrottleGroup>,
104}
105
106struct SpeedTracker {
108 bytes: Arc<RwLock<Vec<(Instant, u64)>>>, window: Duration,
110}
111
112impl SpeedTracker {
113 fn new(window: Duration) -> Self {
114 Self {
115 bytes: Arc::new(RwLock::new(Vec::new())),
116 window,
117 }
118 }
119
120 async fn add_bytes(&self, amount: u64) {
121 let now = Instant::now();
122 let mut bytes = self.bytes.write().await;
123 bytes.push((now, amount));
124
125 let cutoff = now.checked_sub(self.window).unwrap_or(Instant::now());
127 bytes.retain(|(time, _)| *time > cutoff);
128 }
129
130 async fn get_speed(&self) -> u64 {
131 let bytes = self.bytes.read().await;
132 if bytes.is_empty() {
133 return 0;
134 }
135
136 let now = Instant::now();
137 let cutoff = now.checked_sub(self.window).unwrap_or(Instant::now());
138
139 let total_bytes: u64 = bytes.iter()
140 .filter(|(time, _)| *time > cutoff)
141 .map(|(_, amount)| *amount)
142 .sum();
143
144 let oldest_time = bytes.iter()
145 .filter(|(time, _)| *time > cutoff)
146 .map(|(time, _)| *time)
147 .min();
148
149 let elapsed = if let Some(oldest) = oldest_time {
150 now.duration_since(oldest)
151 } else {
152 Duration::from_secs(1) };
154 if elapsed.as_secs_f64() > 0.0 {
155 (total_bytes as f64 / elapsed.as_secs_f64()) as u64
156 } else {
157 0
158 }
159 }
160}
161
162#[derive(Clone)]
163pub enum ClientEvent {
164 Ready,
165 Listening,
166 TorrentAdded(Arc<Torrent>),
167 TorrentRemoved(Arc<Torrent>),
168 Error(String), Download(u64),
170 Upload(u64),
171}
172
173impl WebTorrent {
174 pub fn peer_id(&self) -> [u8; 20] {
176 self.peer_id
177 }
178
179 pub async fn new(options: WebTorrentOptions) -> Result<Self> {
181 let peer_id = options.peer_id.unwrap_or_else(|| {
182 let mut id = [0u8; 20];
183 id[0..3].copy_from_slice(b"-WW");
184 let version_str = format!("{:04}", env!("CARGO_PKG_VERSION_MAJOR").parse::<u16>().unwrap_or(1) * 100 +
186 env!("CARGO_PKG_VERSION_MINOR").parse::<u16>().unwrap_or(0));
187 let version_bytes = version_str.as_bytes();
188 if version_bytes.len() >= 4 {
190 id[3..7].copy_from_slice(&version_bytes[0..4]);
191 } else {
192 id[3..3+version_bytes.len()].copy_from_slice(version_bytes);
194 }
195 id[7] = b'-';
196 let mut rng = rand::thread_rng();
198 rng.fill(&mut id[8..]);
199 id
200 });
201
202 let node_id = options.node_id.unwrap_or_else(|| {
203 let mut id = [0u8; 20];
204 let mut rng = rand::thread_rng();
205 rng.fill(&mut id);
206 id
207 });
208
209 let (event_tx, event_rx) = mpsc::unbounded_channel();
210
211 let download_speed_tracker = Arc::new(SpeedTracker::new(Duration::from_secs(1)));
213 let upload_speed_tracker = Arc::new(SpeedTracker::new(Duration::from_secs(1)));
214
215 let download_throttle = Arc::new(ThrottleGroup::new(
217 options.download_limit.unwrap_or(u64::MAX),
218 options.download_limit.is_some(),
219 ));
220 let upload_throttle = Arc::new(ThrottleGroup::new(
221 options.upload_limit.unwrap_or(u64::MAX),
222 options.upload_limit.is_some(),
223 ));
224
225 let mut client = Self {
226 peer_id,
227 node_id,
228 options: options.clone(),
229 torrents: Arc::new(RwLock::new(Vec::new())),
230 conn_pool: Arc::new(RwLock::new(None)),
231 nat_traversal: None,
232 dht: None,
233 destroyed: Arc::new(RwLock::new(false)),
234 listening: Arc::new(RwLock::new(false)),
235 ready: Arc::new(RwLock::new(false)),
236 torrent_port: Arc::new(RwLock::new(options.torrent_port)),
237 dht_port: Arc::new(RwLock::new(options.dht_port)),
238 event_tx,
239 event_rx: Arc::new(RwLock::new(event_rx)),
240 download_speed_tracker,
241 upload_speed_tracker,
242 download_throttle,
243 upload_throttle,
244 };
245
246 if options.nat_upnp || options.nat_pmp {
248 let nat = Arc::new(NatTraversal::new(options.nat_upnp, options.nat_pmp).await?);
249 client.nat_traversal = Some(nat);
250 }
251
252 Ok(client)
259 }
260
261 pub async fn add(&self, torrent_id: impl Into<TorrentId>) -> Result<Arc<Torrent>> {
263 if *self.destroyed.read().await {
264 return Err(WebTorrentError::ClientDestroyed);
265 }
266
267 {
270 let port = *self.torrent_port.read().await;
271 if port > 0 {
272 let mut conn_pool_guard = self.conn_pool.write().await;
273 if conn_pool_guard.is_none() {
274 let client_for_pool = Arc::new(self.clone());
277 match ConnPool::new(client_for_pool).await {
278 Ok(pool) => {
279 *conn_pool_guard = Some(Arc::new(pool));
280 *self.listening.write().await = true;
281 }
282 Err(e) => {
283 eprintln!("Warning: Failed to initialize connection pool: {}. Tracker announcements may not work.", e);
284 *self.listening.write().await = true;
286 }
287 }
288 } else {
289 *self.listening.write().await = true;
290 }
291 }
292 }
293
294 let torrent_id = torrent_id.into();
295 let torrent = Torrent::new(torrent_id, self.clone()).await?;
296
297 let info_hash = torrent.info_hash();
299 let torrents = self.torrents.read().await;
300 for existing in torrents.iter() {
301 if existing.info_hash() == info_hash {
302 return Err(WebTorrentError::DuplicateTorrent(hex::encode(info_hash)));
303 }
304 }
305 drop(torrents);
306
307 let torrent = Arc::new(torrent);
308
309 torrent.start_discovery().await?;
311
312 self.torrents.write().await.push(torrent.clone());
313
314 self.event_tx.send(ClientEvent::TorrentAdded(torrent.clone()))
315 .map_err(|_| WebTorrentError::Network("Event channel closed".to_string()))?;
316
317 Ok(torrent)
318 }
319
320 pub async fn remove(&self, torrent: Arc<Torrent>) -> Result<()> {
322 if *self.destroyed.read().await {
323 return Err(WebTorrentError::ClientDestroyed);
324 }
325
326 let mut torrents = self.torrents.write().await;
327 if let Some(pos) = torrents.iter().position(|t| Arc::ptr_eq(t, &torrent)) {
328 torrents.remove(pos);
329 torrent.destroy().await?;
330 self.event_tx.send(ClientEvent::TorrentRemoved(torrent))
331 .map_err(|_| WebTorrentError::Network("Event channel closed".to_string()))?;
332 }
333
334 Ok(())
335 }
336
337 pub async fn get(&self, info_hash: &[u8; 20]) -> Option<Arc<Torrent>> {
339 let torrents = self.torrents.read().await;
340 torrents.iter().find(|t| t.info_hash() == *info_hash).cloned()
341 }
342
343 pub async fn seed(
345 &self,
346 name: String,
347 data: Bytes,
348 announce: Option<Vec<String>>,
349 ) -> Result<Arc<Torrent>> {
350 if *self.destroyed.read().await {
351 return Err(WebTorrentError::ClientDestroyed);
352 }
353
354 use crate::torrent_creator::TorrentCreator;
355
356 let announce_list = announce.unwrap_or_else(|| {
358 vec!["http://dig-relay-prod.eba-2cmanxbe.us-east-1.elasticbeanstalk.com:8000/announce".to_string()]
359 });
360
361 let creator = TorrentCreator::new()
363 .with_announce(announce_list.clone());
364
365 let (torrent_file, info_hash) = creator.create_from_data(name.clone(), data.clone()).await?;
366
367 if self.get(&info_hash).await.is_some() {
369 return Err(WebTorrentError::DuplicateTorrent(hex::encode(info_hash)));
370 }
371
372 let torrent = self.add(torrent_file).await?;
374
375 Ok(torrent)
379 }
380
381 pub async fn download_speed(&self) -> u64 {
383 self.download_speed_tracker.get_speed().await
384 }
385
386 pub async fn upload_speed(&self) -> u64 {
388 self.upload_speed_tracker.get_speed().await
389 }
390
391 #[cfg_attr(test, allow(dead_code))]
393 pub(crate) async fn record_download(&self, bytes: u64) {
394 if bytes > 0 {
395 self.download_speed_tracker.add_bytes(bytes).await;
396 let _ = self.event_tx.send(ClientEvent::Download(bytes));
397 }
398 }
399
400 #[allow(dead_code)]
402 pub(crate) async fn record_upload(&self, bytes: u64) {
403 if bytes > 0 {
404 self.upload_speed_tracker.add_bytes(bytes).await;
405 let _ = self.event_tx.send(ClientEvent::Upload(bytes));
406 }
407 }
408
409 pub async fn progress(&self) -> f64 {
411 let torrents = self.torrents.read().await;
412 let mut total_downloaded = 0u64;
413 let mut total_length = 0u64;
414
415 for torrent in torrents.iter() {
416 if torrent.progress().await < 1.0 {
417 total_downloaded += torrent.downloaded().await;
418 total_length += torrent.length().await;
419 }
420 }
421
422 if total_length == 0 {
423 return 1.0;
424 }
425
426 total_downloaded as f64 / total_length as f64
427 }
428
429 pub async fn ratio(&self) -> f64 {
431 let torrents = self.torrents.read().await;
432 let mut total_uploaded = 0u64;
433 let mut total_received = 0u64;
434
435 for torrent in torrents.iter() {
436 total_uploaded += torrent.uploaded().await;
437 total_received += torrent.received().await;
438 }
439
440 if total_received == 0 {
441 return 0.0;
442 }
443
444 total_uploaded as f64 / total_received as f64
445 }
446
447 pub async fn throttle_download(&self, rate: Option<u64>) {
449 if let Some(rate) = rate {
450 self.download_throttle.set_rate(rate).await;
451 self.download_throttle.set_enabled(true).await;
452 } else {
453 self.download_throttle.set_enabled(false).await;
454 }
455 }
456
457 pub async fn throttle_upload(&self, rate: Option<u64>) {
459 if let Some(rate) = rate {
460 self.upload_throttle.set_rate(rate).await;
461 self.upload_throttle.set_enabled(true).await;
462 } else {
463 self.upload_throttle.set_enabled(false).await;
464 }
465 }
466
467 #[allow(dead_code)]
469 pub(crate) fn download_throttle(&self) -> Arc<ThrottleGroup> {
470 Arc::clone(&self.download_throttle)
471 }
472
473 #[allow(dead_code)]
475 pub(crate) fn upload_throttle(&self) -> Arc<ThrottleGroup> {
476 Arc::clone(&self.upload_throttle)
477 }
478
479 pub async fn destroy(&self) -> Result<()> {
481 if *self.destroyed.read().await {
482 return Err(WebTorrentError::ClientDestroyed);
483 }
484
485 *self.destroyed.write().await = true;
486
487 let torrents = self.torrents.read().await.clone();
489 for torrent in torrents {
490 let _ = torrent.destroy().await;
491 }
492
493 if let Some(conn_pool) = self.conn_pool.read().await.as_ref() {
495 conn_pool.destroy().await?;
496 }
497
498 if let Some(nat) = &self.nat_traversal {
500 nat.destroy().await?;
501 }
502
503 if let Some(dht) = &self.dht {
505 dht.destroy().await?;
506 }
507
508 Ok(())
509 }
510
511 pub async fn address(&self) -> Option<(String, u16)> {
513 if !*self.listening.read().await {
514 return None;
515 }
516
517 let port = *self.torrent_port.read().await;
518 Some(("0.0.0.0".to_string(), port))
519 }
520}
521
522impl Clone for WebTorrent {
523 fn clone(&self) -> Self {
524 Self {
525 peer_id: self.peer_id,
526 node_id: self.node_id,
527 options: self.options.clone(),
528 torrents: Arc::clone(&self.torrents),
529 conn_pool: Arc::clone(&self.conn_pool),
530 nat_traversal: self.nat_traversal.clone(),
531 dht: self.dht.clone(),
532 destroyed: Arc::clone(&self.destroyed),
533 listening: Arc::clone(&self.listening),
534 ready: Arc::clone(&self.ready),
535 torrent_port: Arc::clone(&self.torrent_port),
536 dht_port: Arc::clone(&self.dht_port),
537 event_tx: self.event_tx.clone(),
538 event_rx: Arc::clone(&self.event_rx),
539 download_speed_tracker: Arc::clone(&self.download_speed_tracker),
540 upload_speed_tracker: Arc::clone(&self.upload_speed_tracker),
541 download_throttle: Arc::clone(&self.download_throttle),
542 upload_throttle: Arc::clone(&self.upload_throttle),
543 }
544 }
545}
546
547#[derive(Debug, Clone)]
549pub enum TorrentId {
550 InfoHash([u8; 20]),
551 MagnetUri(String),
552 TorrentFile(Bytes),
553 Url(String),
554}
555
556impl From<[u8; 20]> for TorrentId {
557 fn from(hash: [u8; 20]) -> Self {
558 TorrentId::InfoHash(hash)
559 }
560}
561
562impl From<String> for TorrentId {
563 fn from(s: String) -> Self {
564 if s.starts_with("magnet:") {
565 TorrentId::MagnetUri(s)
566 } else if s.starts_with("http://") || s.starts_with("https://") {
567 TorrentId::Url(s)
568 } else {
569 if let Ok(bytes) = hex::decode(&s) {
571 if bytes.len() == 20 {
572 let mut hash = [0u8; 20];
573 hash.copy_from_slice(&bytes);
574 return TorrentId::InfoHash(hash);
575 }
576 }
577 TorrentId::MagnetUri(s) }
579 }
580}
581
582impl From<&str> for TorrentId {
583 fn from(s: &str) -> Self {
584 s.to_string().into()
585 }
586}
587
588impl From<Bytes> for TorrentId {
589 fn from(bytes: Bytes) -> Self {
590 TorrentId::TorrentFile(bytes)
591 }
592}
593
594impl From<Vec<u8>> for TorrentId {
595 fn from(bytes: Vec<u8>) -> Self {
596 TorrentId::TorrentFile(bytes.into())
597 }
598}
599