maelstrom_github/
client.rs

1//! This module contains code that can communicate with GitHub's artifact API.
2//!
3//! This API works by doing HTTP requests to GitHub to manage the artifacts, but the actual data is
4//! stored in Azure, and the API returns signed-URLs to upload or download the data.
5//!
6//! The GitHub API is using the TWIRP RPC framework <https://github.com/twitchtv/twirp>.
7//!
8//! It seems that it uses protobuf definitions to define the body of requests. I'm not sure where
9//! to find these protobuf definitions, but we do have access to typescript that seems to be
10//! generated from them, which can be found here:
11//! <https://github.com/actions/toolkit/blob/main/packages/artifact/src/generated/results/api/v1/artifact.ts>
12
13pub use azure_core::{
14    error::Result as AzureResult,
15    tokio::fs::{FileStream, FileStreamBuilder},
16    Body, SeekableStream,
17};
18
19use anyhow::{anyhow, bail, Result};
20use azure_storage_blobs::prelude::BlobClient;
21use chrono::{DateTime, Utc};
22use derive_more::From;
23use futures::{stream::TryStreamExt as _, StreamExt as _};
24use serde::{de::DeserializeOwned, Deserialize, Serialize};
25use serde_with::{serde_as, DisplayFromStr};
26use std::str::FromStr;
27use tokio::io::AsyncRead;
28use tokio_util::compat::FuturesAsyncReadCompatExt as _;
29use url::Url;
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
32#[serde(rename_all = "camelCase")]
33pub struct BackendIds {
34    pub workflow_run_backend_id: String,
35    pub workflow_job_run_backend_id: String,
36}
37
38impl FromStr for BackendIds {
39    type Err = anyhow::Error;
40
41    fn from_str(token: &str) -> Result<BackendIds> {
42        use base64::Engine as _;
43
44        let mut token_parts = token.split(".").skip(1);
45        let b64_part = token_parts
46            .next()
47            .ok_or_else(|| anyhow!("missing period"))?;
48        let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
49            .decode(b64_part)
50            .map_err(|e| anyhow!("base64 invalid: {e}"))?;
51        let v = serde_json::from_slice::<serde_json::Value>(&decoded)?;
52
53        let scp = v
54            .get("scp")
55            .ok_or_else(|| anyhow!("missing 'scp' field"))?
56            .as_str()
57            .ok_or_else(|| anyhow!("'scp' field not a string"))?;
58
59        let scope_parts = scp
60            .split(" ")
61            .map(|p| p.split(":").collect::<Vec<_>>())
62            .find(|p| p[0] == "Actions.Results")
63            .ok_or_else(|| anyhow!("'Actions.Results' missing from 'scp' field"))?;
64
65        Ok(Self {
66            workflow_run_backend_id: scope_parts[1].into(),
67            workflow_job_run_backend_id: scope_parts[2].into(),
68        })
69    }
70}
71
72struct TwirpClient {
73    client: reqwest::Client,
74    token: String,
75    base_url: Url,
76    backend_ids: BackendIds,
77}
78
79impl TwirpClient {
80    fn new(token: &str, base_url: Url) -> Result<Self> {
81        Ok(Self {
82            client: reqwest::Client::new(),
83            token: token.into(),
84            base_url,
85            backend_ids: token.parse()?,
86        })
87    }
88
89    async fn request<BodyT: Serialize, RespT: DeserializeOwned>(
90        &self,
91        service: &str,
92        method: &str,
93        body: &BodyT,
94    ) -> Result<RespT> {
95        let req = self
96            .client
97            .post(
98                self.base_url
99                    .join(&format!("twirp/{service}/{method}"))
100                    .unwrap(),
101            )
102            .header("Content-Type", "application/json")
103            .header("User-Agent", "@actions/artifact-2.1.11")
104            .header(
105                "Authorization",
106                &format!("Bearer {token}", token = &self.token),
107            )
108            .json(body);
109
110        let resp = req.send().await?;
111        if !resp.status().is_success() {
112            bail!("{}", resp.text().await.unwrap());
113        }
114
115        Ok(resp.json().await?)
116    }
117}
118
119fn rfc3339_encode<S>(v: &Option<DateTime<Utc>>, s: S) -> std::result::Result<S::Ok, S::Error>
120where
121    S: serde::Serializer,
122{
123    s.serialize_str(
124        &v.unwrap()
125            .to_rfc3339_opts(chrono::SecondsFormat::Millis, true),
126    )
127}
128
129#[derive(Serialize)]
130#[serde(rename_all = "camelCase")]
131struct CreateArtifactRequest {
132    #[serde(flatten)]
133    backend_ids: BackendIds,
134    name: String,
135    #[serde(
136        skip_serializing_if = "Option::is_none",
137        serialize_with = "rfc3339_encode"
138    )]
139    expires_at: Option<DateTime<Utc>>,
140    version: u32,
141}
142
143#[test]
144fn create_artifact_json() {
145    use chrono::TimeZone as _;
146    use serde_json::json;
147
148    let req = CreateArtifactRequest {
149        backend_ids: BackendIds {
150            workflow_run_backend_id: "run_id".into(),
151            workflow_job_run_backend_id: "job_id".into(),
152        },
153        name: "foo".into(),
154        expires_at: Some(Utc.with_ymd_and_hms(2020, 1, 1, 0, 0, 0).unwrap()),
155        version: 4,
156    };
157    assert_eq!(
158        serde_json::to_value(&req).unwrap(),
159        json!({
160            "workflowRunBackendId": "run_id",
161            "workflowJobRunBackendId":"job_id",
162            "name": "foo",
163            "expiresAt": "2020-01-01T00:00:00.000Z",
164            "version": 4
165        })
166    );
167}
168
169#[derive(Serialize)]
170#[serde(rename_all = "camelCase")]
171struct FinalizeArtifactRequest {
172    #[serde(flatten)]
173    backend_ids: BackendIds,
174    name: String,
175    size: usize,
176}
177
178#[derive(Serialize)]
179#[serde(rename_all = "camelCase")]
180struct ListArtifactsRequest {
181    #[serde(flatten)]
182    backend_ids: BackendIds,
183    #[serde(skip_serializing_if = "Option::is_none")]
184    name_filter: Option<String>,
185    #[serde(skip_serializing_if = "Option::is_none")]
186    id_filter: Option<DatabaseId>,
187}
188
189#[serde_as]
190#[derive(Copy, Clone, Debug, From, Serialize, Deserialize, PartialEq, Eq)]
191#[serde(transparent)]
192pub struct DatabaseId(#[serde_as(as = "DisplayFromStr")] i64);
193
194#[serde_as]
195#[derive(Debug, Deserialize, PartialEq, Eq)]
196pub struct Artifact {
197    #[serde(flatten, with = "BackendIdsSnakeCase")]
198    pub backend_ids: BackendIds,
199    pub name: String,
200    #[serde_as(as = "DisplayFromStr")]
201    pub size: i64,
202    pub database_id: DatabaseId,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
206#[serde(remote = "BackendIds")]
207struct BackendIdsSnakeCase {
208    workflow_run_backend_id: String,
209    workflow_job_run_backend_id: String,
210}
211
212#[derive(Debug, Deserialize)]
213#[serde(rename_all = "camelCase")]
214struct ListArtifactsResponse {
215    artifacts: Vec<Artifact>,
216}
217
218#[derive(Serialize)]
219#[serde(rename_all = "camelCase")]
220struct GetSignedArtifactUrlRequest {
221    #[serde(flatten)]
222    backend_ids: BackendIds,
223    name: String,
224}
225
226#[derive(Debug, Deserialize)]
227struct CreateArtifactResponse {
228    signed_upload_url: String,
229}
230
231#[derive(Debug, Deserialize)]
232struct GetSignedArtifactUrlResponse {
233    signed_url: String,
234}
235
236pub struct GitHubClient {
237    client: TwirpClient,
238}
239
240impl GitHubClient {
241    pub fn new(token: &str, base_url: Url) -> Result<Self> {
242        Ok(Self {
243            client: TwirpClient::new(token, base_url)?,
244        })
245    }
246
247    /// Start an upload of an artifact. It returns a [`BlobClient`] which should be used to upload
248    /// your data. Once all the data has been written, [`Self::finish_upload`] must be called to
249    /// finalize the upload.
250    ///
251    /// The given name needs to be something unique, an error should be returned on collision.
252    pub async fn start_upload(
253        &self,
254        name: &str,
255        expires_at: Option<DateTime<Utc>>,
256    ) -> Result<BlobClient> {
257        let req = CreateArtifactRequest {
258            backend_ids: self.client.backend_ids.clone(),
259            name: name.into(),
260            expires_at,
261            version: 4,
262        };
263        let resp: CreateArtifactResponse = self
264            .client
265            .request(
266                "github.actions.results.api.v1.ArtifactService",
267                "CreateArtifact",
268                &req,
269            )
270            .await?;
271
272        let upload_url = url::Url::parse(&resp.signed_upload_url)?;
273        Ok(BlobClient::from_sas_url(&upload_url)?)
274    }
275
276    /// Meant to be called on an upload which was started via [`Self::start_upload`] which has had
277    /// all its data uploaded with the returned [`BlobClient`]. Once it returns success, the
278    /// artifact should be immediately available to be downloaded. If called on an artifact not in
279    /// this state, an error should be returned.
280    pub async fn finish_upload(&self, name: &str, content_length: usize) -> Result<()> {
281        let req = FinalizeArtifactRequest {
282            backend_ids: self.client.backend_ids.clone(),
283            name: name.into(),
284            size: content_length,
285        };
286        self.client
287            .request::<_, serde_json::Value>(
288                "github.actions.results.api.v1.ArtifactService",
289                "FinalizeArtifact",
290                &req,
291            )
292            .await?;
293        Ok(())
294    }
295
296    /// Upload the given content as an artifact. Once it returns success, the artifact should be
297    /// immediately available for download. The given content can be an in-memory buffer or a
298    /// [`FileStream`] created using [`FileStreamBuilder`].
299    pub async fn upload(
300        &self,
301        name: &str,
302        expires_at: Option<DateTime<Utc>>,
303        content: impl Into<Body>,
304    ) -> Result<()> {
305        let blob_client = self.start_upload(name, expires_at).await?;
306        let body: Body = content.into();
307        let size = body.len();
308        blob_client
309            .put_block_blob(body)
310            .content_type("application/octet-stream")
311            .await?;
312        self.finish_upload(name, size).await?;
313        Ok(())
314    }
315
316    async fn list_internal(
317        &self,
318        name_filter: Option<String>,
319        id_filter: Option<DatabaseId>,
320    ) -> Result<Vec<Artifact>> {
321        let req = ListArtifactsRequest {
322            backend_ids: self.client.backend_ids.clone(),
323            name_filter,
324            id_filter,
325        };
326        let resp: ListArtifactsResponse = self
327            .client
328            .request(
329                "github.actions.results.api.v1.ArtifactService",
330                "ListArtifacts",
331                &req,
332            )
333            .await?;
334        Ok(resp.artifacts)
335    }
336
337    /// List all the given artifacts accessible to the current workflow run.
338    pub async fn list(&self) -> Result<Vec<Artifact>> {
339        self.list_internal(None, None).await
340    }
341
342    /// Get the artifact represented by the given name if it exists.
343    pub async fn get(&self, name: &str) -> Result<Option<Artifact>> {
344        let mut artifacts = self.list_internal(Some(name.into()), None).await?;
345        if artifacts.is_empty() {
346            return Ok(None);
347        }
348        if artifacts.len() > 1 {
349            bail!("invalid filtered list response");
350        }
351        Ok(Some(artifacts.remove(0)))
352    }
353
354    /// Get the artifact represented by the given id if it exists.
355    pub async fn get_by_id(&self, id: DatabaseId) -> Result<Option<Artifact>> {
356        let mut artifacts = self.list_internal(None, Some(id)).await?;
357        if artifacts.is_empty() {
358            return Ok(None);
359        }
360        if artifacts.len() > 1 {
361            bail!("invalid filtered list response");
362        }
363        Ok(Some(artifacts.remove(0)))
364    }
365
366    /// Start a download of an artifact identified by the given name. The returned [`BlobClient`]
367    /// should be used to download all or part of the data.
368    ///
369    /// The `backend_ids` must be the ones for the artifact obtained from [`Self::list`]. An
370    /// individual uploader should end up with the same `backend_ids` for all artifacts it uploads.
371    pub async fn start_download(&self, backend_ids: BackendIds, name: &str) -> Result<BlobClient> {
372        let req = GetSignedArtifactUrlRequest {
373            backend_ids,
374            name: name.into(),
375        };
376        let resp: GetSignedArtifactUrlResponse = self
377            .client
378            .request(
379                "github.actions.results.api.v1.ArtifactService",
380                "GetSignedArtifactURL",
381                &req,
382            )
383            .await?;
384        let url = Url::parse(&resp.signed_url)?;
385        Ok(BlobClient::from_sas_url(&url)?)
386    }
387
388    /// Return a stream that downloads all the contents of the artifacts represented by the given
389    /// name.
390    ///
391    /// The `backend_ids` must be the ones for the artifact obtained from [`Self::list`]. An
392    /// individual uploader should end up with the same `backend_ids` for all artifacts it uploads.
393    pub async fn download(
394        &self,
395        backend_ids: BackendIds,
396        name: &str,
397    ) -> Result<impl AsyncRead + Unpin + Send + Sync + 'static> {
398        let blob_client = self.start_download(backend_ids, name).await?;
399        let mut page_stream = blob_client.get().chunk_size(u64::MAX).into_stream();
400        let single_page = page_stream
401            .next()
402            .await
403            .ok_or_else(|| anyhow!("missing response"))??;
404        Ok(single_page
405            .data
406            .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
407            .into_async_read()
408            .compat())
409    }
410}
411
412#[cfg(test)]
413pub(crate) mod tests {
414    use super::*;
415    use crate::two_hours_from_now;
416
417    const TEST_TOKEN: &str = include_str!("test_token.b64");
418
419    #[test]
420    fn backend_ids_from_str_canned_example() {
421        let ids = BackendIds::from_str(TEST_TOKEN).unwrap();
422        assert_eq!(
423            ids,
424            BackendIds {
425                workflow_run_backend_id: "a4c8893f-39a2-4108-b278-a7d0fb589276".into(),
426                workflow_job_run_backend_id: "5264e576-3c6f-51f6-f055-fab409685f20".into()
427            }
428        );
429    }
430
431    #[test]
432    fn backend_ids_errors() {
433        fn test_error(s: &str, expected_error: &str) {
434            let actual_error = BackendIds::from_str(s).unwrap_err().to_string();
435            assert!(actual_error.contains(expected_error), "{actual_error}");
436        }
437        test_error("foobar", "missing period");
438        test_error("foo.bar", "base64 invalid");
439        test_error("foo.e30=", "base64 invalid: Invalid padding");
440        test_error("foo.e30", "missing 'scp' field");
441        test_error("foo.eyJzY3AiOjEyfQ", "'scp' field not a string");
442        test_error(
443            "foo.eyJzY3AiOiJmb28ifQ",
444            "'Actions.Results' missing from 'scp' field",
445        );
446    }
447
448    const TEST_DATA: &[u8] = include_bytes!("lib.rs");
449
450    pub fn client_factory() -> Option<GitHubClient> {
451        let token = std::env::var("ACTIONS_RUNTIME_TOKEN").ok()?;
452        let base_url = Url::parse(&std::env::var("ACTIONS_RESULTS_URL").ok()?).unwrap();
453        Some(GitHubClient::new(&token, base_url).unwrap())
454    }
455
456    #[tokio::test]
457    async fn real_github_integration_test() {
458        let Some(client) = client_factory() else {
459            println!("skipping due to missing GitHub credentials");
460            return;
461        };
462        println!("test found GitHub credentials");
463
464        if std::env::var("TEST_ACTOR").unwrap() != "1" {
465            return;
466        }
467
468        client
469            .upload("test_data", Some(two_hours_from_now()), TEST_DATA)
470            .await
471            .unwrap();
472
473        let listing = client.list().await.unwrap();
474        println!("got artifact listing {listing:?}");
475        assert!(listing.iter().find(|a| a.name == "test_data").is_some());
476
477        let artifact = client.get("test_data").await.unwrap().unwrap();
478
479        let artifact2 = client
480            .get_by_id(artifact.database_id)
481            .await
482            .unwrap()
483            .unwrap();
484        assert_eq!(&artifact, &artifact2);
485
486        assert_eq!(client.get("this_does_not_exist").await.unwrap(), None);
487
488        let backend_ids = &artifact.backend_ids;
489        let mut download_stream = client
490            .download(backend_ids.clone(), "test_data")
491            .await
492            .unwrap();
493
494        let mut downloaded = vec![];
495        tokio::io::copy(&mut download_stream, &mut downloaded)
496            .await
497            .unwrap();
498
499        assert_eq!(downloaded, TEST_DATA);
500    }
501}