Skip to main content

gr/
test.rs

1#[cfg(test)]
2pub mod utils {
3    use crate::{
4        api_defaults::REST_API_MAX_PAGES,
5        api_traits::ApiOperation,
6        config::ConfigProperties,
7        error,
8        http::{
9            self,
10            throttle::{self, ThrottleStrategyType},
11            Headers, Request,
12        },
13        io::{self, HttpResponse, HttpRunner, ShellResponse, TaskRunner},
14        time::Milliseconds,
15        Result,
16    };
17    use lazy_static::lazy_static;
18    use log::{Level, LevelFilter, Metadata, Record};
19    use serde::Serialize;
20    use std::{
21        cell::{Ref, RefCell},
22        fmt::Write,
23        fs::File,
24        io::Read,
25        ops::Deref,
26        rc::Rc,
27        sync::{Arc, Mutex},
28    };
29
30    #[derive(Debug, Clone, Copy, PartialEq)]
31    pub enum ContractType {
32        Gitlab,
33        Github,
34        Git,
35    }
36
37    impl ContractType {
38        fn as_str(&self) -> &str {
39            match *self {
40                ContractType::Gitlab => "gitlab",
41                ContractType::Github => "github",
42                ContractType::Git => "git",
43            }
44        }
45    }
46
47    pub fn get_contract(contract_type: ContractType, filename: &str) -> String {
48        let contracts_path = format!("contracts/{}/{}", contract_type.as_str(), filename);
49        let mut file = File::open(contracts_path).unwrap();
50        let mut contents = String::new();
51        file.read_to_string(&mut contents).unwrap();
52        contents
53    }
54
55    pub struct MockRunner<R> {
56        responses: RefCell<Vec<R>>,
57        cmd: RefCell<String>,
58        headers: RefCell<Headers>,
59        url: RefCell<String>,
60        pub api_operation: RefCell<Option<ApiOperation>>,
61        pub config: ConfigMock,
62        pub http_method: RefCell<Vec<http::Method>>,
63        pub throttled: RefCell<u32>,
64        pub milliseconds_throttled: RefCell<Milliseconds>,
65        pub run_count: RefCell<u32>,
66        pub request_body: RefCell<String>,
67    }
68
69    impl<R> MockRunner<R> {
70        pub fn new(responses: Vec<R>) -> Self {
71            Self {
72                responses: RefCell::new(responses),
73                cmd: RefCell::new(String::new()),
74                headers: RefCell::new(Headers::new()),
75                url: RefCell::new(String::new()),
76                api_operation: RefCell::new(None),
77                config: ConfigMock::default(),
78                http_method: RefCell::new(Vec::new()),
79                throttled: RefCell::new(0),
80                milliseconds_throttled: RefCell::new(Milliseconds::new(0)),
81                run_count: RefCell::new(0),
82                request_body: RefCell::new(String::new()),
83            }
84        }
85
86        pub fn with_config(self, config: ConfigMock) -> Self {
87            Self { config, ..self }
88        }
89
90        pub fn cmd(&self) -> Ref<'_, String> {
91            self.cmd.borrow()
92        }
93
94        pub fn url(&self) -> Ref<'_, String> {
95            self.url.borrow()
96        }
97
98        pub fn headers(&self) -> Ref<'_, Headers> {
99            self.headers.borrow()
100        }
101
102        pub fn throttled(&self) -> Ref<'_, u32> {
103            self.throttled.borrow()
104        }
105
106        pub fn milliseconds_throttled(&self) -> Ref<'_, Milliseconds> {
107            self.milliseconds_throttled.borrow()
108        }
109
110        pub fn request_body(&self) -> Ref<'_, String> {
111            self.request_body.borrow()
112        }
113    }
114
115    impl TaskRunner for MockRunner<ShellResponse> {
116        type Response = ShellResponse;
117
118        fn run<T>(&self, cmd: T) -> Result<Self::Response>
119        where
120            T: IntoIterator,
121            T::Item: AsRef<std::ffi::OsStr>,
122        {
123            self.cmd.replace(
124                cmd.into_iter()
125                    .map(|s| s.as_ref().to_str().unwrap().to_string())
126                    .collect::<Vec<String>>()
127                    .join(" "),
128            );
129            let response = self.responses.borrow_mut().pop().unwrap();
130            *self.run_count.borrow_mut() += 1;
131            match response.status {
132                0 => Ok(response),
133                _ => Err(error::gen(&response.body)),
134            }
135        }
136    }
137
138    impl HttpRunner for MockRunner<HttpResponse> {
139        type Response = HttpResponse;
140
141        fn run<T: Serialize>(&self, cmd: &mut Request<T>) -> Result<Self::Response> {
142            self.url.replace(cmd.url().to_string());
143            self.headers.replace(cmd.headers().clone());
144            self.api_operation.replace(cmd.api_operation().clone());
145            let response = self.responses.borrow_mut().pop().unwrap();
146            let body = serde_json::to_string(&cmd.body).unwrap_or_default();
147            self.request_body.replace(body);
148            self.http_method.borrow_mut().push(cmd.method.clone());
149            match response.status {
150                // 409 Conflict - Merge request already exists. - Gitlab
151                // 422 Conflict - Merge request already exists. - Github
152                200 | 201 | 302 | 409 | 422 => Ok(response),
153                // RateLimit error code. 403 secondary rate limit, 429 primary
154                // rate limit.
155                403 | 429 => {
156                    let headers = response.get_ratelimit_headers().unwrap_or_default();
157                    Err(error::GRError::RateLimitExceeded(headers).into())
158                }
159                500..=599 => Err(error::GRError::RemoteServerError(response.body).into()),
160                // Just for testing purposes, if the test client sets a status
161                // code of -1 we return a HTTP transport error.
162                -1 => Err(error::GRError::HttpTransportError(response.body).into()),
163                _ => Err(error::gen(&response.body)),
164            }
165        }
166
167        fn api_max_pages<T: Serialize>(&self, _cmd: &Request<T>) -> u32 {
168            self.config.get_max_pages(
169                self.api_operation
170                    .borrow()
171                    .as_ref()
172                    // We set it to Project by default in cases where it does
173                    // not matter while testing.
174                    .unwrap_or(&ApiOperation::Project),
175            )
176        }
177    }
178
179    pub struct ConfigMock {
180        max_pages: u32,
181    }
182
183    impl ConfigMock {
184        pub fn new(max_pages: u32) -> Self {
185            ConfigMock { max_pages }
186        }
187    }
188
189    impl ConfigProperties for ConfigMock {
190        fn api_token(&self) -> &str {
191            "1234"
192        }
193        fn cache_location(&self) -> Option<&str> {
194            Some("")
195        }
196        fn get_max_pages(&self, _api_operation: &ApiOperation) -> u32 {
197            self.max_pages
198        }
199    }
200
201    pub fn config() -> Arc<dyn ConfigProperties> {
202        Arc::new(ConfigMock::default())
203    }
204
205    impl Default for ConfigMock {
206        fn default() -> Self {
207            ConfigMock {
208                max_pages: REST_API_MAX_PAGES,
209            }
210        }
211    }
212
213    impl ConfigProperties for Arc<ConfigMock> {
214        fn api_token(&self) -> &str {
215            "1234"
216        }
217        fn cache_location(&self) -> Option<&str> {
218            Some("")
219        }
220        fn get_max_pages(&self, _api_operation: &ApiOperation) -> u32 {
221            self.as_ref().max_pages
222        }
223    }
224
225    struct TestLogger;
226
227    lazy_static! {
228        pub static ref LOG_BUFFER: Mutex<String> = Mutex::new(String::new());
229    }
230
231    impl log::Log for TestLogger {
232        fn enabled(&self, metadata: &Metadata) -> bool {
233            metadata.level() <= Level::Trace
234        }
235
236        fn log(&self, record: &Record) {
237            if self.enabled(record.metadata()) {
238                let mut buffer = LOG_BUFFER.lock().unwrap();
239                writeln!(buffer, "{} - {}", record.level(), record.args())
240                    .expect("Failed to write to log buffer");
241            }
242        }
243
244        fn flush(&self) {}
245    }
246
247    pub fn init_test_logger() {
248        let logger = TestLogger;
249        log::set_boxed_logger(Box::new(logger)).unwrap_or(());
250        log::set_max_level(LevelFilter::Trace);
251    }
252
253    pub struct Domain(pub String);
254    pub struct BasePath(pub String);
255
256    impl Deref for Domain {
257        type Target = String;
258
259        fn deref(&self) -> &Self::Target {
260            &self.0
261        }
262    }
263
264    impl Deref for BasePath {
265        type Target = String;
266
267        fn deref(&self) -> &Self::Target {
268            &self.0
269        }
270    }
271
272    pub enum ClientType {
273        Gitlab(Domain, BasePath),
274        Github(Domain, BasePath),
275    }
276
277    pub fn default_gitlab() -> ClientType {
278        ClientType::Gitlab(
279            Domain("gitlab.com".to_string()),
280            BasePath("jordilin/gitlapi".to_string()),
281        )
282    }
283
284    pub fn default_github() -> ClientType {
285        ClientType::Github(
286            Domain("github.com".to_string()),
287            BasePath("jordilin/githapi".to_string()),
288        )
289    }
290
291    #[macro_export]
292    macro_rules! setup_client {
293        ($response_contracts:expr, $client_type:expr, $trait_type:ty) => {{
294            let config = $crate::test::utils::config();
295            let responses: Vec<_> = $response_contracts
296                .into_iter()
297                .map(|(status_code, get_contract_fn, headers)| {
298                    let body = get_contract_fn();
299                    let mut response = HttpResponse::builder();
300                    response.status(status_code);
301                    if headers.is_some() {
302                        response.headers(headers.clone().unwrap());
303                        let rate_limit_header =
304                            crate::io::parse_ratelimit_headers(headers.as_ref());
305                        let link_header = crate::io::parse_page_headers(headers.as_ref());
306                        let flow_control_headers = crate::io::FlowControlHeaders::new(
307                            std::rc::Rc::new(link_header),
308                            std::rc::Rc::new(rate_limit_header),
309                        );
310                        response.flow_control_headers(flow_control_headers);
311                    }
312                    if body.is_some() {
313                        response.body(body.unwrap());
314                    }
315                    response.build().unwrap()
316                })
317                .collect();
318            let client = std::sync::Arc::new(crate::test::utils::MockRunner::new(responses));
319            let remote: Box<$trait_type> = match $client_type {
320                crate::test::utils::ClientType::Gitlab(domain, path) => Box::new(
321                    crate::gitlab::Gitlab::new(config, &domain, &path, client.clone()),
322                ),
323                crate::test::utils::ClientType::Github(domain, path) => Box::new(
324                    crate::github::Github::new(config, &domain, &path, client.clone()),
325                ),
326            };
327
328            (client, remote)
329        }};
330    }
331
332    pub struct ResponseContracts {
333        contract_type: ContractType,
334        contracts: Vec<(i32, Box<dyn Fn() -> Option<String>>, Option<Headers>)>,
335    }
336
337    impl ResponseContracts {
338        pub fn new(contract_type: ContractType) -> Self {
339            Self {
340                contract_type,
341                contracts: Vec::new(),
342            }
343        }
344
345        pub fn add_body<B: Into<String> + Clone + 'static>(
346            mut self,
347            status_code: i32,
348            body: Option<B>,
349            headers: Option<Headers>,
350        ) -> Self {
351            self.contracts.push((
352                status_code,
353                Box::new(move || body.clone().map(|b| b.into())),
354                headers,
355            ));
356            self
357        }
358
359        pub fn add_contract<F: Into<String> + Clone + 'static>(
360            mut self,
361            status_code: i32,
362            contract_file: F,
363            headers: Option<Headers>,
364        ) -> Self {
365            self.contracts.push((
366                status_code,
367                Box::new(move || {
368                    Some(get_contract(
369                        self.contract_type,
370                        &contract_file.clone().into(),
371                    ))
372                }),
373                headers,
374            ));
375            self
376        }
377    }
378
379    impl IntoIterator for ResponseContracts {
380        type Item = (i32, Box<dyn Fn() -> Option<String>>, Option<Headers>);
381        type IntoIter = std::vec::IntoIter<Self::Item>;
382
383        fn into_iter(self) -> Self::IntoIter {
384            self.contracts.into_iter()
385        }
386    }
387
388    pub struct MockThrottler {
389        throttled: RefCell<u32>,
390        milliseconds_throttled: RefCell<Milliseconds>,
391        strategy: throttle::ThrottleStrategyType,
392    }
393
394    impl MockThrottler {
395        pub fn new(strategy_type: Option<ThrottleStrategyType>) -> Self {
396            Self {
397                throttled: RefCell::new(0),
398                milliseconds_throttled: RefCell::new(Milliseconds::new(0)),
399                strategy: strategy_type.unwrap_or(ThrottleStrategyType::NoThrottle),
400            }
401        }
402
403        pub fn throttled(&self) -> Ref<'_, u32> {
404            self.throttled.borrow()
405        }
406
407        pub fn milliseconds_throttled(&self) -> Ref<'_, Milliseconds> {
408            self.milliseconds_throttled.borrow()
409        }
410    }
411
412    impl http::throttle::ThrottleStrategy for Rc<MockThrottler> {
413        fn throttle(&self, _response: Option<&io::FlowControlHeaders>) {
414            let mut throttled = self.throttled.borrow_mut();
415            *throttled += 1;
416        }
417
418        fn throttle_for(&self, delay: Milliseconds) {
419            let mut throttled = self.throttled.borrow_mut();
420            *throttled += 1;
421            let mut milliseconds_throttled = self.milliseconds_throttled.borrow_mut();
422            *milliseconds_throttled += delay;
423        }
424
425        fn strategy(&self) -> ThrottleStrategyType {
426            self.strategy.clone()
427        }
428    }
429}