use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
use std::sync::RwLock;
use k8s_openapi::NamespaceResourceScope;
use kube::Api;
use kube::Client;
use kube::Resource;
use serde::de::DeserializeOwned;
use crate::error::Error;
use crate::error::Result;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CachingStrategy {
Strict,
Adhoc,
Extendable,
}
pub trait ProvideApi<R>
where
R: Resource<Scope = NamespaceResourceScope> + Clone + DeserializeOwned + Debug + Send + Sync + 'static,
R::DynamicType: Default,
{
fn get(&self, namespace: &str) -> Result<Arc<Api<R>>>;
}
pub struct CachedApiProvider<R>
where
R: Resource<Scope = NamespaceResourceScope> + Clone + DeserializeOwned + Debug + Send + Sync + 'static,
R::DynamicType: Default,
{
client: Client,
cache: RwLock<HashMap<String, Arc<Api<R>>>>,
}
impl<R> CachedApiProvider<R>
where
R: Resource<Scope = NamespaceResourceScope> + Clone + DeserializeOwned + Debug + Send + Sync + 'static,
R::DynamicType: Default,
{
pub fn new(client: Client) -> Self {
Self {
client,
cache: RwLock::new(HashMap::new()),
}
}
}
impl<R> ProvideApi<R> for CachedApiProvider<R>
where
R: Resource<Scope = NamespaceResourceScope> + Clone + DeserializeOwned + Debug + Send + Sync + 'static,
R::DynamicType: Default,
{
fn get(&self, namespace: &str) -> Result<Arc<Api<R>>> {
{
let cache = self.cache.read()?;
if let Some(api) = cache.get(namespace) {
return Ok(Arc::clone(api));
}
}
let mut cache = self.cache.write()?;
if let Some(api) = cache.get(namespace) {
return Ok(Arc::clone(api));
}
let api = Arc::new(Api::<R>::namespaced(self.client.clone(), namespace));
cache.insert(namespace.to_string(), Arc::clone(&api));
Ok(api)
}
}
enum CacheStorage<R>
where
R: Resource<Scope = NamespaceResourceScope> + Clone + DeserializeOwned + Debug + Send + Sync + 'static,
R::DynamicType: Default,
{
Static(HashMap<String, Arc<Api<R>>>),
Dynamic(RwLock<HashMap<String, Arc<Api<R>>>>),
}
pub struct StaticApiProvider<R>
where
R: Resource<Scope = NamespaceResourceScope> + Clone + DeserializeOwned + Debug + Send + Sync + 'static,
R::DynamicType: Default,
{
client: Client,
strategy: CachingStrategy,
cache: CacheStorage<R>,
}
impl<R> StaticApiProvider<R>
where
R: Resource<Scope = NamespaceResourceScope> + Clone + DeserializeOwned + Debug + Send + Sync + 'static,
R::DynamicType: Default,
{
pub fn new<I, S>(client: Client, namespaces: I, strategy: CachingStrategy) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let mut map = HashMap::new();
for namespace in namespaces {
let api = Arc::new(Api::<R>::namespaced(client.clone(), namespace.as_ref()));
map.insert(namespace.as_ref().to_string(), api);
}
let cache = match strategy {
CachingStrategy::Strict | CachingStrategy::Adhoc => CacheStorage::Static(map),
CachingStrategy::Extendable => CacheStorage::Dynamic(RwLock::new(map)),
};
Self {
client,
strategy,
cache,
}
}
}
impl<R> ProvideApi<R> for StaticApiProvider<R>
where
R: Resource<Scope = NamespaceResourceScope> + Clone + DeserializeOwned + Debug + Send + Sync + 'static,
R::DynamicType: Default,
{
fn get(&self, namespace: &str) -> Result<Arc<Api<R>>> {
match (&self.cache, self.strategy) {
(CacheStorage::Static(map), CachingStrategy::Strict) => {
map.get(namespace).map(Arc::clone).ok_or_else(|| {
Error::UserInput(format!(
"Namespace '{namespace}' not found in static cache. Did you include it during initialization?"
))
})
}
(CacheStorage::Static(map), CachingStrategy::Adhoc) => {
if let Some(api) = map.get(namespace) {
return Ok(Arc::clone(api));
}
Ok(Arc::new(Api::<R>::namespaced(self.client.clone(), namespace)))
}
(CacheStorage::Dynamic(lock), CachingStrategy::Extendable) => {
{
let cache = lock.read()?;
if let Some(api) = cache.get(namespace) {
return Ok(Arc::clone(api));
}
}
let mut cache = lock.write()?;
if let Some(api) = cache.get(namespace) {
return Ok(Arc::clone(api));
}
let api = Arc::new(Api::<R>::namespaced(self.client.clone(), namespace));
cache.insert(namespace.to_string(), Arc::clone(&api));
Ok(api)
}
_ => Err(Error::InvalidApiProviderConfig),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use k8s_openapi::api::core::v1::ConfigMap;
use kube::client::Body;
use kube::Client;
use std::sync::Arc;
use std::sync::Barrier;
use std::thread;
use tower_test::mock;
fn test_client() -> Client {
let (mock_service, _handle) = mock::pair::<http::Request<Body>, http::Response<hyper::body::Incoming>>();
Client::new(mock_service, "default")
}
#[tokio::test]
async fn test_static_provider_strict_cached_namespace() {
let client = test_client();
let provider: StaticApiProvider<ConfigMap> =
StaticApiProvider::new(client, vec!["default", "kube-system"], CachingStrategy::Strict);
let result_default = provider.get("default");
let result_kube_system = provider.get("kube-system");
assert!(result_default.is_ok());
assert!(result_kube_system.is_ok());
}
#[tokio::test]
async fn test_static_provider_strict_uncached_namespace() {
let client = test_client();
let provider: StaticApiProvider<ConfigMap> =
StaticApiProvider::new(client, vec!["default"], CachingStrategy::Strict);
let result = provider.get("unknown-namespace");
assert!(result.is_err());
if let Err(Error::UserInput(msg)) = result {
assert!(msg.contains("unknown-namespace"));
assert!(msg.contains("not found in static cache"));
} else {
panic!("Expected UserInputError");
}
}
#[tokio::test]
async fn test_static_provider_adhoc_cached_namespace() {
let client = test_client();
let provider: StaticApiProvider<ConfigMap> =
StaticApiProvider::new(client, vec!["default", "kube-system"], CachingStrategy::Adhoc);
let result = provider.get("default");
assert!(result.is_ok());
}
#[tokio::test]
async fn test_static_provider_adhoc_uncached_namespace() {
let client = test_client();
let provider: StaticApiProvider<ConfigMap> =
StaticApiProvider::new(client, vec!["default"], CachingStrategy::Adhoc);
let result1 = provider.get("unknown-namespace");
let result2 = provider.get("unknown-namespace");
assert!(result1.is_ok());
assert!(result2.is_ok());
assert!(!Arc::ptr_eq(&result1.unwrap(), &result2.unwrap()));
}
#[tokio::test]
async fn test_static_provider_extendable_cached_namespace() {
let client = test_client();
let provider: StaticApiProvider<ConfigMap> =
StaticApiProvider::new(client, vec!["default", "kube-system"], CachingStrategy::Extendable);
let result = provider.get("default");
assert!(result.is_ok());
}
#[tokio::test]
async fn test_static_provider_extendable_uncached_namespace() {
let client = test_client();
let provider: StaticApiProvider<ConfigMap> =
StaticApiProvider::new(client, vec!["default"], CachingStrategy::Extendable);
let result1 = provider.get("new-namespace");
let result2 = provider.get("new-namespace");
assert!(result1.is_ok());
assert!(result2.is_ok());
assert!(Arc::ptr_eq(&result1.unwrap(), &result2.unwrap()));
}
#[tokio::test]
async fn test_static_provider_extendable_thread_safety() {
let client = test_client();
let provider = Arc::new(StaticApiProvider::<ConfigMap>::new(
client,
vec!["default"],
CachingStrategy::Extendable,
));
let num_threads = 10;
let barrier = Arc::new(Barrier::new(num_threads));
let mut handles = vec![];
for _ in 0..num_threads {
let provider = Arc::clone(&provider);
let barrier = Arc::clone(&barrier);
let handle = thread::spawn(move || {
barrier.wait();
provider.get("concurrent-namespace")
});
handles.push(handle);
}
let results: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
assert!(results.iter().all(|r| r.is_ok()));
let first_api = results[0].as_ref().unwrap();
for result in &results[1..] {
assert!(Arc::ptr_eq(first_api, result.as_ref().unwrap()));
}
}
#[tokio::test]
async fn test_cached_provider_lazy_loading() {
let client = test_client();
let provider: CachedApiProvider<ConfigMap> = CachedApiProvider::new(client);
let result1 = provider.get("lazy-namespace");
let result2 = provider.get("lazy-namespace");
assert!(result1.is_ok());
assert!(result2.is_ok());
assert!(Arc::ptr_eq(&result1.unwrap(), &result2.unwrap()));
}
#[tokio::test]
async fn test_cached_provider_multiple_namespaces() {
let client = test_client();
let provider: CachedApiProvider<ConfigMap> = CachedApiProvider::new(client);
let ns1 = provider.get("namespace-1");
let ns2 = provider.get("namespace-2");
let ns1_again = provider.get("namespace-1");
assert!(ns1.is_ok());
assert!(ns2.is_ok());
assert!(ns1_again.is_ok());
assert!(Arc::ptr_eq(&ns1.unwrap(), &ns1_again.unwrap()));
}
}