1pub use azure_core::{
14 error::Result as AzureResult,
15 tokio::fs::{FileStream, FileStreamBuilder},
16 Body, SeekableStream,
17};
18
19use anyhow::{anyhow, bail, Result};
20use azure_storage_blobs::prelude::BlobClient;
21use chrono::{DateTime, Utc};
22use derive_more::From;
23use futures::{stream::TryStreamExt as _, StreamExt as _};
24use serde::{de::DeserializeOwned, Deserialize, Serialize};
25use serde_with::{serde_as, DisplayFromStr};
26use std::str::FromStr;
27use tokio::io::AsyncRead;
28use tokio_util::compat::FuturesAsyncReadCompatExt as _;
29use url::Url;
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
32#[serde(rename_all = "camelCase")]
33pub struct BackendIds {
34 pub workflow_run_backend_id: String,
35 pub workflow_job_run_backend_id: String,
36}
37
38impl FromStr for BackendIds {
39 type Err = anyhow::Error;
40
41 fn from_str(token: &str) -> Result<BackendIds> {
42 use base64::Engine as _;
43
44 let mut token_parts = token.split(".").skip(1);
45 let b64_part = token_parts
46 .next()
47 .ok_or_else(|| anyhow!("missing period"))?;
48 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
49 .decode(b64_part)
50 .map_err(|e| anyhow!("base64 invalid: {e}"))?;
51 let v = serde_json::from_slice::<serde_json::Value>(&decoded)?;
52
53 let scp = v
54 .get("scp")
55 .ok_or_else(|| anyhow!("missing 'scp' field"))?
56 .as_str()
57 .ok_or_else(|| anyhow!("'scp' field not a string"))?;
58
59 let scope_parts = scp
60 .split(" ")
61 .map(|p| p.split(":").collect::<Vec<_>>())
62 .find(|p| p[0] == "Actions.Results")
63 .ok_or_else(|| anyhow!("'Actions.Results' missing from 'scp' field"))?;
64
65 Ok(Self {
66 workflow_run_backend_id: scope_parts[1].into(),
67 workflow_job_run_backend_id: scope_parts[2].into(),
68 })
69 }
70}
71
72struct TwirpClient {
73 client: reqwest::Client,
74 token: String,
75 base_url: Url,
76 backend_ids: BackendIds,
77}
78
79impl TwirpClient {
80 fn new(token: &str, base_url: Url) -> Result<Self> {
81 Ok(Self {
82 client: reqwest::Client::new(),
83 token: token.into(),
84 base_url,
85 backend_ids: token.parse()?,
86 })
87 }
88
89 async fn request<BodyT: Serialize, RespT: DeserializeOwned>(
90 &self,
91 service: &str,
92 method: &str,
93 body: &BodyT,
94 ) -> Result<RespT> {
95 let req = self
96 .client
97 .post(
98 self.base_url
99 .join(&format!("twirp/{service}/{method}"))
100 .unwrap(),
101 )
102 .header("Content-Type", "application/json")
103 .header("User-Agent", "@actions/artifact-2.1.11")
104 .header(
105 "Authorization",
106 &format!("Bearer {token}", token = &self.token),
107 )
108 .json(body);
109
110 let resp = req.send().await?;
111 if !resp.status().is_success() {
112 bail!("{}", resp.text().await.unwrap());
113 }
114
115 Ok(resp.json().await?)
116 }
117}
118
119fn rfc3339_encode<S>(v: &Option<DateTime<Utc>>, s: S) -> std::result::Result<S::Ok, S::Error>
120where
121 S: serde::Serializer,
122{
123 s.serialize_str(
124 &v.unwrap()
125 .to_rfc3339_opts(chrono::SecondsFormat::Millis, true),
126 )
127}
128
129#[derive(Serialize)]
130#[serde(rename_all = "camelCase")]
131struct CreateArtifactRequest {
132 #[serde(flatten)]
133 backend_ids: BackendIds,
134 name: String,
135 #[serde(
136 skip_serializing_if = "Option::is_none",
137 serialize_with = "rfc3339_encode"
138 )]
139 expires_at: Option<DateTime<Utc>>,
140 version: u32,
141}
142
143#[test]
144fn create_artifact_json() {
145 use chrono::TimeZone as _;
146 use serde_json::json;
147
148 let req = CreateArtifactRequest {
149 backend_ids: BackendIds {
150 workflow_run_backend_id: "run_id".into(),
151 workflow_job_run_backend_id: "job_id".into(),
152 },
153 name: "foo".into(),
154 expires_at: Some(Utc.with_ymd_and_hms(2020, 1, 1, 0, 0, 0).unwrap()),
155 version: 4,
156 };
157 assert_eq!(
158 serde_json::to_value(&req).unwrap(),
159 json!({
160 "workflowRunBackendId": "run_id",
161 "workflowJobRunBackendId":"job_id",
162 "name": "foo",
163 "expiresAt": "2020-01-01T00:00:00.000Z",
164 "version": 4
165 })
166 );
167}
168
169#[derive(Serialize)]
170#[serde(rename_all = "camelCase")]
171struct FinalizeArtifactRequest {
172 #[serde(flatten)]
173 backend_ids: BackendIds,
174 name: String,
175 size: usize,
176}
177
178#[derive(Serialize)]
179#[serde(rename_all = "camelCase")]
180struct ListArtifactsRequest {
181 #[serde(flatten)]
182 backend_ids: BackendIds,
183 #[serde(skip_serializing_if = "Option::is_none")]
184 name_filter: Option<String>,
185 #[serde(skip_serializing_if = "Option::is_none")]
186 id_filter: Option<DatabaseId>,
187}
188
189#[serde_as]
190#[derive(Copy, Clone, Debug, From, Serialize, Deserialize, PartialEq, Eq)]
191#[serde(transparent)]
192pub struct DatabaseId(#[serde_as(as = "DisplayFromStr")] i64);
193
194#[serde_as]
195#[derive(Debug, Deserialize, PartialEq, Eq)]
196pub struct Artifact {
197 #[serde(flatten, with = "BackendIdsSnakeCase")]
198 pub backend_ids: BackendIds,
199 pub name: String,
200 #[serde_as(as = "DisplayFromStr")]
201 pub size: i64,
202 pub database_id: DatabaseId,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
206#[serde(remote = "BackendIds")]
207struct BackendIdsSnakeCase {
208 workflow_run_backend_id: String,
209 workflow_job_run_backend_id: String,
210}
211
212#[derive(Debug, Deserialize)]
213#[serde(rename_all = "camelCase")]
214struct ListArtifactsResponse {
215 artifacts: Vec<Artifact>,
216}
217
218#[derive(Serialize)]
219#[serde(rename_all = "camelCase")]
220struct GetSignedArtifactUrlRequest {
221 #[serde(flatten)]
222 backend_ids: BackendIds,
223 name: String,
224}
225
226#[derive(Debug, Deserialize)]
227struct CreateArtifactResponse {
228 signed_upload_url: String,
229}
230
231#[derive(Debug, Deserialize)]
232struct GetSignedArtifactUrlResponse {
233 signed_url: String,
234}
235
236pub struct GitHubClient {
237 client: TwirpClient,
238}
239
240impl GitHubClient {
241 pub fn new(token: &str, base_url: Url) -> Result<Self> {
242 Ok(Self {
243 client: TwirpClient::new(token, base_url)?,
244 })
245 }
246
247 pub async fn start_upload(
253 &self,
254 name: &str,
255 expires_at: Option<DateTime<Utc>>,
256 ) -> Result<BlobClient> {
257 let req = CreateArtifactRequest {
258 backend_ids: self.client.backend_ids.clone(),
259 name: name.into(),
260 expires_at,
261 version: 4,
262 };
263 let resp: CreateArtifactResponse = self
264 .client
265 .request(
266 "github.actions.results.api.v1.ArtifactService",
267 "CreateArtifact",
268 &req,
269 )
270 .await?;
271
272 let upload_url = url::Url::parse(&resp.signed_upload_url)?;
273 Ok(BlobClient::from_sas_url(&upload_url)?)
274 }
275
276 pub async fn finish_upload(&self, name: &str, content_length: usize) -> Result<()> {
281 let req = FinalizeArtifactRequest {
282 backend_ids: self.client.backend_ids.clone(),
283 name: name.into(),
284 size: content_length,
285 };
286 self.client
287 .request::<_, serde_json::Value>(
288 "github.actions.results.api.v1.ArtifactService",
289 "FinalizeArtifact",
290 &req,
291 )
292 .await?;
293 Ok(())
294 }
295
296 pub async fn upload(
300 &self,
301 name: &str,
302 expires_at: Option<DateTime<Utc>>,
303 content: impl Into<Body>,
304 ) -> Result<()> {
305 let blob_client = self.start_upload(name, expires_at).await?;
306 let body: Body = content.into();
307 let size = body.len();
308 blob_client
309 .put_block_blob(body)
310 .content_type("application/octet-stream")
311 .await?;
312 self.finish_upload(name, size).await?;
313 Ok(())
314 }
315
316 async fn list_internal(
317 &self,
318 name_filter: Option<String>,
319 id_filter: Option<DatabaseId>,
320 ) -> Result<Vec<Artifact>> {
321 let req = ListArtifactsRequest {
322 backend_ids: self.client.backend_ids.clone(),
323 name_filter,
324 id_filter,
325 };
326 let resp: ListArtifactsResponse = self
327 .client
328 .request(
329 "github.actions.results.api.v1.ArtifactService",
330 "ListArtifacts",
331 &req,
332 )
333 .await?;
334 Ok(resp.artifacts)
335 }
336
337 pub async fn list(&self) -> Result<Vec<Artifact>> {
339 self.list_internal(None, None).await
340 }
341
342 pub async fn get(&self, name: &str) -> Result<Option<Artifact>> {
344 let mut artifacts = self.list_internal(Some(name.into()), None).await?;
345 if artifacts.is_empty() {
346 return Ok(None);
347 }
348 if artifacts.len() > 1 {
349 bail!("invalid filtered list response");
350 }
351 Ok(Some(artifacts.remove(0)))
352 }
353
354 pub async fn get_by_id(&self, id: DatabaseId) -> Result<Option<Artifact>> {
356 let mut artifacts = self.list_internal(None, Some(id)).await?;
357 if artifacts.is_empty() {
358 return Ok(None);
359 }
360 if artifacts.len() > 1 {
361 bail!("invalid filtered list response");
362 }
363 Ok(Some(artifacts.remove(0)))
364 }
365
366 pub async fn start_download(&self, backend_ids: BackendIds, name: &str) -> Result<BlobClient> {
372 let req = GetSignedArtifactUrlRequest {
373 backend_ids,
374 name: name.into(),
375 };
376 let resp: GetSignedArtifactUrlResponse = self
377 .client
378 .request(
379 "github.actions.results.api.v1.ArtifactService",
380 "GetSignedArtifactURL",
381 &req,
382 )
383 .await?;
384 let url = Url::parse(&resp.signed_url)?;
385 Ok(BlobClient::from_sas_url(&url)?)
386 }
387
388 pub async fn download(
394 &self,
395 backend_ids: BackendIds,
396 name: &str,
397 ) -> Result<impl AsyncRead + Unpin + Send + Sync + 'static> {
398 let blob_client = self.start_download(backend_ids, name).await?;
399 let mut page_stream = blob_client.get().chunk_size(u64::MAX).into_stream();
400 let single_page = page_stream
401 .next()
402 .await
403 .ok_or_else(|| anyhow!("missing response"))??;
404 Ok(single_page
405 .data
406 .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
407 .into_async_read()
408 .compat())
409 }
410}
411
412#[cfg(test)]
413pub(crate) mod tests {
414 use super::*;
415 use crate::two_hours_from_now;
416
417 const TEST_TOKEN: &str = include_str!("test_token.b64");
418
419 #[test]
420 fn backend_ids_from_str_canned_example() {
421 let ids = BackendIds::from_str(TEST_TOKEN).unwrap();
422 assert_eq!(
423 ids,
424 BackendIds {
425 workflow_run_backend_id: "a4c8893f-39a2-4108-b278-a7d0fb589276".into(),
426 workflow_job_run_backend_id: "5264e576-3c6f-51f6-f055-fab409685f20".into()
427 }
428 );
429 }
430
431 #[test]
432 fn backend_ids_errors() {
433 fn test_error(s: &str, expected_error: &str) {
434 let actual_error = BackendIds::from_str(s).unwrap_err().to_string();
435 assert!(actual_error.contains(expected_error), "{actual_error}");
436 }
437 test_error("foobar", "missing period");
438 test_error("foo.bar", "base64 invalid");
439 test_error("foo.e30=", "base64 invalid: Invalid padding");
440 test_error("foo.e30", "missing 'scp' field");
441 test_error("foo.eyJzY3AiOjEyfQ", "'scp' field not a string");
442 test_error(
443 "foo.eyJzY3AiOiJmb28ifQ",
444 "'Actions.Results' missing from 'scp' field",
445 );
446 }
447
448 const TEST_DATA: &[u8] = include_bytes!("lib.rs");
449
450 pub fn client_factory() -> Option<GitHubClient> {
451 let token = std::env::var("ACTIONS_RUNTIME_TOKEN").ok()?;
452 let base_url = Url::parse(&std::env::var("ACTIONS_RESULTS_URL").ok()?).unwrap();
453 Some(GitHubClient::new(&token, base_url).unwrap())
454 }
455
456 #[tokio::test]
457 async fn real_github_integration_test() {
458 let Some(client) = client_factory() else {
459 println!("skipping due to missing GitHub credentials");
460 return;
461 };
462 println!("test found GitHub credentials");
463
464 if std::env::var("TEST_ACTOR").unwrap() != "1" {
465 return;
466 }
467
468 client
469 .upload("test_data", Some(two_hours_from_now()), TEST_DATA)
470 .await
471 .unwrap();
472
473 let listing = client.list().await.unwrap();
474 println!("got artifact listing {listing:?}");
475 assert!(listing.iter().find(|a| a.name == "test_data").is_some());
476
477 let artifact = client.get("test_data").await.unwrap().unwrap();
478
479 let artifact2 = client
480 .get_by_id(artifact.database_id)
481 .await
482 .unwrap()
483 .unwrap();
484 assert_eq!(&artifact, &artifact2);
485
486 assert_eq!(client.get("this_does_not_exist").await.unwrap(), None);
487
488 let backend_ids = &artifact.backend_ids;
489 let mut download_stream = client
490 .download(backend_ids.clone(), "test_data")
491 .await
492 .unwrap();
493
494 let mut downloaded = vec![];
495 tokio::io::copy(&mut download_stream, &mut downloaded)
496 .await
497 .unwrap();
498
499 assert_eq!(downloaded, TEST_DATA);
500 }
501}