gr/
test.rs

1#[cfg(test)]
2pub mod utils {
3    use crate::{
4        api_defaults::REST_API_MAX_PAGES,
5        api_traits::ApiOperation,
6        config::ConfigProperties,
7        error,
8        http::{
9            self,
10            throttle::{self, ThrottleStrategyType},
11            Headers, Request,
12        },
13        io::{self, HttpResponse, HttpRunner, ShellResponse, TaskRunner},
14        time::Milliseconds,
15        Result,
16    };
17    use lazy_static::lazy_static;
18    use log::{Level, LevelFilter, Metadata, Record};
19    use serde::Serialize;
20    use std::{
21        cell::{Ref, RefCell},
22        fmt::Write,
23        fs::File,
24        io::Read,
25        ops::Deref,
26        rc::Rc,
27        sync::{Arc, Mutex},
28    };
29
30    #[derive(Debug, Clone, Copy, PartialEq)]
31    pub enum ContractType {
32        Gitlab,
33        Github,
34        Git,
35    }
36
37    impl ContractType {
38        fn as_str(&self) -> &str {
39            match *self {
40                ContractType::Gitlab => "gitlab",
41                ContractType::Github => "github",
42                ContractType::Git => "git",
43            }
44        }
45    }
46
47    pub fn get_contract(contract_type: ContractType, filename: &str) -> String {
48        let contracts_path = format!("contracts/{}/{}", contract_type.as_str(), filename);
49        let mut file = File::open(contracts_path).unwrap();
50        let mut contents = String::new();
51        file.read_to_string(&mut contents).unwrap();
52        contents
53    }
54
55    pub struct MockRunner<R> {
56        responses: RefCell<Vec<R>>,
57        cmd: RefCell<String>,
58        headers: RefCell<Headers>,
59        url: RefCell<String>,
60        pub api_operation: RefCell<Option<ApiOperation>>,
61        pub config: ConfigMock,
62        pub http_method: RefCell<Vec<http::Method>>,
63        pub throttled: RefCell<u32>,
64        pub milliseconds_throttled: RefCell<Milliseconds>,
65        pub run_count: RefCell<u32>,
66        pub request_body: RefCell<String>,
67    }
68
69    impl<R> MockRunner<R> {
70        pub fn new(responses: Vec<R>) -> Self {
71            Self {
72                responses: RefCell::new(responses),
73                cmd: RefCell::new(String::new()),
74                headers: RefCell::new(Headers::new()),
75                url: RefCell::new(String::new()),
76                api_operation: RefCell::new(None),
77                config: ConfigMock::default(),
78                http_method: RefCell::new(Vec::new()),
79                throttled: RefCell::new(0),
80                milliseconds_throttled: RefCell::new(Milliseconds::new(0)),
81                run_count: RefCell::new(0),
82                request_body: RefCell::new(String::new()),
83            }
84        }
85
86        pub fn with_config(self, config: ConfigMock) -> Self {
87            Self { config, ..self }
88        }
89
90        pub fn cmd(&self) -> Ref<String> {
91            self.cmd.borrow()
92        }
93
94        pub fn url(&self) -> Ref<String> {
95            self.url.borrow()
96        }
97
98        pub fn headers(&self) -> Ref<Headers> {
99            self.headers.borrow()
100        }
101
102        pub fn throttled(&self) -> Ref<u32> {
103            self.throttled.borrow()
104        }
105
106        pub fn milliseconds_throttled(&self) -> Ref<Milliseconds> {
107            self.milliseconds_throttled.borrow()
108        }
109
110        pub fn request_body(&self) -> Ref<String> {
111            self.request_body.borrow()
112        }
113    }
114
115    impl TaskRunner for MockRunner<ShellResponse> {
116        type Response = ShellResponse;
117
118        fn run<T>(&self, cmd: T) -> Result<Self::Response>
119        where
120            T: IntoIterator,
121            T::Item: AsRef<std::ffi::OsStr>,
122        {
123            self.cmd.replace(
124                cmd.into_iter()
125                    .map(|s| s.as_ref().to_str().unwrap().to_string())
126                    .collect::<Vec<String>>()
127                    .join(" "),
128            );
129            let response = self.responses.borrow_mut().pop().unwrap();
130            *self.run_count.borrow_mut() += 1;
131            match response.status {
132                0 => return Ok(response),
133                _ => return Err(error::gen(&response.body)),
134            }
135        }
136    }
137
138    impl HttpRunner for MockRunner<HttpResponse> {
139        type Response = HttpResponse;
140
141        fn run<T: Serialize>(&self, cmd: &mut Request<T>) -> Result<Self::Response> {
142            self.url.replace(cmd.url().to_string());
143            self.headers.replace(cmd.headers().clone());
144            self.api_operation.replace(cmd.api_operation().clone());
145            let response = self.responses.borrow_mut().pop().unwrap();
146            let body = serde_json::to_string(&cmd.body).unwrap_or_default();
147            self.request_body.replace(body);
148            self.http_method.borrow_mut().push(cmd.method.clone());
149            match response.status {
150                // 409 Conflict - Merge request already exists. - Gitlab
151                // 422 Conflict - Merge request already exists. - Github
152                200 | 201 | 302 | 409 | 422 => return Ok(response),
153                // RateLimit error code. 403 secondary rate limit, 429 primary
154                // rate limit.
155                403 | 429 => {
156                    let headers = response.get_ratelimit_headers().unwrap_or_default();
157                    return Err(error::GRError::RateLimitExceeded(headers).into());
158                }
159                500..=599 => return Err(error::GRError::RemoteServerError(response.body).into()),
160                // Just for testing purposes, if the test client sets a status
161                // code of -1 we return a HTTP transport error.
162                -1 => return Err(error::GRError::HttpTransportError(response.body).into()),
163                _ => return Err(error::gen(&response.body)),
164            }
165        }
166
167        fn api_max_pages<T: Serialize>(&self, _cmd: &Request<T>) -> u32 {
168            self.config.get_max_pages(
169                &self
170                    .api_operation
171                    .borrow()
172                    .as_ref()
173                    // We set it to Project by default in cases where it does
174                    // not matter while testing.
175                    .unwrap_or(&ApiOperation::Project),
176            )
177        }
178    }
179
180    pub struct ConfigMock {
181        max_pages: u32,
182    }
183
184    impl ConfigMock {
185        pub fn new(max_pages: u32) -> Self {
186            ConfigMock { max_pages }
187        }
188    }
189
190    impl ConfigProperties for ConfigMock {
191        fn api_token(&self) -> &str {
192            "1234"
193        }
194        fn cache_location(&self) -> Option<&str> {
195            Some("")
196        }
197        fn get_max_pages(&self, _api_operation: &ApiOperation) -> u32 {
198            self.max_pages
199        }
200    }
201
202    pub fn config() -> Arc<dyn ConfigProperties> {
203        Arc::new(ConfigMock::default())
204    }
205
206    impl Default for ConfigMock {
207        fn default() -> Self {
208            ConfigMock {
209                max_pages: REST_API_MAX_PAGES,
210            }
211        }
212    }
213
214    impl ConfigProperties for Arc<ConfigMock> {
215        fn api_token(&self) -> &str {
216            "1234"
217        }
218        fn cache_location(&self) -> Option<&str> {
219            Some("")
220        }
221        fn get_max_pages(&self, _api_operation: &ApiOperation) -> u32 {
222            self.as_ref().max_pages
223        }
224    }
225
226    struct TestLogger;
227
228    lazy_static! {
229        pub static ref LOG_BUFFER: Mutex<String> = Mutex::new(String::new());
230    }
231
232    impl log::Log for TestLogger {
233        fn enabled(&self, metadata: &Metadata) -> bool {
234            metadata.level() <= Level::Trace
235        }
236
237        fn log(&self, record: &Record) {
238            if self.enabled(record.metadata()) {
239                let mut buffer = LOG_BUFFER.lock().unwrap();
240                writeln!(buffer, "{} - {}", record.level(), record.args())
241                    .expect("Failed to write to log buffer");
242            }
243        }
244
245        fn flush(&self) {}
246    }
247
248    pub fn init_test_logger() {
249        let logger = TestLogger;
250        log::set_boxed_logger(Box::new(logger)).unwrap_or(());
251        log::set_max_level(LevelFilter::Trace);
252    }
253
254    pub struct Domain(pub String);
255    pub struct BasePath(pub String);
256
257    impl Deref for Domain {
258        type Target = String;
259
260        fn deref(&self) -> &Self::Target {
261            &self.0
262        }
263    }
264
265    impl Deref for BasePath {
266        type Target = String;
267
268        fn deref(&self) -> &Self::Target {
269            &self.0
270        }
271    }
272
273    pub enum ClientType {
274        Gitlab(Domain, BasePath),
275        Github(Domain, BasePath),
276    }
277
278    pub fn default_gitlab() -> ClientType {
279        ClientType::Gitlab(
280            Domain("gitlab.com".to_string()),
281            BasePath("jordilin/gitlapi".to_string()),
282        )
283    }
284
285    pub fn default_github() -> ClientType {
286        ClientType::Github(
287            Domain("github.com".to_string()),
288            BasePath("jordilin/githapi".to_string()),
289        )
290    }
291
292    #[macro_export]
293    macro_rules! setup_client {
294        ($response_contracts:expr, $client_type:expr, $trait_type:ty) => {{
295            let config = crate::test::utils::config();
296            let responses: Vec<_> = $response_contracts
297                .into_iter()
298                .map(|(status_code, get_contract_fn, headers)| {
299                    let body = get_contract_fn();
300                    let mut response = HttpResponse::builder();
301                    response.status(status_code);
302                    if headers.is_some() {
303                        response.headers(headers.clone().unwrap());
304                        let rate_limit_header =
305                            crate::io::parse_ratelimit_headers(headers.as_ref());
306                        let link_header = crate::io::parse_page_headers(headers.as_ref());
307                        let flow_control_headers = crate::io::FlowControlHeaders::new(
308                            std::rc::Rc::new(link_header),
309                            std::rc::Rc::new(rate_limit_header),
310                        );
311                        response.flow_control_headers(flow_control_headers);
312                    }
313                    if body.is_some() {
314                        response.body(body.unwrap());
315                    }
316                    response.build().unwrap()
317                })
318                .collect();
319            let client = std::sync::Arc::new(crate::test::utils::MockRunner::new(responses));
320            let remote: Box<$trait_type> = match $client_type {
321                crate::test::utils::ClientType::Gitlab(domain, path) => Box::new(
322                    crate::gitlab::Gitlab::new(config, &domain, &path, client.clone()),
323                ),
324                crate::test::utils::ClientType::Github(domain, path) => Box::new(
325                    crate::github::Github::new(config, &domain, &path, client.clone()),
326                ),
327            };
328
329            (client, remote)
330        }};
331    }
332
333    pub struct ResponseContracts {
334        contract_type: ContractType,
335        contracts: Vec<(i32, Box<dyn Fn() -> Option<String>>, Option<Headers>)>,
336    }
337
338    impl ResponseContracts {
339        pub fn new(contract_type: ContractType) -> Self {
340            Self {
341                contract_type,
342                contracts: Vec::new(),
343            }
344        }
345
346        pub fn add_body<B: Into<String> + Clone + 'static>(
347            mut self,
348            status_code: i32,
349            body: Option<B>,
350            headers: Option<Headers>,
351        ) -> Self {
352            self.contracts.push((
353                status_code,
354                Box::new(move || body.clone().map(|b| b.into())),
355                headers,
356            ));
357            self
358        }
359
360        pub fn add_contract<F: Into<String> + Clone + 'static>(
361            mut self,
362            status_code: i32,
363            contract_file: F,
364            headers: Option<Headers>,
365        ) -> Self {
366            self.contracts.push((
367                status_code,
368                Box::new(move || {
369                    Some(get_contract(
370                        self.contract_type.clone(),
371                        &contract_file.clone().into(),
372                    ))
373                }),
374                headers,
375            ));
376            self
377        }
378    }
379
380    impl IntoIterator for ResponseContracts {
381        type Item = (i32, Box<dyn Fn() -> Option<String>>, Option<Headers>);
382        type IntoIter = std::vec::IntoIter<Self::Item>;
383
384        fn into_iter(self) -> Self::IntoIter {
385            self.contracts.into_iter()
386        }
387    }
388
389    pub struct MockThrottler {
390        throttled: RefCell<u32>,
391        milliseconds_throttled: RefCell<Milliseconds>,
392        strategy: throttle::ThrottleStrategyType,
393    }
394
395    impl MockThrottler {
396        pub fn new(strategy_type: Option<ThrottleStrategyType>) -> Self {
397            Self {
398                throttled: RefCell::new(0),
399                milliseconds_throttled: RefCell::new(Milliseconds::new(0)),
400                strategy: strategy_type.unwrap_or(ThrottleStrategyType::NoThrottle),
401            }
402        }
403
404        pub fn throttled(&self) -> Ref<u32> {
405            self.throttled.borrow()
406        }
407
408        pub fn milliseconds_throttled(&self) -> Ref<Milliseconds> {
409            self.milliseconds_throttled.borrow()
410        }
411    }
412
413    impl http::throttle::ThrottleStrategy for Rc<MockThrottler> {
414        fn throttle(&self, _response: Option<&io::FlowControlHeaders>) {
415            let mut throttled = self.throttled.borrow_mut();
416            *throttled += 1;
417        }
418
419        fn throttle_for(&self, delay: Milliseconds) {
420            let mut throttled = self.throttled.borrow_mut();
421            *throttled += 1;
422            let mut milliseconds_throttled = self.milliseconds_throttled.borrow_mut();
423            *milliseconds_throttled += delay;
424        }
425
426        fn strategy(&self) -> ThrottleStrategyType {
427            self.strategy.clone()
428        }
429    }
430}