1use std::io::{BufReader, Read, Write};
54use std::net::{TcpStream, ToSocketAddrs};
55use std::path::{Path, PathBuf};
56use std::sync::atomic::{AtomicU64, Ordering};
57use std::sync::{Arc, Mutex};
58use std::time::Duration;
59
60use prost::Message;
61use rustls::pki_types::{CertificateDer, ServerName};
62use rustls::{ClientConfig, ClientConnection, StreamOwned};
63
64use crate::block_store::BlockStore;
65use crate::crypto;
66use crate::error::{FsError, FsResult};
67use crate::proto;
68
69const PIPELINE_BATCH: usize = 64;
74
75pub struct NetworkBlockStoreConfig {
79 addr: String,
80 server_name: String,
81 ca_cert: PathBuf,
82 auth_token: [u8; 32],
83 connect_timeout: Duration,
84 io_timeout: Duration,
85}
86
87impl NetworkBlockStoreConfig {
88 pub fn new(addr: impl Into<String>, server_name: impl Into<String>) -> Self {
91 Self {
92 addr: addr.into(),
93 server_name: server_name.into(),
94 ca_cert: PathBuf::new(),
95 auth_token: [0u8; 32],
96 connect_timeout: Duration::from_secs(10),
97 io_timeout: Duration::from_secs(30),
98 }
99 }
100
101 pub fn ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
102 self.ca_cert = path.into();
103 self
104 }
105
106 pub fn auth_token(mut self, master_key: &[u8]) -> Self {
108 self.auth_token = crypto::derive_auth_token(master_key)
109 .expect("HKDF auth-token derivation should not fail with valid key material");
110 self
111 }
112
113 pub fn auth_token_raw(mut self, token: [u8; 32]) -> Self {
115 self.auth_token = token;
116 self
117 }
118
119 pub fn connect_timeout(mut self, d: Duration) -> Self {
120 self.connect_timeout = d;
121 self
122 }
123
124 pub fn io_timeout(mut self, d: Duration) -> Self {
125 self.io_timeout = d;
126 self
127 }
128}
129
130pub struct NetworkBlockStore {
141 config: NetworkBlockStoreConfig,
142 tls_config: Arc<ClientConfig>,
143 stream: Mutex<Option<StreamOwned<ClientConnection, TcpStream>>>,
144 block_size: usize,
145 total_blocks: u64,
146 next_request_id: AtomicU64,
147}
148
149impl NetworkBlockStore {
150 pub fn connect(
153 addr: &str,
154 server_name: &str,
155 ca_cert: &Path,
156 master_key: &[u8],
157 ) -> FsResult<Self> {
158 Self::from_config(
159 NetworkBlockStoreConfig::new(addr, server_name)
160 .ca_cert(ca_cert)
161 .auth_token(master_key),
162 )
163 }
164
165 pub fn from_config(config: NetworkBlockStoreConfig) -> FsResult<Self> {
167 let tls_config = build_client_tls_config(&config.ca_cert)?;
168
169 let mut stream = establish_connection(&config, &tls_config)?;
171
172 authenticate(&mut stream, &config.auth_token)?;
174
175 let req = proto::Request {
177 request_id: 2,
178 command: Some(proto::request::Command::GetInfo(proto::GetInfoRequest {})),
179 };
180 send_message(&mut stream, &req)?;
181 let resp = recv_message(&mut stream)?;
182
183 let (block_size, total_blocks) = match resp.result {
184 Some(proto::response::Result::GetInfo(info)) => {
185 (info.block_size as usize, info.total_blocks)
186 }
187 Some(proto::response::Result::Error(e)) => {
188 return Err(FsError::Internal(format!(
189 "server error on GetInfo: {}",
190 e.message
191 )))
192 }
193 _ => return Err(FsError::Internal("unexpected response to GetInfo".into())),
194 };
195
196 Ok(Self {
197 config,
198 tls_config,
199 stream: Mutex::new(Some(stream)),
200 block_size,
201 total_blocks,
202 next_request_id: AtomicU64::new(3),
203 })
204 }
205
206 fn next_id(&self) -> u64 {
208 self.next_request_id.fetch_add(1, Ordering::Relaxed)
209 }
210
211 fn reconnect(&self) -> FsResult<StreamOwned<ClientConnection, TcpStream>> {
214 let mut stream = establish_connection(&self.config, &self.tls_config)?;
215 authenticate(&mut stream, &self.config.auth_token)?;
216 Ok(stream)
217 }
218
219 fn roundtrip(&self, req: &proto::Request) -> FsResult<proto::Response> {
222 let mut guard = self
223 .stream
224 .lock()
225 .map_err(|e| FsError::Internal(e.to_string()))?;
226
227 if guard.is_none() {
229 *guard = Some(self.reconnect()?);
230 }
231
232 let stream = guard.as_mut().unwrap();
233 match send_and_recv(stream, req) {
234 Ok(resp) => Ok(resp),
235 Err(_) => {
236 *guard = Some(self.reconnect()?);
238 send_and_recv(guard.as_mut().unwrap(), req)
239 }
240 }
241 }
242
243 fn pipeline_reads(
247 &self,
248 stream: &mut StreamOwned<ClientConnection, TcpStream>,
249 block_ids: &[u64],
250 ) -> FsResult<Vec<Vec<u8>>> {
251 let mut results = Vec::with_capacity(block_ids.len());
252
253 for chunk in block_ids.chunks(PIPELINE_BATCH) {
254 {
257 let mut bw = std::io::BufWriter::with_capacity(
258 chunk.len() * 32, &mut *stream,
260 );
261 for &block_id in chunk {
262 let id = self.next_id();
263 send_message_no_flush(
264 &mut bw,
265 &proto::Request {
266 request_id: id,
267 command: Some(proto::request::Command::ReadBlock(
268 proto::ReadBlockRequest { block_id },
269 )),
270 },
271 )?;
272 }
273 bw.flush()
274 .map_err(|e| FsError::Internal(format!("flush pipeline: {e}")))?;
275 }
276
277 for _ in chunk {
279 let resp = recv_message(stream)?;
280 match resp.result {
281 Some(proto::response::Result::ReadBlock(r)) => results.push(r.data),
282 Some(proto::response::Result::Error(e)) => {
283 return Err(FsError::Internal(format!("server: {}", e.message)));
284 }
285 _ => return Err(FsError::Internal("unexpected response".into())),
286 }
287 }
288 }
289
290 Ok(results)
291 }
292
293 fn pipeline_writes(
295 &self,
296 stream: &mut StreamOwned<ClientConnection, TcpStream>,
297 blocks: &[(u64, &[u8])],
298 ) -> FsResult<()> {
299 for chunk in blocks.chunks(PIPELINE_BATCH) {
300 {
303 let mut bw = std::io::BufWriter::with_capacity(
304 chunk.len() * (32 + chunk.first().map_or(0, |(_, d)| d.len())),
305 &mut *stream,
306 );
307 for &(block_id, data) in chunk {
308 let id = self.next_id();
309 send_message_no_flush(
310 &mut bw,
311 &proto::Request {
312 request_id: id,
313 command: Some(proto::request::Command::WriteBlock(
314 proto::WriteBlockRequest {
315 block_id,
316 data: data.to_vec(),
317 },
318 )),
319 },
320 )?;
321 }
322 bw.flush()
323 .map_err(|e| FsError::Internal(format!("flush pipeline: {e}")))?;
324 }
325
326 for _ in chunk {
327 let resp = recv_message(stream)?;
328 match resp.result {
329 Some(proto::response::Result::WriteBlock(_)) => {}
330 Some(proto::response::Result::Error(e)) => {
331 return Err(FsError::Internal(format!("server: {}", e.message)));
332 }
333 _ => return Err(FsError::Internal("unexpected response".into())),
334 }
335 }
336 }
337
338 Ok(())
339 }
340
341 fn with_pipeline<F, T>(&self, op: F) -> FsResult<T>
343 where
344 F: Fn(&Self, &mut StreamOwned<ClientConnection, TcpStream>) -> FsResult<T>,
345 {
346 let mut guard = self
347 .stream
348 .lock()
349 .map_err(|e| FsError::Internal(e.to_string()))?;
350
351 if guard.is_none() {
352 *guard = Some(self.reconnect()?);
353 }
354
355 match op(self, guard.as_mut().unwrap()) {
356 Ok(v) => Ok(v),
357 Err(_) => {
358 *guard = Some(self.reconnect()?);
359 op(self, guard.as_mut().unwrap())
360 }
361 }
362 }
363}
364
365impl BlockStore for NetworkBlockStore {
366 fn block_size(&self) -> usize {
367 self.block_size
368 }
369
370 fn total_blocks(&self) -> u64 {
371 self.total_blocks
372 }
373
374 fn read_block(&self, block_id: u64) -> FsResult<Vec<u8>> {
375 let id = self.next_id();
376 let req = proto::Request {
377 request_id: id,
378 command: Some(proto::request::Command::ReadBlock(
379 proto::ReadBlockRequest { block_id },
380 )),
381 };
382 let resp = self.roundtrip(&req)?;
383
384 match resp.result {
385 Some(proto::response::Result::ReadBlock(r)) => Ok(r.data),
386 Some(proto::response::Result::Error(e)) => {
387 Err(FsError::Internal(format!("server: {}", e.message)))
388 }
389 _ => Err(FsError::Internal("unexpected response".into())),
390 }
391 }
392
393 fn write_block(&self, block_id: u64, data: &[u8]) -> FsResult<()> {
394 let id = self.next_id();
395 let req = proto::Request {
396 request_id: id,
397 command: Some(proto::request::Command::WriteBlock(
398 proto::WriteBlockRequest {
399 block_id,
400 data: data.to_vec(),
401 },
402 )),
403 };
404 let resp = self.roundtrip(&req)?;
405
406 match resp.result {
407 Some(proto::response::Result::WriteBlock(_)) => Ok(()),
408 Some(proto::response::Result::Error(e)) => {
409 Err(FsError::Internal(format!("server: {}", e.message)))
410 }
411 _ => Err(FsError::Internal("unexpected response".into())),
412 }
413 }
414
415 fn sync(&self) -> FsResult<()> {
416 let id = self.next_id();
417 let req = proto::Request {
418 request_id: id,
419 command: Some(proto::request::Command::Sync(proto::SyncRequest {})),
420 };
421 let resp = self.roundtrip(&req)?;
422
423 match resp.result {
424 Some(proto::response::Result::Sync(_)) => Ok(()),
425 Some(proto::response::Result::Error(e)) => {
426 Err(FsError::Internal(format!("server: {}", e.message)))
427 }
428 _ => Err(FsError::Internal("unexpected response".into())),
429 }
430 }
431
432 fn read_blocks(&self, block_ids: &[u64]) -> FsResult<Vec<Vec<u8>>> {
433 if block_ids.is_empty() {
434 return Ok(Vec::new());
435 }
436 self.with_pipeline(|s, stream| s.pipeline_reads(stream, block_ids))
437 }
438
439 fn write_blocks(&self, blocks: &[(u64, &[u8])]) -> FsResult<()> {
440 if blocks.is_empty() {
441 return Ok(());
442 }
443 self.with_pipeline(|s, stream| s.pipeline_writes(stream, blocks))
444 }
445}
446
447fn authenticate(
451 stream: &mut StreamOwned<ClientConnection, TcpStream>,
452 auth_token: &[u8; 32],
453) -> FsResult<()> {
454 let req = proto::Request {
455 request_id: 2,
456 command: Some(proto::request::Command::Authenticate(
457 proto::AuthenticateRequest {
458 auth_token: auth_token.to_vec(),
459 },
460 )),
461 };
462 send_message(stream, &req)?;
463 let resp = recv_message(stream)?;
464
465 match resp.result {
466 Some(proto::response::Result::Authenticate(_)) => Ok(()),
467 Some(proto::response::Result::Error(e)) => Err(FsError::Internal(format!(
468 "authentication failed: {}",
469 e.message
470 ))),
471 _ => Err(FsError::Internal(
472 "unexpected response to Authenticate".into(),
473 )),
474 }
475}
476
477fn send_message<W: Write>(w: &mut W, msg: &proto::Request) -> FsResult<()> {
480 let payload = msg.encode_to_vec();
481 let len = payload.len() as u32;
482 w.write_all(&len.to_le_bytes())
483 .map_err(|e| FsError::Internal(format!("write length prefix: {e}")))?;
484 w.write_all(&payload)
485 .map_err(|e| FsError::Internal(format!("write payload: {e}")))?;
486 w.flush()
487 .map_err(|e| FsError::Internal(format!("flush: {e}")))?;
488 Ok(())
489}
490
491fn send_message_no_flush<W: Write>(w: &mut W, msg: &proto::Request) -> FsResult<()> {
494 let payload = msg.encode_to_vec();
495 let len = payload.len() as u32;
496 w.write_all(&len.to_le_bytes())
497 .map_err(|e| FsError::Internal(format!("write length prefix: {e}")))?;
498 w.write_all(&payload)
499 .map_err(|e| FsError::Internal(format!("write payload: {e}")))?;
500 Ok(())
501}
502
503fn recv_message<R: Read>(r: &mut R) -> FsResult<proto::Response> {
504 let mut len_buf = [0u8; 4];
505 r.read_exact(&mut len_buf)
506 .map_err(|e| FsError::Internal(format!("read length prefix: {e}")))?;
507 let len = u32::from_le_bytes(len_buf) as usize;
508
509 if len > 16 * 1024 * 1024 {
510 return Err(FsError::Internal(format!(
511 "response too large: {len} bytes"
512 )));
513 }
514
515 let mut buf = vec![0u8; len];
516 r.read_exact(&mut buf)
517 .map_err(|e| FsError::Internal(format!("read payload: {e}")))?;
518
519 proto::Response::decode(&*buf).map_err(|e| FsError::Internal(format!("decode response: {e}")))
520}
521
522fn send_and_recv(
523 stream: &mut StreamOwned<ClientConnection, TcpStream>,
524 req: &proto::Request,
525) -> FsResult<proto::Response> {
526 send_message(stream, req)?;
527 recv_message(stream)
528}
529
530fn establish_connection(
533 config: &NetworkBlockStoreConfig,
534 tls_config: &Arc<ClientConfig>,
535) -> FsResult<StreamOwned<ClientConnection, TcpStream>> {
536 let addr = config
537 .addr
538 .to_socket_addrs()
539 .map_err(|e| FsError::Internal(format!("resolve {}: {e}", config.addr)))?
540 .next()
541 .ok_or_else(|| FsError::Internal(format!("no addresses for {}", config.addr)))?;
542
543 let tcp = TcpStream::connect_timeout(&addr, config.connect_timeout)
544 .map_err(|e| FsError::Internal(format!("connect to {}: {e}", config.addr)))?;
545
546 tcp.set_read_timeout(Some(config.io_timeout))
547 .map_err(|e| FsError::Internal(format!("set read timeout: {e}")))?;
548 tcp.set_write_timeout(Some(config.io_timeout))
549 .map_err(|e| FsError::Internal(format!("set write timeout: {e}")))?;
550
551 let sni = ServerName::try_from(config.server_name.clone())
552 .map_err(|e| FsError::Internal(format!("invalid SNI '{}': {e}", config.server_name)))?;
553
554 let tls_conn = ClientConnection::new(Arc::clone(tls_config), sni)
555 .map_err(|e| FsError::Internal(format!("TLS connection: {e}")))?;
556
557 Ok(StreamOwned::new(tls_conn, tcp))
558}
559
560fn build_client_tls_config(ca_path: &Path) -> FsResult<Arc<ClientConfig>> {
563 let mut root_store = rustls::RootCertStore::empty();
564
565 if ca_path.as_os_str().is_empty() {
566 let native_certs = rustls_native_certs::load_native_certs();
568 for cert in native_certs.certs {
569 root_store
570 .add(cert)
571 .map_err(|e| FsError::Internal(format!("add native CA cert: {e}")))?;
572 }
573 if root_store.is_empty() {
574 return Err(FsError::Internal("no system CA certificates found".into()));
575 }
576 } else {
577 let ca_pem = std::fs::read(ca_path)
579 .map_err(|e| FsError::Internal(format!("read CA cert {}: {e}", ca_path.display())))?;
580 let ca_certs: Vec<CertificateDer<'static>> =
581 rustls_pemfile::certs(&mut BufReader::new(&*ca_pem))
582 .collect::<std::result::Result<Vec<_>, _>>()
583 .map_err(|e| FsError::Internal(format!("parse CA certs: {e}")))?;
584
585 for cert in ca_certs {
586 root_store
587 .add(cert)
588 .map_err(|e| FsError::Internal(format!("add CA cert: {e}")))?;
589 }
590 }
591
592 let config = ClientConfig::builder()
593 .with_root_certificates(root_store)
594 .with_no_client_auth();
595
596 Ok(Arc::new(config))
597}