just_shield/
dns_observer.rs1use std::collections::BTreeSet;
11use std::io;
12use std::net::UdpSocket;
13use std::sync::atomic::{AtomicBool, Ordering};
14use std::sync::{Arc, Mutex};
15use std::time::Duration;
16
17pub fn extract_qname(packet: &[u8]) -> Option<String> {
23 if packet.len() < 12 {
24 return None;
25 }
26 let qdcount = u16::from_be_bytes([packet[4], packet[5]]);
28 if qdcount == 0 {
29 return None;
30 }
31 let mut pos = 12;
32 let mut labels = Vec::new();
33 loop {
34 let len = *packet.get(pos)? as usize;
35 if len == 0 {
36 break; }
38 if len & 0xC0 != 0 {
40 return None;
41 }
42 pos += 1;
43 let end = pos.checked_add(len)?;
44 let label = packet.get(pos..end)?;
45 if !label
47 .iter()
48 .all(|&b| b.is_ascii_alphanumeric() || b == b'-' || b == b'_')
49 {
50 return None;
51 }
52 labels.push(String::from_utf8_lossy(label).to_ascii_lowercase());
53 pos = end;
54 }
55 if labels.is_empty() {
56 return None;
57 }
58 Some(labels.join("."))
59}
60
61pub fn first_nameserver(resolv: &str) -> Option<String> {
63 for line in resolv.lines() {
64 if let Some(addr) = line.trim().strip_prefix("nameserver ") {
65 let addr = addr.trim();
66 if !addr.is_empty() && addr != "127.0.0.1" {
67 return Some(addr.to_string());
68 }
69 }
70 }
71 None
72}
73
74pub fn observing_resolv(original: &str) -> String {
79 let mut out = String::from("# just-shield observe — 127.0.0.1 우선, 원본은 폴백.\n");
80 out.push_str("nameserver 127.0.0.1\n");
81 for line in original.lines() {
82 let trimmed = line.trim();
83 if trimmed == "nameserver 127.0.0.1" {
85 continue;
86 }
87 if trimmed.starts_with('#') {
88 continue;
89 }
90 out.push_str(line);
91 out.push('\n');
92 }
93 out
94}
95
96pub fn render_record(job: &str, domains: &BTreeSet<String>) -> String {
98 let mut out = format!("# just-shield observe 기록 — 잡 '{job}'이 조회한 도메인.\njob {job}\n");
99 for d in domains {
100 out.push_str(d);
101 out.push('\n');
102 }
103 out
104}
105
106pub struct RelayConfig {
108 pub listen: String,
110 pub upstream: String,
112 pub job: String,
114 pub record_path: std::path::PathBuf,
115 pub stop: Arc<AtomicBool>,
117}
118
119pub fn serve(config: &RelayConfig) -> io::Result<()> {
124 let sock = UdpSocket::bind(&config.listen)?;
125 sock.set_read_timeout(Some(Duration::from_millis(200)))?;
126 let seen = Arc::new(Mutex::new(BTreeSet::new()));
127 let _ = std::fs::write(
129 &config.record_path,
130 render_record(&config.job, &seen.lock().unwrap()),
131 );
132 let mut buf = [0u8; 1500];
133 while !config.stop.load(Ordering::Relaxed) {
134 let (n, from) = match sock.recv_from(&mut buf) {
135 Ok(v) => v,
136 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
137 Err(ref e) if e.kind() == io::ErrorKind::TimedOut => continue,
138 Err(_) => continue, };
140 let query = &buf[..n];
141 if let Some(name) = extract_qname(query)
142 && let Ok(mut set) = seen.lock()
143 && set.insert(name)
144 {
145 let _ = std::fs::write(&config.record_path, render_record(&config.job, &set));
147 }
148 let _ = forward(&sock, query, &config.upstream, from);
151 }
152 Ok(())
153}
154
155fn forward(
157 listen_sock: &UdpSocket,
158 query: &[u8],
159 upstream: &str,
160 reply_to: std::net::SocketAddr,
161) -> io::Result<()> {
162 let up = UdpSocket::bind("0.0.0.0:0")?;
163 up.set_read_timeout(Some(Duration::from_secs(3)))?;
164 up.send_to(query, upstream)?;
165 let mut resp = [0u8; 1500];
166 let n = up.recv(&mut resp)?;
167 listen_sock.send_to(&resp[..n], reply_to)?;
168 Ok(())
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 fn query_packet(name: &str) -> Vec<u8> {
177 let mut p = vec![
178 0x12, 0x34, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
182 ];
183 for label in name.split('.') {
184 p.push(label.len() as u8);
185 p.extend_from_slice(label.as_bytes());
186 }
187 p.push(0x00); p.extend_from_slice(&[0x00, 0x01, 0x00, 0x01]); p
190 }
191
192 #[test]
193 fn extracts_simple_and_multi_label_names() {
194 assert_eq!(
195 extract_qname(&query_packet("ghcr.io")).as_deref(),
196 Some("ghcr.io")
197 );
198 assert_eq!(
199 extract_qname(&query_packet("abc123.blob.core.windows.net")).as_deref(),
200 Some("abc123.blob.core.windows.net")
201 );
202 }
203
204 #[test]
205 fn lowercases_names() {
206 assert_eq!(
207 extract_qname(&query_packet("GHCR.IO")).as_deref(),
208 Some("ghcr.io")
209 );
210 }
211
212 #[test]
213 fn rejects_compression_pointer_in_question() {
214 let mut p = query_packet("evil.net");
215 p[12] = 0xC0;
217 assert_eq!(extract_qname(&p), None);
218 }
219
220 #[test]
221 fn rejects_truncated_and_empty() {
222 assert_eq!(extract_qname(&[0u8; 5]), None); let mut p = query_packet("x.com");
225 p[4] = 0;
226 p[5] = 0;
227 assert_eq!(extract_qname(&p), None);
228 let mut bad = vec![0x12, 0x34, 0x01, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0];
230 bad.push(0x40); assert_eq!(extract_qname(&bad), None);
232 }
233
234 #[test]
235 fn first_nameserver_skips_localhost() {
236 let resolv = "# comment\nnameserver 127.0.0.1\nnameserver 8.8.8.8\noptions edns0\n";
237 assert_eq!(first_nameserver(resolv).as_deref(), Some("8.8.8.8"));
238 assert_eq!(first_nameserver("options edns0\n"), None);
239 }
240
241 #[test]
242 fn observing_resolv_keeps_original_as_fallback() {
243 let original = "nameserver 8.8.8.8\nnameserver 1.1.1.1\noptions edns0\n";
244 let out = observing_resolv(original);
245 let lines: Vec<&str> = out
246 .lines()
247 .filter(|l| l.starts_with("nameserver"))
248 .collect();
249 assert_eq!(lines[0], "nameserver 127.0.0.1");
251 assert!(lines.contains(&"nameserver 8.8.8.8"));
252 assert!(lines.contains(&"nameserver 1.1.1.1"));
253 assert!(out.contains("options edns0"));
255 }
256
257 #[test]
258 fn record_format_matches_observe_reader() {
259 let mut set = BTreeSet::new();
260 set.insert("ghcr.io".to_string());
261 set.insert("crates.io".to_string());
262 let text = render_record("release", &set);
263 let parsed = crate::observe::parse_record(&text).unwrap();
265 assert_eq!(parsed.job, "release");
266 assert!(parsed.domains.contains("ghcr.io"));
267 assert!(parsed.domains.contains("crates.io"));
268 }
269}