use crate::fs_watch::FileSystemWatcher;
use crate::grpc_sock;
use crate::plugin_registration_api::v1::{
registration_client::RegistrationClient, InfoRequest, PluginInfo, RegistrationStatus,
API_VERSION,
};
use anyhow::Context;
use notify::Event;
use tokio::fs::{create_dir_all, read_dir};
use tokio::sync::{RwLock, RwLockWriteGuard};
use tokio_stream::wrappers::ReadDirStream;
use tokio_stream::StreamExt;
use tonic::Request;
use tracing::{debug, error, trace, warn};
use std::collections::HashMap;
use std::convert::TryFrom;
use std::path::{Path, PathBuf};
#[cfg(target_family = "unix")]
const DEFAULT_PLUGIN_PATH: &str = "/var/lib/kubelet/plugins_registry/";
#[cfg(target_family = "windows")]
const DEFAULT_PLUGIN_PATH: &str = "c:\\ProgramData\\kubelet\\plugins_registry";
const SOCKET_EXTENSION: &str = "sock";
const ALLOWED_PLUGIN_TYPES: &[PluginType] = &[PluginType::CSIPlugin];
#[derive(Debug, PartialEq)]
enum PluginType {
CSIPlugin,
DevicePlugin,
}
impl TryFrom<&str> for PluginType {
type Error = anyhow::Error;
fn try_from(value: &str) -> Result<Self, Self::Error> {
match value {
"CSIPlugin" => Ok(PluginType::CSIPlugin),
"DevicePlugin" => Ok(PluginType::DevicePlugin),
_ => Err(anyhow::anyhow!(
"Unknown plugin type {}. Allowed types are 'CSIPlugin' and 'DevicePlugin'",
value
)),
}
}
}
#[derive(Debug)]
struct PluginEntry {
plugin_path: PathBuf,
endpoint: Option<PathBuf>,
}
pub struct PluginRegistry {
plugins: RwLock<HashMap<String, PluginEntry>>,
plugin_dir: PathBuf,
}
impl Default for PluginRegistry {
fn default() -> Self {
PluginRegistry {
plugin_dir: PathBuf::from(DEFAULT_PLUGIN_PATH),
plugins: RwLock::new(HashMap::new()),
}
}
}
impl PluginRegistry {
pub fn new<P: AsRef<Path>>(plugin_dir: P) -> Self {
PluginRegistry {
plugin_dir: PathBuf::from(plugin_dir.as_ref()),
..Default::default()
}
}
#[allow(dead_code)]
pub async fn get_endpoint(&self, plugin_name: &str) -> Option<PathBuf> {
let plugins = self.plugins.read().await;
plugins
.get(plugin_name)
.map(|v| v.endpoint.as_ref().unwrap_or(&v.plugin_path).to_owned())
}
pub async fn run(&self) -> anyhow::Result<()> {
create_dir_all(&self.plugin_dir).await?;
let dir_entries: Vec<PathBuf> = ReadDirStream::new(read_dir(&self.plugin_dir).await?)
.map(|res| res.map(|entry| entry.path()))
.collect::<Result<Vec<PathBuf>, _>>()
.await?;
self.handle_create(Event {
paths: dir_entries,
..Default::default()
})
.await?;
let mut event_stream = FileSystemWatcher::new(&self.plugin_dir)?;
while let Some(res) = event_stream.next().await {
match res {
Ok(event) if event.kind.is_create() => {
if let Err(e) = self.handle_create(event).await {
error!("An error occurred while processing a new plugin: {:?}", e);
}
}
Ok(event) if event.kind.is_remove() => self.handle_delete(event).await,
Ok(_) => continue,
Err(e) => error!("An error occurred while watching the plugin directory. Will continue to retry: {:?}", e),
}
}
Ok(())
}
async fn handle_create(&self, event: Event) -> anyhow::Result<()> {
for discovered_path in plugin_paths(event.paths) {
debug!(
"Beginning plugin registration for plugin discovered at {}",
discovered_path.display()
);
let plugin_info = get_plugin_info(&discovered_path).await?;
debug!(
"Successfully retrieved information for plugin discovered at {}:\n {:?}",
discovered_path.display(),
plugin_info
);
if let Err(e) = self.validate(&plugin_info, &discovered_path).await {
inform_plugin(&discovered_path, Some(e.to_string())).await?;
return Err(e).with_context(|| {
format!(
"Validation step failed for plugin discovered at {}",
discovered_path.display()
)
});
}
debug!(
"Successfully validated plugin discovered at {}:\n {:?}",
discovered_path.display(),
plugin_info
);
self.register(&plugin_info, &discovered_path).await;
inform_plugin(&discovered_path, None).await?;
debug!("Plugin registration complete for {:?}", plugin_info)
}
Ok(())
}
async fn handle_delete(&self, event: Event) {
let mut plugins = self.plugins.write().await;
for deleted_plugin in plugin_paths(event.paths) {
remove_plugin(&mut plugins, deleted_plugin);
}
}
async fn register(&self, info: &PluginInfo, discovered_path: &PathBuf) {
let mut lock = self.plugins.write().await;
lock.insert(
info.name.clone(),
PluginEntry {
plugin_path: discovered_path.to_owned(),
endpoint: match info.endpoint.is_empty() {
true => None,
false => Some(PathBuf::from(&info.endpoint)),
},
},
);
}
async fn validate(&self, info: &PluginInfo, discovered_path: &PathBuf) -> anyhow::Result<()> {
trace!(
"Starting validation for plugin {:?} discovered at path {}",
info,
discovered_path.display()
);
self.validate_plugin_type(info.r#type.as_str())?;
trace!("Type validation complete for plugin {:?}", info);
trace!("Checking supported versions for plugin {:?}", info);
self.validate_plugin_version(&info.supported_versions)?;
trace!("Supported version check complete for plugin {:?}", info);
trace!("Checking for naming collisions for plugin {:?}", info);
self.validate_is_unique(info, discovered_path).await?;
trace!("Naming collision check complete for plugin {:?}", info);
Ok(())
}
fn validate_plugin_type(&self, plugin_type: &str) -> anyhow::Result<()> {
let plugin_type = PluginType::try_from(plugin_type)?;
if !is_allowed_plugin_type(plugin_type) {
warn!("DevicePlugins are not currently supported");
return Err(anyhow::anyhow!("DevicePlugins are not currently supported"));
}
Ok(())
}
fn validate_plugin_version(&self, supported_versions: &[String]) -> anyhow::Result<()> {
if !supported_versions.iter().any(|s| s == API_VERSION) {
return Err(anyhow::anyhow!(
"Plugin doesn't support version {}",
API_VERSION
));
}
Ok(())
}
async fn validate_is_unique(
&self,
info: &PluginInfo,
discovered_path: &PathBuf,
) -> anyhow::Result<()> {
let plugins = self.plugins.read().await;
if let Some(current_path) = plugins.get(&info.name) {
if !info.endpoint.is_empty()
&& Some(PathBuf::from(&info.endpoint)) != current_path.endpoint
{
return Err(anyhow::format_err!(
"Plugin already exists with an endpoint of {:?}, which differs from the new endpoint of {}",
current_path.endpoint,
info.endpoint
));
} else if *discovered_path != current_path.plugin_path {
return Err(anyhow::anyhow!(
"Plugin already exists with an endpoint of {}, which differs from the new endpoint of {}",
current_path.plugin_path.display(),
discovered_path.display()
));
}
}
Ok(())
}
}
fn remove_plugin(
plugins: &mut RwLockWriteGuard<HashMap<String, PluginEntry>>,
deleted_plugin: PathBuf,
) {
let key = match plugins
.iter()
.find(|(_, v)| *v.plugin_path == deleted_plugin)
{
Some((key, _)) => key.to_owned(),
None => return,
};
plugins.remove(&key);
}
fn is_allowed_plugin_type(t: PluginType) -> bool {
ALLOWED_PLUGIN_TYPES.iter().any(|item| *item == t)
}
async fn get_plugin_info(path: &PathBuf) -> anyhow::Result<PluginInfo> {
trace!("Connecting to plugin at {:?} for GetInfo", path);
let chan = grpc_sock::client::socket_channel(path).await?;
let mut client = RegistrationClient::new(chan);
let req = Request::new(InfoRequest {});
trace!("Calling GetInfo at {:?}", path);
client
.get_info(req)
.await
.map(|resp| resp.into_inner())
.map_err(|status| {
anyhow::anyhow!(
"GetInfo call to {} failed with error code {} and message {}",
path.display(),
status.code(),
status.message()
)
})
}
async fn inform_plugin(path: &PathBuf, error: Option<String>) -> anyhow::Result<()> {
trace!(
"Connecting to plugin at {:?} for NotifyRegistrationStatus",
path
);
let chan = grpc_sock::client::socket_channel(path).await?;
let mut client = RegistrationClient::new(chan);
let req = Request::new(RegistrationStatus {
plugin_registered: error.is_none(),
error: error.unwrap_or_else(String::new),
});
trace!("Calling NotifyRegistrationStatus at {:?}", path);
client
.notify_registration_status(req)
.await
.map_err(|status| {
anyhow::anyhow!(
"NotifyRegistrationStatus call to {} failed with error code {} and message {}",
path.display(),
status.code(),
status.message()
)
})?;
Ok(())
}
fn plugin_paths(paths: Vec<PathBuf>) -> impl Iterator<Item = PathBuf> {
paths
.into_iter()
.filter(|p| !p.is_dir() && p.extension().unwrap_or_default() == SOCKET_EXTENSION)
}
#[cfg(test)]
mod test {
use super::*;
use crate::plugin_registration_api::v1::{
registration_server::{Registration, RegistrationServer},
InfoRequest, PluginInfo, RegistrationStatusResponse, API_VERSION,
};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc::{self, Receiver, Sender};
use tokio::sync::Mutex;
use tokio::time::timeout;
#[cfg(target_family = "windows")]
use tokio_compat_02::FutureExt;
use tonic::{transport::Server, Request, Response, Status};
const FAKE_ENDPOINT: &str = "/tmp/foo.sock";
#[derive(Debug)]
struct TestCSIPlugin {
name: String,
registration_response: Mutex<Sender<RegistrationStatus>>,
}
#[tonic::async_trait]
impl Registration for TestCSIPlugin {
async fn get_info(
&self,
_req: Request<InfoRequest>,
) -> Result<Response<PluginInfo>, Status> {
Ok(Response::new(PluginInfo {
r#type: "CSIPlugin".to_string(),
name: self.name.clone(),
endpoint: FAKE_ENDPOINT.to_string(),
supported_versions: vec![API_VERSION.to_string()],
}))
}
async fn notify_registration_status(
&self,
req: Request<RegistrationStatus>,
) -> Result<Response<RegistrationStatusResponse>, Status> {
self.registration_response
.lock()
.await
.send(req.into_inner())
.await
.expect("should be able to send registration status on channel");
Ok(Response::new(RegistrationStatusResponse {}))
}
}
#[derive(Debug)]
struct InvalidCSIPlugin {
name: String,
registration_response: Mutex<Sender<RegistrationStatus>>,
}
#[tonic::async_trait]
impl Registration for InvalidCSIPlugin {
async fn get_info(
&self,
_req: Request<InfoRequest>,
) -> Result<Response<PluginInfo>, Status> {
Ok(Response::new(PluginInfo {
r#type: "CSIPlugin".to_string(),
name: self.name.clone(),
endpoint: FAKE_ENDPOINT.to_string(),
supported_versions: vec!["nope".to_string()],
}))
}
async fn notify_registration_status(
&self,
req: Request<RegistrationStatus>,
) -> Result<Response<RegistrationStatusResponse>, Status> {
self.registration_response
.lock()
.await
.send(req.into_inner())
.await
.expect("should be able to send registration status on channel");
Ok(Response::new(RegistrationStatusResponse {}))
}
}
fn setup() -> (tempfile::TempDir, Arc<PluginRegistry>) {
let tempdir = tempfile::tempdir().expect("should be able to create tempdir");
let registrar = PluginRegistry::new(&tempdir);
(tempdir, Arc::new(registrar))
}
async fn setup_server(plugin: impl Registration, path: impl AsRef<Path>) {
let socket = grpc_sock::server::Socket::new(&path)
.expect("unable to setup server listening on socket");
tokio::spawn(async move {
let serv = Server::builder()
.add_service(RegistrationServer::new(plugin))
.serve_with_incoming(socket);
#[cfg(target_family = "windows")]
let serv = serv.compat();
serv.await.expect("Unable to serve test plugin");
println!("server exited");
});
}
async fn start_registrar(registrar: Arc<PluginRegistry>) {
tokio::spawn(async move {
registrar
.run()
.await
.expect("registrar didn't run successfully");
println!("registrar exited");
});
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
}
async fn get_registration_response(mut rx: Receiver<RegistrationStatus>) -> RegistrationStatus {
timeout(Duration::from_secs(10), rx.recv())
.await
.expect("timed out while waiting for registration status response")
.expect("Should have received a valid response in the channel")
}
#[tokio::test]
async fn test_successful_registration() {
let (tempdir, registrar) = setup();
let (tx, rx) = mpsc::channel(1);
let plugin = TestCSIPlugin {
name: "foo".to_string(),
registration_response: Mutex::new(tx),
};
start_registrar(registrar.clone()).await;
setup_server(plugin, tempdir.path().join("foo.sock")).await;
let registration_status = get_registration_response(rx).await;
assert!(
registration_status.plugin_registered,
"Plugin did not receive successful registration request"
);
assert!(
registration_status.error.is_empty(),
"Error message should be empty"
);
let plugin_endpoint = registrar
.get_endpoint("foo")
.await
.expect("Should be able to get plugin info");
assert_eq!(
plugin_endpoint,
PathBuf::from(FAKE_ENDPOINT),
"Incorrect endpoint configured"
);
}
#[tokio::test]
async fn test_unsuccessful_registration() {
let (tempdir, registrar) = setup();
let (tx, rx) = mpsc::channel(1);
let plugin = InvalidCSIPlugin {
name: "foo".to_string(),
registration_response: Mutex::new(tx),
};
start_registrar(registrar.clone()).await;
setup_server(plugin, tempdir.path().join("foo.sock")).await;
let registration_status = get_registration_response(rx).await;
assert!(
!registration_status.plugin_registered,
"Plugin should not have been registered"
);
assert!(
!registration_status.error.is_empty(),
"Error message should be set"
);
assert!(
registrar.get_endpoint("foo").await.is_none(),
"Plugin shouldn't be registered in memory"
);
}
#[tokio::test]
async fn test_existing_socket() {
let (tempdir, registrar) = setup();
let (tx, rx) = mpsc::channel(1);
let plugin = TestCSIPlugin {
name: "foo".to_string(),
registration_response: Mutex::new(tx),
};
setup_server(plugin, tempdir.path().join("foo.sock")).await;
tokio::time::sleep(Duration::from_secs(1)).await;
start_registrar(registrar.clone()).await;
let registration_status = get_registration_response(rx).await;
assert!(
registration_status.plugin_registered,
"Plugin did not receive successful registration request"
);
assert!(
registration_status.error.is_empty(),
"Error message should be empty"
);
let plugin_endpoint = registrar
.get_endpoint("foo")
.await
.expect("Should be able to get plugin info");
assert_eq!(
plugin_endpoint,
PathBuf::from(FAKE_ENDPOINT),
"Incorrect endpoint configured"
);
}
#[tokio::test]
async fn test_unregister() {
let (tempdir, registrar) = setup();
let (tx, rx) = mpsc::channel(1);
let plugin = TestCSIPlugin {
name: "foo".to_string(),
registration_response: Mutex::new(tx),
};
let sock_path = tempdir.path().join("foo.sock");
let socket = grpc_sock::server::Socket::new(&sock_path)
.expect("unable to setup server listening on socket");
let (stop_tx, stop_rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
let server = Server::builder()
.add_service(RegistrationServer::new(plugin))
.serve_with_incoming(socket);
tokio::select! {
res = server => {
res.expect("Unable to serve test plugin");
println!("server exited");
}
_ = stop_rx => {}
}
});
tokio::time::sleep(Duration::from_secs(1)).await;
start_registrar(registrar.clone()).await;
let registration_status = get_registration_response(rx).await;
assert!(
registration_status.plugin_registered,
"Plugin did not receive successful registration request"
);
stop_tx.send(()).expect("Unable to send stop signal");
tokio::fs::remove_file(sock_path)
.await
.expect("Unable to remove socket");
tokio::time::sleep(Duration::from_secs(3)).await;
assert!(
registrar.get_endpoint("foo").await.is_none(),
"Plugin shouldn't be registered in memory"
);
}
fn valid_info() -> PluginInfo {
PluginInfo {
r#type: "CSIPlugin".to_string(),
name: "test".to_string(),
endpoint: FAKE_ENDPOINT.to_string(),
supported_versions: vec![API_VERSION.to_string()],
}
}
#[tokio::test]
async fn test_invalid_type() {
let registrar = PluginRegistry::new("/tmp/foo");
let mut info = valid_info();
info.r#type = "DevicePlugin".to_string();
assert!(
registrar
.validate(&info, &PathBuf::from("/fake"))
.await
.is_err(),
"DevicePlugin type should error"
);
info.r#type = "NonExistent".to_string();
assert!(
registrar
.validate(&info, &PathBuf::from("/fake"))
.await
.is_err(),
"Invalid type should error"
);
}
#[tokio::test]
async fn test_invalid_plugin_version() {
let registrar = PluginRegistry::new("/tmp/foo");
let mut info = valid_info();
info.supported_versions = vec!["v1beta1".to_string()];
assert!(
registrar
.validate(&info, &PathBuf::from("/fake"))
.await
.is_err(),
"Unsupported version should error"
);
}
#[tokio::test]
async fn test_invalid_name_with_different_endpoint() {
let registrar = PluginRegistry::new("/tmp/foo");
let mut info = valid_info();
let discovered_path = PathBuf::from("/tmp/foo/bar.sock");
registrar.register(&info, &discovered_path).await;
info.endpoint = "/another/path.sock".to_string();
assert!(
registrar.validate(&info, &discovered_path).await.is_err(),
"Different endpoint with same name should error"
);
}
#[tokio::test]
async fn test_invalid_name_with_different_discovered_path() {
let registrar = PluginRegistry::new("/tmp/foo");
let mut info = valid_info();
info.endpoint = String::new();
registrar
.register(&info, &PathBuf::from("/tmp/foo/bar.sock"))
.await;
assert!(
registrar
.validate(&info, &PathBuf::from("/tmp/foo/another.sock"))
.await
.is_err(),
"Different discovered path with same name should error"
);
}
#[tokio::test]
async fn test_reregistration() {
let registrar = PluginRegistry::new("/tmp/foo");
let info = valid_info();
let discovered_path = PathBuf::from("/tmp/foo/bar.sock");
registrar.register(&info, &discovered_path).await;
assert!(
registrar.validate(&info, &discovered_path).await.is_ok(),
"Exact same plugin info shouldn't fail"
);
}
}