1use std::{
2 collections::BTreeMap,
3 fmt,
4 fs::OpenOptions,
5 io::{
6 self,
7 BufRead,
8 BufReader,
9 Write,
10 },
11 net::IpAddr,
12 path::{
13 Path,
14 PathBuf,
15 },
16 time::{
17 SystemTime,
18 UNIX_EPOCH,
19 },
20};
21
22pub type Result<T> = std::result::Result<T, HostsFileError>;
23
24#[derive(Debug, Clone)]
25pub enum HostsFileError {
26 Io(String),
27 InvalidPath(String),
28 InvalidData(String),
29 UnsupportedPlatform,
30}
31
32impl fmt::Display for HostsFileError {
33 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
34 match self {
35 Self::Io(msg) => write!(f, "IO error: {}", msg),
36 Self::InvalidPath(msg) => write!(f, "Invalid path: {}", msg),
37 Self::InvalidData(msg) => write!(f, "Invalid data: {}", msg),
38 Self::UnsupportedPlatform => write!(f, "Unsupported platform"),
39 }
40 }
41}
42
43impl std::error::Error for HostsFileError {}
44
45impl From<io::Error> for HostsFileError {
46 fn from(err: io::Error) -> Self {
47 Self::Io(err.to_string())
48 }
49}
50
51pub struct HostsFile {
52 entries: BTreeMap<IpAddr, Vec<String>>,
53 tag: String,
54}
55
56impl HostsFile {
57 pub fn new<S: Into<String>>(tag: S) -> Self {
58 Self {
59 entries: BTreeMap::new(),
60 tag: tag.into(),
61 }
62 }
63
64 pub fn add_entry<S: ToString>(&mut self, ip: IpAddr, hostname: S) -> &mut Self {
65 self.entries
66 .entry(ip)
67 .or_default()
68 .push(hostname.to_string());
69 self
70 }
71
72 pub fn add_entries<I, S>(&mut self, ip: IpAddr, hostnames: I) -> &mut Self
73 where
74 I: IntoIterator<Item = S>,
75 S: ToString,
76 {
77 self.entries
78 .entry(ip)
79 .or_default()
80 .extend(hostnames.into_iter().map(|h| h.to_string()));
81 self
82 }
83
84 pub fn is_empty(&self) -> bool {
85 self.entries.is_empty()
86 }
87
88 pub fn entry_count(&self) -> usize {
89 self.entries.len()
90 }
91
92 pub fn write(&self) -> Result<bool> {
93 self.write_to(get_default_hosts_path()?)
94 }
95
96 pub fn write_to<P: AsRef<Path>>(&self, path: P) -> Result<bool> {
97 let path = path.as_ref();
98 validate_hosts_path(path)?;
99
100 let writer = HostsFileWriter::new(path);
101 writer.update_section(&self.tag, &self.entries)
102 }
103}
104
105struct HostsSection {
106 tag: String,
107}
108
109impl HostsSection {
110 fn new(tag: &str) -> Self {
111 Self {
112 tag: tag.to_string(),
113 }
114 }
115
116 fn begin_marker(&self) -> String {
117 format!("# DO NOT EDIT {} BEGIN", self.tag)
118 }
119
120 fn end_marker(&self) -> String {
121 format!("# DO NOT EDIT {} END", self.tag)
122 }
123
124 fn format_entries(&self, entries: &BTreeMap<IpAddr, Vec<String>>) -> Vec<String> {
125 if entries.is_empty() {
126 return vec![];
127 }
128
129 let mut lines = vec![self.begin_marker()];
130
131 for (ip, hostnames) in entries {
132 lines.extend(self.format_host_entries(ip, hostnames));
133 }
134
135 lines.push(self.end_marker());
136 lines
137 }
138
139 fn format_host_entries(&self, ip: &IpAddr, hostnames: &[String]) -> Vec<String> {
140 if cfg!(windows) {
141 hostnames
142 .iter()
143 .map(|hostname| format!("{} {}", ip, hostname))
144 .collect()
145 } else {
146 vec![format!("{} {}", ip, hostnames.join(" "))]
147 }
148 }
149
150 fn find_section_bounds(&self, lines: &[String]) -> SectionBounds {
151 let begin_marker = self.begin_marker();
152 let end_marker = self.end_marker();
153
154 let begin = lines.iter().position(|line| line.trim() == begin_marker);
155 let end = lines.iter().position(|line| line.trim() == end_marker);
156
157 SectionBounds { begin, end }
158 }
159}
160
161#[derive(Debug)]
162struct SectionBounds {
163 begin: Option<usize>,
164 end: Option<usize>,
165}
166
167impl SectionBounds {
168 fn is_complete(&self) -> bool {
169 self.begin.is_some() && self.end.is_some()
170 }
171
172 fn is_missing(&self) -> bool {
173 self.begin.is_none() && self.end.is_none()
174 }
175
176 fn is_partial(&self) -> bool {
177 !self.is_complete() && !self.is_missing()
178 }
179}
180
181struct HostsFileWriter<'a> {
182 path: &'a Path,
183}
184
185impl<'a> HostsFileWriter<'a> {
186 fn new(path: &'a Path) -> Self {
187 Self { path }
188 }
189
190 fn update_section(&self, tag: &str, entries: &BTreeMap<IpAddr, Vec<String>>) -> Result<bool> {
191 let mut lines = self.read_file_lines()?;
192 let section = HostsSection::new(tag);
193 let new_section_lines = section.format_entries(entries);
194
195 let changed = self.apply_section_update(&mut lines, §ion, new_section_lines)?;
196
197 if changed {
198 self.write_file_lines(&lines)?;
199 }
200
201 Ok(changed)
202 }
203
204 fn read_file_lines(&self) -> Result<Vec<String>> {
205 let file = OpenOptions::new()
206 .create(true)
207 .read(true)
208 .write(true)
209 .truncate(false)
210 .open(self.path)?;
211
212 Ok(BufReader::new(file)
213 .lines()
214 .collect::<io::Result<Vec<_>>>()?)
215 }
216
217 fn apply_section_update(
218 &self, lines: &mut Vec<String>, section: &HostsSection, new_section_lines: Vec<String>,
219 ) -> Result<bool> {
220 let bounds = section.find_section_bounds(lines);
221
222 if bounds.is_partial() {
223 return Err(HostsFileError::InvalidData(format!(
224 "Incomplete section markers for tag '{}'",
225 section.tag
226 )));
227 }
228
229 if bounds.is_complete() {
230 self.replace_existing_section(lines, &bounds, new_section_lines)
231 } else {
232 self.add_new_section(lines, new_section_lines)
233 }
234 }
235
236 fn replace_existing_section(
237 &self, lines: &mut Vec<String>, bounds: &SectionBounds, new_section_lines: Vec<String>,
238 ) -> Result<bool> {
239 let begin = bounds.begin.unwrap();
240 let end = bounds.end.unwrap();
241
242 let old_section: Vec<String> = lines.drain(begin..=end).collect();
243
244 if old_section == new_section_lines {
245 lines.splice(begin..begin, old_section);
246 return Ok(false);
247 }
248
249 lines.splice(begin..begin, new_section_lines);
250 Ok(true)
251 }
252
253 fn add_new_section(
254 &self, lines: &mut Vec<String>, new_section_lines: Vec<String>,
255 ) -> Result<bool> {
256 if new_section_lines.is_empty() {
257 return Ok(false);
258 }
259
260 if let Some(last_line) = lines.last()
261 && !last_line.is_empty()
262 {
263 lines.push(String::new());
264 }
265
266 lines.extend(new_section_lines);
267 Ok(true)
268 }
269
270 fn write_file_lines(&self, lines: &[String]) -> Result<()> {
271 let content = self.format_file_content(lines)?;
272 let writer = AtomicFileWriter::new(self.path);
273 writer.write_content(&content)
274 }
275
276 fn format_file_content(&self, lines: &[String]) -> Result<Vec<u8>> {
277 let mut buffer = Vec::new();
278 for line in lines {
279 writeln!(buffer, "{}", line)?;
280 }
281 Ok(buffer)
282 }
283}
284
285struct AtomicFileWriter<'a> {
286 target_path: &'a Path,
287}
288
289impl<'a> AtomicFileWriter<'a> {
290 fn new(path: &'a Path) -> Self {
291 Self { target_path: path }
292 }
293
294 fn write_content(&self, content: &[u8]) -> Result<()> {
295 match self.try_atomic_write(content) {
296 Ok(()) => {
297 log::debug!("Successfully wrote hosts file using atomic write");
298 Ok(())
299 }
300 Err(_) => {
301 log::debug!("Atomic write failed, falling back to direct write");
302 self.write_directly(content)
303 }
304 }
305 }
306
307 fn try_atomic_write(&self, content: &[u8]) -> Result<()> {
308 let temp_path = self.create_temp_path()?;
309
310 std::fs::copy(self.target_path, &temp_path)?;
311
312 #[cfg(target_os = "linux")]
313 self.preserve_selinux_context(&temp_path);
314
315 self.write_file(&temp_path, content)?;
316 std::fs::rename(&temp_path, self.target_path)?;
317
318 Ok(())
319 }
320
321 fn create_temp_path(&self) -> Result<PathBuf> {
322 let parent = self.target_path.parent().ok_or_else(|| {
323 HostsFileError::InvalidPath("Path has no parent directory".to_string())
324 })?;
325
326 let timestamp = SystemTime::now()
327 .duration_since(UNIX_EPOCH)
328 .expect("System time is before Unix epoch")
329 .as_millis();
330
331 let filename = self
332 .target_path
333 .file_name()
334 .ok_or_else(|| HostsFileError::InvalidPath("Path has no filename".to_string()))?;
335
336 let temp_filename = format!("{}.tmp{}", filename.to_string_lossy(), timestamp);
337 Ok(parent.join(temp_filename))
338 }
339
340 #[cfg(target_os = "linux")]
341 fn preserve_selinux_context(&self, _temp_path: &Path) {
342 log::trace!("SELinux context preservation not implemented");
343 }
344
345 fn write_directly(&self, content: &[u8]) -> Result<()> {
346 self.write_file(self.target_path, content)
347 }
348
349 fn write_file(&self, path: &Path, content: &[u8]) -> Result<()> {
350 OpenOptions::new()
351 .create(true)
352 .write(true)
353 .truncate(true)
354 .open(path)?
355 .write_all(content)?;
356 Ok(())
357 }
358}
359
360fn get_default_hosts_path() -> Result<PathBuf> {
361 let path = get_platform_hosts_path()?;
362
363 if !path.exists() {
364 return Err(HostsFileError::InvalidPath(format!(
365 "Hosts file not found at {}",
366 path.display()
367 )));
368 }
369
370 Ok(path)
371}
372
373fn get_platform_hosts_path() -> Result<PathBuf> {
374 if cfg!(unix) {
375 Ok(PathBuf::from("/etc/hosts"))
376 } else if cfg!(windows) {
377 let windir = std::env::var("WinDir").map_err(|_| {
378 HostsFileError::InvalidPath("WinDir environment variable not found".to_string())
379 })?;
380 Ok(PathBuf::from(format!(
381 "{}\\System32\\Drivers\\Etc\\hosts",
382 windir
383 )))
384 } else {
385 Err(HostsFileError::UnsupportedPlatform)
386 }
387}
388
389fn validate_hosts_path(path: &Path) -> Result<()> {
390 if path.is_dir() {
391 Err(HostsFileError::InvalidPath(
392 "Expected file path, got directory".to_string(),
393 ))
394 } else {
395 Ok(())
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use std::io::Write;
402
403 use super::*;
404
405 #[test]
406 fn test_hosts_file_write() {
407 let (mut temp_file, temp_path) = tempfile::NamedTempFile::new().unwrap().into_parts();
408 temp_file.write_all(b"preexisting\ncontent").unwrap();
409
410 let mut hosts_file = HostsFile::new("test");
411 hosts_file.add_entry([1, 1, 1, 1].into(), "example.com");
412
413 assert!(hosts_file.write_to(&temp_path).unwrap());
414 assert!(!hosts_file.write_to(&temp_path).unwrap());
415
416 let contents = std::fs::read_to_string(&temp_path).unwrap();
417 assert!(contents.contains("preexisting\ncontent"));
418 assert!(contents.contains("# DO NOT EDIT test BEGIN"));
419 assert!(contents.contains("1.1.1.1 example.com"));
420 assert!(contents.contains("# DO NOT EDIT test END"));
421 }
422
423 #[test]
424 fn test_fluent_api() {
425 let mut hosts_file = HostsFile::new("test");
426 hosts_file
427 .add_entry([127, 0, 0, 1].into(), "localhost")
428 .add_entries([192, 168, 1, 1].into(), ["router", "gateway"]);
429
430 assert_eq!(hosts_file.entries.len(), 2);
431 }
432}