use std::collections::{HashMap, VecDeque};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use parking_lot::Mutex;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpListener;
use crate::config::ServerConfig;
pub struct AuthRateLimit {
pub max_attempts: u32,
pub window: Duration,
state: Arc<Mutex<VecDeque<Instant>>>,
}
impl AuthRateLimit {
pub fn new(max_attempts: u32, window: Duration) -> Self {
Self {
max_attempts,
window,
state: Arc::new(Mutex::new(VecDeque::new())),
}
}
fn check_and_record(&self) -> bool {
let now = Instant::now();
let mut s = self.state.lock();
while let Some(&front) = s.front() {
if now.duration_since(front) > self.window {
s.pop_front();
} else {
break;
}
}
s.push_back(now);
s.len() as u32 > self.max_attempts
}
}
impl Clone for AuthRateLimit {
fn clone(&self) -> Self {
Self {
max_attempts: self.max_attempts,
window: self.window,
state: Arc::clone(&self.state),
}
}
}
#[derive(Clone)]
pub struct MockConfig {
pub welcome_code: u16,
pub welcome_message: String,
pub auth_required: bool,
pub valid_credentials: Option<(String, String)>,
pub fail_auth: bool,
pub service_unavailable: bool,
pub groups: HashMap<String, (u64, u64, u64)>,
pub articles: HashMap<String, Vec<u8>>,
pub xover_entries: Vec<String>,
pub xhdr_entries: Vec<String>,
pub xpat_entries: Vec<String>,
pub list_active_entries: Vec<String>,
pub silent_close_after_bytes: Option<usize>,
pub hang_after_command: Option<String>,
pub close_after_n_commands: Option<u32>,
pub response_delay: Option<Duration>,
pub article_response_overrides: HashMap<String, u16>,
pub auth_rate_limit: Option<AuthRateLimit>,
pub capabilities_unsupported: bool,
pub capabilities_mode_reader: bool,
}
impl Default for MockConfig {
fn default() -> Self {
Self {
welcome_code: 200,
welcome_message: "Mock NNTP Ready".into(),
auth_required: false,
valid_credentials: None,
fail_auth: false,
service_unavailable: false,
groups: HashMap::new(),
articles: HashMap::new(),
xover_entries: Vec::new(),
xhdr_entries: Vec::new(),
xpat_entries: Vec::new(),
list_active_entries: Vec::new(),
silent_close_after_bytes: None,
hang_after_command: None,
close_after_n_commands: None,
response_delay: None,
article_response_overrides: HashMap::new(),
auth_rate_limit: None,
capabilities_unsupported: false,
capabilities_mode_reader: false,
}
}
}
pub struct MockNntpServer {
pub addr: SocketAddr,
_shutdown: tokio::sync::watch::Sender<bool>,
}
impl MockNntpServer {
pub async fn start(config: MockConfig) -> Self {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let config = Arc::new(config);
let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(false);
tokio::spawn(async move {
loop {
tokio::select! {
result = listener.accept() => {
if let Ok((stream, _)) = result {
let cfg = config.clone();
tokio::spawn(handle_connection(stream, cfg));
}
}
_ = shutdown_rx.changed() => break,
}
}
});
Self {
addr,
_shutdown: shutdown_tx,
}
}
pub fn port(&self) -> u16 {
self.addr.port()
}
}
pub fn test_config(port: u16) -> ServerConfig {
ServerConfig {
id: "test-server".into(),
name: "Test Server".into(),
host: "127.0.0.1".into(),
port,
ssl: false,
ssl_verify: false,
username: None,
password: None,
connections: 4,
priority: 0,
enabled: true,
retention: 0,
pipelining: 1,
optional: false,
compress: false,
ramp_up_delay_ms: 0, recv_buffer_size: 0,
proxy_url: None,
trusted_fingerprint: None,
}
}
pub fn test_config_with_auth(port: u16, user: &str, pass: &str) -> ServerConfig {
let mut config = test_config(port);
config.username = Some(user.to_string());
config.password = Some(pass.to_string());
config
}
struct ConnState<'a> {
stream: &'a mut BufReader<tokio::net::TcpStream>,
config: &'a MockConfig,
bytes_written: usize,
hung: bool,
}
impl<'a> ConnState<'a> {
fn new(stream: &'a mut BufReader<tokio::net::TcpStream>, config: &'a MockConfig) -> Self {
Self {
stream,
config,
bytes_written: 0,
hung: false,
}
}
async fn write(&mut self, data: &[u8]) -> bool {
if self.hung {
return true;
}
if let Some(delay) = self.config.response_delay {
tokio::time::sleep(delay).await;
}
if let Some(limit) = self.config.silent_close_after_bytes {
if self.bytes_written >= limit {
return false;
}
if self.bytes_written + data.len() > limit {
let to_write = limit - self.bytes_written;
let _ = self.stream.get_mut().write_all(&data[..to_write]).await;
let _ = self.stream.get_mut().flush().await;
self.bytes_written += to_write;
return false;
}
}
if self.stream.get_mut().write_all(data).await.is_err() {
return false;
}
self.bytes_written += data.len();
true
}
async fn flush(&mut self) {
if !self.hung {
let _ = self.stream.get_mut().flush().await;
}
}
}
macro_rules! mwrite {
($conn:expr, $bytes:expr) => {
if !$conn.write($bytes).await {
return;
}
};
}
async fn handle_connection(stream: tokio::net::TcpStream, config: Arc<MockConfig>) {
let mut stream = BufReader::new(stream);
if config.service_unavailable {
let _ = stream
.get_mut()
.write_all(b"502 Service unavailable\r\n")
.await;
let _ = stream.get_mut().flush().await;
return;
}
let mut conn = ConnState::new(&mut stream, &config);
let welcome = format!("{} {}\r\n", config.welcome_code, config.welcome_message);
mwrite!(conn, welcome.as_bytes());
conn.flush().await;
let mut authenticated = !config.auth_required;
let mut selected_group: Option<String> = None;
let mut commands_processed: u32 = 0;
let mut line = String::new();
loop {
line.clear();
let read_result = conn.stream.read_line(&mut line).await;
match read_result {
Ok(0) => break,
Ok(_) => {}
Err(_) => break,
}
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let parts: Vec<&str> = trimmed.splitn(3, ' ').collect();
let cmd = parts[0].to_uppercase();
if let Some(ref hang_cmd) = config.hang_after_command
&& !conn.hung
&& cmd == hang_cmd.to_uppercase()
{
conn.hung = true;
}
match cmd.as_str() {
"QUIT" => {
mwrite!(conn, b"205 Goodbye\r\n");
conn.flush().await;
break;
}
"CAPABILITIES" => {
if config.capabilities_unsupported {
mwrite!(conn, b"500 Unknown command\r\n");
} else {
mwrite!(conn, b"101 Capability list:\r\n");
mwrite!(conn, b"VERSION 2\r\n");
if config.capabilities_mode_reader {
mwrite!(conn, b"MODE-READER\r\n");
mwrite!(conn, b"IHAVE\r\n");
} else {
mwrite!(conn, b"READER\r\n");
mwrite!(conn, b"POST\r\n");
mwrite!(conn, b"HDR\r\n");
mwrite!(conn, b"OVER MSGID\r\n");
mwrite!(conn, b"LIST ACTIVE NEWSGROUPS OVERVIEW.FMT\r\n");
}
mwrite!(conn, b"IMPLEMENTATION nzb-nntp-testutil 1.0\r\n");
mwrite!(conn, b".\r\n");
}
}
"MODE" => {
let sub = parts.get(1).map(|s| s.to_uppercase()).unwrap_or_default();
if sub == "READER" {
mwrite!(conn, b"200 Reader mode, posting allowed\r\n");
} else {
let resp = format!("501 Unknown MODE subcommand: {}\r\n", sub);
mwrite!(conn, resp.as_bytes());
}
}
"AUTHINFO" => {
let sub = parts.get(1).map(|s| s.to_uppercase()).unwrap_or_default();
match sub.as_str() {
"USER" => {
if config.fail_auth {
mwrite!(conn, b"482 Authentication rejected\r\n");
} else {
mwrite!(conn, b"381 Password required\r\n");
}
}
"PASS" => {
let rate_limited = config
.auth_rate_limit
.as_ref()
.map(|r| r.check_and_record())
.unwrap_or(false);
if rate_limited {
mwrite!(conn, b"481 Authentication rate-limited\r\n");
} else if config.fail_auth {
mwrite!(conn, b"481 Authentication failed\r\n");
} else if let Some((_, ref valid_pass)) = config.valid_credentials {
let given = parts.get(2).unwrap_or(&"");
if *given == valid_pass.as_str() {
authenticated = true;
mwrite!(conn, b"281 Authentication accepted\r\n");
} else {
mwrite!(conn, b"481 Authentication failed\r\n");
}
} else {
authenticated = true;
mwrite!(conn, b"281 Authentication accepted\r\n");
}
}
_ => {
mwrite!(conn, b"500 Unknown AUTHINFO subcommand\r\n");
}
}
}
"GROUP" => {
if !authenticated {
mwrite!(conn, b"480 Authentication required\r\n");
} else {
let name = parts.get(1).unwrap_or(&"");
if let Some(&(count, first, last)) = config.groups.get(*name) {
selected_group = Some(name.to_string());
let resp = format!("211 {} {} {} {}\r\n", count, first, last, name);
mwrite!(conn, resp.as_bytes());
} else {
mwrite!(conn, b"411 No such group\r\n");
}
}
}
"XOVER" => {
if !authenticated {
mwrite!(conn, b"480 Authentication required\r\n");
} else if selected_group.is_none() {
mwrite!(conn, b"412 No newsgroup selected\r\n");
} else if config.xover_entries.is_empty() {
mwrite!(conn, b"420 No articles in range\r\n");
} else {
mwrite!(conn, b"224 Overview data follows\r\n");
for entry in &config.xover_entries {
mwrite!(conn, entry.as_bytes());
mwrite!(conn, b"\r\n");
}
mwrite!(conn, b".\r\n");
}
}
"XHDR" => {
if !authenticated {
mwrite!(conn, b"480 Authentication required\r\n");
} else if config.xhdr_entries.is_empty() {
mwrite!(conn, b"420 No articles in range\r\n");
} else {
mwrite!(conn, b"221 Header data follows\r\n");
for entry in &config.xhdr_entries {
mwrite!(conn, entry.as_bytes());
mwrite!(conn, b"\r\n");
}
mwrite!(conn, b".\r\n");
}
}
"XPAT" => {
if !authenticated {
mwrite!(conn, b"480 Authentication required\r\n");
} else if config.xpat_entries.is_empty() {
mwrite!(conn, b"420 No articles matched\r\n");
} else {
mwrite!(conn, b"221 Header data follows\r\n");
for entry in &config.xpat_entries {
mwrite!(conn, entry.as_bytes());
mwrite!(conn, b"\r\n");
}
mwrite!(conn, b".\r\n");
}
}
"LIST" => {
if !authenticated {
mwrite!(conn, b"480 Authentication required\r\n");
} else if config.list_active_entries.is_empty() {
mwrite!(conn, b"215 List of newsgroups follows\r\n");
mwrite!(conn, b".\r\n");
} else {
mwrite!(conn, b"215 List of newsgroups follows\r\n");
for entry in &config.list_active_entries {
mwrite!(conn, entry.as_bytes());
mwrite!(conn, b"\r\n");
}
mwrite!(conn, b".\r\n");
}
}
"ARTICLE" => {
if !authenticated {
mwrite!(conn, b"480 Authentication required\r\n");
} else {
let mid = parts
.get(1)
.unwrap_or(&"")
.trim_matches(|c| c == '<' || c == '>');
if let Some(&code) = config.article_response_overrides.get(mid) {
let resp = format!("{} <{}>\r\n", code, mid);
mwrite!(conn, resp.as_bytes());
} else if let Some(data) = config.articles.get(mid) {
let header = format!("220 0 <{}>\r\n", mid);
mwrite!(conn, header.as_bytes());
if !write_multiline_body(&mut conn, data).await {
return;
}
} else {
let resp = format!("430 No article: <{}>\r\n", mid);
mwrite!(conn, resp.as_bytes());
}
}
}
"BODY" => {
if !authenticated {
mwrite!(conn, b"480 Authentication required\r\n");
} else {
let mid = parts
.get(1)
.unwrap_or(&"")
.trim_matches(|c| c == '<' || c == '>');
if let Some(&code) = config.article_response_overrides.get(mid) {
let resp = format!("{} <{}>\r\n", code, mid);
mwrite!(conn, resp.as_bytes());
} else if let Some(data) = config.articles.get(mid) {
let header = format!("222 0 <{}>\r\n", mid);
mwrite!(conn, header.as_bytes());
if !write_multiline_body(&mut conn, data).await {
return;
}
} else {
let resp = format!("430 No article: <{}>\r\n", mid);
mwrite!(conn, resp.as_bytes());
}
}
}
"STAT" => {
if !authenticated {
mwrite!(conn, b"480 Authentication required\r\n");
} else {
let mid = parts
.get(1)
.unwrap_or(&"")
.trim_matches(|c| c == '<' || c == '>');
if let Some(&code) = config.article_response_overrides.get(mid) {
let resp = format!("{} <{}>\r\n", code, mid);
mwrite!(conn, resp.as_bytes());
} else if config.articles.contains_key(mid) {
let resp = format!("223 0 <{}>\r\n", mid);
mwrite!(conn, resp.as_bytes());
} else {
let resp = format!("430 No article: <{}>\r\n", mid);
mwrite!(conn, resp.as_bytes());
}
}
}
_ => {
let resp = format!("500 Unknown command: {}\r\n", cmd);
mwrite!(conn, resp.as_bytes());
}
}
conn.flush().await;
commands_processed += 1;
if let Some(limit) = config.close_after_n_commands
&& commands_processed >= limit
{
return;
}
}
}
async fn write_multiline_body(conn: &mut ConnState<'_>, data: &[u8]) -> bool {
let mut start = 0usize;
let mut wrote_anything = false;
while start < data.len() {
let nl_pos = data[start..]
.iter()
.position(|&b| b == b'\n')
.map(|p| start + p);
let line_end = nl_pos.unwrap_or(data.len());
let mut line = &data[start..line_end];
if line.last() == Some(&b'\r') {
line = &line[..line.len() - 1];
}
wrote_anything = true;
if line.first() == Some(&b'.') && !conn.write(b".").await {
return false;
}
if !conn.write(line).await {
return false;
}
if !conn.write(b"\r\n").await {
return false;
}
start = match nl_pos {
Some(p) => p + 1,
None => data.len(),
};
}
if !wrote_anything && !conn.write(b"\r\n").await {
return false;
}
if !conn.write(b".\r\n").await {
return false;
}
conn.flush().await;
true
}