1use std::io::{Read, Write};
9use std::net::{SocketAddr, TcpStream};
10use std::time::{Duration, Instant, SystemTime};
11
12use super::common::{filetime_to_system_time, system_time_to_us};
13use crate::time_src::{OffsetMicros, TimeSource, TimeSourceError};
14
15pub struct SmbSource;
16
17const SMB2_CAPABILITIES: u32 = 0x7F;
19
20struct FieldReader<'a> {
22 buf: &'a [u8],
23 pos: usize,
24}
25
26impl<'a> FieldReader<'a> {
27 fn new(buf: &'a [u8]) -> Self {
28 Self { buf, pos: 0 }
29 }
30
31 fn read_u16_le(&mut self) -> Result<u16, TimeSourceError> {
32 let b = self.next_bytes(2)?;
33 Ok(u16::from_le_bytes([b[0], b[1]]))
34 }
35
36 fn read_u32_le(&mut self) -> Result<u32, TimeSourceError> {
37 let b = self.next_bytes(4)?;
38 Ok(u32::from_le_bytes([b[0], b[1], b[2], b[3]]))
39 }
40
41 fn read_u64_le(&mut self) -> Result<u64, TimeSourceError> {
42 let b = self.next_bytes(8)?;
43 Ok(u64::from_le_bytes([
44 b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7],
45 ]))
46 }
47
48 fn skip(&mut self, n: usize) -> Result<(), TimeSourceError> {
49 self.next_bytes(n)?;
50 Ok(())
51 }
52
53 fn next_bytes(&mut self, n: usize) -> Result<&'a [u8], TimeSourceError> {
54 let end = self
55 .pos
56 .checked_add(n)
57 .ok_or_else(|| TimeSourceError::Parse("FieldReader overflow".into()))?;
58 if end > self.buf.len() {
59 return Err(TimeSourceError::Parse("SMB body overruns buffer".into()));
60 }
61 let b = &self.buf[self.pos..end];
62 self.pos = end;
63 Ok(b)
64 }
65}
66
67impl TimeSource for SmbSource {
68 fn name(&self) -> &'static str {
69 "smb"
70 }
71
72 fn fetch(
73 &self,
74 target: SocketAddr,
75 timeout: Duration,
76 ) -> Result<OffsetMicros, TimeSourceError> {
77 let smb_addr: SocketAddr = (target.ip(), 445).into();
78 fetch_smb(smb_addr, timeout)
79 }
80}
81
82fn fetch_smb(addr: SocketAddr, timeout: Duration) -> Result<OffsetMicros, TimeSourceError> {
83 let mut stream = TcpStream::connect_timeout(&addr, timeout).map_err(map_io_err)?;
84 stream
85 .set_read_timeout(Some(timeout))
86 .map_err(|e| TimeSourceError::Protocol(e.to_string()))?;
87 stream
88 .set_write_timeout(Some(timeout))
89 .map_err(|e| TimeSourceError::Protocol(e.to_string()))?;
90
91 let t_send = Instant::now();
92 let t_send_sys = SystemTime::now();
93
94 let request = build_negotiate_request();
95 stream
96 .write_all(&request)
97 .map_err(|e| TimeSourceError::Protocol(e.to_string()))?;
98
99 let mut nb_header = [0u8; 4];
101 stream.read_exact(&mut nb_header).map_err(map_io_err)?;
102 let msg_len = u32::from_be_bytes(nb_header) & 0x00FF_FFFF;
104 if msg_len > 65536 {
105 return Err(TimeSourceError::Protocol(format!(
106 "implausibly large SMB2 response: {} bytes",
107 msg_len
108 )));
109 }
110 if msg_len < 64 + 65 {
111 return Err(TimeSourceError::Parse(format!(
112 "SMB2 response too short: {} bytes",
113 msg_len
114 )));
115 }
116
117 let mut body = vec![0u8; msg_len as usize];
118 stream.read_exact(&mut body).map_err(map_io_err)?;
119
120 let rtt = t_send.elapsed();
121
122 let negotiate = &body[64..];
124 let server_time = parse_negotiate_response(negotiate)?;
125
126 let t_mid_us = system_time_to_us(t_send_sys)? + (rtt.as_micros() as i64) / 2;
129 let server_us = system_time_to_us(server_time)?;
130
131 Ok(server_us - t_mid_us)
132}
133
134fn build_negotiate_request() -> Vec<u8> {
136 let dialects: &[u16] = &[0x0300, 0x0210, 0x0202];
138 let dialect_count = dialects.len() as u16;
139
140 let body_size = 2 + 2 + 2 + 2 + 4 + 16 + 8 + (2 * dialect_count as usize);
145 let smb2_header_size = 64usize;
146 let total = smb2_header_size + body_size;
147
148 let mut pkt = vec![0u8; 4 + total]; pkt[1] = ((total >> 16) & 0xFF) as u8;
152 pkt[2] = ((total >> 8) & 0xFF) as u8;
153 pkt[3] = (total & 0xFF) as u8;
154
155 let h = &mut pkt[4..4 + smb2_header_size];
156 h[0..4].copy_from_slice(b"\xfeSMB");
158 h[4..6].copy_from_slice(&64u16.to_le_bytes());
160 h[12..14].copy_from_slice(&0u16.to_le_bytes());
162 h[18..20].copy_from_slice(&1u16.to_le_bytes());
165 h[28..36].copy_from_slice(&1u64.to_le_bytes());
167
168 let b = &mut pkt[4 + smb2_header_size..];
169 b[0..2].copy_from_slice(&36u16.to_le_bytes());
171 b[2..4].copy_from_slice(&dialect_count.to_le_bytes());
173 b[4..6].copy_from_slice(&1u16.to_le_bytes());
175 b[8..12].copy_from_slice(&SMB2_CAPABILITIES.to_le_bytes());
176 let mut guid = [0u8; 16];
178 for b_out in guid.iter_mut() {
179 *b_out = rand::random();
180 }
181 guid[6] = (guid[6] & 0x0F) | 0x40; guid[8] = (guid[8] & 0x3F) | 0x80; b[12..28].copy_from_slice(&guid);
184 for (i, &d) in dialects.iter().enumerate() {
186 let off = 36 + i * 2;
187 b[off..off + 2].copy_from_slice(&d.to_le_bytes());
188 }
189
190 pkt
191}
192
193fn parse_negotiate_response(b: &[u8]) -> Result<SystemTime, TimeSourceError> {
195 let mut r = FieldReader::new(b);
196 let structure_size = r.read_u16_le()?; let _security_mode = r.read_u16_le()?; let _dialect_revision = r.read_u16_le()?; let _negotiate_ctx_cnt = r.read_u16_le()?; r.skip(16)?; let _capabilities = r.read_u32_le()?; let _max_transact = r.read_u32_le()?; let _max_read = r.read_u32_le()?; let _max_write = r.read_u32_le()?; let system_time = r.read_u64_le()?; if structure_size != 65 {
209 return Err(TimeSourceError::Protocol(format!(
210 "unexpected SMB2 NEGOTIATE_RESPONSE StructureSize: {}",
211 structure_size
212 )));
213 }
214
215 filetime_to_system_time(system_time)
216}
217
218fn map_io_err(e: std::io::Error) -> TimeSourceError {
219 use std::io::ErrorKind::*;
220 match e.kind() {
221 TimedOut | WouldBlock => TimeSourceError::Timeout,
222 ConnectionRefused => TimeSourceError::Refused,
223 _ => TimeSourceError::Protocol(e.to_string()),
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230 use std::time::UNIX_EPOCH;
231
232 #[test]
233 fn filetime_unix_epoch() {
234 let ft: u64 = 116_444_736_000_000_000;
236 let st = filetime_to_system_time(ft).unwrap();
237 assert_eq!(st, UNIX_EPOCH);
238 }
239
240 #[test]
241 fn filetime_2024_01_01() {
242 let ft: u64 = 133_485_408_000_000_000;
246 let st = filetime_to_system_time(ft).unwrap();
247 let unix_secs = st.duration_since(UNIX_EPOCH).unwrap().as_secs();
248 assert_eq!(unix_secs, 1_704_067_200);
249 }
250
251 #[test]
252 fn filetime_before_unix_epoch_errors() {
253 assert!(filetime_to_system_time(0).is_err());
254 assert!(filetime_to_system_time(100).is_err());
255 }
256
257 #[test]
258 fn negotiate_response_too_short() {
259 assert!(parse_negotiate_response(&[0u8; 10]).is_err());
260 }
261
262 #[test]
263 fn negotiate_response_bad_structure_size() {
264 let mut b = vec![0u8; 50];
265 b[0..2].copy_from_slice(&99u16.to_le_bytes());
267 assert!(parse_negotiate_response(&b).is_err());
268 }
269
270 #[test]
271 fn build_negotiate_request_has_random_guid() {
272 let r1 = build_negotiate_request();
274 let r2 = build_negotiate_request();
275 assert_ne!(
276 &r1[80..96],
277 &r2[80..96],
278 "ClientGuid must differ between calls"
279 );
280 assert_ne!(&r1[80..96], &[0u8; 16]);
282 }
283
284 #[test]
285 fn fetch_smb_rejects_large_msg_len() {
286 let large: u32 = 0x0002_0000; assert!(large > 65536);
291 let nb = [0x00u8, 0x02, 0x00, 0x00];
293 let msg_len = u32::from_be_bytes(nb) & 0x00FF_FFFF;
294 assert_eq!(msg_len, 131072);
295 assert!(msg_len > 65536);
296 }
297}