kubernetes_mock/
lib.rs

1// Copyright 2023 Cisco Systems, Inc. and its affiliates
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// SPDX-License-Identifier: Apache-2.0
16
17//! Mock Kubernetes client in Rust
18
19use k8s_openapi::apimachinery::pkg::apis::meta::v1::WatchEvent;
20use k8s_openapi::http::{Request, Response};
21use kube::core::ObjectList;
22use kube::Client;
23
24use hyper::Body;
25use std::future::Future;
26use std::pin::pin;
27use std::time::Duration;
28use tower_test::mock::{self, Handle};
29
30const DEFAULT_NS: &str = "default";
31const API_TIMEOUT_SECS: u64 = 1;
32const TIMEOUT_DURATION: Duration = Duration::from_secs(API_TIMEOUT_SECS);
33
34/// The main mocker struct. Holds all the information about expected API calls, and provides
35/// functionality to set up expected requests and responses before testing with `.run()`.
36pub struct KubernetesMocker {
37    handle: Handle<Request<Body>, Response<Body>>,
38    // Vec of (expected request, response)
39    // This allows us to represent the invariant that the number of
40    // requests/responses is the same (rather than having two vecs).
41    expected_requests: Vec<(Request<Body>, MockApiResponse)>,
42}
43
44/// Returns (kube client, mock struct).
45/// Use the kubernetes client as normal, both during calls to `mocker.expect()` and then during
46/// `mocker.run()`.
47///
48/// # Examples
49/// ```
50/// # use kubernetes_mock::*;
51/// # use kube::{Api, api::ListParams};
52/// # use k8s_openapi::api::core::v1::Node;
53/// # #[tokio::main]
54/// # async fn main() {
55/// let (client, mut mocker) = make_mocker();
56/// let api: Api<Node> = Api::all(client);
57/// mocker.expect(|| async {
58///     let nodes = api.list(&ListParams::default()).await;
59///   },
60///   MockReturn::List(&[Node::default()]),
61/// ).await.unwrap();
62/// // ...
63/// let handle = tokio::spawn(mocker.run());
64/// api.list(&ListParams::default()).await;
65/// handle.await.unwrap().unwrap(); // Assert tests pass with `unwrap()`.
66/// # }
67/// ```
68#[must_use]
69pub fn make_mocker() -> (Client, KubernetesMocker) {
70    let (mock_service, handle) = mock::pair::<Request<Body>, Response<Body>>();
71    let client = Client::new(mock_service, DEFAULT_NS);
72    let mocker = KubernetesMocker {
73        handle,
74        expected_requests: vec![],
75    };
76
77    (client, mocker)
78}
79
80/// Represents errors generated during `mocker.expect()`.
81#[derive(thiserror::Error, Debug)]
82pub enum MockError {
83    /// Serde failed to serialize the mock response.
84    #[error("serde_json failed to serialize mock response")]
85    Serde(#[from] serde_json::Error),
86
87    /// The given closure did not make an API request.
88    #[error("expected a request in closure passed to KubernetesMocker::expect()")]
89    NoRequest,
90
91    /// The given closure ran for longer than expected (uses internal `API_TIMEOUT_SECS` -
92    /// currently 1).
93    #[error("expected_api_call timed out")]
94    Timeout,
95
96    /// The given closure panicked during execution.
97    #[error("expected_api_call panicked")]
98    Panicked,
99
100    /// If expecting a `client.watch()` call, you must provide at least one [`WatchEvent`].
101    /// TODO: maybe this isn't necessary?
102    #[error("Watch events list must contain at least one WatchEvent")]
103    WatchNoItems,
104
105    /// Failed to turn the serialized bytes into an [`k8s_openapi::http::Response`].
106    /// (This error should be impossible in practice?)
107    #[error("failed to create an http::Response from the serialized bytes")]
108    HttpBody(#[from] k8s_openapi::http::Error),
109}
110
111/// Represents possible errors during `mocker.run()`.
112#[derive(thiserror::Error, Debug)]
113pub enum MockRunError {
114    /// The mocker had more `expect()` calls than API calls received during `run()`.
115    #[error(
116        "Mock API received too few API calls. Expected {expected} but only received {received}"
117    )]
118    TooFewApiCalls { expected: usize, received: usize },
119
120    /// The mocker had more API calls than calls to `expect()`. The `call` field holds a vector of
121    /// all extraneous calls received.
122    #[error("Mock API received too many API calls. Expected: {expected}, received {received}, call {call:#?}")]
123    TooManyApiCalls {
124        expected: usize,
125        received: usize,
126        call: Vec<Request<Body>>,
127    },
128
129    /// The mocker received an API call that's different to the corresponding expected API call at
130    /// index `idx`.
131    #[error("Mock API received a different call at {idx} than expected. Expected: {expected:#?}, received: {received:#?}")]
132    IncorrectApiCall {
133        received: Request<Body>,
134        expected: Request<Body>,
135        idx: usize,
136    },
137}
138
139/// An enum to represent the possible return values from the API.
140pub enum MockReturn<'a, T: kube::api::Resource + serde::Serialize> {
141    /// Return a single item. This should be used as the response for `api.get()` and
142    /// `api.create()` calls.
143    Single(T),
144    /// Return multiple items in a list. This should be used for `api.list()` calls.
145    List(&'a [T]),
146    /// Return [`WatchEvent`]s, with the option of adding delays in between. This should be used for
147    /// `api.watch()` calls.
148    Watch(&'a [MockWatch<T>]),
149    /// Catch-all for anything else - if it's not a [`kube::Resource`] or serializable, you can use
150    /// this to return whatever bytes you want.
151    Raw(Vec<u8>),
152}
153
154/// Used to represent a stream of `api.watch()` events as a list.
155pub enum MockWatch<T: kube::api::Resource + serde::Serialize> {
156    /// Wait for this long before sending the next [`WatchEvent`]. This keeps the watch socket
157    /// open, so feel free to use this at the end of [`MockReturn::Watch`] lists to stop a
158    /// controller failing due to the watch call being ended, and creating a new, unexpected API
159    /// call to set it up again.
160    ///
161    /// Note that `mocker.run()` waits for all of these to finish before returning, so don't make
162    /// it too long!
163    Wait(Duration),
164    /// Send a Kubernetes [`WatchEvent`] to whatever called `api.watch()`.
165    Event(WatchEvent<T>),
166}
167
168#[derive(Debug)]
169enum MockApiWatchResponse {
170    Wait(Duration),
171    Event(Vec<u8>),
172}
173
174enum MockApiResponse {
175    Single(Vec<u8>),
176    Stream(Vec<MockApiWatchResponse>),
177}
178
179impl KubernetesMocker {
180    /// `mocker.expect()` - takes a closure producing a future and what to return when receiving
181    /// the API call given by the closure.
182    /// Each closure should only have 1 API call. More will be ignored, fewer will return an error.
183    ///
184    /// Will return an error on panic, timeout, or if the closure does not make a request - feel
185    /// free to use `unwrap()` liberally!
186    ///
187    /// Returns `Ok(&mut self)` on success, meaning you can chain
188    /// `.expect().await.unwrap().expect().await.unwrap()`...
189    ///
190    /// # Errors
191    /// * [`MockError::Serde`] - `serde_json` failed to serialize mock response
192    /// * [`MockError::NoRequest`] - closure passed in did not make an API request.
193    /// * [`MockError::Timeout`] - closure passed in timed out.
194    /// * [`MockError::Panicked`] - closure passed in panicked.
195    /// * [`MockError::WatchNoItems`] - Watch events list must contain at least one `WatchEvent`.
196    /// * [`MockError::HttpBody`] - failed to create an `http::Response` from the serialized bytes.
197    ///
198    /// # Examples
199    /// ```
200    /// # use kubernetes_mock::*;
201    /// # use kube::{Api, api::{ListParams, PostParams}};
202    /// # use k8s_openapi::api::core::v1::Node;
203    /// # #[tokio::main]
204    /// # async fn main() {
205    /// let (client, mut mocker) = make_mocker();
206    /// let api: Api<Node> = Api::all(client);
207    /// mocker.expect(|| async {
208    ///       api.list(&ListParams::default()).await.unwrap();
209    ///     },
210    ///     MockReturn::List(&[Node::default()]),
211    ///   ).await.unwrap()
212    ///   .expect(|| async {
213    ///       api.create(&PostParams::default(), &Node::default()).await.unwrap();
214    ///     },
215    ///     MockReturn::Single(Node::default()),
216    ///   ).await.unwrap();
217    /// # }
218    /// ```
219    ///
220    pub async fn expect<'a, Fut, T, F>(
221        &mut self,
222        expected_api_call: F,
223        result: MockReturn<'a, T>,
224    ) -> Result<&mut Self, MockError>
225    where
226        Fut: Future<Output = ()> + Send,
227        T: kube::Resource + serde::Serialize + Clone,
228        F: FnOnce() -> Fut,
229    {
230        let response = match result {
231            MockReturn::Single(ref item) => serde_json::to_vec(&item)?,
232            MockReturn::List(items) => serde_json::to_vec(&list(items))?,
233            MockReturn::Raw(ref vec) => vec.clone(),
234            // If we're mocking a `watch`, just send the first event back to the closure
235            MockReturn::Watch(items) => {
236                let first_event = match items
237                    .iter()
238                    .find(|i| matches!(i, MockWatch::Event(_)))
239                    .ok_or(MockError::WatchNoItems)?
240                {
241                    MockWatch::Wait(_) => unreachable!(), // Filtered out earlier by the `find`
242                    MockWatch::Event(e) => e,
243                };
244                serde_json::to_vec(first_event)?
245            }
246        };
247        let fut = tokio::time::timeout(TIMEOUT_DURATION, expected_api_call());
248        let handle_request = async {
249            let (request, send) =
250                tokio::time::timeout(TIMEOUT_DURATION, self.handle.next_request())
251                    .await
252                    .ok()
253                    .flatten()
254                    .ok_or(MockError::NoRequest)?;
255
256            // Looks intimidating, but let's step through this
257            let api_return = match result {
258                // If this is a `watch` request, we need to turn the Vec<MockWatch> into
259                // Vec<MockApiWatchResponse>. This is represented by Stream()
260                MockReturn::Watch(items) => MockApiResponse::Stream(
261                    items
262                        .iter()
263                        .map(|e| match e {
264                            MockWatch::Wait(duration) => Ok(MockApiWatchResponse::Wait(*duration)),
265                            MockWatch::Event(e) => {
266                                let mut vec = serde_json::to_vec(e)?;
267                                vec.push(b'\n');
268                                Ok(MockApiWatchResponse::Event(vec))
269                            }
270                        })
271                        // Collect as a result, to allow collecting errors from serde inside the
272                        // loop. We then return if there are any errors, and are left with a
273                        // Vec<MockApiResponse>.
274                        .collect::<Result<Vec<_>, MockError>>()?,
275                ),
276                // Otherwise, we can use the existing `response`.
277                _ => MockApiResponse::Single(response.clone()),
278            };
279            self.expected_requests.push((request, api_return));
280            send.send_response(Response::builder().body(Body::from(response))?);
281            Ok::<(), MockError>(())
282        };
283        let (expected_handle, handle_request) = futures::future::join(fut, handle_request).await;
284        handle_request?;
285        expected_handle
286            //.map_err(|_| MockError::Panicked)? // if using tokio::task::spawn (needs 'static)
287            .map_err(|_| MockError::Timeout)?;
288        Ok(self)
289    }
290
291    /// `KubernetesMocker::run()` - produces a future which will compare the received API calls
292    /// past this point, and compare them to the ones received during the `expect()` calls.
293    ///
294    /// # Errors
295    /// * [`MockRunError::TooFewApiCalls`] - it did not receive as many API calls as `expect()` got.
296    /// * [`MockRunError::TooManyApiCalls`] - received at least 1 more API call that was not expected.
297    /// * [`MockRunError::IncorrectApiCall`] - an API call did not match the one received in the
298    /// corresponding `expect()` call.
299    ///
300    /// Should be run with [`tokio::task::spawn`] so it runs concurrently with whatever is being tested.
301    ///
302    /// # Examples
303    /// See examples for [`make_mocker()`].
304    ///
305    /// # Panics
306    /// Should not panic, if an HTTP body fails to be made it should be presented as an error
307    /// during `expect()`. Still uses `unwrap()` for brevity, and to avoid a duplicate of
308    /// [`MockError::HttpBody`] in [`MockRunError`].
309    #[allow(clippy::similar_names)]
310    pub async fn run(self) -> Result<(), MockRunError> {
311        let KubernetesMocker {
312            handle,
313            expected_requests,
314        } = self;
315        let mut handle = pin!(handle);
316        let expected_num_requests = expected_requests.len();
317        let mut watch_handles = vec![];
318
319        for (i, (expected_api_call, result)) in expected_requests.into_iter().enumerate() {
320            println!("api call {i}, expected {}", expected_api_call.uri());
321            let (request, send) = tokio::time::timeout(TIMEOUT_DURATION, handle.next_request())
322                .await
323                .ok()
324                .flatten()
325                .ok_or(MockRunError::TooFewApiCalls {
326                    expected: expected_num_requests,
327                    received: i,
328                })?;
329            println!("Got request {request:#?}");
330
331            // Need to deconstruct/reconstruct to get body bytes without consuming the requests
332            let (eparts, ebody) = expected_api_call.into_parts();
333            let (aparts, abody) = request.into_parts();
334            let ebody = hyper::body::to_bytes(ebody).await.unwrap();
335            let abody = hyper::body::to_bytes(abody).await.unwrap();
336
337            let same_as_expected =
338                eparts.uri == aparts.uri && ebody == abody && eparts.method == aparts.method;
339
340            let expected_api_call = Request::from_parts(eparts, ebody.into());
341            let request = Request::from_parts(aparts, abody.into());
342
343            if !same_as_expected {
344                println!("NOT SAME AS EXPECTED");
345                return Err(MockRunError::IncorrectApiCall {
346                    received: request,
347                    expected: expected_api_call,
348                    idx: i,
349                });
350            }
351
352            match result {
353                MockApiResponse::Single(resp) => {
354                    send.send_response(Response::builder().body(Body::from(resp)).unwrap());
355                }
356                MockApiResponse::Stream(mut stream) => {
357                    // spawn a new future, feed items from stream into a body
358                    stream.reverse();
359                    let (mut sender, body) = Body::channel(); // Need to use a channel to
360                                                              // continuously send data
361                    let time = std::time::Instant::now();
362                    let fut = async move {
363                        while let Some(watch_response) = stream.pop() {
364                            println!(
365                                "Sending event {watch_response:?} at {:?}",
366                                std::time::Instant::now().duration_since(time)
367                            );
368                            match watch_response {
369                                MockApiWatchResponse::Wait(duration) => {
370                                    tokio::time::sleep(duration).await;
371                                }
372                                MockApiWatchResponse::Event(e) => {
373                                    sender.send_data(e.into()).await.unwrap();
374                                    println!("Sent data");
375                                }
376                            }
377                        }
378                    };
379                    watch_handles.push(tokio::task::spawn(fut));
380                    send.send_response(Response::new(body));
381                }
382            };
383        }
384        let mut extra_requests = Vec::new();
385        while let Ok(Some((request, _))) =
386            tokio::time::timeout(TIMEOUT_DURATION, handle.next_request()).await
387        {
388            extra_requests.push(request);
389        }
390        // Close all watches (we do this after waiting for extra requests, to make sure closing the
391        // watch does not trigger extra requests that we fail on).
392        futures::future::join_all(watch_handles.into_iter()).await;
393        if extra_requests.is_empty() {
394            Ok(())
395        } else {
396            Err(MockRunError::TooManyApiCalls {
397                expected: expected_num_requests,
398                received: expected_num_requests + extra_requests.len(),
399                call: extra_requests,
400            })
401        }
402    }
403}
404
405fn list<T: kube::Resource + Clone>(items: &[T]) -> ObjectList<T> {
406    use kube::core::ListMeta;
407    ObjectList::<T> {
408        items: items.iter().map(Clone::clone).collect::<Vec<_>>(),
409        metadata: ListMeta {
410            resource_version: Some("1".into()),
411            ..Default::default()
412        },
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419
420    #[tokio::test]
421    async fn test_mocker() {
422        use k8s_openapi::api::core::v1::Node;
423        use kube::{api::ListParams, Api};
424        let (client, mut mocker) = make_mocker();
425        let mocker_client = client.clone();
426        mocker
427            .expect(
428                move || async {
429                    let api: Api<Node> = Api::all(mocker_client);
430                    let _nodes = api.list(&ListParams::default()).await;
431                },
432                MockReturn::List(&[Node::default()]),
433            )
434            .await
435            // TODO: look into DSL?
436            //.expect(list(Node, ListParams::default().labels("foo=bar"))))
437            .unwrap();
438        let spawned = tokio::spawn(mocker.run());
439        let api: Api<Node> = Api::all(client);
440        let _nodes = api.list(&ListParams::default()).await;
441        // Need two unwraps: the first one verifies that the spawned future hasn't panicked, and
442        // the second unwraps the result returned from the mocker (an error indicates a failure).
443        spawned.await.unwrap().unwrap();
444    }
445
446    #[tokio::test]
447    #[should_panic(expected = "TooFewApiCalls")]
448    async fn test_mocker_too_few_api_calls() {
449        use k8s_openapi::api::core::v1::Node;
450        use kube::{api::ListParams, Api};
451        let (client, mut mocker) = make_mocker();
452        let mocker_client = client.clone();
453        mocker
454            .expect(
455                move || async {
456                    let api: Api<Node> = Api::all(mocker_client);
457                    let _nodes = api.list(&ListParams::default()).await;
458                },
459                MockReturn::List(&[Node::default()]),
460            )
461            .await
462            .unwrap();
463        let spawned = tokio::spawn(mocker.run());
464        //let api: Api<Node> = Api::all(client);
465        //let _nodes = api.list(&ListParams::default()).await;
466        spawned.await.unwrap().unwrap();
467    }
468
469    #[tokio::test]
470    #[should_panic(expected = "TooManyApiCalls")]
471    async fn test_mocker_too_many_api_calls() {
472        use k8s_openapi::api::core::v1::Node;
473        use kube::{api::ListParams, Api};
474        let (client, mocker) = make_mocker();
475        let spawned = tokio::spawn(mocker.run());
476
477        let api: Api<Node> = Api::all(client);
478        let _nodes = api.list(&ListParams::default()).await;
479        spawned.await.unwrap().unwrap();
480    }
481
482    #[tokio::test]
483    #[should_panic(expected = "IncorrectApiCall")]
484    async fn wrong_api_call() {
485        use k8s_openapi::api::core::v1::{Node, Pod};
486        use kube::{api::ListParams, Api};
487        let (client, mut mocker) = make_mocker();
488        let mocker_client = client.clone();
489        mocker
490            .expect(
491                move || async {
492                    let api: Api<Node> = Api::all(mocker_client);
493                    let _nodes = api.list(&ListParams::default()).await;
494                },
495                MockReturn::List(&[Node::default()]),
496            )
497            .await
498            .unwrap();
499        let spawned = tokio::spawn(mocker.run());
500        let api: Api<Pod> = Api::all(client);
501        let _pods = api.list(&ListParams::default()).await;
502        spawned.await.unwrap().unwrap();
503    }
504}