arm_toolchain/toolchain/
client.rs

1use std::{
2    fmt::Debug,
3    io::{ErrorKind, SeekFrom},
4    path::{Path, PathBuf},
5    sync::{Arc, RwLock},
6};
7
8use camino::Utf8Path;
9use data_encoding::HEXLOWER;
10use futures::{TryStreamExt, future::join_all};
11use octocrab::{Octocrab, models::repos::Asset};
12use reqwest::header;
13use sha2::{Digest, Sha256};
14use tokio::io::{self, AsyncReadExt, AsyncSeekExt, AsyncWriteExt, BufReader, BufWriter};
15use tokio_util::{future::FutureExt as _, sync::CancellationToken};
16use tracing::{debug, info, instrument, trace, warn};
17
18use crate::{
19    CheckCancellation, DIRS, TRASH, fs,
20    toolchain::{
21        APP_USER_AGENT, InstallState, InstalledToolchain, ToolchainError, ToolchainRelease,
22        ToolchainVersion, extract,
23        remove::{RemoveProgress, remove_dir_progress},
24    },
25};
26
27/// A client for downloading and installing the Arm Toolchain for Embedded (ATfE).
28#[derive(Clone)]
29pub struct ToolchainClient {
30    gh_client: Arc<Octocrab>,
31    client: reqwest::Client,
32    cache_path: PathBuf,
33    toolchains_path: PathBuf,
34    current_version: Arc<RwLock<Option<ToolchainVersion>>>,
35}
36
37impl Debug for ToolchainClient {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        f.debug_struct("ToolchainClient")
40            .field("cache_path", &self.cache_path)
41            .field("toolchains_path", &self.toolchains_path)
42            .finish()
43    }
44}
45
46impl ToolchainClient {
47    pub const REPO_OWNER: &str = "arm";
48    pub const REPO_NAME: &str = "arm-toolchain";
49    pub const RELEASE_PREFIX: &str = "release-";
50    pub const RELEASE_SUFFIX: &str = "-ATfE"; // arm toolchain for embedded
51    pub const CURRENT_TOOLCHAIN_FILENAME: &str = "current.txt";
52
53    /// Creates a new toolchain client that installs to a platform-specific data directory.
54    ///
55    /// For example, on macOS this is
56    /// `~/Library/Application Support/dev.vexide.arm-toolchain/llvm-toolchains`.
57    pub async fn using_data_dir() -> Result<Self, ToolchainError> {
58        Self::new(
59            DIRS.data_local_dir().join("llvm-toolchains"),
60            DIRS.cache_dir().join("downloads/llvm-toolchains"),
61        )
62        .await
63    }
64
65    /// Creates a client that installs toolchains in the specified folder.
66    pub async fn new(
67        toolchains_path: impl Into<PathBuf>,
68        cache_path: impl Into<PathBuf>,
69    ) -> Result<Self, ToolchainError> {
70        let toolchains_path = toolchains_path.into();
71        let cache_path = cache_path.into();
72        trace!(
73            ?toolchains_path,
74            ?cache_path,
75            "Initializing toolchain downloader"
76        );
77
78        let (current_version, setup_fut) = tokio::join!(
79            fs::read_to_string(toolchains_path.join(Self::CURRENT_TOOLCHAIN_FILENAME)),
80            async {
81                tokio::try_join!(
82                    fs::create_dir_all(&toolchains_path),
83                    fs::create_dir_all(&cache_path),
84                )
85            },
86        );
87
88        setup_fut?;
89
90        let current_version = current_version
91            .map(|name| ToolchainVersion::named(name.trim()))
92            .ok();
93
94        Ok(Self {
95            gh_client: octocrab::instance(),
96            client: reqwest::Client::builder()
97                .user_agent(APP_USER_AGENT)
98                .build()
99                .unwrap(),
100            toolchains_path,
101            cache_path,
102            current_version: Arc::new(RwLock::new(current_version)),
103        })
104    }
105
106    /// Fetches the latest release of the Arm Toolchain for Embedded (ATfE) from the ARM GitHub repository.
107    #[instrument(skip(self))]
108    pub async fn latest_release(&self) -> Result<ToolchainRelease, ToolchainError> {
109        debug!("Fetching latest release from GitHub repo");
110
111        let releases = self
112            .gh_client
113            .repos(Self::REPO_OWNER, Self::REPO_NAME)
114            .releases()
115            .list()
116            .per_page(10)
117            .send()
118            .await?;
119
120        let Some(latest_embedded_release) = releases
121            .items
122            .iter()
123            .find(|r| r.tag_name.ends_with(Self::RELEASE_SUFFIX))
124        else {
125            return Err(ToolchainError::LatestReleaseMissing {
126                candidates: releases.items.into_iter().map(|r| r.tag_name).collect(),
127            });
128        };
129
130        Ok(ToolchainRelease::new(latest_embedded_release.clone()))
131    }
132
133    /// Fetches the given release of the Arm Toolchain for Embedded (ATfE) from the ARM GitHub repository.
134    #[instrument(skip(self))]
135    pub async fn get_release(
136        &self,
137        version: &ToolchainVersion,
138    ) -> Result<ToolchainRelease, ToolchainError> {
139        let tag_name = version.to_tag_name();
140        info!(%tag_name, "Fetching release data from GitHub");
141
142        let release = self
143            .gh_client
144            .repos(Self::REPO_OWNER, Self::REPO_NAME)
145            .releases()
146            .get_by_tag(&tag_name)
147            .await?;
148
149        Ok(ToolchainRelease::new(release.clone()))
150    }
151
152    /// Returns the path where the given toolchain version would be installed.
153    pub fn install_path_for(&self, version: &ToolchainVersion) -> PathBuf {
154        self.toolchains_path.join(&version.name)
155    }
156
157    /// Checks if the specified toolchain version is already installed.
158    pub fn version_is_installed(&self, version: &ToolchainVersion) -> bool {
159        self.install_path_for(version).exists()
160    }
161
162    /// Downloads the specified toolchain asset, verifies its checksum, extracts it,
163    /// and installs it to the appropriate location.
164    ///
165    /// The downloaded toolchain will be activated if there is no other active toolchain. Returns
166    /// the path to the extracted toolchain directory.
167    ///
168    /// # Resuming downloads
169    ///
170    /// This method will also handle resuming downloads if the file already exists and is partially downloaded.
171    /// If the partially-downloaded file contains invalid bytes, a checksum error will be returned and the file
172    /// will be deleted.
173    #[instrument(
174        skip(self, release, asset, progress, cancel_token),
175        fields(version = release.version().name, asset.name)
176    )]
177    pub async fn download_and_install(
178        &self,
179        release: &ToolchainRelease,
180        asset: &Asset,
181        progress: Arc<dyn Fn(InstallState) + Send + Sync>,
182        cancel_token: CancellationToken,
183    ) -> Result<PathBuf, ToolchainError> {
184        let file_name = Utf8Path::new(&asset.name).file_name().ok_or_else(|| {
185            ToolchainError::InvalidAssetName {
186                name: asset.name.to_string(),
187            }
188        })?;
189        let archive_destination = self.cache_path.join(file_name);
190
191        debug!(asset.name, ?archive_destination, "Downloading asset");
192
193        // Begin downloading the checksum file in parallel so it's ready when we need it.
194        let checksum_future = self.fetch_asset_checksum(asset);
195
196        // Meanwhile, either begin or resume the asset download.
197        let download_task = async {
198            let mut downloaded_file = self
199                .download_asset(asset, &archive_destination, progress.clone())
200                .await?;
201
202            debug!("Calculating checksum for downloaded file");
203            let checksum_bytes =
204                calculate_file_checksum(&mut downloaded_file, progress.clone()).await?;
205            let checksum_hex = HEXLOWER.encode(&checksum_bytes);
206            trace!(?checksum_hex, "Checksum calculated");
207
208            Ok::<_, ToolchainError>((downloaded_file, checksum_hex))
209        };
210
211        let ((mut downloaded_file, real_checksum), expected_checksum) =
212            async { tokio::try_join!(download_task, checksum_future) }
213                .with_cancellation_token(&cancel_token)
214                .await
215                .ok_or(ToolchainError::Cancelled)??;
216
217        // Verify the checksum to make sure the download was successful and the file is not corrupted.
218
219        let checksums_match = real_checksum.eq_ignore_ascii_case(&expected_checksum);
220        debug!(
221            ?real_checksum,
222            ?expected_checksum,
223            "Checksum verification: {checksums_match}"
224        );
225        if !checksums_match {
226            fs::remove_file(archive_destination).await?;
227            return Err(ToolchainError::ChecksumMismatch {
228                expected: expected_checksum,
229                actual: real_checksum,
230            });
231        }
232
233        debug!("Download finished");
234
235        // Now choose the extraction method based on the file extension.
236
237        let extract_location = self.install_path_for(release.version());
238
239        cancel_token.check_cancellation(ToolchainError::Cancelled)?;
240
241        debug!(archive = ?archive_destination, ?extract_location, "Extracting downloaded archive");
242        progress(InstallState::ExtractBegin);
243
244        if extract_location.exists() {
245            debug!("Destination folder already exists, removing it");
246            TRASH.delete(&extract_location)?;
247        }
248
249        downloaded_file.seek(SeekFrom::Start(0)).await?;
250        if file_name.ends_with(".dmg") {
251            extract::macos::extract_dmg(
252                archive_destination.clone(),
253                &extract_location,
254                progress.clone(),
255                cancel_token,
256            )
257            .await?;
258        } else if file_name.ends_with(".zip") {
259            extract::extract_zip(downloaded_file, extract_location.clone()).await?;
260        } else if file_name.ends_with(".tar.xz") {
261            let progress = progress.clone();
262            extract::extract_tar_xz(
263                downloaded_file,
264                extract_location.clone(),
265                progress.clone(),
266                cancel_token,
267            )
268            .await?;
269        } else {
270            unreachable!("Unsupported file format");
271        }
272
273        progress(InstallState::ExtractCleanUp);
274        fs::remove_file(archive_destination).await?;
275
276        progress(InstallState::ExtractDone);
277
278        debug!("Updating current toolchain if necessary.");
279        if self.active_toolchain().is_none() {
280            let new_version = release.version().clone();
281            info!(%new_version, "Updating current toolchain");
282            self.set_active_toolchain(Some(release.version().clone()))
283                .await?;
284        }
285
286        Ok(extract_location)
287    }
288
289    /// Downloads the asset to the specified destination path without checksum verification or extraction.
290    ///
291    /// If the destination path already has a partially downloaded file, it will resume the download from where it left off.
292    #[instrument(skip(self, asset, progress))]
293    async fn download_asset(
294        &self,
295        asset: &Asset,
296        destination: &Path,
297        progress: Arc<dyn Fn(InstallState) + Send + Sync>,
298    ) -> Result<fs::File, ToolchainError> {
299        if let Some(parent) = destination.parent() {
300            fs::create_dir_all(parent).await?;
301        }
302
303        let mut file = fs::File::options()
304            .read(true)
305            .append(true)
306            .create(true)
307            .open(&destination)
308            .await?;
309
310        let mut current_file_length = file.seek(SeekFrom::End(0)).await?;
311
312        // Some initial checks before we start downloading to see if it makes sense to continue.
313
314        if current_file_length > asset.size as u64 {
315            // Having *too much* data doesn't make any sense... just restart the download from scratch.
316            warn!(
317                ?current_file_length,
318                ?asset.size,
319                "File size mismatch: existing file is larger than expected. Truncating file and starting over."
320            );
321
322            file.set_len(0).await?;
323            current_file_length = file.seek(SeekFrom::End(0)).await?;
324        }
325
326        if current_file_length == asset.size as u64 {
327            debug!("File already downloaded, skipping download");
328            return Ok(file);
329        }
330
331        // If there's already data in the file, we will assume that's from the last download attempt and
332        // set the Range header to continue downloading from where we left off.
333
334        let next_byte_index = current_file_length;
335        let last_byte_index = asset.size as u64 - 1;
336        let range_header = format!("bytes={next_byte_index}-{last_byte_index}");
337        trace!(?range_header, "Setting Range header for download");
338
339        if next_byte_index > 0 {
340            debug!("Resuming an existing download");
341        }
342
343        progress(InstallState::DownloadBegin {
344            asset_size: asset.size as u64,
345            bytes_read: current_file_length,
346        });
347
348        // At this point, we're all good to just start copying bytes from the stream to the file.
349
350        let mut stream = self
351            .client
352            .get(asset.browser_download_url.clone())
353            .header(header::RANGE, range_header)
354            .header(header::ACCEPT, "*/*")
355            .send()
356            .await?
357            .error_for_status()?
358            .bytes_stream();
359
360        let mut writer = BufWriter::new(file);
361
362        while let Some(chunk) = stream.try_next().await? {
363            writer.write_all(&chunk).await?;
364
365            current_file_length += chunk.len() as u64;
366            progress(InstallState::Download {
367                bytes_read: current_file_length,
368            });
369        }
370
371        writer.flush().await?;
372        progress(InstallState::DownloadFinish);
373        debug!(?destination, "Download completed");
374
375        Ok(writer.into_inner())
376    }
377
378    /// Downloads the expected SHA256 checksum for the asset.
379    ///
380    /// The resulting string contains the checksum in hex format.
381    async fn fetch_asset_checksum(&self, asset: &Asset) -> Result<String, ToolchainError> {
382        let mut sha256_url = asset.browser_download_url.clone();
383        sha256_url.set_path(&format!("{}.sha256", sha256_url.path()));
384
385        let mut checksum_file = self
386            .client
387            .get(sha256_url)
388            .send()
389            .await?
390            .error_for_status()?
391            .text()
392            .await?;
393
394        // Trim off the filename from the checksum file, which is usually in the format:
395        // `<checksum> <filename>`
396
397        let mut parts = checksum_file.split_ascii_whitespace();
398        let hash_part = parts.next().unwrap_or("");
399        checksum_file.truncate(hash_part.len());
400
401        Ok(checksum_file)
402    }
403
404    pub async fn installed_versions(&self) -> Result<Vec<ToolchainVersion>, ToolchainError> {
405        let mut futs = vec![];
406
407        let mut dir = fs::read_dir(&self.toolchains_path).await?;
408        while let Some(entry) = dir.next_entry().await? {
409            futs.push(async move {
410                if let Ok(ty) = entry.file_type().await
411                    && ty.is_dir()
412                {
413                    let name = entry.file_name();
414                    return Some(ToolchainVersion::named(name.to_string_lossy()));
415                }
416
417                None
418            });
419        }
420
421        let versions = join_all(futs).await.into_iter().flatten().collect();
422        Ok(versions)
423    }
424
425    /// Delete all files related to the given toolchain version.
426    pub async fn remove(
427        &self,
428        version: &ToolchainVersion,
429        progress: impl FnMut(RemoveProgress),
430        cancel_token: &CancellationToken,
431    ) -> Result<(), ToolchainError> {
432        if let Ok(toolchain) = self.toolchain(version).await {
433            remove_dir_progress(toolchain.path, progress, cancel_token).await?;
434        }
435
436        if self.active_toolchain().as_ref() == Some(version) {
437            self.set_active_toolchain(None).await?;
438        }
439
440        Ok(())
441    }
442
443    /// Delete the cache directory, returning the number of bytes deleted.
444    pub async fn purge_cache(&self) -> Result<u64, ToolchainError> {
445        let bytes = async {
446            let mut bytes = 0;
447
448            let mut read_dir = fs::read_dir(&self.cache_path).await?;
449            while let Some(item) = read_dir.next_entry().await? {
450                let meta = item.metadata().await?;
451                bytes += meta.len();
452            }
453
454            Ok::<u64, ToolchainError>(bytes)
455        };
456
457        let bytes = bytes.await.unwrap_or(0);
458        fs::remove_dir_all(&self.cache_path).await?;
459        Ok(bytes)
460    }
461
462    /// Get the version of the active (default) toolchain.
463    pub fn active_toolchain(&self) -> Option<ToolchainVersion> {
464        self.current_version.read().unwrap().clone()
465    }
466
467    /// Set the version of the active (default) toolchain.
468    ///
469    /// This will write the given value to disk.
470    pub async fn set_active_toolchain(
471        &self,
472        version: Option<ToolchainVersion>,
473    ) -> Result<(), ToolchainError> {
474        let path = self.toolchains_path.join(Self::CURRENT_TOOLCHAIN_FILENAME);
475
476        if let Some(version) = &version {
477            fs::write(path, &version.name).await?;
478        } else {
479            match fs::remove_file(path).await {
480                Ok(()) => Ok(()),
481                Err(e) if e.kind() == ErrorKind::NotFound => Ok(()),
482                other => other,
483            }?;
484        }
485
486        *self.current_version.write().unwrap() = version;
487
488        Ok(())
489    }
490
491    /// Returns a struct used to access paths of an installed toolchain.
492    ///
493    /// This doesn't check whether the specified version is actually installed,
494    /// so make sure the paths exist before using them.
495    pub async fn toolchain(
496        &self,
497        version: &ToolchainVersion,
498    ) -> Result<InstalledToolchain, ToolchainError> {
499        let toolchain = InstalledToolchain::new(self.toolchains_path.join(&version.name));
500        toolchain.check_installed().await?;
501        Ok(toolchain)
502    }
503}
504
505/// Scans an entire file and calculates its SHA256 checksum.
506async fn calculate_file_checksum(
507    file: &mut fs::File,
508    progress: Arc<dyn Fn(InstallState) + Send + Sync>,
509) -> Result<[u8; 32], io::Error> {
510    let file_size = file.metadata().await?.len();
511    progress(InstallState::VerifyingBegin {
512        asset_size: file_size,
513    });
514
515    file.seek(SeekFrom::Start(0)).await?;
516    let mut reader = BufReader::new(file);
517
518    let mut hasher = Sha256::default();
519    let mut data = vec![0; 64 * 1024];
520
521    let mut bytes_read = 0;
522    loop {
523        let len = reader.read(&mut data).await?;
524        if len == 0 {
525            break;
526        }
527
528        hasher.update(&data[..len]);
529
530        bytes_read += len as u64;
531        progress(InstallState::Verifying { bytes_read });
532    }
533
534    let checksum = hasher.finalize().into();
535
536    progress(InstallState::VerifyingFinish);
537
538    Ok(checksum)
539}