Skip to main content

cfgd_csi/node/
mod.rs

1#![allow(clippy::result_large_err)] // tonic::Status is inherently 176 bytes
2
3use std::collections::HashMap;
4use std::path::Path;
5use std::sync::Arc;
6
7use cfgd_core::PathDisplayExt;
8use tonic::{Request, Response, Status};
9
10use crate::cache::Cache;
11use crate::csi::v1::node_server::Node;
12use crate::csi::v1::{
13    NodeExpandVolumeRequest, NodeExpandVolumeResponse, NodeGetCapabilitiesRequest,
14    NodeGetCapabilitiesResponse, NodeGetInfoRequest, NodeGetInfoResponse,
15    NodeGetVolumeStatsRequest, NodeGetVolumeStatsResponse, NodePublishVolumeRequest,
16    NodePublishVolumeResponse, NodeServiceCapability, NodeStageVolumeRequest,
17    NodeStageVolumeResponse, NodeUnpublishVolumeRequest, NodeUnpublishVolumeResponse,
18    NodeUnstageVolumeRequest, NodeUnstageVolumeResponse, VolumeUsage, node_service_capability,
19    volume_usage,
20};
21use crate::metrics::{CsiMetrics, ModuleLabels, PublishLabels, PullLabels};
22
23/// Env var holding the registry allow-list for CSI module pulls.
24/// Comma-separated list of host[:port] entries; `*` disables the check.
25/// Unset leaves the check disabled but emits a startup warning.
26pub const ALLOWED_REGISTRIES_ENV: &str = "CFGD_CSI_ALLOWED_REGISTRIES";
27
28pub struct CfgdNode {
29    cache: Arc<Cache>,
30    metrics: Arc<CsiMetrics>,
31    node_id: String,
32    // `None` means "no allow-list configured" (log a startup warn). `Some(empty)`
33    // means "allow-list explicitly empty — refuse every ref" (not a normal
34    // deploy, but useful for fail-closed dev). `Some(non-empty)` means the ref
35    // host must match one of the entries.
36    allowed_registries: Option<Vec<String>>,
37}
38
39impl CfgdNode {
40    pub fn new(cache: Arc<Cache>, metrics: Arc<CsiMetrics>, node_id: String) -> Self {
41        let allowed_registries = parse_allowed_registries_from_env();
42        match &allowed_registries {
43            None => tracing::warn!(
44                env = ALLOWED_REGISTRIES_ENV,
45                "CSI registry allow-list is not configured — accepting any ociRef from volume context. In multi-tenant clusters set this env (comma-separated host[:port]) to restrict pulls."
46            ),
47            Some(list) if list.is_empty() => tracing::warn!(
48                env = ALLOWED_REGISTRIES_ENV,
49                "CSI registry allow-list is explicitly empty — all module pulls will be refused."
50            ),
51            Some(list) => tracing::info!(
52                env = ALLOWED_REGISTRIES_ENV,
53                count = list.len(),
54                "CSI registry allow-list active"
55            ),
56        }
57        Self {
58            cache,
59            metrics,
60            node_id,
61            allowed_registries,
62        }
63    }
64}
65
66fn parse_allowed_registries_from_env() -> Option<Vec<String>> {
67    let raw = std::env::var(ALLOWED_REGISTRIES_ENV).ok()?;
68    let trimmed = raw.trim();
69    if trimmed.is_empty() {
70        return None;
71    }
72    if trimmed == "*" {
73        // Wildcard = explicit "any registry"; keep it distinct from "unset"
74        // for log clarity but collapse to no-check here.
75        return None;
76    }
77    Some(
78        trimmed
79            .split(',')
80            .map(str::trim)
81            .filter(|s| !s.is_empty())
82            .map(str::to_string)
83            .collect(),
84    )
85}
86
87/// Extract the registry host[:port] from a full `oci_ref` like
88/// `ghcr.io/org/mod:tag` or `myregistry.local:5000/org/mod:tag`. Refs without
89/// an explicit registry (e.g. `cfgd-modules/foo:v1`) return the empty string.
90fn registry_of(oci_ref: &str) -> &str {
91    // An OCI reference has a registry iff its first path segment contains
92    // `.`, `:`, or equals `localhost`.
93    let first_slash = oci_ref.find('/').unwrap_or(oci_ref.len());
94    let head = &oci_ref[..first_slash];
95    if head.contains('.') || head.contains(':') || head == "localhost" {
96        head
97    } else {
98        ""
99    }
100}
101
102fn require_attr<'a>(attrs: &'a HashMap<String, String>, key: &str) -> Result<&'a str, Status> {
103    attrs
104        .get(key)
105        .map(|v| v.as_str())
106        .filter(|v| !v.is_empty())
107        .ok_or_else(|| {
108            Status::invalid_argument(format!("missing required volume attribute: {key}"))
109        })
110}
111
112fn require_volume_id(volume_id: &str) -> Result<(), Status> {
113    if volume_id.is_empty() {
114        return Err(Status::invalid_argument("volume_id is required"));
115    }
116    Ok(())
117}
118
119fn resolve_oci_ref(attrs: &HashMap<String, String>, module: &str, version: &str) -> String {
120    attrs
121        .get("ociRef")
122        .filter(|v| !v.is_empty())
123        .cloned()
124        .unwrap_or_else(|| format!("cfgd-modules/{module}:{version}"))
125}
126
127/// Check whether `oci_ref` is permitted under `allowed_registries`.
128/// `None` (unset) → allow (legacy behavior, with startup warn already emitted).
129/// `Some(list)` → the registry host of the ref must be in `list`. Refs with no
130/// explicit registry (default module namespace) are treated as the reserved
131/// `"cfgd-modules"` namespace and are always allowed — those can't reach the
132/// network without a caller-provided registry anyway.
133fn check_registry_allowed(
134    oci_ref: &str,
135    allowed_registries: Option<&[String]>,
136) -> Result<(), Status> {
137    let Some(list) = allowed_registries else {
138        return Ok(());
139    };
140    let registry = registry_of(oci_ref);
141    if registry.is_empty() {
142        return Ok(());
143    }
144    if list.iter().any(|r| r == registry) {
145        return Ok(());
146    }
147    Err(Status::permission_denied(format!(
148        "registry '{registry}' is not in the CSI allow-list (set {ALLOWED_REGISTRIES_ENV})"
149    )))
150}
151
152#[tonic::async_trait]
153impl Node for CfgdNode {
154    async fn node_stage_volume(
155        &self,
156        request: Request<NodeStageVolumeRequest>,
157    ) -> Result<Response<NodeStageVolumeResponse>, Status> {
158        let req = request.into_inner();
159        require_volume_id(&req.volume_id)?;
160        if req.staging_target_path.is_empty() {
161            return Err(Status::invalid_argument("staging_target_path is required"));
162        }
163        if let Err(e) = cfgd_core::validate_no_traversal(Path::new(&req.staging_target_path)) {
164            return Err(Status::invalid_argument(format!(
165                "staging_target_path traversal rejected: {e}"
166            )));
167        }
168
169        let attrs = &req.volume_context;
170        let module = require_attr(attrs, "module")?;
171        let version = require_attr(attrs, "version")?;
172        let oci_ref = resolve_oci_ref(attrs, module, version);
173        check_registry_allowed(&oci_ref, self.allowed_registries.as_deref())?;
174
175        tracing::info!(
176            module = module,
177            version = version,
178            volume_id = req.volume_id,
179            "staging volume — pulling to cache"
180        );
181
182        let start = std::time::Instant::now();
183        let cached = self.cache.get(module, version).is_some();
184        self.cache
185            .get_or_pull(module, version, &oci_ref)
186            .map_err(|e| Status::internal(format!("cache pull failed: {e}")))?;
187
188        let duration = start.elapsed().as_secs_f64();
189        self.metrics
190            .pull_duration_seconds
191            .get_or_create(&PullLabels {
192                module: module.to_string(),
193                cached: cached.to_string(),
194            })
195            .observe(duration);
196
197        if cached {
198            self.metrics
199                .cache_hits_total
200                .get_or_create(&ModuleLabels {
201                    module: module.to_string(),
202                })
203                .inc();
204        }
205
206        self.metrics
207            .cache_size_bytes
208            .set(self.cache.current_size_bytes() as i64);
209
210        Ok(Response::new(NodeStageVolumeResponse {}))
211    }
212
213    async fn node_unstage_volume(
214        &self,
215        request: Request<NodeUnstageVolumeRequest>,
216    ) -> Result<Response<NodeUnstageVolumeResponse>, Status> {
217        let req = request.into_inner();
218        require_volume_id(&req.volume_id)?;
219        tracing::debug!(
220            volume_id = req.volume_id,
221            "unstage volume (no-op, cache persists)"
222        );
223        Ok(Response::new(NodeUnstageVolumeResponse {}))
224    }
225
226    async fn node_publish_volume(
227        &self,
228        request: Request<NodePublishVolumeRequest>,
229    ) -> Result<Response<NodePublishVolumeResponse>, Status> {
230        let req = request.into_inner();
231        require_volume_id(&req.volume_id)?;
232
233        let attrs = &req.volume_context;
234        let module = require_attr(attrs, "module")?;
235        let version = require_attr(attrs, "version")?;
236        let target_path = &req.target_path;
237
238        if target_path.is_empty() {
239            return Err(Status::invalid_argument("target_path is required"));
240        }
241        if let Err(e) = cfgd_core::validate_no_traversal(Path::new(target_path)) {
242            return Err(Status::invalid_argument(format!(
243                "target_path traversal rejected: {e}"
244            )));
245        }
246
247        let target = Path::new(target_path);
248
249        // Idempotent: if already mounted as read-only, return success
250        if is_mountpoint(target) {
251            if is_readonly_mount(target) {
252                tracing::debug!(
253                    target = target_path,
254                    "already mounted read-only, returning success"
255                );
256                return Ok(Response::new(NodePublishVolumeResponse {}));
257            }
258            // Mounted but not read-only — attempt remount
259            tracing::warn!(
260                target = target_path,
261                "mount exists but is not read-only, attempting remount"
262            );
263            #[cfg(target_os = "linux")]
264            {
265                use nix::mount::{MsFlags, mount};
266                mount(
267                    None::<&str>,
268                    target,
269                    None::<&str>,
270                    MsFlags::MS_REMOUNT | MsFlags::MS_BIND | MsFlags::MS_RDONLY,
271                    None::<&str>,
272                )
273                .map_err(|e| Status::internal(format!("read-only remount failed: {e}")))?;
274            }
275            return Ok(Response::new(NodePublishVolumeResponse {}));
276        }
277
278        tracing::info!(
279            module = module,
280            version = version,
281            target = target_path,
282            volume_id = req.volume_id,
283            "publishing volume"
284        );
285
286        // Get cached content (should have been staged already, but pull if needed)
287        let oci_ref = resolve_oci_ref(attrs, module, version);
288        check_registry_allowed(&oci_ref, self.allowed_registries.as_deref())?;
289
290        // OCI pull, `std::fs::create_dir_all`, and the `bind_mount` syscall are
291        // all blocking. Running them inline on the tokio runtime starves other
292        // worker threads under kubelet concurrency. Move the whole blocking
293        // sequence into `spawn_blocking` with owned clones of the relevant state.
294        let cache = Arc::clone(&self.cache);
295        let metrics = Arc::clone(&self.metrics);
296        let module = module.to_string();
297        let version = version.to_string();
298        let oci_ref_owned = oci_ref.clone();
299        let target_path_owned: std::path::PathBuf = target.to_path_buf();
300        tokio::task::spawn_blocking(move || {
301            let source = cache
302                .get_or_pull(&module, &version, &oci_ref_owned)
303                .map_err(|e| Status::internal(format!("cache pull failed: {e}")))?;
304
305            std::fs::create_dir_all(&target_path_owned)
306                .map_err(|e| Status::internal(format!("cannot create target dir: {e}")))?;
307
308            match bind_mount_readonly(&source, &target_path_owned) {
309                Ok(()) => {
310                    metrics
311                        .volume_publish_total
312                        .get_or_create(&PublishLabels {
313                            module: module.clone(),
314                            result: "success".to_string(),
315                        })
316                        .inc();
317                    Ok(())
318                }
319                Err(e) => {
320                    metrics
321                        .volume_publish_total
322                        .get_or_create(&PublishLabels {
323                            module: module.clone(),
324                            result: "error".to_string(),
325                        })
326                        .inc();
327                    if let Err(rm_err) = std::fs::remove_dir(&target_path_owned) {
328                        tracing::warn!(
329                            error = %rm_err,
330                            target = %target_path_owned.posix(),
331                            "failed to remove mount target after bind_mount failure",
332                        );
333                    }
334                    Err(e)
335                }
336            }
337        })
338        .await
339        .map_err(|e| Status::internal(format!("publish task join failed: {e}")))??;
340
341        Ok(Response::new(NodePublishVolumeResponse {}))
342    }
343
344    async fn node_unpublish_volume(
345        &self,
346        request: Request<NodeUnpublishVolumeRequest>,
347    ) -> Result<Response<NodeUnpublishVolumeResponse>, Status> {
348        let req = request.into_inner();
349        require_volume_id(&req.volume_id)?;
350
351        let target_path = &req.target_path;
352        if target_path.is_empty() {
353            return Err(Status::invalid_argument("target_path is required"));
354        }
355        if let Err(e) = cfgd_core::validate_no_traversal(std::path::Path::new(target_path)) {
356            return Err(Status::invalid_argument(format!(
357                "target_path traversal rejected: {e}"
358            )));
359        }
360
361        tracing::info!(
362            target = target_path,
363            volume_id = req.volume_id,
364            "unpublishing volume"
365        );
366
367        // `unmount` (umount2 syscall) and `remove_dir` both block — run them
368        // off the tokio runtime so kubelet-driven concurrency cannot starve
369        // other csi workers.
370        let target_path_owned: std::path::PathBuf = target_path.into();
371        tokio::task::spawn_blocking(move || -> Result<(), Status> {
372            unmount(&target_path_owned)?;
373            if let Err(e) = std::fs::remove_dir(&target_path_owned) {
374                tracing::warn!(target = %target_path_owned.posix(), error = %e, "failed to remove target directory after unmount");
375            }
376            Ok(())
377        })
378        .await
379        .map_err(|e| Status::internal(format!("unpublish task join failed: {e}")))??;
380
381        Ok(Response::new(NodeUnpublishVolumeResponse {}))
382    }
383
384    async fn node_get_volume_stats(
385        &self,
386        request: Request<NodeGetVolumeStatsRequest>,
387    ) -> Result<Response<NodeGetVolumeStatsResponse>, Status> {
388        let req = request.into_inner();
389        require_volume_id(&req.volume_id)?;
390
391        let volume_path = &req.volume_path;
392        if volume_path.is_empty() {
393            return Err(Status::invalid_argument("volume_path is required"));
394        }
395
396        let path = Path::new(volume_path);
397        if !path.exists() {
398            return Err(Status::not_found(format!(
399                "volume path does not exist: {volume_path}"
400            )));
401        }
402
403        let (bytes, inodes) = walk_volume_stats(path);
404
405        tracing::debug!(
406            volume_id = req.volume_id,
407            volume_path = volume_path,
408            bytes = bytes,
409            inodes = inodes,
410            "volume stats"
411        );
412
413        Ok(Response::new(NodeGetVolumeStatsResponse {
414            usage: vec![
415                VolumeUsage {
416                    total: bytes as i64,
417                    used: bytes as i64,
418                    available: 0,
419                    unit: volume_usage::Unit::Bytes as i32,
420                },
421                VolumeUsage {
422                    total: inodes as i64,
423                    used: inodes as i64,
424                    available: 0,
425                    unit: volume_usage::Unit::Inodes as i32,
426                },
427            ],
428            volume_condition: None,
429        }))
430    }
431
432    async fn node_expand_volume(
433        &self,
434        _request: Request<NodeExpandVolumeRequest>,
435    ) -> Result<Response<NodeExpandVolumeResponse>, Status> {
436        Err(Status::unimplemented("NodeExpandVolume not supported"))
437    }
438
439    async fn node_get_capabilities(
440        &self,
441        _request: Request<NodeGetCapabilitiesRequest>,
442    ) -> Result<Response<NodeGetCapabilitiesResponse>, Status> {
443        tracing::debug!("NodeGetCapabilities called");
444        Ok(Response::new(NodeGetCapabilitiesResponse {
445            capabilities: vec![
446                NodeServiceCapability {
447                    r#type: Some(node_service_capability::Type::Rpc(
448                        node_service_capability::Rpc {
449                            r#type: node_service_capability::rpc::Type::StageUnstageVolume.into(),
450                        },
451                    )),
452                },
453                NodeServiceCapability {
454                    r#type: Some(node_service_capability::Type::Rpc(
455                        node_service_capability::Rpc {
456                            r#type: node_service_capability::rpc::Type::GetVolumeStats.into(),
457                        },
458                    )),
459                },
460            ],
461        }))
462    }
463
464    async fn node_get_info(
465        &self,
466        _request: Request<NodeGetInfoRequest>,
467    ) -> Result<Response<NodeGetInfoResponse>, Status> {
468        tracing::debug!(node_id = %self.node_id, "NodeGetInfo called");
469        Ok(Response::new(NodeGetInfoResponse {
470            node_id: self.node_id.clone(),
471            max_volumes_per_node: 0, // no limit
472            accessible_topology: None,
473        }))
474    }
475}
476
477/// Check if a path is a mountpoint by comparing device IDs with parent.
478fn is_mountpoint(path: &Path) -> bool {
479    #[cfg(target_os = "linux")]
480    {
481        use std::os::unix::fs::MetadataExt;
482        let Ok(path_meta) = std::fs::metadata(path) else {
483            return false;
484        };
485        let Some(parent) = path.parent() else {
486            return true; // root is always a mountpoint
487        };
488        let Ok(parent_meta) = std::fs::metadata(parent) else {
489            return false;
490        };
491        path_meta.dev() != parent_meta.dev()
492    }
493    #[cfg(not(target_os = "linux"))]
494    {
495        let _ = path;
496        false
497    }
498}
499
500/// Check if a path is mounted read-only using the statvfs syscall.
501#[cfg(target_os = "linux")]
502fn is_readonly_mount(path: &Path) -> bool {
503    use nix::sys::statvfs::{FsFlags, statvfs};
504    statvfs(path)
505        .map(|stat| stat.flags().contains(FsFlags::ST_RDONLY))
506        .unwrap_or(false)
507}
508
509#[cfg(not(target_os = "linux"))]
510fn is_readonly_mount(_path: &Path) -> bool {
511    false
512}
513
514/// Bind mount `source` to `target` as read-only.
515///
516/// Two-step operation on Linux: bind mount, then remount read-only.
517/// (MS_BIND | MS_RDONLY in a single call does NOT work.)
518#[cfg(target_os = "linux")]
519fn bind_mount_readonly(source: &Path, target: &Path) -> Result<(), Status> {
520    use nix::mount::{MsFlags, mount};
521
522    mount(
523        Some(source),
524        target,
525        None::<&str>,
526        MsFlags::MS_BIND,
527        None::<&str>,
528    )
529    .map_err(|e| Status::internal(format!("bind mount failed: {e}")))?;
530
531    if let Err(e) = mount(
532        None::<&str>,
533        target,
534        None::<&str>,
535        MsFlags::MS_REMOUNT | MsFlags::MS_BIND | MsFlags::MS_RDONLY,
536        None::<&str>,
537    ) {
538        if let Err(umount_err) = nix::mount::umount2(target, nix::mount::MntFlags::MNT_DETACH) {
539            tracing::debug!(
540                error = %umount_err,
541                target = %target.display(),
542                "best-effort cleanup umount2 failed after read-only remount error",
543            );
544        }
545        return Err(Status::internal(format!("read-only remount failed: {e}")));
546    }
547
548    Ok(())
549}
550
551#[cfg(not(target_os = "linux"))]
552fn bind_mount_readonly(_source: &Path, _target: &Path) -> Result<(), Status> {
553    Err(Status::unimplemented("bind mount only supported on Linux"))
554}
555
556/// Unmount a target path with MNT_DETACH (lazy unmount).
557#[cfg(target_os = "linux")]
558fn unmount(target: &Path) -> Result<(), Status> {
559    use nix::mount::{MntFlags, umount2};
560
561    match umount2(target, MntFlags::MNT_DETACH) {
562        Ok(()) => Ok(()),
563        Err(nix::errno::Errno::EINVAL)
564        | Err(nix::errno::Errno::ENOENT)
565        | Err(nix::errno::Errno::EPERM) => {
566            // EINVAL = not mounted, ENOENT = doesn't exist, EPERM = not a mount point
567            // (non-root gets EPERM instead of EINVAL) — all idempotent success
568            Ok(())
569        }
570        Err(e) => Err(Status::internal(format!("unmount failed: {e}"))),
571    }
572}
573
574#[cfg(not(target_os = "linux"))]
575fn unmount(_target: &Path) -> Result<(), Status> {
576    // Non-linux stub for compilation — CSI driver only runs on Linux
577    Ok(())
578}
579
580/// Recursively walk a directory, returning (total_bytes, total_inodes).
581/// Uses symlink_metadata to avoid following symlinks (prevents infinite loops).
582fn walk_volume_stats(path: &Path) -> (u64, u64) {
583    let mut bytes = 0u64;
584    let mut inodes = 0u64;
585
586    fn walk(path: &Path, bytes: &mut u64, inodes: &mut u64) {
587        let entries = match std::fs::read_dir(path) {
588            Ok(rd) => rd,
589            Err(_) => return,
590        };
591        for entry in entries.flatten() {
592            *inodes += 1;
593            let p = entry.path();
594            let Ok(meta) = p.symlink_metadata() else {
595                continue;
596            };
597            if meta.is_symlink() {
598                // Count the symlink inode but don't follow it
599                *bytes = bytes.saturating_add(meta.len());
600            } else if meta.is_dir() {
601                walk(&p, bytes, inodes);
602            } else {
603                *bytes = bytes.saturating_add(meta.len());
604            }
605        }
606    }
607
608    // Count the root directory itself
609    inodes += 1;
610    walk(path, &mut bytes, &mut inodes);
611    (bytes, inodes)
612}
613
614#[cfg(test)]
615mod tests;