use std::collections::HashSet;
use crate::error::ValueError;
use crate::parser;
use crate::types::{
AUTORELABEL_KEY, Line, REQUIRESEUSERS_KEY, SELINUX_KEY, SELINUXTYPE_DEFAULT, SELINUXTYPE_KEY,
SETLOCALDEFS_KEY, SelinuxMode,
};
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ConfigFile {
pub(crate) lines: Vec<Line>,
}
impl ConfigFile {
pub fn new() -> Self {
ConfigFile { lines: Vec::new() }
}
pub fn parse(input: &str) -> Result<Self, crate::error::ParseError> {
parser::parse(input)
}
#[must_use]
pub fn lines(&self) -> &[Line] {
&self.lines
}
#[must_use]
pub fn is_empty(&self) -> bool {
!self.lines.iter().any(|l| matches!(l, Line::Entry { .. }))
}
#[must_use]
pub fn contains(&self, key: &str) -> bool {
self.lines.iter().any(|line| match line {
Line::Entry { key_raw, .. } => key_raw.eq_ignore_ascii_case(key),
_ => false,
})
}
#[must_use]
pub fn get(&self, key: &str) -> Option<&str> {
self.lines.iter().rev().find_map(|line| {
if let Line::Entry { key_raw, value, .. } = line
&& key_raw.eq_ignore_ascii_case(key)
{
return Some(value.as_str());
}
None
})
}
#[must_use]
pub fn selinux(&self) -> Option<SelinuxMode> {
self.get(SELINUX_KEY).and_then(|v| v.parse().ok())
}
#[must_use]
pub fn selinuxtype(&self) -> Option<&str> {
self.get(SELINUXTYPE_KEY)
}
#[must_use]
pub fn require_seusers(&self) -> Option<bool> {
self.get_bool(REQUIRESEUSERS_KEY)
}
#[must_use]
pub fn autorelabel(&self) -> Option<bool> {
self.get_bool(AUTORELABEL_KEY)
}
#[must_use]
pub fn setlocaldefs(&self) -> Option<bool> {
self.get_bool(SETLOCALDEFS_KEY)
}
#[must_use]
pub fn get_bool(&self, key: &str) -> Option<bool> {
self.get(key)
.and_then(|v| match v.to_ascii_lowercase().as_str() {
"1" | "true" => Some(true),
"0" | "false" => Some(false),
_ => None,
})
}
pub fn set_selinux(&mut self, mode: SelinuxMode) {
self.set_inner(SELINUX_KEY, &mode.to_string());
}
pub fn set_selinuxtype(&mut self, value: &str) -> Result<(), ValueError> {
let errors = validate_selinuxtype_value(value);
if let Some(e) = errors.into_iter().next() {
return Err(e);
}
let trimmed = value.trim();
self.set_inner(SELINUXTYPE_KEY, trimmed);
Ok(())
}
pub fn set_require_seusers(&mut self, value: bool) {
self.set_inner(REQUIRESEUSERS_KEY, if value { "1" } else { "0" });
}
pub fn set_autorelabel(&mut self, value: bool) {
self.set_inner(AUTORELABEL_KEY, if value { "1" } else { "0" });
}
pub fn set_setlocaldefs(&mut self, value: bool) {
self.set_inner(SETLOCALDEFS_KEY, if value { "1" } else { "0" });
}
pub fn set(&mut self, key: &str, value: &str) {
if key.is_empty() {
return;
}
let canonical = canonical_key_name(key);
self.set_inner(&canonical, value);
}
pub fn remove(&mut self, key: &str) -> bool {
let len_before = self.lines.len();
self.lines.retain(|line| match line {
Line::Entry { key_raw, .. } => !key_raw.eq_ignore_ascii_case(key),
_ => true,
});
self.lines.len() != len_before
}
pub fn disable(&mut self, key: &str) -> bool {
let mut disabled = false;
for line in self.lines.iter_mut() {
if let Line::Entry {
key_raw,
value,
raw_leading,
raw_separator,
raw_suffix,
} = line
&& key_raw.eq_ignore_ascii_case(key)
{
let commented = format!(
"{}# {}{}{}{}",
raw_leading, key_raw, raw_separator, value, raw_suffix
);
*line = Line::Comment(commented);
disabled = true;
}
}
disabled
}
#[must_use]
pub fn keys(&self) -> Vec<&str> {
let mut seen = HashSet::new();
let mut result = Vec::new();
for line in &self.lines {
if let Line::Entry { key_raw, .. } = line {
let lower = key_raw.to_ascii_lowercase();
if seen.insert(lower) {
result.push(key_raw.as_str());
}
}
}
result
}
pub fn add_comment_line(&mut self, comment: &str) {
self.lines.push(Line::Comment(format!("# {}\n", comment)));
}
pub fn add_blank_line(&mut self) {
self.lines.push(Line::Blank(String::from("\n")));
}
#[must_use]
pub fn validate(&self) -> Vec<ValueError> {
let mut errors = Vec::new();
for line in &self.lines {
if let Line::Entry { key_raw, value, .. } = line {
let key_upper = key_raw.to_ascii_uppercase();
match key_upper.as_str() {
"SELINUX" if value.parse::<SelinuxMode>().is_err() => {
errors.push(ValueError {
key: SELINUX_KEY.into(),
message: format!("invalid SELinux mode: '{}'", value),
});
}
"SELINUXTYPE" => {
errors.extend(validate_selinuxtype_value(value));
}
"REQUIRESEUSERS" | "AUTORELABEL" | "SETLOCALDEFS" => {
if let Some(e) = validate_boolean_value(key_raw, value) {
errors.push(e);
}
}
_ => {}
}
}
}
errors
}
#[doc(hidden)]
pub(crate) fn set_inner(&mut self, key: &str, value: &str) {
for line in self.lines.iter_mut().rev() {
if let Line::Entry {
key_raw, value: v, ..
} = line
&& key_raw.eq_ignore_ascii_case(key)
{
*v = value.to_string();
return;
}
}
self.lines.push(Line::Entry {
key_raw: key.to_string(),
value: value.to_string(),
raw_leading: String::new(),
raw_separator: "=".to_string(),
raw_suffix: "\n".to_string(),
});
}
}
impl Default for ConfigFile {
fn default() -> Self {
ConfigFile::new()
}
}
impl ConfigFile {
pub fn minimal() -> Self {
let mut cfg = ConfigFile::new();
cfg.lines.push(Line::Entry {
key_raw: SELINUX_KEY.to_string(),
value: "enforcing".to_string(),
raw_leading: String::new(),
raw_separator: "=".to_string(),
raw_suffix: "\n".to_string(),
});
cfg.lines.push(Line::Entry {
key_raw: SELINUXTYPE_KEY.to_string(),
value: SELINUXTYPE_DEFAULT.to_string(),
raw_leading: String::new(),
raw_separator: "=".to_string(),
raw_suffix: "\n".to_string(),
});
cfg
}
}
fn canonical_key_name(key: &str) -> String {
if key.eq_ignore_ascii_case(SELINUX_KEY) {
return SELINUX_KEY.into();
}
if key.eq_ignore_ascii_case(SELINUXTYPE_KEY) {
return SELINUXTYPE_KEY.into();
}
if key.eq_ignore_ascii_case(REQUIRESEUSERS_KEY) {
return REQUIRESEUSERS_KEY.into();
}
if key.eq_ignore_ascii_case(AUTORELABEL_KEY) {
return AUTORELABEL_KEY.into();
}
if key.eq_ignore_ascii_case(SETLOCALDEFS_KEY) {
return SETLOCALDEFS_KEY.into();
}
key.to_string()
}
fn validate_selinuxtype_value(value: &str) -> Vec<ValueError> {
let trimmed = value.trim();
let mut errors = Vec::new();
if trimmed.is_empty() {
errors.push(ValueError {
key: SELINUXTYPE_KEY.into(),
message: "SELINUXTYPE value must not be empty".into(),
});
}
if trimmed.contains('/') {
errors.push(ValueError {
key: SELINUXTYPE_KEY.into(),
message: format!("SELINUXTYPE value must not contain '/': '{}'", trimmed),
});
}
if trimmed.chars().any(|c| c.is_ascii_control()) {
errors.push(ValueError {
key: SELINUXTYPE_KEY.into(),
message: format!(
"SELINUXTYPE value contains control characters: '{}'",
trimmed
),
});
}
errors
}
fn validate_boolean_value(key: &str, value: &str) -> Option<ValueError> {
let lower = value.to_ascii_lowercase();
if lower != "1" && lower != "0" && lower != "true" && lower != "false" {
Some(ValueError {
key: key.into(),
message: format!("invalid boolean value: '{}'", value),
})
} else {
None
}
}