ipfrs_transport/
session.rs

1//! Session management for grouping related block requests
2//!
3//! Sessions provide a way to group related block requests together,
4//! enabling features like:
5//! - Batch completion notifications
6//! - Session-level prioritization
7//! - Coordinated cancellation
8//! - Progress tracking across multiple blocks
9//!
10//! # Example
11//!
12//! ```
13//! use ipfrs_transport::{SessionManager, SessionConfig, Priority};
14//! use ipfrs_core::Cid;
15//! use multihash::Multihash;
16//!
17//! // Create a session manager
18//! let manager = SessionManager::new();
19//!
20//! // Create test CIDs
21//! let hash1 = Multihash::wrap(0x12, &[1, 2, 3]).unwrap();
22//! let cid1 = Cid::new_v1(0x55, hash1);
23//! let hash2 = Multihash::wrap(0x12, &[4, 5, 6]).unwrap();
24//! let cid2 = Cid::new_v1(0x55, hash2);
25//!
26//! // Create a session
27//! let config = SessionConfig::default();
28//! let session = manager.create_session(config);
29//!
30//! // Add blocks to the session
31//! session.add_block(cid1, Some(Priority::Normal)).unwrap();
32//! session.add_block(cid2, Some(Priority::High)).unwrap();
33//!
34//! // Check session status
35//! let stats = session.stats();
36//! println!("Total blocks: {}, received: {}", stats.total_blocks, stats.blocks_received);
37//! ```
38
39use crate::want_list::Priority;
40use bytes::Bytes;
41use cid::Cid;
42use dashmap::DashMap;
43use parking_lot::RwLock;
44use std::sync::Arc;
45use std::time::{Duration, Instant};
46use thiserror::Error;
47use tokio::sync::{mpsc, watch};
48use tracing::{debug, info, warn};
49
50/// Session ID type
51pub type SessionId = u64;
52
53/// Session error types
54#[derive(Error, Debug)]
55pub enum SessionError {
56    #[error("Session not found: {0}")]
57    NotFound(SessionId),
58
59    #[error("Session already exists: {0}")]
60    AlreadyExists(SessionId),
61
62    #[error("Session closed: {0}")]
63    Closed(SessionId),
64
65    #[error("Block not in session: {0}")]
66    BlockNotInSession(String),
67
68    #[error("Timeout waiting for session completion")]
69    Timeout,
70}
71
72/// Session state
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub enum SessionState {
75    /// Session is active and accepting requests
76    Active,
77    /// Session is paused (not sending new requests)
78    Paused,
79    /// Session is completing (no new blocks, waiting for pending)
80    Completing,
81    /// Session is completed
82    Completed,
83    /// Session was cancelled
84    Cancelled,
85}
86
87/// Session configuration
88#[derive(Debug, Clone)]
89pub struct SessionConfig {
90    /// Session timeout (0 = no timeout)
91    pub timeout: Duration,
92    /// Default priority for blocks in this session
93    pub default_priority: Priority,
94    /// Maximum concurrent blocks per session
95    pub max_concurrent_blocks: usize,
96    /// Enable progress notifications
97    pub progress_notifications: bool,
98}
99
100impl Default for SessionConfig {
101    fn default() -> Self {
102        Self {
103            timeout: Duration::from_secs(300), // 5 minutes
104            default_priority: Priority::Normal,
105            max_concurrent_blocks: 100,
106            progress_notifications: true,
107        }
108    }
109}
110
111/// Session statistics
112#[derive(Debug, Clone, Default)]
113pub struct SessionStats {
114    /// Total blocks requested
115    pub total_blocks: usize,
116    /// Blocks received
117    pub blocks_received: usize,
118    /// Blocks failed
119    pub blocks_failed: usize,
120    /// Total bytes transferred
121    pub bytes_transferred: u64,
122    /// Session start time
123    pub started_at: Option<Instant>,
124    /// Session end time
125    pub completed_at: Option<Instant>,
126    /// Average block fetch time
127    pub avg_block_time: Option<Duration>,
128}
129
130impl SessionStats {
131    /// Calculate progress percentage
132    pub fn progress(&self) -> f64 {
133        if self.total_blocks == 0 {
134            return 0.0;
135        }
136        (self.blocks_received as f64 / self.total_blocks as f64) * 100.0
137    }
138
139    /// Calculate throughput in bytes per second
140    pub fn throughput(&self) -> Option<f64> {
141        if let (Some(started), Some(completed)) = (self.started_at, self.completed_at) {
142            let duration = completed.duration_since(started).as_secs_f64();
143            if duration > 0.0 {
144                return Some(self.bytes_transferred as f64 / duration);
145            }
146        }
147        None
148    }
149
150    /// Check if session is complete
151    pub fn is_complete(&self) -> bool {
152        self.blocks_received + self.blocks_failed >= self.total_blocks
153    }
154}
155
156/// Block request within a session
157#[derive(Debug, Clone)]
158#[allow(dead_code)]
159struct BlockRequest {
160    cid: Cid,
161    priority: Priority,
162    requested_at: Instant,
163    completed_at: Option<Instant>,
164    size: Option<usize>,
165}
166
167/// Session progress event
168#[derive(Debug, Clone)]
169pub enum SessionEvent {
170    /// Session started
171    Started { session_id: SessionId },
172    /// Block received
173    BlockReceived {
174        session_id: SessionId,
175        cid: Cid,
176        size: usize,
177    },
178    /// Block failed
179    BlockFailed {
180        session_id: SessionId,
181        cid: Cid,
182        error: String,
183    },
184    /// Session progress update
185    Progress {
186        session_id: SessionId,
187        stats: SessionStats,
188    },
189    /// Session completed
190    Completed {
191        session_id: SessionId,
192        stats: SessionStats,
193    },
194    /// Session cancelled
195    Cancelled { session_id: SessionId },
196}
197
198/// A session for grouping related block requests
199pub struct Session {
200    id: SessionId,
201    config: SessionConfig,
202    state: Arc<RwLock<SessionState>>,
203    blocks: Arc<DashMap<Cid, BlockRequest>>,
204    stats: Arc<RwLock<SessionStats>>,
205    event_tx: Option<mpsc::UnboundedSender<SessionEvent>>,
206    state_rx: watch::Receiver<SessionState>,
207    state_tx: watch::Sender<SessionState>,
208}
209
210impl Session {
211    /// Create a new session
212    pub fn new(
213        id: SessionId,
214        config: SessionConfig,
215        event_tx: Option<mpsc::UnboundedSender<SessionEvent>>,
216    ) -> Self {
217        let (state_tx, state_rx) = watch::channel(SessionState::Active);
218
219        let session = Self {
220            id,
221            config,
222            state: Arc::new(RwLock::new(SessionState::Active)),
223            blocks: Arc::new(DashMap::new()),
224            stats: Arc::new(RwLock::new(SessionStats {
225                started_at: Some(Instant::now()),
226                ..Default::default()
227            })),
228            event_tx,
229            state_rx,
230            state_tx,
231        };
232
233        // Send started event
234        if let Some(tx) = &session.event_tx {
235            let _ = tx.send(SessionEvent::Started { session_id: id });
236        }
237
238        session
239    }
240
241    /// Get session ID
242    pub fn id(&self) -> SessionId {
243        self.id
244    }
245
246    /// Get current state
247    pub fn state(&self) -> SessionState {
248        *self.state.read()
249    }
250
251    /// Add a block to the session
252    pub fn add_block(&self, cid: Cid, priority: Option<Priority>) -> Result<(), SessionError> {
253        let state = *self.state.read();
254        if state != SessionState::Active {
255            return Err(SessionError::Closed(self.id));
256        }
257
258        let priority = priority.unwrap_or(self.config.default_priority);
259
260        let request = BlockRequest {
261            cid,
262            priority,
263            requested_at: Instant::now(),
264            completed_at: None,
265            size: None,
266        };
267
268        self.blocks.insert(cid, request);
269
270        // Update stats
271        {
272            let mut stats = self.stats.write();
273            stats.total_blocks += 1;
274        }
275
276        debug!("Added block {} to session {}", cid, self.id);
277
278        Ok(())
279    }
280
281    /// Add multiple blocks to the session
282    pub fn add_blocks(&self, cids: &[Cid], priority: Option<Priority>) -> Result<(), SessionError> {
283        for cid in cids {
284            self.add_block(*cid, priority)?;
285        }
286        Ok(())
287    }
288
289    /// Mark a block as received
290    pub fn mark_received(&self, cid: &Cid, data: &Bytes) -> Result<(), SessionError> {
291        let mut block = self
292            .blocks
293            .get_mut(cid)
294            .ok_or_else(|| SessionError::BlockNotInSession(cid.to_string()))?;
295
296        block.completed_at = Some(Instant::now());
297        block.size = Some(data.len());
298
299        // Update stats and check for completion
300        let should_complete = {
301            let mut stats = self.stats.write();
302            stats.blocks_received += 1;
303            stats.bytes_transferred += data.len() as u64;
304
305            // Update average block time
306            let fetch_time = block
307                .completed_at
308                .unwrap()
309                .duration_since(block.requested_at);
310            stats.avg_block_time = Some(
311                stats
312                    .avg_block_time
313                    .map(|avg| (avg + fetch_time) / 2)
314                    .unwrap_or(fetch_time),
315            );
316
317            // Check if session is complete
318            let is_complete = stats.is_complete() && self.state() == SessionState::Active;
319            if is_complete {
320                stats.completed_at = Some(Instant::now());
321            }
322            is_complete
323        }; // Release stats write lock here
324
325        // Transition state outside the stats lock to avoid deadlock
326        if should_complete {
327            self.transition_state(SessionState::Completed);
328        }
329
330        // Send events
331        if let Some(tx) = &self.event_tx {
332            let _ = tx.send(SessionEvent::BlockReceived {
333                session_id: self.id,
334                cid: *cid,
335                size: data.len(),
336            });
337
338            if self.config.progress_notifications {
339                let _ = tx.send(SessionEvent::Progress {
340                    session_id: self.id,
341                    stats: self.stats.read().clone(),
342                });
343            }
344        }
345
346        debug!("Block {} received in session {}", cid, self.id);
347
348        Ok(())
349    }
350
351    /// Mark a block as failed
352    pub fn mark_failed(&self, cid: &Cid, error: String) -> Result<(), SessionError> {
353        let _block = self
354            .blocks
355            .get(cid)
356            .ok_or_else(|| SessionError::BlockNotInSession(cid.to_string()))?;
357
358        // Update stats
359        {
360            let mut stats = self.stats.write();
361            stats.blocks_failed += 1;
362
363            // Check if session is complete
364            if stats.is_complete() && self.state() == SessionState::Active {
365                stats.completed_at = Some(Instant::now());
366                self.transition_state(SessionState::Completed);
367            }
368        }
369
370        // Send events
371        if let Some(tx) = &self.event_tx {
372            let _ = tx.send(SessionEvent::BlockFailed {
373                session_id: self.id,
374                cid: *cid,
375                error: error.clone(),
376            });
377        }
378
379        warn!("Block {} failed in session {}: {}", cid, self.id, error);
380
381        Ok(())
382    }
383
384    /// Pause the session
385    pub fn pause(&self) {
386        self.transition_state(SessionState::Paused);
387        info!("Session {} paused", self.id);
388    }
389
390    /// Resume the session
391    pub fn resume(&self) {
392        self.transition_state(SessionState::Active);
393        info!("Session {} resumed", self.id);
394    }
395
396    /// Cancel the session
397    pub fn cancel(&self) {
398        self.transition_state(SessionState::Cancelled);
399
400        if let Some(tx) = &self.event_tx {
401            let _ = tx.send(SessionEvent::Cancelled {
402                session_id: self.id,
403            });
404        }
405
406        info!("Session {} cancelled", self.id);
407    }
408
409    /// Get session statistics
410    pub fn stats(&self) -> SessionStats {
411        self.stats.read().clone()
412    }
413
414    /// Get pending blocks (not yet received)
415    pub fn pending_blocks(&self) -> Vec<Cid> {
416        self.blocks
417            .iter()
418            .filter(|entry| entry.value().completed_at.is_none())
419            .map(|entry| *entry.key())
420            .collect()
421    }
422
423    /// Wait for session completion
424    pub async fn wait_completion(&self) -> Result<SessionStats, SessionError> {
425        let mut rx = self.state_rx.clone();
426
427        // Check if already complete
428        let state = *self.state.read();
429        if state == SessionState::Completed || state == SessionState::Cancelled {
430            return Ok(self.stats.read().clone());
431        }
432
433        // Wait for state change
434        loop {
435            if rx.changed().await.is_err() {
436                return Err(SessionError::Closed(self.id));
437            }
438
439            let state = *rx.borrow();
440            if state == SessionState::Completed || state == SessionState::Cancelled {
441                return Ok(self.stats.read().clone());
442            }
443        }
444    }
445
446    /// Transition to a new state
447    fn transition_state(&self, new_state: SessionState) {
448        *self.state.write() = new_state;
449        let _ = self.state_tx.send(new_state);
450
451        if new_state == SessionState::Completed {
452            if let Some(tx) = &self.event_tx {
453                let _ = tx.send(SessionEvent::Completed {
454                    session_id: self.id,
455                    stats: self.stats.read().clone(),
456                });
457            }
458        }
459    }
460}
461
462/// Session manager for managing multiple sessions
463pub struct SessionManager {
464    sessions: Arc<DashMap<SessionId, Arc<Session>>>,
465    next_session_id: Arc<RwLock<SessionId>>,
466    event_tx: mpsc::UnboundedSender<SessionEvent>,
467    event_rx: Arc<RwLock<mpsc::UnboundedReceiver<SessionEvent>>>,
468}
469
470impl SessionManager {
471    /// Create a new session manager
472    pub fn new() -> Self {
473        let (event_tx, event_rx) = mpsc::unbounded_channel();
474
475        Self {
476            sessions: Arc::new(DashMap::new()),
477            next_session_id: Arc::new(RwLock::new(1)),
478            event_tx,
479            event_rx: Arc::new(RwLock::new(event_rx)),
480        }
481    }
482
483    /// Create a new session
484    pub fn create_session(&self, config: SessionConfig) -> Arc<Session> {
485        let session_id = {
486            let mut id = self.next_session_id.write();
487            let current = *id;
488            *id += 1;
489            current
490        };
491
492        let session = Arc::new(Session::new(
493            session_id,
494            config,
495            Some(self.event_tx.clone()),
496        ));
497        self.sessions.insert(session_id, session.clone());
498
499        info!("Created session {}", session_id);
500
501        session
502    }
503
504    /// Get a session by ID
505    pub fn get_session(&self, session_id: SessionId) -> Option<Arc<Session>> {
506        self.sessions.get(&session_id).map(|s| s.clone())
507    }
508
509    /// Remove a session
510    pub fn remove_session(&self, session_id: SessionId) -> Option<Arc<Session>> {
511        self.sessions.remove(&session_id).map(|(_, s)| s)
512    }
513
514    /// Get all active sessions
515    pub fn active_sessions(&self) -> Vec<Arc<Session>> {
516        self.sessions
517            .iter()
518            .filter(|entry| entry.value().state() == SessionState::Active)
519            .map(|entry| entry.value().clone())
520            .collect()
521    }
522
523    /// Clean up completed sessions
524    pub fn cleanup_completed(&self) -> usize {
525        let to_remove: Vec<_> = self
526            .sessions
527            .iter()
528            .filter(|entry| {
529                let state = entry.value().state();
530                state == SessionState::Completed || state == SessionState::Cancelled
531            })
532            .map(|entry| *entry.key())
533            .collect();
534
535        let count = to_remove.len();
536        for session_id in to_remove {
537            self.sessions.remove(&session_id);
538        }
539
540        if count > 0 {
541            info!("Cleaned up {} completed sessions", count);
542        }
543
544        count
545    }
546
547    /// Receive session events
548    #[allow(clippy::await_holding_lock)]
549    pub async fn recv_event(&self) -> Option<SessionEvent> {
550        self.event_rx.write().recv().await
551    }
552}
553
554impl Default for SessionManager {
555    fn default() -> Self {
556        Self::new()
557    }
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563
564    fn dummy_cid(n: u8) -> Cid {
565        let data = vec![n; 32];
566        Cid::new_v1(0x55, multihash::Multihash::wrap(0x12, &data).unwrap())
567    }
568
569    #[test]
570    fn test_session_creation() {
571        let manager = SessionManager::new();
572        let session = manager.create_session(SessionConfig::default());
573
574        assert_eq!(session.state(), SessionState::Active);
575        assert_eq!(session.stats().total_blocks, 0);
576    }
577
578    #[test]
579    fn test_add_blocks() {
580        let manager = SessionManager::new();
581        let session = manager.create_session(SessionConfig::default());
582
583        let cid1 = dummy_cid(1);
584        let cid2 = dummy_cid(2);
585
586        session.add_block(cid1, None).unwrap();
587        session.add_block(cid2, Some(Priority::High)).unwrap();
588
589        let stats = session.stats();
590        assert_eq!(stats.total_blocks, 2);
591        assert_eq!(stats.blocks_received, 0);
592    }
593
594    #[test]
595    fn test_mark_received() {
596        let manager = SessionManager::new();
597        let session = manager.create_session(SessionConfig::default());
598
599        let cid = dummy_cid(1);
600        session.add_block(cid, None).unwrap();
601
602        let data = Bytes::from(vec![1, 2, 3, 4]);
603        session.mark_received(&cid, &data).unwrap();
604
605        let stats = session.stats();
606        assert_eq!(stats.blocks_received, 1);
607        assert_eq!(stats.bytes_transferred, 4);
608        assert!(stats.is_complete());
609    }
610
611    #[test]
612    fn test_session_progress() {
613        let manager = SessionManager::new();
614        let session = manager.create_session(SessionConfig::default());
615
616        session
617            .add_blocks(&[dummy_cid(1), dummy_cid(2), dummy_cid(3)], None)
618            .unwrap();
619
620        session
621            .mark_received(&dummy_cid(1), &Bytes::from(vec![1]))
622            .unwrap();
623        let progress1 = session.stats().progress();
624        assert!(
625            (progress1 - 100.0 / 3.0).abs() < 1e-6,
626            "Expected ~33.33%, got {}",
627            progress1
628        );
629
630        session
631            .mark_received(&dummy_cid(2), &Bytes::from(vec![2]))
632            .unwrap();
633        let progress2 = session.stats().progress();
634        assert!(
635            (progress2 - 200.0 / 3.0).abs() < 1e-6,
636            "Expected ~66.67%, got {}",
637            progress2
638        );
639
640        session
641            .mark_received(&dummy_cid(3), &Bytes::from(vec![3]))
642            .unwrap();
643        let progress3 = session.stats().progress();
644        assert!(
645            (progress3 - 100.0).abs() < 1e-6,
646            "Expected 100%, got {}",
647            progress3
648        );
649        assert_eq!(session.state(), SessionState::Completed);
650    }
651
652    #[test]
653    fn test_pause_resume() {
654        let manager = SessionManager::new();
655        let session = manager.create_session(SessionConfig::default());
656
657        assert_eq!(session.state(), SessionState::Active);
658
659        session.pause();
660        assert_eq!(session.state(), SessionState::Paused);
661
662        session.resume();
663        assert_eq!(session.state(), SessionState::Active);
664    }
665
666    #[test]
667    fn test_cancel() {
668        let manager = SessionManager::new();
669        let session = manager.create_session(SessionConfig::default());
670
671        session.cancel();
672        assert_eq!(session.state(), SessionState::Cancelled);
673
674        // Should not be able to add blocks to cancelled session
675        assert!(session.add_block(dummy_cid(1), None).is_err());
676    }
677
678    #[tokio::test]
679    async fn test_wait_completion() {
680        let manager = SessionManager::new();
681        let session = manager.create_session(SessionConfig::default());
682
683        session.add_block(dummy_cid(1), None).unwrap();
684
685        let session_clone = session.clone();
686        tokio::spawn(async move {
687            tokio::time::sleep(Duration::from_millis(10)).await;
688            session_clone
689                .mark_received(&dummy_cid(1), &Bytes::from(vec![1]))
690                .unwrap();
691        });
692
693        let stats = session.wait_completion().await.unwrap();
694        assert_eq!(stats.blocks_received, 1);
695    }
696}