baracuda-forge 0.0.1-alpha.68

Build-time CUDA kernel compiler for the baracuda ecosystem: nvcc-driven incremental builds, parallel compilation, GPU auto-detection, and CUTLASS / custom git dependency support.
Documentation
//! Source file selection and filtering.

use crate::error::{Error, Result};
use std::path::{Path, PathBuf};
use walkdir::WalkDir;

/// Source file selection configuration.
#[derive(Debug, Clone, Default)]
pub struct SourceSelector {
    includes: Vec<SourcePath>,
    excludes: Vec<String>,
    watch_paths: Vec<PathBuf>,
}

#[derive(Debug, Clone)]
enum SourcePath {
    File(PathBuf),
    Directory(PathBuf),
    Glob(String),
}

impl SourceSelector {
    /// Create a new empty source selector.
    pub fn new() -> Self {
        Self::default()
    }

    /// Add a directory to search for `.cu` files (recursive).
    pub fn add_directory<P: AsRef<Path>>(mut self, dir: P) -> Self {
        self.includes
            .push(SourcePath::Directory(dir.as_ref().to_path_buf()));
        self
    }

    /// Add specific files.
    pub fn add_files<I, P>(mut self, files: I) -> Self
    where
        I: IntoIterator<Item = P>,
        P: AsRef<Path>,
    {
        for file in files {
            self.includes
                .push(SourcePath::File(file.as_ref().to_path_buf()));
        }
        self
    }

    /// Add files matching a glob pattern.
    pub fn add_glob(mut self, pattern: &str) -> Self {
        self.includes.push(SourcePath::Glob(pattern.to_string()));
        self
    }

    /// Exclude files matching patterns.
    ///
    /// Patterns can be:
    /// - `"*_test.cu"` — files ending with `_test.cu`.
    /// - `"deprecated/*"` — files in a directory.
    /// - `"test_*.cu"` — files starting with `test_`.
    pub fn exclude(mut self, patterns: &[&str]) -> Self {
        for pattern in patterns {
            self.excludes.push(pattern.to_string());
        }
        self
    }

    /// Add paths to watch for changes (headers, etc.).
    pub fn watch<I, P>(mut self, paths: I) -> Self
    where
        I: IntoIterator<Item = P>,
        P: AsRef<Path>,
    {
        for path in paths {
            self.watch_paths.push(path.as_ref().to_path_buf());
        }
        self
    }

    /// Resolve all sources to a list of kernel files.
    pub fn resolve(&self) -> Result<Vec<PathBuf>> {
        let mut files = Vec::new();

        if self.includes.is_empty() {
            if let Ok(entries) = glob::glob("src/**/*.cu") {
                for entry in entries.flatten() {
                    if !self.is_excluded(&entry) {
                        files.push(entry);
                    }
                }
            }
        } else {
            for source in &self.includes {
                match source {
                    SourcePath::File(path) => {
                        if !path.exists() {
                            return Err(Error::SourcePathNotFound(path.clone()));
                        }
                        if !self.is_excluded(path) {
                            files.push(path.clone());
                        }
                    }
                    SourcePath::Directory(dir) => {
                        if !dir.exists() {
                            return Err(Error::SourcePathNotFound(dir.clone()));
                        }
                        self.collect_from_directory(dir, &mut files)?;
                    }
                    SourcePath::Glob(pattern) => {
                        if let Ok(entries) = glob::glob(pattern) {
                            for entry in entries.flatten() {
                                if entry.extension().is_some_and(|e| e == "cu")
                                    && !self.is_excluded(&entry)
                                {
                                    files.push(entry);
                                }
                            }
                        }
                    }
                }
            }
        }

        files.sort();
        files.dedup();
        Ok(files)
    }

    /// Get watch paths.
    pub fn watch_paths(&self) -> &[PathBuf] {
        &self.watch_paths
    }

    fn collect_from_directory(&self, dir: &Path, files: &mut Vec<PathBuf>) -> Result<()> {
        for entry in WalkDir::new(dir).into_iter().filter_map(|e| e.ok()) {
            let path = entry.path();
            if path.is_file()
                && path.extension().is_some_and(|e| e == "cu")
                && !self.is_excluded(path)
            {
                files.push(path.to_path_buf());
            }
        }
        Ok(())
    }

    fn is_excluded(&self, path: &Path) -> bool {
        let filename = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
        let path_str = path.to_string_lossy();

        for pattern in &self.excludes {
            if matches_exclusion_pattern(filename, &path_str, pattern) {
                return true;
            }
        }
        false
    }
}

fn matches_exclusion_pattern(filename: &str, path_str: &str, pattern: &str) -> bool {
    if pattern.contains('/') {
        let pattern_parts: Vec<&str> = pattern.split('/').collect();
        if pattern_parts.len() == 2 && pattern_parts[1] == "*" {
            return path_str.contains(&format!("/{}/", pattern_parts[0]))
                || path_str.contains(&format!("\\{}\\", pattern_parts[0]));
        }
    }

    if pattern.contains('*') {
        let parts: Vec<&str> = pattern.split('*').collect();
        if parts.len() == 2 {
            let (prefix, suffix) = (parts[0], parts[1]);
            return filename.starts_with(prefix) && filename.ends_with(suffix);
        }
        if let Some(stripped) = pattern.strip_prefix('*') {
            return filename.ends_with(stripped);
        }
        if let Some(stripped) = pattern.strip_suffix('*') {
            return filename.starts_with(stripped);
        }
    }

    filename == pattern
}

/// Collect header files (`.cuh`) from directories.
pub fn collect_headers<P: AsRef<Path>>(dirs: &[P]) -> Vec<PathBuf> {
    let mut headers = Vec::new();

    for dir in dirs {
        if let Ok(entries) = glob::glob(&format!("{}/**/*.cuh", dir.as_ref().display())) {
            for entry in entries.flatten() {
                headers.push(entry);
            }
        }
    }

    headers.sort();
    headers.dedup();
    headers
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_exclusion_patterns() {
        assert!(matches_exclusion_pattern(
            "test_kernel.cu",
            "src/test_kernel.cu",
            "test_*.cu"
        ));
        assert!(matches_exclusion_pattern(
            "kernel_test.cu",
            "src/kernel_test.cu",
            "*_test.cu"
        ));
        assert!(!matches_exclusion_pattern(
            "kernel.cu",
            "src/kernel.cu",
            "*_test.cu"
        ));
    }
}