use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct DhtNodeEntry {
pub host: String,
pub port: i64,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PeerStrikeEntry {
pub ip: String,
pub count: i64,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SessionState {
#[serde(rename = "dht-nodes", default)]
pub dht_nodes: Vec<DhtNodeEntry>,
#[serde(
rename = "dht-node-id",
default,
skip_serializing_if = "Option::is_none"
)]
pub dht_node_id: Option<String>,
#[serde(rename = "torrents", default)]
pub torrents: Vec<irontide_core::FastResumeData>,
#[serde(rename = "banned-peers", default)]
pub banned_peers: Vec<String>,
#[serde(rename = "peer-strikes", default)]
pub peer_strikes: Vec<PeerStrikeEntry>,
}
impl SessionState {
#[must_use]
pub fn new() -> Self {
Self {
dht_nodes: Vec::new(),
dht_node_id: None,
torrents: Vec::new(),
banned_peers: Vec::new(),
peer_strikes: Vec::new(),
}
}
}
impl Default for SessionState {
fn default() -> Self {
Self::new()
}
}
#[must_use]
pub fn validate_resume_bitfield(pieces: &[u8], num_pieces: u32) -> bool {
if num_pieces == 0 {
return pieces.is_empty();
}
let expected = num_pieces.div_ceil(8) as usize;
pieces.len() == expected
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn session_state_bencode_round_trip() {
let state = SessionState {
dht_nodes: vec![
DhtNodeEntry {
host: "router.bittorrent.com".into(),
port: 6881,
},
DhtNodeEntry {
host: "dht.transmissionbt.com".into(),
port: 6881,
},
],
dht_node_id: None,
torrents: vec![irontide_core::FastResumeData::new(
vec![0xAA; 20],
"test-torrent".into(),
"/downloads".into(),
)],
banned_peers: Vec::new(),
peer_strikes: Vec::new(),
};
let encoded = irontide_bencode::to_bytes(&state).unwrap();
let decoded: SessionState = irontide_bencode::from_bytes(&encoded).unwrap();
assert_eq!(state, decoded);
}
#[test]
fn session_state_with_node_id_round_trip() {
let state = SessionState {
dht_nodes: vec![DhtNodeEntry {
host: "1.2.3.4".into(),
port: 6881,
}],
dht_node_id: Some("26d8457c04424098fd9e615b297745c772f49706".into()),
torrents: vec![],
banned_peers: vec![],
peer_strikes: vec![],
};
let encoded = irontide_bencode::to_bytes(&state).unwrap();
let encoded_str = String::from_utf8_lossy(&encoded);
assert!(
encoded_str.contains("dht-node-id"),
"encoded bencode should contain dht-node-id key: {encoded_str}"
);
let decoded: SessionState = irontide_bencode::from_bytes(&encoded).unwrap();
assert_eq!(state.dht_node_id, decoded.dht_node_id);
}
#[test]
fn empty_session_state_round_trip() {
let state = SessionState::new();
let encoded = irontide_bencode::to_bytes(&state).unwrap();
let decoded: SessionState = irontide_bencode::from_bytes(&encoded).unwrap();
assert_eq!(state, decoded);
}
#[test]
fn validate_resume_bitfield_correct_length() {
assert!(validate_resume_bitfield(&[0xFF], 8));
assert!(validate_resume_bitfield(&[0xFF, 0x80], 9));
assert!(validate_resume_bitfield(&[0xFF, 0xFF], 16));
assert!(validate_resume_bitfield(&[0x80], 1));
}
#[test]
fn validate_resume_bitfield_wrong_length() {
assert!(!validate_resume_bitfield(&[0xFF, 0x00], 8));
assert!(!validate_resume_bitfield(&[0xFF], 9));
assert!(!validate_resume_bitfield(&[0x00], 0));
}
#[test]
fn validate_resume_bitfield_zero_pieces() {
assert!(validate_resume_bitfield(&[], 0));
}
#[test]
fn session_state_with_bans_round_trip() {
let state = SessionState {
dht_nodes: vec![],
dht_node_id: None,
torrents: vec![],
banned_peers: vec!["10.0.0.1".into(), "192.168.1.5".into()],
peer_strikes: vec![
PeerStrikeEntry {
ip: "10.0.0.1".into(),
count: 3,
},
PeerStrikeEntry {
ip: "10.0.0.2".into(),
count: 1,
},
],
};
let encoded = irontide_bencode::to_bytes(&state).unwrap();
let decoded: SessionState = irontide_bencode::from_bytes(&encoded).unwrap();
assert_eq!(state, decoded);
assert_eq!(decoded.banned_peers.len(), 2);
assert_eq!(decoded.peer_strikes.len(), 2);
}
#[test]
fn session_state_backward_compatible() {
let old_state = SessionState {
dht_nodes: vec![DhtNodeEntry {
host: "example.com".into(),
port: 6881,
}],
dht_node_id: None,
torrents: vec![],
banned_peers: vec![],
peer_strikes: vec![],
};
let encoded = irontide_bencode::to_bytes(&old_state).unwrap();
let decoded: SessionState = irontide_bencode::from_bytes(&encoded).unwrap();
assert!(decoded.banned_peers.is_empty());
assert!(decoded.peer_strikes.is_empty());
assert_eq!(decoded.dht_nodes.len(), 1);
}
}