1use crate::MssqlConnectOptions;
2
3use super::packet::{encode_message, PacketFrameError, PacketType};
4use thiserror::Error;
5
6const LOGIN7_FIXED_LEN: usize = 94;
7const TDS_VERSION_74: u32 = 0x7400_0004;
8
9const OPTION_FLAGS_1: u8 = 0xe0;
10const OPTION_FLAGS_2: u8 = 0x03;
11const TYPE_FLAGS: u8 = 0x00;
12const OPTION_FLAGS_3: u8 = 0x00;
13
14pub fn build_login7_payload(options: &MssqlConnectOptions) -> Result<Vec<u8>, Login7Error> {
16 let mut fields = Login7Fields::new(LOGIN7_FIXED_LEN);
17
18 let hostname = fields.push_text(options.hostname(), false)?;
19 let username = fields.push_text(options.username(), false)?;
20 let password = fields.push_text(options.password().unwrap_or_default(), true)?;
21 let app_name = fields.push_text(options.app_name(), false)?;
22 let server_name = fields.push_text(options.server_name(), false)?;
23 let unused = Login7FieldOffset::empty(fields.next_offset);
24 let client_interface_name = fields.push_text(options.client_interface_name(), false)?;
25 let language = fields.push_text(options.language(), false)?;
26 let database = fields.push_text(options.database(), false)?;
27 let sspi = Login7FieldOffset::empty(fields.next_offset);
28 let attach_db_file = Login7FieldOffset::empty(fields.next_offset);
29 let change_password = Login7FieldOffset::empty(fields.next_offset);
30
31 let total_len = u32::from(fields.next_offset);
32 let mut out = Vec::with_capacity(usize::from(fields.next_offset));
33
34 write_u32_le(&mut out, total_len);
35 write_u32_le(&mut out, TDS_VERSION_74);
36 write_u32_le(&mut out, options.requested_packet_size());
37 write_u32_le(&mut out, options.client_program_version());
38 write_u32_le(&mut out, options.client_pid());
39 write_u32_le(&mut out, 0);
40 out.extend_from_slice(&[OPTION_FLAGS_1, OPTION_FLAGS_2, TYPE_FLAGS, OPTION_FLAGS_3]);
41 write_i32_le(&mut out, 0);
42 write_u32_le(&mut out, 0);
43
44 for offset in [
45 hostname,
46 username,
47 password,
48 app_name,
49 server_name,
50 unused,
51 client_interface_name,
52 language,
53 database,
54 ] {
55 offset.write_to(&mut out);
56 }
57
58 out.extend_from_slice(&[0; 6]);
59 sspi.write_to(&mut out);
60 attach_db_file.write_to(&mut out);
61 change_password.write_to(&mut out);
62 write_u32_le(&mut out, 0);
63
64 debug_assert_eq!(LOGIN7_FIXED_LEN, out.len());
65 out.extend_from_slice(&fields.data);
66
67 Ok(out)
68}
69
70pub fn build_login7_packet(options: &MssqlConnectOptions) -> Result<Vec<u8>, Login7Error> {
72 let payload = build_login7_payload(options)?;
73
74 encode_message(
75 PacketType::LOGIN7,
76 &payload,
77 usize::try_from(options.requested_packet_size())
78 .map_err(|_| Login7Error::MessageTooLarge)?,
79 )
80 .map_err(Login7Error::Packet)
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
84struct Login7FieldOffset {
85 offset: u16,
86 len_chars: u16,
87}
88
89impl Login7FieldOffset {
90 fn empty(offset: u16) -> Self {
91 Self {
92 offset,
93 len_chars: 0,
94 }
95 }
96
97 fn write_to(self, out: &mut Vec<u8>) {
98 write_u16_le(out, self.offset);
99 write_u16_le(out, self.len_chars);
100 }
101}
102
103struct Login7Fields {
104 data: Vec<u8>,
105 next_offset: u16,
106}
107
108impl Login7Fields {
109 fn new(base_offset: usize) -> Self {
110 Self {
111 data: Vec::new(),
112 next_offset: u16::try_from(base_offset).expect("LOGIN7 fixed header fits in u16"),
113 }
114 }
115
116 fn push_text(
117 &mut self,
118 value: &str,
119 obfuscate: bool,
120 ) -> Result<Login7FieldOffset, Login7Error> {
121 let offset = self.next_offset;
122 let len_chars =
123 u16::try_from(value.encode_utf16().count()).map_err(|_| Login7Error::FieldTooLong)?;
124 let encoded = encode_utf16_le(value, obfuscate);
125 let encoded_len = u16::try_from(encoded.len()).map_err(|_| Login7Error::MessageTooLarge)?;
126
127 self.next_offset = self
128 .next_offset
129 .checked_add(encoded_len)
130 .ok_or(Login7Error::MessageTooLarge)?;
131 self.data.extend_from_slice(&encoded);
132
133 Ok(Login7FieldOffset { offset, len_chars })
134 }
135}
136
137fn encode_utf16_le(value: &str, obfuscate: bool) -> Vec<u8> {
138 let mut out = Vec::with_capacity(value.len() * 2);
139
140 for unit in value.encode_utf16() {
141 out.extend_from_slice(&unit.to_le_bytes());
142 }
143
144 if obfuscate {
145 for byte in &mut out {
146 *byte = byte.rotate_left(4) ^ 0xa5;
147 }
148 }
149
150 out
151}
152
153fn write_u16_le(out: &mut Vec<u8>, value: u16) {
154 out.extend_from_slice(&value.to_le_bytes());
155}
156
157fn write_u32_le(out: &mut Vec<u8>, value: u32) {
158 out.extend_from_slice(&value.to_le_bytes());
159}
160
161fn write_i32_le(out: &mut Vec<u8>, value: i32) {
162 out.extend_from_slice(&value.to_le_bytes());
163}
164
165#[derive(Debug, Error, PartialEq, Eq)]
167pub enum Login7Error {
168 #[error("TDS LOGIN7 text field is too long")]
170 FieldTooLong,
171 #[error("TDS LOGIN7 message is too large")]
173 MessageTooLarge,
174 #[error(transparent)]
176 Packet(#[from] PacketFrameError),
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use crate::protocol::packet::{PacketHeader, PacketStatus, PACKET_HEADER_LEN};
183
184 #[test]
185 fn builds_login7_payload_with_little_endian_fixed_fields() {
186 let options = MssqlConnectOptions::parse_url(
187 "mssql://alice:secret@example.com/appdb?packet_size=512&client_program_version=42&client_pid=7",
188 )
189 .unwrap();
190
191 let payload = build_login7_payload(&options).unwrap();
192
193 assert_eq!(
194 payload.len() as u32,
195 u32::from_le_bytes(payload[0..4].try_into().unwrap())
196 );
197 assert_eq!(
198 TDS_VERSION_74,
199 u32::from_le_bytes(payload[4..8].try_into().unwrap())
200 );
201 assert_eq!(512, u32::from_le_bytes(payload[8..12].try_into().unwrap()));
202 assert_eq!(42, u32::from_le_bytes(payload[12..16].try_into().unwrap()));
203 assert_eq!(7, u32::from_le_bytes(payload[16..20].try_into().unwrap()));
204 assert_eq!(
205 [OPTION_FLAGS_1, OPTION_FLAGS_2, TYPE_FLAGS, OPTION_FLAGS_3],
206 payload[24..28]
207 );
208 }
209
210 #[test]
211 fn encodes_variable_fields_as_utf16_with_character_lengths() {
212 let options = MssqlConnectOptions::parse_url(
213 "mssql://al:pw@example.com/db?hostname=client&app_name=sqlx",
214 )
215 .unwrap();
216 let payload = build_login7_payload(&options).unwrap();
217
218 let hostname = field_at(&payload, 36);
219 let username = field_at(&payload, 40);
220 let password = field_at(&payload, 44);
221 let app_name = field_at(&payload, 48);
222 let database = field_at(&payload, 68);
223
224 assert_eq!((94, 6), hostname);
225 assert_eq!(b"c\0l\0i\0e\0n\0t\0", field_bytes(&payload, hostname));
226 assert_eq!((106, 2), username);
227 assert_eq!(b"a\0l\0", field_bytes(&payload, username));
228 assert_eq!((114, 4), app_name);
229 assert_eq!(b"s\0q\0l\0x\0", field_bytes(&payload, app_name));
230 assert_eq!((122, 2), database);
231 assert_eq!(b"d\0b\0", field_bytes(&payload, database));
232
233 let raw_password = encode_utf16_le("pw", true);
234 assert_eq!((110, 2), password);
235 assert_eq!(raw_password.as_slice(), field_bytes(&payload, password));
236 assert_ne!(b"p\0w\0", field_bytes(&payload, password));
237 }
238
239 #[test]
240 fn frames_login7_payload_as_login7_packet() {
241 let options = MssqlConnectOptions::parse_url(
242 "mssql://alice:secret@example.com/master?packet_size=512",
243 )
244 .unwrap();
245 let packet = build_login7_packet(&options).unwrap();
246 let header = PacketHeader::decode(&packet[..PACKET_HEADER_LEN]).unwrap();
247
248 assert_eq!(PacketType::LOGIN7, header.packet_type);
249 assert_eq!(PacketStatus::END_OF_MESSAGE, header.status);
250 assert_eq!(packet.len(), usize::from(header.length));
251 assert_eq!(
252 packet.len() - PACKET_HEADER_LEN,
253 u32::from_le_bytes(
254 packet[PACKET_HEADER_LEN..PACKET_HEADER_LEN + 4]
255 .try_into()
256 .unwrap()
257 ) as usize
258 );
259 }
260
261 #[test]
262 fn rejects_text_fields_that_do_not_fit_login7_lengths() {
263 let mut options = MssqlConnectOptions::new();
264 options.set_hostname_for_test("a".repeat(usize::from(u16::MAX) + 1));
265
266 let err = build_login7_payload(&options).unwrap_err();
267
268 assert_eq!(Login7Error::FieldTooLong, err);
269 }
270
271 fn field_at(payload: &[u8], offset: usize) -> (usize, usize) {
272 let start = usize::from(u16::from_le_bytes(
273 payload[offset..offset + 2].try_into().unwrap(),
274 ));
275 let len_chars = usize::from(u16::from_le_bytes(
276 payload[offset + 2..offset + 4].try_into().unwrap(),
277 ));
278
279 (start, len_chars)
280 }
281
282 fn field_bytes(payload: &[u8], field: (usize, usize)) -> &[u8] {
283 &payload[field.0..field.0 + field.1 * 2]
284 }
285}