1use std::io::{BufReader, Read, Write};
54use std::net::{TcpStream, ToSocketAddrs};
55use std::path::{Path, PathBuf};
56use std::sync::atomic::{AtomicU64, Ordering};
57use std::sync::{Arc, Mutex};
58use std::time::Duration;
59
60use prost::Message;
61use rustls::pki_types::{CertificateDer, ServerName};
62use rustls::{ClientConfig, ClientConnection, StreamOwned};
63
64use crate::block_store::BlockStore;
65use crate::crypto;
66use crate::error::{FsError, FsResult};
67use crate::proto;
68
69const PIPELINE_BATCH: usize = 64;
74
75pub struct NetworkBlockStoreConfig {
79 addr: String,
80 server_name: String,
81 ca_cert: PathBuf,
82 auth_token: [u8; 32],
83 connect_timeout: Duration,
84 io_timeout: Duration,
85}
86
87impl NetworkBlockStoreConfig {
88 pub fn new(addr: impl Into<String>, server_name: impl Into<String>) -> Self {
91 Self {
92 addr: addr.into(),
93 server_name: server_name.into(),
94 ca_cert: PathBuf::new(),
95 auth_token: [0u8; 32],
96 connect_timeout: Duration::from_secs(10),
97 io_timeout: Duration::from_secs(30),
98 }
99 }
100
101 pub fn ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
102 self.ca_cert = path.into();
103 self
104 }
105
106 pub fn auth_token(mut self, master_key: &[u8]) -> Self {
108 self.auth_token = crypto::derive_auth_token(master_key)
109 .expect("HKDF auth-token derivation should not fail with valid key material");
110 self
111 }
112
113 pub fn auth_token_raw(mut self, token: [u8; 32]) -> Self {
115 self.auth_token = token;
116 self
117 }
118
119 pub fn connect_timeout(mut self, d: Duration) -> Self {
120 self.connect_timeout = d;
121 self
122 }
123
124 pub fn io_timeout(mut self, d: Duration) -> Self {
125 self.io_timeout = d;
126 self
127 }
128}
129
130pub struct NetworkBlockStore {
141 config: NetworkBlockStoreConfig,
142 tls_config: Arc<ClientConfig>,
143 stream: Mutex<Option<StreamOwned<ClientConnection, TcpStream>>>,
144 block_size: usize,
145 total_blocks: u64,
146 next_request_id: AtomicU64,
147}
148
149impl NetworkBlockStore {
150 pub fn connect(
153 addr: &str,
154 server_name: &str,
155 ca_cert: &Path,
156 master_key: &[u8],
157 ) -> FsResult<Self> {
158 Self::from_config(
159 NetworkBlockStoreConfig::new(addr, server_name)
160 .ca_cert(ca_cert)
161 .auth_token(master_key),
162 )
163 }
164
165 pub fn from_config(config: NetworkBlockStoreConfig) -> FsResult<Self> {
167 let tls_config = build_client_tls_config(&config.ca_cert)?;
168
169 let mut stream = establish_connection(&config, &tls_config)?;
171
172 authenticate(&mut stream, &config.auth_token)?;
174
175 let req = proto::Request {
177 request_id: 2,
178 command: Some(proto::request::Command::GetInfo(proto::GetInfoRequest {})),
179 };
180 send_message(&mut stream, &req)?;
181 let resp = recv_message(&mut stream)?;
182
183 let (block_size, total_blocks) = match resp.result {
184 Some(proto::response::Result::GetInfo(info)) => {
185 (info.block_size as usize, info.total_blocks)
186 }
187 Some(proto::response::Result::Error(e)) => {
188 return Err(FsError::Internal(format!(
189 "server error on GetInfo: {}",
190 e.message
191 )))
192 }
193 _ => return Err(FsError::Internal("unexpected response to GetInfo".into())),
194 };
195
196 Ok(Self {
197 config,
198 tls_config,
199 stream: Mutex::new(Some(stream)),
200 block_size,
201 total_blocks,
202 next_request_id: AtomicU64::new(3),
203 })
204 }
205
206 fn next_id(&self) -> u64 {
208 self.next_request_id.fetch_add(1, Ordering::Relaxed)
209 }
210
211 fn reconnect(&self) -> FsResult<StreamOwned<ClientConnection, TcpStream>> {
214 let mut stream = establish_connection(&self.config, &self.tls_config)?;
215 authenticate(&mut stream, &self.config.auth_token)?;
216 Ok(stream)
217 }
218
219 fn roundtrip(&self, req: &proto::Request) -> FsResult<proto::Response> {
222 let mut guard = self
223 .stream
224 .lock()
225 .map_err(|e| FsError::Internal(e.to_string()))?;
226
227 if guard.is_none() {
229 *guard = Some(self.reconnect()?);
230 }
231
232 let stream = guard.as_mut().unwrap();
233 match send_and_recv(stream, req) {
234 Ok(resp) => Ok(resp),
235 Err(_) => {
236 *guard = Some(self.reconnect()?);
238 send_and_recv(guard.as_mut().unwrap(), req)
239 }
240 }
241 }
242
243 fn pipeline_reads(
247 &self,
248 stream: &mut StreamOwned<ClientConnection, TcpStream>,
249 block_ids: &[u64],
250 ) -> FsResult<Vec<Vec<u8>>> {
251 let mut results = Vec::with_capacity(block_ids.len());
252
253 for chunk in block_ids.chunks(PIPELINE_BATCH) {
254 for &block_id in chunk {
256 let id = self.next_id();
257 send_message(
258 stream,
259 &proto::Request {
260 request_id: id,
261 command: Some(proto::request::Command::ReadBlock(
262 proto::ReadBlockRequest { block_id },
263 )),
264 },
265 )?;
266 }
267
268 for _ in chunk {
270 let resp = recv_message(stream)?;
271 match resp.result {
272 Some(proto::response::Result::ReadBlock(r)) => results.push(r.data),
273 Some(proto::response::Result::Error(e)) => {
274 return Err(FsError::Internal(format!("server: {}", e.message)));
275 }
276 _ => return Err(FsError::Internal("unexpected response".into())),
277 }
278 }
279 }
280
281 Ok(results)
282 }
283
284 fn pipeline_writes(
286 &self,
287 stream: &mut StreamOwned<ClientConnection, TcpStream>,
288 blocks: &[(u64, &[u8])],
289 ) -> FsResult<()> {
290 for chunk in blocks.chunks(PIPELINE_BATCH) {
291 for &(block_id, data) in chunk {
292 let id = self.next_id();
293 send_message(
294 stream,
295 &proto::Request {
296 request_id: id,
297 command: Some(proto::request::Command::WriteBlock(
298 proto::WriteBlockRequest {
299 block_id,
300 data: data.to_vec(),
301 },
302 )),
303 },
304 )?;
305 }
306
307 for _ in chunk {
308 let resp = recv_message(stream)?;
309 match resp.result {
310 Some(proto::response::Result::WriteBlock(_)) => {}
311 Some(proto::response::Result::Error(e)) => {
312 return Err(FsError::Internal(format!("server: {}", e.message)));
313 }
314 _ => return Err(FsError::Internal("unexpected response".into())),
315 }
316 }
317 }
318
319 Ok(())
320 }
321
322 fn with_pipeline<F, T>(&self, op: F) -> FsResult<T>
324 where
325 F: Fn(&Self, &mut StreamOwned<ClientConnection, TcpStream>) -> FsResult<T>,
326 {
327 let mut guard = self
328 .stream
329 .lock()
330 .map_err(|e| FsError::Internal(e.to_string()))?;
331
332 if guard.is_none() {
333 *guard = Some(self.reconnect()?);
334 }
335
336 match op(self, guard.as_mut().unwrap()) {
337 Ok(v) => Ok(v),
338 Err(_) => {
339 *guard = Some(self.reconnect()?);
340 op(self, guard.as_mut().unwrap())
341 }
342 }
343 }
344}
345
346impl BlockStore for NetworkBlockStore {
347 fn block_size(&self) -> usize {
348 self.block_size
349 }
350
351 fn total_blocks(&self) -> u64 {
352 self.total_blocks
353 }
354
355 fn read_block(&self, block_id: u64) -> FsResult<Vec<u8>> {
356 let id = self.next_id();
357 let req = proto::Request {
358 request_id: id,
359 command: Some(proto::request::Command::ReadBlock(
360 proto::ReadBlockRequest { block_id },
361 )),
362 };
363 let resp = self.roundtrip(&req)?;
364
365 match resp.result {
366 Some(proto::response::Result::ReadBlock(r)) => Ok(r.data),
367 Some(proto::response::Result::Error(e)) => {
368 Err(FsError::Internal(format!("server: {}", e.message)))
369 }
370 _ => Err(FsError::Internal("unexpected response".into())),
371 }
372 }
373
374 fn write_block(&self, block_id: u64, data: &[u8]) -> FsResult<()> {
375 let id = self.next_id();
376 let req = proto::Request {
377 request_id: id,
378 command: Some(proto::request::Command::WriteBlock(
379 proto::WriteBlockRequest {
380 block_id,
381 data: data.to_vec(),
382 },
383 )),
384 };
385 let resp = self.roundtrip(&req)?;
386
387 match resp.result {
388 Some(proto::response::Result::WriteBlock(_)) => Ok(()),
389 Some(proto::response::Result::Error(e)) => {
390 Err(FsError::Internal(format!("server: {}", e.message)))
391 }
392 _ => Err(FsError::Internal("unexpected response".into())),
393 }
394 }
395
396 fn sync(&self) -> FsResult<()> {
397 let id = self.next_id();
398 let req = proto::Request {
399 request_id: id,
400 command: Some(proto::request::Command::Sync(proto::SyncRequest {})),
401 };
402 let resp = self.roundtrip(&req)?;
403
404 match resp.result {
405 Some(proto::response::Result::Sync(_)) => Ok(()),
406 Some(proto::response::Result::Error(e)) => {
407 Err(FsError::Internal(format!("server: {}", e.message)))
408 }
409 _ => Err(FsError::Internal("unexpected response".into())),
410 }
411 }
412
413 fn read_blocks(&self, block_ids: &[u64]) -> FsResult<Vec<Vec<u8>>> {
414 if block_ids.is_empty() {
415 return Ok(Vec::new());
416 }
417 self.with_pipeline(|s, stream| s.pipeline_reads(stream, block_ids))
418 }
419
420 fn write_blocks(&self, blocks: &[(u64, &[u8])]) -> FsResult<()> {
421 if blocks.is_empty() {
422 return Ok(());
423 }
424 self.with_pipeline(|s, stream| s.pipeline_writes(stream, blocks))
425 }
426}
427
428fn authenticate(
432 stream: &mut StreamOwned<ClientConnection, TcpStream>,
433 auth_token: &[u8; 32],
434) -> FsResult<()> {
435 let req = proto::Request {
436 request_id: 2,
437 command: Some(proto::request::Command::Authenticate(
438 proto::AuthenticateRequest {
439 auth_token: auth_token.to_vec(),
440 },
441 )),
442 };
443 send_message(stream, &req)?;
444 let resp = recv_message(stream)?;
445
446 match resp.result {
447 Some(proto::response::Result::Authenticate(_)) => Ok(()),
448 Some(proto::response::Result::Error(e)) => Err(FsError::Internal(format!(
449 "authentication failed: {}",
450 e.message
451 ))),
452 _ => Err(FsError::Internal(
453 "unexpected response to Authenticate".into(),
454 )),
455 }
456}
457
458fn send_message<W: Write>(w: &mut W, msg: &proto::Request) -> FsResult<()> {
461 let payload = msg.encode_to_vec();
462 let len = payload.len() as u32;
463 w.write_all(&len.to_le_bytes())
464 .map_err(|e| FsError::Internal(format!("write length prefix: {e}")))?;
465 w.write_all(&payload)
466 .map_err(|e| FsError::Internal(format!("write payload: {e}")))?;
467 w.flush()
468 .map_err(|e| FsError::Internal(format!("flush: {e}")))?;
469 Ok(())
470}
471
472fn recv_message<R: Read>(r: &mut R) -> FsResult<proto::Response> {
473 let mut len_buf = [0u8; 4];
474 r.read_exact(&mut len_buf)
475 .map_err(|e| FsError::Internal(format!("read length prefix: {e}")))?;
476 let len = u32::from_le_bytes(len_buf) as usize;
477
478 if len > 16 * 1024 * 1024 {
479 return Err(FsError::Internal(format!(
480 "response too large: {len} bytes"
481 )));
482 }
483
484 let mut buf = vec![0u8; len];
485 r.read_exact(&mut buf)
486 .map_err(|e| FsError::Internal(format!("read payload: {e}")))?;
487
488 proto::Response::decode(&*buf).map_err(|e| FsError::Internal(format!("decode response: {e}")))
489}
490
491fn send_and_recv(
492 stream: &mut StreamOwned<ClientConnection, TcpStream>,
493 req: &proto::Request,
494) -> FsResult<proto::Response> {
495 send_message(stream, req)?;
496 recv_message(stream)
497}
498
499fn establish_connection(
502 config: &NetworkBlockStoreConfig,
503 tls_config: &Arc<ClientConfig>,
504) -> FsResult<StreamOwned<ClientConnection, TcpStream>> {
505 let addr = config
506 .addr
507 .to_socket_addrs()
508 .map_err(|e| FsError::Internal(format!("resolve {}: {e}", config.addr)))?
509 .next()
510 .ok_or_else(|| FsError::Internal(format!("no addresses for {}", config.addr)))?;
511
512 let tcp = TcpStream::connect_timeout(&addr, config.connect_timeout)
513 .map_err(|e| FsError::Internal(format!("connect to {}: {e}", config.addr)))?;
514
515 tcp.set_read_timeout(Some(config.io_timeout))
516 .map_err(|e| FsError::Internal(format!("set read timeout: {e}")))?;
517 tcp.set_write_timeout(Some(config.io_timeout))
518 .map_err(|e| FsError::Internal(format!("set write timeout: {e}")))?;
519
520 let sni = ServerName::try_from(config.server_name.clone())
521 .map_err(|e| FsError::Internal(format!("invalid SNI '{}': {e}", config.server_name)))?;
522
523 let tls_conn = ClientConnection::new(Arc::clone(tls_config), sni)
524 .map_err(|e| FsError::Internal(format!("TLS connection: {e}")))?;
525
526 Ok(StreamOwned::new(tls_conn, tcp))
527}
528
529fn build_client_tls_config(ca_path: &Path) -> FsResult<Arc<ClientConfig>> {
532 let ca_pem = std::fs::read(ca_path)
533 .map_err(|e| FsError::Internal(format!("read CA cert {}: {e}", ca_path.display())))?;
534 let ca_certs: Vec<CertificateDer<'static>> =
535 rustls_pemfile::certs(&mut BufReader::new(&*ca_pem))
536 .collect::<std::result::Result<Vec<_>, _>>()
537 .map_err(|e| FsError::Internal(format!("parse CA certs: {e}")))?;
538
539 let mut root_store = rustls::RootCertStore::empty();
540 for cert in ca_certs {
541 root_store
542 .add(cert)
543 .map_err(|e| FsError::Internal(format!("add CA cert: {e}")))?;
544 }
545
546 let config = ClientConfig::builder()
547 .with_root_certificates(root_store)
548 .with_no_client_auth();
549
550 Ok(Arc::new(config))
551}