mssql_testing/
mock_server.rs

1//! Mock TDS server for unit testing.
2//!
3//! This module provides a mock SQL Server implementation that can be used
4//! for unit testing without requiring a real database instance.
5//!
6//! ## Features
7//!
8//! - Simulates TDS protocol handshake (prelogin, login)
9//! - Configurable responses for SQL queries
10//! - Support for multiple concurrent connections
11//! - Recorded packet replay for regression testing
12//!
13//! ## Example
14//!
15//! ```rust,ignore
16//! use mssql_testing::mock_server::{MockTdsServer, MockResponse};
17//!
18//! #[tokio::test]
19//! async fn test_query() {
20//!     let server = MockTdsServer::builder()
21//!         .with_response("SELECT 1", MockResponse::scalar(1i32))
22//!         .build()
23//!         .await
24//!         .unwrap();
25//!
26//!     let addr = server.addr();
27//!     // Connect your client to addr...
28//! }
29//! ```
30
31use bytes::{BufMut, Bytes, BytesMut};
32use std::collections::HashMap;
33use std::fmt;
34use std::net::SocketAddr;
35use std::sync::Arc;
36use tds_protocol::types::TypeId;
37use tds_protocol::{
38    DoneStatus, EnvChangeType, PACKET_HEADER_SIZE, PacketHeader, PacketStatus, PacketType,
39    TokenType,
40};
41use thiserror::Error;
42use tokio::io::{AsyncReadExt, AsyncWriteExt};
43use tokio::net::{TcpListener, TcpStream};
44use tokio::sync::{Mutex, broadcast};
45
46/// Error type for mock server operations.
47#[derive(Debug, Error)]
48pub enum MockServerError {
49    /// IO error.
50    #[error("IO error: {0}")]
51    Io(#[from] std::io::Error),
52
53    /// Protocol error.
54    #[error("Protocol error: {0}")]
55    Protocol(String),
56
57    /// Server already stopped.
58    #[error("Server already stopped")]
59    Stopped,
60}
61
62/// Result type for mock server operations.
63pub type Result<T> = std::result::Result<T, MockServerError>;
64
65/// Mock response configuration.
66#[derive(Clone)]
67pub enum MockResponse {
68    /// Return a single scalar value.
69    Scalar(ScalarValue),
70
71    /// Return multiple rows with columns.
72    Rows {
73        /// Column definitions.
74        columns: Vec<MockColumn>,
75        /// Row data.
76        rows: Vec<Vec<ScalarValue>>,
77    },
78
79    /// Return an error.
80    Error {
81        /// Error number.
82        number: i32,
83        /// Error message.
84        message: String,
85        /// Severity class.
86        severity: u8,
87    },
88
89    /// Return rows affected count (for INSERT/UPDATE/DELETE).
90    RowsAffected(u64),
91
92    /// Return raw pre-encoded TDS tokens.
93    Raw(Bytes),
94
95    /// Execute a custom handler.
96    Custom(Arc<dyn Fn(&str) -> MockResponse + Send + Sync>),
97}
98
99impl fmt::Debug for MockResponse {
100    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101        match self {
102            Self::Scalar(v) => f.debug_tuple("Scalar").field(v).finish(),
103            Self::Rows { columns, rows } => f
104                .debug_struct("Rows")
105                .field("columns", columns)
106                .field("rows", rows)
107                .finish(),
108            Self::Error {
109                number,
110                message,
111                severity,
112            } => f
113                .debug_struct("Error")
114                .field("number", number)
115                .field("message", message)
116                .field("severity", severity)
117                .finish(),
118            Self::RowsAffected(n) => f.debug_tuple("RowsAffected").field(n).finish(),
119            Self::Raw(data) => f.debug_tuple("Raw").field(&data.len()).finish(),
120            Self::Custom(_) => f.debug_tuple("Custom").field(&"<fn>").finish(),
121        }
122    }
123}
124
125impl MockResponse {
126    /// Create a scalar integer response.
127    pub fn scalar_int(value: i32) -> Self {
128        Self::Scalar(ScalarValue::Int(value))
129    }
130
131    /// Create a scalar string response.
132    pub fn scalar_string(value: impl Into<String>) -> Self {
133        Self::Scalar(ScalarValue::String(value.into()))
134    }
135
136    /// Create an empty result response.
137    pub fn empty() -> Self {
138        Self::RowsAffected(0)
139    }
140
141    /// Create a rows affected response.
142    pub fn affected(count: u64) -> Self {
143        Self::RowsAffected(count)
144    }
145
146    /// Create an error response.
147    pub fn error(number: i32, message: impl Into<String>) -> Self {
148        Self::Error {
149            number,
150            message: message.into(),
151            severity: 16,
152        }
153    }
154
155    /// Create a multi-row response.
156    pub fn rows(columns: Vec<MockColumn>, rows: Vec<Vec<ScalarValue>>) -> Self {
157        Self::Rows { columns, rows }
158    }
159}
160
161/// Scalar value for mock responses.
162#[derive(Debug, Clone)]
163pub enum ScalarValue {
164    /// NULL value.
165    Null,
166    /// Boolean value.
167    Bool(bool),
168    /// 32-bit integer.
169    Int(i32),
170    /// 64-bit integer.
171    BigInt(i64),
172    /// 32-bit float.
173    Float(f32),
174    /// 64-bit float.
175    Double(f64),
176    /// String value.
177    String(String),
178    /// Binary data.
179    Binary(Vec<u8>),
180}
181
182impl ScalarValue {
183    /// Get the TDS type ID for this value.
184    fn type_id(&self) -> TypeId {
185        match self {
186            Self::Null => TypeId::Null,
187            Self::Bool(_) => TypeId::BitN,
188            Self::Int(_) => TypeId::IntN,
189            Self::BigInt(_) => TypeId::IntN,
190            Self::Float(_) => TypeId::FloatN,
191            Self::Double(_) => TypeId::FloatN,
192            Self::String(_) => TypeId::NVarChar,
193            Self::Binary(_) => TypeId::BigVarBinary,
194        }
195    }
196
197    /// Encode this value to TDS format.
198    fn encode(&self, dst: &mut BytesMut) {
199        match self {
200            Self::Null => {
201                dst.put_u8(0); // NULL length
202            }
203            Self::Bool(v) => {
204                dst.put_u8(1); // length
205                dst.put_u8(if *v { 1 } else { 0 });
206            }
207            Self::Int(v) => {
208                dst.put_u8(4); // length
209                dst.put_i32_le(*v);
210            }
211            Self::BigInt(v) => {
212                dst.put_u8(8); // length
213                dst.put_i64_le(*v);
214            }
215            Self::Float(v) => {
216                dst.put_u8(4); // length
217                dst.put_f32_le(*v);
218            }
219            Self::Double(v) => {
220                dst.put_u8(8); // length
221                dst.put_f64_le(*v);
222            }
223            Self::String(s) => {
224                let utf16: Vec<u16> = s.encode_utf16().collect();
225                let byte_len = utf16.len() * 2;
226                if byte_len > 0xFFFF {
227                    // PLP format for large strings
228                    dst.put_u64_le(byte_len as u64);
229                    dst.put_u32_le(byte_len as u32);
230                    for c in utf16 {
231                        dst.put_u16_le(c);
232                    }
233                    dst.put_u32_le(0); // terminator
234                } else {
235                    dst.put_u16_le(byte_len as u16);
236                    for c in utf16 {
237                        dst.put_u16_le(c);
238                    }
239                }
240            }
241            Self::Binary(data) => {
242                if data.len() > 0xFFFF {
243                    // PLP format
244                    dst.put_u64_le(data.len() as u64);
245                    dst.put_u32_le(data.len() as u32);
246                    dst.extend_from_slice(data);
247                    dst.put_u32_le(0); // terminator
248                } else {
249                    dst.put_u16_le(data.len() as u16);
250                    dst.extend_from_slice(data);
251                }
252            }
253        }
254    }
255}
256
257/// Mock column definition.
258#[derive(Debug, Clone)]
259pub struct MockColumn {
260    /// Column name.
261    pub name: String,
262    /// Column type.
263    pub type_id: TypeId,
264    /// Maximum length (for variable-length types).
265    pub max_length: Option<u32>,
266    /// Whether the column is nullable.
267    pub nullable: bool,
268}
269
270impl MockColumn {
271    /// Create a new column definition.
272    pub fn new(name: impl Into<String>, type_id: TypeId) -> Self {
273        Self {
274            name: name.into(),
275            type_id,
276            max_length: None,
277            nullable: true,
278        }
279    }
280
281    /// Create an INT column.
282    pub fn int(name: impl Into<String>) -> Self {
283        Self::new(name, TypeId::IntN).with_max_length(4)
284    }
285
286    /// Create a BIGINT column.
287    pub fn bigint(name: impl Into<String>) -> Self {
288        Self::new(name, TypeId::IntN).with_max_length(8)
289    }
290
291    /// Create an NVARCHAR column.
292    pub fn nvarchar(name: impl Into<String>, max_len: u32) -> Self {
293        Self::new(name, TypeId::NVarChar).with_max_length(max_len * 2)
294    }
295
296    /// Set the maximum length.
297    pub fn with_max_length(mut self, len: u32) -> Self {
298        self.max_length = Some(len);
299        self
300    }
301
302    /// Set nullable flag.
303    pub fn with_nullable(mut self, nullable: bool) -> Self {
304        self.nullable = nullable;
305        self
306    }
307}
308
309/// Configuration for the mock TDS server.
310#[derive(Default)]
311pub struct MockServerConfig {
312    /// Pre-configured responses for specific SQL queries.
313    responses: HashMap<String, MockResponse>,
314    /// Default response for unmatched queries.
315    default_response: Option<MockResponse>,
316    /// Server name to report in LoginAck.
317    server_name: String,
318    /// TDS version to report.
319    tds_version: u32,
320    /// Default database name.
321    database: String,
322}
323
324/// Builder for `MockTdsServer`.
325pub struct MockServerBuilder {
326    config: MockServerConfig,
327}
328
329impl MockServerBuilder {
330    /// Create a new builder with default settings.
331    pub fn new() -> Self {
332        Self {
333            config: MockServerConfig {
334                responses: HashMap::new(),
335                default_response: Some(MockResponse::empty()),
336                server_name: "MockSQLServer".to_string(),
337                tds_version: 0x74000004, // TDS 7.4
338                database: "master".to_string(),
339            },
340        }
341    }
342
343    /// Add a response for a specific SQL query.
344    pub fn with_response(mut self, sql: impl Into<String>, response: MockResponse) -> Self {
345        self.config.responses.insert(sql.into(), response);
346        self
347    }
348
349    /// Set the default response for unmatched queries.
350    pub fn with_default_response(mut self, response: MockResponse) -> Self {
351        self.config.default_response = Some(response);
352        self
353    }
354
355    /// Set the server name reported in LoginAck.
356    pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
357        self.config.server_name = name.into();
358        self
359    }
360
361    /// Set the default database.
362    pub fn with_database(mut self, db: impl Into<String>) -> Self {
363        self.config.database = db.into();
364        self
365    }
366
367    /// Build and start the mock server.
368    pub async fn build(self) -> Result<MockTdsServer> {
369        MockTdsServer::start(self.config).await
370    }
371}
372
373impl Default for MockServerBuilder {
374    fn default() -> Self {
375        Self::new()
376    }
377}
378
379/// A mock TDS server for testing.
380///
381/// This server simulates a SQL Server instance for unit testing purposes.
382/// It handles the TDS protocol handshake and responds to queries based on
383/// pre-configured responses.
384pub struct MockTdsServer {
385    /// Server address.
386    addr: SocketAddr,
387    /// Shutdown signal sender.
388    shutdown_tx: broadcast::Sender<()>,
389    /// Server configuration (stored for potential introspection).
390    #[allow(dead_code)]
391    config: Arc<MockServerConfig>,
392    /// Connection count.
393    connection_count: Arc<Mutex<usize>>,
394}
395
396impl MockTdsServer {
397    /// Create a new builder for the mock server.
398    pub fn builder() -> MockServerBuilder {
399        MockServerBuilder::new()
400    }
401
402    /// Start the mock server on an available port.
403    pub async fn start(config: MockServerConfig) -> Result<Self> {
404        let listener = TcpListener::bind("127.0.0.1:0").await?;
405        let addr = listener.local_addr()?;
406        let (shutdown_tx, _) = broadcast::channel(1);
407        let config = Arc::new(config);
408        let connection_count = Arc::new(Mutex::new(0usize));
409
410        let server = Self {
411            addr,
412            shutdown_tx: shutdown_tx.clone(),
413            config: config.clone(),
414            connection_count: connection_count.clone(),
415        };
416
417        // Spawn the accept loop
418        let mut shutdown_rx = shutdown_tx.subscribe();
419        tokio::spawn(async move {
420            loop {
421                tokio::select! {
422                    result = listener.accept() => {
423                        match result {
424                            Ok((stream, _peer_addr)) => {
425                                let config = config.clone();
426                                let count = connection_count.clone();
427                                tokio::spawn(async move {
428                                    {
429                                        let mut c = count.lock().await;
430                                        *c += 1;
431                                    }
432                                    if let Err(e) = handle_connection(stream, config).await {
433                                        tracing::debug!("Connection error: {}", e);
434                                    }
435                                    {
436                                        let mut c = count.lock().await;
437                                        *c = c.saturating_sub(1);
438                                    }
439                                });
440                            }
441                            Err(e) => {
442                                tracing::error!("Accept error: {}", e);
443                                break;
444                            }
445                        }
446                    }
447                    _ = shutdown_rx.recv() => {
448                        break;
449                    }
450                }
451            }
452        });
453
454        Ok(server)
455    }
456
457    /// Get the server's listening address.
458    pub fn addr(&self) -> SocketAddr {
459        self.addr
460    }
461
462    /// Get the host string for connection configuration.
463    pub fn host(&self) -> String {
464        self.addr.ip().to_string()
465    }
466
467    /// Get the port number.
468    pub fn port(&self) -> u16 {
469        self.addr.port()
470    }
471
472    /// Get the current connection count.
473    pub async fn connection_count(&self) -> usize {
474        *self.connection_count.lock().await
475    }
476
477    /// Stop the server.
478    pub fn stop(&self) {
479        let _ = self.shutdown_tx.send(());
480    }
481}
482
483impl Drop for MockTdsServer {
484    fn drop(&mut self) {
485        self.stop();
486    }
487}
488
489/// Handle a single client connection.
490async fn handle_connection(mut stream: TcpStream, config: Arc<MockServerConfig>) -> Result<()> {
491    // Step 1: Handle PRELOGIN
492    let prelogin_request = read_packet(&mut stream).await?;
493    if prelogin_request.packet_type != PacketType::PreLogin {
494        return Err(MockServerError::Protocol(format!(
495            "Expected PreLogin, got {:?}",
496            prelogin_request.packet_type
497        )));
498    }
499    send_prelogin_response(&mut stream).await?;
500
501    // Step 2: Handle LOGIN7
502    let login_request = read_packet(&mut stream).await?;
503    if login_request.packet_type != PacketType::Tds7Login {
504        return Err(MockServerError::Protocol(format!(
505            "Expected Tds7Login, got {:?}",
506            login_request.packet_type
507        )));
508    }
509    send_login_response(&mut stream, &config).await?;
510
511    // Step 3: Handle SQL batches and RPC requests
512    loop {
513        let packet = match read_packet(&mut stream).await {
514            Ok(p) => p,
515            Err(MockServerError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
516                // Client disconnected
517                break;
518            }
519            Err(e) => return Err(e),
520        };
521
522        match packet.packet_type {
523            PacketType::SqlBatch => {
524                let sql = decode_sql_batch(&packet.payload)?;
525                let response = find_response(&sql, &config);
526                send_query_response(&mut stream, response).await?;
527            }
528            PacketType::Rpc => {
529                // For RPC requests (sp_executesql, sp_prepare, etc.)
530                // Extract the SQL from the RPC payload and handle similarly
531                let response = config
532                    .default_response
533                    .clone()
534                    .unwrap_or(MockResponse::empty());
535                send_query_response(&mut stream, response).await?;
536            }
537            PacketType::Attention => {
538                // Client sent attention/cancel signal
539                send_attention_ack(&mut stream).await?;
540            }
541            _ => {
542                tracing::debug!("Unexpected packet type: {:?}", packet.packet_type);
543            }
544        }
545    }
546
547    Ok(())
548}
549
550/// Parsed TDS packet.
551struct Packet {
552    packet_type: PacketType,
553    payload: Bytes,
554}
555
556/// Read a complete TDS packet from the stream.
557async fn read_packet(stream: &mut TcpStream) -> Result<Packet> {
558    let mut header_buf = [0u8; PACKET_HEADER_SIZE];
559    stream.read_exact(&mut header_buf).await?;
560
561    let mut cursor = &header_buf[..];
562    let header =
563        PacketHeader::decode(&mut cursor).map_err(|e| MockServerError::Protocol(e.to_string()))?;
564
565    let payload_len = header.payload_length();
566    let mut payload = vec![0u8; payload_len];
567    if payload_len > 0 {
568        stream.read_exact(&mut payload).await?;
569    }
570
571    // Handle multi-packet messages
572    let mut full_payload = BytesMut::from(&payload[..]);
573
574    if !header.is_end_of_message() {
575        loop {
576            let mut next_header_buf = [0u8; PACKET_HEADER_SIZE];
577            stream.read_exact(&mut next_header_buf).await?;
578
579            let mut cursor = &next_header_buf[..];
580            let next_header = PacketHeader::decode(&mut cursor)
581                .map_err(|e| MockServerError::Protocol(e.to_string()))?;
582
583            let next_payload_len = next_header.payload_length();
584            let mut next_payload = vec![0u8; next_payload_len];
585            if next_payload_len > 0 {
586                stream.read_exact(&mut next_payload).await?;
587            }
588
589            full_payload.extend_from_slice(&next_payload);
590
591            if next_header.is_end_of_message() {
592                break;
593            }
594        }
595    }
596
597    Ok(Packet {
598        packet_type: header.packet_type,
599        payload: full_payload.freeze(),
600    })
601}
602
603/// Write a TDS packet to the stream.
604async fn write_packet(
605    stream: &mut TcpStream,
606    packet_type: PacketType,
607    payload: &[u8],
608) -> Result<()> {
609    let total_len = PACKET_HEADER_SIZE + payload.len();
610    let header = PacketHeader {
611        packet_type,
612        status: PacketStatus::END_OF_MESSAGE,
613        length: total_len as u16,
614        spid: 0,
615        packet_id: 1,
616        window: 0,
617    };
618
619    let mut buf = BytesMut::with_capacity(total_len);
620    header.encode(&mut buf);
621    buf.extend_from_slice(payload);
622
623    stream.write_all(&buf).await?;
624    stream.flush().await?;
625    Ok(())
626}
627
628/// Send PRELOGIN response.
629async fn send_prelogin_response(stream: &mut TcpStream) -> Result<()> {
630    // PRELOGIN response format:
631    // Option tokens (5 bytes each: type + offset + length) followed by data
632    // VERSION (0x00), ENCRYPTION (0x01)
633    //
634    // Options section layout:
635    //   VERSION token:    1 + 2 + 2 = 5 bytes
636    //   ENCRYPTION token: 1 + 2 + 2 = 5 bytes
637    //   Terminator:       1 byte
638    //   Total header:     11 bytes
639    //
640    // Data section layout:
641    //   VERSION data at offset 11: 6 bytes
642    //   ENCRYPTION data at offset 17: 1 byte
643
644    let mut response = BytesMut::new();
645
646    // Option header area (offsets are big-endian per TDS spec)
647    // VERSION token
648    response.put_u8(0x00); // VERSION
649    response.put_u16(11); // offset (header size)
650    response.put_u16(6); // length
651
652    // ENCRYPTION token
653    response.put_u8(0x01); // ENCRYPTION
654    response.put_u16(17); // offset (11 + 6)
655    response.put_u16(1); // length
656
657    // Terminator
658    response.put_u8(0xFF);
659
660    // VERSION data (at offset 11)
661    response.put_u8(16); // major version
662    response.put_u8(0); // minor version
663    response.put_u16_le(0); // build number
664    response.put_u16_le(0); // sub-build number
665
666    // ENCRYPTION data (at offset 17)
667    response.put_u8(0x00); // ENCRYPT_OFF (no encryption)
668
669    write_packet(stream, PacketType::PreLogin, &response).await
670}
671
672/// Send LOGIN7 response (LoginAck + EnvChange + Done).
673async fn send_login_response(stream: &mut TcpStream, config: &MockServerConfig) -> Result<()> {
674    let mut response = BytesMut::new();
675
676    // EnvChange: Database
677    encode_env_change(&mut response, EnvChangeType::Database, &config.database, "");
678
679    // EnvChange: PacketSize
680    encode_env_change(&mut response, EnvChangeType::PacketSize, "4096", "4096");
681
682    // LoginAck
683    encode_login_ack(&mut response, &config.server_name, config.tds_version);
684
685    // Done
686    encode_done(&mut response, 0, false);
687
688    write_packet(stream, PacketType::TabularResult, &response).await
689}
690
691/// Encode an EnvChange token.
692fn encode_env_change(dst: &mut BytesMut, env_type: EnvChangeType, new_val: &str, old_val: &str) {
693    let new_utf16: Vec<u16> = new_val.encode_utf16().collect();
694    let old_utf16: Vec<u16> = old_val.encode_utf16().collect();
695
696    let data_len = 1 + 1 + new_utf16.len() * 2 + 1 + old_utf16.len() * 2;
697
698    dst.put_u8(TokenType::EnvChange as u8);
699    dst.put_u16_le(data_len as u16);
700    dst.put_u8(env_type as u8);
701
702    // New value (B_VARCHAR format)
703    dst.put_u8(new_utf16.len() as u8);
704    for c in &new_utf16 {
705        dst.put_u16_le(*c);
706    }
707
708    // Old value (B_VARCHAR format)
709    dst.put_u8(old_utf16.len() as u8);
710    for c in &old_utf16 {
711        dst.put_u16_le(*c);
712    }
713}
714
715/// Encode a LoginAck token.
716fn encode_login_ack(dst: &mut BytesMut, server_name: &str, tds_version: u32) {
717    let name_utf16: Vec<u16> = server_name.encode_utf16().collect();
718
719    // LoginAck: interface (1) + tds_version (4) + prog_name (b_varchar) + prog_version (4)
720    let data_len = 1 + 4 + 1 + name_utf16.len() * 2 + 4;
721
722    dst.put_u8(TokenType::LoginAck as u8);
723    dst.put_u16_le(data_len as u16);
724    dst.put_u8(1); // interface: SQL
725    dst.put_u32_le(tds_version);
726
727    // Program name (B_VARCHAR)
728    dst.put_u8(name_utf16.len() as u8);
729    for c in &name_utf16 {
730        dst.put_u16_le(*c);
731    }
732
733    // Program version
734    dst.put_u32_le(0x10000000); // 16.0.0.0
735}
736
737/// Encode a Done token.
738fn encode_done(dst: &mut BytesMut, row_count: u64, more: bool) {
739    dst.put_u8(TokenType::Done as u8);
740
741    let status = DoneStatus {
742        count: row_count > 0,
743        more,
744        ..Default::default()
745    };
746
747    dst.put_u16_le(status.to_bits());
748    dst.put_u16_le(0xC1); // cur_cmd: SELECT
749    dst.put_u64_le(row_count);
750}
751
752/// Decode SQL from a SQL_BATCH packet payload.
753fn decode_sql_batch(payload: &Bytes) -> Result<String> {
754    // SQL Batch format: ALL_HEADERS (optional) + SQL text in UTF-16LE
755    // For simplicity, assume no ALL_HEADERS (check first 4 bytes)
756
757    let mut cursor = payload.as_ref();
758
759    // Check if ALL_HEADERS is present
760    if cursor.len() >= 4 {
761        let total_len = u32::from_le_bytes([cursor[0], cursor[1], cursor[2], cursor[3]]) as usize;
762
763        // If total_len looks like a header length (reasonable size), skip headers
764        if total_len >= 4 && total_len < cursor.len() && total_len < 1000 {
765            cursor = &cursor[total_len..];
766        }
767    }
768
769    // Read UTF-16LE SQL text
770    if cursor.len() % 2 != 0 {
771        return Err(MockServerError::Protocol(
772            "Invalid UTF-16 SQL text length".to_string(),
773        ));
774    }
775
776    let char_count = cursor.len() / 2;
777    let mut chars = Vec::with_capacity(char_count);
778    for i in 0..char_count {
779        let c = u16::from_le_bytes([cursor[i * 2], cursor[i * 2 + 1]]);
780        chars.push(c);
781    }
782
783    String::from_utf16(&chars)
784        .map_err(|_| MockServerError::Protocol("Invalid UTF-16 SQL text".to_string()))
785}
786
787/// Find the response for a SQL query.
788fn find_response(sql: &str, config: &MockServerConfig) -> MockResponse {
789    // Normalize SQL for matching
790    let normalized = sql.trim().to_uppercase();
791
792    // Check exact match first
793    if let Some(response) = config.responses.get(&normalized) {
794        return response.clone();
795    }
796
797    // Check case-insensitive match
798    for (key, response) in &config.responses {
799        if key.trim().to_uppercase() == normalized {
800            return response.clone();
801        }
802    }
803
804    // Use default response
805    config
806        .default_response
807        .clone()
808        .unwrap_or(MockResponse::empty())
809}
810
811/// Send a query response based on the MockResponse.
812async fn send_query_response(stream: &mut TcpStream, response: MockResponse) -> Result<()> {
813    let mut buf = BytesMut::new();
814
815    match response {
816        MockResponse::Scalar(value) => {
817            // Single column, single row result
818            encode_colmetadata(&mut buf, &[MockColumn::new("", value.type_id())]);
819            encode_row(&mut buf, &[value.clone()]);
820            encode_done(&mut buf, 1, false);
821        }
822        MockResponse::Rows { columns, rows } => {
823            encode_colmetadata(&mut buf, &columns);
824            for row in &rows {
825                encode_row(&mut buf, row);
826            }
827            encode_done(&mut buf, rows.len() as u64, false);
828        }
829        MockResponse::Error {
830            number,
831            message,
832            severity,
833        } => {
834            encode_error(&mut buf, number, &message, severity);
835            encode_done(&mut buf, 0, false);
836        }
837        MockResponse::RowsAffected(count) => {
838            encode_done(&mut buf, count, false);
839        }
840        MockResponse::Raw(data) => {
841            buf.extend_from_slice(&data);
842        }
843        MockResponse::Custom(_handler) => {
844            // For custom handlers, we'd need the SQL here
845            // For now, just send empty result
846            encode_done(&mut buf, 0, false);
847        }
848    }
849
850    write_packet(stream, PacketType::TabularResult, &buf).await
851}
852
853/// Encode COLMETADATA token.
854fn encode_colmetadata(dst: &mut BytesMut, columns: &[MockColumn]) {
855    dst.put_u8(TokenType::ColMetaData as u8);
856    dst.put_u16_le(columns.len() as u16);
857
858    for col in columns {
859        // UserType (4 bytes)
860        dst.put_u32_le(0);
861
862        // Flags (2 bytes) - nullable = 0x01
863        dst.put_u16_le(if col.nullable { 0x01 } else { 0x00 });
864
865        // Type ID (1 byte)
866        dst.put_u8(col.type_id as u8);
867
868        // Type-specific metadata
869        match col.type_id {
870            TypeId::IntN | TypeId::BitN | TypeId::FloatN | TypeId::MoneyN | TypeId::DateTimeN => {
871                dst.put_u8(col.max_length.unwrap_or(4) as u8);
872            }
873            TypeId::NVarChar | TypeId::NChar => {
874                dst.put_u16_le(col.max_length.unwrap_or(8000) as u16);
875                // Collation (5 bytes)
876                dst.put_u32_le(0x0904D000); // LCID
877                dst.put_u8(0x34); // Sort ID
878            }
879            TypeId::BigVarBinary | TypeId::BigBinary => {
880                dst.put_u16_le(col.max_length.unwrap_or(8000) as u16);
881            }
882            _ => {
883                // Fixed-length types have no additional metadata
884            }
885        }
886
887        // Column name (B_VARCHAR)
888        let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
889        dst.put_u8(name_utf16.len() as u8);
890        for c in &name_utf16 {
891            dst.put_u16_le(*c);
892        }
893    }
894}
895
896/// Encode ROW token.
897fn encode_row(dst: &mut BytesMut, values: &[ScalarValue]) {
898    dst.put_u8(TokenType::Row as u8);
899    for value in values {
900        value.encode(dst);
901    }
902}
903
904/// Encode ERROR token.
905fn encode_error(dst: &mut BytesMut, number: i32, message: &str, severity: u8) {
906    let msg_utf16: Vec<u16> = message.encode_utf16().collect();
907    let server_utf16: Vec<u16> = "MockServer".encode_utf16().collect();
908
909    // ERROR: number (4) + state (1) + class (1) + message (us_varchar) +
910    //        server (b_varchar) + procedure (b_varchar) + line (4)
911    let data_len = (4 + 1 + 1 + 2 + msg_utf16.len() * 2 + 1 + server_utf16.len() * 2 + 1) + 4;
912
913    dst.put_u8(TokenType::Error as u8);
914    dst.put_u16_le(data_len as u16);
915    dst.put_i32_le(number);
916    dst.put_u8(1); // state
917    dst.put_u8(severity); // class
918
919    // Message (US_VARCHAR)
920    dst.put_u16_le(msg_utf16.len() as u16);
921    for c in &msg_utf16 {
922        dst.put_u16_le(*c);
923    }
924
925    // Server name (B_VARCHAR)
926    dst.put_u8(server_utf16.len() as u8);
927    for c in &server_utf16 {
928        dst.put_u16_le(*c);
929    }
930
931    // Procedure name (B_VARCHAR) - empty
932    dst.put_u8(0);
933
934    // Line number
935    dst.put_i32_le(1);
936}
937
938/// Send attention acknowledgment.
939async fn send_attention_ack(stream: &mut TcpStream) -> Result<()> {
940    let mut buf = BytesMut::new();
941
942    // DONE with ATTN flag
943    buf.put_u8(TokenType::Done as u8);
944    let status = DoneStatus {
945        attn: true,
946        ..Default::default()
947    };
948    buf.put_u16_le(status.to_bits());
949    buf.put_u16_le(0);
950    buf.put_u64_le(0);
951
952    write_packet(stream, PacketType::TabularResult, &buf).await
953}
954
955/// Recorded packet for replay testing.
956#[derive(Debug, Clone)]
957pub struct RecordedPacket {
958    /// Packet direction (true = server to client).
959    pub from_server: bool,
960    /// Raw packet data including header.
961    pub data: Bytes,
962}
963
964/// Packet recorder for capturing and replaying TDS sessions.
965#[derive(Debug, Default)]
966pub struct PacketRecorder {
967    packets: Vec<RecordedPacket>,
968}
969
970impl PacketRecorder {
971    /// Create a new packet recorder.
972    pub fn new() -> Self {
973        Self::default()
974    }
975
976    /// Record a packet.
977    pub fn record(&mut self, from_server: bool, data: Bytes) {
978        self.packets.push(RecordedPacket { from_server, data });
979    }
980
981    /// Get all recorded packets.
982    pub fn packets(&self) -> &[RecordedPacket] {
983        &self.packets
984    }
985
986    /// Save recorded packets to a file.
987    pub async fn save(&self, path: &std::path::Path) -> std::io::Result<()> {
988        use tokio::fs::File;
989        use tokio::io::AsyncWriteExt;
990
991        let mut file = File::create(path).await?;
992
993        for packet in &self.packets {
994            // Direction (1 byte) + length (4 bytes) + data
995            file.write_u8(if packet.from_server { 1 } else { 0 })
996                .await?;
997            file.write_u32_le(packet.data.len() as u32).await?;
998            file.write_all(&packet.data).await?;
999        }
1000
1001        Ok(())
1002    }
1003
1004    /// Load recorded packets from a file.
1005    pub async fn load(path: &std::path::Path) -> std::io::Result<Self> {
1006        use tokio::fs::File;
1007        use tokio::io::AsyncReadExt;
1008
1009        let mut file = File::open(path).await?;
1010        let mut recorder = Self::new();
1011
1012        loop {
1013            let from_server = match file.read_u8().await {
1014                Ok(b) => b != 0,
1015                Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
1016                Err(e) => return Err(e),
1017            };
1018
1019            let len = file.read_u32_le().await? as usize;
1020            let mut data = vec![0u8; len];
1021            file.read_exact(&mut data).await?;
1022
1023            recorder.record(from_server, Bytes::from(data));
1024        }
1025
1026        Ok(recorder)
1027    }
1028}
1029
1030#[cfg(test)]
1031#[allow(clippy::unwrap_used, clippy::panic)]
1032mod tests {
1033    use super::*;
1034
1035    #[tokio::test]
1036    async fn test_mock_server_starts() {
1037        let server = MockTdsServer::builder()
1038            .with_server_name("TestServer")
1039            .build()
1040            .await
1041            .unwrap();
1042
1043        assert!(server.port() > 0);
1044        assert_eq!(server.host(), "127.0.0.1");
1045    }
1046
1047    #[tokio::test]
1048    async fn test_mock_response_scalar() {
1049        let response = MockResponse::scalar_int(42);
1050        match response {
1051            MockResponse::Scalar(ScalarValue::Int(v)) => assert_eq!(v, 42),
1052            _ => panic!("Expected scalar int"),
1053        }
1054    }
1055
1056    #[tokio::test]
1057    async fn test_mock_response_error() {
1058        let response = MockResponse::error(50000, "Test error");
1059        match response {
1060            MockResponse::Error {
1061                number,
1062                message,
1063                severity,
1064            } => {
1065                assert_eq!(number, 50000);
1066                assert_eq!(message, "Test error");
1067                assert_eq!(severity, 16);
1068            }
1069            _ => panic!("Expected error response"),
1070        }
1071    }
1072
1073    #[test]
1074    fn test_scalar_value_encode_int() {
1075        let value = ScalarValue::Int(42);
1076        let mut buf = BytesMut::new();
1077        value.encode(&mut buf);
1078
1079        assert_eq!(buf.len(), 5); // 1 byte length + 4 bytes value
1080        assert_eq!(buf[0], 4); // length
1081        assert_eq!(i32::from_le_bytes([buf[1], buf[2], buf[3], buf[4]]), 42);
1082    }
1083
1084    #[test]
1085    fn test_scalar_value_encode_string() {
1086        let value = ScalarValue::String("test".to_string());
1087        let mut buf = BytesMut::new();
1088        value.encode(&mut buf);
1089
1090        // 2 bytes length + 8 bytes UTF-16
1091        assert_eq!(buf.len(), 10);
1092        assert_eq!(u16::from_le_bytes([buf[0], buf[1]]), 8);
1093    }
1094
1095    #[test]
1096    fn test_mock_column_int() {
1097        let col = MockColumn::int("id");
1098        assert_eq!(col.name, "id");
1099        assert_eq!(col.type_id, TypeId::IntN);
1100        assert_eq!(col.max_length, Some(4));
1101    }
1102
1103    #[test]
1104    fn test_mock_column_nvarchar() {
1105        let col = MockColumn::nvarchar("name", 50);
1106        assert_eq!(col.name, "name");
1107        assert_eq!(col.type_id, TypeId::NVarChar);
1108        assert_eq!(col.max_length, Some(100)); // 50 chars * 2 bytes
1109    }
1110
1111    #[test]
1112    fn test_done_status_encoding() {
1113        let mut buf = BytesMut::new();
1114        encode_done(&mut buf, 5, false);
1115
1116        assert_eq!(buf[0], TokenType::Done as u8);
1117        // Status should have COUNT flag set
1118        let status = u16::from_le_bytes([buf[1], buf[2]]);
1119        assert_eq!(status & 0x0010, 0x0010); // DONE_COUNT
1120    }
1121}