1pub mod peers;
4pub mod tokens;
5
6use std::{fmt::Debug, net::SocketAddrV4, num::NonZeroUsize};
7
8use dyn_clone::DynClone;
9use lru::LruCache;
10use tracing::debug;
11
12use crate::common::{
13 validate_immutable, AnnouncePeerRequestArguments, ErrorSpecific, FindNodeRequestArguments,
14 FindNodeResponseArguments, GetImmutableResponseArguments, GetMutableResponseArguments,
15 GetPeersRequestArguments, GetPeersResponseArguments, GetValueRequestArguments, Id, MutableItem,
16 NoMoreRecentValueResponseArguments, NoValuesResponseArguments, PingResponseArguments,
17 PutImmutableRequestArguments, PutMutableRequestArguments, PutRequest, PutRequestSpecific,
18 RequestTypeSpecific, ResponseSpecific, RoutingTable,
19};
20
21use peers::PeersStore;
22use tokens::Tokens;
23
24pub use crate::common::{MessageType, RequestSpecific};
25
26pub const MAX_INFO_HASHES: usize = 2000;
28pub const MAX_PEERS: usize = 500;
30pub const MAX_VALUES: usize = 1000;
32
33pub trait RequestFilter: Send + Sync + Debug + DynClone {
37 fn allow_request(&self, request: &RequestSpecific, from: SocketAddrV4) -> bool;
39}
40
41dyn_clone::clone_trait_object!(RequestFilter);
42
43#[derive(Debug, Clone)]
44struct DefaultFilter;
45
46impl RequestFilter for DefaultFilter {
47 fn allow_request(&self, _request: &RequestSpecific, _from: SocketAddrV4) -> bool {
48 true
49 }
50}
51
52#[derive(Debug)]
53pub struct Server {
59 tokens: Tokens,
61 peers: PeersStore,
63 immutable_values: LruCache<Id, Box<[u8]>>,
65 mutable_values: LruCache<Id, MutableItem>,
67 filter: Box<dyn RequestFilter>,
69}
70
71impl Default for Server {
72 fn default() -> Self {
73 Self::new(ServerSettings::default())
74 }
75}
76
77#[derive(Debug, Clone)]
78pub struct ServerSettings {
80 pub max_info_hashes: usize,
84 pub max_peers_per_info_hash: usize,
88 pub max_immutable_values: usize,
92 pub max_mutable_values: usize,
96 pub filter: Box<dyn RequestFilter>,
100}
101
102impl Default for ServerSettings {
103 fn default() -> Self {
104 Self {
105 max_info_hashes: MAX_INFO_HASHES,
106 max_peers_per_info_hash: MAX_PEERS,
107 max_mutable_values: MAX_VALUES,
108 max_immutable_values: MAX_VALUES,
109
110 filter: Box::new(DefaultFilter),
111 }
112 }
113}
114
115impl Server {
116 pub fn new(settings: ServerSettings) -> Self {
118 let tokens = Tokens::new();
119
120 Self {
121 tokens,
122 peers: PeersStore::new(
123 NonZeroUsize::new(settings.max_info_hashes).unwrap_or(
124 NonZeroUsize::new(MAX_INFO_HASHES).expect("MAX_PEERS is NonZeroUsize"),
125 ),
126 NonZeroUsize::new(settings.max_peers_per_info_hash)
127 .unwrap_or(NonZeroUsize::new(MAX_PEERS).expect("MAX_PEERS is NonZeroUsize")),
128 ),
129
130 immutable_values: LruCache::new(
131 NonZeroUsize::new(settings.max_immutable_values)
132 .unwrap_or(NonZeroUsize::new(MAX_VALUES).expect("MAX_VALUES is NonZeroUsize")),
133 ),
134 mutable_values: LruCache::new(
135 NonZeroUsize::new(settings.max_mutable_values)
136 .unwrap_or(NonZeroUsize::new(MAX_VALUES).expect("MAX_VALUES is NonZeroUsize")),
137 ),
138 filter: settings.filter,
139 }
140 }
141
142 pub fn handle_request(
146 &mut self,
147 routing_table: &RoutingTable,
148 from: SocketAddrV4,
149 request: RequestSpecific,
150 ) -> Option<MessageType> {
151 if !self.filter.allow_request(&request, from) {
152 return None;
153 }
154
155 if self.tokens.should_update() {
157 self.tokens.rotate()
158 }
159
160 let requester_id = request.requester_id;
161
162 Some(match request.request_type {
163 RequestTypeSpecific::Ping => {
164 MessageType::Response(ResponseSpecific::Ping(PingResponseArguments {
165 responder_id: *routing_table.id(),
166 }))
167 }
168 RequestTypeSpecific::FindNode(FindNodeRequestArguments { target, .. }) => {
169 MessageType::Response(ResponseSpecific::FindNode(FindNodeResponseArguments {
170 responder_id: *routing_table.id(),
171 nodes: routing_table.closest(target),
172 }))
173 }
174 RequestTypeSpecific::GetPeers(GetPeersRequestArguments { info_hash, .. }) => {
175 MessageType::Response(match self.peers.get_random_peers(&info_hash) {
176 Some(peers) => ResponseSpecific::GetPeers(GetPeersResponseArguments {
177 responder_id: *routing_table.id(),
178 token: self.tokens.generate_token(from).into(),
179 nodes: Some(routing_table.closest(info_hash)),
180 values: peers,
181 }),
182 None => ResponseSpecific::NoValues(NoValuesResponseArguments {
183 responder_id: *routing_table.id(),
184 token: self.tokens.generate_token(from).into(),
185 nodes: Some(routing_table.closest(info_hash)),
186 }),
187 })
188 }
189 RequestTypeSpecific::GetValue(GetValueRequestArguments { target, seq, .. }) => {
190 if seq.is_some() {
191 MessageType::Response(self.handle_get_mutable(routing_table, from, target, seq))
192 } else if let Some(v) = self.immutable_values.get(&target) {
193 MessageType::Response(ResponseSpecific::GetImmutable(
194 GetImmutableResponseArguments {
195 responder_id: *routing_table.id(),
196 token: self.tokens.generate_token(from).into(),
197 nodes: Some(routing_table.closest(target)),
198 v: v.clone(),
199 },
200 ))
201 } else {
202 MessageType::Response(self.handle_get_mutable(routing_table, from, target, seq))
203 }
204 }
205 RequestTypeSpecific::Put(PutRequest {
206 token,
207 put_request_type,
208 }) => match put_request_type {
209 PutRequestSpecific::AnnouncePeer(AnnouncePeerRequestArguments {
210 info_hash,
211 port,
212 implied_port,
213 ..
214 }) => {
215 if !self.tokens.validate(from, &token) {
216 debug!(
217 ?info_hash,
218 ?requester_id,
219 ?from,
220 request_type = "announce_peer",
221 "Invalid token"
222 );
223
224 return Some(MessageType::Error(ErrorSpecific {
225 code: 203,
226 description: "Bad token".to_string(),
227 }));
228 }
229
230 let peer = match implied_port {
231 Some(true) => from,
232 _ => SocketAddrV4::new(*from.ip(), port),
233 };
234
235 self.peers
236 .add_peer(info_hash, (&request.requester_id, peer));
237
238 return Some(MessageType::Response(ResponseSpecific::Ping(
239 PingResponseArguments {
240 responder_id: *routing_table.id(),
241 },
242 )));
243 }
244 PutRequestSpecific::PutImmutable(PutImmutableRequestArguments {
245 v,
246 target,
247 ..
248 }) => {
249 if !self.tokens.validate(from, &token) {
250 debug!(
251 ?target,
252 ?requester_id,
253 ?from,
254 request_type = "put_immutable",
255 "Invalid token"
256 );
257
258 return Some(MessageType::Error(ErrorSpecific {
259 code: 203,
260 description: "Bad token".to_string(),
261 }));
262 }
263
264 if v.len() > 1000 {
265 debug!(?target, ?requester_id, ?from, size = ?v.len(), "Message (v field) too big.");
266
267 return Some(MessageType::Error(ErrorSpecific {
268 code: 205,
269 description: "Message (v field) too big.".to_string(),
270 }));
271 }
272 if !validate_immutable(&v, target) {
273 debug!(?target, ?requester_id, ?from, v = ?v, "Target doesn't match the sha1 hash of v field.");
274
275 return Some(MessageType::Error(ErrorSpecific {
276 code: 203,
277 description: "Target doesn't match the sha1 hash of v field"
278 .to_string(),
279 }));
280 }
281
282 self.immutable_values.put(target, v);
283
284 return Some(MessageType::Response(ResponseSpecific::Ping(
285 PingResponseArguments {
286 responder_id: *routing_table.id(),
287 },
288 )));
289 }
290 PutRequestSpecific::PutMutable(PutMutableRequestArguments {
291 target,
292 v,
293 k,
294 seq,
295 sig,
296 salt,
297 cas,
298 ..
299 }) => {
300 if !self.tokens.validate(from, &token) {
301 debug!(
302 ?target,
303 ?requester_id,
304 ?from,
305 request_type = "put_mutable",
306 "Invalid token"
307 );
308 return Some(MessageType::Error(ErrorSpecific {
309 code: 203,
310 description: "Bad token".to_string(),
311 }));
312 }
313 if v.len() > 1000 {
314 return Some(MessageType::Error(ErrorSpecific {
315 code: 205,
316 description: "Message (v field) too big.".to_string(),
317 }));
318 }
319 if let Some(ref salt) = salt {
320 if salt.len() > 64 {
321 return Some(MessageType::Error(ErrorSpecific {
322 code: 207,
323 description: "salt (salt field) too big.".to_string(),
324 }));
325 }
326 }
327 if let Some(previous) = self.mutable_values.get(&target) {
328 if let Some(cas) = cas {
329 if previous.seq() != cas {
330 debug!(
331 ?target,
332 ?requester_id,
333 ?from,
334 "CAS mismatched, re-read value and try again."
335 );
336
337 return Some(MessageType::Error(ErrorSpecific {
338 code: 301,
339 description: "CAS mismatched, re-read value and try again."
340 .to_string(),
341 }));
342 }
343 };
344
345 if seq < previous.seq() {
346 debug!(
347 ?target,
348 ?requester_id,
349 ?from,
350 "Sequence number less than current."
351 );
352
353 return Some(MessageType::Error(ErrorSpecific {
354 code: 302,
355 description: "Sequence number less than current.".to_string(),
356 }));
357 }
358 }
359
360 match MutableItem::from_dht_message(target, &k, v, seq, &sig, salt) {
361 Ok(item) => {
362 self.mutable_values.put(target, item);
363
364 MessageType::Response(ResponseSpecific::Ping(PingResponseArguments {
365 responder_id: *routing_table.id(),
366 }))
367 }
368 Err(error) => {
369 debug!(?target, ?requester_id, ?from, ?error, "Invalid signature");
370
371 MessageType::Error(ErrorSpecific {
372 code: 206,
373 description: "Invalid signature".to_string(),
374 })
375 }
376 }
377 }
378 },
379 })
380 }
381
382 fn handle_get_mutable(
384 &mut self,
385 routing_table: &RoutingTable,
386 from: SocketAddrV4,
387 target: Id,
388 seq: Option<i64>,
389 ) -> ResponseSpecific {
390 match self.mutable_values.get(&target) {
391 Some(item) => {
392 let no_more_recent_values = seq.map(|request_seq| item.seq() <= request_seq);
393
394 match no_more_recent_values {
395 Some(true) => {
396 ResponseSpecific::NoMoreRecentValue(NoMoreRecentValueResponseArguments {
397 responder_id: *routing_table.id(),
398 token: self.tokens.generate_token(from).into(),
399 nodes: Some(routing_table.closest(target)),
400 seq: item.seq(),
401 })
402 }
403 _ => ResponseSpecific::GetMutable(GetMutableResponseArguments {
404 responder_id: *routing_table.id(),
405 token: self.tokens.generate_token(from).into(),
406 nodes: Some(routing_table.closest(target)),
407 v: item.value().into(),
408 k: *item.key(),
409 seq: item.seq(),
410 sig: *item.signature(),
411 }),
412 }
413 }
414 None => ResponseSpecific::NoValues(NoValuesResponseArguments {
415 responder_id: *routing_table.id(),
416 token: self.tokens.generate_token(from).into(),
417 nodes: Some(routing_table.closest(target)),
418 }),
419 }
420 }
421}