1use crate::want_list::Priority;
40use bytes::Bytes;
41use cid::Cid;
42use dashmap::DashMap;
43use parking_lot::RwLock;
44use std::sync::Arc;
45use std::time::{Duration, Instant};
46use thiserror::Error;
47use tokio::sync::{mpsc, watch};
48use tracing::{debug, info, warn};
49
50pub type SessionId = u64;
52
53#[derive(Error, Debug)]
55pub enum SessionError {
56 #[error("Session not found: {0}")]
57 NotFound(SessionId),
58
59 #[error("Session already exists: {0}")]
60 AlreadyExists(SessionId),
61
62 #[error("Session closed: {0}")]
63 Closed(SessionId),
64
65 #[error("Block not in session: {0}")]
66 BlockNotInSession(String),
67
68 #[error("Timeout waiting for session completion")]
69 Timeout,
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub enum SessionState {
75 Active,
77 Paused,
79 Completing,
81 Completed,
83 Cancelled,
85}
86
87#[derive(Debug, Clone)]
89pub struct SessionConfig {
90 pub timeout: Duration,
92 pub default_priority: Priority,
94 pub max_concurrent_blocks: usize,
96 pub progress_notifications: bool,
98}
99
100impl Default for SessionConfig {
101 fn default() -> Self {
102 Self {
103 timeout: Duration::from_secs(300), default_priority: Priority::Normal,
105 max_concurrent_blocks: 100,
106 progress_notifications: true,
107 }
108 }
109}
110
111#[derive(Debug, Clone, Default)]
113pub struct SessionStats {
114 pub total_blocks: usize,
116 pub blocks_received: usize,
118 pub blocks_failed: usize,
120 pub bytes_transferred: u64,
122 pub started_at: Option<Instant>,
124 pub completed_at: Option<Instant>,
126 pub avg_block_time: Option<Duration>,
128}
129
130impl SessionStats {
131 pub fn progress(&self) -> f64 {
133 if self.total_blocks == 0 {
134 return 0.0;
135 }
136 (self.blocks_received as f64 / self.total_blocks as f64) * 100.0
137 }
138
139 pub fn throughput(&self) -> Option<f64> {
141 if let (Some(started), Some(completed)) = (self.started_at, self.completed_at) {
142 let duration = completed.duration_since(started).as_secs_f64();
143 if duration > 0.0 {
144 return Some(self.bytes_transferred as f64 / duration);
145 }
146 }
147 None
148 }
149
150 pub fn is_complete(&self) -> bool {
152 self.blocks_received + self.blocks_failed >= self.total_blocks
153 }
154}
155
156#[derive(Debug, Clone)]
158#[allow(dead_code)]
159struct BlockRequest {
160 cid: Cid,
161 priority: Priority,
162 requested_at: Instant,
163 completed_at: Option<Instant>,
164 size: Option<usize>,
165}
166
167#[derive(Debug, Clone)]
169pub enum SessionEvent {
170 Started { session_id: SessionId },
172 BlockReceived {
174 session_id: SessionId,
175 cid: Cid,
176 size: usize,
177 },
178 BlockFailed {
180 session_id: SessionId,
181 cid: Cid,
182 error: String,
183 },
184 Progress {
186 session_id: SessionId,
187 stats: SessionStats,
188 },
189 Completed {
191 session_id: SessionId,
192 stats: SessionStats,
193 },
194 Cancelled { session_id: SessionId },
196}
197
198pub struct Session {
200 id: SessionId,
201 config: SessionConfig,
202 state: Arc<RwLock<SessionState>>,
203 blocks: Arc<DashMap<Cid, BlockRequest>>,
204 stats: Arc<RwLock<SessionStats>>,
205 event_tx: Option<mpsc::UnboundedSender<SessionEvent>>,
206 state_rx: watch::Receiver<SessionState>,
207 state_tx: watch::Sender<SessionState>,
208}
209
210impl Session {
211 pub fn new(
213 id: SessionId,
214 config: SessionConfig,
215 event_tx: Option<mpsc::UnboundedSender<SessionEvent>>,
216 ) -> Self {
217 let (state_tx, state_rx) = watch::channel(SessionState::Active);
218
219 let session = Self {
220 id,
221 config,
222 state: Arc::new(RwLock::new(SessionState::Active)),
223 blocks: Arc::new(DashMap::new()),
224 stats: Arc::new(RwLock::new(SessionStats {
225 started_at: Some(Instant::now()),
226 ..Default::default()
227 })),
228 event_tx,
229 state_rx,
230 state_tx,
231 };
232
233 if let Some(tx) = &session.event_tx {
235 let _ = tx.send(SessionEvent::Started { session_id: id });
236 }
237
238 session
239 }
240
241 pub fn id(&self) -> SessionId {
243 self.id
244 }
245
246 pub fn state(&self) -> SessionState {
248 *self.state.read()
249 }
250
251 pub fn add_block(&self, cid: Cid, priority: Option<Priority>) -> Result<(), SessionError> {
253 let state = *self.state.read();
254 if state != SessionState::Active {
255 return Err(SessionError::Closed(self.id));
256 }
257
258 let priority = priority.unwrap_or(self.config.default_priority);
259
260 let request = BlockRequest {
261 cid,
262 priority,
263 requested_at: Instant::now(),
264 completed_at: None,
265 size: None,
266 };
267
268 self.blocks.insert(cid, request);
269
270 {
272 let mut stats = self.stats.write();
273 stats.total_blocks += 1;
274 }
275
276 debug!("Added block {} to session {}", cid, self.id);
277
278 Ok(())
279 }
280
281 pub fn add_blocks(&self, cids: &[Cid], priority: Option<Priority>) -> Result<(), SessionError> {
283 for cid in cids {
284 self.add_block(*cid, priority)?;
285 }
286 Ok(())
287 }
288
289 pub fn mark_received(&self, cid: &Cid, data: &Bytes) -> Result<(), SessionError> {
291 let mut block = self
292 .blocks
293 .get_mut(cid)
294 .ok_or_else(|| SessionError::BlockNotInSession(cid.to_string()))?;
295
296 block.completed_at = Some(Instant::now());
297 block.size = Some(data.len());
298
299 let should_complete = {
301 let mut stats = self.stats.write();
302 stats.blocks_received += 1;
303 stats.bytes_transferred += data.len() as u64;
304
305 let fetch_time = block
307 .completed_at
308 .unwrap()
309 .duration_since(block.requested_at);
310 stats.avg_block_time = Some(
311 stats
312 .avg_block_time
313 .map(|avg| (avg + fetch_time) / 2)
314 .unwrap_or(fetch_time),
315 );
316
317 let is_complete = stats.is_complete() && self.state() == SessionState::Active;
319 if is_complete {
320 stats.completed_at = Some(Instant::now());
321 }
322 is_complete
323 }; if should_complete {
327 self.transition_state(SessionState::Completed);
328 }
329
330 if let Some(tx) = &self.event_tx {
332 let _ = tx.send(SessionEvent::BlockReceived {
333 session_id: self.id,
334 cid: *cid,
335 size: data.len(),
336 });
337
338 if self.config.progress_notifications {
339 let _ = tx.send(SessionEvent::Progress {
340 session_id: self.id,
341 stats: self.stats.read().clone(),
342 });
343 }
344 }
345
346 debug!("Block {} received in session {}", cid, self.id);
347
348 Ok(())
349 }
350
351 pub fn mark_failed(&self, cid: &Cid, error: String) -> Result<(), SessionError> {
353 let _block = self
354 .blocks
355 .get(cid)
356 .ok_or_else(|| SessionError::BlockNotInSession(cid.to_string()))?;
357
358 {
360 let mut stats = self.stats.write();
361 stats.blocks_failed += 1;
362
363 if stats.is_complete() && self.state() == SessionState::Active {
365 stats.completed_at = Some(Instant::now());
366 self.transition_state(SessionState::Completed);
367 }
368 }
369
370 if let Some(tx) = &self.event_tx {
372 let _ = tx.send(SessionEvent::BlockFailed {
373 session_id: self.id,
374 cid: *cid,
375 error: error.clone(),
376 });
377 }
378
379 warn!("Block {} failed in session {}: {}", cid, self.id, error);
380
381 Ok(())
382 }
383
384 pub fn pause(&self) {
386 self.transition_state(SessionState::Paused);
387 info!("Session {} paused", self.id);
388 }
389
390 pub fn resume(&self) {
392 self.transition_state(SessionState::Active);
393 info!("Session {} resumed", self.id);
394 }
395
396 pub fn cancel(&self) {
398 self.transition_state(SessionState::Cancelled);
399
400 if let Some(tx) = &self.event_tx {
401 let _ = tx.send(SessionEvent::Cancelled {
402 session_id: self.id,
403 });
404 }
405
406 info!("Session {} cancelled", self.id);
407 }
408
409 pub fn stats(&self) -> SessionStats {
411 self.stats.read().clone()
412 }
413
414 pub fn pending_blocks(&self) -> Vec<Cid> {
416 self.blocks
417 .iter()
418 .filter(|entry| entry.value().completed_at.is_none())
419 .map(|entry| *entry.key())
420 .collect()
421 }
422
423 pub async fn wait_completion(&self) -> Result<SessionStats, SessionError> {
425 let mut rx = self.state_rx.clone();
426
427 let state = *self.state.read();
429 if state == SessionState::Completed || state == SessionState::Cancelled {
430 return Ok(self.stats.read().clone());
431 }
432
433 loop {
435 if rx.changed().await.is_err() {
436 return Err(SessionError::Closed(self.id));
437 }
438
439 let state = *rx.borrow();
440 if state == SessionState::Completed || state == SessionState::Cancelled {
441 return Ok(self.stats.read().clone());
442 }
443 }
444 }
445
446 fn transition_state(&self, new_state: SessionState) {
448 *self.state.write() = new_state;
449 let _ = self.state_tx.send(new_state);
450
451 if new_state == SessionState::Completed {
452 if let Some(tx) = &self.event_tx {
453 let _ = tx.send(SessionEvent::Completed {
454 session_id: self.id,
455 stats: self.stats.read().clone(),
456 });
457 }
458 }
459 }
460}
461
462pub struct SessionManager {
464 sessions: Arc<DashMap<SessionId, Arc<Session>>>,
465 next_session_id: Arc<RwLock<SessionId>>,
466 event_tx: mpsc::UnboundedSender<SessionEvent>,
467 event_rx: Arc<RwLock<mpsc::UnboundedReceiver<SessionEvent>>>,
468}
469
470impl SessionManager {
471 pub fn new() -> Self {
473 let (event_tx, event_rx) = mpsc::unbounded_channel();
474
475 Self {
476 sessions: Arc::new(DashMap::new()),
477 next_session_id: Arc::new(RwLock::new(1)),
478 event_tx,
479 event_rx: Arc::new(RwLock::new(event_rx)),
480 }
481 }
482
483 pub fn create_session(&self, config: SessionConfig) -> Arc<Session> {
485 let session_id = {
486 let mut id = self.next_session_id.write();
487 let current = *id;
488 *id += 1;
489 current
490 };
491
492 let session = Arc::new(Session::new(
493 session_id,
494 config,
495 Some(self.event_tx.clone()),
496 ));
497 self.sessions.insert(session_id, session.clone());
498
499 info!("Created session {}", session_id);
500
501 session
502 }
503
504 pub fn get_session(&self, session_id: SessionId) -> Option<Arc<Session>> {
506 self.sessions.get(&session_id).map(|s| s.clone())
507 }
508
509 pub fn remove_session(&self, session_id: SessionId) -> Option<Arc<Session>> {
511 self.sessions.remove(&session_id).map(|(_, s)| s)
512 }
513
514 pub fn active_sessions(&self) -> Vec<Arc<Session>> {
516 self.sessions
517 .iter()
518 .filter(|entry| entry.value().state() == SessionState::Active)
519 .map(|entry| entry.value().clone())
520 .collect()
521 }
522
523 pub fn cleanup_completed(&self) -> usize {
525 let to_remove: Vec<_> = self
526 .sessions
527 .iter()
528 .filter(|entry| {
529 let state = entry.value().state();
530 state == SessionState::Completed || state == SessionState::Cancelled
531 })
532 .map(|entry| *entry.key())
533 .collect();
534
535 let count = to_remove.len();
536 for session_id in to_remove {
537 self.sessions.remove(&session_id);
538 }
539
540 if count > 0 {
541 info!("Cleaned up {} completed sessions", count);
542 }
543
544 count
545 }
546
547 #[allow(clippy::await_holding_lock)]
549 pub async fn recv_event(&self) -> Option<SessionEvent> {
550 self.event_rx.write().recv().await
551 }
552}
553
554impl Default for SessionManager {
555 fn default() -> Self {
556 Self::new()
557 }
558}
559
560#[cfg(test)]
561mod tests {
562 use super::*;
563
564 fn dummy_cid(n: u8) -> Cid {
565 let data = vec![n; 32];
566 Cid::new_v1(0x55, multihash::Multihash::wrap(0x12, &data).unwrap())
567 }
568
569 #[test]
570 fn test_session_creation() {
571 let manager = SessionManager::new();
572 let session = manager.create_session(SessionConfig::default());
573
574 assert_eq!(session.state(), SessionState::Active);
575 assert_eq!(session.stats().total_blocks, 0);
576 }
577
578 #[test]
579 fn test_add_blocks() {
580 let manager = SessionManager::new();
581 let session = manager.create_session(SessionConfig::default());
582
583 let cid1 = dummy_cid(1);
584 let cid2 = dummy_cid(2);
585
586 session.add_block(cid1, None).unwrap();
587 session.add_block(cid2, Some(Priority::High)).unwrap();
588
589 let stats = session.stats();
590 assert_eq!(stats.total_blocks, 2);
591 assert_eq!(stats.blocks_received, 0);
592 }
593
594 #[test]
595 fn test_mark_received() {
596 let manager = SessionManager::new();
597 let session = manager.create_session(SessionConfig::default());
598
599 let cid = dummy_cid(1);
600 session.add_block(cid, None).unwrap();
601
602 let data = Bytes::from(vec![1, 2, 3, 4]);
603 session.mark_received(&cid, &data).unwrap();
604
605 let stats = session.stats();
606 assert_eq!(stats.blocks_received, 1);
607 assert_eq!(stats.bytes_transferred, 4);
608 assert!(stats.is_complete());
609 }
610
611 #[test]
612 fn test_session_progress() {
613 let manager = SessionManager::new();
614 let session = manager.create_session(SessionConfig::default());
615
616 session
617 .add_blocks(&[dummy_cid(1), dummy_cid(2), dummy_cid(3)], None)
618 .unwrap();
619
620 session
621 .mark_received(&dummy_cid(1), &Bytes::from(vec![1]))
622 .unwrap();
623 let progress1 = session.stats().progress();
624 assert!(
625 (progress1 - 100.0 / 3.0).abs() < 1e-6,
626 "Expected ~33.33%, got {}",
627 progress1
628 );
629
630 session
631 .mark_received(&dummy_cid(2), &Bytes::from(vec![2]))
632 .unwrap();
633 let progress2 = session.stats().progress();
634 assert!(
635 (progress2 - 200.0 / 3.0).abs() < 1e-6,
636 "Expected ~66.67%, got {}",
637 progress2
638 );
639
640 session
641 .mark_received(&dummy_cid(3), &Bytes::from(vec![3]))
642 .unwrap();
643 let progress3 = session.stats().progress();
644 assert!(
645 (progress3 - 100.0).abs() < 1e-6,
646 "Expected 100%, got {}",
647 progress3
648 );
649 assert_eq!(session.state(), SessionState::Completed);
650 }
651
652 #[test]
653 fn test_pause_resume() {
654 let manager = SessionManager::new();
655 let session = manager.create_session(SessionConfig::default());
656
657 assert_eq!(session.state(), SessionState::Active);
658
659 session.pause();
660 assert_eq!(session.state(), SessionState::Paused);
661
662 session.resume();
663 assert_eq!(session.state(), SessionState::Active);
664 }
665
666 #[test]
667 fn test_cancel() {
668 let manager = SessionManager::new();
669 let session = manager.create_session(SessionConfig::default());
670
671 session.cancel();
672 assert_eq!(session.state(), SessionState::Cancelled);
673
674 assert!(session.add_block(dummy_cid(1), None).is_err());
676 }
677
678 #[tokio::test]
679 async fn test_wait_completion() {
680 let manager = SessionManager::new();
681 let session = manager.create_session(SessionConfig::default());
682
683 session.add_block(dummy_cid(1), None).unwrap();
684
685 let session_clone = session.clone();
686 tokio::spawn(async move {
687 tokio::time::sleep(Duration::from_millis(10)).await;
688 session_clone
689 .mark_received(&dummy_cid(1), &Bytes::from(vec![1]))
690 .unwrap();
691 });
692
693 let stats = session.wait_completion().await.unwrap();
694 assert_eq!(stats.blocks_received, 1);
695 }
696}