use std::fmt;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;
use crate::error::{Error, ErrorKind};
use super::{AnnounceEvent, AnnounceRequest, AnnounceResponse, IntoUrl, Url};
use super::DEFAULT_TIMEOUT;
const MAX_REDIRECTS: u32 = 5;
const MAX_RESPONSE_SIZE: u64 = 256 * 1024;
trait TrackerStream: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T: AsyncRead + AsyncWrite + Unpin + Send> TrackerStream for T {}
pub struct HttpTracker {
url: Url,
host: String,
port: u16,
tls: Option<TlsConnector>,
timeout: Duration,
}
impl fmt::Debug for HttpTracker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HttpTracker")
.field("url", &self.url)
.field("tls", &self.tls.is_some())
.finish()
}
}
impl Clone for HttpTracker {
fn clone(&self) -> Self {
HttpTracker {
url: self.url.clone(),
host: self.host.clone(),
port: self.port,
tls: self.tls.clone(),
timeout: self.timeout,
}
}
}
impl HttpTracker {
pub fn new(url: impl IntoUrl) -> Result<Self, Error> {
HttpTracker::with_timeout(url, DEFAULT_TIMEOUT)
}
pub fn with_timeout(url: impl IntoUrl, timeout: Duration) -> Result<Self, Error> {
let url = url.into_url()?;
let host = url
.host_str()
.ok_or(Error::new(ErrorKind::InvalidInput))?
.to_owned();
let port = url.port_or_known_default().unwrap_or(80);
let tls = if url.scheme() == "https" {
Some(build_tls_connector()?)
} else {
None
};
Ok(HttpTracker {
url,
host,
port,
tls,
timeout,
})
}
pub fn url(&self) -> &Url {
&self.url
}
pub async fn announce(&self, req: &AnnounceRequest) -> Result<AnnounceResponse, Error> {
tracing::info!("HTTP announce to {} (event: {:?})", self.url, req.event);
let mut current_url = self.url.clone();
let mut tls = self.tls.clone();
let mut redirects_remaining = MAX_REDIRECTS;
loop {
let path_and_query = format!("{}?{}", current_url.path(), build_query_string(req));
let buf = send_http_request(¤t_url, &tls, &path_and_query, self.timeout).await?;
let Some(header_end) = buf.windows(4).position(|w| w == b"\r\n\r\n") else {
tracing::warn!("HTTP announce: missing header separator");
return Err(Error::new(ErrorKind::TrackerInvalidResponse));
};
let body = &buf[header_end + 4..];
let headers_str = std::str::from_utf8(&buf[..header_end])
.map_err(|_| Error::new(ErrorKind::TrackerInvalidResponse))?;
let first_line = headers_str.lines().next().unwrap_or("");
let status_code = first_line.split_whitespace().nth(1).unwrap_or("");
match status_code {
"301" | "302" => {
redirects_remaining -= 1;
if redirects_remaining == 0 {
tracing::warn!("HTTP announce: too many redirects");
return Err(Error::new(ErrorKind::TrackerRequestFailed));
}
let location = headers_str
.lines()
.find(|l| l.to_ascii_lowercase().starts_with("location: "))
.and_then(|l| l.split_once(": ").map(|x| x.1))
.map(|l| l.trim())
.unwrap_or("");
if location.is_empty() {
return Err(Error::new(ErrorKind::TrackerProtocolError));
}
let new_url = resolve_redirect_url(¤t_url, location)?;
tracing::info!(
"HTTP redirect #{}/{}: {} -> {}",
MAX_REDIRECTS - redirects_remaining,
MAX_REDIRECTS,
current_url,
new_url,
);
if new_url.scheme() == "https" && tls.is_none() {
tls = Some(build_tls_connector()?);
} else if new_url.scheme() == "http" {
tls = None;
}
current_url = new_url;
continue;
}
"200" => return AnnounceResponse::from_bencode(body),
_ => {
tracing::warn!("HTTP announce: unexpected status {}", status_code);
return Err(Error::new(ErrorKind::TrackerRequestFailed));
}
}
}
}
}
fn build_tls_connector() -> Result<TlsConnector, Error> {
let mut root_store = rustls::RootCertStore::empty();
let native_certs = rustls_native_certs::load_native_certs();
for cert in native_certs.certs {
root_store.add(cert).map_err(Error::invalid_input)?;
}
let config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
Ok(TlsConnector::from(Arc::new(config)))
}
async fn send_http_request(
url: &Url, tls: &Option<TlsConnector>, path_and_query: &str, timeout: Duration,
) -> Result<Vec<u8>, Error> {
let host = url
.host_str()
.ok_or(Error::new(ErrorKind::InvalidInput))?
.to_owned();
let port = url.port_or_known_default().unwrap_or(80);
let tls = tls.clone();
let path_and_query = path_and_query.to_owned();
let response = tokio::time::timeout(timeout, async move {
let tcp_stream = TcpStream::connect((&*host, port))
.await
.map_err(Error::tracker_failed)?;
let mut stream: Box<dyn TrackerStream> = if let Some(ref connector) = tls {
use rustls::pki_types::ServerName;
let domain = ServerName::try_from(host.clone())
.map_err(Error::invalid_input)?;
let tls_stream = connector
.connect(domain, tcp_stream)
.await
.map_err(Error::tracker_failed)?;
Box::new(tls_stream)
} else {
Box::new(tcp_stream)
};
let request = format!(
"GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: torrent-rs/0.1.0\r\nAccept-Encoding: identity\r\nConnection: close\r\n\r\n",
path_and_query, host
);
stream
.write_all(request.as_bytes())
.await
.map_err(Error::tracker_failed)?;
let mut buf = Vec::new();
let mut limited = AsyncReadExt::take(&mut stream, MAX_RESPONSE_SIZE);
limited
.read_to_end(&mut buf)
.await
.map_err(Error::tracker_failed)?;
Ok(buf)
})
.await;
match response {
Ok(Ok(buf)) => Ok(buf),
Ok(Err(e)) => Err(e),
Err(_) => Err(Error::new(ErrorKind::TrackerRequestFailed)),
}
}
fn resolve_redirect_url(base: &Url, location: &str) -> Result<Url, Error> {
let new_url = Url::options()
.base_url(Some(base))
.parse(location)
.map_err(|_| Error::new(ErrorKind::TrackerProtocolError))?;
match new_url.scheme() {
"http" | "https" => Ok(new_url),
_ => {
tracing::warn!(
"HTTP redirect: unsupported scheme '{}' in redirect URL {}",
new_url.scheme(),
new_url,
);
Err(Error::new(ErrorKind::TrackerProtocolError))
}
}
}
fn build_query_string(req: &AnnounceRequest) -> String {
use url::form_urlencoded::byte_serialize;
let mut q = String::new();
q.push_str("info_hash=");
q.push_str(&byte_serialize(&req.info_hash).collect::<String>());
q.push_str("&peer_id=");
q.push_str(&byte_serialize(&req.peer_id.0).collect::<String>());
q.push_str(&format!("&port={}", req.port));
q.push_str(&format!("&uploaded={}", req.uploaded));
q.push_str(&format!("&downloaded={}", req.downloaded));
q.push_str(&format!("&left={}", req.left));
if req.compact {
q.push_str("&compact=1");
}
let event_str = match req.event {
AnnounceEvent::Started => "started",
AnnounceEvent::Stopped => "stopped",
AnnounceEvent::Completed => "completed",
AnnounceEvent::None => "empty",
_ => "empty", };
q.push_str("&event=");
q.push_str(event_str);
if let Some(numwant) = req.numwant {
q.push_str(&format!("&numwant={}", numwant));
}
if let Some(key) = req.key {
q.push_str(&format!("&key={}", key));
}
if let Some(ref trackerid) = req.trackerid {
q.push_str("&trackerid=");
q.push_str(trackerid);
}
if let Some(ip) = req.ip {
q.push_str(&format!("&ip={ip}"));
}
if let Some(ipv6) = req.ipv6 {
q.push_str(&format!("&ipv6={ipv6}"));
}
q
}
#[cfg(test)]
mod tests {
use super::*;
use crate::peer::PeerId;
#[test]
fn test_build_query_string() {
let mut req = AnnounceRequest::new([0x01; 20], PeerId::random(), 6881);
req.left = 1024;
req.event = AnnounceEvent::Started;
let q = build_query_string(&req);
assert!(q.starts_with("info_hash="));
assert!(q.contains("&peer_id="));
assert!(q.contains("&port=6881"));
assert!(q.contains("&compact=1"));
assert!(q.contains("&event=started"));
assert!(q.contains("&left=1024"));
}
#[test]
fn test_new_invalid_url() {
assert!(HttpTracker::new("not-a-valid-url").is_err());
}
#[test]
fn test_build_query_string_binary_info_hash() {
let info_hash = [
0x00, 0x01, 0x7F, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];
let mut req = AnnounceRequest::new(info_hash, PeerId::random(), 6881);
req.compact = false;
req.numwant = None;
req.left = 100;
let q = build_query_string(&req);
assert!(q.contains("%00%01%7F%FF"));
}
#[test]
fn redirect_absolute_url() {
let base = Url::parse("http://tracker.example.com:6969/announce").unwrap();
let resolved = resolve_redirect_url(&base, "http://new.example.com/announce").unwrap();
assert_eq!(resolved.as_str(), "http://new.example.com/announce");
}
#[test]
fn redirect_relative_path() {
let base = Url::parse("http://tracker.example.com:6969/announce").unwrap();
let resolved = resolve_redirect_url(&base, "/new-announce").unwrap();
assert_eq!(
resolved.as_str(),
"http://tracker.example.com:6969/new-announce"
);
}
#[test]
fn redirect_https_to_http() {
let base = Url::parse("https://tracker.example.com/announce").unwrap();
let resolved = resolve_redirect_url(&base, "http://other.example.com/announce").unwrap();
assert_eq!(resolved.as_str(), "http://other.example.com/announce");
}
#[test]
fn redirect_http_to_https() {
let base = Url::parse("http://tracker.example.com/announce").unwrap();
let resolved = resolve_redirect_url(&base, "https://tracker.example.com/announce").unwrap();
assert_eq!(resolved.as_str(), "https://tracker.example.com/announce");
}
#[test]
fn redirect_rejects_udp_scheme() {
let base = Url::parse("http://tracker.example.com/announce").unwrap();
assert!(resolve_redirect_url(&base, "udp://tracker.example.com:6969").is_err());
}
#[test]
fn redirect_empty_location_is_error() {
let base = Url::parse("http://tracker.example.com/announce").unwrap();
let resolved = resolve_redirect_url(&base, "").unwrap();
assert_eq!(resolved.as_str(), "http://tracker.example.com/announce");
}
#[test]
fn redirect_scheme_relative() {
let base = Url::parse("http://tracker.example.com/announce").unwrap();
let resolved = resolve_redirect_url(&base, "//other.example.com/announce").unwrap();
assert_eq!(resolved.as_str(), "http://other.example.com/announce");
}
}