Skip to main content

gitlab_runner/
client.rs

1use reqwest::multipart::{Form, Part};
2use reqwest::{Body, StatusCode};
3use serde::de::Deserializer;
4use serde::{Deserialize, Serialize};
5use serde_json::Value as JsonValue;
6use std::collections::HashMap;
7use std::ops::Not;
8use std::time::Duration;
9use thiserror::Error;
10use tokio::io::{AsyncWrite, AsyncWriteExt};
11use url::Url;
12use zip::result::ZipError;
13
14fn deserialize_null_default<'de, D, T>(deserializer: D) -> Result<T, D::Error>
15where
16    T: Default + Deserialize<'de>,
17    D: Deserializer<'de>,
18{
19    let opt = Option::deserialize(deserializer)?;
20    Ok(opt.unwrap_or_default())
21}
22
23const GITLAB_TRACE_UPDATE_INTERVAL: &str = "X-GitLab-Trace-Update-Interval";
24const JOB_STATUS: &str = "Job-Status";
25
26#[derive(Debug, Default, Clone, Serialize)]
27struct FeaturesInfo {
28    #[serde(skip_serializing_if = "Not::not")]
29    variables: bool,
30    #[serde(skip_serializing_if = "Not::not")]
31    image: bool,
32    #[serde(skip_serializing_if = "Not::not")]
33    services: bool,
34    #[serde(skip_serializing_if = "Not::not")]
35    artifacts: bool,
36    #[serde(skip_serializing_if = "Not::not")]
37    cache: bool,
38    #[serde(skip_serializing_if = "Not::not")]
39    shared: bool,
40    #[serde(skip_serializing_if = "Not::not")]
41    upload_multiple_artifacts: bool,
42    #[serde(skip_serializing_if = "Not::not")]
43    upload_raw_artifacts: bool,
44    #[serde(skip_serializing_if = "Not::not")]
45    session: bool,
46    #[serde(skip_serializing_if = "Not::not")]
47    terminal: bool,
48    #[serde(skip_serializing_if = "Not::not")]
49    refspecs: bool,
50    #[serde(skip_serializing_if = "Not::not")]
51    masking: bool,
52    #[serde(skip_serializing_if = "Not::not")]
53    proxy: bool,
54    #[serde(skip_serializing_if = "Not::not")]
55    raw_variables: bool,
56    #[serde(skip_serializing_if = "Not::not")]
57    artifacts_exclude: bool,
58    #[serde(skip_serializing_if = "Not::not")]
59    multi_build_steps: bool,
60    #[serde(skip_serializing_if = "Not::not")]
61    trace_reset: bool,
62    #[serde(skip_serializing_if = "Not::not")]
63    trace_checksum: bool,
64    #[serde(skip_serializing_if = "Not::not")]
65    trace_size: bool,
66    #[serde(skip_serializing_if = "Not::not")]
67    vault_secrets: bool,
68    #[serde(skip_serializing_if = "Not::not")]
69    cancelable: bool,
70    #[serde(skip_serializing_if = "Not::not")]
71    return_exit_code: bool,
72    #[serde(skip_serializing_if = "Not::not")]
73    service_variables: bool,
74}
75
76#[derive(Debug, Clone, Serialize)]
77struct VersionInfo<'a> {
78    #[serde(flatten)]
79    metadata: &'a ClientMetadata,
80    features: FeaturesInfo,
81}
82
83#[derive(Debug, Clone, Serialize)]
84struct JobRequest<'a> {
85    token: &'a str,
86    system_id: &'a str,
87    info: VersionInfo<'a>,
88}
89
90#[derive(Debug, Clone, Serialize)]
91#[serde(rename_all = "lowercase")]
92#[allow(dead_code)]
93pub enum JobState {
94    Pending,
95    Running,
96    Success,
97    Failed,
98}
99
100#[derive(Debug, Clone, Serialize)]
101#[serde(rename = "lower_case")]
102struct JobUpdate<'a> {
103    token: &'a str,
104    state: JobState,
105}
106
107#[derive(Debug, Clone)]
108pub struct JobUpdateReply {
109    // GitLabs job update endpoint can include a suggested request rate in the response's HTTP header.
110    // Currently we only use this value from trace calls (e.g. appending to the jobs log).
111    // This field is kept around to document it's existence.
112    #[allow(dead_code)]
113    pub trace_update_interval: Option<Duration>,
114}
115
116#[derive(Debug, Clone)]
117pub struct TraceReply {
118    pub trace_update_interval: Option<Duration>,
119}
120
121#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
122pub(crate) struct JobVariable {
123    pub key: String,
124    #[serde(deserialize_with = "deserialize_null_default")]
125    pub value: String,
126    pub public: bool,
127    pub masked: bool,
128}
129
130fn variable_hash<'de, D>(deserializer: D) -> Result<HashMap<String, JobVariable>, D::Error>
131where
132    D: Deserializer<'de>,
133{
134    let hash = Vec::<JobVariable>::deserialize(deserializer)?
135        .drain(..)
136        .map(|v| (v.key.clone(), v))
137        .collect();
138    Ok(hash)
139}
140
141#[derive(Copy, Clone, Deserialize, Debug, Eq, PartialEq)]
142#[serde(rename_all = "snake_case")]
143pub enum JobStepWhen {
144    Always,
145    OnFailure,
146    OnSuccess,
147}
148
149/// Phase of the gitlab job steps
150#[derive(Copy, Clone, Deserialize, Debug, Eq, PartialEq)]
151#[serde(rename_all = "snake_case")]
152pub enum Phase {
153    /// script step; Practically this is before_script + script as defined in the gitlab job yaml
154    Script,
155    /// after_script step
156    AfterScript,
157}
158
159#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
160pub(crate) struct JobStep {
161    pub name: Phase,
162    pub script: Vec<String>,
163    pub timeout: u32,
164    pub when: JobStepWhen,
165    pub allow_failure: bool,
166}
167
168#[derive(Copy, Clone, Deserialize, Debug, Eq, PartialEq)]
169#[serde(rename_all = "snake_case")]
170#[derive(Default)]
171pub enum ArtifactWhen {
172    Always,
173    OnFailure,
174    #[default]
175    OnSuccess,
176}
177
178#[derive(Copy, Clone, Deserialize, Debug, Eq, PartialEq)]
179#[serde(rename_all = "snake_case")]
180#[derive(Default)]
181pub enum ArtifactFormat {
182    #[default]
183    Zip,
184    Gzip,
185    Raw,
186}
187
188impl std::fmt::Display for ArtifactFormat {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        let s = match self {
191            Self::Zip => "zip",
192            Self::Gzip => "gzip",
193            Self::Raw => "raw",
194        };
195        write!(f, "{s}")
196    }
197}
198
199#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
200pub(crate) struct JobArtifact {
201    pub name: Option<String>,
202    #[serde(default, deserialize_with = "deserialize_null_default")]
203    pub untracked: bool,
204    pub paths: Vec<String>,
205    #[serde(deserialize_with = "deserialize_null_default")]
206    pub when: ArtifactWhen,
207    pub artifact_type: String,
208    #[serde(deserialize_with = "deserialize_null_default")]
209    pub artifact_format: ArtifactFormat,
210    pub expire_in: Option<String>,
211}
212
213#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
214pub(crate) struct JobArtifactFile {
215    pub filename: String,
216    pub size: usize,
217}
218
219#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
220pub(crate) struct JobDependency {
221    pub id: u64,
222    pub name: String,
223    pub token: String,
224    pub artifacts_file: Option<JobArtifactFile>,
225}
226
227#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
228pub(crate) struct JobResponse {
229    pub id: u64,
230    pub token: String,
231    pub allow_git_fetch: bool,
232    #[serde(deserialize_with = "variable_hash")]
233    pub variables: HashMap<String, JobVariable>,
234    pub steps: Vec<JobStep>,
235    #[serde(deserialize_with = "deserialize_null_default")]
236    pub dependencies: Vec<JobDependency>,
237    #[serde(deserialize_with = "deserialize_null_default")]
238    pub artifacts: Vec<JobArtifact>,
239    #[serde(flatten)]
240    unparsed: JsonValue,
241}
242
243impl JobResponse {
244    pub fn step(&self, name: Phase) -> Option<&JobStep> {
245        self.steps.iter().find(|s| s.name == name)
246    }
247}
248
249#[derive(Error, Debug)]
250pub enum Error {
251    #[error("Unexpected reply code {0}")]
252    UnexpectedStatus(StatusCode),
253    #[error("Job cancelled")]
254    JobCancelled,
255    #[error("Request failure {0}")]
256    Request(#[from] reqwest::Error),
257    #[error("Failed to write to destination {0}")]
258    WriteFailure(#[source] futures::io::Error),
259    #[error("Failed to parse zip file: {0}")]
260    ZipFile(#[from] ZipError),
261    #[error("Empty trace")]
262    EmptyTrace,
263}
264
265pub(crate) struct ArtifactInfo<'a> {
266    pub name: &'a str,
267    pub artifact_format: &'a str,
268    pub artifact_type: &'a str,
269    pub expire_in: Option<&'a str>,
270}
271
272#[derive(Clone, Debug, Default, Serialize)]
273pub(crate) struct ClientMetadata {
274    #[serde(skip_serializing_if = "Option::is_none")]
275    pub(crate) version: Option<String>,
276    #[serde(skip_serializing_if = "Option::is_none")]
277    pub(crate) revision: Option<String>,
278    #[serde(skip_serializing_if = "Option::is_none")]
279    pub(crate) platform: Option<String>,
280    #[serde(skip_serializing_if = "Option::is_none")]
281    pub(crate) architecture: Option<String>,
282}
283
284#[derive(Clone, Debug)]
285pub(crate) struct Client {
286    client: reqwest::Client,
287    url: Url,
288    token: String,
289    system_id: String,
290    metadata: ClientMetadata,
291}
292
293impl Client {
294    pub fn new(url: Url, token: String, system_id: String, metadata: ClientMetadata) -> Self {
295        Self {
296            client: reqwest::Client::new(),
297            url,
298            token,
299            system_id,
300            metadata,
301        }
302    }
303
304    pub async fn request_job(&self) -> Result<Option<JobResponse>, Error> {
305        let request = JobRequest {
306            token: &self.token,
307            system_id: &self.system_id,
308            info: VersionInfo {
309                // Setting `refspecs` is required to run detached MR pipelines.
310                features: FeaturesInfo {
311                    refspecs: true,
312                    upload_multiple_artifacts: true,
313                    ..Default::default()
314                },
315                metadata: &self.metadata,
316            },
317        };
318
319        let mut url = self.url.clone();
320        url.path_segments_mut()
321            .unwrap()
322            .extend(&["api", "v4", "jobs", "request"]);
323
324        let r = self
325            .client
326            .post(url)
327            .json(&request)
328            .send()
329            .await?
330            .error_for_status()?;
331
332        match r.status() {
333            StatusCode::CREATED => Ok(Some(r.json().await?)),
334            StatusCode::NO_CONTENT => Ok(None),
335            _ => Err(Error::UnexpectedStatus(r.status())),
336        }
337    }
338
339    fn check_for_job_cancellation(&self, response: &reqwest::Response) -> Result<(), Error> {
340        match response.headers().get(JOB_STATUS) {
341            Some(header) if header == "canceled" => Err(Error::JobCancelled),
342            _ => Ok(()),
343        }
344    }
345
346    pub async fn update_job(
347        &self,
348        id: u64,
349        token: &str,
350        state: JobState,
351    ) -> Result<JobUpdateReply, Error> {
352        let mut url = self.url.clone();
353        let id_s = format!("{id}");
354        url.path_segments_mut()
355            .unwrap()
356            .extend(&["api", "v4", "jobs", &id_s]);
357
358        let update = JobUpdate { token, state };
359
360        let r = self.client.put(url).json(&update).send().await?;
361
362        self.check_for_job_cancellation(&r)?;
363
364        let trace_update_interval = r
365            .headers()
366            .get(GITLAB_TRACE_UPDATE_INTERVAL)
367            .and_then(|v| Some(Duration::from_secs(v.to_str().ok()?.parse().ok()?)));
368        match r.status() {
369            StatusCode::OK => Ok(JobUpdateReply {
370                trace_update_interval,
371            }),
372            _ => Err(Error::UnexpectedStatus(r.status())),
373        }
374    }
375
376    pub async fn trace<B>(
377        &self,
378        id: u64,
379        token: &str,
380        body: B,
381        start: usize,
382        length: usize,
383    ) -> Result<TraceReply, Error>
384    where
385        B: Into<Body>,
386    {
387        if length == 0 {
388            return Err(Error::EmptyTrace);
389        }
390
391        let mut url = self.url.clone();
392        let id_s = format!("{id}");
393        url.path_segments_mut()
394            .unwrap()
395            .extend(&["api", "v4", "jobs", &id_s, "trace"]);
396
397        let range = format!("{}-{}", start, start + length - 1);
398
399        let r = self
400            .client
401            .patch(url)
402            .header("JOB-TOKEN", token)
403            .header(reqwest::header::CONTENT_RANGE, range)
404            .body(body)
405            .send()
406            .await?;
407
408        self.check_for_job_cancellation(&r)?;
409
410        let trace_update_interval = r
411            .headers()
412            .get(GITLAB_TRACE_UPDATE_INTERVAL)
413            .and_then(|v| Some(Duration::from_secs(v.to_str().ok()?.parse().ok()?)));
414
415        match r.status() {
416            StatusCode::ACCEPTED => Ok(TraceReply {
417                trace_update_interval,
418            }),
419            _ => Err(Error::UnexpectedStatus(r.status())),
420        }
421    }
422
423    pub async fn download_artifact<D: AsyncWrite + Unpin>(
424        &self,
425        id: u64,
426        token: &str,
427        mut dest: D,
428    ) -> Result<(), Error> {
429        let mut url = self.url.clone();
430        let id_s = format!("{id}");
431        url.path_segments_mut()
432            .unwrap()
433            .extend(&["api", "v4", "jobs", &id_s, "artifacts"]);
434
435        let mut r = self
436            .client
437            .get(url)
438            .header("JOB-TOKEN", token)
439            .send()
440            .await?;
441
442        match r.status() {
443            StatusCode::OK => {
444                while let Some(ref chunk) = r.chunk().await? {
445                    dest.write_all(chunk).await.map_err(Error::WriteFailure)?
446                }
447                Ok(())
448            }
449            _ => Err(Error::UnexpectedStatus(r.status())),
450        }
451    }
452
453    pub async fn upload_artifact<D>(
454        &self,
455        id: u64,
456        token: &str,
457        info: ArtifactInfo<'_>,
458        data: D,
459    ) -> Result<(), Error>
460    where
461        D: Into<Body>,
462    {
463        let part = Part::stream(data).file_name(info.name.to_string());
464        let form = Form::new()
465            .part("file", part)
466            .text("artifact_format", info.artifact_format.to_string())
467            .text("artifact_type", info.artifact_type.to_string());
468
469        let form = if let Some(expiry) = info.expire_in {
470            form.text("expire_in", expiry.to_string())
471        } else {
472            form
473        };
474
475        let mut url = self.url.clone();
476        let id_s = format!("{id}");
477        url.path_segments_mut()
478            .unwrap()
479            .extend(&["api", "v4", "jobs", &id_s, "artifacts"]);
480
481        let r = self
482            .client
483            .post(url)
484            .header("JOB-TOKEN", token)
485            .multipart(form)
486            .send()
487            .await?;
488
489        match r.status() {
490            StatusCode::CREATED => Ok(()),
491            _ => Err(Error::UnexpectedStatus(r.status())),
492        }
493    }
494}
495
496#[cfg(test)]
497mod test {
498    use super::*;
499    use gitlab_runner_mock::GitlabRunnerMock;
500    use serde_json::json;
501
502    #[test]
503    fn deserialize_variables() {
504        #[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
505        struct Test {
506            #[serde(deserialize_with = "variable_hash")]
507            variables: HashMap<String, JobVariable>,
508        }
509
510        let json = json!({
511            "variables": [
512                { "key": "VAR1", "value": "1", "public": true, "masked": false },
513                { "key": "VAR2", "value": "2", "public": false, "masked": true }
514            ]
515        });
516
517        let t: Test = serde_json::from_str(&json.to_string()).expect("Failed to deserialize json");
518        assert_eq!(2, t.variables.len());
519        let v = t.variables.get("VAR1").unwrap();
520        assert_eq!(
521            &JobVariable {
522                key: "VAR1".to_string(),
523                value: "1".to_string(),
524                public: true,
525                masked: false
526            },
527            v
528        );
529        let v = t.variables.get("VAR2").unwrap();
530        assert_eq!(
531            &JobVariable {
532                key: "VAR2".to_string(),
533                value: "2".to_string(),
534                public: false,
535                masked: true
536            },
537            v
538        );
539    }
540
541    #[tokio::test]
542    async fn no_job() {
543        let mock = GitlabRunnerMock::start().await;
544
545        let client = Client::new(
546            mock.uri(),
547            mock.runner_token().to_string(),
548            "s_ystem_id1234".to_string(),
549            ClientMetadata::default(),
550        );
551
552        let job = client.request_job().await.unwrap();
553
554        assert_eq!(None, job);
555    }
556
557    #[tokio::test]
558    async fn process_job() {
559        let mock = GitlabRunnerMock::start().await;
560        mock.add_dummy_job("process job".to_string());
561
562        let client = Client::new(
563            mock.uri(),
564            mock.runner_token().to_string(),
565            "s_ystem_id1234".to_string(),
566            ClientMetadata::default(),
567        );
568
569        if let Some(job) = client.request_job().await.unwrap() {
570            client
571                .update_job(job.id, &job.token, JobState::Success)
572                .await
573                .unwrap();
574        } else {
575            panic!("No job!")
576        }
577
578        let job = client.request_job().await.unwrap();
579        assert_eq!(None, job);
580    }
581}