use std::collections::BTreeSet;
use std::io;
use std::net::UdpSocket;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
pub fn extract_qname(packet: &[u8]) -> Option<String> {
if packet.len() < 12 {
return None;
}
let qdcount = u16::from_be_bytes([packet[4], packet[5]]);
if qdcount == 0 {
return None;
}
let mut pos = 12;
let mut labels = Vec::new();
loop {
let len = *packet.get(pos)? as usize;
if len == 0 {
break; }
if len & 0xC0 != 0 {
return None;
}
pos += 1;
let end = pos.checked_add(len)?;
let label = packet.get(pos..end)?;
if !label
.iter()
.all(|&b| b.is_ascii_alphanumeric() || b == b'-' || b == b'_')
{
return None;
}
labels.push(String::from_utf8_lossy(label).to_ascii_lowercase());
pos = end;
}
if labels.is_empty() {
return None;
}
Some(labels.join("."))
}
pub fn first_nameserver(resolv: &str) -> Option<String> {
for line in resolv.lines() {
if let Some(addr) = line.trim().strip_prefix("nameserver ") {
let addr = addr.trim();
if !addr.is_empty() && addr != "127.0.0.1" {
return Some(addr.to_string());
}
}
}
None
}
pub fn observing_resolv(original: &str) -> String {
let mut out = String::from("# just-shield observe — 127.0.0.1 우선, 원본은 폴백.\n");
out.push_str("nameserver 127.0.0.1\n");
for line in original.lines() {
let trimmed = line.trim();
if trimmed == "nameserver 127.0.0.1" {
continue;
}
if trimmed.starts_with('#') {
continue;
}
out.push_str(line);
out.push('\n');
}
out
}
pub fn render_record(job: &str, domains: &BTreeSet<String>) -> String {
let mut out = format!("# just-shield observe 기록 — 잡 '{job}'이 조회한 도메인.\njob {job}\n");
for d in domains {
out.push_str(d);
out.push('\n');
}
out
}
pub struct RelayConfig {
pub listen: String,
pub upstream: String,
pub job: String,
pub record_path: std::path::PathBuf,
pub stop: Arc<AtomicBool>,
}
pub fn serve(config: &RelayConfig) -> io::Result<()> {
let sock = UdpSocket::bind(&config.listen)?;
sock.set_read_timeout(Some(Duration::from_millis(200)))?;
let seen = Arc::new(Mutex::new(BTreeSet::new()));
let _ = std::fs::write(
&config.record_path,
render_record(&config.job, &seen.lock().unwrap()),
);
let mut buf = [0u8; 1500];
while !config.stop.load(Ordering::Relaxed) {
let (n, from) = match sock.recv_from(&mut buf) {
Ok(v) => v,
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
Err(ref e) if e.kind() == io::ErrorKind::TimedOut => continue,
Err(_) => continue, };
let query = &buf[..n];
if let Some(name) = extract_qname(query)
&& let Ok(mut set) = seen.lock()
&& set.insert(name)
{
let _ = std::fs::write(&config.record_path, render_record(&config.job, &set));
}
let _ = forward(&sock, query, &config.upstream, from);
}
Ok(())
}
fn forward(
listen_sock: &UdpSocket,
query: &[u8],
upstream: &str,
reply_to: std::net::SocketAddr,
) -> io::Result<()> {
let up = UdpSocket::bind("0.0.0.0:0")?;
up.set_read_timeout(Some(Duration::from_secs(3)))?;
up.send_to(query, upstream)?;
let mut resp = [0u8; 1500];
let n = up.recv(&mut resp)?;
listen_sock.send_to(&resp[..n], reply_to)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn query_packet(name: &str) -> Vec<u8> {
let mut p = vec![
0x12, 0x34, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];
for label in name.split('.') {
p.push(label.len() as u8);
p.extend_from_slice(label.as_bytes());
}
p.push(0x00); p.extend_from_slice(&[0x00, 0x01, 0x00, 0x01]); p
}
#[test]
fn extracts_simple_and_multi_label_names() {
assert_eq!(
extract_qname(&query_packet("ghcr.io")).as_deref(),
Some("ghcr.io")
);
assert_eq!(
extract_qname(&query_packet("abc123.blob.core.windows.net")).as_deref(),
Some("abc123.blob.core.windows.net")
);
}
#[test]
fn lowercases_names() {
assert_eq!(
extract_qname(&query_packet("GHCR.IO")).as_deref(),
Some("ghcr.io")
);
}
#[test]
fn rejects_compression_pointer_in_question() {
let mut p = query_packet("evil.net");
p[12] = 0xC0;
assert_eq!(extract_qname(&p), None);
}
#[test]
fn rejects_truncated_and_empty() {
assert_eq!(extract_qname(&[0u8; 5]), None); let mut p = query_packet("x.com");
p[4] = 0;
p[5] = 0;
assert_eq!(extract_qname(&p), None);
let mut bad = vec![0x12, 0x34, 0x01, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0];
bad.push(0x40); assert_eq!(extract_qname(&bad), None);
}
#[test]
fn first_nameserver_skips_localhost() {
let resolv = "# comment\nnameserver 127.0.0.1\nnameserver 8.8.8.8\noptions edns0\n";
assert_eq!(first_nameserver(resolv).as_deref(), Some("8.8.8.8"));
assert_eq!(first_nameserver("options edns0\n"), None);
}
#[test]
fn observing_resolv_keeps_original_as_fallback() {
let original = "nameserver 8.8.8.8\nnameserver 1.1.1.1\noptions edns0\n";
let out = observing_resolv(original);
let lines: Vec<&str> = out
.lines()
.filter(|l| l.starts_with("nameserver"))
.collect();
assert_eq!(lines[0], "nameserver 127.0.0.1");
assert!(lines.contains(&"nameserver 8.8.8.8"));
assert!(lines.contains(&"nameserver 1.1.1.1"));
assert!(out.contains("options edns0"));
}
#[test]
fn record_format_matches_observe_reader() {
let mut set = BTreeSet::new();
set.insert("ghcr.io".to_string());
set.insert("crates.io".to_string());
let text = render_record("release", &set);
let parsed = crate::observe::parse_record(&text).unwrap();
assert_eq!(parsed.job, "release");
assert!(parsed.domains.contains("ghcr.io"));
assert!(parsed.domains.contains("crates.io"));
}
}