Skip to main content

doublecrypt_core/
network_store.rs

1//! Network-backed block store that connects to a `doublecrypt-server` over mTLS.
2//!
3//! Uses the 4-byte little-endian length-prefixed protobuf protocol defined in
4//! `proto/blockstore.proto`.  The connection is synchronous (matching the
5//! [`BlockStore`] trait) but supports:
6//!
7//! * **Request pipelining** — [`read_blocks`](BlockStore::read_blocks) and
8//!   [`write_blocks`](BlockStore::write_blocks) send a full batch of requests
9//!   before reading any responses, eliminating per-block round-trip latency.
10//! * **Automatic reconnection** — a single retry on I/O failure with a fresh
11//!   TLS handshake.
12//! * **Configurable timeouts** — connect, read, and write deadlines.
13//!
14//! # Quick start
15//!
16//! ```no_run
17//! use std::path::Path;
18//! use doublecrypt_core::network_store::NetworkBlockStore;
19//! use doublecrypt_core::block_store::BlockStore;
20//!
21//! let store = NetworkBlockStore::connect(
22//!     "127.0.0.1:9100",
23//!     "localhost",
24//!     Path::new("certs/client.pem"),
25//!     Path::new("certs/client-key.pem"),
26//!     Path::new("certs/ca.pem"),
27//! ).expect("connect to server");
28//!
29//! let data = store.read_block(0).expect("read block 0");
30//! ```
31//!
32//! # Builder
33//!
34//! ```no_run
35//! use std::time::Duration;
36//! use doublecrypt_core::network_store::{NetworkBlockStore, NetworkBlockStoreConfig};
37//! use doublecrypt_core::block_store::BlockStore;
38//!
39//! let store = NetworkBlockStore::from_config(
40//!     NetworkBlockStoreConfig::new("10.0.0.5:9100", "block-server")
41//!         .client_cert("certs/client.pem")
42//!         .client_key("certs/client-key.pem")
43//!         .ca_cert("certs/ca.pem")
44//!         .connect_timeout(Duration::from_secs(5))
45//!         .io_timeout(Duration::from_secs(60)),
46//! ).expect("connect to server");
47//! ```
48
49use std::io::{BufReader, Read, Write};
50use std::net::{TcpStream, ToSocketAddrs};
51use std::path::{Path, PathBuf};
52use std::sync::atomic::{AtomicU64, Ordering};
53use std::sync::{Arc, Mutex};
54use std::time::Duration;
55
56use prost::Message;
57use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
58use rustls::{ClientConfig, ClientConnection, StreamOwned};
59
60use crate::block_store::BlockStore;
61use crate::error::{FsError, FsResult};
62use crate::proto;
63
64/// Maximum number of requests to pipeline before reading responses.
65///
66/// Keeps TCP buffer usage bounded and avoids deadlocks when the kernel
67/// send/receive buffers are smaller than the total pipelined payload.
68const PIPELINE_BATCH: usize = 64;
69
70// ── Configuration ───────────────────────────────────────────
71
72/// Connection parameters for a [`NetworkBlockStore`].
73pub struct NetworkBlockStoreConfig {
74    addr: String,
75    server_name: String,
76    client_cert: PathBuf,
77    client_key: PathBuf,
78    ca_cert: PathBuf,
79    connect_timeout: Duration,
80    io_timeout: Duration,
81}
82
83impl NetworkBlockStoreConfig {
84    /// Create a config targeting `addr` (`"host:port"`) with the given TLS
85    /// server name (SNI).  Timeouts default to 10 s (connect) and 30 s (I/O).
86    pub fn new(addr: impl Into<String>, server_name: impl Into<String>) -> Self {
87        Self {
88            addr: addr.into(),
89            server_name: server_name.into(),
90            client_cert: PathBuf::new(),
91            client_key: PathBuf::new(),
92            ca_cert: PathBuf::new(),
93            connect_timeout: Duration::from_secs(10),
94            io_timeout: Duration::from_secs(30),
95        }
96    }
97
98    pub fn client_cert(mut self, path: impl Into<PathBuf>) -> Self {
99        self.client_cert = path.into();
100        self
101    }
102
103    pub fn client_key(mut self, path: impl Into<PathBuf>) -> Self {
104        self.client_key = path.into();
105        self
106    }
107
108    pub fn ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
109        self.ca_cert = path.into();
110        self
111    }
112
113    pub fn connect_timeout(mut self, d: Duration) -> Self {
114        self.connect_timeout = d;
115        self
116    }
117
118    pub fn io_timeout(mut self, d: Duration) -> Self {
119        self.io_timeout = d;
120        self
121    }
122}
123
124// ── Store ───────────────────────────────────────────────────
125
126/// A [`BlockStore`] backed by a remote `doublecrypt-server` reached over mTLS.
127///
128/// On construction the client performs a TLS handshake and issues a `GetInfo`
129/// RPC to learn the block size and total block count.  The connection is
130/// stored and reused; if it breaks, one automatic reconnect is attempted.
131pub struct NetworkBlockStore {
132    config: NetworkBlockStoreConfig,
133    tls_config: Arc<ClientConfig>,
134    stream: Mutex<Option<StreamOwned<ClientConnection, TcpStream>>>,
135    block_size: usize,
136    total_blocks: u64,
137    next_request_id: AtomicU64,
138}
139
140impl NetworkBlockStore {
141    /// Connect to a `doublecrypt-server` using mutual TLS (convenience wrapper
142    /// around [`from_config`](Self::from_config)).
143    pub fn connect(
144        addr: &str,
145        server_name: &str,
146        client_cert: &Path,
147        client_key: &Path,
148        ca_cert: &Path,
149    ) -> FsResult<Self> {
150        Self::from_config(
151            NetworkBlockStoreConfig::new(addr, server_name)
152                .client_cert(client_cert)
153                .client_key(client_key)
154                .ca_cert(ca_cert),
155        )
156    }
157
158    /// Connect using a [`NetworkBlockStoreConfig`].
159    pub fn from_config(config: NetworkBlockStoreConfig) -> FsResult<Self> {
160        let tls_config =
161            build_client_tls_config(&config.client_cert, &config.client_key, &config.ca_cert)?;
162
163        // Establish the initial connection.
164        let mut stream = establish_connection(&config, &tls_config)?;
165
166        // Issue GetInfo to learn block geometry.
167        let req = proto::Request {
168            request_id: 1,
169            command: Some(proto::request::Command::GetInfo(proto::GetInfoRequest {})),
170        };
171        send_message(&mut stream, &req)?;
172        let resp = recv_message(&mut stream)?;
173
174        match resp.result {
175            Some(proto::response::Result::GetInfo(info)) => Ok(Self {
176                config,
177                tls_config,
178                stream: Mutex::new(Some(stream)),
179                block_size: info.block_size as usize,
180                total_blocks: info.total_blocks,
181                next_request_id: AtomicU64::new(2),
182            }),
183            Some(proto::response::Result::Error(e)) => Err(FsError::Internal(format!(
184                "server error on GetInfo: {}",
185                e.message
186            ))),
187            _ => Err(FsError::Internal("unexpected response to GetInfo".into())),
188        }
189    }
190
191    /// Allocate a monotonically increasing request ID.
192    fn next_id(&self) -> u64 {
193        self.next_request_id.fetch_add(1, Ordering::Relaxed)
194    }
195
196    /// Establish a fresh TLS connection using stored config.
197    fn reconnect(&self) -> FsResult<StreamOwned<ClientConnection, TcpStream>> {
198        establish_connection(&self.config, &self.tls_config)
199    }
200
201    /// Send a single request and receive its response, retrying once on I/O
202    /// failure by reconnecting.
203    fn roundtrip(&self, req: &proto::Request) -> FsResult<proto::Response> {
204        let mut guard = self
205            .stream
206            .lock()
207            .map_err(|e| FsError::Internal(e.to_string()))?;
208
209        // Ensure we have a live connection.
210        if guard.is_none() {
211            *guard = Some(self.reconnect()?);
212        }
213
214        let stream = guard.as_mut().unwrap();
215        match send_and_recv(stream, req) {
216            Ok(resp) => Ok(resp),
217            Err(_) => {
218                // Connection may be dead — reconnect and retry once.
219                *guard = Some(self.reconnect()?);
220                send_and_recv(guard.as_mut().unwrap(), req)
221            }
222        }
223    }
224
225    // ── Pipelined helpers ───────────────────────────────────
226
227    /// Pipeline a batch of read requests on `stream`.
228    fn pipeline_reads(
229        &self,
230        stream: &mut StreamOwned<ClientConnection, TcpStream>,
231        block_ids: &[u64],
232    ) -> FsResult<Vec<Vec<u8>>> {
233        let mut results = Vec::with_capacity(block_ids.len());
234
235        for chunk in block_ids.chunks(PIPELINE_BATCH) {
236            // Send all requests in this batch.
237            for &block_id in chunk {
238                let id = self.next_id();
239                send_message(
240                    stream,
241                    &proto::Request {
242                        request_id: id,
243                        command: Some(proto::request::Command::ReadBlock(
244                            proto::ReadBlockRequest { block_id },
245                        )),
246                    },
247                )?;
248            }
249
250            // Read all responses.
251            for _ in chunk {
252                let resp = recv_message(stream)?;
253                match resp.result {
254                    Some(proto::response::Result::ReadBlock(r)) => results.push(r.data),
255                    Some(proto::response::Result::Error(e)) => {
256                        return Err(FsError::Internal(format!("server: {}", e.message)));
257                    }
258                    _ => return Err(FsError::Internal("unexpected response".into())),
259                }
260            }
261        }
262
263        Ok(results)
264    }
265
266    /// Pipeline a batch of write requests on `stream`.
267    fn pipeline_writes(
268        &self,
269        stream: &mut StreamOwned<ClientConnection, TcpStream>,
270        blocks: &[(u64, &[u8])],
271    ) -> FsResult<()> {
272        for chunk in blocks.chunks(PIPELINE_BATCH) {
273            for &(block_id, data) in chunk {
274                let id = self.next_id();
275                send_message(
276                    stream,
277                    &proto::Request {
278                        request_id: id,
279                        command: Some(proto::request::Command::WriteBlock(
280                            proto::WriteBlockRequest {
281                                block_id,
282                                data: data.to_vec(),
283                            },
284                        )),
285                    },
286                )?;
287            }
288
289            for _ in chunk {
290                let resp = recv_message(stream)?;
291                match resp.result {
292                    Some(proto::response::Result::WriteBlock(_)) => {}
293                    Some(proto::response::Result::Error(e)) => {
294                        return Err(FsError::Internal(format!("server: {}", e.message)));
295                    }
296                    _ => return Err(FsError::Internal("unexpected response".into())),
297                }
298            }
299        }
300
301        Ok(())
302    }
303
304    /// Run a pipelined operation with one reconnect attempt on failure.
305    fn with_pipeline<F, T>(&self, op: F) -> FsResult<T>
306    where
307        F: Fn(&Self, &mut StreamOwned<ClientConnection, TcpStream>) -> FsResult<T>,
308    {
309        let mut guard = self
310            .stream
311            .lock()
312            .map_err(|e| FsError::Internal(e.to_string()))?;
313
314        if guard.is_none() {
315            *guard = Some(self.reconnect()?);
316        }
317
318        match op(self, guard.as_mut().unwrap()) {
319            Ok(v) => Ok(v),
320            Err(_) => {
321                *guard = Some(self.reconnect()?);
322                op(self, guard.as_mut().unwrap())
323            }
324        }
325    }
326}
327
328impl BlockStore for NetworkBlockStore {
329    fn block_size(&self) -> usize {
330        self.block_size
331    }
332
333    fn total_blocks(&self) -> u64 {
334        self.total_blocks
335    }
336
337    fn read_block(&self, block_id: u64) -> FsResult<Vec<u8>> {
338        let id = self.next_id();
339        let req = proto::Request {
340            request_id: id,
341            command: Some(proto::request::Command::ReadBlock(
342                proto::ReadBlockRequest { block_id },
343            )),
344        };
345        let resp = self.roundtrip(&req)?;
346
347        match resp.result {
348            Some(proto::response::Result::ReadBlock(r)) => Ok(r.data),
349            Some(proto::response::Result::Error(e)) => {
350                Err(FsError::Internal(format!("server: {}", e.message)))
351            }
352            _ => Err(FsError::Internal("unexpected response".into())),
353        }
354    }
355
356    fn write_block(&self, block_id: u64, data: &[u8]) -> FsResult<()> {
357        let id = self.next_id();
358        let req = proto::Request {
359            request_id: id,
360            command: Some(proto::request::Command::WriteBlock(
361                proto::WriteBlockRequest {
362                    block_id,
363                    data: data.to_vec(),
364                },
365            )),
366        };
367        let resp = self.roundtrip(&req)?;
368
369        match resp.result {
370            Some(proto::response::Result::WriteBlock(_)) => Ok(()),
371            Some(proto::response::Result::Error(e)) => {
372                Err(FsError::Internal(format!("server: {}", e.message)))
373            }
374            _ => Err(FsError::Internal("unexpected response".into())),
375        }
376    }
377
378    fn sync(&self) -> FsResult<()> {
379        let id = self.next_id();
380        let req = proto::Request {
381            request_id: id,
382            command: Some(proto::request::Command::Sync(proto::SyncRequest {})),
383        };
384        let resp = self.roundtrip(&req)?;
385
386        match resp.result {
387            Some(proto::response::Result::Sync(_)) => Ok(()),
388            Some(proto::response::Result::Error(e)) => {
389                Err(FsError::Internal(format!("server: {}", e.message)))
390            }
391            _ => Err(FsError::Internal("unexpected response".into())),
392        }
393    }
394
395    fn read_blocks(&self, block_ids: &[u64]) -> FsResult<Vec<Vec<u8>>> {
396        if block_ids.is_empty() {
397            return Ok(Vec::new());
398        }
399        self.with_pipeline(|s, stream| s.pipeline_reads(stream, block_ids))
400    }
401
402    fn write_blocks(&self, blocks: &[(u64, &[u8])]) -> FsResult<()> {
403        if blocks.is_empty() {
404            return Ok(());
405        }
406        self.with_pipeline(|s, stream| s.pipeline_writes(stream, blocks))
407    }
408}
409
410// ── Wire helpers ────────────────────────────────────────────
411
412fn send_message<W: Write>(w: &mut W, msg: &proto::Request) -> FsResult<()> {
413    let payload = msg.encode_to_vec();
414    let len = payload.len() as u32;
415    w.write_all(&len.to_le_bytes())
416        .map_err(|e| FsError::Internal(format!("write length prefix: {e}")))?;
417    w.write_all(&payload)
418        .map_err(|e| FsError::Internal(format!("write payload: {e}")))?;
419    w.flush()
420        .map_err(|e| FsError::Internal(format!("flush: {e}")))?;
421    Ok(())
422}
423
424fn recv_message<R: Read>(r: &mut R) -> FsResult<proto::Response> {
425    let mut len_buf = [0u8; 4];
426    r.read_exact(&mut len_buf)
427        .map_err(|e| FsError::Internal(format!("read length prefix: {e}")))?;
428    let len = u32::from_le_bytes(len_buf) as usize;
429
430    if len > 16 * 1024 * 1024 {
431        return Err(FsError::Internal(format!(
432            "response too large: {len} bytes"
433        )));
434    }
435
436    let mut buf = vec![0u8; len];
437    r.read_exact(&mut buf)
438        .map_err(|e| FsError::Internal(format!("read payload: {e}")))?;
439
440    proto::Response::decode(&*buf).map_err(|e| FsError::Internal(format!("decode response: {e}")))
441}
442
443fn send_and_recv(
444    stream: &mut StreamOwned<ClientConnection, TcpStream>,
445    req: &proto::Request,
446) -> FsResult<proto::Response> {
447    send_message(stream, req)?;
448    recv_message(stream)
449}
450
451// ── Connection establishment ────────────────────────────────
452
453fn establish_connection(
454    config: &NetworkBlockStoreConfig,
455    tls_config: &Arc<ClientConfig>,
456) -> FsResult<StreamOwned<ClientConnection, TcpStream>> {
457    let addr = config
458        .addr
459        .to_socket_addrs()
460        .map_err(|e| FsError::Internal(format!("resolve {}: {e}", config.addr)))?
461        .next()
462        .ok_or_else(|| FsError::Internal(format!("no addresses for {}", config.addr)))?;
463
464    let tcp = TcpStream::connect_timeout(&addr, config.connect_timeout)
465        .map_err(|e| FsError::Internal(format!("connect to {}: {e}", config.addr)))?;
466
467    tcp.set_read_timeout(Some(config.io_timeout))
468        .map_err(|e| FsError::Internal(format!("set read timeout: {e}")))?;
469    tcp.set_write_timeout(Some(config.io_timeout))
470        .map_err(|e| FsError::Internal(format!("set write timeout: {e}")))?;
471
472    let sni = ServerName::try_from(config.server_name.clone())
473        .map_err(|e| FsError::Internal(format!("invalid SNI '{}': {e}", config.server_name)))?;
474
475    let tls_conn = ClientConnection::new(Arc::clone(tls_config), sni)
476        .map_err(|e| FsError::Internal(format!("TLS connection: {e}")))?;
477
478    Ok(StreamOwned::new(tls_conn, tcp))
479}
480
481// ── TLS configuration ───────────────────────────────────────
482
483fn build_client_tls_config(
484    cert_path: &Path,
485    key_path: &Path,
486    ca_path: &Path,
487) -> FsResult<Arc<ClientConfig>> {
488    let cert_pem = std::fs::read(cert_path)
489        .map_err(|e| FsError::Internal(format!("read client cert {}: {e}", cert_path.display())))?;
490    let certs: Vec<CertificateDer<'static>> =
491        rustls_pemfile::certs(&mut BufReader::new(&*cert_pem))
492            .collect::<std::result::Result<Vec<_>, _>>()
493            .map_err(|e| FsError::Internal(format!("parse client certs: {e}")))?;
494
495    let key_pem = std::fs::read(key_path)
496        .map_err(|e| FsError::Internal(format!("read client key {}: {e}", key_path.display())))?;
497    let key: PrivateKeyDer<'static> = rustls_pemfile::private_key(&mut BufReader::new(&*key_pem))
498        .map_err(|e| FsError::Internal(format!("parse client key: {e}")))?
499        .ok_or_else(|| FsError::Internal("no private key found in PEM file".into()))?;
500
501    let ca_pem = std::fs::read(ca_path)
502        .map_err(|e| FsError::Internal(format!("read CA cert {}: {e}", ca_path.display())))?;
503    let ca_certs: Vec<CertificateDer<'static>> =
504        rustls_pemfile::certs(&mut BufReader::new(&*ca_pem))
505            .collect::<std::result::Result<Vec<_>, _>>()
506            .map_err(|e| FsError::Internal(format!("parse CA certs: {e}")))?;
507
508    let mut root_store = rustls::RootCertStore::empty();
509    for cert in ca_certs {
510        root_store
511            .add(cert)
512            .map_err(|e| FsError::Internal(format!("add CA cert: {e}")))?;
513    }
514
515    let config = ClientConfig::builder()
516        .with_root_certificates(root_store)
517        .with_client_auth_cert(certs, key)
518        .map_err(|e| FsError::Internal(format!("build TLS client config: {e}")))?;
519
520    Ok(Arc::new(config))
521}