use super::error::HostlistError;
use std::path::Path;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RangeItem {
Single(i64),
Range { start: i64, end: i64 },
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RangeExpression {
pub items: Vec<RangeItem>,
pub padding: usize,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PatternSegment {
Literal(String),
Range(RangeExpression),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct HostPattern {
pub segments: Vec<PatternSegment>,
}
impl HostPattern {
pub fn has_ranges(&self) -> bool {
self.segments
.iter()
.any(|s| matches!(s, PatternSegment::Range(_)))
}
pub fn expansion_count(&self) -> usize {
let mut count = 1usize;
for segment in &self.segments {
if let PatternSegment::Range(range) = segment {
count = count.saturating_mul(range.item_count());
}
}
count
}
}
impl RangeExpression {
pub fn item_count(&self) -> usize {
self.items.iter().map(|item| item.count()).sum()
}
pub fn values(&self) -> Vec<i64> {
let mut result = Vec::new();
for item in &self.items {
match item {
RangeItem::Single(v) => result.push(*v),
RangeItem::Range { start, end } => {
for v in *start..=*end {
result.push(v);
}
}
}
}
result
}
pub fn format_value(&self, value: i64) -> String {
if self.padding > 0 {
format!("{:0>width$}", value, width = self.padding)
} else {
value.to_string()
}
}
}
impl RangeItem {
pub fn count(&self) -> usize {
match self {
RangeItem::Single(_) => 1,
RangeItem::Range { start, end } => {
if end >= start {
(end - start + 1) as usize
} else {
0
}
}
}
}
}
pub fn parse_host_pattern(pattern: &str) -> Result<HostPattern, HostlistError> {
if pattern.is_empty() {
return Ok(HostPattern {
segments: Vec::new(),
});
}
let mut segments = Vec::new();
let mut current_literal = String::new();
let mut chars = pattern.chars().peekable();
let mut bracket_depth = 0;
while let Some(ch) = chars.next() {
match ch {
'[' => {
if bracket_depth > 0 {
return Err(HostlistError::NestedBrackets {
expression: pattern.to_string(),
});
}
if let Some(&next_ch) = chars.peek() {
if is_ipv6_start(next_ch, &chars) {
current_literal.push(ch);
continue;
}
}
if !current_literal.is_empty() {
segments.push(PatternSegment::Literal(current_literal.clone()));
current_literal.clear();
}
bracket_depth = 1;
let mut range_content = String::new();
for inner_ch in chars.by_ref() {
match inner_ch {
'[' => {
return Err(HostlistError::NestedBrackets {
expression: pattern.to_string(),
});
}
']' => {
bracket_depth = 0;
break;
}
_ => range_content.push(inner_ch),
}
}
if bracket_depth != 0 {
return Err(HostlistError::UnclosedBracket {
expression: pattern.to_string(),
});
}
if range_content.is_empty() {
return Err(HostlistError::EmptyBracket {
expression: pattern.to_string(),
});
}
let range_expr = parse_range_expression(&range_content, pattern)?;
segments.push(PatternSegment::Range(range_expr));
}
']' => {
if bracket_depth == 0 {
return Err(HostlistError::UnmatchedBracket {
expression: pattern.to_string(),
});
}
bracket_depth -= 1;
}
_ => {
current_literal.push(ch);
}
}
}
if !current_literal.is_empty() {
segments.push(PatternSegment::Literal(current_literal));
}
Ok(HostPattern { segments })
}
fn is_ipv6_start(next_ch: char, _chars: &std::iter::Peekable<std::str::Chars>) -> bool {
next_ch == ':'
}
fn parse_range_expression(content: &str, pattern: &str) -> Result<RangeExpression, HostlistError> {
let mut items = Vec::new();
let mut max_padding = 0;
for item_str in content.split(',') {
let item_str = item_str.trim();
if item_str.is_empty() {
continue;
}
if let Some(dash_pos) = item_str.find('-') {
if dash_pos == 0 {
let rest = &item_str[1..];
if let Some(second_dash) = rest.find('-') {
let start_str = &item_str[..=second_dash];
let end_str = &rest[second_dash + 1..];
let (start, start_padding) = parse_number(start_str, pattern)?;
let (end, end_padding) = parse_number(end_str, pattern)?;
if start > end {
return Err(HostlistError::ReversedRange {
expression: pattern.to_string(),
start,
end,
});
}
max_padding = max_padding.max(start_padding).max(end_padding);
items.push(RangeItem::Range { start, end });
} else {
let (value, padding) = parse_number(item_str, pattern)?;
max_padding = max_padding.max(padding);
items.push(RangeItem::Single(value));
}
} else {
let start_str = &item_str[..dash_pos];
let end_str = &item_str[dash_pos + 1..];
if end_str.starts_with('-') && !end_str[1..].starts_with('-') {
let (start, start_padding) = parse_number(start_str, pattern)?;
let (end, end_padding) = parse_number(end_str, pattern)?;
if start > end {
return Err(HostlistError::ReversedRange {
expression: pattern.to_string(),
start,
end,
});
}
max_padding = max_padding.max(start_padding).max(end_padding);
items.push(RangeItem::Range { start, end });
} else {
let (start, start_padding) = parse_number(start_str, pattern)?;
let (end, end_padding) = parse_number(end_str, pattern)?;
if start > end {
return Err(HostlistError::ReversedRange {
expression: pattern.to_string(),
start,
end,
});
}
max_padding = max_padding.max(start_padding).max(end_padding);
items.push(RangeItem::Range { start, end });
}
}
} else {
let (value, padding) = parse_number(item_str, pattern)?;
max_padding = max_padding.max(padding);
items.push(RangeItem::Single(value));
}
}
if items.is_empty() {
return Err(HostlistError::EmptyBracket {
expression: pattern.to_string(),
});
}
Ok(RangeExpression {
items,
padding: max_padding,
})
}
fn parse_number(s: &str, pattern: &str) -> Result<(i64, usize), HostlistError> {
let s = s.trim();
if s.is_empty() {
return Err(HostlistError::InvalidNumber {
expression: pattern.to_string(),
value: s.to_string(),
});
}
let digits = if let Some(rest) = s.strip_prefix('-') {
rest
} else {
s
};
let padding = if digits.len() > 1 && digits.starts_with('0') {
digits.len()
} else {
0
};
let value: i64 = s.parse().map_err(|_| HostlistError::InvalidNumber {
expression: pattern.to_string(),
value: s.to_string(),
})?;
Ok((value, padding))
}
const MAX_HOSTFILE_SIZE: u64 = 1024 * 1024;
const MAX_HOSTFILE_LINES: usize = 100_000;
pub fn parse_hostfile(path: &Path) -> Result<Vec<String>, HostlistError> {
let metadata = std::fs::metadata(path).map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
HostlistError::FileNotFound {
path: path.display().to_string(),
}
} else {
HostlistError::FileReadError {
path: path.display().to_string(),
reason: e.to_string(),
}
}
})?;
let file_size = metadata.len();
if file_size > MAX_HOSTFILE_SIZE {
return Err(HostlistError::FileReadError {
path: path.display().to_string(),
reason: format!(
"file size {} bytes exceeds maximum allowed size of {} bytes",
file_size, MAX_HOSTFILE_SIZE
),
});
}
let content = std::fs::read_to_string(path).map_err(|e| HostlistError::FileReadError {
path: path.display().to_string(),
reason: e.to_string(),
})?;
let hosts: Vec<String> = content
.lines()
.take(MAX_HOSTFILE_LINES)
.map(|line| line.trim())
.filter(|line| !line.is_empty() && !line.starts_with('#'))
.map(String::from)
.collect();
if content.lines().count() > MAX_HOSTFILE_LINES {
return Err(HostlistError::FileReadError {
path: path.display().to_string(),
reason: format!(
"file contains more than {} lines (limit exceeded)",
MAX_HOSTFILE_LINES
),
});
}
Ok(hosts)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_simple_range() {
let pattern = parse_host_pattern("node[1-3]").unwrap();
assert_eq!(pattern.segments.len(), 2);
match &pattern.segments[0] {
PatternSegment::Literal(s) => assert_eq!(s, "node"),
_ => panic!("Expected literal"),
}
match &pattern.segments[1] {
PatternSegment::Range(r) => {
assert_eq!(r.items.len(), 1);
assert_eq!(r.padding, 0);
match &r.items[0] {
RangeItem::Range { start, end } => {
assert_eq!(*start, 1);
assert_eq!(*end, 3);
}
_ => panic!("Expected range"),
}
}
_ => panic!("Expected range"),
}
}
#[test]
fn test_parse_zero_padded_range() {
let pattern = parse_host_pattern("node[01-05]").unwrap();
match &pattern.segments[1] {
PatternSegment::Range(r) => {
assert_eq!(r.padding, 2);
assert_eq!(r.values(), vec![1, 2, 3, 4, 5]);
}
_ => panic!("Expected range"),
}
}
#[test]
fn test_parse_comma_separated_values() {
let pattern = parse_host_pattern("node[1,3,5]").unwrap();
match &pattern.segments[1] {
PatternSegment::Range(r) => {
assert_eq!(r.items.len(), 3);
assert_eq!(r.values(), vec![1, 3, 5]);
}
_ => panic!("Expected range"),
}
}
#[test]
fn test_parse_mixed_range() {
let pattern = parse_host_pattern("node[1-3,7,9-10]").unwrap();
match &pattern.segments[1] {
PatternSegment::Range(r) => {
assert_eq!(r.values(), vec![1, 2, 3, 7, 9, 10]);
}
_ => panic!("Expected range"),
}
}
#[test]
fn test_parse_multiple_ranges() {
let pattern = parse_host_pattern("rack[1-2]-node[1-3]").unwrap();
assert_eq!(pattern.segments.len(), 4);
assert!(pattern.has_ranges());
assert_eq!(pattern.expansion_count(), 6);
}
#[test]
fn test_parse_with_domain() {
let pattern = parse_host_pattern("web[1-3].example.com").unwrap();
assert_eq!(pattern.segments.len(), 3);
match &pattern.segments[2] {
PatternSegment::Literal(s) => assert_eq!(s, ".example.com"),
_ => panic!("Expected literal"),
}
}
#[test]
fn test_parse_no_range() {
let pattern = parse_host_pattern("simple.host.com").unwrap();
assert_eq!(pattern.segments.len(), 1);
assert!(!pattern.has_ranges());
assert_eq!(pattern.expansion_count(), 1);
}
#[test]
fn test_parse_empty_bracket_error() {
let result = parse_host_pattern("node[]");
assert!(matches!(result, Err(HostlistError::EmptyBracket { .. })));
}
#[test]
fn test_parse_unclosed_bracket_error() {
let result = parse_host_pattern("node[1-5");
assert!(matches!(result, Err(HostlistError::UnclosedBracket { .. })));
}
#[test]
fn test_parse_unmatched_bracket_error() {
let result = parse_host_pattern("node]1-5[");
assert!(matches!(
result,
Err(HostlistError::UnmatchedBracket { .. })
));
}
#[test]
fn test_parse_reversed_range_error() {
let result = parse_host_pattern("node[5-1]");
assert!(matches!(result, Err(HostlistError::ReversedRange { .. })));
}
#[test]
fn test_parse_invalid_number_error() {
let result = parse_host_pattern("node[a-z]");
assert!(matches!(result, Err(HostlistError::InvalidNumber { .. })));
}
#[test]
fn test_parse_nested_brackets_error() {
let result = parse_host_pattern("node[[1-2]]");
assert!(matches!(result, Err(HostlistError::NestedBrackets { .. })));
}
#[test]
fn test_range_expression_format_value() {
let expr = RangeExpression {
items: vec![RangeItem::Range { start: 1, end: 5 }],
padding: 3,
};
assert_eq!(expr.format_value(1), "001");
assert_eq!(expr.format_value(12), "012");
assert_eq!(expr.format_value(123), "123");
}
#[test]
fn test_range_item_count() {
assert_eq!(RangeItem::Single(5).count(), 1);
assert_eq!(RangeItem::Range { start: 1, end: 5 }.count(), 5);
assert_eq!(RangeItem::Range { start: 0, end: 0 }.count(), 1);
}
}