gitlab_runner/
client.rs

1use reqwest::multipart::{Form, Part};
2use reqwest::{Body, StatusCode};
3use serde::de::Deserializer;
4use serde::{Deserialize, Serialize};
5use serde_json::Value as JsonValue;
6use std::collections::HashMap;
7use std::ops::Not;
8use std::time::Duration;
9use thiserror::Error;
10use tokio::io::{AsyncWrite, AsyncWriteExt};
11use url::Url;
12use zip::result::ZipError;
13
14fn deserialize_null_default<'de, D, T>(deserializer: D) -> Result<T, D::Error>
15where
16    T: Default + Deserialize<'de>,
17    D: Deserializer<'de>,
18{
19    let opt = Option::deserialize(deserializer)?;
20    Ok(opt.unwrap_or_default())
21}
22
23const GITLAB_TRACE_UPDATE_INTERVAL: &str = "X-GitLab-Trace-Update-Interval";
24const JOB_STATUS: &str = "Job-Status";
25
26#[derive(Debug, Default, Clone, Serialize)]
27struct FeaturesInfo {
28    #[serde(skip_serializing_if = "Not::not")]
29    variables: bool,
30    #[serde(skip_serializing_if = "Not::not")]
31    image: bool,
32    #[serde(skip_serializing_if = "Not::not")]
33    services: bool,
34    #[serde(skip_serializing_if = "Not::not")]
35    artifacts: bool,
36    #[serde(skip_serializing_if = "Not::not")]
37    cache: bool,
38    #[serde(skip_serializing_if = "Not::not")]
39    shared: bool,
40    #[serde(skip_serializing_if = "Not::not")]
41    upload_multiple_artifacts: bool,
42    #[serde(skip_serializing_if = "Not::not")]
43    upload_raw_artifacts: bool,
44    #[serde(skip_serializing_if = "Not::not")]
45    session: bool,
46    #[serde(skip_serializing_if = "Not::not")]
47    terminal: bool,
48    #[serde(skip_serializing_if = "Not::not")]
49    refspecs: bool,
50    #[serde(skip_serializing_if = "Not::not")]
51    masking: bool,
52    #[serde(skip_serializing_if = "Not::not")]
53    proxy: bool,
54    #[serde(skip_serializing_if = "Not::not")]
55    raw_variables: bool,
56    #[serde(skip_serializing_if = "Not::not")]
57    artifacts_exclude: bool,
58    #[serde(skip_serializing_if = "Not::not")]
59    multi_build_steps: bool,
60    #[serde(skip_serializing_if = "Not::not")]
61    trace_reset: bool,
62    #[serde(skip_serializing_if = "Not::not")]
63    trace_checksum: bool,
64    #[serde(skip_serializing_if = "Not::not")]
65    trace_size: bool,
66    #[serde(skip_serializing_if = "Not::not")]
67    vault_secrets: bool,
68    #[serde(skip_serializing_if = "Not::not")]
69    cancelable: bool,
70    #[serde(skip_serializing_if = "Not::not")]
71    return_exit_code: bool,
72    #[serde(skip_serializing_if = "Not::not")]
73    service_variables: bool,
74}
75
76#[derive(Debug, Clone, Serialize)]
77struct VersionInfo<'a> {
78    #[serde(flatten)]
79    metadata: &'a ClientMetadata,
80    features: FeaturesInfo,
81}
82
83#[derive(Debug, Clone, Serialize)]
84struct JobRequest<'a> {
85    token: &'a str,
86    system_id: &'a str,
87    info: VersionInfo<'a>,
88}
89
90#[derive(Debug, Clone, Serialize)]
91#[serde(rename_all = "lowercase")]
92#[allow(dead_code)]
93pub enum JobState {
94    Pending,
95    Running,
96    Success,
97    Failed,
98}
99
100#[derive(Debug, Clone, Serialize)]
101#[serde(rename = "lower_case")]
102struct JobUpdate<'a> {
103    token: &'a str,
104    state: JobState,
105}
106
107#[derive(Debug, Clone)]
108pub struct JobUpdateReply {
109    // GitLabs job update endpoint can include a suggested request rate in the response's HTTP header.
110    // Currently we only use this value from trace calls (e.g. appending to the jobs log).
111    // This field is kept around to document it's existence.
112    #[allow(dead_code)]
113    pub trace_update_interval: Option<Duration>,
114}
115
116#[derive(Debug, Clone)]
117pub struct TraceReply {
118    pub trace_update_interval: Option<Duration>,
119}
120
121#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
122pub(crate) struct JobVariable {
123    pub key: String,
124    #[serde(deserialize_with = "deserialize_null_default")]
125    pub value: String,
126    pub public: bool,
127    pub masked: bool,
128}
129
130fn variable_hash<'de, D>(deserializer: D) -> Result<HashMap<String, JobVariable>, D::Error>
131where
132    D: Deserializer<'de>,
133{
134    let hash = Vec::<JobVariable>::deserialize(deserializer)?
135        .drain(..)
136        .map(|v| (v.key.clone(), v))
137        .collect();
138    Ok(hash)
139}
140
141#[derive(Copy, Clone, Deserialize, Debug, Eq, PartialEq)]
142#[serde(rename_all = "snake_case")]
143pub enum JobStepWhen {
144    Always,
145    OnFailure,
146    OnSuccess,
147}
148
149/// Phase of the gitlab job steps
150#[derive(Copy, Clone, Deserialize, Debug, Eq, PartialEq)]
151#[serde(rename_all = "snake_case")]
152pub enum Phase {
153    /// script step; Practically this is before_script + script as defined in the gitlab job yaml
154    Script,
155    /// after_script step
156    AfterScript,
157}
158
159#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
160pub(crate) struct JobStep {
161    pub name: Phase,
162    pub script: Vec<String>,
163    pub timeout: u32,
164    pub when: JobStepWhen,
165    pub allow_failure: bool,
166}
167
168#[derive(Copy, Clone, Deserialize, Debug, Eq, PartialEq)]
169#[serde(rename_all = "snake_case")]
170pub enum ArtifactWhen {
171    Always,
172    OnFailure,
173    OnSuccess,
174}
175
176impl Default for ArtifactWhen {
177    fn default() -> Self {
178        Self::OnSuccess
179    }
180}
181
182#[derive(Copy, Clone, Deserialize, Debug, Eq, PartialEq)]
183#[serde(rename_all = "snake_case")]
184pub enum ArtifactFormat {
185    Zip,
186    Gzip,
187    Raw,
188}
189
190impl Default for ArtifactFormat {
191    fn default() -> Self {
192        Self::Zip
193    }
194}
195
196impl std::fmt::Display for ArtifactFormat {
197    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198        let s = match self {
199            Self::Zip => "zip",
200            Self::Gzip => "gzip",
201            Self::Raw => "raw",
202        };
203        write!(f, "{}", s)
204    }
205}
206
207#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
208pub(crate) struct JobArtifact {
209    pub name: Option<String>,
210    #[serde(default, deserialize_with = "deserialize_null_default")]
211    pub untracked: bool,
212    pub paths: Vec<String>,
213    #[serde(deserialize_with = "deserialize_null_default")]
214    pub when: ArtifactWhen,
215    pub artifact_type: String,
216    #[serde(deserialize_with = "deserialize_null_default")]
217    pub artifact_format: ArtifactFormat,
218    pub expire_in: Option<String>,
219}
220
221#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
222pub(crate) struct JobArtifactFile {
223    pub filename: String,
224    pub size: usize,
225}
226
227#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
228pub(crate) struct JobDependency {
229    pub id: u64,
230    pub name: String,
231    pub token: String,
232    pub artifacts_file: Option<JobArtifactFile>,
233}
234
235#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
236pub(crate) struct JobResponse {
237    pub id: u64,
238    pub token: String,
239    pub allow_git_fetch: bool,
240    #[serde(deserialize_with = "variable_hash")]
241    pub variables: HashMap<String, JobVariable>,
242    pub steps: Vec<JobStep>,
243    #[serde(deserialize_with = "deserialize_null_default")]
244    pub dependencies: Vec<JobDependency>,
245    #[serde(deserialize_with = "deserialize_null_default")]
246    pub artifacts: Vec<JobArtifact>,
247    #[serde(flatten)]
248    unparsed: JsonValue,
249}
250
251impl JobResponse {
252    pub fn step(&self, name: Phase) -> Option<&JobStep> {
253        self.steps.iter().find(|s| s.name == name)
254    }
255}
256
257#[derive(Error, Debug)]
258pub enum Error {
259    #[error("Unexpected reply code {0}")]
260    UnexpectedStatus(StatusCode),
261    #[error("Job cancelled")]
262    JobCancelled,
263    #[error("Request failure {0}")]
264    Request(#[from] reqwest::Error),
265    #[error("Failed to write to destination {0}")]
266    WriteFailure(#[source] futures::io::Error),
267    #[error("Failed to parse zip file: {0}")]
268    ZipFile(#[from] ZipError),
269    #[error("Empty trace")]
270    EmptyTrace,
271}
272
273pub(crate) struct ArtifactInfo<'a> {
274    pub name: &'a str,
275    pub artifact_format: &'a str,
276    pub artifact_type: &'a str,
277    pub expire_in: Option<&'a str>,
278}
279
280#[derive(Clone, Debug, Default, Serialize)]
281pub(crate) struct ClientMetadata {
282    #[serde(skip_serializing_if = "Option::is_none")]
283    pub(crate) version: Option<String>,
284    #[serde(skip_serializing_if = "Option::is_none")]
285    pub(crate) revision: Option<String>,
286    #[serde(skip_serializing_if = "Option::is_none")]
287    pub(crate) platform: Option<String>,
288    #[serde(skip_serializing_if = "Option::is_none")]
289    pub(crate) architecture: Option<String>,
290}
291
292#[derive(Clone, Debug)]
293pub(crate) struct Client {
294    client: reqwest::Client,
295    url: Url,
296    token: String,
297    system_id: String,
298    metadata: ClientMetadata,
299}
300
301impl Client {
302    pub fn new(url: Url, token: String, system_id: String, metadata: ClientMetadata) -> Self {
303        Self {
304            client: reqwest::Client::new(),
305            url,
306            token,
307            system_id,
308            metadata,
309        }
310    }
311
312    pub async fn request_job(&self) -> Result<Option<JobResponse>, Error> {
313        let request = JobRequest {
314            token: &self.token,
315            system_id: &self.system_id,
316            info: VersionInfo {
317                // Setting `refspecs` is required to run detached MR pipelines.
318                features: FeaturesInfo {
319                    refspecs: true,
320                    upload_multiple_artifacts: true,
321                    ..Default::default()
322                },
323                metadata: &self.metadata,
324            },
325        };
326
327        let mut url = self.url.clone();
328        url.path_segments_mut()
329            .unwrap()
330            .extend(&["api", "v4", "jobs", "request"]);
331
332        let r = self
333            .client
334            .post(url)
335            .json(&request)
336            .send()
337            .await?
338            .error_for_status()?;
339
340        match r.status() {
341            StatusCode::CREATED => Ok(Some(r.json().await?)),
342            StatusCode::NO_CONTENT => Ok(None),
343            _ => Err(Error::UnexpectedStatus(r.status())),
344        }
345    }
346
347    fn check_for_job_cancellation(&self, response: &reqwest::Response) -> Result<(), Error> {
348        match response.headers().get(JOB_STATUS) {
349            Some(header) if header == "canceled" => Err(Error::JobCancelled),
350            _ => Ok(()),
351        }
352    }
353
354    pub async fn update_job(
355        &self,
356        id: u64,
357        token: &str,
358        state: JobState,
359    ) -> Result<JobUpdateReply, Error> {
360        let mut url = self.url.clone();
361        let id_s = format!("{}", id);
362        url.path_segments_mut()
363            .unwrap()
364            .extend(&["api", "v4", "jobs", &id_s]);
365
366        let update = JobUpdate { token, state };
367
368        let r = self.client.put(url).json(&update).send().await?;
369
370        self.check_for_job_cancellation(&r)?;
371
372        let trace_update_interval = r
373            .headers()
374            .get(GITLAB_TRACE_UPDATE_INTERVAL)
375            .and_then(|v| Some(Duration::from_secs(v.to_str().ok()?.parse().ok()?)));
376        match r.status() {
377            StatusCode::OK => Ok(JobUpdateReply {
378                trace_update_interval,
379            }),
380            _ => Err(Error::UnexpectedStatus(r.status())),
381        }
382    }
383
384    pub async fn trace<B>(
385        &self,
386        id: u64,
387        token: &str,
388        body: B,
389        start: usize,
390        length: usize,
391    ) -> Result<TraceReply, Error>
392    where
393        B: Into<Body>,
394    {
395        if length == 0 {
396            return Err(Error::EmptyTrace);
397        }
398
399        let mut url = self.url.clone();
400        let id_s = format!("{}", id);
401        url.path_segments_mut()
402            .unwrap()
403            .extend(&["api", "v4", "jobs", &id_s, "trace"]);
404
405        let range = format!("{}-{}", start, start + length - 1);
406
407        let r = self
408            .client
409            .patch(url)
410            .header("JOB-TOKEN", token)
411            .header(reqwest::header::CONTENT_RANGE, range)
412            .body(body)
413            .send()
414            .await?;
415
416        self.check_for_job_cancellation(&r)?;
417
418        let trace_update_interval = r
419            .headers()
420            .get(GITLAB_TRACE_UPDATE_INTERVAL)
421            .and_then(|v| Some(Duration::from_secs(v.to_str().ok()?.parse().ok()?)));
422
423        match r.status() {
424            StatusCode::ACCEPTED => Ok(TraceReply {
425                trace_update_interval,
426            }),
427            _ => Err(Error::UnexpectedStatus(r.status())),
428        }
429    }
430
431    pub async fn download_artifact<D: AsyncWrite + Unpin>(
432        &self,
433        id: u64,
434        token: &str,
435        mut dest: D,
436    ) -> Result<(), Error> {
437        let mut url = self.url.clone();
438        let id_s = format!("{}", id);
439        url.path_segments_mut()
440            .unwrap()
441            .extend(&["api", "v4", "jobs", &id_s, "artifacts"]);
442
443        let mut r = self
444            .client
445            .get(url)
446            .header("JOB-TOKEN", token)
447            .send()
448            .await?;
449
450        match r.status() {
451            StatusCode::OK => {
452                while let Some(ref chunk) = r.chunk().await? {
453                    dest.write_all(chunk).await.map_err(Error::WriteFailure)?
454                }
455                Ok(())
456            }
457            _ => Err(Error::UnexpectedStatus(r.status())),
458        }
459    }
460
461    pub async fn upload_artifact<D>(
462        &self,
463        id: u64,
464        token: &str,
465        info: ArtifactInfo<'_>,
466        data: D,
467    ) -> Result<(), Error>
468    where
469        D: Into<Body>,
470    {
471        let part = Part::stream(data).file_name(info.name.to_string());
472        let form = Form::new()
473            .part("file", part)
474            .text("artifact_format", info.artifact_format.to_string())
475            .text("artifact_type", info.artifact_type.to_string());
476
477        let form = if let Some(expiry) = info.expire_in {
478            form.text("expire_in", expiry.to_string())
479        } else {
480            form
481        };
482
483        let mut url = self.url.clone();
484        let id_s = format!("{}", id);
485        url.path_segments_mut()
486            .unwrap()
487            .extend(&["api", "v4", "jobs", &id_s, "artifacts"]);
488
489        let r = self
490            .client
491            .post(url)
492            .header("JOB-TOKEN", token)
493            .multipart(form)
494            .send()
495            .await?;
496
497        match r.status() {
498            StatusCode::CREATED => Ok(()),
499            _ => Err(Error::UnexpectedStatus(r.status())),
500        }
501    }
502}
503
504#[cfg(test)]
505mod test {
506    use super::*;
507    use gitlab_runner_mock::GitlabRunnerMock;
508    use serde_json::json;
509
510    #[test]
511    fn deserialize_variables() {
512        #[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
513        struct Test {
514            #[serde(deserialize_with = "variable_hash")]
515            variables: HashMap<String, JobVariable>,
516        }
517
518        let json = json!({
519            "variables": [
520                { "key": "VAR1", "value": "1", "public": true, "masked": false },
521                { "key": "VAR2", "value": "2", "public": false, "masked": true }
522            ]
523        });
524
525        let t: Test = serde_json::from_str(&json.to_string()).expect("Failed to deserialize json");
526        assert_eq!(2, t.variables.len());
527        let v = t.variables.get("VAR1").unwrap();
528        assert_eq!(
529            &JobVariable {
530                key: "VAR1".to_string(),
531                value: "1".to_string(),
532                public: true,
533                masked: false
534            },
535            v
536        );
537        let v = t.variables.get("VAR2").unwrap();
538        assert_eq!(
539            &JobVariable {
540                key: "VAR2".to_string(),
541                value: "2".to_string(),
542                public: false,
543                masked: true
544            },
545            v
546        );
547    }
548
549    #[tokio::test]
550    async fn no_job() {
551        let mock = GitlabRunnerMock::start().await;
552
553        let client = Client::new(
554            mock.uri(),
555            mock.runner_token().to_string(),
556            "s_ystem_id1234".to_string(),
557            ClientMetadata::default(),
558        );
559
560        let job = client.request_job().await.unwrap();
561
562        assert_eq!(None, job);
563    }
564
565    #[tokio::test]
566    async fn process_job() {
567        let mock = GitlabRunnerMock::start().await;
568        mock.add_dummy_job("process job".to_string());
569
570        let client = Client::new(
571            mock.uri(),
572            mock.runner_token().to_string(),
573            "s_ystem_id1234".to_string(),
574            ClientMetadata::default(),
575        );
576
577        if let Some(job) = client.request_job().await.unwrap() {
578            client
579                .update_job(job.id, &job.token, JobState::Success)
580                .await
581                .unwrap();
582        } else {
583            panic!("No job!")
584        }
585
586        let job = client.request_job().await.unwrap();
587        assert_eq!(None, job);
588    }
589}