use crate::backend::CapabilitySet;
use std::collections::BTreeSet;
#[derive(Clone, Debug, Default, PartialEq, Eq)]
#[non_exhaustive]
pub struct CapabilityRequirement {
pub tokens: BTreeSet<String>,
}
impl CapabilityRequirement {
pub fn new<I, S>(tokens: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
Self {
tokens: tokens
.into_iter()
.map(Into::into)
.filter(|t| !t.is_empty())
.collect(),
}
}
pub fn from_csv(csv: &str) -> Self {
Self::new(csv.split(',').filter(|t| !t.is_empty()))
}
pub fn is_empty(&self) -> bool {
self.tokens.is_empty()
}
}
pub fn matches(required: &CapabilityRequirement, worker: &CapabilitySet) -> bool {
if required.is_empty() {
return true;
}
let worker_tokens: &[String] = &worker.tokens;
required
.tokens
.iter()
.all(|t| worker_tokens.iter().any(|w| w == t))
}
pub fn matches_csv(required_csv: &str, worker_caps: &BTreeSet<String>) -> bool {
required_csv
.split(',')
.filter(|t| !t.is_empty())
.all(|t| worker_caps.contains(t))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_required_csv_matches_any_worker() {
let worker: BTreeSet<String> = BTreeSet::new();
assert!(matches_csv("", &worker));
assert!(matches_csv("", &BTreeSet::from(["gpu".to_owned()])));
}
#[test]
fn all_separator_csv_matches_any_worker() {
let worker = BTreeSet::from(["gpu".to_owned()]);
assert!(matches_csv(",,,", &worker));
}
#[test]
fn exact_match_csv() {
let worker = BTreeSet::from(["gpu".to_owned(), "cuda".to_owned()]);
assert!(matches_csv("gpu,cuda", &worker));
}
#[test]
fn subset_match_csv() {
let worker = BTreeSet::from([
"gpu".to_owned(),
"cuda".to_owned(),
"fp16".to_owned(),
]);
assert!(matches_csv("gpu,cuda", &worker));
assert!(matches_csv("gpu", &worker));
}
#[test]
fn missing_token_rejects_csv() {
let worker = BTreeSet::from(["gpu".to_owned()]);
assert!(!matches_csv("gpu,cuda", &worker));
assert!(!matches_csv("cuda", &worker));
}
#[test]
fn case_sensitive_csv() {
let worker = BTreeSet::from(["gpu".to_owned()]);
assert!(!matches_csv("GPU", &worker));
assert!(matches_csv("gpu", &worker));
}
#[test]
fn structured_empty_required_matches_any() {
let req = CapabilityRequirement::default();
let worker = CapabilitySet::default();
assert!(matches(&req, &worker));
assert!(matches(&req, &CapabilitySet::new(["gpu"])));
}
#[test]
fn structured_subset_match() {
let req = CapabilityRequirement::new(["gpu", "cuda"]);
let worker = CapabilitySet::new(["gpu", "cuda", "fp16"]);
assert!(matches(&req, &worker));
}
#[test]
fn structured_missing_token_rejects() {
let req = CapabilityRequirement::new(["gpu", "cuda"]);
let worker = CapabilitySet::new(["gpu"]);
assert!(!matches(&req, &worker));
}
#[test]
fn structured_case_sensitive() {
let req = CapabilityRequirement::new(["GPU"]);
let worker = CapabilitySet::new(["gpu"]);
assert!(!matches(&req, &worker));
}
#[test]
fn from_csv_drops_empty_tokens() {
let req = CapabilityRequirement::from_csv(",gpu,,cuda,");
assert_eq!(req.tokens.len(), 2);
assert!(req.tokens.contains("gpu"));
assert!(req.tokens.contains("cuda"));
}
#[test]
fn from_csv_empty_string_is_empty_requirement() {
let req = CapabilityRequirement::from_csv("");
assert!(req.is_empty());
}
#[test]
fn matches_and_matches_csv_agree() {
let cases = [
("", vec!["gpu"], true),
("gpu", vec!["gpu"], true),
("gpu,cuda", vec!["gpu"], false),
("gpu,cuda", vec!["gpu", "cuda", "fp16"], true),
(",gpu,", vec!["gpu"], true),
("GPU", vec!["gpu"], false),
];
for (req_csv, worker_tokens, expected) in cases {
let worker_btree: BTreeSet<String> =
worker_tokens.iter().map(|s| (*s).to_owned()).collect();
let worker_set = CapabilitySet::new(worker_tokens.iter().copied());
let req = CapabilityRequirement::from_csv(req_csv);
assert_eq!(
matches_csv(req_csv, &worker_btree),
expected,
"matches_csv({req_csv:?}) mismatch"
);
assert_eq!(
matches(&req, &worker_set),
expected,
"matches({req_csv:?}) mismatch"
);
}
}
}