generic_device_plugin/
lib.rs1use std::{fs, io::ErrorKind, marker::PhantomData, path::PathBuf};
2
3use anyhow::bail;
4use notify::{recommended_watcher, RecursiveMode, Watcher};
5use tokio::{
6 net::{UnixListener, UnixStream},
7 spawn,
8 sync::watch,
9};
10use tokio_stream::wrappers::UnixListenerStream;
11use tonic::{
12 transport::{Endpoint, Server, Uri},
13 Request,
14};
15use tower::service_fn;
16use tracing::{error, info, warn};
17
18use self::pb::{
19 device_plugin_server::DevicePluginServer, registration_client::RegistrationClient,
20 DevicePluginOptions, RegisterRequest,
21};
22pub use self::{
23 pb::{
24 CdiDevice, ContainerAllocateResponse, ContainerPreferredAllocationResponse, Device,
25 DeviceSpec, Mount, NumaNode, TopologyInfo,
26 },
27 service::GenericDevicePlugin,
28};
29
30mod service;
31mod pb {
32 tonic::include_proto!("v1beta1");
33}
34
35static VERSION: &str = "v1beta1";
36static KUBELET_SOCK: &str = "kubelet.sock";
37
38pub struct GenericDevicePluginServer<DP: GenericDevicePlugin> {
39 dir_path: PathBuf,
40 socket_name: String,
41 _phantom: PhantomData<DP>,
42}
43
44impl<DP: GenericDevicePlugin> GenericDevicePluginServer<DP> {
45 pub fn new(dir_path: PathBuf, socket_name: String) -> Self {
46 Self {
47 dir_path,
48 socket_name,
49 _phantom: PhantomData,
50 }
51 }
52
53 pub async fn run(self) -> anyhow::Result<()> {
59 let socket_path = self.dir_path.join(&self.socket_name);
60
61 loop {
62 match std::os::unix::net::UnixStream::connect(&socket_path) {
63 Err(e) if e.kind() == ErrorKind::NotFound => {}
64 Err(e) if e.kind() == ErrorKind::ConnectionRefused => {
65 fs::remove_file(&socket_path)?
66 }
67 Err(e) => bail!("unable to ensure uds is available: {e:?}"),
68 Ok(_) => bail!("active unix socket connect exist on {socket_path:?}"),
69 }
70
71 let uds = UnixListener::bind(socket_path.clone())?;
72
73 let (tx, mut rx) = watch::channel(());
74 let mut watcher = recommended_watcher(move |res| {
75 if let Err(e) = res {
76 error!("failed to watch device plugin socket: {e}")
77 }
78 let _ = tx.send(());
79 })?;
80
81 watcher.watch(&socket_path, RecursiveMode::NonRecursive)?;
82
83 let handle = spawn(
84 Server::builder()
85 .add_service(DevicePluginServer::new(DP::default()))
86 .serve_with_incoming_shutdown(UnixListenerStream::new(uds), async move {
87 let _ = rx.changed().await;
88 warn!("socket file changed, restarting server...")
89 }),
90 );
91 info!("plugin server started on {socket_path:?}!");
92
93 self.register().await?;
94 info!("plugin registered!");
95
96 let _ = handle.await;
97 let _ = fs::remove_file(&socket_path);
98 }
99 }
100
101 async fn register(&self) -> anyhow::Result<()> {
102 let register_client_socket_path = self.dir_path.join(KUBELET_SOCK);
103 RegistrationClient::new(
104 Endpoint::try_from("http://[::]:50051")?
105 .connect_with_connector(service_fn(move |_: Uri| {
106 UnixStream::connect(register_client_socket_path.clone())
107 }))
108 .await?,
109 )
110 .register(Request::new(RegisterRequest {
111 endpoint: self.socket_name.clone(),
112 resource_name: DP::RESOURCE_NAME.to_string(),
113 version: VERSION.to_string(),
114 options: Some(DevicePluginOptions {
115 pre_start_required: DP::PRE_START_REQUIRED,
116 get_preferred_allocation_available: DP::GET_PREFERRED_ALLOCATION_AVAILABLE,
117 }),
118 }))
119 .await?;
120 Ok(())
121 }
122}