Skip to main content

gitlab_runner/
client.rs

1use reqwest::multipart::{Form, Part};
2use reqwest::{Body, StatusCode};
3use serde::de::Deserializer;
4use serde::{Deserialize, Serialize};
5use serde_json::Value as JsonValue;
6use std::collections::HashMap;
7use std::ops::Not;
8use std::time::Duration;
9use thiserror::Error;
10use tokio::io::{AsyncWrite, AsyncWriteExt};
11use url::Url;
12use zip::result::ZipError;
13
14fn deserialize_null_default<'de, D, T>(deserializer: D) -> Result<T, D::Error>
15where
16    T: Default + Deserialize<'de>,
17    D: Deserializer<'de>,
18{
19    let opt = Option::deserialize(deserializer)?;
20    Ok(opt.unwrap_or_default())
21}
22
23const GITLAB_TRACE_UPDATE_INTERVAL: &str = "X-GitLab-Trace-Update-Interval";
24const JOB_STATUS: &str = "Job-Status";
25
26#[derive(Debug, Default, Clone, Serialize)]
27struct FeaturesInfo {
28    #[serde(skip_serializing_if = "Not::not")]
29    variables: bool,
30    #[serde(skip_serializing_if = "Not::not")]
31    image: bool,
32    #[serde(skip_serializing_if = "Not::not")]
33    services: bool,
34    #[serde(skip_serializing_if = "Not::not")]
35    artifacts: bool,
36    #[serde(skip_serializing_if = "Not::not")]
37    cache: bool,
38    #[serde(skip_serializing_if = "Not::not")]
39    shared: bool,
40    #[serde(skip_serializing_if = "Not::not")]
41    upload_multiple_artifacts: bool,
42    #[serde(skip_serializing_if = "Not::not")]
43    upload_raw_artifacts: bool,
44    #[serde(skip_serializing_if = "Not::not")]
45    session: bool,
46    #[serde(skip_serializing_if = "Not::not")]
47    terminal: bool,
48    #[serde(skip_serializing_if = "Not::not")]
49    refspecs: bool,
50    #[serde(skip_serializing_if = "Not::not")]
51    masking: bool,
52    #[serde(skip_serializing_if = "Not::not")]
53    proxy: bool,
54    #[serde(skip_serializing_if = "Not::not")]
55    raw_variables: bool,
56    #[serde(skip_serializing_if = "Not::not")]
57    artifacts_exclude: bool,
58    #[serde(skip_serializing_if = "Not::not")]
59    multi_build_steps: bool,
60    #[serde(skip_serializing_if = "Not::not")]
61    trace_reset: bool,
62    #[serde(skip_serializing_if = "Not::not")]
63    trace_checksum: bool,
64    #[serde(skip_serializing_if = "Not::not")]
65    trace_size: bool,
66    #[serde(skip_serializing_if = "Not::not")]
67    vault_secrets: bool,
68    #[serde(skip_serializing_if = "Not::not")]
69    cancelable: bool,
70    #[serde(skip_serializing_if = "Not::not")]
71    return_exit_code: bool,
72    #[serde(skip_serializing_if = "Not::not")]
73    service_variables: bool,
74}
75
76#[derive(Debug, Clone, Serialize)]
77struct VersionInfo<'a> {
78    #[serde(flatten)]
79    metadata: &'a ClientMetadata,
80    features: FeaturesInfo,
81}
82
83#[derive(Debug, Clone, Serialize)]
84struct JobRequest<'a> {
85    token: &'a str,
86    system_id: &'a str,
87    info: VersionInfo<'a>,
88}
89
90#[derive(Debug, Clone, Serialize)]
91#[serde(rename_all = "lowercase")]
92#[allow(dead_code)]
93pub enum JobState {
94    Pending,
95    Running,
96    Success,
97    Failed,
98}
99
100#[derive(Debug, Clone, Serialize)]
101#[serde(rename = "lower_case")]
102struct JobUpdate<'a> {
103    token: &'a str,
104    state: JobState,
105}
106
107#[derive(Debug, Clone)]
108pub struct JobUpdateReply {
109    // GitLabs job update endpoint can include a suggested request rate in the response's HTTP header.
110    // Currently we only use this value from trace calls (e.g. appending to the jobs log).
111    // This field is kept around to document it's existence.
112    #[allow(dead_code)]
113    pub trace_update_interval: Option<Duration>,
114}
115
116#[derive(Debug, Clone)]
117pub struct TraceReply {
118    pub trace_update_interval: Option<Duration>,
119}
120
121#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
122pub(crate) struct JobVariable {
123    pub key: String,
124    #[serde(deserialize_with = "deserialize_null_default")]
125    pub value: String,
126    pub public: bool,
127    pub masked: bool,
128}
129
130fn variable_hash<'de, D>(deserializer: D) -> Result<HashMap<String, JobVariable>, D::Error>
131where
132    D: Deserializer<'de>,
133{
134    let hash = Vec::<JobVariable>::deserialize(deserializer)?
135        .drain(..)
136        .map(|v| (v.key.clone(), v))
137        .collect();
138    Ok(hash)
139}
140
141#[derive(Copy, Clone, Deserialize, Debug, Eq, PartialEq)]
142#[serde(rename_all = "snake_case")]
143pub enum JobStepWhen {
144    Always,
145    OnFailure,
146    OnSuccess,
147}
148
149/// Phase of the gitlab job steps
150#[derive(Copy, Clone, Deserialize, Debug, Eq, PartialEq)]
151#[serde(rename_all = "snake_case")]
152pub enum Phase {
153    /// script step; Practically this is before_script + script as defined in the gitlab job yaml
154    Script,
155    /// after_script step
156    AfterScript,
157}
158
159#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
160pub(crate) struct JobStep {
161    pub name: Phase,
162    pub script: Vec<String>,
163    pub timeout: u32,
164    pub when: JobStepWhen,
165    pub allow_failure: bool,
166}
167
168#[derive(Copy, Clone, Deserialize, Debug, Eq, PartialEq)]
169#[serde(rename_all = "snake_case")]
170#[derive(Default)]
171pub enum ArtifactWhen {
172    Always,
173    OnFailure,
174    #[default]
175    OnSuccess,
176}
177
178#[derive(Copy, Clone, Deserialize, Debug, Eq, PartialEq)]
179#[serde(rename_all = "snake_case")]
180#[derive(Default)]
181pub enum ArtifactFormat {
182    #[default]
183    Zip,
184    Gzip,
185    Raw,
186}
187
188impl std::fmt::Display for ArtifactFormat {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        let s = match self {
191            Self::Zip => "zip",
192            Self::Gzip => "gzip",
193            Self::Raw => "raw",
194        };
195        write!(f, "{s}")
196    }
197}
198
199#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
200pub(crate) struct JobArtifact {
201    pub name: Option<String>,
202    #[serde(default, deserialize_with = "deserialize_null_default")]
203    pub untracked: bool,
204    pub paths: Vec<String>,
205    #[serde(deserialize_with = "deserialize_null_default")]
206    pub when: ArtifactWhen,
207    pub artifact_type: String,
208    #[serde(deserialize_with = "deserialize_null_default")]
209    pub artifact_format: ArtifactFormat,
210    pub expire_in: Option<String>,
211}
212
213#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
214pub(crate) struct JobArtifactFile {
215    pub filename: String,
216    pub size: usize,
217}
218
219#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
220pub(crate) struct JobDependency {
221    pub id: u64,
222    pub name: String,
223    pub token: String,
224    pub artifacts_file: Option<JobArtifactFile>,
225}
226
227#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
228pub(crate) struct JobResponse {
229    pub id: u64,
230    pub token: String,
231    pub allow_git_fetch: bool,
232    #[serde(deserialize_with = "variable_hash")]
233    pub variables: HashMap<String, JobVariable>,
234    pub steps: Vec<JobStep>,
235    #[serde(deserialize_with = "deserialize_null_default")]
236    pub dependencies: Vec<JobDependency>,
237    #[serde(deserialize_with = "deserialize_null_default")]
238    pub artifacts: Vec<JobArtifact>,
239    #[serde(flatten)]
240    unparsed: JsonValue,
241}
242
243impl JobResponse {
244    pub fn step(&self, name: Phase) -> Option<&JobStep> {
245        self.steps.iter().find(|s| s.name == name)
246    }
247}
248
249#[derive(Error, Debug)]
250pub enum Error {
251    #[error("Unexpected reply code {0}")]
252    UnexpectedStatus(StatusCode),
253    #[error("Job cancelled")]
254    JobCancelled,
255    #[error("Request failure {0}")]
256    Request(#[from] reqwest::Error),
257    #[error("Failed to write to destination {0}")]
258    WriteFailure(#[source] futures::io::Error),
259    #[error("Failed to parse zip file: {0}")]
260    ZipFile(#[from] ZipError),
261    #[error("Empty trace")]
262    EmptyTrace,
263}
264
265pub(crate) struct ArtifactInfo<'a> {
266    pub name: &'a str,
267    pub artifact_format: &'a str,
268    pub artifact_type: &'a str,
269    pub expire_in: Option<&'a str>,
270}
271
272#[derive(Clone, Debug, Default, Serialize)]
273pub(crate) struct ClientMetadata {
274    #[serde(skip_serializing_if = "Option::is_none")]
275    pub(crate) version: Option<String>,
276    #[serde(skip_serializing_if = "Option::is_none")]
277    pub(crate) revision: Option<String>,
278    #[serde(skip_serializing_if = "Option::is_none")]
279    pub(crate) platform: Option<String>,
280    #[serde(skip_serializing_if = "Option::is_none")]
281    pub(crate) architecture: Option<String>,
282}
283
284#[derive(Clone, Debug)]
285pub(crate) struct Client {
286    client: reqwest::Client,
287    url: Url,
288    token: String,
289    system_id: String,
290    metadata: ClientMetadata,
291}
292
293impl Client {
294    pub fn new(url: Url, token: String, system_id: String, metadata: ClientMetadata) -> Self {
295        Self {
296            client: reqwest::Client::new(),
297            url,
298            token,
299            system_id,
300            metadata,
301        }
302    }
303
304    pub async fn request_job(&self) -> Result<Option<JobResponse>, Error> {
305        let request = JobRequest {
306            token: &self.token,
307            system_id: &self.system_id,
308            info: VersionInfo {
309                // Setting `refspecs` is required to run detached MR pipelines.
310                features: FeaturesInfo {
311                    refspecs: true,
312                    upload_multiple_artifacts: true,
313                    ..Default::default()
314                },
315                metadata: &self.metadata,
316            },
317        };
318
319        let mut url = self.url.clone();
320        url.path_segments_mut()
321            .unwrap()
322            .extend(&["api", "v4", "jobs", "request"]);
323
324        let r = self
325            .client
326            .post(url)
327            .json(&request)
328            .send()
329            .await?
330            .error_for_status()?;
331
332        match r.status() {
333            StatusCode::CREATED => Ok(Some(r.json().await?)),
334            StatusCode::NO_CONTENT => Ok(None),
335            _ => Err(Error::UnexpectedStatus(r.status())),
336        }
337    }
338
339    fn check_for_job_cancellation(&self, response: &reqwest::Response) -> Result<(), Error> {
340        match response.headers().get(JOB_STATUS) {
341            Some(header) if header == "canceled" => Err(Error::JobCancelled),
342            _ => Ok(()),
343        }
344    }
345
346    pub async fn update_job(
347        &self,
348        id: u64,
349        token: &str,
350        state: JobState,
351    ) -> Result<JobUpdateReply, Error> {
352        let mut url = self.url.clone();
353        let id_s = format!("{id}");
354        url.path_segments_mut()
355            .unwrap()
356            .extend(&["api", "v4", "jobs", &id_s]);
357
358        let update = JobUpdate { token, state };
359
360        let r = self.client.put(url).json(&update).send().await?;
361
362        self.check_for_job_cancellation(&r)?;
363
364        let trace_update_interval = r
365            .headers()
366            .get(GITLAB_TRACE_UPDATE_INTERVAL)
367            .and_then(|v| Some(Duration::from_secs(v.to_str().ok()?.parse().ok()?)));
368        match r.status() {
369            StatusCode::OK => Ok(JobUpdateReply {
370                trace_update_interval,
371            }),
372            _ => Err(Error::UnexpectedStatus(r.status())),
373        }
374    }
375
376    pub async fn trace<B>(
377        &self,
378        id: u64,
379        token: &str,
380        body: B,
381        start: usize,
382        length: usize,
383    ) -> Result<TraceReply, Error>
384    where
385        B: Into<Body>,
386    {
387        if length == 0 {
388            return Err(Error::EmptyTrace);
389        }
390
391        let mut url = self.url.clone();
392        let id_s = format!("{id}");
393        url.path_segments_mut()
394            .unwrap()
395            .extend(&["api", "v4", "jobs", &id_s, "trace"]);
396
397        let range = format!("{}-{}", start, start + length - 1);
398
399        let r = self
400            .client
401            .patch(url)
402            .header("JOB-TOKEN", token)
403            .header(reqwest::header::CONTENT_RANGE, range)
404            .header(reqwest::header::CONTENT_TYPE, "text/plain")
405            .body(body)
406            .send()
407            .await?;
408
409        self.check_for_job_cancellation(&r)?;
410
411        let trace_update_interval = r
412            .headers()
413            .get(GITLAB_TRACE_UPDATE_INTERVAL)
414            .and_then(|v| Some(Duration::from_secs(v.to_str().ok()?.parse().ok()?)));
415
416        match r.status() {
417            StatusCode::ACCEPTED => Ok(TraceReply {
418                trace_update_interval,
419            }),
420            _ => Err(Error::UnexpectedStatus(r.status())),
421        }
422    }
423
424    pub async fn download_artifact<D: AsyncWrite + Unpin>(
425        &self,
426        id: u64,
427        token: &str,
428        mut dest: D,
429    ) -> Result<(), Error> {
430        let mut url = self.url.clone();
431        let id_s = format!("{id}");
432        url.path_segments_mut()
433            .unwrap()
434            .extend(&["api", "v4", "jobs", &id_s, "artifacts"]);
435
436        let mut r = self
437            .client
438            .get(url)
439            .header("JOB-TOKEN", token)
440            .send()
441            .await?;
442
443        match r.status() {
444            StatusCode::OK => {
445                while let Some(ref chunk) = r.chunk().await? {
446                    dest.write_all(chunk).await.map_err(Error::WriteFailure)?
447                }
448                Ok(())
449            }
450            _ => Err(Error::UnexpectedStatus(r.status())),
451        }
452    }
453
454    pub async fn upload_artifact<D>(
455        &self,
456        id: u64,
457        token: &str,
458        info: ArtifactInfo<'_>,
459        data: D,
460    ) -> Result<(), Error>
461    where
462        D: Into<Body>,
463    {
464        let part = Part::stream(data).file_name(info.name.to_string());
465        let form = Form::new()
466            .part("file", part)
467            .text("artifact_format", info.artifact_format.to_string())
468            .text("artifact_type", info.artifact_type.to_string());
469
470        let form = if let Some(expiry) = info.expire_in {
471            form.text("expire_in", expiry.to_string())
472        } else {
473            form
474        };
475
476        let mut url = self.url.clone();
477        let id_s = format!("{id}");
478        url.path_segments_mut()
479            .unwrap()
480            .extend(&["api", "v4", "jobs", &id_s, "artifacts"]);
481
482        let r = self
483            .client
484            .post(url)
485            .header("JOB-TOKEN", token)
486            .multipart(form)
487            .send()
488            .await?;
489
490        match r.status() {
491            StatusCode::CREATED => Ok(()),
492            _ => Err(Error::UnexpectedStatus(r.status())),
493        }
494    }
495}
496
497#[cfg(test)]
498mod test {
499    use super::*;
500    use gitlab_runner_mock::GitlabRunnerMock;
501    use serde_json::json;
502
503    #[test]
504    fn deserialize_variables() {
505        #[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
506        struct Test {
507            #[serde(deserialize_with = "variable_hash")]
508            variables: HashMap<String, JobVariable>,
509        }
510
511        let json = json!({
512            "variables": [
513                { "key": "VAR1", "value": "1", "public": true, "masked": false },
514                { "key": "VAR2", "value": "2", "public": false, "masked": true }
515            ]
516        });
517
518        let t: Test = serde_json::from_str(&json.to_string()).expect("Failed to deserialize json");
519        assert_eq!(2, t.variables.len());
520        let v = t.variables.get("VAR1").unwrap();
521        assert_eq!(
522            &JobVariable {
523                key: "VAR1".to_string(),
524                value: "1".to_string(),
525                public: true,
526                masked: false
527            },
528            v
529        );
530        let v = t.variables.get("VAR2").unwrap();
531        assert_eq!(
532            &JobVariable {
533                key: "VAR2".to_string(),
534                value: "2".to_string(),
535                public: false,
536                masked: true
537            },
538            v
539        );
540    }
541
542    #[tokio::test]
543    async fn no_job() {
544        let mock = GitlabRunnerMock::start().await;
545
546        let client = Client::new(
547            mock.uri(),
548            mock.runner_token().to_string(),
549            "s_ystem_id1234".to_string(),
550            ClientMetadata::default(),
551        );
552
553        let job = client.request_job().await.unwrap();
554
555        assert_eq!(None, job);
556    }
557
558    #[tokio::test]
559    async fn process_job() {
560        let mock = GitlabRunnerMock::start().await;
561        mock.add_dummy_job("process job".to_string());
562
563        let client = Client::new(
564            mock.uri(),
565            mock.runner_token().to_string(),
566            "s_ystem_id1234".to_string(),
567            ClientMetadata::default(),
568        );
569
570        if let Some(job) = client.request_job().await.unwrap() {
571            client
572                .update_job(job.id, &job.token, JobState::Success)
573                .await
574                .unwrap();
575        } else {
576            panic!("No job!")
577        }
578
579        let job = client.request_job().await.unwrap();
580        assert_eq!(None, job);
581    }
582}