generic_device_plugin/
lib.rs

1use std::{fs, io::ErrorKind, marker::PhantomData, path::PathBuf};
2
3use anyhow::bail;
4use notify::{recommended_watcher, RecursiveMode, Watcher};
5use tokio::{
6    net::{UnixListener, UnixStream},
7    spawn,
8    sync::watch,
9};
10use tokio_stream::wrappers::UnixListenerStream;
11use tonic::{
12    transport::{Endpoint, Server, Uri},
13    Request,
14};
15use tower::service_fn;
16use tracing::{error, info, warn};
17
18use self::pb::{
19    device_plugin_server::DevicePluginServer, registration_client::RegistrationClient,
20    DevicePluginOptions, RegisterRequest,
21};
22pub use self::{
23    pb::{
24        CdiDevice, ContainerAllocateResponse, ContainerPreferredAllocationResponse, Device,
25        DeviceSpec, Mount, NumaNode, TopologyInfo,
26    },
27    service::GenericDevicePlugin,
28};
29
30mod service;
31mod pb {
32    tonic::include_proto!("v1beta1");
33}
34
35static VERSION: &str = "v1beta1";
36static KUBELET_SOCK: &str = "kubelet.sock";
37
38pub struct GenericDevicePluginServer<DP: GenericDevicePlugin> {
39    dir_path: PathBuf,
40    socket_name: String,
41    _phantom: PhantomData<DP>,
42}
43
44impl<DP: GenericDevicePlugin> GenericDevicePluginServer<DP> {
45    pub fn new(dir_path: PathBuf, socket_name: String) -> Self {
46        Self {
47            dir_path,
48            socket_name,
49            _phantom: PhantomData,
50        }
51    }
52
53    /// 1. clean up & bind socket
54    /// 2. watch socket file (kubelet restart)
55    /// 3. start device plugin server
56    /// 4. register to kubelet
57    /// 5. clean up & goto 1 if socket file changed (graceful)
58    pub async fn run(self) -> anyhow::Result<()> {
59        let socket_path = self.dir_path.join(&self.socket_name);
60
61        loop {
62            match std::os::unix::net::UnixStream::connect(&socket_path) {
63                Err(e) if e.kind() == ErrorKind::NotFound => {}
64                Err(e) if e.kind() == ErrorKind::ConnectionRefused => {
65                    fs::remove_file(&socket_path)?
66                }
67                Err(e) => bail!("unable to ensure uds is available: {e:?}"),
68                Ok(_) => bail!("active unix socket connect exist on {socket_path:?}"),
69            }
70
71            let uds = UnixListener::bind(socket_path.clone())?;
72
73            let (tx, mut rx) = watch::channel(());
74            let mut watcher = recommended_watcher(move |res| {
75                if let Err(e) = res {
76                    error!("failed to watch device plugin socket: {e}")
77                }
78                let _ = tx.send(());
79            })?;
80
81            watcher.watch(&socket_path, RecursiveMode::NonRecursive)?;
82
83            let handle = spawn(
84                Server::builder()
85                    .add_service(DevicePluginServer::new(DP::default()))
86                    .serve_with_incoming_shutdown(UnixListenerStream::new(uds), async move {
87                        let _ = rx.changed().await;
88                        warn!("socket file changed, restarting server...")
89                    }),
90            );
91            info!("plugin server started on {socket_path:?}!");
92
93            self.register().await?;
94            info!("plugin registered!");
95
96            let _ = handle.await;
97            let _ = fs::remove_file(&socket_path);
98        }
99    }
100
101    async fn register(&self) -> anyhow::Result<()> {
102        let register_client_socket_path = self.dir_path.join(KUBELET_SOCK);
103        RegistrationClient::new(
104            Endpoint::try_from("http://[::]:50051")?
105                .connect_with_connector(service_fn(move |_: Uri| {
106                    UnixStream::connect(register_client_socket_path.clone())
107                }))
108                .await?,
109        )
110        .register(Request::new(RegisterRequest {
111            endpoint: self.socket_name.clone(),
112            resource_name: DP::RESOURCE_NAME.to_string(),
113            version: VERSION.to_string(),
114            options: Some(DevicePluginOptions {
115                pre_start_required: DP::PRE_START_REQUIRED,
116                get_preferred_allocation_available: DP::GET_PREFERRED_ALLOCATION_AVAILABLE,
117            }),
118        }))
119        .await?;
120        Ok(())
121    }
122}