1use std::{
2 path::{Path, PathBuf},
3 pin::Pin,
4 task::Poll,
5};
6
7use anyhow::{Context, Result};
8use async_compression::futures::bufread::{BzDecoder, GzipDecoder};
9use futures::{AsyncRead, AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt, io::BufReader};
10use sha2::{Digest, Sha256};
11
12use crate::{HttpClient, github::AssetKind};
13
14#[derive(serde::Deserialize, serde::Serialize, Debug)]
15pub struct GithubBinaryMetadata {
16 pub metadata_version: u64,
17 pub digest: Option<String>,
18}
19
20impl GithubBinaryMetadata {
21 pub async fn read_from_file(metadata_path: &Path) -> Result<GithubBinaryMetadata> {
22 let metadata_content = async_fs::read_to_string(metadata_path)
23 .await
24 .with_context(|| format!("reading metadata file at {metadata_path:?}"))?;
25 serde_json::from_str(&metadata_content)
26 .with_context(|| format!("parsing metadata file at {metadata_path:?}"))
27 }
28
29 pub async fn write_to_file(&self, metadata_path: &Path) -> Result<()> {
30 let metadata_content = serde_json::to_string(self)
31 .with_context(|| format!("serializing metadata for {metadata_path:?}"))?;
32 async_fs::write(metadata_path, metadata_content.as_bytes())
33 .await
34 .with_context(|| format!("writing metadata file at {metadata_path:?}"))?;
35 Ok(())
36 }
37}
38
39pub async fn download_server_binary(
40 http_client: &dyn HttpClient,
41 url: &str,
42 digest: Option<&str>,
43 destination_path: &Path,
44 asset_kind: AssetKind,
45) -> Result<(), anyhow::Error> {
46 log::info!("downloading github artifact from {url}");
47 let Some(destination_parent) = destination_path.parent() else {
48 anyhow::bail!("destination path has no parent: {destination_path:?}");
49 };
50
51 let staging_path = staging_path(destination_parent, asset_kind)?;
52 let mut response = http_client
53 .get(url, Default::default(), true)
54 .await
55 .with_context(|| format!("downloading release from {url}"))?;
56 let body = response.body_mut();
57
58 if let Err(err) = extract_to_staging(body, digest, url, &staging_path, asset_kind).await {
59 cleanup_staging_path(&staging_path, asset_kind).await;
60 return Err(err);
61 }
62
63 if let Err(err) = finalize_download(&staging_path, destination_path).await {
64 cleanup_staging_path(&staging_path, asset_kind).await;
65 return Err(err);
66 }
67
68 Ok(())
69}
70
71pub async fn download_server_raw_binary(
72 http_client: &dyn HttpClient,
73 url: &str,
74 digest: Option<&str>,
75 destination_path: &Path,
76 binary_file_name: &str,
77) -> Result<(), anyhow::Error> {
78 log::info!("downloading raw binary from {url}");
79 let Some(destination_parent) = destination_path.parent() else {
80 anyhow::bail!("destination path has no parent: {destination_path:?}");
81 };
82
83 let staging_path = staging_dir_path(destination_parent)?;
84 let result = async {
85 let mut response = http_client
86 .get(url, Default::default(), true)
87 .await
88 .with_context(|| format!("downloading release from {url}"))?;
89
90 let binary_path = staging_path.join(binary_file_name);
91 let mut writer = HashingWriter {
92 writer: async_fs::File::create(&binary_path)
93 .await
94 .with_context(|| format!("creating a file {binary_path:?} for {url}"))?,
95 hasher: Sha256::new(),
96 };
97 futures::io::copy(&mut BufReader::new(response.body_mut()), &mut writer)
98 .await
99 .with_context(|| format!("saving binary contents from {url}"))?;
100 let asset_sha_256 = writer
101 .finish()
102 .await
103 .with_context(|| format!("flushing binary contents for {url}"))?;
104
105 if let Some(expected_sha_256) = digest {
106 anyhow::ensure!(
107 asset_sha_256 == expected_sha_256,
108 "{url} asset got SHA-256 mismatch. Expected: {expected_sha_256}, Got: {asset_sha_256}",
109 );
110 }
111
112 open_gpui_util::fs::make_file_executable(&binary_path)
113 .await
114 .with_context(|| format!("marking {binary_path:?} as executable"))?;
115 finalize_download(&staging_path, destination_path).await
116 }
117 .await;
118
119 if let Err(err) = result {
120 if let Err(err) = async_fs::remove_dir_all(&staging_path).await {
121 log::warn!("failed to remove staging directory {staging_path:?}: {err:?}");
122 }
123 return Err(err);
124 }
125
126 Ok(())
127}
128
129async fn extract_to_staging(
130 body: impl AsyncRead + Unpin,
131 digest: Option<&str>,
132 url: &str,
133 staging_path: &Path,
134 asset_kind: AssetKind,
135) -> Result<()> {
136 match digest {
137 Some(expected_sha_256) => {
138 let temp_asset_file = tempfile::NamedTempFile::new()
139 .with_context(|| format!("creating a temporary file for {url}"))?;
140 let (temp_asset_file, _temp_guard) = temp_asset_file.into_parts();
141 let mut writer = HashingWriter {
142 writer: async_fs::File::from(temp_asset_file),
143 hasher: Sha256::new(),
144 };
145 futures::io::copy(&mut BufReader::new(body), &mut writer)
146 .await
147 .with_context(|| {
148 format!("saving archive contents into the temporary file for {url}")
149 })?;
150 let asset_sha_256 = format!("{:x}", writer.hasher.finalize());
151
152 anyhow::ensure!(
153 asset_sha_256 == expected_sha_256,
154 "{url} asset got SHA-256 mismatch. Expected: {expected_sha_256}, Got: {asset_sha_256}",
155 );
156 writer
157 .writer
158 .seek(std::io::SeekFrom::Start(0))
159 .await
160 .with_context(|| format!("seeking temporary file for {url}"))?;
161 stream_file_archive(&mut writer.writer, url, staging_path, asset_kind)
162 .await
163 .with_context(|| {
164 format!("extracting downloaded asset for {url} into {staging_path:?}")
165 })?;
166 }
167 None => {
168 stream_response_archive(body, url, staging_path, asset_kind)
169 .await
170 .with_context(|| {
171 format!("extracting response for asset {url} into {staging_path:?}")
172 })?;
173 }
174 }
175 Ok(())
176}
177
178fn staging_dir_path(parent: &Path) -> Result<PathBuf> {
179 let dir = tempfile::Builder::new()
180 .prefix(".tmp-github-download-")
181 .tempdir_in(parent)
182 .with_context(|| format!("creating staging directory in {parent:?}"))?;
183 Ok(dir.keep())
184}
185
186fn staging_path(parent: &Path, asset_kind: AssetKind) -> Result<PathBuf> {
187 match asset_kind {
188 AssetKind::TarGz | AssetKind::TarBz2 | AssetKind::Zip => staging_dir_path(parent),
189 AssetKind::Gz => {
190 let path = tempfile::Builder::new()
191 .prefix(".tmp-github-download-")
192 .tempfile_in(parent)
193 .with_context(|| format!("creating staging file in {parent:?}"))?
194 .into_temp_path()
195 .keep()
196 .with_context(|| format!("persisting staging file in {parent:?}"))?;
197 Ok(path)
198 }
199 }
200}
201
202async fn cleanup_staging_path(staging_path: &Path, asset_kind: AssetKind) {
203 match asset_kind {
204 AssetKind::TarGz | AssetKind::TarBz2 | AssetKind::Zip => {
205 if let Err(err) = async_fs::remove_dir_all(staging_path).await {
206 log::warn!("failed to remove staging directory {staging_path:?}: {err:?}");
207 }
208 }
209 AssetKind::Gz => {
210 if let Err(err) = async_fs::remove_file(staging_path).await {
211 log::warn!("failed to remove staging file {staging_path:?}: {err:?}");
212 }
213 }
214 }
215}
216
217async fn finalize_download(staging_path: &Path, destination_path: &Path) -> Result<()> {
218 _ = async_fs::remove_dir_all(destination_path).await;
219 async_fs::rename(staging_path, destination_path)
220 .await
221 .with_context(|| format!("renaming {staging_path:?} to {destination_path:?}"))?;
222 Ok(())
223}
224
225async fn stream_response_archive(
226 response: impl AsyncRead + Unpin,
227 url: &str,
228 destination_path: &Path,
229 asset_kind: AssetKind,
230) -> Result<()> {
231 match asset_kind {
232 AssetKind::TarGz => extract_tar_gz(destination_path, url, response).await?,
233 AssetKind::TarBz2 => extract_tar_bz2(destination_path, url, response).await?,
234 AssetKind::Gz => extract_gz(destination_path, url, response).await?,
235 AssetKind::Zip => {
236 open_gpui_util::archive::extract_zip(destination_path, response).await?;
237 }
238 };
239 Ok(())
240}
241
242async fn stream_file_archive(
243 file_archive: impl AsyncRead + AsyncSeek + Unpin,
244 url: &str,
245 destination_path: &Path,
246 asset_kind: AssetKind,
247) -> Result<()> {
248 match asset_kind {
249 AssetKind::TarGz => extract_tar_gz(destination_path, url, file_archive).await?,
250 AssetKind::TarBz2 => extract_tar_bz2(destination_path, url, file_archive).await?,
251 AssetKind::Gz => extract_gz(destination_path, url, file_archive).await?,
252 #[cfg(not(windows))]
253 AssetKind::Zip => {
254 open_gpui_util::archive::extract_seekable_zip(destination_path, file_archive).await?;
255 }
256 #[cfg(windows)]
257 AssetKind::Zip => {
258 open_gpui_util::archive::extract_zip(destination_path, file_archive).await?;
259 }
260 };
261 Ok(())
262}
263
264async fn extract_tar_gz(
265 destination_path: &Path,
266 url: &str,
267 from: impl AsyncRead + Unpin,
268) -> Result<(), anyhow::Error> {
269 let decompressed_bytes = GzipDecoder::new(BufReader::new(from));
270 unpack_tar_archive(destination_path, url, decompressed_bytes).await?;
271 Ok(())
272}
273
274async fn extract_tar_bz2(
275 destination_path: &Path,
276 url: &str,
277 from: impl AsyncRead + Unpin,
278) -> Result<(), anyhow::Error> {
279 let decompressed_bytes = BzDecoder::new(BufReader::new(from));
280 unpack_tar_archive(destination_path, url, decompressed_bytes).await?;
281 Ok(())
282}
283
284async fn unpack_tar_archive(
285 destination_path: &Path,
286 url: &str,
287 archive_bytes: impl AsyncRead + Unpin,
288) -> Result<(), anyhow::Error> {
289 let archive = async_tar::ArchiveBuilder::new(archive_bytes)
293 .set_preserve_mtime(false)
294 .build();
295 archive
296 .unpack(&destination_path)
297 .await
298 .with_context(|| format!("extracting {url} to {destination_path:?}"))?;
299 Ok(())
300}
301
302async fn extract_gz(
303 destination_path: &Path,
304 url: &str,
305 from: impl AsyncRead + Unpin,
306) -> Result<(), anyhow::Error> {
307 let mut decompressed_bytes = GzipDecoder::new(BufReader::new(from));
308 let mut file = async_fs::File::create(&destination_path)
309 .await
310 .with_context(|| {
311 format!("creating a file {destination_path:?} for a download from {url}")
312 })?;
313 futures::io::copy(&mut decompressed_bytes, &mut file)
314 .await
315 .with_context(|| format!("extracting {url} to {destination_path:?}"))?;
316 Ok(())
317}
318
319struct HashingWriter<W: AsyncWrite + Unpin> {
320 writer: W,
321 hasher: Sha256,
322}
323
324impl<W: AsyncWrite + Unpin> HashingWriter<W> {
325 async fn finish(mut self) -> std::io::Result<String> {
334 self.writer.close().await?;
335 drop(self.writer);
336 Ok(format!("{:x}", self.hasher.finalize()))
337 }
338}
339
340impl<W: AsyncWrite + Unpin> AsyncWrite for HashingWriter<W> {
341 fn poll_write(
342 mut self: Pin<&mut Self>,
343 cx: &mut std::task::Context<'_>,
344 buf: &[u8],
345 ) -> Poll<std::result::Result<usize, std::io::Error>> {
346 match Pin::new(&mut self.writer).poll_write(cx, buf) {
347 Poll::Ready(Ok(n)) => {
348 self.hasher.update(&buf[..n]);
349 Poll::Ready(Ok(n))
350 }
351 other => other,
352 }
353 }
354
355 fn poll_flush(
356 mut self: Pin<&mut Self>,
357 cx: &mut std::task::Context<'_>,
358 ) -> Poll<Result<(), std::io::Error>> {
359 Pin::new(&mut self.writer).poll_flush(cx)
360 }
361
362 fn poll_close(
363 mut self: Pin<&mut Self>,
364 cx: &mut std::task::Context<'_>,
365 ) -> Poll<std::result::Result<(), std::io::Error>> {
366 Pin::new(&mut self.writer).poll_close(cx)
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373 use crate::{AsyncBody, Response};
374 use futures::future::BoxFuture;
375 use http::HeaderValue;
376 use url::Url;
377
378 struct StaticResponseClient {
379 body: Vec<u8>,
380 }
381
382 impl HttpClient for StaticResponseClient {
383 fn send(
384 &self,
385 _req: http::Request<AsyncBody>,
386 ) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> {
387 let body = self.body.clone();
388 Box::pin(async move {
389 Ok(Response::builder()
390 .status(200)
391 .body(AsyncBody::from(body))
392 .unwrap())
393 })
394 }
395
396 fn user_agent(&self) -> Option<&HeaderValue> {
397 None
398 }
399
400 fn proxy(&self) -> Option<&Url> {
401 None
402 }
403 }
404
405 #[test]
406 fn downloads_raw_binary_into_destination_dir() {
407 futures::executor::block_on(async {
408 let temp_dir = tempfile::tempdir().unwrap();
409 let destination_path = temp_dir.path().join("v_1");
410 let contents = b"#!/bin/sh\necho hello\n".to_vec();
411 let expected_sha_256 = format!("{:x}", Sha256::digest(&contents));
412 let client = StaticResponseClient { body: contents };
413
414 download_server_raw_binary(
415 &client,
416 "https://example.com/agent-binary",
417 Some(&expected_sha_256),
418 &destination_path,
419 "agent-binary",
420 )
421 .await
422 .unwrap();
423
424 let binary_path = destination_path.join("agent-binary");
425 assert_eq!(
426 std::fs::read(&binary_path).unwrap(),
427 b"#!/bin/sh\necho hello\n"
428 );
429 #[cfg(unix)]
430 {
431 use std::os::unix::fs::PermissionsExt;
432 let mode = std::fs::metadata(&binary_path)
433 .unwrap()
434 .permissions()
435 .mode();
436 assert_eq!(mode & 0o111, 0o111, "binary should be executable");
437 }
438 });
439 }
440
441 #[test]
442 fn raw_binary_digest_mismatch_cleans_up_staging() {
443 futures::executor::block_on(async {
444 let temp_dir = tempfile::tempdir().unwrap();
445 let destination_path = temp_dir.path().join("v_1");
446 let client = StaticResponseClient {
447 body: b"some binary".to_vec(),
448 };
449
450 let error = download_server_raw_binary(
451 &client,
452 "https://example.com/agent-binary",
453 Some("0000000000000000000000000000000000000000000000000000000000000000"),
454 &destination_path,
455 "agent-binary",
456 )
457 .await
458 .unwrap_err();
459
460 assert!(error.to_string().contains("SHA-256 mismatch"));
461 assert!(!destination_path.exists());
462 let leftover_entries = std::fs::read_dir(temp_dir.path()).unwrap().count();
463 assert_eq!(leftover_entries, 0, "staging directory should be removed");
464 });
465 }
466}