use std::fmt;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;
use crate::error::{Error, ErrorKind};
use super::{AnnounceEvent, AnnounceRequest, AnnounceResponse, IntoUrl, Url};
use super::DEFAULT_TIMEOUT;
const MAX_RESPONSE_SIZE: u64 = 256 * 1024;
trait TrackerStream: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T: AsyncRead + AsyncWrite + Unpin + Send> TrackerStream for T {}
pub struct HttpTracker {
url: Url,
host: String,
port: u16,
tls: Option<TlsConnector>,
timeout: Duration,
}
impl fmt::Debug for HttpTracker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HttpTracker")
.field("url", &self.url)
.field("tls", &self.tls.is_some())
.finish()
}
}
impl Clone for HttpTracker {
fn clone(&self) -> Self {
HttpTracker {
url: self.url.clone(),
host: self.host.clone(),
port: self.port,
tls: self.tls.clone(),
timeout: self.timeout,
}
}
}
impl HttpTracker {
pub fn new(url: impl IntoUrl) -> Result<Self, Error> {
Self::with_timeout(url, DEFAULT_TIMEOUT)
}
pub fn with_timeout(url: impl IntoUrl, timeout: Duration) -> Result<Self, Error> {
let url = url.into_url()?;
let host = url
.host_str()
.ok_or(Error::new(ErrorKind::InvalidInput))?
.to_owned();
let port = url.port_or_known_default().unwrap_or(80);
let tls = if url.scheme() == "https" {
Some(build_tls_connector()?)
} else {
None
};
Ok(HttpTracker {
url,
host,
port,
tls,
timeout,
})
}
pub fn url(&self) -> &Url {
&self.url
}
pub async fn announce(&self, req: &AnnounceRequest) -> Result<AnnounceResponse, Error> {
tracing::info!("HTTP announce to {} (event: {:?})", self.url, req.event);
let path_and_query = format!("{}?{}", self.url.path(), build_query_string(req));
let host = self.host.clone();
let port = self.port;
let tls = self.tls.clone();
let response = tokio::time::timeout(self.timeout, async move {
let tcp_stream = TcpStream::connect((&*host, port))
.await
.map_err(Error::tracker_failed)?;
let mut stream: Box<dyn TrackerStream> = if let Some(ref connector) = tls {
use rustls::pki_types::ServerName;
let domain = ServerName::try_from(host.clone())
.map_err(Error::invalid_input)?;
let tls_stream = connector
.connect(domain, tcp_stream)
.await
.map_err(Error::tracker_failed)?;
Box::new(tls_stream)
} else {
Box::new(tcp_stream)
};
let request = format!(
"GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: torrent-rs/0.1.0\r\nAccept-Encoding: identity\r\nConnection: close\r\n\r\n",
path_and_query, host
);
stream
.write_all(request.as_bytes())
.await
.map_err(Error::tracker_failed)?;
let mut buf = Vec::new();
let mut limited = AsyncReadExt::take(&mut stream, MAX_RESPONSE_SIZE);
limited
.read_to_end(&mut buf)
.await
.map_err(Error::tracker_failed)?;
Ok(buf)
})
.await;
let buf = match response {
Ok(Ok(buf)) => buf,
Ok(Err(e)) => return Err(e),
Err(_) => return Err(Error::new(ErrorKind::TrackerRequestFailed)),
};
let Some(header_end) = buf.windows(4).position(|w| w == b"\r\n\r\n") else {
tracing::warn!("HTTP announce: missing header separator");
return Err(Error::new(ErrorKind::TrackerInvalidResponse));
};
let body = &buf[header_end + 4..];
let headers_str = std::str::from_utf8(&buf[..header_end])
.map_err(|_| Error::new(ErrorKind::TrackerInvalidResponse))?;
let first_line = headers_str.lines().next().unwrap_or("");
let status_code = first_line.split_whitespace().nth(1).unwrap_or("");
match status_code {
"301" | "302" => {
let location = headers_str
.lines()
.find(|l| l.to_ascii_lowercase().starts_with("location: "))
.and_then(|l| l.split_once(": ").map(|x| x.1))
.map(|l| l.trim())
.unwrap_or("");
if location.is_empty() {
return Err(Error::new(ErrorKind::TrackerProtocolError));
}
tracing::info!("HTTP redirect to {} — not yet supported", location);
Err(Error::new(ErrorKind::TrackerProtocolError))
}
"200" => AnnounceResponse::from_bencode(body),
_ => Err(Error::new(ErrorKind::TrackerRequestFailed)),
}
}
}
fn build_tls_connector() -> Result<TlsConnector, Error> {
let mut root_store = rustls::RootCertStore::empty();
let native_certs = rustls_native_certs::load_native_certs();
for cert in native_certs.certs {
root_store.add(cert).map_err(Error::invalid_input)?;
}
let config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
Ok(TlsConnector::from(Arc::new(config)))
}
fn build_query_string(req: &AnnounceRequest) -> String {
use url::form_urlencoded::byte_serialize;
let mut q = String::new();
q.push_str("info_hash=");
q.push_str(&byte_serialize(&req.info_hash).collect::<String>());
q.push_str("&peer_id=");
q.push_str(&byte_serialize(&req.peer_id.0).collect::<String>());
q.push_str(&format!("&port={}", req.port));
q.push_str(&format!("&uploaded={}", req.uploaded));
q.push_str(&format!("&downloaded={}", req.downloaded));
q.push_str(&format!("&left={}", req.left));
if req.compact {
q.push_str("&compact=1");
}
let event_str = match req.event {
AnnounceEvent::Started => "started",
AnnounceEvent::Stopped => "stopped",
AnnounceEvent::Completed => "completed",
AnnounceEvent::None => "empty",
_ => "empty", };
q.push_str("&event=");
q.push_str(event_str);
if let Some(numwant) = req.numwant {
q.push_str(&format!("&numwant={}", numwant));
}
if let Some(key) = req.key {
q.push_str(&format!("&key={}", key));
}
if let Some(ref trackerid) = req.trackerid {
q.push_str("&trackerid=");
q.push_str(trackerid);
}
q
}
#[cfg(test)]
mod tests {
use super::*;
use crate::peer::PeerId;
#[test]
fn test_build_query_string() {
let mut req = AnnounceRequest::new([0x01; 20], PeerId::random(), 6881);
req.left = 1024;
req.event = AnnounceEvent::Started;
let q = build_query_string(&req);
assert!(q.starts_with("info_hash="));
assert!(q.contains("&peer_id="));
assert!(q.contains("&port=6881"));
assert!(q.contains("&compact=1"));
assert!(q.contains("&event=started"));
assert!(q.contains("&left=1024"));
}
#[test]
fn test_new_invalid_url() {
assert!(HttpTracker::new("not-a-valid-url").is_err());
}
#[test]
fn test_build_query_string_binary_info_hash() {
let info_hash = [
0x00, 0x01, 0x7F, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];
let mut req = AnnounceRequest::new(info_hash, PeerId::random(), 6881);
req.compact = false;
req.numwant = None;
req.left = 100;
let q = build_query_string(&req);
assert!(q.contains("%00%01%7F%FF"));
}
}