nntp_proxy/cache/
session.rs

1//! Caching session handler that wraps ClientSession with caching logic
2
3use anyhow::Result;
4use std::net::SocketAddr;
5use std::sync::Arc;
6use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
7use tokio::net::TcpStream;
8use tracing::{debug, info};
9
10use crate::auth::AuthHandler;
11use crate::cache::article::{ArticleCache, CachedArticle};
12use crate::command::{CommandHandler, NntpCommand};
13use crate::constants::buffer;
14use crate::protocol::{NntpResponse, ResponseCode};
15use crate::types::BytesTransferred;
16
17/// Caching session that wraps standard session with article cache
18pub struct CachingSession {
19    client_addr: SocketAddr,
20    cache: Arc<ArticleCache>,
21    auth_handler: Arc<AuthHandler>,
22    authenticated: std::sync::atomic::AtomicBool,
23}
24
25impl CachingSession {
26    /// Create a new caching session
27    pub fn new(
28        client_addr: SocketAddr,
29        cache: Arc<ArticleCache>,
30        auth_handler: Arc<AuthHandler>,
31    ) -> Self {
32        Self {
33            client_addr,
34            cache,
35            auth_handler,
36            authenticated: std::sync::atomic::AtomicBool::new(false),
37        }
38    }
39
40    /// Handle client connection with caching support
41    pub async fn handle_with_pooled_backend<T>(
42        &self,
43        mut client_stream: TcpStream,
44        backend_conn: T,
45    ) -> Result<(u64, u64)>
46    where
47        T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
48    {
49        use tokio::io::BufReader;
50
51        let (client_read, mut client_write) = client_stream.split();
52        let (backend_read, mut backend_write) = tokio::io::split(backend_conn);
53        let mut client_reader = BufReader::new(client_read);
54        let mut backend_reader = BufReader::with_capacity(buffer::POOL, backend_read);
55
56        let mut client_to_backend_bytes = BytesTransferred::zero();
57        let mut backend_to_client_bytes = BytesTransferred::zero();
58        let mut line = String::with_capacity(buffer::COMMAND);
59        // Pre-allocate with typical NNTP response line size (most are < 512 bytes)
60        // Reduces reallocations during line reading
61        // `first_line` is a Vec<u8> because it is used for reading raw bytes from the backend,
62        // which may not always be valid UTF-8, whereas `line` is a String for text-based client input.
63        let mut first_line = Vec::with_capacity(buffer::COMMAND);
64
65        // Auth state: username from AUTHINFO USER command
66        let mut auth_username: Option<String> = None;
67
68        debug!("Caching session for client {} starting", self.client_addr);
69
70        loop {
71            line.clear();
72
73            tokio::select! {
74                result = client_reader.read_line(&mut line) => {
75                    match result {
76                        Ok(0) => {
77                            debug!("Client {} disconnected", self.client_addr);
78                            break;
79                        }
80                        Ok(_n) => {
81                            debug!("Client {} sent command: {}", self.client_addr, line.trim());
82
83                            // PERFORMANCE OPTIMIZATION: Fast path after authentication
84                            if self.authenticated.load(std::sync::atomic::Ordering::Acquire) || !self.auth_handler.is_enabled() {
85                                // Already authenticated OR auth disabled - process normally (HOT PATH)
86
87                                // Check if this is a cacheable command (article by message-ID)
88                                if matches!(NntpCommand::classify(&line), NntpCommand::ArticleByMessageId) {
89                                    if let Some(message_id) = NntpResponse::extract_message_id(&line) {
90                                        // Check cache first
91                                        if let Some(cached) = self.cache.get(&message_id).await {
92                                            info!("Cache HIT for message-ID: {} (size: {} bytes)", message_id, cached.response.len());
93                                            client_write.write_all(&cached.response).await?;
94                                            backend_to_client_bytes.add(cached.response.len());
95                                            continue;
96                                        }
97                                        info!("Cache MISS for message-ID: {}", message_id);
98                                    } else {
99                                        debug!("No message-ID extracted from command: {}", line.trim());
100                                    }
101                                }
102
103                                // Forward to backend and cache response
104                                backend_write.write_all(line.as_bytes()).await?;
105                                client_to_backend_bytes.add(line.len());
106
107                                // Read first line of response using read_until for efficiency
108                                first_line.clear();
109                                backend_reader.read_until(b'\n', &mut first_line).await?;
110
111                                if first_line.is_empty() {
112                                    debug!("Backend {} closed connection", self.client_addr);
113                                    break;
114                                }
115
116                                // Transfer ownership using mem::take (leaves first_line as empty Vec)
117                                let mut response_buffer = std::mem::take(&mut first_line);
118
119                                // Check for backend disconnect (205 status)
120                                if NntpResponse::is_disconnect(&response_buffer) {
121                                    debug!("Backend {} sent disconnect: {}", self.client_addr, String::from_utf8_lossy(&response_buffer));
122                                    client_write.write_all(&response_buffer).await?;
123                                    backend_to_client_bytes.add(response_buffer.len());
124                                    break;
125                                }
126
127                                // Parse response code once and reuse it (avoid redundant parsing)
128                                let response_code = ResponseCode::parse(&response_buffer);
129                                let is_multiline = response_code.is_multiline();
130
131                                // Read multiline data if needed (as raw bytes)
132                                if is_multiline {
133                                    loop {
134                                        let start_pos = response_buffer.len();
135                                        backend_reader.read_until(b'\n', &mut response_buffer).await?;
136
137                                        if response_buffer.len() == start_pos {
138                                            break;
139                                        }
140
141                                        // Check for end marker by examining just the new data
142                                        let new_data = &response_buffer[start_pos..];
143                                        if new_data == b".\r\n" || new_data == b".\n" {
144                                            break;
145                                        }
146                                    }
147                                }
148
149                                // Cache if it was a cacheable command (article by message-ID)
150                                if matches!(NntpCommand::classify(&line), NntpCommand::ArticleByMessageId)
151                                    && let Some(message_id) = NntpResponse::extract_message_id(&line) {
152                                        // Only cache successful responses (2xx) - reuse already-parsed response_code
153                                        if response_code.is_success() {
154                                            info!("Caching response for message-ID: {}", message_id);
155                                            self.cache.insert(
156                                                message_id,
157                                                CachedArticle {
158                                                    response: Arc::new(response_buffer.clone()),
159                                                }
160                                            ).await;
161                                        }
162                                }
163
164                                // Forward response to client
165                                client_write.write_all(&response_buffer).await?;
166                                backend_to_client_bytes.add(response_buffer.len());
167                            } else {
168                                // Not yet authenticated - check for auth commands
169                                use crate::command::CommandAction;
170                                let action = CommandHandler::handle_command(&line);
171                                match action {
172                                    CommandAction::ForwardStateless => {
173                                        // Reject all non-auth commands before authentication
174                                        let response = b"480 Authentication required\r\n";
175                                        client_write.write_all(response).await?;
176                                        backend_to_client_bytes.add(response.len());
177                                    }
178                                    CommandAction::InterceptAuth(auth_action) => {
179                                        // Store username if this is AUTHINFO USER
180                                        if let crate::command::AuthAction::RequestPassword(ref username) = auth_action {
181                                            auth_username = Some(username.clone());
182                                        }
183
184                                        // Handle auth and validate
185                                        let (bytes, auth_success) = self
186                                            .auth_handler
187                                            .handle_auth_command(auth_action, &mut client_write, auth_username.as_deref())
188                                            .await?;
189                                        backend_to_client_bytes.add(bytes);
190
191                                        if auth_success {
192                                            self.authenticated.store(true, std::sync::atomic::Ordering::Release);
193                                        }
194                                    }
195                                    CommandAction::Reject(response) => {
196                                        // Send rejection response inline
197                                        client_write.write_all(response.as_bytes()).await?;
198                                        backend_to_client_bytes.add(response.len());
199                                    }
200                                }
201                            }
202                        }
203                        Err(e) => {
204                            debug!("Error reading from client {}: {}", self.client_addr, e);
205                            break;
206                        }
207                    }
208                }
209            }
210        }
211
212        Ok((
213            client_to_backend_bytes.as_u64(),
214            backend_to_client_bytes.as_u64(),
215        ))
216    }
217}