Skip to main content

doublecrypt_core/
network_store.rs

1//! Network-backed block store that connects to a `doublecrypt-server` over TLS.
2//!
3//! Uses the 4-byte little-endian length-prefixed protobuf protocol defined in
4//! `proto/blockstore.proto`.  The connection is synchronous (matching the
5//! [`BlockStore`] trait) but supports:
6//!
7//! * **Request pipelining** — [`read_blocks`](BlockStore::read_blocks) and
8//!   [`write_blocks`](BlockStore::write_blocks) send a full batch of requests
9//!   before reading any responses, eliminating per-block round-trip latency.
10//! * **Automatic reconnection** — a single retry on I/O failure with a fresh
11//!   TLS handshake (re-authenticates automatically).
12//! * **Configurable timeouts** — connect, read, and write deadlines.
13//! * **Key-derived authentication** — after the TLS handshake, the client sends
14//!   an `Authenticate` request containing a token derived from the master key
15//!   via HKDF (see [`derive_auth_token`](crate::crypto::derive_auth_token)).
16//!   This proves possession of the encryption key without revealing it.
17//!
18//! # Quick start
19//!
20//! ```no_run
21//! use std::path::Path;
22//! use doublecrypt_core::network_store::NetworkBlockStore;
23//! use doublecrypt_core::block_store::BlockStore;
24//!
25//! let master_key = [0u8; 32];
26//! let store = NetworkBlockStore::connect(
27//!     "127.0.0.1:9100",
28//!     "localhost",
29//!     Path::new("certs/ca.pem"),
30//!     &master_key,
31//! ).expect("connect to server");
32//!
33//! let data = store.read_block(0).expect("read block 0");
34//! ```
35//!
36//! # Builder
37//!
38//! ```no_run
39//! use std::time::Duration;
40//! use doublecrypt_core::network_store::{NetworkBlockStore, NetworkBlockStoreConfig};
41//! use doublecrypt_core::block_store::BlockStore;
42//!
43//! let master_key = [0u8; 32];
44//! let store = NetworkBlockStore::from_config(
45//!     NetworkBlockStoreConfig::new("10.0.0.5:9100", "block-server")
46//!         .ca_cert("certs/ca.pem")
47//!         .auth_token(&master_key)
48//!         .connect_timeout(Duration::from_secs(5))
49//!         .io_timeout(Duration::from_secs(60)),
50//! ).expect("connect to server");
51//! ```
52
53use std::io::{BufReader, Read, Write};
54use std::net::{TcpStream, ToSocketAddrs};
55use std::path::{Path, PathBuf};
56use std::sync::atomic::{AtomicU64, Ordering};
57use std::sync::{Arc, Mutex};
58use std::time::Duration;
59
60use prost::Message;
61use rustls::pki_types::{CertificateDer, ServerName};
62use rustls::{ClientConfig, ClientConnection, StreamOwned};
63
64use crate::block_store::BlockStore;
65use crate::crypto;
66use crate::error::{FsError, FsResult};
67use crate::proto;
68
69/// Maximum number of requests to pipeline before reading responses.
70///
71/// Keeps TCP buffer usage bounded and avoids deadlocks when the kernel
72/// send/receive buffers are smaller than the total pipelined payload.
73const PIPELINE_BATCH: usize = 64;
74
75// ── Configuration ───────────────────────────────────────────
76
77/// Connection parameters for a [`NetworkBlockStore`].
78pub struct NetworkBlockStoreConfig {
79    addr: String,
80    server_name: String,
81    ca_cert: PathBuf,
82    auth_token: [u8; 32],
83    connect_timeout: Duration,
84    io_timeout: Duration,
85}
86
87impl NetworkBlockStoreConfig {
88    /// Create a config targeting `addr` (`"host:port"`) with the given TLS
89    /// server name (SNI).  Timeouts default to 10 s (connect) and 30 s (I/O).
90    pub fn new(addr: impl Into<String>, server_name: impl Into<String>) -> Self {
91        Self {
92            addr: addr.into(),
93            server_name: server_name.into(),
94            ca_cert: PathBuf::new(),
95            auth_token: [0u8; 32],
96            connect_timeout: Duration::from_secs(10),
97            io_timeout: Duration::from_secs(30),
98        }
99    }
100
101    pub fn ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
102        self.ca_cert = path.into();
103        self
104    }
105
106    /// Set the auth token by deriving it from the given master key.
107    pub fn auth_token(mut self, master_key: &[u8]) -> Self {
108        self.auth_token = crypto::derive_auth_token(master_key)
109            .expect("HKDF auth-token derivation should not fail with valid key material");
110        self
111    }
112
113    /// Set a pre-derived auth token directly.
114    pub fn auth_token_raw(mut self, token: [u8; 32]) -> Self {
115        self.auth_token = token;
116        self
117    }
118
119    pub fn connect_timeout(mut self, d: Duration) -> Self {
120        self.connect_timeout = d;
121        self
122    }
123
124    pub fn io_timeout(mut self, d: Duration) -> Self {
125        self.io_timeout = d;
126        self
127    }
128}
129
130// ── Store ───────────────────────────────────────────────────
131
132/// A [`BlockStore`] backed by a remote `doublecrypt-server` reached over TLS
133/// with key-derived authentication.
134///
135/// On construction the client performs a TLS handshake, issues a `GetInfo`
136/// RPC to learn the block size and total block count, then sends an
137/// `Authenticate` request with a token derived from the master key.
138/// The connection is stored and reused; if it breaks, one automatic
139/// reconnect (including re-authentication) is attempted.
140pub struct NetworkBlockStore {
141    config: NetworkBlockStoreConfig,
142    tls_config: Arc<ClientConfig>,
143    stream: Mutex<Option<StreamOwned<ClientConnection, TcpStream>>>,
144    block_size: usize,
145    total_blocks: u64,
146    next_request_id: AtomicU64,
147}
148
149impl NetworkBlockStore {
150    /// Connect to a `doublecrypt-server` using TLS with key-derived
151    /// authentication (convenience wrapper around [`from_config`](Self::from_config)).
152    pub fn connect(
153        addr: &str,
154        server_name: &str,
155        ca_cert: &Path,
156        master_key: &[u8],
157    ) -> FsResult<Self> {
158        Self::from_config(
159            NetworkBlockStoreConfig::new(addr, server_name)
160                .ca_cert(ca_cert)
161                .auth_token(master_key),
162        )
163    }
164
165    /// Connect using a [`NetworkBlockStoreConfig`].
166    pub fn from_config(config: NetworkBlockStoreConfig) -> FsResult<Self> {
167        let tls_config = build_client_tls_config(&config.ca_cert)?;
168
169        // Establish the initial connection.
170        let mut stream = establish_connection(&config, &tls_config)?;
171
172        // Authenticate with the key-derived token.
173        authenticate(&mut stream, &config.auth_token)?;
174
175        // Issue GetInfo to learn block geometry.
176        let req = proto::Request {
177            request_id: 2,
178            command: Some(proto::request::Command::GetInfo(proto::GetInfoRequest {})),
179        };
180        send_message(&mut stream, &req)?;
181        let resp = recv_message(&mut stream)?;
182
183        let (block_size, total_blocks) = match resp.result {
184            Some(proto::response::Result::GetInfo(info)) => {
185                (info.block_size as usize, info.total_blocks)
186            }
187            Some(proto::response::Result::Error(e)) => {
188                return Err(FsError::Internal(format!(
189                    "server error on GetInfo: {}",
190                    e.message
191                )))
192            }
193            _ => return Err(FsError::Internal("unexpected response to GetInfo".into())),
194        };
195
196        Ok(Self {
197            config,
198            tls_config,
199            stream: Mutex::new(Some(stream)),
200            block_size,
201            total_blocks,
202            next_request_id: AtomicU64::new(3),
203        })
204    }
205
206    /// Allocate a monotonically increasing request ID.
207    fn next_id(&self) -> u64 {
208        self.next_request_id.fetch_add(1, Ordering::Relaxed)
209    }
210
211    /// Establish a fresh TLS connection using stored config, including
212    /// re-authentication.
213    fn reconnect(&self) -> FsResult<StreamOwned<ClientConnection, TcpStream>> {
214        let mut stream = establish_connection(&self.config, &self.tls_config)?;
215        authenticate(&mut stream, &self.config.auth_token)?;
216        Ok(stream)
217    }
218
219    /// Send a single request and receive its response, retrying once on I/O
220    /// failure by reconnecting.
221    fn roundtrip(&self, req: &proto::Request) -> FsResult<proto::Response> {
222        let mut guard = self
223            .stream
224            .lock()
225            .map_err(|e| FsError::Internal(e.to_string()))?;
226
227        // Ensure we have a live connection.
228        if guard.is_none() {
229            *guard = Some(self.reconnect()?);
230        }
231
232        let stream = guard.as_mut().unwrap();
233        match send_and_recv(stream, req) {
234            Ok(resp) => Ok(resp),
235            Err(_) => {
236                // Connection may be dead — reconnect and retry once.
237                *guard = Some(self.reconnect()?);
238                send_and_recv(guard.as_mut().unwrap(), req)
239            }
240        }
241    }
242
243    // ── Pipelined helpers ───────────────────────────────────
244
245    /// Pipeline a batch of read requests on `stream`.
246    fn pipeline_reads(
247        &self,
248        stream: &mut StreamOwned<ClientConnection, TcpStream>,
249        block_ids: &[u64],
250    ) -> FsResult<Vec<Vec<u8>>> {
251        let mut results = Vec::with_capacity(block_ids.len());
252
253        for chunk in block_ids.chunks(PIPELINE_BATCH) {
254            // Wrap in a BufWriter so all requests in this chunk are
255            // coalesced into fewer TLS records (one flush at the end).
256            {
257                let mut bw = std::io::BufWriter::with_capacity(
258                    chunk.len() * 32, // read requests are small
259                    &mut *stream,
260                );
261                for &block_id in chunk {
262                    let id = self.next_id();
263                    send_message_no_flush(
264                        &mut bw,
265                        &proto::Request {
266                            request_id: id,
267                            command: Some(proto::request::Command::ReadBlock(
268                                proto::ReadBlockRequest { block_id },
269                            )),
270                        },
271                    )?;
272                }
273                bw.flush()
274                    .map_err(|e| FsError::Internal(format!("flush pipeline: {e}")))?;
275            }
276
277            // Read all responses.
278            for _ in chunk {
279                let resp = recv_message(stream)?;
280                match resp.result {
281                    Some(proto::response::Result::ReadBlock(r)) => results.push(r.data),
282                    Some(proto::response::Result::Error(e)) => {
283                        return Err(FsError::Internal(format!("server: {}", e.message)));
284                    }
285                    _ => return Err(FsError::Internal("unexpected response".into())),
286                }
287            }
288        }
289
290        Ok(results)
291    }
292
293    /// Pipeline a batch of write requests on `stream`.
294    fn pipeline_writes(
295        &self,
296        stream: &mut StreamOwned<ClientConnection, TcpStream>,
297        blocks: &[(u64, &[u8])],
298    ) -> FsResult<()> {
299        for chunk in blocks.chunks(PIPELINE_BATCH) {
300            // Wrap in a BufWriter so multiple write requests are batched
301            // into fewer TLS records.
302            {
303                let mut bw = std::io::BufWriter::with_capacity(
304                    chunk.len() * (32 + chunk.first().map_or(0, |(_, d)| d.len())),
305                    &mut *stream,
306                );
307                for &(block_id, data) in chunk {
308                    let id = self.next_id();
309                    send_message_no_flush(
310                        &mut bw,
311                        &proto::Request {
312                            request_id: id,
313                            command: Some(proto::request::Command::WriteBlock(
314                                proto::WriteBlockRequest {
315                                    block_id,
316                                    data: data.to_vec(),
317                                },
318                            )),
319                        },
320                    )?;
321                }
322                bw.flush()
323                    .map_err(|e| FsError::Internal(format!("flush pipeline: {e}")))?;
324            }
325
326            for _ in chunk {
327                let resp = recv_message(stream)?;
328                match resp.result {
329                    Some(proto::response::Result::WriteBlock(_)) => {}
330                    Some(proto::response::Result::Error(e)) => {
331                        return Err(FsError::Internal(format!("server: {}", e.message)));
332                    }
333                    _ => return Err(FsError::Internal("unexpected response".into())),
334                }
335            }
336        }
337
338        Ok(())
339    }
340
341    /// Run a pipelined operation with one reconnect attempt on failure.
342    fn with_pipeline<F, T>(&self, op: F) -> FsResult<T>
343    where
344        F: Fn(&Self, &mut StreamOwned<ClientConnection, TcpStream>) -> FsResult<T>,
345    {
346        let mut guard = self
347            .stream
348            .lock()
349            .map_err(|e| FsError::Internal(e.to_string()))?;
350
351        if guard.is_none() {
352            *guard = Some(self.reconnect()?);
353        }
354
355        match op(self, guard.as_mut().unwrap()) {
356            Ok(v) => Ok(v),
357            Err(_) => {
358                *guard = Some(self.reconnect()?);
359                op(self, guard.as_mut().unwrap())
360            }
361        }
362    }
363}
364
365impl BlockStore for NetworkBlockStore {
366    fn block_size(&self) -> usize {
367        self.block_size
368    }
369
370    fn total_blocks(&self) -> u64 {
371        self.total_blocks
372    }
373
374    fn read_block(&self, block_id: u64) -> FsResult<Vec<u8>> {
375        let id = self.next_id();
376        let req = proto::Request {
377            request_id: id,
378            command: Some(proto::request::Command::ReadBlock(
379                proto::ReadBlockRequest { block_id },
380            )),
381        };
382        let resp = self.roundtrip(&req)?;
383
384        match resp.result {
385            Some(proto::response::Result::ReadBlock(r)) => Ok(r.data),
386            Some(proto::response::Result::Error(e)) => {
387                Err(FsError::Internal(format!("server: {}", e.message)))
388            }
389            _ => Err(FsError::Internal("unexpected response".into())),
390        }
391    }
392
393    fn write_block(&self, block_id: u64, data: &[u8]) -> FsResult<()> {
394        let id = self.next_id();
395        let req = proto::Request {
396            request_id: id,
397            command: Some(proto::request::Command::WriteBlock(
398                proto::WriteBlockRequest {
399                    block_id,
400                    data: data.to_vec(),
401                },
402            )),
403        };
404        let resp = self.roundtrip(&req)?;
405
406        match resp.result {
407            Some(proto::response::Result::WriteBlock(_)) => Ok(()),
408            Some(proto::response::Result::Error(e)) => {
409                Err(FsError::Internal(format!("server: {}", e.message)))
410            }
411            _ => Err(FsError::Internal("unexpected response".into())),
412        }
413    }
414
415    fn sync(&self) -> FsResult<()> {
416        let id = self.next_id();
417        let req = proto::Request {
418            request_id: id,
419            command: Some(proto::request::Command::Sync(proto::SyncRequest {})),
420        };
421        let resp = self.roundtrip(&req)?;
422
423        match resp.result {
424            Some(proto::response::Result::Sync(_)) => Ok(()),
425            Some(proto::response::Result::Error(e)) => {
426                Err(FsError::Internal(format!("server: {}", e.message)))
427            }
428            _ => Err(FsError::Internal("unexpected response".into())),
429        }
430    }
431
432    fn read_blocks(&self, block_ids: &[u64]) -> FsResult<Vec<Vec<u8>>> {
433        if block_ids.is_empty() {
434            return Ok(Vec::new());
435        }
436        self.with_pipeline(|s, stream| s.pipeline_reads(stream, block_ids))
437    }
438
439    fn write_blocks(&self, blocks: &[(u64, &[u8])]) -> FsResult<()> {
440        if blocks.is_empty() {
441            return Ok(());
442        }
443        self.with_pipeline(|s, stream| s.pipeline_writes(stream, blocks))
444    }
445}
446
447// ── Authentication ──────────────────────────────────────────
448
449/// Send an Authenticate request and verify the server accepts it.
450fn authenticate(
451    stream: &mut StreamOwned<ClientConnection, TcpStream>,
452    auth_token: &[u8; 32],
453) -> FsResult<()> {
454    let req = proto::Request {
455        request_id: 2,
456        command: Some(proto::request::Command::Authenticate(
457            proto::AuthenticateRequest {
458                auth_token: auth_token.to_vec(),
459            },
460        )),
461    };
462    send_message(stream, &req)?;
463    let resp = recv_message(stream)?;
464
465    match resp.result {
466        Some(proto::response::Result::Authenticate(_)) => Ok(()),
467        Some(proto::response::Result::Error(e)) => Err(FsError::Internal(format!(
468            "authentication failed: {}",
469            e.message
470        ))),
471        _ => Err(FsError::Internal(
472            "unexpected response to Authenticate".into(),
473        )),
474    }
475}
476
477// ── Wire helpers ────────────────────────────────────────────
478
479fn send_message<W: Write>(w: &mut W, msg: &proto::Request) -> FsResult<()> {
480    let payload = msg.encode_to_vec();
481    let len = payload.len() as u32;
482    w.write_all(&len.to_le_bytes())
483        .map_err(|e| FsError::Internal(format!("write length prefix: {e}")))?;
484    w.write_all(&payload)
485        .map_err(|e| FsError::Internal(format!("write payload: {e}")))?;
486    w.flush()
487        .map_err(|e| FsError::Internal(format!("flush: {e}")))?;
488    Ok(())
489}
490
491/// Like `send_message` but without flushing.  Used by pipelined operations
492/// that batch many messages and flush once at the end.
493fn send_message_no_flush<W: Write>(w: &mut W, msg: &proto::Request) -> FsResult<()> {
494    let payload = msg.encode_to_vec();
495    let len = payload.len() as u32;
496    w.write_all(&len.to_le_bytes())
497        .map_err(|e| FsError::Internal(format!("write length prefix: {e}")))?;
498    w.write_all(&payload)
499        .map_err(|e| FsError::Internal(format!("write payload: {e}")))?;
500    Ok(())
501}
502
503fn recv_message<R: Read>(r: &mut R) -> FsResult<proto::Response> {
504    let mut len_buf = [0u8; 4];
505    r.read_exact(&mut len_buf)
506        .map_err(|e| FsError::Internal(format!("read length prefix: {e}")))?;
507    let len = u32::from_le_bytes(len_buf) as usize;
508
509    if len > 16 * 1024 * 1024 {
510        return Err(FsError::Internal(format!(
511            "response too large: {len} bytes"
512        )));
513    }
514
515    let mut buf = vec![0u8; len];
516    r.read_exact(&mut buf)
517        .map_err(|e| FsError::Internal(format!("read payload: {e}")))?;
518
519    proto::Response::decode(&*buf).map_err(|e| FsError::Internal(format!("decode response: {e}")))
520}
521
522fn send_and_recv(
523    stream: &mut StreamOwned<ClientConnection, TcpStream>,
524    req: &proto::Request,
525) -> FsResult<proto::Response> {
526    send_message(stream, req)?;
527    recv_message(stream)
528}
529
530// ── Connection establishment ────────────────────────────────
531
532fn establish_connection(
533    config: &NetworkBlockStoreConfig,
534    tls_config: &Arc<ClientConfig>,
535) -> FsResult<StreamOwned<ClientConnection, TcpStream>> {
536    let addr = config
537        .addr
538        .to_socket_addrs()
539        .map_err(|e| FsError::Internal(format!("resolve {}: {e}", config.addr)))?
540        .next()
541        .ok_or_else(|| FsError::Internal(format!("no addresses for {}", config.addr)))?;
542
543    let tcp = TcpStream::connect_timeout(&addr, config.connect_timeout)
544        .map_err(|e| FsError::Internal(format!("connect to {}: {e}", config.addr)))?;
545
546    tcp.set_read_timeout(Some(config.io_timeout))
547        .map_err(|e| FsError::Internal(format!("set read timeout: {e}")))?;
548    tcp.set_write_timeout(Some(config.io_timeout))
549        .map_err(|e| FsError::Internal(format!("set write timeout: {e}")))?;
550
551    let sni = ServerName::try_from(config.server_name.clone())
552        .map_err(|e| FsError::Internal(format!("invalid SNI '{}': {e}", config.server_name)))?;
553
554    let tls_conn = ClientConnection::new(Arc::clone(tls_config), sni)
555        .map_err(|e| FsError::Internal(format!("TLS connection: {e}")))?;
556
557    Ok(StreamOwned::new(tls_conn, tcp))
558}
559
560// ── TLS configuration ───────────────────────────────────────
561
562fn build_client_tls_config(ca_path: &Path) -> FsResult<Arc<ClientConfig>> {
563    let mut root_store = rustls::RootCertStore::empty();
564
565    if ca_path.as_os_str().is_empty() {
566        // Use system CA certificates.
567        let native_certs = rustls_native_certs::load_native_certs();
568        for cert in native_certs.certs {
569            root_store
570                .add(cert)
571                .map_err(|e| FsError::Internal(format!("add native CA cert: {e}")))?;
572        }
573        if root_store.is_empty() {
574            return Err(FsError::Internal("no system CA certificates found".into()));
575        }
576    } else {
577        // Use the provided custom CA certificate file.
578        let ca_pem = std::fs::read(ca_path)
579            .map_err(|e| FsError::Internal(format!("read CA cert {}: {e}", ca_path.display())))?;
580        let ca_certs: Vec<CertificateDer<'static>> =
581            rustls_pemfile::certs(&mut BufReader::new(&*ca_pem))
582                .collect::<std::result::Result<Vec<_>, _>>()
583                .map_err(|e| FsError::Internal(format!("parse CA certs: {e}")))?;
584
585        for cert in ca_certs {
586            root_store
587                .add(cert)
588                .map_err(|e| FsError::Internal(format!("add CA cert: {e}")))?;
589        }
590    }
591
592    let config = ClientConfig::builder()
593        .with_root_certificates(root_store)
594        .with_no_client_auth();
595
596    Ok(Arc::new(config))
597}