1use std::io::{BufRead, BufReader, Read, Write};
29use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
30use std::sync::atomic::{AtomicBool, Ordering};
31use std::sync::Arc;
32use std::time::Duration;
33
34#[derive(thiserror::Error, Debug)]
36pub enum PeerError {
37 #[error("peer request was not authorized")]
39 Unauthorized,
40 #[error("peer handle not found")]
42 NotFound,
43 #[error("peer returned HTTP {0}")]
45 Status(u16),
46 #[error("peer protocol error: {0}")]
48 Protocol(String),
49 #[error("{0}")]
51 Io(String),
52}
53
54impl From<std::io::Error> for PeerError {
55 fn from(error: std::io::Error) -> Self {
56 PeerError::Io(error.to_string())
57 }
58}
59
60pub trait ByteSource: Send + Sync {
66 fn len(&self) -> Option<u64>;
68
69 fn read_at(&self, offset: u64, buf: &mut [u8]) -> std::io::Result<usize>;
72
73 fn is_empty(&self) -> bool {
75 self.len() == Some(0)
76 }
77}
78
79pub struct BytesSource {
82 bytes: Vec<u8>,
83}
84
85impl BytesSource {
86 pub fn new(bytes: Vec<u8>) -> Self {
87 Self { bytes }
88 }
89}
90
91impl ByteSource for BytesSource {
92 fn len(&self) -> Option<u64> {
93 Some(self.bytes.len() as u64)
94 }
95
96 fn read_at(&self, offset: u64, buf: &mut [u8]) -> std::io::Result<usize> {
97 let offset = offset.min(self.bytes.len() as u64) as usize;
98 let available = &self.bytes[offset..];
99 let n = available.len().min(buf.len());
100 buf[..n].copy_from_slice(&available[..n]);
101 Ok(n)
102 }
103}
104
105pub type SourceResolver = Arc<dyn Fn(&str) -> Option<Arc<dyn ByteSource>> + Send + Sync>;
107
108pub struct PeerServer {
110 addr: SocketAddr,
111 running: Arc<AtomicBool>,
112}
113
114impl PeerServer {
115 pub fn start(
118 bind_addr: impl ToSocketAddrs,
119 token: impl Into<String>,
120 resolver: SourceResolver,
121 ) -> Result<PeerServer, PeerError> {
122 let listener = TcpListener::bind(bind_addr)?;
123 let addr = listener.local_addr()?;
124 let running = Arc::new(AtomicBool::new(true));
125 let token = token.into();
126
127 let loop_running = running.clone();
128 std::thread::Builder::new()
129 .name("cranpose-peer".to_string())
130 .spawn(move || {
131 for stream in listener.incoming() {
132 if !loop_running.load(Ordering::SeqCst) {
133 break;
134 }
135 let Ok(stream) = stream else { continue };
136 let token = token.clone();
137 let resolver = resolver.clone();
138 let _ = std::thread::Builder::new()
141 .name("cranpose-peer-conn".to_string())
142 .spawn(move || {
143 let _ = handle_connection(stream, &token, &resolver);
144 });
145 }
146 })
147 .map_err(|error| PeerError::Io(error.to_string()))?;
148
149 Ok(PeerServer { addr, running })
150 }
151
152 pub fn local_addr(&self) -> SocketAddr {
154 self.addr
155 }
156
157 pub fn port(&self) -> u16 {
159 self.addr.port()
160 }
161}
162
163impl Drop for PeerServer {
164 fn drop(&mut self) {
165 self.running.store(false, Ordering::SeqCst);
166 let _ = TcpStream::connect(self.addr);
168 }
169}
170
171fn handle_connection(
172 mut stream: TcpStream,
173 token: &str,
174 resolver: &SourceResolver,
175) -> Result<(), PeerError> {
176 stream.set_read_timeout(Some(Duration::from_secs(30)))?;
177 let mut reader = BufReader::new(stream.try_clone()?);
178
179 let mut request_line = String::new();
180 if reader.read_line(&mut request_line)? == 0 {
181 return Ok(()); }
183 let mut parts = request_line.split_whitespace();
184 let method = parts.next().unwrap_or("");
185 let path = parts.next().unwrap_or("");
186
187 let mut authorization = None;
188 let mut range = None;
189 loop {
190 let mut line = String::new();
191 if reader.read_line(&mut line)? == 0 {
192 break;
193 }
194 let line = line.trim_end();
195 if line.is_empty() {
196 break;
197 }
198 if let Some((name, value)) = line.split_once(':') {
199 let value = value.trim();
200 match name.trim().to_ascii_lowercase().as_str() {
201 "authorization" => authorization = Some(value.to_string()),
202 "range" => range = parse_range_header(value),
203 _ => {}
204 }
205 }
206 }
207
208 if method != "GET" {
209 return write_status(&mut stream, 405, "Method Not Allowed");
210 }
211 if authorization.as_deref() != Some(&format!("Bearer {token}")) {
212 return write_status(&mut stream, 401, "Unauthorized");
213 }
214 let Some(handle) = path.strip_prefix("/track/") else {
215 return write_status(&mut stream, 404, "Not Found");
216 };
217 let handle = percent_decode(handle);
218 let Some(source) = resolver(&handle) else {
219 return write_status(&mut stream, 404, "Not Found");
220 };
221
222 serve_source(&mut stream, source.as_ref(), range)
223}
224
225fn serve_source(
226 stream: &mut TcpStream,
227 source: &dyn ByteSource,
228 range: Option<(u64, Option<u64>)>,
229) -> Result<(), PeerError> {
230 let total = source.len();
231
232 let (status, reason, start, length) = match (range, total) {
233 (Some((start, end)), Some(total)) if start < total => {
234 let last = end.unwrap_or(total - 1).min(total - 1);
235 if last < start {
236 return write_status(stream, 416, "Range Not Satisfiable");
237 }
238 (206, "Partial Content", start, last - start + 1)
239 }
240 (Some((start, _)), Some(total)) if start >= total => {
241 return write_status(stream, 416, "Range Not Satisfiable");
242 }
243 (_, Some(total)) => (200, "OK", 0, total),
244 (Some(_), None) => return write_status(stream, 416, "Range Not Satisfiable"),
246 (None, None) => {
247 return serve_unknown_length(stream, source);
249 }
250 };
251
252 let mut header = format!(
253 "HTTP/1.1 {status} {reason}\r\nContent-Length: {length}\r\nAccept-Ranges: bytes\r\nContent-Type: application/octet-stream\r\nConnection: close\r\n"
254 );
255 if status == 206 {
256 if let Some(total) = total {
257 let end = start + length - 1;
258 header.push_str(&format!("Content-Range: bytes {start}-{end}/{total}\r\n"));
259 }
260 }
261 header.push_str("\r\n");
262 stream.write_all(header.as_bytes())?;
263
264 stream_bytes(stream, source, start, length)
265}
266
267fn serve_unknown_length(stream: &mut TcpStream, source: &dyn ByteSource) -> Result<(), PeerError> {
268 let header =
270 "HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\nConnection: close\r\n\r\n";
271 stream.write_all(header.as_bytes())?;
272 let mut buf = vec![0u8; 64 * 1024];
273 let mut offset = 0u64;
274 loop {
275 let n = source.read_at(offset, &mut buf)?;
276 if n == 0 {
277 break;
278 }
279 stream.write_all(&buf[..n])?;
280 offset += n as u64;
281 }
282 Ok(())
283}
284
285fn stream_bytes(
286 stream: &mut TcpStream,
287 source: &dyn ByteSource,
288 start: u64,
289 length: u64,
290) -> Result<(), PeerError> {
291 let mut buf = vec![0u8; 64 * 1024];
292 let mut sent = 0u64;
293 while sent < length {
294 let want = ((length - sent) as usize).min(buf.len());
295 let n = source.read_at(start + sent, &mut buf[..want])?;
296 if n == 0 {
297 break;
298 }
299 stream.write_all(&buf[..n])?;
300 sent += n as u64;
301 }
302 Ok(())
303}
304
305fn write_status(stream: &mut TcpStream, code: u16, reason: &str) -> Result<(), PeerError> {
306 let response =
307 format!("HTTP/1.1 {code} {reason}\r\nContent-Length: 0\r\nConnection: close\r\n\r\n");
308 stream.write_all(response.as_bytes())?;
309 Ok(())
310}
311
312fn parse_range_header(value: &str) -> Option<(u64, Option<u64>)> {
314 let spec = value.trim().strip_prefix("bytes=")?;
315 let (start, end) = spec.split_once('-')?;
316 let start = start.trim().parse::<u64>().ok()?;
317 let end = end.trim();
318 let end = if end.is_empty() {
319 None
320 } else {
321 Some(end.parse::<u64>().ok()?)
322 };
323 Some((start, end))
324}
325
326fn percent_decode(input: &str) -> String {
327 let bytes = input.as_bytes();
328 let mut out = Vec::with_capacity(bytes.len());
329 let mut i = 0;
330 while i < bytes.len() {
331 if bytes[i] == b'%' && i + 2 < bytes.len() {
332 if let (Some(h), Some(l)) = (hex_val(bytes[i + 1]), hex_val(bytes[i + 2])) {
333 out.push((h << 4) | l);
334 i += 3;
335 continue;
336 }
337 }
338 out.push(bytes[i]);
339 i += 1;
340 }
341 String::from_utf8_lossy(&out).into_owned()
342}
343
344fn hex_val(byte: u8) -> Option<u8> {
345 match byte {
346 b'0'..=b'9' => Some(byte - b'0'),
347 b'a'..=b'f' => Some(byte - b'a' + 10),
348 b'A'..=b'F' => Some(byte - b'A' + 10),
349 _ => None,
350 }
351}
352
353pub struct FetchResult {
359 pub total_len: Option<u64>,
361 pub bytes: Vec<u8>,
363}
364
365struct ResponseHead {
366 total_len: Option<u64>,
367 content_length: Option<u64>,
368 reader: BufReader<TcpStream>,
369}
370
371fn open_request(
375 base: &str,
376 token: &str,
377 handle: &str,
378 start: u64,
379 len: Option<u64>,
380) -> Result<ResponseHead, PeerError> {
381 let mut stream = TcpStream::connect(base)?;
382 stream.set_read_timeout(Some(Duration::from_secs(30)))?;
383
384 let range = match len {
385 Some(len) if len > 0 => format!("bytes={start}-{}", start + len - 1),
386 Some(_) => format!("bytes={start}-{start}"),
387 None => format!("bytes={start}-"),
388 };
389 let request = format!(
390 "GET /track/{} HTTP/1.1\r\nHost: {base}\r\nAuthorization: Bearer {token}\r\nRange: {range}\r\nConnection: close\r\n\r\n",
391 encode_handle(handle)
392 );
393 stream.write_all(request.as_bytes())?;
394
395 let mut reader = BufReader::new(stream);
396 let mut status_line = String::new();
397 reader.read_line(&mut status_line)?;
398 let status = parse_status(&status_line)?;
399
400 let mut total_len = None;
401 let mut content_length = None;
402 loop {
403 let mut line = String::new();
404 if reader.read_line(&mut line)? == 0 {
405 break;
406 }
407 let line = line.trim_end();
408 if line.is_empty() {
409 break;
410 }
411 if let Some((name, value)) = line.split_once(':') {
412 match name.trim().to_ascii_lowercase().as_str() {
413 "content-length" => content_length = value.trim().parse::<u64>().ok(),
414 "content-range" => total_len = parse_content_range_total(value.trim()),
415 _ => {}
416 }
417 }
418 }
419
420 match status {
421 401 => Err(PeerError::Unauthorized),
422 404 => Err(PeerError::NotFound),
423 200 | 206 => Ok(ResponseHead {
424 total_len,
425 content_length,
426 reader,
427 }),
428 other => Err(PeerError::Status(other)),
429 }
430}
431
432pub fn fetch_range(
435 base: &str,
436 token: &str,
437 handle: &str,
438 start: u64,
439 len: Option<u64>,
440) -> Result<FetchResult, PeerError> {
441 let mut head = open_request(base, token, handle, start, len)?;
442 let mut bytes = Vec::new();
443 match head.content_length {
444 Some(length) => {
445 bytes.resize(length as usize, 0);
446 head.reader.read_exact(&mut bytes)?;
447 }
448 None => {
449 head.reader.read_to_end(&mut bytes)?;
450 }
451 }
452 Ok(FetchResult {
453 total_len: head.total_len,
454 bytes,
455 })
456}
457
458pub fn fetch_to_writer(
462 base: &str,
463 token: &str,
464 handle: &str,
465 start: u64,
466 len: Option<u64>,
467 writer: &mut dyn Write,
468) -> Result<Option<u64>, PeerError> {
469 let mut head = open_request(base, token, handle, start, len)?;
470 let mut buf = vec![0u8; 64 * 1024];
471 let mut remaining = head.content_length;
472 loop {
473 let want = match remaining {
474 Some(0) => break,
475 Some(r) => (r as usize).min(buf.len()),
476 None => buf.len(),
477 };
478 let n = head.reader.read(&mut buf[..want])?;
479 if n == 0 {
480 break;
481 }
482 writer.write_all(&buf[..n])?;
483 if let Some(r) = remaining.as_mut() {
484 *r -= n as u64;
485 }
486 }
487 Ok(head.total_len)
488}
489
490pub fn content_length(base: &str, token: &str, handle: &str) -> Result<Option<u64>, PeerError> {
492 Ok(fetch_range(base, token, handle, 0, Some(1))?.total_len)
493}
494
495fn parse_status(line: &str) -> Result<u16, PeerError> {
496 line.split_whitespace()
497 .nth(1)
498 .and_then(|code| code.parse::<u16>().ok())
499 .ok_or_else(|| PeerError::Protocol(format!("bad status line: {line:?}")))
500}
501
502fn parse_content_range_total(value: &str) -> Option<u64> {
503 value.rsplit('/').next()?.trim().parse::<u64>().ok()
505}
506
507fn encode_handle(handle: &str) -> String {
508 let mut out = String::with_capacity(handle.len());
509 for byte in handle.as_bytes() {
510 match byte {
511 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
512 out.push(*byte as char)
513 }
514 other => out.push_str(&format!("%{other:02X}")),
515 }
516 }
517 out
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523
524 fn resolver_for(handle: &'static str, bytes: Vec<u8>) -> SourceResolver {
525 Arc::new(move |requested: &str| {
526 if requested == handle {
527 Some(Arc::new(BytesSource::new(bytes.clone())) as Arc<dyn ByteSource>)
528 } else {
529 None
530 }
531 })
532 }
533
534 #[test]
535 fn round_trips_full_and_partial() {
536 let data: Vec<u8> = (0..=255u8).cycle().take(5000).collect();
537 let server = PeerServer::start("127.0.0.1:0", "secret", resolver_for("song", data.clone()))
538 .expect("start");
539 let base = format!("127.0.0.1:{}", server.port());
540
541 let full = fetch_range(&base, "secret", "song", 0, None).expect("full");
542 assert_eq!(full.bytes, data);
543 assert_eq!(full.total_len, Some(5000));
544
545 let part = fetch_range(&base, "secret", "song", 1000, Some(256)).expect("part");
546 assert_eq!(part.bytes, data[1000..1256]);
547 assert_eq!(part.total_len, Some(5000));
548
549 assert_eq!(content_length(&base, "secret", "song").unwrap(), Some(5000));
550 }
551
552 #[test]
553 fn streams_to_writer_without_buffering() {
554 let data: Vec<u8> = (0..2000u32).map(|i| i as u8).collect();
555 let server =
556 PeerServer::start("127.0.0.1:0", "k", resolver_for("s", data.clone())).expect("start");
557 let base = format!("127.0.0.1:{}", server.port());
558 let mut out = Vec::new();
559 let total = fetch_to_writer(&base, "k", "s", 0, None, &mut out).expect("stream");
560 assert_eq!(out, data);
561 assert_eq!(total, Some(2000));
562 }
563
564 #[test]
565 fn rejects_wrong_token() {
566 let server =
567 PeerServer::start("127.0.0.1:0", "right", resolver_for("a", vec![1, 2, 3])).expect("s");
568 let base = format!("127.0.0.1:{}", server.port());
569 assert!(matches!(
570 fetch_range(&base, "wrong", "a", 0, None),
571 Err(PeerError::Unauthorized)
572 ));
573 }
574
575 #[test]
576 fn unknown_handle_is_not_found() {
577 let server =
578 PeerServer::start("127.0.0.1:0", "t", resolver_for("a", vec![1, 2, 3])).expect("s");
579 let base = format!("127.0.0.1:{}", server.port());
580 assert!(matches!(
581 fetch_range(&base, "t", "missing", 0, None),
582 Err(PeerError::NotFound)
583 ));
584 }
585
586 #[test]
587 fn handle_is_percent_encoded_round_trip() {
588 let server =
589 PeerServer::start("127.0.0.1:0", "t", resolver_for("a b/c.mp3", vec![9, 8, 7]))
590 .expect("s");
591 let base = format!("127.0.0.1:{}", server.port());
592 let got = fetch_range(&base, "t", "a b/c.mp3", 0, None).expect("fetch");
593 assert_eq!(got.bytes, vec![9, 8, 7]);
594 }
595}