1use bytes::{BufMut, Bytes, BytesMut};
32use std::collections::HashMap;
33use std::fmt;
34use std::net::SocketAddr;
35use std::sync::Arc;
36use tds_protocol::types::TypeId;
37use tds_protocol::{
38 DoneStatus, EnvChangeType, PACKET_HEADER_SIZE, PacketHeader, PacketStatus, PacketType,
39 TokenType,
40};
41use thiserror::Error;
42use tokio::io::{AsyncReadExt, AsyncWriteExt};
43use tokio::net::{TcpListener, TcpStream};
44use tokio::sync::{Mutex, broadcast};
45
46#[derive(Debug, Error)]
48pub enum MockServerError {
49 #[error("IO error: {0}")]
51 Io(#[from] std::io::Error),
52
53 #[error("Protocol error: {0}")]
55 Protocol(String),
56
57 #[error("Server already stopped")]
59 Stopped,
60}
61
62pub type Result<T> = std::result::Result<T, MockServerError>;
64
65#[derive(Clone)]
67pub enum MockResponse {
68 Scalar(ScalarValue),
70
71 Rows {
73 columns: Vec<MockColumn>,
75 rows: Vec<Vec<ScalarValue>>,
77 },
78
79 Error {
81 number: i32,
83 message: String,
85 severity: u8,
87 },
88
89 RowsAffected(u64),
91
92 Raw(Bytes),
94
95 Custom(Arc<dyn Fn(&str) -> MockResponse + Send + Sync>),
97}
98
99impl fmt::Debug for MockResponse {
100 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101 match self {
102 Self::Scalar(v) => f.debug_tuple("Scalar").field(v).finish(),
103 Self::Rows { columns, rows } => f
104 .debug_struct("Rows")
105 .field("columns", columns)
106 .field("rows", rows)
107 .finish(),
108 Self::Error {
109 number,
110 message,
111 severity,
112 } => f
113 .debug_struct("Error")
114 .field("number", number)
115 .field("message", message)
116 .field("severity", severity)
117 .finish(),
118 Self::RowsAffected(n) => f.debug_tuple("RowsAffected").field(n).finish(),
119 Self::Raw(data) => f.debug_tuple("Raw").field(&data.len()).finish(),
120 Self::Custom(_) => f.debug_tuple("Custom").field(&"<fn>").finish(),
121 }
122 }
123}
124
125impl MockResponse {
126 pub fn scalar_int(value: i32) -> Self {
128 Self::Scalar(ScalarValue::Int(value))
129 }
130
131 pub fn scalar_string(value: impl Into<String>) -> Self {
133 Self::Scalar(ScalarValue::String(value.into()))
134 }
135
136 pub fn empty() -> Self {
138 Self::RowsAffected(0)
139 }
140
141 pub fn affected(count: u64) -> Self {
143 Self::RowsAffected(count)
144 }
145
146 pub fn error(number: i32, message: impl Into<String>) -> Self {
148 Self::Error {
149 number,
150 message: message.into(),
151 severity: 16,
152 }
153 }
154
155 pub fn rows(columns: Vec<MockColumn>, rows: Vec<Vec<ScalarValue>>) -> Self {
157 Self::Rows { columns, rows }
158 }
159}
160
161#[derive(Debug, Clone)]
163pub enum ScalarValue {
164 Null,
166 Bool(bool),
168 Int(i32),
170 BigInt(i64),
172 Float(f32),
174 Double(f64),
176 String(String),
178 Binary(Vec<u8>),
180}
181
182impl ScalarValue {
183 fn type_id(&self) -> TypeId {
185 match self {
186 Self::Null => TypeId::Null,
187 Self::Bool(_) => TypeId::BitN,
188 Self::Int(_) => TypeId::IntN,
189 Self::BigInt(_) => TypeId::IntN,
190 Self::Float(_) => TypeId::FloatN,
191 Self::Double(_) => TypeId::FloatN,
192 Self::String(_) => TypeId::NVarChar,
193 Self::Binary(_) => TypeId::BigVarBinary,
194 }
195 }
196
197 fn encode(&self, dst: &mut BytesMut) {
199 match self {
200 Self::Null => {
201 dst.put_u8(0); }
203 Self::Bool(v) => {
204 dst.put_u8(1); dst.put_u8(if *v { 1 } else { 0 });
206 }
207 Self::Int(v) => {
208 dst.put_u8(4); dst.put_i32_le(*v);
210 }
211 Self::BigInt(v) => {
212 dst.put_u8(8); dst.put_i64_le(*v);
214 }
215 Self::Float(v) => {
216 dst.put_u8(4); dst.put_f32_le(*v);
218 }
219 Self::Double(v) => {
220 dst.put_u8(8); dst.put_f64_le(*v);
222 }
223 Self::String(s) => {
224 let utf16: Vec<u16> = s.encode_utf16().collect();
225 let byte_len = utf16.len() * 2;
226 if byte_len > 0xFFFF {
227 dst.put_u64_le(byte_len as u64);
229 dst.put_u32_le(byte_len as u32);
230 for c in utf16 {
231 dst.put_u16_le(c);
232 }
233 dst.put_u32_le(0); } else {
235 dst.put_u16_le(byte_len as u16);
236 for c in utf16 {
237 dst.put_u16_le(c);
238 }
239 }
240 }
241 Self::Binary(data) => {
242 if data.len() > 0xFFFF {
243 dst.put_u64_le(data.len() as u64);
245 dst.put_u32_le(data.len() as u32);
246 dst.extend_from_slice(data);
247 dst.put_u32_le(0); } else {
249 dst.put_u16_le(data.len() as u16);
250 dst.extend_from_slice(data);
251 }
252 }
253 }
254 }
255}
256
257#[derive(Debug, Clone)]
259pub struct MockColumn {
260 pub name: String,
262 pub type_id: TypeId,
264 pub max_length: Option<u32>,
266 pub nullable: bool,
268}
269
270impl MockColumn {
271 pub fn new(name: impl Into<String>, type_id: TypeId) -> Self {
273 Self {
274 name: name.into(),
275 type_id,
276 max_length: None,
277 nullable: true,
278 }
279 }
280
281 pub fn int(name: impl Into<String>) -> Self {
283 Self::new(name, TypeId::IntN).with_max_length(4)
284 }
285
286 pub fn bigint(name: impl Into<String>) -> Self {
288 Self::new(name, TypeId::IntN).with_max_length(8)
289 }
290
291 pub fn nvarchar(name: impl Into<String>, max_len: u32) -> Self {
293 Self::new(name, TypeId::NVarChar).with_max_length(max_len * 2)
294 }
295
296 pub fn with_max_length(mut self, len: u32) -> Self {
298 self.max_length = Some(len);
299 self
300 }
301
302 pub fn with_nullable(mut self, nullable: bool) -> Self {
304 self.nullable = nullable;
305 self
306 }
307}
308
309#[derive(Default)]
311pub struct MockServerConfig {
312 responses: HashMap<String, MockResponse>,
314 default_response: Option<MockResponse>,
316 server_name: String,
318 tds_version: u32,
320 database: String,
322}
323
324pub struct MockServerBuilder {
326 config: MockServerConfig,
327}
328
329impl MockServerBuilder {
330 pub fn new() -> Self {
332 Self {
333 config: MockServerConfig {
334 responses: HashMap::new(),
335 default_response: Some(MockResponse::empty()),
336 server_name: "MockSQLServer".to_string(),
337 tds_version: 0x74000004, database: "master".to_string(),
339 },
340 }
341 }
342
343 pub fn with_response(mut self, sql: impl Into<String>, response: MockResponse) -> Self {
345 self.config.responses.insert(sql.into(), response);
346 self
347 }
348
349 pub fn with_default_response(mut self, response: MockResponse) -> Self {
351 self.config.default_response = Some(response);
352 self
353 }
354
355 pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
357 self.config.server_name = name.into();
358 self
359 }
360
361 pub fn with_database(mut self, db: impl Into<String>) -> Self {
363 self.config.database = db.into();
364 self
365 }
366
367 pub async fn build(self) -> Result<MockTdsServer> {
369 MockTdsServer::start(self.config).await
370 }
371}
372
373impl Default for MockServerBuilder {
374 fn default() -> Self {
375 Self::new()
376 }
377}
378
379pub struct MockTdsServer {
385 addr: SocketAddr,
387 shutdown_tx: broadcast::Sender<()>,
389 #[allow(dead_code)]
391 config: Arc<MockServerConfig>,
392 connection_count: Arc<Mutex<usize>>,
394}
395
396impl MockTdsServer {
397 pub fn builder() -> MockServerBuilder {
399 MockServerBuilder::new()
400 }
401
402 pub async fn start(config: MockServerConfig) -> Result<Self> {
404 let listener = TcpListener::bind("127.0.0.1:0").await?;
405 let addr = listener.local_addr()?;
406 let (shutdown_tx, _) = broadcast::channel(1);
407 let config = Arc::new(config);
408 let connection_count = Arc::new(Mutex::new(0usize));
409
410 let server = Self {
411 addr,
412 shutdown_tx: shutdown_tx.clone(),
413 config: config.clone(),
414 connection_count: connection_count.clone(),
415 };
416
417 let mut shutdown_rx = shutdown_tx.subscribe();
419 tokio::spawn(async move {
420 loop {
421 tokio::select! {
422 result = listener.accept() => {
423 match result {
424 Ok((stream, _peer_addr)) => {
425 let config = config.clone();
426 let count = connection_count.clone();
427 tokio::spawn(async move {
428 {
429 let mut c = count.lock().await;
430 *c += 1;
431 }
432 if let Err(e) = handle_connection(stream, config).await {
433 tracing::debug!("Connection error: {}", e);
434 }
435 {
436 let mut c = count.lock().await;
437 *c = c.saturating_sub(1);
438 }
439 });
440 }
441 Err(e) => {
442 tracing::error!("Accept error: {}", e);
443 break;
444 }
445 }
446 }
447 _ = shutdown_rx.recv() => {
448 break;
449 }
450 }
451 }
452 });
453
454 Ok(server)
455 }
456
457 pub fn addr(&self) -> SocketAddr {
459 self.addr
460 }
461
462 pub fn host(&self) -> String {
464 self.addr.ip().to_string()
465 }
466
467 pub fn port(&self) -> u16 {
469 self.addr.port()
470 }
471
472 pub async fn connection_count(&self) -> usize {
474 *self.connection_count.lock().await
475 }
476
477 pub fn stop(&self) {
479 let _ = self.shutdown_tx.send(());
480 }
481}
482
483impl Drop for MockTdsServer {
484 fn drop(&mut self) {
485 self.stop();
486 }
487}
488
489async fn handle_connection(mut stream: TcpStream, config: Arc<MockServerConfig>) -> Result<()> {
491 let prelogin_request = read_packet(&mut stream).await?;
493 if prelogin_request.packet_type != PacketType::PreLogin {
494 return Err(MockServerError::Protocol(format!(
495 "Expected PreLogin, got {:?}",
496 prelogin_request.packet_type
497 )));
498 }
499 send_prelogin_response(&mut stream).await?;
500
501 let login_request = read_packet(&mut stream).await?;
503 if login_request.packet_type != PacketType::Tds7Login {
504 return Err(MockServerError::Protocol(format!(
505 "Expected Tds7Login, got {:?}",
506 login_request.packet_type
507 )));
508 }
509 send_login_response(&mut stream, &config).await?;
510
511 loop {
513 let packet = match read_packet(&mut stream).await {
514 Ok(p) => p,
515 Err(MockServerError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
516 break;
518 }
519 Err(e) => return Err(e),
520 };
521
522 match packet.packet_type {
523 PacketType::SqlBatch => {
524 let sql = decode_sql_batch(&packet.payload)?;
525 let response = find_response(&sql, &config);
526 send_query_response(&mut stream, response).await?;
527 }
528 PacketType::Rpc => {
529 let response = config
532 .default_response
533 .clone()
534 .unwrap_or(MockResponse::empty());
535 send_query_response(&mut stream, response).await?;
536 }
537 PacketType::Attention => {
538 send_attention_ack(&mut stream).await?;
540 }
541 _ => {
542 tracing::debug!("Unexpected packet type: {:?}", packet.packet_type);
543 }
544 }
545 }
546
547 Ok(())
548}
549
550struct Packet {
552 packet_type: PacketType,
553 payload: Bytes,
554}
555
556async fn read_packet(stream: &mut TcpStream) -> Result<Packet> {
558 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
559 stream.read_exact(&mut header_buf).await?;
560
561 let mut cursor = &header_buf[..];
562 let header =
563 PacketHeader::decode(&mut cursor).map_err(|e| MockServerError::Protocol(e.to_string()))?;
564
565 let payload_len = header.payload_length();
566 let mut payload = vec![0u8; payload_len];
567 if payload_len > 0 {
568 stream.read_exact(&mut payload).await?;
569 }
570
571 let mut full_payload = BytesMut::from(&payload[..]);
573
574 if !header.is_end_of_message() {
575 loop {
576 let mut next_header_buf = [0u8; PACKET_HEADER_SIZE];
577 stream.read_exact(&mut next_header_buf).await?;
578
579 let mut cursor = &next_header_buf[..];
580 let next_header = PacketHeader::decode(&mut cursor)
581 .map_err(|e| MockServerError::Protocol(e.to_string()))?;
582
583 let next_payload_len = next_header.payload_length();
584 let mut next_payload = vec![0u8; next_payload_len];
585 if next_payload_len > 0 {
586 stream.read_exact(&mut next_payload).await?;
587 }
588
589 full_payload.extend_from_slice(&next_payload);
590
591 if next_header.is_end_of_message() {
592 break;
593 }
594 }
595 }
596
597 Ok(Packet {
598 packet_type: header.packet_type,
599 payload: full_payload.freeze(),
600 })
601}
602
603async fn write_packet(
605 stream: &mut TcpStream,
606 packet_type: PacketType,
607 payload: &[u8],
608) -> Result<()> {
609 let total_len = PACKET_HEADER_SIZE + payload.len();
610 let header = PacketHeader {
611 packet_type,
612 status: PacketStatus::END_OF_MESSAGE,
613 length: total_len as u16,
614 spid: 0,
615 packet_id: 1,
616 window: 0,
617 };
618
619 let mut buf = BytesMut::with_capacity(total_len);
620 header.encode(&mut buf);
621 buf.extend_from_slice(payload);
622
623 stream.write_all(&buf).await?;
624 stream.flush().await?;
625 Ok(())
626}
627
628async fn send_prelogin_response(stream: &mut TcpStream) -> Result<()> {
630 let mut response = BytesMut::new();
645
646 response.put_u8(0x00); response.put_u16(11); response.put_u16(6); response.put_u8(0x01); response.put_u16(17); response.put_u16(1); response.put_u8(0xFF);
659
660 response.put_u8(16); response.put_u8(0); response.put_u16_le(0); response.put_u16_le(0); response.put_u8(0x00); write_packet(stream, PacketType::PreLogin, &response).await
670}
671
672async fn send_login_response(stream: &mut TcpStream, config: &MockServerConfig) -> Result<()> {
674 let mut response = BytesMut::new();
675
676 encode_env_change(&mut response, EnvChangeType::Database, &config.database, "");
678
679 encode_env_change(&mut response, EnvChangeType::PacketSize, "4096", "4096");
681
682 encode_login_ack(&mut response, &config.server_name, config.tds_version);
684
685 encode_done(&mut response, 0, false);
687
688 write_packet(stream, PacketType::TabularResult, &response).await
689}
690
691fn encode_env_change(dst: &mut BytesMut, env_type: EnvChangeType, new_val: &str, old_val: &str) {
693 let new_utf16: Vec<u16> = new_val.encode_utf16().collect();
694 let old_utf16: Vec<u16> = old_val.encode_utf16().collect();
695
696 let data_len = 1 + 1 + new_utf16.len() * 2 + 1 + old_utf16.len() * 2;
697
698 dst.put_u8(TokenType::EnvChange as u8);
699 dst.put_u16_le(data_len as u16);
700 dst.put_u8(env_type as u8);
701
702 dst.put_u8(new_utf16.len() as u8);
704 for c in &new_utf16 {
705 dst.put_u16_le(*c);
706 }
707
708 dst.put_u8(old_utf16.len() as u8);
710 for c in &old_utf16 {
711 dst.put_u16_le(*c);
712 }
713}
714
715fn encode_login_ack(dst: &mut BytesMut, server_name: &str, tds_version: u32) {
717 let name_utf16: Vec<u16> = server_name.encode_utf16().collect();
718
719 let data_len = 1 + 4 + 1 + name_utf16.len() * 2 + 4;
721
722 dst.put_u8(TokenType::LoginAck as u8);
723 dst.put_u16_le(data_len as u16);
724 dst.put_u8(1); dst.put_u32_le(tds_version);
726
727 dst.put_u8(name_utf16.len() as u8);
729 for c in &name_utf16 {
730 dst.put_u16_le(*c);
731 }
732
733 dst.put_u32_le(0x10000000); }
736
737fn encode_done(dst: &mut BytesMut, row_count: u64, more: bool) {
739 dst.put_u8(TokenType::Done as u8);
740
741 let status = DoneStatus {
742 count: row_count > 0,
743 more,
744 ..Default::default()
745 };
746
747 dst.put_u16_le(status.to_bits());
748 dst.put_u16_le(0xC1); dst.put_u64_le(row_count);
750}
751
752fn decode_sql_batch(payload: &Bytes) -> Result<String> {
754 let mut cursor = payload.as_ref();
758
759 if cursor.len() >= 4 {
761 let total_len = u32::from_le_bytes([cursor[0], cursor[1], cursor[2], cursor[3]]) as usize;
762
763 if total_len >= 4 && total_len < cursor.len() && total_len < 1000 {
765 cursor = &cursor[total_len..];
766 }
767 }
768
769 if cursor.len() % 2 != 0 {
771 return Err(MockServerError::Protocol(
772 "Invalid UTF-16 SQL text length".to_string(),
773 ));
774 }
775
776 let char_count = cursor.len() / 2;
777 let mut chars = Vec::with_capacity(char_count);
778 for i in 0..char_count {
779 let c = u16::from_le_bytes([cursor[i * 2], cursor[i * 2 + 1]]);
780 chars.push(c);
781 }
782
783 String::from_utf16(&chars)
784 .map_err(|_| MockServerError::Protocol("Invalid UTF-16 SQL text".to_string()))
785}
786
787fn find_response(sql: &str, config: &MockServerConfig) -> MockResponse {
789 let normalized = sql.trim().to_uppercase();
791
792 if let Some(response) = config.responses.get(&normalized) {
794 return response.clone();
795 }
796
797 for (key, response) in &config.responses {
799 if key.trim().to_uppercase() == normalized {
800 return response.clone();
801 }
802 }
803
804 config
806 .default_response
807 .clone()
808 .unwrap_or(MockResponse::empty())
809}
810
811async fn send_query_response(stream: &mut TcpStream, response: MockResponse) -> Result<()> {
813 let mut buf = BytesMut::new();
814
815 match response {
816 MockResponse::Scalar(value) => {
817 encode_colmetadata(&mut buf, &[MockColumn::new("", value.type_id())]);
819 encode_row(&mut buf, &[value.clone()]);
820 encode_done(&mut buf, 1, false);
821 }
822 MockResponse::Rows { columns, rows } => {
823 encode_colmetadata(&mut buf, &columns);
824 for row in &rows {
825 encode_row(&mut buf, row);
826 }
827 encode_done(&mut buf, rows.len() as u64, false);
828 }
829 MockResponse::Error {
830 number,
831 message,
832 severity,
833 } => {
834 encode_error(&mut buf, number, &message, severity);
835 encode_done(&mut buf, 0, false);
836 }
837 MockResponse::RowsAffected(count) => {
838 encode_done(&mut buf, count, false);
839 }
840 MockResponse::Raw(data) => {
841 buf.extend_from_slice(&data);
842 }
843 MockResponse::Custom(_handler) => {
844 encode_done(&mut buf, 0, false);
847 }
848 }
849
850 write_packet(stream, PacketType::TabularResult, &buf).await
851}
852
853fn encode_colmetadata(dst: &mut BytesMut, columns: &[MockColumn]) {
855 dst.put_u8(TokenType::ColMetaData as u8);
856 dst.put_u16_le(columns.len() as u16);
857
858 for col in columns {
859 dst.put_u32_le(0);
861
862 dst.put_u16_le(if col.nullable { 0x01 } else { 0x00 });
864
865 dst.put_u8(col.type_id as u8);
867
868 match col.type_id {
870 TypeId::IntN | TypeId::BitN | TypeId::FloatN | TypeId::MoneyN | TypeId::DateTimeN => {
871 dst.put_u8(col.max_length.unwrap_or(4) as u8);
872 }
873 TypeId::NVarChar | TypeId::NChar => {
874 dst.put_u16_le(col.max_length.unwrap_or(8000) as u16);
875 dst.put_u32_le(0x0904D000); dst.put_u8(0x34); }
879 TypeId::BigVarBinary | TypeId::BigBinary => {
880 dst.put_u16_le(col.max_length.unwrap_or(8000) as u16);
881 }
882 _ => {
883 }
885 }
886
887 let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
889 dst.put_u8(name_utf16.len() as u8);
890 for c in &name_utf16 {
891 dst.put_u16_le(*c);
892 }
893 }
894}
895
896fn encode_row(dst: &mut BytesMut, values: &[ScalarValue]) {
898 dst.put_u8(TokenType::Row as u8);
899 for value in values {
900 value.encode(dst);
901 }
902}
903
904fn encode_error(dst: &mut BytesMut, number: i32, message: &str, severity: u8) {
906 let msg_utf16: Vec<u16> = message.encode_utf16().collect();
907 let server_utf16: Vec<u16> = "MockServer".encode_utf16().collect();
908
909 let data_len = (4 + 1 + 1 + 2 + msg_utf16.len() * 2 + 1 + server_utf16.len() * 2 + 1) + 4;
912
913 dst.put_u8(TokenType::Error as u8);
914 dst.put_u16_le(data_len as u16);
915 dst.put_i32_le(number);
916 dst.put_u8(1); dst.put_u8(severity); dst.put_u16_le(msg_utf16.len() as u16);
921 for c in &msg_utf16 {
922 dst.put_u16_le(*c);
923 }
924
925 dst.put_u8(server_utf16.len() as u8);
927 for c in &server_utf16 {
928 dst.put_u16_le(*c);
929 }
930
931 dst.put_u8(0);
933
934 dst.put_i32_le(1);
936}
937
938async fn send_attention_ack(stream: &mut TcpStream) -> Result<()> {
940 let mut buf = BytesMut::new();
941
942 buf.put_u8(TokenType::Done as u8);
944 let status = DoneStatus {
945 attn: true,
946 ..Default::default()
947 };
948 buf.put_u16_le(status.to_bits());
949 buf.put_u16_le(0);
950 buf.put_u64_le(0);
951
952 write_packet(stream, PacketType::TabularResult, &buf).await
953}
954
955#[derive(Debug, Clone)]
957pub struct RecordedPacket {
958 pub from_server: bool,
960 pub data: Bytes,
962}
963
964#[derive(Debug, Default)]
966pub struct PacketRecorder {
967 packets: Vec<RecordedPacket>,
968}
969
970impl PacketRecorder {
971 pub fn new() -> Self {
973 Self::default()
974 }
975
976 pub fn record(&mut self, from_server: bool, data: Bytes) {
978 self.packets.push(RecordedPacket { from_server, data });
979 }
980
981 pub fn packets(&self) -> &[RecordedPacket] {
983 &self.packets
984 }
985
986 pub async fn save(&self, path: &std::path::Path) -> std::io::Result<()> {
988 use tokio::fs::File;
989 use tokio::io::AsyncWriteExt;
990
991 let mut file = File::create(path).await?;
992
993 for packet in &self.packets {
994 file.write_u8(if packet.from_server { 1 } else { 0 })
996 .await?;
997 file.write_u32_le(packet.data.len() as u32).await?;
998 file.write_all(&packet.data).await?;
999 }
1000
1001 Ok(())
1002 }
1003
1004 pub async fn load(path: &std::path::Path) -> std::io::Result<Self> {
1006 use tokio::fs::File;
1007 use tokio::io::AsyncReadExt;
1008
1009 let mut file = File::open(path).await?;
1010 let mut recorder = Self::new();
1011
1012 loop {
1013 let from_server = match file.read_u8().await {
1014 Ok(b) => b != 0,
1015 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
1016 Err(e) => return Err(e),
1017 };
1018
1019 let len = file.read_u32_le().await? as usize;
1020 let mut data = vec![0u8; len];
1021 file.read_exact(&mut data).await?;
1022
1023 recorder.record(from_server, Bytes::from(data));
1024 }
1025
1026 Ok(recorder)
1027 }
1028}
1029
1030#[cfg(test)]
1031#[allow(clippy::unwrap_used, clippy::panic)]
1032mod tests {
1033 use super::*;
1034
1035 #[tokio::test]
1036 async fn test_mock_server_starts() {
1037 let server = MockTdsServer::builder()
1038 .with_server_name("TestServer")
1039 .build()
1040 .await
1041 .unwrap();
1042
1043 assert!(server.port() > 0);
1044 assert_eq!(server.host(), "127.0.0.1");
1045 }
1046
1047 #[tokio::test]
1048 async fn test_mock_response_scalar() {
1049 let response = MockResponse::scalar_int(42);
1050 match response {
1051 MockResponse::Scalar(ScalarValue::Int(v)) => assert_eq!(v, 42),
1052 _ => panic!("Expected scalar int"),
1053 }
1054 }
1055
1056 #[tokio::test]
1057 async fn test_mock_response_error() {
1058 let response = MockResponse::error(50000, "Test error");
1059 match response {
1060 MockResponse::Error {
1061 number,
1062 message,
1063 severity,
1064 } => {
1065 assert_eq!(number, 50000);
1066 assert_eq!(message, "Test error");
1067 assert_eq!(severity, 16);
1068 }
1069 _ => panic!("Expected error response"),
1070 }
1071 }
1072
1073 #[test]
1074 fn test_scalar_value_encode_int() {
1075 let value = ScalarValue::Int(42);
1076 let mut buf = BytesMut::new();
1077 value.encode(&mut buf);
1078
1079 assert_eq!(buf.len(), 5); assert_eq!(buf[0], 4); assert_eq!(i32::from_le_bytes([buf[1], buf[2], buf[3], buf[4]]), 42);
1082 }
1083
1084 #[test]
1085 fn test_scalar_value_encode_string() {
1086 let value = ScalarValue::String("test".to_string());
1087 let mut buf = BytesMut::new();
1088 value.encode(&mut buf);
1089
1090 assert_eq!(buf.len(), 10);
1092 assert_eq!(u16::from_le_bytes([buf[0], buf[1]]), 8);
1093 }
1094
1095 #[test]
1096 fn test_mock_column_int() {
1097 let col = MockColumn::int("id");
1098 assert_eq!(col.name, "id");
1099 assert_eq!(col.type_id, TypeId::IntN);
1100 assert_eq!(col.max_length, Some(4));
1101 }
1102
1103 #[test]
1104 fn test_mock_column_nvarchar() {
1105 let col = MockColumn::nvarchar("name", 50);
1106 assert_eq!(col.name, "name");
1107 assert_eq!(col.type_id, TypeId::NVarChar);
1108 assert_eq!(col.max_length, Some(100)); }
1110
1111 #[test]
1112 fn test_done_status_encoding() {
1113 let mut buf = BytesMut::new();
1114 encode_done(&mut buf, 5, false);
1115
1116 assert_eq!(buf[0], TokenType::Done as u8);
1117 let status = u16::from_le_bytes([buf[1], buf[2]]);
1119 assert_eq!(status & 0x0010, 0x0010); }
1121}