kftray_commons/utils/
hostsfile.rs

1use std::{
2    collections::BTreeMap,
3    fmt,
4    fs::OpenOptions,
5    io::{
6        self,
7        BufRead,
8        BufReader,
9        Write,
10    },
11    net::IpAddr,
12    path::{
13        Path,
14        PathBuf,
15    },
16    time::{
17        SystemTime,
18        UNIX_EPOCH,
19    },
20};
21
22pub type Result<T> = std::result::Result<T, HostsFileError>;
23
24#[derive(Debug, Clone)]
25pub enum HostsFileError {
26    Io(String),
27    InvalidPath(String),
28    InvalidData(String),
29    UnsupportedPlatform,
30}
31
32impl fmt::Display for HostsFileError {
33    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
34        match self {
35            Self::Io(msg) => write!(f, "IO error: {}", msg),
36            Self::InvalidPath(msg) => write!(f, "Invalid path: {}", msg),
37            Self::InvalidData(msg) => write!(f, "Invalid data: {}", msg),
38            Self::UnsupportedPlatform => write!(f, "Unsupported platform"),
39        }
40    }
41}
42
43impl std::error::Error for HostsFileError {}
44
45impl From<io::Error> for HostsFileError {
46    fn from(err: io::Error) -> Self {
47        Self::Io(err.to_string())
48    }
49}
50
51pub struct HostsFile {
52    entries: BTreeMap<IpAddr, Vec<String>>,
53    tag: String,
54}
55
56impl HostsFile {
57    pub fn new<S: Into<String>>(tag: S) -> Self {
58        Self {
59            entries: BTreeMap::new(),
60            tag: tag.into(),
61        }
62    }
63
64    pub fn add_entry<S: ToString>(&mut self, ip: IpAddr, hostname: S) -> &mut Self {
65        self.entries
66            .entry(ip)
67            .or_default()
68            .push(hostname.to_string());
69        self
70    }
71
72    pub fn add_entries<I, S>(&mut self, ip: IpAddr, hostnames: I) -> &mut Self
73    where
74        I: IntoIterator<Item = S>,
75        S: ToString,
76    {
77        self.entries
78            .entry(ip)
79            .or_default()
80            .extend(hostnames.into_iter().map(|h| h.to_string()));
81        self
82    }
83
84    pub fn is_empty(&self) -> bool {
85        self.entries.is_empty()
86    }
87
88    pub fn entry_count(&self) -> usize {
89        self.entries.len()
90    }
91
92    pub fn write(&self) -> Result<bool> {
93        self.write_to(get_default_hosts_path()?)
94    }
95
96    pub fn write_to<P: AsRef<Path>>(&self, path: P) -> Result<bool> {
97        let path = path.as_ref();
98        validate_hosts_path(path)?;
99
100        let writer = HostsFileWriter::new(path);
101        writer.update_section(&self.tag, &self.entries)
102    }
103}
104
105struct HostsSection {
106    tag: String,
107}
108
109impl HostsSection {
110    fn new(tag: &str) -> Self {
111        Self {
112            tag: tag.to_string(),
113        }
114    }
115
116    fn begin_marker(&self) -> String {
117        format!("# DO NOT EDIT {} BEGIN", self.tag)
118    }
119
120    fn end_marker(&self) -> String {
121        format!("# DO NOT EDIT {} END", self.tag)
122    }
123
124    fn format_entries(&self, entries: &BTreeMap<IpAddr, Vec<String>>) -> Vec<String> {
125        if entries.is_empty() {
126            return vec![];
127        }
128
129        let mut lines = vec![self.begin_marker()];
130
131        for (ip, hostnames) in entries {
132            lines.extend(self.format_host_entries(ip, hostnames));
133        }
134
135        lines.push(self.end_marker());
136        lines
137    }
138
139    fn format_host_entries(&self, ip: &IpAddr, hostnames: &[String]) -> Vec<String> {
140        if cfg!(windows) {
141            hostnames
142                .iter()
143                .map(|hostname| format!("{} {}", ip, hostname))
144                .collect()
145        } else {
146            vec![format!("{} {}", ip, hostnames.join(" "))]
147        }
148    }
149
150    fn find_section_bounds(&self, lines: &[String]) -> SectionBounds {
151        let begin_marker = self.begin_marker();
152        let end_marker = self.end_marker();
153
154        let begin = lines.iter().position(|line| line.trim() == begin_marker);
155        let end = lines.iter().position(|line| line.trim() == end_marker);
156
157        SectionBounds { begin, end }
158    }
159}
160
161#[derive(Debug)]
162struct SectionBounds {
163    begin: Option<usize>,
164    end: Option<usize>,
165}
166
167impl SectionBounds {
168    fn is_complete(&self) -> bool {
169        self.begin.is_some() && self.end.is_some()
170    }
171
172    fn is_missing(&self) -> bool {
173        self.begin.is_none() && self.end.is_none()
174    }
175
176    fn is_partial(&self) -> bool {
177        !self.is_complete() && !self.is_missing()
178    }
179}
180
181struct HostsFileWriter<'a> {
182    path: &'a Path,
183}
184
185impl<'a> HostsFileWriter<'a> {
186    fn new(path: &'a Path) -> Self {
187        Self { path }
188    }
189
190    fn update_section(&self, tag: &str, entries: &BTreeMap<IpAddr, Vec<String>>) -> Result<bool> {
191        let mut lines = self.read_file_lines()?;
192        let section = HostsSection::new(tag);
193        let new_section_lines = section.format_entries(entries);
194
195        let changed = self.apply_section_update(&mut lines, &section, new_section_lines)?;
196
197        if changed {
198            self.write_file_lines(&lines)?;
199        }
200
201        Ok(changed)
202    }
203
204    fn read_file_lines(&self) -> Result<Vec<String>> {
205        let file = OpenOptions::new()
206            .create(true)
207            .read(true)
208            .write(true)
209            .truncate(false)
210            .open(self.path)?;
211
212        Ok(BufReader::new(file)
213            .lines()
214            .collect::<io::Result<Vec<_>>>()?)
215    }
216
217    fn apply_section_update(
218        &self, lines: &mut Vec<String>, section: &HostsSection, new_section_lines: Vec<String>,
219    ) -> Result<bool> {
220        let bounds = section.find_section_bounds(lines);
221
222        if bounds.is_partial() {
223            return Err(HostsFileError::InvalidData(format!(
224                "Incomplete section markers for tag '{}'",
225                section.tag
226            )));
227        }
228
229        if bounds.is_complete() {
230            self.replace_existing_section(lines, &bounds, new_section_lines)
231        } else {
232            self.add_new_section(lines, new_section_lines)
233        }
234    }
235
236    fn replace_existing_section(
237        &self, lines: &mut Vec<String>, bounds: &SectionBounds, new_section_lines: Vec<String>,
238    ) -> Result<bool> {
239        let begin = bounds.begin.unwrap();
240        let end = bounds.end.unwrap();
241
242        let old_section: Vec<String> = lines.drain(begin..=end).collect();
243
244        if old_section == new_section_lines {
245            lines.splice(begin..begin, old_section);
246            return Ok(false);
247        }
248
249        lines.splice(begin..begin, new_section_lines);
250        Ok(true)
251    }
252
253    fn add_new_section(
254        &self, lines: &mut Vec<String>, new_section_lines: Vec<String>,
255    ) -> Result<bool> {
256        if new_section_lines.is_empty() {
257            return Ok(false);
258        }
259
260        if let Some(last_line) = lines.last()
261            && !last_line.is_empty()
262        {
263            lines.push(String::new());
264        }
265
266        lines.extend(new_section_lines);
267        Ok(true)
268    }
269
270    fn write_file_lines(&self, lines: &[String]) -> Result<()> {
271        let content = self.format_file_content(lines)?;
272        let writer = AtomicFileWriter::new(self.path);
273        writer.write_content(&content)
274    }
275
276    fn format_file_content(&self, lines: &[String]) -> Result<Vec<u8>> {
277        let mut buffer = Vec::new();
278        for line in lines {
279            writeln!(buffer, "{}", line)?;
280        }
281        Ok(buffer)
282    }
283}
284
285struct AtomicFileWriter<'a> {
286    target_path: &'a Path,
287}
288
289impl<'a> AtomicFileWriter<'a> {
290    fn new(path: &'a Path) -> Self {
291        Self { target_path: path }
292    }
293
294    fn write_content(&self, content: &[u8]) -> Result<()> {
295        match self.try_atomic_write(content) {
296            Ok(()) => {
297                log::debug!("Successfully wrote hosts file using atomic write");
298                Ok(())
299            }
300            Err(_) => {
301                log::debug!("Atomic write failed, falling back to direct write");
302                self.write_directly(content)
303            }
304        }
305    }
306
307    fn try_atomic_write(&self, content: &[u8]) -> Result<()> {
308        let temp_path = self.create_temp_path()?;
309
310        std::fs::copy(self.target_path, &temp_path)?;
311
312        #[cfg(target_os = "linux")]
313        self.preserve_selinux_context(&temp_path);
314
315        self.write_file(&temp_path, content)?;
316        std::fs::rename(&temp_path, self.target_path)?;
317
318        Ok(())
319    }
320
321    fn create_temp_path(&self) -> Result<PathBuf> {
322        let parent = self.target_path.parent().ok_or_else(|| {
323            HostsFileError::InvalidPath("Path has no parent directory".to_string())
324        })?;
325
326        let timestamp = SystemTime::now()
327            .duration_since(UNIX_EPOCH)
328            .expect("System time is before Unix epoch")
329            .as_millis();
330
331        let filename = self
332            .target_path
333            .file_name()
334            .ok_or_else(|| HostsFileError::InvalidPath("Path has no filename".to_string()))?;
335
336        let temp_filename = format!("{}.tmp{}", filename.to_string_lossy(), timestamp);
337        Ok(parent.join(temp_filename))
338    }
339
340    #[cfg(target_os = "linux")]
341    fn preserve_selinux_context(&self, _temp_path: &Path) {
342        log::trace!("SELinux context preservation not implemented");
343    }
344
345    fn write_directly(&self, content: &[u8]) -> Result<()> {
346        self.write_file(self.target_path, content)
347    }
348
349    fn write_file(&self, path: &Path, content: &[u8]) -> Result<()> {
350        OpenOptions::new()
351            .create(true)
352            .write(true)
353            .truncate(true)
354            .open(path)?
355            .write_all(content)?;
356        Ok(())
357    }
358}
359
360fn get_default_hosts_path() -> Result<PathBuf> {
361    let path = get_platform_hosts_path()?;
362
363    if !path.exists() {
364        return Err(HostsFileError::InvalidPath(format!(
365            "Hosts file not found at {}",
366            path.display()
367        )));
368    }
369
370    Ok(path)
371}
372
373fn get_platform_hosts_path() -> Result<PathBuf> {
374    if cfg!(unix) {
375        Ok(PathBuf::from("/etc/hosts"))
376    } else if cfg!(windows) {
377        let windir = std::env::var("WinDir").map_err(|_| {
378            HostsFileError::InvalidPath("WinDir environment variable not found".to_string())
379        })?;
380        Ok(PathBuf::from(format!(
381            "{}\\System32\\Drivers\\Etc\\hosts",
382            windir
383        )))
384    } else {
385        Err(HostsFileError::UnsupportedPlatform)
386    }
387}
388
389fn validate_hosts_path(path: &Path) -> Result<()> {
390    if path.is_dir() {
391        Err(HostsFileError::InvalidPath(
392            "Expected file path, got directory".to_string(),
393        ))
394    } else {
395        Ok(())
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use std::io::Write;
402
403    use super::*;
404
405    #[test]
406    fn test_hosts_file_write() {
407        let (mut temp_file, temp_path) = tempfile::NamedTempFile::new().unwrap().into_parts();
408        temp_file.write_all(b"preexisting\ncontent").unwrap();
409
410        let mut hosts_file = HostsFile::new("test");
411        hosts_file.add_entry([1, 1, 1, 1].into(), "example.com");
412
413        assert!(hosts_file.write_to(&temp_path).unwrap());
414        assert!(!hosts_file.write_to(&temp_path).unwrap());
415
416        let contents = std::fs::read_to_string(&temp_path).unwrap();
417        assert!(contents.contains("preexisting\ncontent"));
418        assert!(contents.contains("# DO NOT EDIT test BEGIN"));
419        assert!(contents.contains("1.1.1.1 example.com"));
420        assert!(contents.contains("# DO NOT EDIT test END"));
421    }
422
423    #[test]
424    fn test_fluent_api() {
425        let mut hosts_file = HostsFile::new("test");
426        hosts_file
427            .add_entry([127, 0, 0, 1].into(), "localhost")
428            .add_entries([192, 168, 1, 1].into(), ["router", "gateway"]);
429
430        assert_eq!(hosts_file.entries.len(), 2);
431    }
432}