Skip to main content

open_gpui_http_client/
github_download.rs

1use std::{
2    path::{Path, PathBuf},
3    pin::Pin,
4    task::Poll,
5};
6
7use anyhow::{Context, Result};
8use async_compression::futures::bufread::{BzDecoder, GzipDecoder};
9use futures::{AsyncRead, AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt, io::BufReader};
10use sha2::{Digest, Sha256};
11
12use crate::{HttpClient, github::AssetKind};
13
14#[derive(serde::Deserialize, serde::Serialize, Debug)]
15pub struct GithubBinaryMetadata {
16    pub metadata_version: u64,
17    pub digest: Option<String>,
18}
19
20impl GithubBinaryMetadata {
21    pub async fn read_from_file(metadata_path: &Path) -> Result<GithubBinaryMetadata> {
22        let metadata_content = async_fs::read_to_string(metadata_path)
23            .await
24            .with_context(|| format!("reading metadata file at {metadata_path:?}"))?;
25        serde_json::from_str(&metadata_content)
26            .with_context(|| format!("parsing metadata file at {metadata_path:?}"))
27    }
28
29    pub async fn write_to_file(&self, metadata_path: &Path) -> Result<()> {
30        let metadata_content = serde_json::to_string(self)
31            .with_context(|| format!("serializing metadata for {metadata_path:?}"))?;
32        async_fs::write(metadata_path, metadata_content.as_bytes())
33            .await
34            .with_context(|| format!("writing metadata file at {metadata_path:?}"))?;
35        Ok(())
36    }
37}
38
39pub async fn download_server_binary(
40    http_client: &dyn HttpClient,
41    url: &str,
42    digest: Option<&str>,
43    destination_path: &Path,
44    asset_kind: AssetKind,
45) -> Result<(), anyhow::Error> {
46    log::info!("downloading github artifact from {url}");
47    let Some(destination_parent) = destination_path.parent() else {
48        anyhow::bail!("destination path has no parent: {destination_path:?}");
49    };
50
51    let staging_path = staging_path(destination_parent, asset_kind)?;
52    let mut response = http_client
53        .get(url, Default::default(), true)
54        .await
55        .with_context(|| format!("downloading release from {url}"))?;
56    let body = response.body_mut();
57
58    if let Err(err) = extract_to_staging(body, digest, url, &staging_path, asset_kind).await {
59        cleanup_staging_path(&staging_path, asset_kind).await;
60        return Err(err);
61    }
62
63    if let Err(err) = finalize_download(&staging_path, destination_path).await {
64        cleanup_staging_path(&staging_path, asset_kind).await;
65        return Err(err);
66    }
67
68    Ok(())
69}
70
71pub async fn download_server_raw_binary(
72    http_client: &dyn HttpClient,
73    url: &str,
74    digest: Option<&str>,
75    destination_path: &Path,
76    binary_file_name: &str,
77) -> Result<(), anyhow::Error> {
78    log::info!("downloading raw binary from {url}");
79    let Some(destination_parent) = destination_path.parent() else {
80        anyhow::bail!("destination path has no parent: {destination_path:?}");
81    };
82
83    let staging_path = staging_dir_path(destination_parent)?;
84    let result = async {
85        let mut response = http_client
86            .get(url, Default::default(), true)
87            .await
88            .with_context(|| format!("downloading release from {url}"))?;
89
90        let binary_path = staging_path.join(binary_file_name);
91        let mut writer = HashingWriter {
92            writer: async_fs::File::create(&binary_path)
93                .await
94                .with_context(|| format!("creating a file {binary_path:?} for {url}"))?,
95            hasher: Sha256::new(),
96        };
97        futures::io::copy(&mut BufReader::new(response.body_mut()), &mut writer)
98            .await
99            .with_context(|| format!("saving binary contents from {url}"))?;
100        let asset_sha_256 = writer
101            .finish()
102            .await
103            .with_context(|| format!("flushing binary contents for {url}"))?;
104
105        if let Some(expected_sha_256) = digest {
106            anyhow::ensure!(
107                asset_sha_256 == expected_sha_256,
108                "{url} asset got SHA-256 mismatch. Expected: {expected_sha_256}, Got: {asset_sha_256}",
109            );
110        }
111
112        open_gpui_util::fs::make_file_executable(&binary_path)
113            .await
114            .with_context(|| format!("marking {binary_path:?} as executable"))?;
115        finalize_download(&staging_path, destination_path).await
116    }
117    .await;
118
119    if let Err(err) = result {
120        if let Err(err) = async_fs::remove_dir_all(&staging_path).await {
121            log::warn!("failed to remove staging directory {staging_path:?}: {err:?}");
122        }
123        return Err(err);
124    }
125
126    Ok(())
127}
128
129async fn extract_to_staging(
130    body: impl AsyncRead + Unpin,
131    digest: Option<&str>,
132    url: &str,
133    staging_path: &Path,
134    asset_kind: AssetKind,
135) -> Result<()> {
136    match digest {
137        Some(expected_sha_256) => {
138            let temp_asset_file = tempfile::NamedTempFile::new()
139                .with_context(|| format!("creating a temporary file for {url}"))?;
140            let (temp_asset_file, _temp_guard) = temp_asset_file.into_parts();
141            let mut writer = HashingWriter {
142                writer: async_fs::File::from(temp_asset_file),
143                hasher: Sha256::new(),
144            };
145            futures::io::copy(&mut BufReader::new(body), &mut writer)
146                .await
147                .with_context(|| {
148                    format!("saving archive contents into the temporary file for {url}")
149                })?;
150            let asset_sha_256 = format!("{:x}", writer.hasher.finalize());
151
152            anyhow::ensure!(
153                asset_sha_256 == expected_sha_256,
154                "{url} asset got SHA-256 mismatch. Expected: {expected_sha_256}, Got: {asset_sha_256}",
155            );
156            writer
157                .writer
158                .seek(std::io::SeekFrom::Start(0))
159                .await
160                .with_context(|| format!("seeking temporary file for {url}"))?;
161            stream_file_archive(&mut writer.writer, url, staging_path, asset_kind)
162                .await
163                .with_context(|| {
164                    format!("extracting downloaded asset for {url} into {staging_path:?}")
165                })?;
166        }
167        None => {
168            stream_response_archive(body, url, staging_path, asset_kind)
169                .await
170                .with_context(|| {
171                    format!("extracting response for asset {url} into {staging_path:?}")
172                })?;
173        }
174    }
175    Ok(())
176}
177
178fn staging_dir_path(parent: &Path) -> Result<PathBuf> {
179    let dir = tempfile::Builder::new()
180        .prefix(".tmp-github-download-")
181        .tempdir_in(parent)
182        .with_context(|| format!("creating staging directory in {parent:?}"))?;
183    Ok(dir.keep())
184}
185
186fn staging_path(parent: &Path, asset_kind: AssetKind) -> Result<PathBuf> {
187    match asset_kind {
188        AssetKind::TarGz | AssetKind::TarBz2 | AssetKind::Zip => staging_dir_path(parent),
189        AssetKind::Gz => {
190            let path = tempfile::Builder::new()
191                .prefix(".tmp-github-download-")
192                .tempfile_in(parent)
193                .with_context(|| format!("creating staging file in {parent:?}"))?
194                .into_temp_path()
195                .keep()
196                .with_context(|| format!("persisting staging file in {parent:?}"))?;
197            Ok(path)
198        }
199    }
200}
201
202async fn cleanup_staging_path(staging_path: &Path, asset_kind: AssetKind) {
203    match asset_kind {
204        AssetKind::TarGz | AssetKind::TarBz2 | AssetKind::Zip => {
205            if let Err(err) = async_fs::remove_dir_all(staging_path).await {
206                log::warn!("failed to remove staging directory {staging_path:?}: {err:?}");
207            }
208        }
209        AssetKind::Gz => {
210            if let Err(err) = async_fs::remove_file(staging_path).await {
211                log::warn!("failed to remove staging file {staging_path:?}: {err:?}");
212            }
213        }
214    }
215}
216
217async fn finalize_download(staging_path: &Path, destination_path: &Path) -> Result<()> {
218    _ = async_fs::remove_dir_all(destination_path).await;
219    async_fs::rename(staging_path, destination_path)
220        .await
221        .with_context(|| format!("renaming {staging_path:?} to {destination_path:?}"))?;
222    Ok(())
223}
224
225async fn stream_response_archive(
226    response: impl AsyncRead + Unpin,
227    url: &str,
228    destination_path: &Path,
229    asset_kind: AssetKind,
230) -> Result<()> {
231    match asset_kind {
232        AssetKind::TarGz => extract_tar_gz(destination_path, url, response).await?,
233        AssetKind::TarBz2 => extract_tar_bz2(destination_path, url, response).await?,
234        AssetKind::Gz => extract_gz(destination_path, url, response).await?,
235        AssetKind::Zip => {
236            open_gpui_util::archive::extract_zip(destination_path, response).await?;
237        }
238    };
239    Ok(())
240}
241
242async fn stream_file_archive(
243    file_archive: impl AsyncRead + AsyncSeek + Unpin,
244    url: &str,
245    destination_path: &Path,
246    asset_kind: AssetKind,
247) -> Result<()> {
248    match asset_kind {
249        AssetKind::TarGz => extract_tar_gz(destination_path, url, file_archive).await?,
250        AssetKind::TarBz2 => extract_tar_bz2(destination_path, url, file_archive).await?,
251        AssetKind::Gz => extract_gz(destination_path, url, file_archive).await?,
252        #[cfg(not(windows))]
253        AssetKind::Zip => {
254            open_gpui_util::archive::extract_seekable_zip(destination_path, file_archive).await?;
255        }
256        #[cfg(windows)]
257        AssetKind::Zip => {
258            open_gpui_util::archive::extract_zip(destination_path, file_archive).await?;
259        }
260    };
261    Ok(())
262}
263
264async fn extract_tar_gz(
265    destination_path: &Path,
266    url: &str,
267    from: impl AsyncRead + Unpin,
268) -> Result<(), anyhow::Error> {
269    let decompressed_bytes = GzipDecoder::new(BufReader::new(from));
270    unpack_tar_archive(destination_path, url, decompressed_bytes).await?;
271    Ok(())
272}
273
274async fn extract_tar_bz2(
275    destination_path: &Path,
276    url: &str,
277    from: impl AsyncRead + Unpin,
278) -> Result<(), anyhow::Error> {
279    let decompressed_bytes = BzDecoder::new(BufReader::new(from));
280    unpack_tar_archive(destination_path, url, decompressed_bytes).await?;
281    Ok(())
282}
283
284async fn unpack_tar_archive(
285    destination_path: &Path,
286    url: &str,
287    archive_bytes: impl AsyncRead + Unpin,
288) -> Result<(), anyhow::Error> {
289    // We don't need to set the modified time. It's irrelevant to downloaded
290    // archive verification, and some filesystems return errors when asked to
291    // apply it after extraction.
292    let archive = async_tar::ArchiveBuilder::new(archive_bytes)
293        .set_preserve_mtime(false)
294        .build();
295    archive
296        .unpack(&destination_path)
297        .await
298        .with_context(|| format!("extracting {url} to {destination_path:?}"))?;
299    Ok(())
300}
301
302async fn extract_gz(
303    destination_path: &Path,
304    url: &str,
305    from: impl AsyncRead + Unpin,
306) -> Result<(), anyhow::Error> {
307    let mut decompressed_bytes = GzipDecoder::new(BufReader::new(from));
308    let mut file = async_fs::File::create(&destination_path)
309        .await
310        .with_context(|| {
311            format!("creating a file {destination_path:?} for a download from {url}")
312        })?;
313    futures::io::copy(&mut decompressed_bytes, &mut file)
314        .await
315        .with_context(|| format!("extracting {url} to {destination_path:?}"))?;
316    Ok(())
317}
318
319struct HashingWriter<W: AsyncWrite + Unpin> {
320    writer: W,
321    hasher: Sha256,
322}
323
324impl<W: AsyncWrite + Unpin> HashingWriter<W> {
325    /// Closes and drops the inner writer, returning the hex SHA-256 digest of
326    /// everything written.
327    ///
328    /// Taking `self` by value guarantees the writer is dropped before this
329    /// returns. For file writers this releases the OS handle, which Windows
330    /// requires before an ancestor directory can be renamed or deleted; note
331    /// that closing alone is not enough, as `async_fs::File` holds its handle
332    /// until dropped.
333    async fn finish(mut self) -> std::io::Result<String> {
334        self.writer.close().await?;
335        drop(self.writer);
336        Ok(format!("{:x}", self.hasher.finalize()))
337    }
338}
339
340impl<W: AsyncWrite + Unpin> AsyncWrite for HashingWriter<W> {
341    fn poll_write(
342        mut self: Pin<&mut Self>,
343        cx: &mut std::task::Context<'_>,
344        buf: &[u8],
345    ) -> Poll<std::result::Result<usize, std::io::Error>> {
346        match Pin::new(&mut self.writer).poll_write(cx, buf) {
347            Poll::Ready(Ok(n)) => {
348                self.hasher.update(&buf[..n]);
349                Poll::Ready(Ok(n))
350            }
351            other => other,
352        }
353    }
354
355    fn poll_flush(
356        mut self: Pin<&mut Self>,
357        cx: &mut std::task::Context<'_>,
358    ) -> Poll<Result<(), std::io::Error>> {
359        Pin::new(&mut self.writer).poll_flush(cx)
360    }
361
362    fn poll_close(
363        mut self: Pin<&mut Self>,
364        cx: &mut std::task::Context<'_>,
365    ) -> Poll<std::result::Result<(), std::io::Error>> {
366        Pin::new(&mut self.writer).poll_close(cx)
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use crate::{AsyncBody, Response};
374    use futures::future::BoxFuture;
375    use http::HeaderValue;
376    use url::Url;
377
378    struct StaticResponseClient {
379        body: Vec<u8>,
380    }
381
382    impl HttpClient for StaticResponseClient {
383        fn send(
384            &self,
385            _req: http::Request<AsyncBody>,
386        ) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> {
387            let body = self.body.clone();
388            Box::pin(async move {
389                Ok(Response::builder()
390                    .status(200)
391                    .body(AsyncBody::from(body))
392                    .unwrap())
393            })
394        }
395
396        fn user_agent(&self) -> Option<&HeaderValue> {
397            None
398        }
399
400        fn proxy(&self) -> Option<&Url> {
401            None
402        }
403    }
404
405    #[test]
406    fn downloads_raw_binary_into_destination_dir() {
407        futures::executor::block_on(async {
408            let temp_dir = tempfile::tempdir().unwrap();
409            let destination_path = temp_dir.path().join("v_1");
410            let contents = b"#!/bin/sh\necho hello\n".to_vec();
411            let expected_sha_256 = format!("{:x}", Sha256::digest(&contents));
412            let client = StaticResponseClient { body: contents };
413
414            download_server_raw_binary(
415                &client,
416                "https://example.com/agent-binary",
417                Some(&expected_sha_256),
418                &destination_path,
419                "agent-binary",
420            )
421            .await
422            .unwrap();
423
424            let binary_path = destination_path.join("agent-binary");
425            assert_eq!(
426                std::fs::read(&binary_path).unwrap(),
427                b"#!/bin/sh\necho hello\n"
428            );
429            #[cfg(unix)]
430            {
431                use std::os::unix::fs::PermissionsExt;
432                let mode = std::fs::metadata(&binary_path)
433                    .unwrap()
434                    .permissions()
435                    .mode();
436                assert_eq!(mode & 0o111, 0o111, "binary should be executable");
437            }
438        });
439    }
440
441    #[test]
442    fn raw_binary_digest_mismatch_cleans_up_staging() {
443        futures::executor::block_on(async {
444            let temp_dir = tempfile::tempdir().unwrap();
445            let destination_path = temp_dir.path().join("v_1");
446            let client = StaticResponseClient {
447                body: b"some binary".to_vec(),
448            };
449
450            let error = download_server_raw_binary(
451                &client,
452                "https://example.com/agent-binary",
453                Some("0000000000000000000000000000000000000000000000000000000000000000"),
454                &destination_path,
455                "agent-binary",
456            )
457            .await
458            .unwrap_err();
459
460            assert!(error.to_string().contains("SHA-256 mismatch"));
461            assert!(!destination_path.exists());
462            let leftover_entries = std::fs::read_dir(temp_dir.path()).unwrap().count();
463            assert_eq!(leftover_entries, 0, "staging directory should be removed");
464        });
465    }
466}