updatehub_cloud_sdk/
client.rs

1// Copyright (C) 2020 O.S. Systems Sofware LTDA
2//
3// SPDX-License-Identifier: Apache-2.0
4
5use crate::{api, Error, Result};
6use reqwest::{header, StatusCode};
7use slog_scope::{debug, error};
8use std::{
9    convert::{TryFrom, TryInto},
10    path::Path,
11};
12use tokio::{fs, io};
13
14pub struct Client<'a> {
15    client: reqwest::Client,
16    server: &'a str,
17}
18
19pub async fn get<W>(url: &str, handle: &mut W) -> Result<()>
20where
21    W: io::AsyncWrite + Unpin,
22{
23    let url = reqwest::Url::parse(url)?;
24    save_body_to(reqwest::get(url).await?, handle).await
25}
26
27async fn save_body_to<W>(mut resp: reqwest::Response, handle: &mut W) -> Result<()>
28where
29    W: io::AsyncWrite + Unpin,
30{
31    use io::AsyncWriteExt;
32    use std::str::FromStr;
33
34    if !resp.status().is_success() {
35        return Err(Error::InvalidStatusResponse(resp.status()));
36    }
37
38    let mut written: f32 = 0.;
39    let mut threshold = 10;
40    let length = match resp.headers().get(header::CONTENT_LENGTH) {
41        Some(v) => usize::from_str(v.to_str()?)?,
42        None => 0,
43    };
44
45    while let Some(chunk) = resp.chunk().await? {
46        let read = chunk.len();
47        handle.write_all(&chunk).await?;
48        if length > 0 {
49            written += read as f32 / (length as f32 / 100.);
50            if written as usize >= threshold {
51                threshold += 20;
52                debug!("{}% of the file has been downloaded", std::cmp::min(written as usize, 100));
53            }
54        }
55    }
56
57    Ok(())
58}
59
60impl<'a> Client<'a> {
61    pub fn new(server: &'a str) -> Self {
62        let mut headers = header::HeaderMap::new();
63        headers.insert(header::USER_AGENT, header::HeaderValue::from_static("updatehub/2.0 Linux"));
64        headers.insert(header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"));
65        headers.insert(
66            "api-content-type",
67            header::HeaderValue::from_static("application/vnd.updatehub-v1+json"),
68        );
69
70        let client = reqwest::Client::builder()
71            .connect_timeout(std::time::Duration::from_secs(10))
72            .default_headers(headers)
73            .build()
74            .unwrap();
75
76        Self { server, client }
77    }
78
79    pub async fn probe(
80        &self,
81        num_retries: usize,
82        firmware: api::FirmwareMetadata<'_>,
83    ) -> Result<api::ProbeResponse> {
84        reqwest::Url::parse(self.server)?;
85
86        let response = self
87            .client
88            .post(&format!("{}/upgrades", &self.server))
89            .header("api-retries", num_retries.to_string())
90            .json(&firmware)
91            .send()
92            .await?;
93
94        match response.status() {
95            StatusCode::NOT_FOUND => Ok(api::ProbeResponse::NoUpdate),
96            StatusCode::OK => {
97                match response
98                    .headers()
99                    .get("add-extra-poll")
100                    .and_then(|extra_poll| extra_poll.to_str().ok())
101                    .and_then(|extra_poll| extra_poll.parse().ok())
102                {
103                    Some(extra_poll) => Ok(api::ProbeResponse::ExtraPoll(extra_poll)),
104                    None => {
105                        let signature = response
106                            .headers()
107                            .get("UH-Signature")
108                            .map(TryInto::try_into)
109                            .transpose()?;
110                        Ok(api::ProbeResponse::Update(
111                            api::UpdatePackage::parse(&response.bytes().await?)?,
112                            signature,
113                        ))
114                    }
115                }
116            }
117            s => Err(Error::InvalidStatusResponse(s)),
118        }
119    }
120
121    pub async fn download_object(
122        &self,
123        product_uid: &str,
124        package_uid: &str,
125        download_dir: &Path,
126        object: &str,
127    ) -> Result<()> {
128        validate_url(self.server)?;
129
130        // FIXME: Discuss the need of packages inside the route
131        let mut request = self.client.get(format!(
132            "{}/products/{}/packages/{}/objects/{}",
133            &self.server, product_uid, package_uid, object
134        ));
135
136        if !download_dir.exists() {
137            fs::create_dir_all(download_dir).await.map_err(|e| {
138                error!("fail to create {:?} directory, error: {}", download_dir, e);
139                e
140            })?;
141        }
142
143        let file = download_dir.join(object);
144        if file.exists() {
145            request = request
146                .header("RANGE", format!("bytes={}-", file.metadata()?.len().saturating_sub(1)));
147        }
148
149        let mut file = fs::OpenOptions::new().create(true).append(true).open(&file).await?;
150
151        save_body_to(request.send().await?, &mut file).await
152    }
153
154    pub async fn report(
155        &self,
156        state: &str,
157        firmware: api::FirmwareMetadata<'_>,
158        package_uid: &str,
159        previous_state: Option<&str>,
160        error_message: Option<String>,
161        current_log: Option<String>,
162    ) -> Result<()> {
163        validate_url(self.server)?;
164
165        #[derive(serde::Serialize)]
166        #[serde(rename_all = "kebab-case")]
167        struct Payload<'a> {
168            #[serde(rename = "status")]
169            state: &'a str,
170            #[serde(flatten)]
171            firmware: api::FirmwareMetadata<'a>,
172            package_uid: &'a str,
173            #[serde(skip_serializing_if = "Option::is_none")]
174            previous_state: Option<&'a str>,
175            #[serde(skip_serializing_if = "Option::is_none")]
176            error_message: Option<String>,
177            #[serde(skip_serializing_if = "Option::is_none")]
178            current_log: Option<String>,
179        }
180
181        let payload =
182            Payload { state, firmware, package_uid, previous_state, error_message, current_log };
183
184        self.client.post(&format!("{}/report", &self.server)).json(&payload).send().await?;
185        Ok(())
186    }
187}
188
189impl TryFrom<&header::HeaderValue> for api::Signature {
190    type Error = Error;
191
192    fn try_from(value: &header::HeaderValue) -> Result<Self> {
193        let value = value.to_str()?;
194
195        // Workarround for https://github.com/sfackler/rust-openssl/issues/1325
196        if value.is_empty() {
197            return Self::from_base64_str("");
198        }
199
200        Self::from_base64_str(value)
201    }
202}
203
204fn validate_url(url: &str) -> Result<()> {
205    url::Url::parse(url)?;
206    Ok(())
207}