nntp_proxy/cache/
session.rs

1//! Caching session handler that wraps ClientSession with caching logic
2
3use anyhow::Result;
4use std::net::SocketAddr;
5use std::sync::Arc;
6use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
7use tokio::net::TcpStream;
8use tracing::{debug, info};
9
10use crate::cache::article::{ArticleCache, CachedArticle};
11use crate::command::{CommandAction, CommandHandler};
12use crate::constants::buffer;
13use crate::constants::stateless_proxy::NNTP_COMMAND_NOT_SUPPORTED;
14
15/// Extract message-ID from command arguments using fast byte searching
16/// 
17/// Message-IDs are ASCII and must be in the format <...@...>
18/// See RFC 5536 Section 3.1.3: https://datatracker.ietf.org/doc/html/rfc5536#section-3.1.3
19fn extract_message_id(command: &str) -> Option<String> {
20    let trimmed = command.trim();
21    let bytes = trimmed.as_bytes();
22
23    // Find opening '<'
24    let start = memchr::memchr(b'<', bytes)?;
25    
26    // Find closing '>' after the '<'
27    // Since end is relative to &bytes[start..], the actual position is start + end
28    let end = memchr::memchr(b'>', &bytes[start + 1..])?;
29    let msgid_end = start + end + 2; // +1 for the slice offset, +1 to include '>'
30    
31    // Safety: Message-IDs are ASCII, so no need for is_char_boundary checks
32    // We already know msgid_end is valid since memchr found '>' at that position
33    Some(trimmed[start..msgid_end].to_string())
34}
35
36/// Check if command is cacheable (ARTICLE/BODY/HEAD/STAT with message-ID)
37fn is_cacheable_command(command: &str) -> bool {
38    let upper = command.trim().to_uppercase();
39    (upper.starts_with("ARTICLE ")
40        || upper.starts_with("BODY ")
41        || upper.starts_with("HEAD ")
42        || upper.starts_with("STAT "))
43        && extract_message_id(command).is_some()
44}
45
46/// Parse status code from binary data and determine if it's a multiline response
47fn parse_multiline_status(data: &[u8]) -> bool {
48    std::str::from_utf8(data)
49        .ok()
50        .and_then(parse_status_code)
51        .map(is_multiline_status)
52        .unwrap_or(false)
53}
54
55/// Caching session that wraps standard session with article cache
56pub struct CachingSession {
57    client_addr: SocketAddr,
58    cache: Arc<ArticleCache>,
59}
60
61impl CachingSession {
62    /// Create a new caching session
63    pub fn new(client_addr: SocketAddr, cache: Arc<ArticleCache>) -> Self {
64        Self { client_addr, cache }
65    }
66
67    /// Handle client connection with caching support
68    pub async fn handle_with_pooled_backend<T>(
69        &self,
70        mut client_stream: TcpStream,
71        backend_conn: T,
72    ) -> Result<(u64, u64)>
73    where
74        T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
75    {
76        use tokio::io::BufReader;
77
78        let (client_read, mut client_write) = client_stream.split();
79        let (backend_read, mut backend_write) = tokio::io::split(backend_conn);
80        let mut client_reader = BufReader::new(client_read);
81        let mut backend_reader = BufReader::with_capacity(buffer::MEDIUM_BUFFER_SIZE, backend_read);
82
83        let mut client_to_backend_bytes = 0u64;
84        let mut backend_to_client_bytes = 0u64;
85        let mut line = String::with_capacity(buffer::COMMAND_SIZE);
86        // Pre-allocate with typical NNTP response line size (most are < 512 bytes)
87        // Reduces reallocations during line reading
88        let mut first_line = Vec::with_capacity(512);
89
90        debug!("Caching session for client {} starting", self.client_addr);
91
92        loop {
93            line.clear();
94
95            tokio::select! {
96                result = client_reader.read_line(&mut line) => {
97                    match result {
98                        Ok(0) => {
99                            debug!("Client {} disconnected", self.client_addr);
100                            break;
101                        }
102                        Ok(_n) => {
103                            debug!("Client {} sent command: {}", self.client_addr, line.trim());
104
105                            // Check if this is a cacheable command
106                            if is_cacheable_command(&line) {
107                                if let Some(message_id) = extract_message_id(&line) {
108                                    // Check cache first
109                                    if let Some(cached) = self.cache.get(&message_id).await {
110                                        info!("Cache HIT for message-ID: {} (size: {} bytes)", message_id, cached.response.len());
111                                        client_write.write_all(&cached.response).await?;
112                                        backend_to_client_bytes += cached.response.len() as u64;
113                                        continue;
114                                    } else {
115                                        info!("Cache MISS for message-ID: {}", message_id);
116                                    }
117                                } else {
118                                    debug!("No message-ID extracted from command: {}", line.trim());
119                                }
120                            }
121
122                            // Handle command using standard handler
123                            match CommandHandler::handle_command(&line) {
124                                CommandAction::InterceptAuth(auth_action) => {
125                                    use crate::auth::AuthHandler;
126                                    use crate::command::AuthAction;
127                                    let response = match auth_action {
128                                        AuthAction::RequestPassword => AuthHandler::user_response(),
129                                        AuthAction::AcceptAuth => AuthHandler::pass_response(),
130                                    };
131                                    client_write.write_all(response).await?;
132                                    backend_to_client_bytes += response.len() as u64;
133                                }
134                                CommandAction::Reject(_reason) => {
135                                    client_write.write_all(NNTP_COMMAND_NOT_SUPPORTED).await?;
136                                    backend_to_client_bytes += NNTP_COMMAND_NOT_SUPPORTED.len() as u64;
137                                }
138                                CommandAction::ForwardStateless => {
139                                    // Forward stateless commands to backend
140                                    backend_write.write_all(line.as_bytes()).await?;
141                                    client_to_backend_bytes += line.len() as u64;
142
143                                    // Read response using read_until for efficiency
144                                    first_line.clear();
145                                    backend_reader.read_until(b'\n', &mut first_line).await?;
146
147                                    if first_line.is_empty() {
148                                        debug!("Backend {} closed connection", self.client_addr);
149                                        break;
150                                    }
151
152                                    // Use mem::take to transfer ownership from first_line to response_buffer
153                                    // More idiomatic than swap - explicitly shows we're taking the value and leaving default
154                                    let mut response_buffer = std::mem::take(&mut first_line);
155
156                                    // Check for backend disconnect (205 status)
157                                    if response_buffer.len() >= 3 && &response_buffer[0..3] == b"205" {
158                                        debug!("Backend {} sent disconnect: {}", self.client_addr, String::from_utf8_lossy(&response_buffer));
159                                        client_write.write_all(&response_buffer).await?;
160                                        backend_to_client_bytes += response_buffer.len() as u64;
161                                        break;
162                                    }
163
164                                    let is_multiline = parse_multiline_status(&response_buffer);
165
166                                    if is_multiline {
167                                        loop {
168                                            let start_pos = response_buffer.len();
169                                            backend_reader.read_until(b'\n', &mut response_buffer).await?;
170
171                                            if response_buffer.len() == start_pos {
172                                                break;
173                                            }
174
175                                            // Check for end marker by examining just the new data
176                                            let new_data = &response_buffer[start_pos..];
177                                            if new_data == b".\r\n" || new_data == b".\n" {
178                                                break;
179                                            }
180                                        }
181                                    }
182
183                                    client_write.write_all(&response_buffer).await?;
184                                    backend_to_client_bytes += response_buffer.len() as u64;
185                                }
186                                CommandAction::ForwardHighThroughput => {
187                                    // Forward to backend
188                                    backend_write.write_all(line.as_bytes()).await?;
189                                    client_to_backend_bytes += line.len() as u64;
190
191                                    // Read first line of response using read_until for efficiency
192                                    first_line.clear();
193                                    backend_reader.read_until(b'\n', &mut first_line).await?;
194
195                                    if first_line.is_empty() {
196                                        debug!("Backend {} closed connection", self.client_addr);
197                                        break;
198                                    }
199
200                                    // Transfer ownership using mem::take (leaves first_line as empty Vec)
201                                    let mut response_buffer = std::mem::take(&mut first_line);
202
203                                    // Check for backend disconnect (205 status)
204                                    if response_buffer.len() >= 3 && &response_buffer[0..3] == b"205" {
205                                        debug!("Backend {} sent disconnect: {}", self.client_addr, String::from_utf8_lossy(&response_buffer));
206                                        client_write.write_all(&response_buffer).await?;
207                                        backend_to_client_bytes += response_buffer.len() as u64;
208                                        break;
209                                    }
210
211                                    // Check if this is a multiline response by parsing status code
212                                    let is_multiline = parse_multiline_status(&response_buffer);
213
214                                    // Read multiline data if needed (as raw bytes)
215                                    if is_multiline {
216                                        loop {
217                                            let start_pos = response_buffer.len();
218                                            backend_reader.read_until(b'\n', &mut response_buffer).await?;
219
220                                            if response_buffer.len() == start_pos {
221                                                break;
222                                            }
223
224                                            // Check for end marker by examining just the new data
225                                            let new_data = &response_buffer[start_pos..];
226                                            if new_data == b".\r\n" || new_data == b".\n" {
227                                                break;
228                                            }
229                                        }
230                                    }
231
232                                    // Cache if it was a cacheable command
233                                    if is_cacheable_command(&line)
234                                        && let Some(message_id) = extract_message_id(&line) {
235                                            // Only cache successful responses (2xx)
236                                            if !response_buffer.is_empty() && response_buffer[0] == b'2' {
237                                                info!("Caching response for message-ID: {}", message_id);
238                                                self.cache.insert(
239                                                    message_id,
240                                                    CachedArticle {
241                                                        response: Arc::new(response_buffer.clone()),
242                                                    }
243                                                ).await;
244                                            }
245                                    }
246
247                                    // Forward response to client
248                                    client_write.write_all(&response_buffer).await?;
249                                    backend_to_client_bytes += response_buffer.len() as u64;
250                                }
251                            }
252                        }
253                        Err(e) => {
254                            debug!("Error reading from client {}: {}", self.client_addr, e);
255                            break;
256                        }
257                    }
258                }
259            }
260        }
261
262        Ok((client_to_backend_bytes, backend_to_client_bytes))
263    }
264}
265
266/// Parse status code from NNTP response line
267fn parse_status_code(line: &str) -> Option<u16> {
268    let trimmed = line.trim();
269    if trimmed.len() < 3 {
270        return None;
271    }
272    trimmed[0..3].parse().ok()
273}
274
275/// Check if a status code indicates a multiline response
276fn is_multiline_status(status_code: u16) -> bool {
277    // Multiline responses: 1xx informational, and specific 2xx codes
278    match status_code {
279        100..=199 => true, // Informational multiline
280        215 | 220 | 221 | 222 | 224 | 225 | 230 | 231 | 282 => true, // Article/list data
281        _ => false,
282    }
283}