Skip to main content

irontide_session/
persistence.rs

1use serde::{Deserialize, Serialize};
2
3/// A DHT bootstrap node entry for session persistence.
4#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
5pub struct DhtNodeEntry {
6    /// Hostname or IP address of the DHT node.
7    pub host: String,
8    /// Port number of the DHT node.
9    pub port: i64,
10}
11
12/// A peer strike entry for session persistence.
13#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
14pub struct PeerStrikeEntry {
15    /// IP address of the peer that received strikes.
16    pub ip: String,
17    /// Number of accumulated strikes.
18    pub count: i64,
19}
20
21/// Persisted session state containing a DHT node cache and torrent resume data.
22///
23/// Serializes to bencode for on-disk persistence. The DHT node list allows
24/// faster bootstrapping on restart, and the torrent list holds
25/// [`irontide_core::FastResumeData`] entries so torrents can skip piece
26/// verification when the bitfield matches.
27#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
28pub struct SessionState {
29    /// Cached DHT routing table nodes for faster bootstrap on restart.
30    #[serde(rename = "dht-nodes", default)]
31    pub dht_nodes: Vec<DhtNodeEntry>,
32    /// BEP 42-compliant DHT node ID (hex). Persisted so the routing table
33    /// survives across sessions without regeneration.
34    #[serde(
35        rename = "dht-node-id",
36        default,
37        skip_serializing_if = "Option::is_none"
38    )]
39    pub dht_node_id: Option<String>,
40    /// Fast resume data for each torrent in the session.
41    #[serde(rename = "torrents", default)]
42    pub torrents: Vec<irontide_core::FastResumeData>,
43    /// IP addresses of permanently banned peers.
44    #[serde(rename = "banned-peers", default)]
45    pub banned_peers: Vec<String>,
46    /// Per-peer strike counts for the smart ban system.
47    #[serde(rename = "peer-strikes", default)]
48    pub peer_strikes: Vec<PeerStrikeEntry>,
49}
50
51impl SessionState {
52    /// Create a new empty `SessionState`.
53    pub fn new() -> Self {
54        Self {
55            dht_nodes: Vec::new(),
56            dht_node_id: None,
57            torrents: Vec::new(),
58            banned_peers: Vec::new(),
59            peer_strikes: Vec::new(),
60        }
61    }
62}
63
64impl Default for SessionState {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70/// Returns `true` if the `pieces` bitfield has the correct length for
71/// `num_pieces` pieces (i.e. `ceil(num_pieces / 8)` bytes).
72///
73/// This is used to decide whether a resume file's piece bitfield is
74/// trustworthy and hash verification can be skipped on restart.
75pub fn validate_resume_bitfield(pieces: &[u8], num_pieces: u32) -> bool {
76    if num_pieces == 0 {
77        return pieces.is_empty();
78    }
79    let expected = num_pieces.div_ceil(8) as usize;
80    pieces.len() == expected
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86    use pretty_assertions::assert_eq;
87
88    #[test]
89    fn session_state_bencode_round_trip() {
90        let state = SessionState {
91            dht_nodes: vec![
92                DhtNodeEntry {
93                    host: "router.bittorrent.com".into(),
94                    port: 6881,
95                },
96                DhtNodeEntry {
97                    host: "dht.transmissionbt.com".into(),
98                    port: 6881,
99                },
100            ],
101            dht_node_id: None,
102            torrents: vec![irontide_core::FastResumeData::new(
103                vec![0xAA; 20],
104                "test-torrent".into(),
105                "/downloads".into(),
106            )],
107            banned_peers: Vec::new(),
108            peer_strikes: Vec::new(),
109        };
110
111        let encoded = irontide_bencode::to_bytes(&state).unwrap();
112        let decoded: SessionState = irontide_bencode::from_bytes(&encoded).unwrap();
113        assert_eq!(state, decoded);
114    }
115
116    #[test]
117    fn session_state_with_node_id_round_trip() {
118        let state = SessionState {
119            dht_nodes: vec![DhtNodeEntry {
120                host: "1.2.3.4".into(),
121                port: 6881,
122            }],
123            dht_node_id: Some("26d8457c04424098fd9e615b297745c772f49706".into()),
124            torrents: vec![],
125            banned_peers: vec![],
126            peer_strikes: vec![],
127        };
128
129        let encoded = irontide_bencode::to_bytes(&state).unwrap();
130        let encoded_str = String::from_utf8_lossy(&encoded);
131        assert!(
132            encoded_str.contains("dht-node-id"),
133            "encoded bencode should contain dht-node-id key: {encoded_str}"
134        );
135
136        let decoded: SessionState = irontide_bencode::from_bytes(&encoded).unwrap();
137        assert_eq!(state.dht_node_id, decoded.dht_node_id);
138    }
139
140    #[test]
141    fn empty_session_state_round_trip() {
142        let state = SessionState::new();
143
144        let encoded = irontide_bencode::to_bytes(&state).unwrap();
145        let decoded: SessionState = irontide_bencode::from_bytes(&encoded).unwrap();
146        assert_eq!(state, decoded);
147    }
148
149    #[test]
150    fn validate_resume_bitfield_correct_length() {
151        // 8 pieces -> 1 byte
152        assert!(validate_resume_bitfield(&[0xFF], 8));
153        // 9 pieces -> 2 bytes
154        assert!(validate_resume_bitfield(&[0xFF, 0x80], 9));
155        // 16 pieces -> 2 bytes
156        assert!(validate_resume_bitfield(&[0xFF, 0xFF], 16));
157        // 1 piece -> 1 byte
158        assert!(validate_resume_bitfield(&[0x80], 1));
159    }
160
161    #[test]
162    fn validate_resume_bitfield_wrong_length() {
163        // 8 pieces with 2 bytes -> wrong
164        assert!(!validate_resume_bitfield(&[0xFF, 0x00], 8));
165        // 9 pieces with 1 byte -> wrong
166        assert!(!validate_resume_bitfield(&[0xFF], 9));
167        // 0 pieces with 1 byte of data -> wrong
168        assert!(!validate_resume_bitfield(&[0x00], 0));
169    }
170
171    #[test]
172    fn validate_resume_bitfield_zero_pieces() {
173        // 0 pieces with empty data -> true
174        assert!(validate_resume_bitfield(&[], 0));
175    }
176
177    #[test]
178    fn session_state_with_bans_round_trip() {
179        let state = SessionState {
180            dht_nodes: vec![],
181            dht_node_id: None,
182            torrents: vec![],
183            banned_peers: vec!["10.0.0.1".into(), "192.168.1.5".into()],
184            peer_strikes: vec![
185                PeerStrikeEntry {
186                    ip: "10.0.0.1".into(),
187                    count: 3,
188                },
189                PeerStrikeEntry {
190                    ip: "10.0.0.2".into(),
191                    count: 1,
192                },
193            ],
194        };
195
196        let encoded = irontide_bencode::to_bytes(&state).unwrap();
197        let decoded: SessionState = irontide_bencode::from_bytes(&encoded).unwrap();
198        assert_eq!(state, decoded);
199        assert_eq!(decoded.banned_peers.len(), 2);
200        assert_eq!(decoded.peer_strikes.len(), 2);
201    }
202
203    #[test]
204    fn session_state_backward_compatible() {
205        // Old format without ban fields — should deserialize cleanly with defaults
206        let old_state = SessionState {
207            dht_nodes: vec![DhtNodeEntry {
208                host: "example.com".into(),
209                port: 6881,
210            }],
211            dht_node_id: None,
212            torrents: vec![],
213            banned_peers: vec![],
214            peer_strikes: vec![],
215        };
216        let encoded = irontide_bencode::to_bytes(&old_state).unwrap();
217
218        // Manually create bencode without banned-peers/peer-strikes to simulate old format
219        // Since #[serde(default)] is used, decoding old data missing those fields works
220        let decoded: SessionState = irontide_bencode::from_bytes(&encoded).unwrap();
221        assert!(decoded.banned_peers.is_empty());
222        assert!(decoded.peer_strikes.is_empty());
223        assert_eq!(decoded.dht_nodes.len(), 1);
224    }
225}