use async_trait::async_trait;
use dragonfly_client_config::dfdaemon::Config;
use dragonfly_client_core::{Error, Result};
use dragonfly_client_storage::{client::quic::QUICClient, client::tcp::TCPClient};
use dragonfly_client_util::pool::{Builder as PoolBuilder, Entry, Factory, Pool};
use std::sync::Arc;
use std::time::Duration;
use tokio::io::AsyncRead;
use tracing::{error, instrument};
const DEFAULT_DOWNLOADER_CAPACITY: usize = 2000;
const DEFAULT_DOWNLOADER_IDLE_TIMEOUT: Duration = Duration::from_secs(420);
#[async_trait]
pub trait Downloader: Send + Sync {
async fn download_piece(
&self,
addr: &str,
number: u32,
host_id: &str,
task_id: &str,
) -> Result<(Box<dyn AsyncRead + Send + Unpin>, u64, String)>;
async fn download_persistent_piece(
&self,
addr: &str,
number: u32,
host_id: &str,
task_id: &str,
) -> Result<(Box<dyn AsyncRead + Send + Unpin>, u64, String)>;
async fn download_persistent_cache_piece(
&self,
addr: &str,
number: u32,
host_id: &str,
task_id: &str,
) -> Result<(Box<dyn AsyncRead + Send + Unpin>, u64, String)>;
}
pub struct DownloaderFactory {
downloader: Arc<dyn Downloader + Send + Sync>,
}
impl DownloaderFactory {
pub fn new(protocol: &str, config: Arc<Config>) -> Result<Self> {
let downloader: Arc<dyn Downloader> = match protocol {
"tcp" => Arc::new(TCPDownloader::new(
config.clone(),
DEFAULT_DOWNLOADER_CAPACITY,
DEFAULT_DOWNLOADER_IDLE_TIMEOUT,
)),
"quic" => Arc::new(QUICDownloader::new(
config.clone(),
DEFAULT_DOWNLOADER_CAPACITY,
DEFAULT_DOWNLOADER_IDLE_TIMEOUT,
)),
_ => {
error!("unsupported protocol: {}", protocol);
return Err(Error::InvalidParameter);
}
};
Ok(Self { downloader })
}
pub fn build(&self) -> Arc<dyn Downloader> {
self.downloader.clone()
}
}
pub struct QUICDownloader {
client_pool: Pool<String, String, QUICClient, QUICClientFactory>,
}
struct QUICClientFactory {
config: Arc<Config>,
}
#[async_trait]
impl Factory<String, QUICClient> for QUICClientFactory {
type Error = Error;
async fn make_client(&self, addr: &String) -> Result<QUICClient> {
Ok(QUICClient::new(self.config.clone(), addr.clone()))
}
}
impl QUICDownloader {
const MAX_CONNECTIONS_PER_ADDRESS: usize = 32;
pub fn new(config: Arc<Config>, capacity: usize, idle_timeout: Duration) -> Self {
Self {
client_pool: PoolBuilder::new(QUICClientFactory {
config: config.clone(),
})
.capacity(capacity)
.idle_timeout(idle_timeout)
.build(),
}
}
async fn get_client_entry(&self, key: String, addr: String) -> Result<Entry<QUICClient>> {
self.client_pool.entry(&key, &addr).await
}
async fn remove_client_entry(&self, key: String) {
self.client_pool.remove_entry(&key).await;
}
fn get_entry_key(&self, addr: &str) -> String {
format!(
"{}-{}",
addr,
fastrand::usize(..Self::MAX_CONNECTIONS_PER_ADDRESS)
)
}
}
#[async_trait]
impl Downloader for QUICDownloader {
#[instrument(skip_all)]
async fn download_piece(
&self,
addr: &str,
number: u32,
_host_id: &str,
task_id: &str,
) -> Result<(Box<dyn AsyncRead + Send + Unpin>, u64, String)> {
let key = self.get_entry_key(addr);
let entry = self.get_client_entry(key.clone(), addr.to_string()).await?;
let request_guard = entry.request_guard();
match entry.client.download_piece(number, task_id).await {
Ok((reader, offset, digest)) => Ok((Box::new(reader), offset, digest)),
Err(err) => {
drop(request_guard);
self.remove_client_entry(key).await;
Err(err)
}
}
}
#[instrument(skip_all)]
async fn download_persistent_piece(
&self,
addr: &str,
number: u32,
_host_id: &str,
task_id: &str,
) -> Result<(Box<dyn AsyncRead + Send + Unpin>, u64, String)> {
let key = self.get_entry_key(addr);
let entry = self.get_client_entry(key.clone(), addr.to_string()).await?;
let request_guard = entry.request_guard();
match entry
.client
.download_persistent_piece(number, task_id)
.await
{
Ok((reader, offset, digest)) => Ok((Box::new(reader), offset, digest)),
Err(err) => {
drop(request_guard);
self.remove_client_entry(key).await;
Err(err)
}
}
}
#[instrument(skip_all)]
async fn download_persistent_cache_piece(
&self,
addr: &str,
number: u32,
_host_id: &str,
task_id: &str,
) -> Result<(Box<dyn AsyncRead + Send + Unpin>, u64, String)> {
let key = self.get_entry_key(addr);
let entry = self.get_client_entry(key.clone(), addr.to_string()).await?;
let request_guard = entry.request_guard();
match entry
.client
.download_persistent_cache_piece(number, task_id)
.await
{
Ok((reader, offset, digest)) => Ok((Box::new(reader), offset, digest)),
Err(err) => {
drop(request_guard);
self.remove_client_entry(key).await;
Err(err)
}
}
}
}
pub struct TCPDownloader {
client_pool: Pool<String, String, TCPClient, TCPClientFactory>,
}
struct TCPClientFactory {
config: Arc<Config>,
}
#[async_trait]
impl Factory<String, TCPClient> for TCPClientFactory {
type Error = Error;
async fn make_client(&self, addr: &String) -> Result<TCPClient> {
Ok(TCPClient::new(self.config.clone(), addr.clone()))
}
}
impl TCPDownloader {
const MAX_CONNECTIONS_PER_ADDRESS: usize = 32;
pub fn new(config: Arc<Config>, capacity: usize, idle_timeout: Duration) -> Self {
Self {
client_pool: PoolBuilder::new(TCPClientFactory {
config: config.clone(),
})
.capacity(capacity)
.idle_timeout(idle_timeout)
.build(),
}
}
async fn get_client_entry(&self, key: String, addr: String) -> Result<Entry<TCPClient>> {
self.client_pool.entry(&key, &addr).await
}
async fn remove_client_entry(&self, key: String) {
self.client_pool.remove_entry(&key).await;
}
fn get_entry_key(&self, addr: &str) -> String {
format!(
"{}-{}",
addr,
fastrand::usize(..Self::MAX_CONNECTIONS_PER_ADDRESS)
)
}
}
#[async_trait]
impl Downloader for TCPDownloader {
#[instrument(skip_all)]
async fn download_piece(
&self,
addr: &str,
number: u32,
_host_id: &str,
task_id: &str,
) -> Result<(Box<dyn AsyncRead + Send + Unpin>, u64, String)> {
let key = self.get_entry_key(addr);
let entry = self.get_client_entry(key.clone(), addr.to_string()).await?;
let request_guard = entry.request_guard();
match entry.client.download_piece(number, task_id).await {
Ok((reader, offset, digest)) => Ok((Box::new(reader), offset, digest)),
Err(err) => {
drop(request_guard);
self.remove_client_entry(key).await;
Err(err)
}
}
}
#[instrument(skip_all)]
async fn download_persistent_piece(
&self,
addr: &str,
number: u32,
_host_id: &str,
task_id: &str,
) -> Result<(Box<dyn AsyncRead + Send + Unpin>, u64, String)> {
let key = self.get_entry_key(addr);
let entry = self.get_client_entry(key.clone(), addr.to_string()).await?;
let request_guard = entry.request_guard();
match entry
.client
.download_persistent_piece(number, task_id)
.await
{
Ok((reader, offset, digest)) => Ok((Box::new(reader), offset, digest)),
Err(err) => {
drop(request_guard);
self.remove_client_entry(key).await;
Err(err)
}
}
}
#[instrument(skip_all)]
async fn download_persistent_cache_piece(
&self,
addr: &str,
number: u32,
_host_id: &str,
task_id: &str,
) -> Result<(Box<dyn AsyncRead + Send + Unpin>, u64, String)> {
let key = self.get_entry_key(addr);
let entry = self.get_client_entry(key.clone(), addr.to_string()).await?;
let request_guard = entry.request_guard();
match entry
.client
.download_persistent_cache_piece(number, task_id)
.await
{
Ok((reader, offset, digest)) => Ok((Box::new(reader), offset, digest)),
Err(err) => {
drop(request_guard);
self.remove_client_entry(key).await;
Err(err)
}
}
}
}