1use std::io::{BufReader, Read, Write};
50use std::net::{TcpStream, ToSocketAddrs};
51use std::path::{Path, PathBuf};
52use std::sync::atomic::{AtomicU64, Ordering};
53use std::sync::{Arc, Mutex};
54use std::time::Duration;
55
56use prost::Message;
57use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
58use rustls::{ClientConfig, ClientConnection, StreamOwned};
59
60use crate::block_store::BlockStore;
61use crate::error::{FsError, FsResult};
62use crate::proto;
63
64const PIPELINE_BATCH: usize = 64;
69
70pub struct NetworkBlockStoreConfig {
74 addr: String,
75 server_name: String,
76 client_cert: PathBuf,
77 client_key: PathBuf,
78 ca_cert: PathBuf,
79 connect_timeout: Duration,
80 io_timeout: Duration,
81}
82
83impl NetworkBlockStoreConfig {
84 pub fn new(addr: impl Into<String>, server_name: impl Into<String>) -> Self {
87 Self {
88 addr: addr.into(),
89 server_name: server_name.into(),
90 client_cert: PathBuf::new(),
91 client_key: PathBuf::new(),
92 ca_cert: PathBuf::new(),
93 connect_timeout: Duration::from_secs(10),
94 io_timeout: Duration::from_secs(30),
95 }
96 }
97
98 pub fn client_cert(mut self, path: impl Into<PathBuf>) -> Self {
99 self.client_cert = path.into();
100 self
101 }
102
103 pub fn client_key(mut self, path: impl Into<PathBuf>) -> Self {
104 self.client_key = path.into();
105 self
106 }
107
108 pub fn ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
109 self.ca_cert = path.into();
110 self
111 }
112
113 pub fn connect_timeout(mut self, d: Duration) -> Self {
114 self.connect_timeout = d;
115 self
116 }
117
118 pub fn io_timeout(mut self, d: Duration) -> Self {
119 self.io_timeout = d;
120 self
121 }
122}
123
124pub struct NetworkBlockStore {
132 config: NetworkBlockStoreConfig,
133 tls_config: Arc<ClientConfig>,
134 stream: Mutex<Option<StreamOwned<ClientConnection, TcpStream>>>,
135 block_size: usize,
136 total_blocks: u64,
137 next_request_id: AtomicU64,
138}
139
140impl NetworkBlockStore {
141 pub fn connect(
144 addr: &str,
145 server_name: &str,
146 client_cert: &Path,
147 client_key: &Path,
148 ca_cert: &Path,
149 ) -> FsResult<Self> {
150 Self::from_config(
151 NetworkBlockStoreConfig::new(addr, server_name)
152 .client_cert(client_cert)
153 .client_key(client_key)
154 .ca_cert(ca_cert),
155 )
156 }
157
158 pub fn from_config(config: NetworkBlockStoreConfig) -> FsResult<Self> {
160 let tls_config =
161 build_client_tls_config(&config.client_cert, &config.client_key, &config.ca_cert)?;
162
163 let mut stream = establish_connection(&config, &tls_config)?;
165
166 let req = proto::Request {
168 request_id: 1,
169 command: Some(proto::request::Command::GetInfo(proto::GetInfoRequest {})),
170 };
171 send_message(&mut stream, &req)?;
172 let resp = recv_message(&mut stream)?;
173
174 match resp.result {
175 Some(proto::response::Result::GetInfo(info)) => Ok(Self {
176 config,
177 tls_config,
178 stream: Mutex::new(Some(stream)),
179 block_size: info.block_size as usize,
180 total_blocks: info.total_blocks,
181 next_request_id: AtomicU64::new(2),
182 }),
183 Some(proto::response::Result::Error(e)) => Err(FsError::Internal(format!(
184 "server error on GetInfo: {}",
185 e.message
186 ))),
187 _ => Err(FsError::Internal("unexpected response to GetInfo".into())),
188 }
189 }
190
191 fn next_id(&self) -> u64 {
193 self.next_request_id.fetch_add(1, Ordering::Relaxed)
194 }
195
196 fn reconnect(&self) -> FsResult<StreamOwned<ClientConnection, TcpStream>> {
198 establish_connection(&self.config, &self.tls_config)
199 }
200
201 fn roundtrip(&self, req: &proto::Request) -> FsResult<proto::Response> {
204 let mut guard = self
205 .stream
206 .lock()
207 .map_err(|e| FsError::Internal(e.to_string()))?;
208
209 if guard.is_none() {
211 *guard = Some(self.reconnect()?);
212 }
213
214 let stream = guard.as_mut().unwrap();
215 match send_and_recv(stream, req) {
216 Ok(resp) => Ok(resp),
217 Err(_) => {
218 *guard = Some(self.reconnect()?);
220 send_and_recv(guard.as_mut().unwrap(), req)
221 }
222 }
223 }
224
225 fn pipeline_reads(
229 &self,
230 stream: &mut StreamOwned<ClientConnection, TcpStream>,
231 block_ids: &[u64],
232 ) -> FsResult<Vec<Vec<u8>>> {
233 let mut results = Vec::with_capacity(block_ids.len());
234
235 for chunk in block_ids.chunks(PIPELINE_BATCH) {
236 for &block_id in chunk {
238 let id = self.next_id();
239 send_message(
240 stream,
241 &proto::Request {
242 request_id: id,
243 command: Some(proto::request::Command::ReadBlock(
244 proto::ReadBlockRequest { block_id },
245 )),
246 },
247 )?;
248 }
249
250 for _ in chunk {
252 let resp = recv_message(stream)?;
253 match resp.result {
254 Some(proto::response::Result::ReadBlock(r)) => results.push(r.data),
255 Some(proto::response::Result::Error(e)) => {
256 return Err(FsError::Internal(format!("server: {}", e.message)));
257 }
258 _ => return Err(FsError::Internal("unexpected response".into())),
259 }
260 }
261 }
262
263 Ok(results)
264 }
265
266 fn pipeline_writes(
268 &self,
269 stream: &mut StreamOwned<ClientConnection, TcpStream>,
270 blocks: &[(u64, &[u8])],
271 ) -> FsResult<()> {
272 for chunk in blocks.chunks(PIPELINE_BATCH) {
273 for &(block_id, data) in chunk {
274 let id = self.next_id();
275 send_message(
276 stream,
277 &proto::Request {
278 request_id: id,
279 command: Some(proto::request::Command::WriteBlock(
280 proto::WriteBlockRequest {
281 block_id,
282 data: data.to_vec(),
283 },
284 )),
285 },
286 )?;
287 }
288
289 for _ in chunk {
290 let resp = recv_message(stream)?;
291 match resp.result {
292 Some(proto::response::Result::WriteBlock(_)) => {}
293 Some(proto::response::Result::Error(e)) => {
294 return Err(FsError::Internal(format!("server: {}", e.message)));
295 }
296 _ => return Err(FsError::Internal("unexpected response".into())),
297 }
298 }
299 }
300
301 Ok(())
302 }
303
304 fn with_pipeline<F, T>(&self, op: F) -> FsResult<T>
306 where
307 F: Fn(&Self, &mut StreamOwned<ClientConnection, TcpStream>) -> FsResult<T>,
308 {
309 let mut guard = self
310 .stream
311 .lock()
312 .map_err(|e| FsError::Internal(e.to_string()))?;
313
314 if guard.is_none() {
315 *guard = Some(self.reconnect()?);
316 }
317
318 match op(self, guard.as_mut().unwrap()) {
319 Ok(v) => Ok(v),
320 Err(_) => {
321 *guard = Some(self.reconnect()?);
322 op(self, guard.as_mut().unwrap())
323 }
324 }
325 }
326}
327
328impl BlockStore for NetworkBlockStore {
329 fn block_size(&self) -> usize {
330 self.block_size
331 }
332
333 fn total_blocks(&self) -> u64 {
334 self.total_blocks
335 }
336
337 fn read_block(&self, block_id: u64) -> FsResult<Vec<u8>> {
338 let id = self.next_id();
339 let req = proto::Request {
340 request_id: id,
341 command: Some(proto::request::Command::ReadBlock(
342 proto::ReadBlockRequest { block_id },
343 )),
344 };
345 let resp = self.roundtrip(&req)?;
346
347 match resp.result {
348 Some(proto::response::Result::ReadBlock(r)) => Ok(r.data),
349 Some(proto::response::Result::Error(e)) => {
350 Err(FsError::Internal(format!("server: {}", e.message)))
351 }
352 _ => Err(FsError::Internal("unexpected response".into())),
353 }
354 }
355
356 fn write_block(&self, block_id: u64, data: &[u8]) -> FsResult<()> {
357 let id = self.next_id();
358 let req = proto::Request {
359 request_id: id,
360 command: Some(proto::request::Command::WriteBlock(
361 proto::WriteBlockRequest {
362 block_id,
363 data: data.to_vec(),
364 },
365 )),
366 };
367 let resp = self.roundtrip(&req)?;
368
369 match resp.result {
370 Some(proto::response::Result::WriteBlock(_)) => Ok(()),
371 Some(proto::response::Result::Error(e)) => {
372 Err(FsError::Internal(format!("server: {}", e.message)))
373 }
374 _ => Err(FsError::Internal("unexpected response".into())),
375 }
376 }
377
378 fn sync(&self) -> FsResult<()> {
379 let id = self.next_id();
380 let req = proto::Request {
381 request_id: id,
382 command: Some(proto::request::Command::Sync(proto::SyncRequest {})),
383 };
384 let resp = self.roundtrip(&req)?;
385
386 match resp.result {
387 Some(proto::response::Result::Sync(_)) => Ok(()),
388 Some(proto::response::Result::Error(e)) => {
389 Err(FsError::Internal(format!("server: {}", e.message)))
390 }
391 _ => Err(FsError::Internal("unexpected response".into())),
392 }
393 }
394
395 fn read_blocks(&self, block_ids: &[u64]) -> FsResult<Vec<Vec<u8>>> {
396 if block_ids.is_empty() {
397 return Ok(Vec::new());
398 }
399 self.with_pipeline(|s, stream| s.pipeline_reads(stream, block_ids))
400 }
401
402 fn write_blocks(&self, blocks: &[(u64, &[u8])]) -> FsResult<()> {
403 if blocks.is_empty() {
404 return Ok(());
405 }
406 self.with_pipeline(|s, stream| s.pipeline_writes(stream, blocks))
407 }
408}
409
410fn send_message<W: Write>(w: &mut W, msg: &proto::Request) -> FsResult<()> {
413 let payload = msg.encode_to_vec();
414 let len = payload.len() as u32;
415 w.write_all(&len.to_le_bytes())
416 .map_err(|e| FsError::Internal(format!("write length prefix: {e}")))?;
417 w.write_all(&payload)
418 .map_err(|e| FsError::Internal(format!("write payload: {e}")))?;
419 w.flush()
420 .map_err(|e| FsError::Internal(format!("flush: {e}")))?;
421 Ok(())
422}
423
424fn recv_message<R: Read>(r: &mut R) -> FsResult<proto::Response> {
425 let mut len_buf = [0u8; 4];
426 r.read_exact(&mut len_buf)
427 .map_err(|e| FsError::Internal(format!("read length prefix: {e}")))?;
428 let len = u32::from_le_bytes(len_buf) as usize;
429
430 if len > 16 * 1024 * 1024 {
431 return Err(FsError::Internal(format!(
432 "response too large: {len} bytes"
433 )));
434 }
435
436 let mut buf = vec![0u8; len];
437 r.read_exact(&mut buf)
438 .map_err(|e| FsError::Internal(format!("read payload: {e}")))?;
439
440 proto::Response::decode(&*buf).map_err(|e| FsError::Internal(format!("decode response: {e}")))
441}
442
443fn send_and_recv(
444 stream: &mut StreamOwned<ClientConnection, TcpStream>,
445 req: &proto::Request,
446) -> FsResult<proto::Response> {
447 send_message(stream, req)?;
448 recv_message(stream)
449}
450
451fn establish_connection(
454 config: &NetworkBlockStoreConfig,
455 tls_config: &Arc<ClientConfig>,
456) -> FsResult<StreamOwned<ClientConnection, TcpStream>> {
457 let addr = config
458 .addr
459 .to_socket_addrs()
460 .map_err(|e| FsError::Internal(format!("resolve {}: {e}", config.addr)))?
461 .next()
462 .ok_or_else(|| FsError::Internal(format!("no addresses for {}", config.addr)))?;
463
464 let tcp = TcpStream::connect_timeout(&addr, config.connect_timeout)
465 .map_err(|e| FsError::Internal(format!("connect to {}: {e}", config.addr)))?;
466
467 tcp.set_read_timeout(Some(config.io_timeout))
468 .map_err(|e| FsError::Internal(format!("set read timeout: {e}")))?;
469 tcp.set_write_timeout(Some(config.io_timeout))
470 .map_err(|e| FsError::Internal(format!("set write timeout: {e}")))?;
471
472 let sni = ServerName::try_from(config.server_name.clone())
473 .map_err(|e| FsError::Internal(format!("invalid SNI '{}': {e}", config.server_name)))?;
474
475 let tls_conn = ClientConnection::new(Arc::clone(tls_config), sni)
476 .map_err(|e| FsError::Internal(format!("TLS connection: {e}")))?;
477
478 Ok(StreamOwned::new(tls_conn, tcp))
479}
480
481fn build_client_tls_config(
484 cert_path: &Path,
485 key_path: &Path,
486 ca_path: &Path,
487) -> FsResult<Arc<ClientConfig>> {
488 let cert_pem = std::fs::read(cert_path)
489 .map_err(|e| FsError::Internal(format!("read client cert {}: {e}", cert_path.display())))?;
490 let certs: Vec<CertificateDer<'static>> =
491 rustls_pemfile::certs(&mut BufReader::new(&*cert_pem))
492 .collect::<std::result::Result<Vec<_>, _>>()
493 .map_err(|e| FsError::Internal(format!("parse client certs: {e}")))?;
494
495 let key_pem = std::fs::read(key_path)
496 .map_err(|e| FsError::Internal(format!("read client key {}: {e}", key_path.display())))?;
497 let key: PrivateKeyDer<'static> = rustls_pemfile::private_key(&mut BufReader::new(&*key_pem))
498 .map_err(|e| FsError::Internal(format!("parse client key: {e}")))?
499 .ok_or_else(|| FsError::Internal("no private key found in PEM file".into()))?;
500
501 let ca_pem = std::fs::read(ca_path)
502 .map_err(|e| FsError::Internal(format!("read CA cert {}: {e}", ca_path.display())))?;
503 let ca_certs: Vec<CertificateDer<'static>> =
504 rustls_pemfile::certs(&mut BufReader::new(&*ca_pem))
505 .collect::<std::result::Result<Vec<_>, _>>()
506 .map_err(|e| FsError::Internal(format!("parse CA certs: {e}")))?;
507
508 let mut root_store = rustls::RootCertStore::empty();
509 for cert in ca_certs {
510 root_store
511 .add(cert)
512 .map_err(|e| FsError::Internal(format!("add CA cert: {e}")))?;
513 }
514
515 let config = ClientConfig::builder()
516 .with_root_certificates(root_store)
517 .with_client_auth_cert(certs, key)
518 .map_err(|e| FsError::Internal(format!("build TLS client config: {e}")))?;
519
520 Ok(Arc::new(config))
521}