pub mod errors;
mod selector;
use crate::http::{headermap_to_hashmap, query_params::default_proxy_rule_filtered_query_params};
use crate::id_generator::{IDGenerator, TaskIDParameter};
use crate::net::format_url;
use crate::net::preferred_local_ip;
use crate::pool::{Builder as PoolBuilder, Entry, Factory, Pool};
use bytes::BytesMut;
use dragonfly_api::scheduler::v2::scheduler_client::SchedulerClient;
use errors::{BackendError, DfdaemonError, Error, ProxyError};
use futures::TryStreamExt;
use hostname;
use reqwest::{header::HeaderMap, header::HeaderValue, Client};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_tracing::TracingMiddleware;
use rustix::path::Arg;
use rustls_pki_types::CertificateDer;
use selector::{SeedPeerSelector, Selector};
use std::io::{Error as IOError, ErrorKind};
use std::net::IpAddr;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::AsyncRead;
use tokio_util::io::StreamReader;
use tonic::transport::Endpoint;
use tracing::{debug, error};
const POOL_MAX_IDLE_PER_HOST: usize = 1024;
const KEEP_ALIVE_INTERVAL: Duration = Duration::from_secs(60);
const DEFAULT_CLIENT_POOL_IDLE_TIMEOUT: Duration = Duration::from_secs(30 * 60);
const DEFAULT_CLIENT_POOL_CAPACITY: usize = 128;
const DEFAULT_SCHEDULER_REQUEST_TIMEOUT: Duration = Duration::from_secs(5);
pub type Result<T> = std::result::Result<T, Error>;
pub type Body = Box<dyn AsyncRead + Send + Unpin>;
#[tonic::async_trait]
pub trait Request {
async fn get(&self, request: GetRequest) -> Result<GetResponse<Body>>;
async fn get_into(&self, request: GetRequest, buf: &mut BytesMut) -> Result<GetResponse>;
}
pub struct GetRequest {
pub url: String,
pub header: Option<HeaderMap>,
pub piece_length: Option<u64>,
pub tag: Option<String>,
pub application: Option<String>,
pub filtered_query_params: Vec<String>,
pub content_for_calculating_task_id: Option<String>,
pub priority: Option<i32>,
pub timeout: Duration,
pub client_cert: Option<Vec<CertificateDer<'static>>>,
}
pub struct GetResponse<R = Body>
where
R: AsyncRead + Unpin,
{
pub success: bool,
pub header: Option<HeaderMap>,
pub status_code: Option<reqwest::StatusCode>,
pub reader: Option<R>,
}
#[derive(Debug, Clone, Default)]
struct HTTPClientFactory {}
#[tonic::async_trait]
impl Factory<String, ClientWithMiddleware> for HTTPClientFactory {
type Error = Error;
async fn make_client(&self, proxy_addr: &String) -> Result<ClientWithMiddleware> {
let client = Client::builder()
.hickory_dns(true)
.danger_accept_invalid_certs(true)
.pool_max_idle_per_host(POOL_MAX_IDLE_PER_HOST)
.tcp_keepalive(KEEP_ALIVE_INTERVAL)
.proxy(reqwest::Proxy::all(proxy_addr).map_err(|err| {
Error::Internal(format!("failed to set proxy {}: {}", proxy_addr, err))
})?)
.build()
.map_err(|err| Error::Internal(format!("failed to build reqwest client: {}", err)))?;
Ok(ClientBuilder::new(client)
.with(TracingMiddleware::default())
.build())
}
}
pub struct Builder {
scheduler_endpoint: String,
scheduler_request_timeout: Duration,
health_check_interval: Duration,
max_retries: u8,
}
impl Default for Builder {
fn default() -> Self {
Self {
scheduler_endpoint: "".to_string(),
scheduler_request_timeout: DEFAULT_SCHEDULER_REQUEST_TIMEOUT,
health_check_interval: Duration::from_secs(60),
max_retries: 1,
}
}
}
impl Builder {
pub fn scheduler_endpoint(mut self, endpoint: String) -> Self {
self.scheduler_endpoint = endpoint;
self
}
pub fn scheduler_request_timeout(mut self, timeout: Duration) -> Self {
self.scheduler_request_timeout = timeout;
self
}
pub fn health_check_interval(mut self, interval: Duration) -> Self {
self.health_check_interval = interval;
self
}
pub fn max_retries(mut self, retries: u8) -> Self {
self.max_retries = retries;
self
}
pub async fn build(self) -> Result<Proxy> {
self.validate()?;
let scheduler_channel = Endpoint::from_shared(self.scheduler_endpoint.to_string())
.map_err(|err| Error::InvalidArgument(err.to_string()))?
.connect_timeout(self.scheduler_request_timeout)
.timeout(self.scheduler_request_timeout)
.connect()
.await
.map_err(|err| {
Error::Internal(format!(
"failed to connect to scheduler {}: {}",
self.scheduler_endpoint, err
))
})?;
let scheduler_client = SchedulerClient::new(scheduler_channel);
let seed_peer_selector = Arc::new(
SeedPeerSelector::new(scheduler_client, self.health_check_interval)
.await
.map_err(|err| {
Error::Internal(format!("failed to create seed peer selector: {}", err))
})?,
);
let seed_peer_selector_clone = seed_peer_selector.clone();
tokio::spawn(async move {
seed_peer_selector_clone.run().await;
});
let local_ip = preferred_local_ip().unwrap().to_string();
let hostname = hostname::get().unwrap().to_string_lossy().to_string();
let id_generator = IDGenerator::new(local_ip, hostname, true);
let proxy = Proxy {
seed_peer_selector,
max_retries: self.max_retries,
client_pool: PoolBuilder::new(HTTPClientFactory::default())
.capacity(DEFAULT_CLIENT_POOL_CAPACITY)
.idle_timeout(DEFAULT_CLIENT_POOL_IDLE_TIMEOUT)
.build(),
id_generator: Arc::new(id_generator),
};
Ok(proxy)
}
fn validate(&self) -> Result<()> {
if let Err(err) = url::Url::parse(&self.scheduler_endpoint) {
return Err(Error::InvalidArgument(err.to_string()));
};
if self.scheduler_request_timeout.as_millis() < 100 {
return Err(Error::InvalidArgument(
"scheduler request timeout must be at least 100 milliseconds".to_string(),
));
}
if self.health_check_interval.as_secs() < 1 || self.health_check_interval.as_secs() > 600 {
return Err(Error::InvalidArgument(
"health check interval must be between 1 and 600 seconds".to_string(),
));
}
if self.max_retries > 10 {
return Err(Error::InvalidArgument(
"max retries must be between 0 and 10".to_string(),
));
}
Ok(())
}
}
pub struct Proxy {
seed_peer_selector: Arc<SeedPeerSelector>,
max_retries: u8,
client_pool: Pool<String, String, ClientWithMiddleware, HTTPClientFactory>,
id_generator: Arc<IDGenerator>,
}
impl Proxy {
pub fn builder() -> Builder {
Builder::default()
}
}
#[tonic::async_trait]
impl Request for Proxy {
async fn get(&self, request: GetRequest) -> Result<GetResponse> {
let response = self.try_send(&request).await?;
let header = response.headers().clone();
let status_code = response.status();
let reader = Box::new(StreamReader::new(
response
.bytes_stream()
.map_err(|err| IOError::new(ErrorKind::Other, err)),
));
Ok(GetResponse {
success: status_code.is_success(),
header: Some(header),
status_code: Some(status_code),
reader: Some(reader),
})
}
async fn get_into(&self, request: GetRequest, buf: &mut BytesMut) -> Result<GetResponse> {
let get_into = async {
let response = self.try_send(&request).await?;
let status = response.status();
let headers = response.headers().clone();
if status.is_success() {
let bytes = response.bytes().await.map_err(|err| {
Error::Internal(format!("failed to read response body: {}", err))
})?;
buf.extend_from_slice(&bytes);
}
Ok(GetResponse {
success: status.is_success(),
header: Some(headers),
status_code: Some(status),
reader: None,
})
};
tokio::time::timeout(request.timeout, get_into)
.await
.map_err(|err| Error::RequestTimeout(err.to_string()))?
}
}
impl Proxy {
async fn client_entries(
&self,
request: &GetRequest,
) -> Result<Vec<Entry<ClientWithMiddleware>>> {
let filtered_query_params = if request.filtered_query_params.is_empty() {
default_proxy_rule_filtered_query_params()
} else {
request.filtered_query_params.clone()
};
let task_id = self
.id_generator
.task_id(match request.content_for_calculating_task_id.as_ref() {
Some(content) => TaskIDParameter::Content(content.clone()),
None => TaskIDParameter::URLBased {
url: request.url.clone(),
piece_length: request.piece_length,
tag: request.tag.clone(),
application: request.application.clone(),
filtered_query_params,
},
})
.map_err(|err| Error::Internal(format!("failed to generate task id: {}", err)))?;
let seed_peers = self
.seed_peer_selector
.select(task_id.clone(), self.max_retries as u32)
.await
.map_err(|err| {
Error::Internal(format!(
"failed to select seed peers from scheduler: {}",
err
))
})?;
debug!("task {} selected seed peers: {:?}", task_id, seed_peers);
let mut client_entries = Vec::with_capacity(seed_peers.len());
for peer in seed_peers.iter() {
let addr = format_url(
"http",
IpAddr::from_str(&peer.ip).map_err(|err| Error::Internal(err.to_string()))?,
peer.proxy_port as u16,
);
let client_entry = self.client_pool.entry(&addr, &addr).await?;
client_entries.push(client_entry);
}
Ok(client_entries)
}
async fn try_send(&self, request: &GetRequest) -> Result<reqwest::Response> {
let entries = self.client_entries(request).await?;
if entries.is_empty() {
return Err(Error::Internal(
"no available client entries to send request".to_string(),
));
}
for (index, entry) in entries.iter().enumerate() {
match self.send(entry, request).await {
Ok(response) => return Ok(response),
Err(err) => {
error!(
"failed to send request to client entry {:?}: {:?}",
entry.client, err
);
if index == entries.len() - 1 {
return Err(err);
}
}
}
}
Err(Error::Internal(
"failed to send request to any client entry".to_string(),
))
}
async fn send(
&self,
entry: &Entry<ClientWithMiddleware>,
request: &GetRequest,
) -> Result<reqwest::Response> {
let headers = self.make_request_headers(request)?;
let response = entry
.client
.get(&request.url)
.headers(headers.clone())
.timeout(request.timeout)
.send()
.await
.map_err(|err| Error::Internal(err.to_string()))?;
let status = response.status();
if status.is_success() {
return Ok(response);
}
let response_headers = response.headers().clone();
let header_map = headermap_to_hashmap(&response_headers);
let message = response.text().await.ok();
let error_type = response_headers
.get("X-Dragonfly-Error-Type")
.and_then(|v| v.to_str().ok());
match error_type {
Some("backend") => Err(Error::BackendError(BackendError {
message,
header: header_map,
status_code: Some(status),
})),
Some("proxy") => Err(Error::ProxyError(ProxyError {
message,
header: header_map,
status_code: Some(status),
})),
Some("dfdaemon") => Err(Error::DfdaemonError(DfdaemonError { message })),
Some(other) => Err(Error::ProxyError(ProxyError {
message: Some(format!("unknown error type from proxy: {}", other)),
header: header_map,
status_code: Some(status),
})),
None => Err(Error::ProxyError(ProxyError {
message: Some(format!("unexpected status code from proxy: {}", status)),
header: header_map,
status_code: Some(status),
})),
}
}
fn make_request_headers(&self, request: &GetRequest) -> Result<HeaderMap> {
let mut headers = request.header.clone().unwrap_or_default();
if let Some(piece_length) = request.piece_length {
headers.insert(
"X-Dragonfly-Piece-Length",
piece_length.to_string().parse().map_err(|err| {
Error::InvalidArgument(format!("invalid piece length: {}", err))
})?,
);
}
if let Some(tag) = request.tag.clone() {
headers.insert(
"X-Dragonfly-Tag",
tag.to_string()
.parse()
.map_err(|err| Error::InvalidArgument(format!("invalid tag: {}", err)))?,
);
}
if let Some(application) = request.application.clone() {
headers.insert(
"X-Dragonfly-Application",
application.to_string().parse().map_err(|err| {
Error::InvalidArgument(format!("invalid application: {}", err))
})?,
);
}
if let Some(content_for_calculating_task_id) =
request.content_for_calculating_task_id.clone()
{
headers.insert(
"X-Dragonfly-Content-For-Calculating-Task-ID",
content_for_calculating_task_id
.to_string()
.parse()
.map_err(|err| {
Error::InvalidArgument(format!(
"invalid content for calculating task id: {}",
err
))
})?,
);
}
if let Some(priority) = request.priority {
headers.insert(
"X-Dragonfly-Priority",
priority
.to_string()
.parse()
.map_err(|err| Error::InvalidArgument(format!("invalid priority: {}", err)))?,
);
}
if !request.filtered_query_params.is_empty() {
let value = request.filtered_query_params.join(",");
headers.insert(
"X-Dragonfly-Filtered-Query-Params",
value.parse().map_err(|err| {
Error::InvalidArgument(format!("invalid filtered query params: {}", err))
})?,
);
}
headers.insert("X-Dragonfly-Use-P2P", HeaderValue::from_static("true"));
Ok(headers)
}
}
#[cfg(test)]
mod tests {
use super::*;
use dragonfly_api::scheduler::v2::ListHostsResponse;
use mocktail::prelude::*;
use std::time::Duration;
async fn setup_mock_scheduler() -> Result<mocktail::server::MockServer> {
let mut mocks = MockSet::new();
mocks.mock(|when, then| {
when.path("/scheduler.v2.Scheduler/ListHosts");
then.pb(ListHostsResponse { hosts: vec![] });
});
let server = MockServer::new_grpc("scheduler.v2.Scheduler").with_mocks(mocks);
server.start().await.map_err(|err| {
Error::Internal(format!("failed to start mock scheduler server: {}", err))
})?;
Ok(server)
}
#[tokio::test]
async fn test_proxy_new_success() {
let mock_server = setup_mock_scheduler().await.unwrap();
let scheduler_endpoint = format!("http://0.0.0.0:{}", mock_server.port().unwrap());
let result = Proxy::builder()
.scheduler_endpoint(scheduler_endpoint)
.build()
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().max_retries, 1);
}
#[tokio::test]
async fn test_proxy_new_empty_endpoint() {
let result = Proxy::builder()
.scheduler_endpoint("".to_string())
.build()
.await;
assert!(result.is_err());
assert!(matches!(result, Err(Error::InvalidArgument(_))));
}
#[tokio::test]
async fn test_proxy_new_invalid_retry_times() {
let mock_server = setup_mock_scheduler().await.unwrap();
let scheduler_endpoint = format!("http://0.0.0.0:{}", mock_server.port().unwrap());
let result = Proxy::builder()
.scheduler_endpoint(scheduler_endpoint)
.max_retries(11)
.build()
.await;
assert!(result.is_err());
assert!(matches!(result, Err(Error::InvalidArgument(_))));
}
#[tokio::test]
async fn test_proxy_new_invalid_health_check_interval() {
let mock_server = setup_mock_scheduler().await.unwrap();
let scheduler_endpoint = format!("http://0.0.0.0:{}", mock_server.port().unwrap());
let result = Proxy::builder()
.scheduler_endpoint(scheduler_endpoint)
.max_retries(11)
.build()
.await;
assert!(result.is_err());
assert!(matches!(result, Err(Error::InvalidArgument(_))));
}
#[tokio::test]
async fn test_client_pool_get_or_create() {
let pool = PoolBuilder::new(HTTPClientFactory {})
.capacity(10)
.idle_timeout(Duration::from_secs(600))
.build();
assert_eq!(pool.size().await, 0);
let addr = "http://proxy1.com".to_string();
let _ = pool.entry(&addr, &addr).await.unwrap();
assert_eq!(pool.size().await, 1);
let _ = pool.entry(&addr, &addr).await.unwrap();
assert_eq!(pool.size().await, 1);
let addr = "http://proxy2.com".to_string();
let _ = pool.entry(&addr, &addr).await.unwrap();
assert_eq!(pool.size().await, 2);
}
#[tokio::test]
async fn test_client_pool_cleanup() {
let pool = PoolBuilder::new(HTTPClientFactory {})
.capacity(10)
.idle_timeout(Duration::from_millis(10))
.build();
let addr = "http://proxy1.com".to_string();
let _ = pool.entry(&addr, &addr).await.unwrap();
assert_eq!(pool.size().await, 1);
tokio::time::sleep(Duration::from_millis(50)).await;
let addr = "http://proxy2.com".to_string();
let _ = pool.entry(&addr, &addr).await.unwrap();
assert_eq!(pool.size().await, 1);
}
}