Skip to main content

mockforge_tcp/
server.rs

1//! TCP server implementation
2
3use crate::{TcpConfig, TcpSpecRegistry};
4use mockforge_core::Result;
5use std::net::SocketAddr;
6use std::sync::Arc;
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8use tokio::net::{TcpListener, TcpStream};
9use tokio::time::{sleep, timeout, Duration};
10use tracing::{debug, error, info, warn};
11
12/// TCP server
13pub struct TcpServer {
14    config: TcpConfig,
15    spec_registry: Arc<TcpSpecRegistry>,
16}
17
18impl TcpServer {
19    /// Create a new TCP server
20    pub fn new(config: TcpConfig, spec_registry: Arc<TcpSpecRegistry>) -> Result<Self> {
21        Ok(Self {
22            config,
23            spec_registry,
24        })
25    }
26
27    /// Start the TCP server
28    pub async fn start(&self) -> Result<()> {
29        let addr = format!("{}:{}", self.config.host, self.config.port);
30        let listener = TcpListener::bind(&addr).await?;
31
32        info!("TCP server listening on {}", addr);
33
34        loop {
35            match listener.accept().await {
36                Ok((stream, peer_addr)) => {
37                    debug!("New TCP connection from {}", peer_addr);
38
39                    let registry = self.spec_registry.clone();
40                    let config = self.config.clone();
41
42                    tokio::spawn(async move {
43                        if let Err(e) =
44                            handle_tcp_connection(stream, peer_addr, registry, config).await
45                        {
46                            error!("TCP connection error from {}: {}", peer_addr, e);
47                        }
48                    });
49                }
50                Err(e) => {
51                    error!("Failed to accept TCP connection: {}", e);
52                }
53            }
54        }
55    }
56}
57
58/// Handle a single TCP connection
59async fn handle_tcp_connection(
60    mut stream: TcpStream,
61    peer_addr: SocketAddr,
62    registry: Arc<TcpSpecRegistry>,
63    config: TcpConfig,
64) -> Result<()> {
65    debug!("Handling TCP connection from {}", peer_addr);
66
67    let mut buffer = vec![0u8; config.read_buffer_size];
68    let mut accumulated_data = Vec::new();
69
70    loop {
71        // Set read timeout
72        let read_timeout = Duration::from_secs(config.timeout_secs);
73
74        match timeout(read_timeout, stream.read(&mut buffer)).await {
75            Ok(Ok(0)) => {
76                // Connection closed by client
77                debug!("TCP connection closed by client: {}", peer_addr);
78                break;
79            }
80            Ok(Ok(n)) => {
81                let received_data = &buffer[..n];
82                accumulated_data.extend_from_slice(received_data);
83
84                debug!("Received {} bytes from {}", n, peer_addr);
85
86                // Try to find matching fixture
87                let response_data =
88                    if let Some(fixture) = registry.find_matching_fixture(&accumulated_data) {
89                        debug!("Found matching fixture: {}", fixture.identifier);
90
91                        // Apply delay if configured
92                        if fixture.response.delay_ms > 0 {
93                            sleep(Duration::from_millis(fixture.response.delay_ms)).await;
94                        }
95
96                        // Generate response data
97                        generate_response_data(&fixture.response)?
98                    } else if config.echo_mode {
99                        // Echo mode: echo back received data
100                        debug!("No fixture match, echoing data back");
101                        accumulated_data.clone()
102                    } else {
103                        // No match and echo mode disabled - close connection
104                        warn!("No fixture match and echo mode disabled, closing connection");
105                        break;
106                    };
107
108                // Send response
109                if !response_data.is_empty() {
110                    if let Err(e) = stream.write_all(&response_data).await {
111                        error!("Failed to write response to {}: {}", peer_addr, e);
112                        break;
113                    }
114
115                    if let Err(e) = stream.flush().await {
116                        error!("Failed to flush response to {}: {}", peer_addr, e);
117                        break;
118                    }
119                }
120
121                // Check if we should close after response
122                if let Some(fixture) = registry.find_matching_fixture(&accumulated_data) {
123                    if fixture.response.close_after_response {
124                        debug!("Closing connection after response as configured");
125                        break;
126                    }
127
128                    if !fixture.response.keep_alive {
129                        debug!("Closing connection (keep_alive=false)");
130                        break;
131                    }
132                } else if !config.echo_mode {
133                    // Close if echo mode disabled and no fixture matched
134                    break;
135                }
136
137                // If delimiter is configured, check if we've received complete message
138                if let Some(ref delimiter) = config.delimiter {
139                    if accumulated_data.ends_with(delimiter) {
140                        debug!("Received complete message (matched delimiter), resetting buffer");
141                        accumulated_data.clear();
142                    }
143                } else {
144                    // Stream mode: reset buffer for next read
145                    accumulated_data.clear();
146                }
147            }
148            Ok(Err(e)) => {
149                error!("TCP read error from {}: {}", peer_addr, e);
150                break;
151            }
152            Err(_) => {
153                warn!("TCP read timeout from {}", peer_addr);
154                break;
155            }
156        }
157    }
158
159    debug!("TCP connection handler finished for {}", peer_addr);
160    Ok(())
161}
162
163/// Generate response data from fixture configuration
164fn generate_response_data(response: &crate::fixtures::TcpResponse) -> Result<Vec<u8>> {
165    match response.encoding.as_str() {
166        "hex" => hex::decode(&response.data)
167            .map_err(|e| mockforge_core::Error::generic(format!("Invalid hex data: {}", e))),
168        "base64" => {
169            use base64::Engine;
170            base64::engine::general_purpose::STANDARD
171                .decode(&response.data)
172                .map_err(|e| mockforge_core::Error::generic(format!("Invalid base64 data: {e}")))
173        }
174        "text" => Ok(response.data.as_bytes().to_vec()),
175        "file" => {
176            let file_path = response.file_path.as_ref().ok_or_else(|| {
177                mockforge_core::Error::generic("file_path not specified for file encoding")
178            })?;
179
180            std::fs::read(file_path).map_err(|e| {
181                mockforge_core::Error::generic(format!(
182                    "Failed to read file {:?}: {}",
183                    file_path, e
184                ))
185            })
186        }
187        _ => Err(mockforge_core::Error::generic(format!(
188            "Unknown encoding: {}. Supported: hex, base64, text, file",
189            response.encoding
190        ))),
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use crate::fixtures::TcpResponse;
198    use std::io::Write;
199    use std::path::PathBuf;
200
201    fn create_test_response(data: &str, encoding: &str) -> TcpResponse {
202        TcpResponse {
203            data: data.to_string(),
204            encoding: encoding.to_string(),
205            file_path: None,
206            delay_ms: 0,
207            close_after_response: false,
208            keep_alive: true,
209        }
210    }
211
212    #[test]
213    fn test_tcp_server_new() {
214        let config = TcpConfig::default();
215        let registry = Arc::new(TcpSpecRegistry::new());
216
217        let server = TcpServer::new(config.clone(), registry.clone());
218        assert!(server.is_ok());
219
220        let server = server.unwrap();
221        assert_eq!(server.config.port, config.port);
222        assert_eq!(server.config.host, config.host);
223    }
224
225    #[test]
226    fn test_tcp_server_new_with_custom_config() {
227        let config = TcpConfig {
228            port: 8080,
229            host: "127.0.0.1".to_string(),
230            timeout_secs: 60,
231            echo_mode: false,
232            ..Default::default()
233        };
234        let registry = Arc::new(TcpSpecRegistry::new());
235
236        let server = TcpServer::new(config.clone(), registry).unwrap();
237        assert_eq!(server.config.port, 8080);
238        assert_eq!(server.config.host, "127.0.0.1");
239        assert_eq!(server.config.timeout_secs, 60);
240        assert!(!server.config.echo_mode);
241    }
242
243    #[test]
244    fn test_generate_response_data_text_encoding() {
245        let response = create_test_response("Hello, World!", "text");
246        let result = generate_response_data(&response);
247
248        assert!(result.is_ok());
249        let data = result.unwrap();
250        assert_eq!(data, b"Hello, World!");
251        assert_eq!(String::from_utf8(data).unwrap(), "Hello, World!");
252    }
253
254    #[test]
255    fn test_generate_response_data_text_encoding_empty() {
256        let response = create_test_response("", "text");
257        let result = generate_response_data(&response);
258
259        assert!(result.is_ok());
260        assert_eq!(result.unwrap(), b"");
261    }
262
263    #[test]
264    fn test_generate_response_data_text_encoding_unicode() {
265        let response = create_test_response("Hello δΈ–η•Œ 🌍", "text");
266        let result = generate_response_data(&response);
267
268        assert!(result.is_ok());
269        let data = result.unwrap();
270        assert_eq!(String::from_utf8(data).unwrap(), "Hello δΈ–η•Œ 🌍");
271    }
272
273    #[test]
274    fn test_generate_response_data_hex_encoding() {
275        let response = create_test_response("48656c6c6f", "hex"); // "Hello" in hex
276        let result = generate_response_data(&response);
277
278        assert!(result.is_ok());
279        let data = result.unwrap();
280        assert_eq!(data, b"Hello");
281    }
282
283    #[test]
284    fn test_generate_response_data_hex_encoding_uppercase() {
285        let response = create_test_response("48656C6C6F", "hex"); // "Hello" in hex (uppercase)
286        let result = generate_response_data(&response);
287
288        assert!(result.is_ok());
289        let data = result.unwrap();
290        assert_eq!(data, b"Hello");
291    }
292
293    #[test]
294    fn test_generate_response_data_hex_encoding_mixed_case() {
295        let response = create_test_response("48656c6C6f", "hex"); // "Hello" in hex (mixed case)
296        let result = generate_response_data(&response);
297
298        assert!(result.is_ok());
299        let data = result.unwrap();
300        assert_eq!(data, b"Hello");
301    }
302
303    #[test]
304    fn test_generate_response_data_hex_encoding_invalid() {
305        let response = create_test_response("GGGG", "hex"); // Invalid hex
306        let result = generate_response_data(&response);
307
308        assert!(result.is_err());
309        let error = result.unwrap_err();
310        assert!(error.to_string().contains("Invalid hex data"));
311    }
312
313    #[test]
314    fn test_generate_response_data_hex_encoding_odd_length() {
315        let response = create_test_response("123", "hex"); // Odd length hex string
316        let result = generate_response_data(&response);
317
318        assert!(result.is_err());
319        let error = result.unwrap_err();
320        assert!(error.to_string().contains("Invalid hex data"));
321    }
322
323    #[test]
324    fn test_generate_response_data_base64_encoding() {
325        let response = create_test_response("SGVsbG8gV29ybGQ=", "base64"); // "Hello World" in base64
326        let result = generate_response_data(&response);
327
328        assert!(result.is_ok());
329        let data = result.unwrap();
330        assert_eq!(data, b"Hello World");
331    }
332
333    #[test]
334    fn test_generate_response_data_base64_encoding_with_padding() {
335        let response = create_test_response("SGVsbG8=", "base64"); // "Hello" in base64 with padding
336        let result = generate_response_data(&response);
337
338        assert!(result.is_ok());
339        let data = result.unwrap();
340        assert_eq!(data, b"Hello");
341    }
342
343    #[test]
344    fn test_generate_response_data_base64_url_safe() {
345        let response = create_test_response("PEJPRA==", "base64"); // base64 standard encoding
346        let result = generate_response_data(&response);
347
348        assert!(result.is_ok());
349        assert!(!result.unwrap().is_empty());
350    }
351
352    #[test]
353    fn test_generate_response_data_base64_encoding_invalid() {
354        let response = create_test_response("!!!invalid@@@", "base64"); // Invalid base64
355        let result = generate_response_data(&response);
356
357        assert!(result.is_err());
358        let error = result.unwrap_err();
359        assert!(error.to_string().contains("Invalid base64 data"));
360    }
361
362    #[test]
363    fn test_generate_response_data_file_encoding() {
364        // Create a temporary file
365        let mut temp_file = tempfile::NamedTempFile::new().unwrap();
366        temp_file.write_all(b"File content").unwrap();
367        temp_file.flush().unwrap();
368
369        let mut response = create_test_response("", "file");
370        response.file_path = Some(temp_file.path().to_path_buf());
371
372        let result = generate_response_data(&response);
373
374        assert!(result.is_ok());
375        let data = result.unwrap();
376        assert_eq!(data, b"File content");
377    }
378
379    #[test]
380    fn test_generate_response_data_file_encoding_binary() {
381        // Create a temporary file with binary data
382        let mut temp_file = tempfile::NamedTempFile::new().unwrap();
383        let binary_data = vec![0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD];
384        temp_file.write_all(&binary_data).unwrap();
385        temp_file.flush().unwrap();
386
387        let mut response = create_test_response("", "file");
388        response.file_path = Some(temp_file.path().to_path_buf());
389
390        let result = generate_response_data(&response);
391
392        assert!(result.is_ok());
393        let data = result.unwrap();
394        assert_eq!(data, binary_data);
395    }
396
397    #[test]
398    fn test_generate_response_data_file_encoding_no_path() {
399        let response = create_test_response("", "file");
400        // file_path is None
401
402        let result = generate_response_data(&response);
403
404        assert!(result.is_err());
405        let error = result.unwrap_err();
406        assert!(error.to_string().contains("file_path not specified"));
407    }
408
409    #[test]
410    fn test_generate_response_data_file_encoding_nonexistent_file() {
411        let mut response = create_test_response("", "file");
412        response.file_path = Some(PathBuf::from("/nonexistent/path/to/file.txt"));
413
414        let result = generate_response_data(&response);
415
416        assert!(result.is_err());
417        let error = result.unwrap_err();
418        assert!(error.to_string().contains("Failed to read file"));
419    }
420
421    #[test]
422    fn test_generate_response_data_unknown_encoding() {
423        let response = create_test_response("data", "unknown");
424        let result = generate_response_data(&response);
425
426        assert!(result.is_err());
427        let error = result.unwrap_err();
428        assert!(error.to_string().contains("Unknown encoding: unknown"));
429        assert!(error.to_string().contains("Supported: hex, base64, text, file"));
430    }
431
432    #[test]
433    fn test_generate_response_data_case_sensitive_encoding() {
434        // Test that encoding is case-sensitive
435        let response = create_test_response("SGVsbG8=", "BASE64"); // uppercase encoding
436        let result = generate_response_data(&response);
437
438        assert!(result.is_err());
439        assert!(result.unwrap_err().to_string().contains("Unknown encoding"));
440    }
441
442    #[test]
443    fn test_generate_response_data_text_with_special_chars() {
444        let response = create_test_response("Line1\nLine2\r\nLine3\t\0End", "text");
445        let result = generate_response_data(&response);
446
447        assert!(result.is_ok());
448        let data = result.unwrap();
449        assert_eq!(data, b"Line1\nLine2\r\nLine3\t\0End");
450    }
451
452    #[test]
453    fn test_generate_response_data_hex_empty() {
454        let response = create_test_response("", "hex");
455        let result = generate_response_data(&response);
456
457        assert!(result.is_ok());
458        assert_eq!(result.unwrap(), b"");
459    }
460
461    #[test]
462    fn test_generate_response_data_base64_empty() {
463        let response = create_test_response("", "base64");
464        let result = generate_response_data(&response);
465
466        assert!(result.is_ok());
467        assert_eq!(result.unwrap(), b"");
468    }
469
470    #[test]
471    fn test_generate_response_data_hex_with_spaces() {
472        // Hex decoder doesn't handle spaces, should fail
473        let response = create_test_response("48 65 6c 6c 6f", "hex");
474        let result = generate_response_data(&response);
475
476        assert!(result.is_err());
477    }
478
479    #[test]
480    fn test_tcp_server_config_fields() {
481        let config = TcpConfig {
482            port: 9000,
483            host: "localhost".to_string(),
484            fixtures_dir: Some(PathBuf::from("/tmp/fixtures")),
485            timeout_secs: 120,
486            max_connections: 50,
487            read_buffer_size: 4096,
488            write_buffer_size: 4096,
489            enable_tls: true,
490            tls_cert_path: Some(PathBuf::from("/path/to/cert.pem")),
491            tls_key_path: Some(PathBuf::from("/path/to/key.pem")),
492            echo_mode: false,
493            delimiter: Some(b"\r\n".to_vec()),
494        };
495
496        let registry = Arc::new(TcpSpecRegistry::new());
497        let server = TcpServer::new(config, registry).unwrap();
498
499        assert_eq!(server.config.port, 9000);
500        assert_eq!(server.config.host, "localhost");
501        assert_eq!(server.config.timeout_secs, 120);
502        assert_eq!(server.config.max_connections, 50);
503        assert_eq!(server.config.read_buffer_size, 4096);
504        assert_eq!(server.config.write_buffer_size, 4096);
505        assert!(server.config.enable_tls);
506        assert!(!server.config.echo_mode);
507        assert_eq!(server.config.delimiter, Some(b"\r\n".to_vec()));
508    }
509
510    #[test]
511    fn test_tcp_response_with_delay() {
512        let response = TcpResponse {
513            data: "delayed".to_string(),
514            encoding: "text".to_string(),
515            file_path: None,
516            delay_ms: 500,
517            close_after_response: true,
518            keep_alive: false,
519        };
520
521        let result = generate_response_data(&response);
522        assert!(result.is_ok());
523        assert_eq!(result.unwrap(), b"delayed");
524        // Note: delay is applied in handle_tcp_connection, not in generate_response_data
525    }
526
527    #[test]
528    fn test_tcp_response_close_after_response() {
529        let response = TcpResponse {
530            data: "close me".to_string(),
531            encoding: "text".to_string(),
532            file_path: None,
533            delay_ms: 0,
534            close_after_response: true,
535            keep_alive: false,
536        };
537
538        assert!(response.close_after_response);
539        assert!(!response.keep_alive);
540
541        let result = generate_response_data(&response);
542        assert!(result.is_ok());
543    }
544
545    #[test]
546    fn test_generate_response_data_large_text() {
547        let large_text = "x".repeat(100_000);
548        let response = create_test_response(&large_text, "text");
549        let result = generate_response_data(&response);
550
551        assert!(result.is_ok());
552        let data = result.unwrap();
553        assert_eq!(data.len(), 100_000);
554        assert_eq!(data, large_text.as_bytes());
555    }
556
557    #[test]
558    fn test_generate_response_data_large_hex() {
559        // Generate 10000 bytes of hex data (20000 hex chars)
560        let hex_data = "00".repeat(10_000);
561        let response = create_test_response(&hex_data, "hex");
562        let result = generate_response_data(&response);
563
564        assert!(result.is_ok());
565        let data = result.unwrap();
566        assert_eq!(data.len(), 10_000);
567        assert!(data.iter().all(|&b| b == 0));
568    }
569
570    #[test]
571    fn test_file_encoding_empty_file() {
572        let temp_file = tempfile::NamedTempFile::new().unwrap();
573        // Don't write anything, leave it empty
574
575        let mut response = create_test_response("", "file");
576        response.file_path = Some(temp_file.path().to_path_buf());
577
578        let result = generate_response_data(&response);
579
580        assert!(result.is_ok());
581        assert_eq!(result.unwrap(), b"");
582    }
583
584    #[test]
585    fn test_file_encoding_large_file() {
586        let mut temp_file = tempfile::NamedTempFile::new().unwrap();
587        let large_data = vec![0xAB; 50_000]; // 50KB of data
588        temp_file.write_all(&large_data).unwrap();
589        temp_file.flush().unwrap();
590
591        let mut response = create_test_response("", "file");
592        response.file_path = Some(temp_file.path().to_path_buf());
593
594        let result = generate_response_data(&response);
595
596        assert!(result.is_ok());
597        let data = result.unwrap();
598        assert_eq!(data.len(), 50_000);
599        assert_eq!(data, large_data);
600    }
601}