use std::path::{Path, PathBuf};
use bytes::Bytes;
use futures::{prelude::*, TryStreamExt};
use reqwest::{Client, Response};
use serde::{Deserialize, Serialize};
use tokio::{
fs::{DirBuilder, File},
io::AsyncWriteExt,
};
use url::Url;
use crate::ddi::client::Error;
use crate::ddi::common::{Execution, Finished, Link};
use crate::ddi::feedback::Feedback;
#[derive(Debug)]
pub struct UpdatePreFetch {
client: Client,
url: String,
}
impl UpdatePreFetch {
pub(crate) fn new(client: Client, url: String) -> Self {
Self { client, url }
}
pub async fn fetch(self) -> Result<Update, Error> {
let reply = self.client.get(&self.url).send().await?;
reply.error_for_status_ref()?;
let reply = reply.json::<Reply>().await?;
Ok(Update::new(self.client, reply, self.url))
}
}
#[derive(Debug, Deserialize)]
struct Reply {
id: String,
deployment: Deployment,
#[serde(rename = "actionHistory")]
action_history: Option<ActionHistory>,
}
#[derive(Debug, Deserialize)]
struct Deployment {
download: Type,
update: Type,
#[serde(rename = "maintenanceWindow")]
maintenance_window: Option<MaintenanceWindow>,
chunks: Vec<ChunkInternal>,
}
#[derive(Debug, Deserialize, Serialize, Copy, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Type {
Skip,
Attempt,
Forced,
}
#[derive(Debug, Deserialize, Serialize, Copy, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum MaintenanceWindow {
Available,
Unavailable,
}
#[derive(Debug, Deserialize)]
struct ChunkInternal {
#[serde(default)]
metadata: Vec<Metadata>,
part: String,
name: String,
version: String,
artifacts: Vec<ArtifactInternal>,
}
#[derive(Debug, Deserialize)]
struct Metadata {
key: String,
value: String,
}
#[derive(Debug, Deserialize)]
struct ArtifactInternal {
filename: String,
hashes: Hashes,
size: u32,
#[serde(rename = "_links")]
links: Links,
}
#[derive(Debug, Deserialize, Clone)]
struct Hashes {
sha1: String,
md5: String,
sha256: String,
}
#[derive(Debug, Deserialize)]
struct Links {
#[serde(rename = "download-http")]
download_http: Link,
#[serde(rename = "md5sum-http")]
md5sum_http: Link,
}
#[derive(Debug, Deserialize)]
struct ActionHistory {
status: String,
#[serde(default)]
messages: Vec<String>,
}
#[derive(Debug)]
pub struct Update {
client: Client,
info: Reply,
url: String,
}
impl Update {
fn new(client: Client, info: Reply, url: String) -> Self {
Self { client, info, url }
}
pub fn download_type(&self) -> Type {
self.info.deployment.download
}
pub fn update_type(&self) -> Type {
self.info.deployment.update
}
pub fn maintenance_window(&self) -> Option<MaintenanceWindow> {
self.info.deployment.maintenance_window
}
pub fn chunks(&self) -> impl Iterator<Item = Chunk> {
let client = self.client.clone();
self.info
.deployment
.chunks
.iter()
.map(move |c| Chunk::new(c, client.clone()))
}
pub async fn download(&self, dir: &Path) -> Result<Vec<DownloadedArtifact>, Error> {
let mut result = Vec::new();
for c in self.chunks() {
let downloaded = c.download(dir).await?;
result.extend(downloaded);
}
Ok(result)
}
async fn send_feedback_internal<T: Serialize>(
&self,
execution: Execution,
finished: Finished,
progress: Option<T>,
details: Vec<&str>,
) -> Result<(), Error> {
let mut url: Url = self.url.parse()?;
{
match url.path_segments_mut() {
Err(_) => {
return Err(Error::ParseUrlError(
url::ParseError::SetHostOnCannotBeABaseUrl,
))
}
Ok(mut paths) => {
paths.push("feedback");
}
}
}
url.set_query(None);
let details = details.iter().map(|m| m.to_string()).collect();
let feedback = Feedback::new(&self.info.id, execution, finished, progress, details);
let reply = self
.client
.post(&url.to_string())
.json(&feedback)
.send()
.await?;
reply.error_for_status()?;
Ok(())
}
pub async fn send_feedback_with_progress<T: Serialize>(
&self,
execution: Execution,
finished: Finished,
progress: T,
details: Vec<&str>,
) -> Result<(), Error> {
self.send_feedback_internal(execution, finished, Some(progress), details)
.await
}
pub async fn send_feedback(
&self,
execution: Execution,
finished: Finished,
details: Vec<&str>,
) -> Result<(), Error> {
self.send_feedback_internal::<bool>(execution, finished, None, details)
.await
}
}
#[derive(Debug)]
pub struct Chunk<'a> {
chunk: &'a ChunkInternal,
client: Client,
}
impl<'a> Chunk<'a> {
fn new(chunk: &'a ChunkInternal, client: Client) -> Self {
Self { chunk, client }
}
pub fn part(&self) -> &str {
&self.chunk.part
}
pub fn name(&self) -> &str {
&self.chunk.name
}
pub fn version(&self) -> &str {
&self.chunk.version
}
pub fn artifacts(&self) -> impl Iterator<Item = Artifact> {
let client = self.client.clone();
self.chunk
.artifacts
.iter()
.map(move |a| Artifact::new(a, client.clone()))
}
pub async fn download(&'a self, dir: &Path) -> Result<Vec<DownloadedArtifact>, Error> {
let mut dir = dir.to_path_buf();
dir.push(self.name());
let mut result = Vec::new();
for a in self.artifacts() {
let downloaded = a.download(&dir).await?;
result.push(downloaded);
}
Ok(result)
}
}
#[derive(Debug)]
pub struct Artifact<'a> {
artifact: &'a ArtifactInternal,
client: Client,
}
impl<'a> Artifact<'a> {
fn new(artifact: &'a ArtifactInternal, client: Client) -> Self {
Self { artifact, client }
}
pub fn filename(&self) -> &str {
&self.artifact.filename
}
pub fn size(&self) -> u32 {
self.artifact.size
}
async fn download_response(&'a self) -> Result<Response, Error> {
let resp = self
.client
.get(&self.artifact.links.download_http.to_string())
.send()
.await?;
resp.error_for_status_ref()?;
Ok(resp)
}
pub async fn download(&'a self, dir: &Path) -> Result<DownloadedArtifact, Error> {
let mut resp = self.download_response().await?;
if !dir.exists() {
DirBuilder::new().recursive(true).create(dir).await?;
}
let mut file_name = dir.to_path_buf();
file_name.push(self.filename());
let mut dest = File::create(&file_name).await?;
while let Some(chunk) = resp.chunk().await? {
dest.write_all(&chunk).await?;
}
Ok(DownloadedArtifact::new(
file_name,
self.artifact.hashes.clone(),
))
}
pub async fn download_stream(
&'a self,
) -> Result<impl Stream<Item = Result<Bytes, Error>>, Error> {
let resp = self.download_response().await?;
Ok(resp.bytes_stream().map_err(|e| e.into()))
}
#[cfg(feature = "hash-md5")]
pub async fn download_stream_with_md5_check(
&'a self,
) -> Result<impl Stream<Item = Result<Bytes, Error>>, Error> {
let stream = self.download_stream().await?;
let hasher = DownloadHasher::new_md5(self.artifact.hashes.md5.clone());
let stream = DownloadStreamHash {
stream: Box::new(stream),
hasher,
};
Ok(stream)
}
#[cfg(feature = "hash-sha1")]
pub async fn download_stream_with_sha1_check(
&'a self,
) -> Result<impl Stream<Item = Result<Bytes, Error>>, Error> {
let stream = self.download_stream().await?;
let hasher = DownloadHasher::new_sha1(self.artifact.hashes.sha1.clone());
let stream = DownloadStreamHash {
stream: Box::new(stream),
hasher,
};
Ok(stream)
}
#[cfg(feature = "hash-sha256")]
pub async fn download_stream_with_sha256_check(
&'a self,
) -> Result<impl Stream<Item = Result<Bytes, Error>>, Error> {
let stream = self.download_stream().await?;
let hasher = DownloadHasher::new_sha256(self.artifact.hashes.sha256.clone());
let stream = DownloadStreamHash {
stream: Box::new(stream),
hasher,
};
Ok(stream)
}
}
#[derive(Debug)]
pub struct DownloadedArtifact {
file: PathBuf,
hashes: Hashes,
}
cfg_if::cfg_if! {
if #[cfg(feature = "hash-digest")] {
use std::{
pin::Pin,
task::Poll,
};
use digest::Digest;
const HASH_BUFFER_SIZE: usize = 4096;
#[derive(Debug, strum::Display, Clone)]
pub enum ChecksumType {
#[cfg(feature = "hash-md5")]
Md5,
#[cfg(feature = "hash-sha1")]
Sha1,
#[cfg(feature = "hash-sha256")]
Sha256,
}
#[derive(Clone)]
struct DownloadHasher<T>
where
T: Digest,
<T as Digest>::OutputSize: core::ops::Add,
<<T as Digest>::OutputSize as core::ops::Add>::Output: generic_array::ArrayLength<u8>,
{
hasher: T,
expected: String,
error: ChecksumType,
}
impl<T> DownloadHasher<T>
where
T: Digest,
<T as Digest>::OutputSize: core::ops::Add,
<<T as Digest>::OutputSize as core::ops::Add>::Output: generic_array::ArrayLength<u8>
{
fn update(&mut self, data: impl AsRef<[u8]>) {
self.hasher.update(data);
}
fn finalize(self) -> Result<(), Error> {
let digest = self.hasher.finalize();
if format!("{:x}", digest) == self.expected {
Ok(())
} else {
Err(Error::ChecksumError(self.error))
}
}
}
#[cfg(feature = "hash-md5")]
impl DownloadHasher<md5::Md5> {
fn new_md5(expected: String) -> Self {
Self {
hasher: md5::Md5::new(),
expected,
error: ChecksumType::Md5,
}
}
}
#[cfg(feature = "hash-sha1")]
impl DownloadHasher<sha1::Sha1> {
fn new_sha1(expected: String) -> Self {
Self {
hasher: sha1::Sha1::new(),
expected,
error: ChecksumType::Sha1,
}
}
}
#[cfg(feature = "hash-sha256")]
impl DownloadHasher<sha2::Sha256> {
fn new_sha256(expected: String) -> Self {
Self {
hasher: sha2::Sha256::new(),
expected,
error: ChecksumType::Sha256,
}
}
}
struct DownloadStreamHash<T>
where
T: Digest,
<T as Digest>::OutputSize: core::ops::Add,
<<T as Digest>::OutputSize as core::ops::Add>::Output: generic_array::ArrayLength<u8>,
{
stream: Box<dyn Stream<Item = Result<Bytes, Error>> + Unpin + Send + Sync>,
hasher: DownloadHasher<T>,
}
impl<T> Stream for DownloadStreamHash<T>
where
T: Digest,
<T as Digest>::OutputSize: core::ops::Add,
<<T as Digest>::OutputSize as core::ops::Add>::Output: generic_array::ArrayLength<u8>,
T: Unpin,
T: Clone,
{
type Item = Result<Bytes, Error>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let me = Pin::into_inner(self);
match Pin::new(&mut me.stream).poll_next(cx) {
Poll::Ready(Some(Ok(data))) => {
me.hasher.update(&data);
Poll::Ready(Some(Ok(data)))
}
Poll::Ready(None) => {
match me.hasher.clone().finalize() {
Ok(_) => Poll::Ready(None),
Err(e) => Poll::Ready(Some(Err(e))),
}
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Pending => Poll::Pending,
}
}
}
}
}
impl<'a> DownloadedArtifact {
fn new(file: PathBuf, hashes: Hashes) -> Self {
Self { file, hashes }
}
pub fn file(&self) -> &PathBuf {
&self.file
}
#[cfg(feature = "hash-digest")]
async fn hash<T>(&self, mut hasher: DownloadHasher<T>) -> Result<(), Error>
where
T: Digest,
<T as Digest>::OutputSize: core::ops::Add,
<<T as Digest>::OutputSize as core::ops::Add>::Output: generic_array::ArrayLength<u8>,
{
use tokio::io::AsyncReadExt;
let mut file = File::open(&self.file).await?;
let mut buffer = [0; HASH_BUFFER_SIZE];
loop {
let n = file.read(&mut buffer[..]).await?;
if n == 0 {
break;
}
hasher.update(&buffer[..n]);
}
hasher.finalize()
}
#[cfg(feature = "hash-md5")]
pub async fn check_md5(&self) -> Result<(), Error> {
let hasher = DownloadHasher::new_md5(self.hashes.md5.clone());
self.hash(hasher).await
}
#[cfg(feature = "hash-sha1")]
pub async fn check_sha1(&self) -> Result<(), Error> {
let hasher = DownloadHasher::new_sha1(self.hashes.sha1.clone());
self.hash(hasher).await
}
#[cfg(feature = "hash-sha256")]
pub async fn check_sha256(&self) -> Result<(), Error> {
let hasher = DownloadHasher::new_sha256(self.hashes.sha256.clone());
self.hash(hasher).await
}
}