use bytesize::ByteSize;
use dragonfly_client_core::{
error::{ErrorType, OrErr},
Result,
};
use dragonfly_client_util::{
http::basic_auth,
http::query_params::default_proxy_rule_filtered_query_params,
ratelimiter::bbr::BBRConfig,
tls::{generate_ca_cert_from_pem, generate_cert_from_pem},
};
use local_ip_address::{local_ip, local_ipv6};
use rcgen::Certificate;
use regex::Regex;
use rustls_pki_types::CertificateDer;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::path::PathBuf;
use std::time::Duration;
use tokio::fs;
use tonic::transport::{
Certificate as TonicCertificate, ClientTlsConfig, Identity, ServerTlsConfig,
};
use tracing::{error, instrument};
use validator::Validate;
pub const NAME: &str = "dfdaemon";
#[inline]
pub fn default_dfdaemon_config_path() -> PathBuf {
crate::default_config_dir().join("dfdaemon.yaml")
}
#[inline]
pub fn default_dfdaemon_log_dir() -> PathBuf {
crate::default_log_dir().join(NAME)
}
pub fn default_download_unix_socket_path() -> PathBuf {
crate::default_root_dir().join("dfdaemon.sock")
}
#[inline]
fn default_download_protocol() -> String {
"tcp".to_string()
}
pub fn default_download_request_rate_limit() -> u64 {
4000
}
pub fn default_download_request_buffer_size() -> usize {
1000
}
#[inline]
fn default_host_hostname() -> String {
hostname::get().unwrap().to_string_lossy().to_string()
}
#[inline]
fn default_dfdaemon_plugin_dir() -> PathBuf {
crate::default_plugin_dir().join(NAME)
}
#[inline]
fn default_dfdaemon_cache_dir() -> PathBuf {
crate::default_cache_dir().join(NAME)
}
#[inline]
fn default_upload_grpc_server_port() -> u16 {
4000
}
pub fn default_upload_request_rate_limit() -> u64 {
4000
}
pub fn default_upload_request_buffer_size() -> usize {
1000
}
#[inline]
fn default_upload_bandwidth_limit() -> ByteSize {
ByteSize::gb(50)
}
#[inline]
fn default_health_server_port() -> u16 {
4003
}
#[inline]
fn default_metrics_server_port() -> u16 {
4002
}
#[inline]
fn default_stats_server_port() -> u16 {
4004
}
#[inline]
fn default_download_bandwidth_limit() -> ByteSize {
ByteSize::gb(50)
}
#[inline]
fn default_back_to_source_bandwidth_limit() -> ByteSize {
ByteSize::gb(50)
}
#[inline]
fn default_download_piece_timeout() -> Duration {
Duration::from_secs(360)
}
#[inline]
fn default_collected_download_piece_timeout() -> Duration {
Duration::from_secs(360)
}
#[inline]
fn default_download_concurrent_piece_count() -> u32 {
8
}
#[inline]
fn default_backend_enable_cache_temporary_redirect() -> bool {
true
}
#[inline]
fn default_backend_cache_temporary_redirect_ttl() -> Duration {
Duration::from_secs(600)
}
#[inline]
fn default_backend_put_concurrent_chunk_count() -> u32 {
16
}
fn default_backend_put_chunk_size() -> ByteSize {
ByteSize::mib(8)
}
fn default_backend_put_timeout() -> Duration {
Duration::from_secs(900)
}
fn default_backend_enable_hickory_dns() -> bool {
true
}
#[inline]
fn default_download_max_schedule_count() -> u32 {
5
}
#[inline]
fn default_tracing_path() -> Option<PathBuf> {
Some(PathBuf::from("/v1/traces"))
}
#[inline]
fn default_scheduler_announce_interval() -> Duration {
Duration::from_secs(300)
}
#[inline]
fn default_scheduler_schedule_timeout() -> Duration {
Duration::from_secs(3 * 60 * 60)
}
#[inline]
fn default_dynconfig_refresh_interval() -> Duration {
Duration::from_secs(300)
}
#[inline]
fn default_storage_server_tcp_port() -> u16 {
4005
}
#[inline]
fn default_storage_server_quic_port() -> u16 {
4006
}
#[inline]
fn default_storage_keep() -> bool {
false
}
#[inline]
fn default_storage_write_piece_timeout() -> Duration {
Duration::from_secs(360)
}
#[inline]
fn default_storage_write_buffer_size() -> usize {
4 * 1024 * 1024
}
#[inline]
fn default_storage_read_buffer_size() -> usize {
4 * 1024 * 1024
}
#[inline]
fn default_storage_cache_capacity() -> ByteSize {
ByteSize::mib(64)
}
#[inline]
fn default_gc_interval() -> Duration {
Duration::from_secs(900)
}
#[inline]
fn default_gc_policy_task_ttl() -> Duration {
Duration::from_secs(2_592_000)
}
#[inline]
fn default_gc_policy_persistent_task_ttl() -> Duration {
Duration::from_secs(86_400)
}
#[inline]
fn default_gc_policy_persistent_cache_task_ttl() -> Duration {
Duration::from_secs(86_400)
}
#[inline]
fn default_gc_policy_disk_threshold() -> ByteSize {
ByteSize::default()
}
#[inline]
fn default_gc_policy_disk_high_threshold_percent() -> u8 {
80
}
#[inline]
fn default_gc_policy_disk_low_threshold_percent() -> u8 {
60
}
#[inline]
pub fn default_proxy_server_port() -> u16 {
4001
}
#[inline]
pub fn default_proxy_read_buffer_size() -> usize {
4 * 1024 * 1024
}
#[inline]
fn default_prefetch_bandwidth_limit() -> ByteSize {
ByteSize::gb(10)
}
#[inline]
fn default_proxy_registry_mirror_addr() -> String {
"https://index.docker.io".to_string()
}
#[inline]
fn default_enable_task_id_based_blob_digest() -> bool {
false
}
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct Host {
pub idc: Option<String>,
pub location: Option<String>,
#[serde(default = "default_host_hostname")]
pub hostname: String,
pub ip: Option<IpAddr>,
#[serde(rename = "schedulerClusterID")]
pub scheduler_cluster_id: Option<u64>,
}
impl Default for Host {
fn default() -> Self {
Host {
idc: None,
location: None,
hostname: default_host_hostname(),
ip: None,
scheduler_cluster_id: None,
}
}
}
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct Server {
#[serde(default = "default_dfdaemon_plugin_dir")]
pub plugin_dir: PathBuf,
#[serde(default = "default_dfdaemon_cache_dir")]
pub cache_dir: PathBuf,
pub adaptive_rate_limit: Option<BBRConfig>,
}
impl Default for Server {
fn default() -> Self {
Server {
plugin_dir: default_dfdaemon_plugin_dir(),
cache_dir: default_dfdaemon_cache_dir(),
adaptive_rate_limit: None,
}
}
}
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct DownloadServer {
#[serde(default = "default_download_unix_socket_path")]
pub socket_path: PathBuf,
#[serde(default = "default_download_request_rate_limit")]
pub request_rate_limit: u64,
#[serde(default = "default_download_request_buffer_size")]
pub request_buffer_size: usize,
}
impl Default for DownloadServer {
fn default() -> Self {
DownloadServer {
socket_path: default_download_unix_socket_path(),
request_rate_limit: default_download_request_rate_limit(),
request_buffer_size: default_download_request_buffer_size(),
}
}
}
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct Download {
pub server: DownloadServer,
#[serde(default = "default_download_protocol")]
pub protocol: String,
#[serde(with = "bytesize_serde", default = "default_download_bandwidth_limit")]
pub bandwidth_limit: ByteSize,
#[serde(
with = "bytesize_serde",
default = "default_back_to_source_bandwidth_limit"
)]
pub back_to_source_bandwidth_limit: ByteSize,
#[serde(default = "default_download_piece_timeout", with = "humantime_serde")]
pub piece_timeout: Duration,
#[serde(
default = "default_collected_download_piece_timeout",
with = "humantime_serde"
)]
pub collected_piece_timeout: Duration,
#[serde(default = "default_download_concurrent_piece_count")]
#[validate(range(min = 1))]
pub concurrent_piece_count: u32,
}
impl Default for Download {
fn default() -> Self {
Download {
server: DownloadServer::default(),
protocol: default_download_protocol(),
bandwidth_limit: default_download_bandwidth_limit(),
back_to_source_bandwidth_limit: default_back_to_source_bandwidth_limit(),
piece_timeout: default_download_piece_timeout(),
collected_piece_timeout: default_collected_download_piece_timeout(),
concurrent_piece_count: default_download_concurrent_piece_count(),
}
}
}
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct UploadServer {
pub ip: Option<IpAddr>,
#[serde(default = "default_upload_grpc_server_port")]
pub port: u16,
pub ca_cert: Option<PathBuf>,
pub cert: Option<PathBuf>,
pub key: Option<PathBuf>,
#[serde(default = "default_upload_request_rate_limit")]
pub request_rate_limit: u64,
#[serde(default = "default_upload_request_buffer_size")]
pub request_buffer_size: usize,
}
impl Default for UploadServer {
fn default() -> Self {
UploadServer {
ip: None,
port: default_upload_grpc_server_port(),
ca_cert: None,
cert: None,
key: None,
request_rate_limit: default_upload_request_rate_limit(),
request_buffer_size: default_upload_request_buffer_size(),
}
}
}
impl UploadServer {
pub async fn load_server_tls_config(&self) -> Result<Option<ServerTlsConfig>> {
if let (Some(ca_cert_path), Some(server_cert_path), Some(server_key_path)) =
(self.ca_cert.clone(), self.cert.clone(), self.key.clone())
{
let server_cert = fs::read(&server_cert_path).await?;
let server_key = fs::read(&server_key_path).await?;
let server_identity = Identity::from_pem(server_cert, server_key);
let ca_cert = fs::read(&ca_cert_path).await?;
let ca_cert = TonicCertificate::from_pem(ca_cert);
return Ok(Some(
ServerTlsConfig::new()
.identity(server_identity)
.client_ca_root(ca_cert),
));
}
Ok(None)
}
}
#[derive(Debug, Clone, Default, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct UploadClient {
pub ca_cert: Option<PathBuf>,
pub cert: Option<PathBuf>,
pub key: Option<PathBuf>,
}
impl UploadClient {
pub async fn load_client_tls_config(
&self,
domain_name: &str,
) -> Result<Option<ClientTlsConfig>> {
if let (Some(ca_cert_path), Some(client_cert_path), Some(client_key_path)) =
(self.ca_cert.clone(), self.cert.clone(), self.key.clone())
{
let client_cert = fs::read(&client_cert_path).await?;
let client_key = fs::read(&client_key_path).await?;
let client_identity = Identity::from_pem(client_cert, client_key);
let ca_cert = fs::read(&ca_cert_path).await?;
let ca_cert = TonicCertificate::from_pem(ca_cert);
return Ok(Some(
ClientTlsConfig::new()
.domain_name(domain_name)
.ca_certificate(ca_cert)
.identity(client_identity),
));
}
Ok(None)
}
}
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct Upload {
pub server: UploadServer,
pub client: UploadClient,
pub disable_shared: bool,
#[serde(with = "bytesize_serde", default = "default_upload_bandwidth_limit")]
pub bandwidth_limit: ByteSize,
}
impl Default for Upload {
fn default() -> Self {
Upload {
server: UploadServer::default(),
client: UploadClient::default(),
disable_shared: false,
bandwidth_limit: default_upload_bandwidth_limit(),
}
}
}
#[derive(Debug, Clone, Default, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct Manager {
pub addr: String,
pub ca_cert: Option<PathBuf>,
pub cert: Option<PathBuf>,
pub key: Option<PathBuf>,
}
impl Manager {
pub async fn load_client_tls_config(
&self,
domain_name: &str,
) -> Result<Option<ClientTlsConfig>> {
if let (Some(ca_cert_path), Some(client_cert_path), Some(client_key_path)) =
(self.ca_cert.clone(), self.cert.clone(), self.key.clone())
{
let client_cert = fs::read(&client_cert_path).await?;
let client_key = fs::read(&client_key_path).await?;
let client_identity = Identity::from_pem(client_cert, client_key);
let ca_cert = fs::read(&ca_cert_path).await?;
let ca_cert = TonicCertificate::from_pem(ca_cert);
return Ok(Some(
ClientTlsConfig::new()
.domain_name(domain_name)
.ca_certificate(ca_cert)
.identity(client_identity),
));
}
Ok(None)
}
}
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct Scheduler {
#[serde(
default = "default_scheduler_announce_interval",
with = "humantime_serde"
)]
pub announce_interval: Duration,
#[serde(
default = "default_scheduler_schedule_timeout",
with = "humantime_serde"
)]
pub schedule_timeout: Duration,
#[serde(default = "default_download_max_schedule_count")]
#[validate(range(min = 1))]
pub max_schedule_count: u32,
pub ca_cert: Option<PathBuf>,
pub cert: Option<PathBuf>,
pub key: Option<PathBuf>,
}
impl Default for Scheduler {
fn default() -> Self {
Scheduler {
announce_interval: default_scheduler_announce_interval(),
schedule_timeout: default_scheduler_schedule_timeout(),
max_schedule_count: default_download_max_schedule_count(),
ca_cert: None,
cert: None,
key: None,
}
}
}
impl Scheduler {
pub async fn load_client_tls_config(
&self,
domain_name: &str,
) -> Result<Option<ClientTlsConfig>> {
if let (Some(ca_cert_path), Some(client_cert_path), Some(client_key_path)) =
(self.ca_cert.clone(), self.cert.clone(), self.key.clone())
{
let client_cert = fs::read(&client_cert_path).await?;
let client_key = fs::read(&client_key_path).await?;
let client_identity = Identity::from_pem(client_cert, client_key);
let ca_cert = fs::read(&ca_cert_path).await?;
let ca_cert = TonicCertificate::from_pem(ca_cert);
return Ok(Some(
ClientTlsConfig::new()
.domain_name(domain_name)
.ca_certificate(ca_cert)
.identity(client_identity),
));
}
Ok(None)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
pub enum HostType {
#[serde(rename = "normal")]
Normal,
#[default]
#[serde(rename = "super")]
Super,
}
impl fmt::Display for HostType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
HostType::Normal => write!(f, "normal"),
HostType::Super => write!(f, "super"),
}
}
}
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct SeedPeer {
pub enable: bool,
#[serde(default, rename = "type")]
pub kind: HostType,
}
impl Default for SeedPeer {
fn default() -> Self {
SeedPeer {
enable: false,
kind: HostType::Normal,
}
}
}
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct Dynconfig {
#[serde(
default = "default_dynconfig_refresh_interval",
with = "humantime_serde"
)]
pub refresh_interval: Duration,
}
impl Default for Dynconfig {
fn default() -> Self {
Dynconfig {
refresh_interval: default_dynconfig_refresh_interval(),
}
}
}
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct StorageServer {
pub ip: Option<IpAddr>,
#[serde(default = "default_storage_server_tcp_port")]
pub tcp_port: u16,
pub tcp_fastopen: bool,
#[serde(default = "default_storage_server_quic_port")]
pub quic_port: u16,
}
impl Default for StorageServer {
fn default() -> Self {
StorageServer {
ip: None,
tcp_port: default_storage_server_tcp_port(),
tcp_fastopen: false,
quic_port: default_storage_server_quic_port(),
}
}
}
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct Storage {
pub server: StorageServer,
#[serde(default = "crate::default_storage_dir")]
pub dir: PathBuf,
#[serde(default = "default_storage_keep")]
pub keep: bool,
#[serde(
default = "default_storage_write_piece_timeout",
with = "humantime_serde"
)]
pub write_piece_timeout: Duration,
#[serde(default = "default_storage_write_buffer_size")]
pub write_buffer_size: usize,
#[serde(default = "default_storage_read_buffer_size")]
pub read_buffer_size: usize,
#[serde(with = "bytesize_serde", default = "default_storage_cache_capacity")]
pub cache_capacity: ByteSize,
}
impl Default for Storage {
fn default() -> Self {
Storage {
server: StorageServer::default(),
dir: crate::default_storage_dir(),
keep: default_storage_keep(),
write_piece_timeout: default_storage_write_piece_timeout(),
write_buffer_size: default_storage_write_buffer_size(),
read_buffer_size: default_storage_read_buffer_size(),
cache_capacity: default_storage_cache_capacity(),
}
}
}
#[derive(Debug, Clone, Validate, Deserialize, Serialize)]
#[serde(default, rename_all = "camelCase")]
pub struct Policy {
#[serde(
default = "default_gc_policy_task_ttl",
rename = "taskTTL",
with = "humantime_serde"
)]
pub task_ttl: Duration,
#[serde(
default = "default_gc_policy_persistent_task_ttl",
rename = "persistentTaskTTL",
with = "humantime_serde"
)]
pub persistent_task_ttl: Duration,
#[serde(
default = "default_gc_policy_persistent_cache_task_ttl",
rename = "persistentCacheTaskTTL",
with = "humantime_serde"
)]
pub persistent_cache_task_ttl: Duration,
#[serde(
with = "bytesize_serde",
default = "default_gc_policy_disk_threshold",
alias = "distThreshold"
)]
pub disk_threshold: ByteSize,
#[serde(
default = "default_gc_policy_disk_high_threshold_percent",
alias = "distHighThresholdPercent"
)]
#[validate(range(min = 1, max = 99))]
pub disk_high_threshold_percent: u8,
#[serde(
default = "default_gc_policy_disk_low_threshold_percent",
alias = "distLowThresholdPercent"
)]
#[validate(range(min = 1, max = 99))]
pub disk_low_threshold_percent: u8,
}
impl Default for Policy {
fn default() -> Self {
Policy {
disk_threshold: default_gc_policy_disk_threshold(),
task_ttl: default_gc_policy_task_ttl(),
persistent_task_ttl: default_gc_policy_persistent_task_ttl(),
persistent_cache_task_ttl: default_gc_policy_persistent_cache_task_ttl(),
disk_high_threshold_percent: default_gc_policy_disk_high_threshold_percent(),
disk_low_threshold_percent: default_gc_policy_disk_low_threshold_percent(),
}
}
}
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct GC {
#[serde(default = "default_gc_interval", with = "humantime_serde")]
pub interval: Duration,
pub policy: Policy,
}
impl Default for GC {
fn default() -> Self {
GC {
interval: default_gc_interval(),
policy: Policy::default(),
}
}
}
#[derive(Default, Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct BasicAuth {
#[validate(length(min = 1, max = 20))]
pub username: String,
#[validate(length(min = 1, max = 20))]
pub password: String,
}
impl BasicAuth {
pub fn credentials(&self) -> basic_auth::Credentials {
basic_auth::Credentials::new(&self.username, &self.password)
}
}
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct ProxyServer {
pub ip: Option<IpAddr>,
#[serde(default = "default_proxy_server_port")]
pub port: u16,
pub ca_cert: Option<PathBuf>,
pub ca_key: Option<PathBuf>,
pub basic_auth: Option<BasicAuth>,
}
impl Default for ProxyServer {
fn default() -> Self {
Self {
ip: None,
port: default_proxy_server_port(),
ca_cert: None,
ca_key: None,
basic_auth: None,
}
}
}
impl ProxyServer {
pub fn load_cert(&self) -> Result<Option<Certificate>> {
if let (Some(server_ca_cert_path), Some(server_ca_key_path)) =
(self.ca_cert.clone(), self.ca_key.clone())
{
match generate_ca_cert_from_pem(&server_ca_cert_path, &server_ca_key_path) {
Ok(server_ca_cert) => return Ok(Some(server_ca_cert)),
Err(err) => {
error!("generate ca cert and key from pem failed: {}", err);
return Err(err);
}
}
}
Ok(None)
}
}
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct Rule {
#[serde(with = "serde_regex")]
pub regex: Regex,
#[serde(rename = "useTLS")]
pub use_tls: bool,
pub redirect: Option<String>,
#[serde(default = "default_proxy_rule_filtered_query_params")]
pub filtered_query_params: Vec<String>,
}
impl Default for Rule {
fn default() -> Self {
Self {
regex: Regex::new(r".*").unwrap(),
use_tls: false,
redirect: None,
filtered_query_params: default_proxy_rule_filtered_query_params(),
}
}
}
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct RegistryMirror {
#[serde(default = "default_proxy_registry_mirror_addr")]
pub addr: String,
pub cert: Option<PathBuf>,
#[serde(
default = "default_enable_task_id_based_blob_digest",
rename = "enableTaskIDBasedBlobDigest"
)]
pub enable_task_id_based_blob_digest: bool,
}
impl Default for RegistryMirror {
fn default() -> Self {
Self {
addr: default_proxy_registry_mirror_addr(),
cert: None,
enable_task_id_based_blob_digest: default_enable_task_id_based_blob_digest(),
}
}
}
impl RegistryMirror {
pub fn load_cert_der(&self) -> Result<Option<Vec<CertificateDer<'static>>>> {
if let Some(cert_path) = self.cert.clone() {
match generate_cert_from_pem(&cert_path) {
Ok(cert) => return Ok(Some(cert)),
Err(err) => {
error!("generate cert from pems failed: {}", err);
return Err(err);
}
}
};
Ok(None)
}
}
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct Proxy {
pub server: ProxyServer,
pub rules: Option<Vec<Rule>>,
pub registry_mirror: RegistryMirror,
pub disable_back_to_source: bool,
pub prefetch: bool,
#[serde(with = "bytesize_serde", default = "default_prefetch_bandwidth_limit")]
pub prefetch_bandwidth_limit: ByteSize,
#[serde(default = "default_proxy_read_buffer_size")]
pub read_buffer_size: usize,
}
impl Default for Proxy {
fn default() -> Self {
Self {
server: ProxyServer::default(),
rules: None,
registry_mirror: RegistryMirror::default(),
disable_back_to_source: false,
prefetch: false,
prefetch_bandwidth_limit: default_prefetch_bandwidth_limit(),
read_buffer_size: default_proxy_read_buffer_size(),
}
}
}
#[derive(Debug, Clone, Default, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct Security {
pub enable: bool,
}
#[derive(Debug, Clone, Default, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct Network {
#[serde(rename = "enableIPv6")]
pub enable_ipv6: bool,
}
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct HealthServer {
pub ip: Option<IpAddr>,
#[serde(default = "default_health_server_port")]
pub port: u16,
}
impl Default for HealthServer {
fn default() -> Self {
Self {
ip: None,
port: default_health_server_port(),
}
}
}
#[derive(Debug, Clone, Default, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct Health {
pub server: HealthServer,
}
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct MetricsServer {
pub ip: Option<IpAddr>,
#[serde(default = "default_metrics_server_port")]
pub port: u16,
}
impl Default for MetricsServer {
fn default() -> Self {
Self {
ip: None,
port: default_metrics_server_port(),
}
}
}
#[derive(Debug, Clone, Default, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct Metrics {
pub server: MetricsServer,
}
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct StatsServer {
pub ip: Option<IpAddr>,
#[serde(default = "default_stats_server_port")]
pub port: u16,
}
impl Default for StatsServer {
fn default() -> Self {
Self {
ip: None,
port: default_stats_server_port(),
}
}
}
#[derive(Debug, Clone, Default, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct Stats {
pub server: StatsServer,
}
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct Tracing {
pub protocol: Option<String>,
pub endpoint: Option<String>,
#[serde(default = "default_tracing_path")]
pub path: Option<PathBuf>,
#[serde(with = "http_serde::header_map")]
pub headers: reqwest::header::HeaderMap,
}
impl Default for Tracing {
fn default() -> Self {
Self {
protocol: None,
endpoint: None,
path: default_tracing_path(),
headers: reqwest::header::HeaderMap::new(),
}
}
}
#[derive(Debug, Clone, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct Backend {
pub request_header: Option<HashMap<String, String>>,
#[serde(default = "default_backend_enable_cache_temporary_redirect")]
pub enable_cache_temporary_redirect: bool,
#[serde(
default = "default_backend_cache_temporary_redirect_ttl",
rename = "cacheTemporaryRedirectTTL",
with = "humantime_serde"
)]
pub cache_temporary_redirect_ttl: Duration,
#[serde(default = "default_backend_put_concurrent_chunk_count")]
pub put_concurrent_chunk_count: u32,
#[serde(default = "default_backend_put_chunk_size", with = "bytesize_serde")]
pub put_chunk_size: ByteSize,
#[serde(default = "default_backend_put_timeout", with = "humantime_serde")]
pub put_timeout: Duration,
#[serde(
default = "default_backend_enable_hickory_dns",
rename = "enableHickoryDNS"
)]
pub enable_hickory_dns: bool,
}
impl Default for Backend {
fn default() -> Self {
Self {
request_header: None,
enable_cache_temporary_redirect: default_backend_enable_cache_temporary_redirect(),
cache_temporary_redirect_ttl: default_backend_cache_temporary_redirect_ttl(),
put_concurrent_chunk_count: default_backend_put_concurrent_chunk_count(),
put_chunk_size: default_backend_put_chunk_size(),
put_timeout: default_backend_put_timeout(),
enable_hickory_dns: default_backend_enable_hickory_dns(),
}
}
}
#[derive(Debug, Clone, Default, Validate, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct Config {
#[validate]
pub host: Host,
#[validate]
pub server: Server,
#[validate]
pub download: Download,
#[validate]
pub upload: Upload,
#[validate]
pub manager: Manager,
#[validate]
pub scheduler: Scheduler,
#[validate]
pub seed_peer: SeedPeer,
#[validate]
pub dynconfig: Dynconfig,
#[validate]
pub storage: Storage,
#[validate]
pub backend: Backend,
#[validate]
pub gc: GC,
#[validate]
pub proxy: Proxy,
#[validate]
pub security: Security,
#[validate]
pub health: Health,
#[validate]
pub metrics: Metrics,
#[validate]
pub stats: Stats,
#[validate]
pub tracing: Tracing,
#[validate]
pub network: Network,
}
impl Config {
#[instrument(skip_all)]
pub async fn load(path: &PathBuf) -> Result<Config> {
let content = fs::read_to_string(path).await?;
let mut config: Config = serde_yaml::from_str(&content).or_err(ErrorType::ConfigError)?;
config.convert();
config.validate().or_err(ErrorType::ValidationError)?;
Ok(config)
}
#[instrument(skip_all)]
fn convert(&mut self) {
if self.host.ip.is_none() {
self.host.ip = if self.network.enable_ipv6 {
Some(local_ipv6().unwrap())
} else {
local_ip().ok().or_else(|| local_ipv6().ok())
};
}
if self.upload.server.ip.is_none() {
self.upload.server.ip = if self.network.enable_ipv6 {
Some(Ipv6Addr::UNSPECIFIED.into())
} else {
Some(Ipv4Addr::UNSPECIFIED.into())
}
}
if self.storage.server.ip.is_none() {
self.storage.server.ip = if self.network.enable_ipv6 {
Some(Ipv6Addr::UNSPECIFIED.into())
} else {
Some(Ipv4Addr::UNSPECIFIED.into())
}
}
if self.health.server.ip.is_none() {
self.health.server.ip = if self.network.enable_ipv6 {
Some(Ipv6Addr::UNSPECIFIED.into())
} else {
Some(Ipv4Addr::UNSPECIFIED.into())
}
}
if self.metrics.server.ip.is_none() {
self.metrics.server.ip = if self.network.enable_ipv6 {
Some(Ipv6Addr::UNSPECIFIED.into())
} else {
Some(Ipv4Addr::UNSPECIFIED.into())
}
}
if self.stats.server.ip.is_none() {
self.stats.server.ip = if self.network.enable_ipv6 {
Some(Ipv6Addr::UNSPECIFIED.into())
} else {
Some(Ipv4Addr::UNSPECIFIED.into())
}
}
if self.proxy.server.ip.is_none() {
self.proxy.server.ip = if self.network.enable_ipv6 {
Some(Ipv6Addr::UNSPECIFIED.into())
} else {
Some(Ipv4Addr::UNSPECIFIED.into())
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
use tempfile::NamedTempFile;
use tokio::fs;
#[test]
fn deserialize_server_correctly() {
let json_data = r#"
{
"pluginDir": "/custom/plugin/dir",
"cacheDir": "/custom/cache/dir"
}"#;
let server: Server = serde_json::from_str(json_data).unwrap();
assert_eq!(server.plugin_dir, PathBuf::from("/custom/plugin/dir"));
assert_eq!(server.cache_dir, PathBuf::from("/custom/cache/dir"));
}
#[test]
fn deserialize_download_correctly() {
let json_data = r#"
{
"server": {
"socketPath": "/var/run/dragonfly/dfdaemon.sock",
"requestRateLimit": 4000
},
"protocol": "quic",
"bandwidthLimit": "50GB",
"pieceTimeout": "30s",
"concurrentPieceCount": 10
}"#;
let download: Download = serde_json::from_str(json_data).unwrap();
assert_eq!(
download.server.socket_path,
PathBuf::from("/var/run/dragonfly/dfdaemon.sock")
);
assert_eq!(download.server.request_rate_limit, 4000);
assert_eq!(download.protocol, "quic".to_string());
assert_eq!(download.bandwidth_limit, ByteSize::gb(50));
assert_eq!(download.piece_timeout, Duration::from_secs(30));
assert_eq!(download.concurrent_piece_count, 10);
}
#[test]
fn deserialize_upload_correctly() {
let json_data = r#"
{
"server": {
"port": 4000,
"ip": "127.0.0.1",
"caCert": "/etc/ssl/certs/ca.crt",
"cert": "/etc/ssl/certs/server.crt",
"key": "/etc/ssl/private/server.pem"
},
"client": {
"caCert": "/etc/ssl/certs/ca.crt",
"cert": "/etc/ssl/certs/client.crt",
"key": "/etc/ssl/private/client.pem"
},
"disableShared": false,
"bandwidthLimit": "10GB"
}"#;
let upload: Upload = serde_json::from_str(json_data).unwrap();
assert_eq!(upload.server.port, 4000);
assert_eq!(
upload.server.ip,
Some("127.0.0.1".parse::<IpAddr>().unwrap())
);
assert_eq!(
upload.server.ca_cert,
Some(PathBuf::from("/etc/ssl/certs/ca.crt"))
);
assert_eq!(
upload.server.cert,
Some(PathBuf::from("/etc/ssl/certs/server.crt"))
);
assert_eq!(
upload.server.key,
Some(PathBuf::from("/etc/ssl/private/server.pem"))
);
assert_eq!(
upload.client.ca_cert,
Some(PathBuf::from("/etc/ssl/certs/ca.crt"))
);
assert_eq!(
upload.client.cert,
Some(PathBuf::from("/etc/ssl/certs/client.crt"))
);
assert_eq!(
upload.client.key,
Some(PathBuf::from("/etc/ssl/private/client.pem"))
);
assert!(!upload.disable_shared);
assert_eq!(upload.bandwidth_limit, ByteSize::gb(10));
}
#[test]
fn upload_server_default() {
let server = UploadServer::default();
assert!(server.ip.is_none());
assert_eq!(server.port, default_upload_grpc_server_port());
assert!(server.ca_cert.is_none());
assert!(server.cert.is_none());
assert!(server.key.is_none());
assert_eq!(
server.request_rate_limit,
default_upload_request_rate_limit()
);
}
#[tokio::test]
async fn upload_load_server_tls_config_success() {
let (ca_file, cert_file, key_file) = create_temp_certs().await;
let server = UploadServer {
ca_cert: Some(ca_file.path().to_path_buf()),
cert: Some(cert_file.path().to_path_buf()),
key: Some(key_file.path().to_path_buf()),
..Default::default()
};
let tls_config = server.load_server_tls_config().await.unwrap();
assert!(tls_config.is_some());
}
#[tokio::test]
async fn load_server_tls_config_missing_certs() {
let server = UploadServer {
ca_cert: Some(PathBuf::from("/invalid/path")),
cert: None,
key: None,
..Default::default()
};
let tls_config = server.load_server_tls_config().await.unwrap();
assert!(tls_config.is_none());
}
#[test]
fn upload_client_default() {
let client = UploadClient::default();
assert!(client.ca_cert.is_none());
assert!(client.cert.is_none());
assert!(client.key.is_none());
}
#[tokio::test]
async fn upload_client_load_tls_config_success() {
let (ca_file, cert_file, key_file) = create_temp_certs().await;
let client = UploadClient {
ca_cert: Some(ca_file.path().to_path_buf()),
cert: Some(cert_file.path().to_path_buf()),
key: Some(key_file.path().to_path_buf()),
};
let tls_config = client.load_client_tls_config("example.com").await.unwrap();
assert!(tls_config.is_some());
let cfg_string = format!("{:?}", tls_config.unwrap());
assert!(
cfg_string.contains("example.com"),
"Domain name not found in TLS config"
);
}
#[tokio::test]
async fn upload_server_load_tls_config_invalid_path() {
let server = UploadServer {
ca_cert: Some(PathBuf::from("/invalid/ca.crt")),
cert: Some(PathBuf::from("/invalid/server.crt")),
key: Some(PathBuf::from("/invalid/server.key")),
..Default::default()
};
let result = server.load_server_tls_config().await;
assert!(result.is_err());
}
async fn create_temp_certs() -> (NamedTempFile, NamedTempFile, NamedTempFile) {
let ca = NamedTempFile::new().unwrap();
let cert = NamedTempFile::new().unwrap();
let key = NamedTempFile::new().unwrap();
fs::write(ca.path(), "-----BEGIN CERT-----\n...\n-----END CERT-----\n")
.await
.unwrap();
fs::write(
cert.path(),
"-----BEGIN CERT-----\n...\n-----END CERT-----\n",
)
.await
.unwrap();
fs::write(
key.path(),
"-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----\n",
)
.await
.unwrap();
(ca, cert, key)
}
#[tokio::test]
async fn manager_load_client_tls_config_success() {
let temp_dir = tempfile::TempDir::new().unwrap();
let ca_path = temp_dir.path().join("ca.crt");
let cert_path = temp_dir.path().join("client.crt");
let key_path = temp_dir.path().join("client.key");
fs::write(&ca_path, "CA cert content").await.unwrap();
fs::write(&cert_path, "Client cert content").await.unwrap();
fs::write(&key_path, "Client key content").await.unwrap();
let manager = Manager {
addr: "http://example.com".to_string(),
ca_cert: Some(ca_path),
cert: Some(cert_path),
key: Some(key_path),
};
let result = manager.load_client_tls_config("example.com").await;
assert!(result.is_ok());
let config = result.unwrap();
assert!(config.is_some());
}
#[test]
fn deserialize_optional_fields_correctly() {
let yaml = r#"
addr: http://another-service:8080
"#;
let manager: Manager = serde_yaml::from_str(yaml).unwrap();
assert_eq!(manager.addr, "http://another-service:8080");
assert!(manager.ca_cert.is_none());
assert!(manager.cert.is_none());
assert!(manager.key.is_none());
}
#[test]
fn deserialize_manager_correctly() {
let yaml = r#"
addr: http://manager-service:65003
caCert: /etc/ssl/certs/ca.crt
cert: /etc/ssl/certs/client.crt
key: /etc/ssl/private/client.pem
"#;
let manager: Manager = serde_yaml::from_str(yaml).expect("Failed to deserialize");
assert_eq!(manager.addr, "http://manager-service:65003");
assert_eq!(
manager.ca_cert,
Some(PathBuf::from("/etc/ssl/certs/ca.crt"))
);
assert_eq!(
manager.cert,
Some(PathBuf::from("/etc/ssl/certs/client.crt"))
);
assert_eq!(
manager.key,
Some(PathBuf::from("/etc/ssl/private/client.pem"))
);
}
#[test]
fn default_host_type_correctly() {
assert_eq!(HostType::Normal.to_string(), "normal");
assert_eq!(HostType::Super.to_string(), "super");
let default_host_type: HostType = Default::default();
assert_eq!(default_host_type, HostType::Super);
}
#[test]
fn serialize_host_type_correctly() {
let normal: HostType = serde_json::from_str("\"normal\"").unwrap();
let super_seed: HostType = serde_json::from_str("\"super\"").unwrap();
assert_eq!(normal, HostType::Normal);
assert_eq!(super_seed, HostType::Super);
}
#[test]
fn serialize_host_type() {
let normal_json = serde_json::to_string(&HostType::Normal).unwrap();
let super_json = serde_json::to_string(&HostType::Super).unwrap();
assert_eq!(normal_json, "\"normal\"");
assert_eq!(super_json, "\"super\"");
}
#[test]
fn default_seed_peer() {
let default_seed_peer = SeedPeer::default();
assert!(!default_seed_peer.enable);
assert_eq!(default_seed_peer.kind, HostType::Normal);
}
#[test]
fn validate_seed_peer() {
let valid_seed_peer = SeedPeer {
enable: true,
kind: HostType::Super,
};
assert!(valid_seed_peer.validate().is_ok());
}
#[test]
fn deserialize_seed_peer_correctly() {
let json_data = r#"
{
"enable": true,
"type": "super",
"clusterID": 2,
"keepaliveInterval": "60s"
}"#;
let seed_peer: SeedPeer = serde_json::from_str(json_data).unwrap();
assert!(seed_peer.enable);
assert_eq!(seed_peer.kind, HostType::Super);
}
#[test]
fn default_dynconfig() {
let default_dynconfig = Dynconfig::default();
assert_eq!(default_dynconfig.refresh_interval, Duration::from_secs(300));
}
#[test]
fn deserialize_dynconfig_correctly() {
let json_data = r#"
{
"refreshInterval": "5m"
}"#;
let dynconfig: Dynconfig = serde_json::from_str(json_data).unwrap();
assert_eq!(dynconfig.refresh_interval, Duration::from_secs(300));
}
#[test]
fn deserialize_storage_correctly() {
let json_data = r#"
{
"server": {
"ip": "128.0.0.1",
"tcpPort": 4005,
"quicPort": 4006
},
"dir": "/tmp/storage",
"keep": true,
"writePieceTimeout": "20s",
"writeBufferSize": 8388608,
"readBufferSize": 8388608,
"cacheCapacity": "256MB"
}"#;
let storage: Storage = serde_json::from_str(json_data).unwrap();
assert_eq!(
storage.server.ip.unwrap().to_string(),
"128.0.0.1".to_string()
);
assert_eq!(storage.server.tcp_port, 4005);
assert_eq!(storage.server.quic_port, 4006);
assert_eq!(storage.dir, PathBuf::from("/tmp/storage"));
assert!(storage.keep);
assert_eq!(storage.write_piece_timeout, Duration::from_secs(20));
assert_eq!(storage.write_buffer_size, 8 * 1024 * 1024);
assert_eq!(storage.read_buffer_size, 8 * 1024 * 1024);
assert_eq!(storage.cache_capacity, ByteSize::mb(256));
}
#[test]
fn validate_policy() {
let valid_policy = Policy {
task_ttl: Duration::from_secs(12 * 3600),
persistent_task_ttl: Duration::from_secs(24 * 3600),
persistent_cache_task_ttl: Duration::from_secs(48 * 3600),
disk_threshold: ByteSize::mb(100),
disk_high_threshold_percent: 90,
disk_low_threshold_percent: 70,
};
assert!(valid_policy.validate().is_ok());
let invalid_policy = Policy {
task_ttl: Duration::from_secs(12 * 3600),
persistent_task_ttl: Duration::from_secs(24 * 3600),
persistent_cache_task_ttl: Duration::from_secs(48 * 3600),
disk_threshold: ByteSize::mb(100),
disk_high_threshold_percent: 100,
disk_low_threshold_percent: 70,
};
assert!(invalid_policy.validate().is_err());
}
#[test]
fn deserialize_gc_correctly() {
let json_data = r#"
{
"interval": "1h",
"policy": {
"taskTTL": "12h",
"persistentTaskTTL": "24h",
"persistentCacheTaskTTL": "48h",
"distHighThresholdPercent": 90,
"distLowThresholdPercent": 70
}
}"#;
let gc: GC = serde_json::from_str(json_data).unwrap();
assert_eq!(gc.interval, Duration::from_secs(3600));
assert_eq!(gc.policy.task_ttl, Duration::from_secs(12 * 3600));
assert_eq!(
gc.policy.persistent_task_ttl,
Duration::from_secs(24 * 3600)
);
assert_eq!(
gc.policy.persistent_cache_task_ttl,
Duration::from_secs(48 * 3600)
);
assert_eq!(gc.policy.disk_high_threshold_percent, 90);
assert_eq!(gc.policy.disk_low_threshold_percent, 70);
}
#[test]
fn deserialize_proxy_correctly() {
let json_data = r#"
{
"server": {
"port": 8080,
"caCert": "/path/to/ca_cert.pem",
"caKey": "/path/to/ca_key.pem",
"basicAuth": {
"username": "admin",
"password": "password"
}
},
"rules": [
{
"regex": "^https?://example\\.com/.*$",
"useTLS": true,
"redirect": "https://mirror.example.com",
"filteredQueryParams": ["Signature", "Expires"]
}
],
"registryMirror": {
"enableTaskIDBasedBlobDigest": true,
"addr": "https://mirror.example.com",
"cert": "/path/to/cert.pem"
},
"disableBackToSource": true,
"prefetch": true,
"prefetchBandwidthLimit": "1GB",
"readBufferSize": 8388608,
"customHeaders": {
"X-Custom-Header": "custom-value"
}
}"#;
let proxy: Proxy = serde_json::from_str(json_data).unwrap();
assert_eq!(proxy.server.port, 8080);
assert_eq!(
proxy.server.ca_cert,
Some(PathBuf::from("/path/to/ca_cert.pem"))
);
assert_eq!(
proxy.server.ca_key,
Some(PathBuf::from("/path/to/ca_key.pem"))
);
assert_eq!(
proxy.server.basic_auth.as_ref().unwrap().username,
"admin".to_string()
);
assert_eq!(
proxy.server.basic_auth.as_ref().unwrap().password,
"password".to_string()
);
let rule = &proxy.rules.as_ref().unwrap()[0];
assert_eq!(rule.regex.as_str(), "^https?://example\\.com/.*$");
assert!(rule.use_tls);
assert_eq!(
rule.redirect,
Some("https://mirror.example.com".to_string())
);
assert_eq!(rule.filtered_query_params, vec!["Signature", "Expires"]);
assert!(proxy.registry_mirror.enable_task_id_based_blob_digest);
assert_eq!(proxy.registry_mirror.addr, "https://mirror.example.com");
assert_eq!(
proxy.registry_mirror.cert,
Some(PathBuf::from("/path/to/cert.pem"))
);
assert!(proxy.disable_back_to_source);
assert!(proxy.prefetch);
assert_eq!(proxy.prefetch_bandwidth_limit, ByteSize::gb(1));
assert_eq!(proxy.read_buffer_size, 8 * 1024 * 1024);
}
#[test]
fn deserialize_tracing_correctly() {
let json_data = r#"
{
"protocol": "http",
"endpoint": "tracing.example.com",
"path": "/v1/traces",
"headers": {
"X-Custom-Header": "value"
}
}"#;
let tracing: Tracing = serde_json::from_str(json_data).unwrap();
assert_eq!(tracing.protocol, Some("http".to_string()));
assert_eq!(tracing.endpoint, Some("tracing.example.com".to_string()));
assert_eq!(tracing.path, Some(PathBuf::from("/v1/traces")));
assert!(tracing.headers.contains_key("X-Custom-Header"));
}
#[test]
fn deserialize_metrics_correctly() {
let json_data = r#"
{
"server": {
"port": 4002,
"ip": "127.0.0.1"
}
}"#;
let metrics: Metrics = serde_json::from_str(json_data).unwrap();
assert_eq!(metrics.server.port, 4002);
assert_eq!(
metrics.server.ip,
Some("127.0.0.1".parse::<IpAddr>().unwrap())
);
}
#[test]
fn deserialize_backend_correctly() {
let json_data = r#"
{
"requestHeader": {
"X-Custom-Header": "value"
},
"enableCacheTemporaryRedirect": false,
"cacheTemporaryRedirectTTL": "15m",
"putConcurrentChunkCount": 2,
"putChunkSize": "2mib",
"putTimeout": "1m",
"enableHickoryDNS": false
}"#;
let backend: Backend = serde_json::from_str(json_data).unwrap();
assert!(backend.request_header.is_some());
assert_eq!(
backend
.request_header
.as_ref()
.unwrap()
.get("X-Custom-Header"),
Some(&"value".to_string())
);
assert!(!backend.enable_cache_temporary_redirect);
assert_eq!(
backend.cache_temporary_redirect_ttl,
Duration::from_secs(900)
);
assert_eq!(backend.put_concurrent_chunk_count, 2);
assert_eq!(backend.put_chunk_size, ByteSize::mib(2));
assert_eq!(backend.put_timeout, Duration::from_secs(60));
assert!(!backend.enable_hickory_dns);
}
}