use bytes::{Bytes, BytesMut};
use dragonfly_client_config::dfdaemon::Config;
use dragonfly_client_core::{
error::{ErrorType, OrErr},
Error as ClientError, Result as ClientResult,
};
use quinn::crypto::rustls::QuicClientConfig;
use quinn::{AckFrequencyConfig, ClientConfig, Endpoint, RecvStream, SendStream, TransportConfig};
use rustls_pki_types::{CertificateDer, ServerName, UnixTime};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::AsyncRead;
use tokio::time;
use tracing::{error, instrument, Span};
use vortex_protocol::{
tlv::{
download_persistent_cache_piece::DownloadPersistentCachePiece,
download_persistent_piece::DownloadPersistentPiece, download_piece::DownloadPiece,
error::Error as VortexError, persistent_cache_piece_content, persistent_piece_content,
piece_content, Tag,
},
Header, Vortex, HEADER_SIZE,
};
#[derive(Clone)]
pub struct QUICClient {
config: Arc<Config>,
addr: String,
}
impl QUICClient {
pub fn new(config: Arc<Config>, addr: String) -> Self {
Self { config, addr }
}
#[instrument(skip_all, fields(parent_addr))]
pub async fn download_piece(
&self,
number: u32,
task_id: &str,
) -> ClientResult<(impl AsyncRead, u64, String)> {
Span::current().record("parent_addr", self.addr.as_str());
time::timeout(
self.config.download.piece_timeout,
self.handle_download_piece(number, task_id),
)
.await
.inspect_err(|err| {
error!("connect timeout to {}: {}", self.addr, err);
})?
}
#[instrument(skip_all)]
async fn handle_download_piece(
&self,
number: u32,
task_id: &str,
) -> ClientResult<(impl AsyncRead, u64, String)> {
let request: Bytes = Vortex::DownloadPiece(
Header::new_download_piece(),
DownloadPiece::new(task_id.to_string(), number),
)
.into();
let (mut reader, _writer) = self.connect_and_write_request(request).await?;
let header = self.read_header(&mut reader).await?;
match header.tag() {
Tag::PieceContent => {
let piece_content: piece_content::PieceContent = self
.read_piece_content(&mut reader, piece_content::METADATA_LENGTH_SIZE)
.await?;
let metadata = piece_content.metadata();
Ok((reader, metadata.offset, metadata.digest))
}
Tag::Error => Err(self.read_error(&mut reader, header.length() as usize).await),
_ => Err(ClientError::Unknown(format!(
"unexpected tag: {:?}",
header.tag()
))),
}
}
#[instrument(skip_all)]
pub async fn download_persistent_piece(
&self,
number: u32,
task_id: &str,
) -> ClientResult<(impl AsyncRead, u64, String)> {
time::timeout(
self.config.download.piece_timeout,
self.handle_download_persistent_piece(number, task_id),
)
.await
.inspect_err(|err| {
error!("connect timeout to {}: {}", self.addr, err);
})?
}
#[instrument(skip_all)]
async fn handle_download_persistent_piece(
&self,
number: u32,
task_id: &str,
) -> ClientResult<(impl AsyncRead, u64, String)> {
let request: Bytes = Vortex::DownloadPersistentPiece(
Header::new_download_persistent_piece(),
DownloadPersistentPiece::new(task_id.to_string(), number),
)
.into();
let (mut reader, _writer) = self.connect_and_write_request(request).await?;
let header = self.read_header(&mut reader).await?;
match header.tag() {
Tag::PersistentPieceContent => {
let persistent_piece_content: persistent_piece_content::PersistentPieceContent =
self.read_piece_content(
&mut reader,
persistent_piece_content::METADATA_LENGTH_SIZE,
)
.await?;
let metadata = persistent_piece_content.metadata();
Ok((reader, metadata.offset, metadata.digest))
}
Tag::Error => Err(self.read_error(&mut reader, header.length() as usize).await),
_ => Err(ClientError::Unknown(format!(
"unexpected tag: {:?}",
header.tag()
))),
}
}
#[instrument(skip_all)]
pub async fn download_persistent_cache_piece(
&self,
number: u32,
task_id: &str,
) -> ClientResult<(impl AsyncRead, u64, String)> {
time::timeout(
self.config.download.piece_timeout,
self.handle_download_persistent_cache_piece(number, task_id),
)
.await
.inspect_err(|err| {
error!("connect timeout to {}: {}", self.addr, err);
})?
}
#[instrument(skip_all)]
async fn handle_download_persistent_cache_piece(
&self,
number: u32,
task_id: &str,
) -> ClientResult<(impl AsyncRead, u64, String)> {
let request: Bytes = Vortex::DownloadPersistentCachePiece(
Header::new_download_persistent_cache_piece(),
DownloadPersistentCachePiece::new(task_id.to_string(), number),
)
.into();
let (mut reader, _writer) = self.connect_and_write_request(request).await?;
let header = self.read_header(&mut reader).await?;
match header.tag() {
Tag::PersistentCachePieceContent => {
let persistent_cache_piece_content: persistent_cache_piece_content::PersistentCachePieceContent =
self.read_piece_content(&mut reader, persistent_cache_piece_content::METADATA_LENGTH_SIZE)
.await?;
let metadata = persistent_cache_piece_content.metadata();
Ok((reader, metadata.offset, metadata.digest))
}
Tag::Error => Err(self.read_error(&mut reader, header.length() as usize).await),
_ => Err(ClientError::Unknown(format!(
"unexpected tag: {:?}",
header.tag()
))),
}
}
#[instrument(skip_all)]
async fn connect_and_write_request(
&self,
request: Bytes,
) -> ClientResult<(RecvStream, SendStream)> {
let mut client_config = ClientConfig::new(Arc::new(
QuicClientConfig::try_from(
quinn::rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(NoVerifier::new())
.with_no_client_auth(),
)
.map_err(|err| {
ClientError::Unknown(format!("failed to create quic client config: {}", err))
})?,
));
let mut transport = TransportConfig::default();
transport.keep_alive_interval(Some(super::DEFAULT_KEEPALIVE_INTERVAL));
transport.max_idle_timeout(Some(super::DEFAULT_MAX_IDLE_TIMEOUT.try_into().unwrap()));
transport.ack_frequency_config(Some(AckFrequencyConfig::default()));
transport.send_window(super::DEFAULT_SEND_BUFFER_SIZE as u64);
transport.receive_window((super::DEFAULT_RECV_BUFFER_SIZE as u32).into());
transport.stream_receive_window((super::DEFAULT_RECV_BUFFER_SIZE as u32).into());
client_config.transport_config(Arc::new(transport));
let mut endpoint =
Endpoint::client(SocketAddr::new(self.config.storage.server.ip.unwrap(), 0))?;
endpoint.set_default_client_config(client_config);
let connection = endpoint
.connect(self.addr.parse().or_err(ErrorType::ParseError)?, "d7y")
.or_err(ErrorType::ConnectError)?
.await
.inspect_err(|err| error!("failed to connect to {}: {}", self.addr, err))
.or_err(ErrorType::ConnectError)?;
let (mut writer, reader) = connection
.open_bi()
.await
.inspect_err(|err| error!("failed to open bi stream: {}", err))
.or_err(ErrorType::ConnectError)?;
writer
.write_all(&request)
.await
.inspect_err(|err| error!("failed to send request: {}", err))
.or_err(ErrorType::ConnectError)?;
Ok((reader, writer))
}
#[instrument(skip_all)]
async fn read_header(&self, reader: &mut RecvStream) -> ClientResult<Header> {
let mut header_bytes = BytesMut::with_capacity(HEADER_SIZE);
header_bytes.resize(HEADER_SIZE, 0);
reader
.read_exact(&mut header_bytes)
.await
.inspect_err(|err| error!("failed to receive header: {}", err))
.or_err(ErrorType::ConnectError)?;
Header::try_from(header_bytes.freeze()).map_err(Into::into)
}
#[instrument(skip_all)]
async fn read_piece_content<T>(
&self,
reader: &mut RecvStream,
metadata_length_size: usize,
) -> ClientResult<T>
where
T: TryFrom<Bytes, Error: Into<ClientError>>,
{
let mut metadata_length_bytes = BytesMut::with_capacity(metadata_length_size);
metadata_length_bytes.resize(metadata_length_size, 0);
reader
.read_exact(&mut metadata_length_bytes)
.await
.inspect_err(|err| error!("failed to receive metadata length: {}", err))
.or_err(ErrorType::ConnectError)?;
let metadata_length = u32::from_be_bytes(metadata_length_bytes[..].try_into()?) as usize;
let mut metadata_bytes = BytesMut::with_capacity(metadata_length);
metadata_bytes.resize(metadata_length, 0);
reader
.read_exact(&mut metadata_bytes)
.await
.inspect_err(|err| error!("failed to receive metadata: {}", err))
.or_err(ErrorType::ConnectError)?;
let mut content_bytes = BytesMut::with_capacity(metadata_length_size + metadata_length);
content_bytes.extend_from_slice(&metadata_length_bytes);
content_bytes.extend_from_slice(&metadata_bytes);
content_bytes.freeze().try_into().map_err(Into::into)
}
#[instrument(skip_all)]
async fn read_error(&self, reader: &mut RecvStream, header_length: usize) -> ClientError {
let mut error_bytes = BytesMut::with_capacity(header_length);
error_bytes.resize(header_length, 0);
if let Err(err) = reader.read_exact(&mut error_bytes).await {
error!("failed to receive error: {}", err);
return ClientError::Unknown(err.to_string());
};
error_bytes
.freeze()
.try_into()
.map(|error: VortexError| {
ClientError::VortexProtocolStatus(error.code(), error.message().to_string())
})
.unwrap_or_else(|err| {
error!("failed to extract error: {}", err);
ClientError::Unknown(format!("failed to extract error: {}", err))
})
}
}
#[derive(Debug)]
pub struct NoVerifier(Arc<quinn::rustls::crypto::CryptoProvider>);
impl NoVerifier {
pub fn new() -> Arc<Self> {
Arc::new(Self(Arc::new(
quinn::rustls::crypto::ring::default_provider(),
)))
}
}
impl quinn::rustls::client::danger::ServerCertVerifier for NoVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp: &[u8],
_now: UnixTime,
) -> Result<quinn::rustls::client::danger::ServerCertVerified, quinn::rustls::Error> {
Ok(quinn::rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &quinn::rustls::DigitallySignedStruct,
) -> Result<quinn::rustls::client::danger::HandshakeSignatureValid, quinn::rustls::Error> {
quinn::rustls::crypto::verify_tls12_signature(
message,
cert,
dss,
&self.0.signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &quinn::rustls::DigitallySignedStruct,
) -> Result<quinn::rustls::client::danger::HandshakeSignatureValid, quinn::rustls::Error> {
quinn::rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&self.0.signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<quinn::rustls::SignatureScheme> {
self.0.signature_verification_algorithms.supported_schemes()
}
}
#[cfg(test)]
mod tests {
use super::*;
use quinn::rustls::client::danger::ServerCertVerifier;
#[test]
fn test_no_verifier() {
let verifier = NoVerifier::new();
let result = verifier.verify_server_cert(
&CertificateDer::from(vec![]),
&[],
&ServerName::DnsName("d7y.io".try_into().unwrap()),
&[],
UnixTime::now(),
);
assert!(result.is_ok());
let schemes = verifier.supported_verify_schemes();
assert!(!schemes.is_empty());
}
}