nntp_proxy/cache/session.rs
1//! Caching session handler that wraps ClientSession with caching logic
2
3use anyhow::Result;
4use std::net::SocketAddr;
5use std::sync::Arc;
6use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
7use tokio::net::TcpStream;
8use tracing::{debug, info};
9
10use crate::cache::article::{ArticleCache, CachedArticle};
11use crate::command::{CommandAction, CommandHandler};
12use crate::constants::buffer;
13use crate::constants::stateless_proxy::NNTP_COMMAND_NOT_SUPPORTED;
14
15/// Extract message-ID from command arguments using fast byte searching
16///
17/// Message-IDs are ASCII and must be in the format <...@...>
18/// See RFC 5536 Section 3.1.3: https://datatracker.ietf.org/doc/html/rfc5536#section-3.1.3
19fn extract_message_id(command: &str) -> Option<String> {
20 let trimmed = command.trim();
21 let bytes = trimmed.as_bytes();
22
23 // Find opening '<'
24 let start = memchr::memchr(b'<', bytes)?;
25
26 // Find closing '>' after the '<'
27 // Since end is relative to &bytes[start..], the actual position is start + end
28 let end = memchr::memchr(b'>', &bytes[start + 1..])?;
29 let msgid_end = start + end + 2; // +1 for the slice offset, +1 to include '>'
30
31 // Safety: Message-IDs are ASCII, so no need for is_char_boundary checks
32 // We already know msgid_end is valid since memchr found '>' at that position
33 Some(trimmed[start..msgid_end].to_string())
34}
35
36/// Check if command is cacheable (ARTICLE/BODY/HEAD/STAT with message-ID)
37fn is_cacheable_command(command: &str) -> bool {
38 let upper = command.trim().to_uppercase();
39 (upper.starts_with("ARTICLE ")
40 || upper.starts_with("BODY ")
41 || upper.starts_with("HEAD ")
42 || upper.starts_with("STAT "))
43 && extract_message_id(command).is_some()
44}
45
46/// Parse status code from binary data and determine if it's a multiline response
47fn parse_multiline_status(data: &[u8]) -> bool {
48 std::str::from_utf8(data)
49 .ok()
50 .and_then(parse_status_code)
51 .map(is_multiline_status)
52 .unwrap_or(false)
53}
54
55/// Caching session that wraps standard session with article cache
56pub struct CachingSession {
57 client_addr: SocketAddr,
58 cache: Arc<ArticleCache>,
59}
60
61impl CachingSession {
62 /// Create a new caching session
63 pub fn new(client_addr: SocketAddr, cache: Arc<ArticleCache>) -> Self {
64 Self { client_addr, cache }
65 }
66
67 /// Handle client connection with caching support
68 pub async fn handle_with_pooled_backend<T>(
69 &self,
70 mut client_stream: TcpStream,
71 backend_conn: T,
72 ) -> Result<(u64, u64)>
73 where
74 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
75 {
76 use tokio::io::BufReader;
77
78 let (client_read, mut client_write) = client_stream.split();
79 let (backend_read, mut backend_write) = tokio::io::split(backend_conn);
80 let mut client_reader = BufReader::new(client_read);
81 let mut backend_reader = BufReader::with_capacity(buffer::MEDIUM_BUFFER_SIZE, backend_read);
82
83 let mut client_to_backend_bytes = 0u64;
84 let mut backend_to_client_bytes = 0u64;
85 let mut line = String::with_capacity(buffer::COMMAND_SIZE);
86 // Pre-allocate with typical NNTP response line size (most are < 512 bytes)
87 // Reduces reallocations during line reading
88 let mut first_line = Vec::with_capacity(512);
89
90 debug!("Caching session for client {} starting", self.client_addr);
91
92 loop {
93 line.clear();
94
95 tokio::select! {
96 result = client_reader.read_line(&mut line) => {
97 match result {
98 Ok(0) => {
99 debug!("Client {} disconnected", self.client_addr);
100 break;
101 }
102 Ok(_n) => {
103 debug!("Client {} sent command: {}", self.client_addr, line.trim());
104
105 // Check if this is a cacheable command
106 if is_cacheable_command(&line) {
107 if let Some(message_id) = extract_message_id(&line) {
108 // Check cache first
109 if let Some(cached) = self.cache.get(&message_id).await {
110 info!("Cache HIT for message-ID: {} (size: {} bytes)", message_id, cached.response.len());
111 client_write.write_all(&cached.response).await?;
112 backend_to_client_bytes += cached.response.len() as u64;
113 continue;
114 } else {
115 info!("Cache MISS for message-ID: {}", message_id);
116 }
117 } else {
118 debug!("No message-ID extracted from command: {}", line.trim());
119 }
120 }
121
122 // Handle command using standard handler
123 match CommandHandler::handle_command(&line) {
124 CommandAction::InterceptAuth(auth_action) => {
125 use crate::auth::AuthHandler;
126 use crate::command::AuthAction;
127 let response = match auth_action {
128 AuthAction::RequestPassword => AuthHandler::user_response(),
129 AuthAction::AcceptAuth => AuthHandler::pass_response(),
130 };
131 client_write.write_all(response).await?;
132 backend_to_client_bytes += response.len() as u64;
133 }
134 CommandAction::Reject(_reason) => {
135 client_write.write_all(NNTP_COMMAND_NOT_SUPPORTED).await?;
136 backend_to_client_bytes += NNTP_COMMAND_NOT_SUPPORTED.len() as u64;
137 }
138 CommandAction::ForwardStateless => {
139 // Forward stateless commands to backend
140 backend_write.write_all(line.as_bytes()).await?;
141 client_to_backend_bytes += line.len() as u64;
142
143 // Read response using read_until for efficiency
144 first_line.clear();
145 backend_reader.read_until(b'\n', &mut first_line).await?;
146
147 if first_line.is_empty() {
148 debug!("Backend {} closed connection", self.client_addr);
149 break;
150 }
151
152 // Use mem::take to transfer ownership from first_line to response_buffer
153 // More idiomatic than swap - explicitly shows we're taking the value and leaving default
154 let mut response_buffer = std::mem::take(&mut first_line);
155
156 // Check for backend disconnect (205 status)
157 if response_buffer.len() >= 3 && &response_buffer[0..3] == b"205" {
158 debug!("Backend {} sent disconnect: {}", self.client_addr, String::from_utf8_lossy(&response_buffer));
159 client_write.write_all(&response_buffer).await?;
160 backend_to_client_bytes += response_buffer.len() as u64;
161 break;
162 }
163
164 let is_multiline = parse_multiline_status(&response_buffer);
165
166 if is_multiline {
167 loop {
168 let start_pos = response_buffer.len();
169 backend_reader.read_until(b'\n', &mut response_buffer).await?;
170
171 if response_buffer.len() == start_pos {
172 break;
173 }
174
175 // Check for end marker by examining just the new data
176 let new_data = &response_buffer[start_pos..];
177 if new_data == b".\r\n" || new_data == b".\n" {
178 break;
179 }
180 }
181 }
182
183 client_write.write_all(&response_buffer).await?;
184 backend_to_client_bytes += response_buffer.len() as u64;
185 }
186 CommandAction::ForwardHighThroughput => {
187 // Forward to backend
188 backend_write.write_all(line.as_bytes()).await?;
189 client_to_backend_bytes += line.len() as u64;
190
191 // Read first line of response using read_until for efficiency
192 first_line.clear();
193 backend_reader.read_until(b'\n', &mut first_line).await?;
194
195 if first_line.is_empty() {
196 debug!("Backend {} closed connection", self.client_addr);
197 break;
198 }
199
200 // Transfer ownership using mem::take (leaves first_line as empty Vec)
201 let mut response_buffer = std::mem::take(&mut first_line);
202
203 // Check for backend disconnect (205 status)
204 if response_buffer.len() >= 3 && &response_buffer[0..3] == b"205" {
205 debug!("Backend {} sent disconnect: {}", self.client_addr, String::from_utf8_lossy(&response_buffer));
206 client_write.write_all(&response_buffer).await?;
207 backend_to_client_bytes += response_buffer.len() as u64;
208 break;
209 }
210
211 // Check if this is a multiline response by parsing status code
212 let is_multiline = parse_multiline_status(&response_buffer);
213
214 // Read multiline data if needed (as raw bytes)
215 if is_multiline {
216 loop {
217 let start_pos = response_buffer.len();
218 backend_reader.read_until(b'\n', &mut response_buffer).await?;
219
220 if response_buffer.len() == start_pos {
221 break;
222 }
223
224 // Check for end marker by examining just the new data
225 let new_data = &response_buffer[start_pos..];
226 if new_data == b".\r\n" || new_data == b".\n" {
227 break;
228 }
229 }
230 }
231
232 // Cache if it was a cacheable command
233 if is_cacheable_command(&line)
234 && let Some(message_id) = extract_message_id(&line) {
235 // Only cache successful responses (2xx)
236 if !response_buffer.is_empty() && response_buffer[0] == b'2' {
237 info!("Caching response for message-ID: {}", message_id);
238 self.cache.insert(
239 message_id,
240 CachedArticle {
241 response: Arc::new(response_buffer.clone()),
242 }
243 ).await;
244 }
245 }
246
247 // Forward response to client
248 client_write.write_all(&response_buffer).await?;
249 backend_to_client_bytes += response_buffer.len() as u64;
250 }
251 }
252 }
253 Err(e) => {
254 debug!("Error reading from client {}: {}", self.client_addr, e);
255 break;
256 }
257 }
258 }
259 }
260 }
261
262 Ok((client_to_backend_bytes, backend_to_client_bytes))
263 }
264}
265
266/// Parse status code from NNTP response line
267fn parse_status_code(line: &str) -> Option<u16> {
268 let trimmed = line.trim();
269 if trimmed.len() < 3 {
270 return None;
271 }
272 trimmed[0..3].parse().ok()
273}
274
275/// Check if a status code indicates a multiline response
276fn is_multiline_status(status_code: u16) -> bool {
277 // Multiline responses: 1xx informational, and specific 2xx codes
278 match status_code {
279 100..=199 => true, // Informational multiline
280 215 | 220 | 221 | 222 | 224 | 225 | 230 | 231 | 282 => true, // Article/list data
281 _ => false,
282 }
283}