dora_node_api/daemon_connection/
tcp.rs1use dora_message::{
2 daemon_to_node::DaemonReply,
3 node_to_daemon::{DaemonRequest, Timestamped},
4};
5use eyre::{Context, eyre};
6use std::{
7 io::{Read, Write},
8 net::TcpStream,
9};
10
11enum Serializer {
12 Bincode,
13 SerdeJson,
14}
15pub fn request(
16 connection: &mut TcpStream,
17 request: &Timestamped<DaemonRequest>,
18) -> eyre::Result<DaemonReply> {
19 send_message(connection, request)?;
20 if request.inner.expects_tcp_bincode_reply() {
21 receive_reply(connection, Serializer::Bincode)
22 .and_then(|reply| reply.ok_or_else(|| eyre!("server disconnected unexpectedly")))
23 } else if request.inner.expects_tcp_json_reply() {
25 receive_reply(connection, Serializer::SerdeJson)
26 .and_then(|reply| reply.ok_or_else(|| eyre!("server disconnected unexpectedly")))
27 } else {
28 Ok(DaemonReply::Empty)
29 }
30}
31
32fn send_message(
33 connection: &mut TcpStream,
34 message: &Timestamped<DaemonRequest>,
35) -> eyre::Result<()> {
36 let serialized = bincode::serialize(&message).wrap_err("failed to serialize DaemonRequest")?;
37 tcp_send(connection, &serialized).wrap_err("failed to send DaemonRequest")?;
38 Ok(())
39}
40
41fn receive_reply(
42 connection: &mut TcpStream,
43 serializer: Serializer,
44) -> eyre::Result<Option<DaemonReply>> {
45 let raw =
46 match tcp_receive(connection) {
47 Ok(raw) => raw,
48 Err(err) => match err.kind() {
49 std::io::ErrorKind::UnexpectedEof | std::io::ErrorKind::ConnectionAborted => {
50 return Ok(None);
51 }
52 other => return Err(err).with_context(|| {
53 format!(
54 "unexpected I/O error (kind {other:?}) while trying to receive DaemonReply"
55 )
56 }),
57 },
58 };
59 match serializer {
60 Serializer::Bincode => bincode::deserialize(&raw)
61 .wrap_err("failed to deserialize DaemonReply")
62 .map(Some),
63 Serializer::SerdeJson => serde_json::from_slice(&raw)
64 .wrap_err("failed to deserialize DaemonReply")
65 .map(Some),
66 }
67}
68
69fn tcp_send(connection: &mut (impl Write + Unpin), message: &[u8]) -> std::io::Result<()> {
70 let len_raw = (message.len() as u64).to_le_bytes();
71 connection.write_all(&len_raw)?;
72 connection.write_all(message)?;
73 connection.flush()?;
74 Ok(())
75}
76
77fn tcp_receive(connection: &mut (impl Read + Unpin)) -> std::io::Result<Vec<u8>> {
78 let reply_len = {
79 let mut raw = [0; 8];
80 connection.read_exact(&mut raw)?;
81 u64::from_le_bytes(raw) as usize
82 };
83 let mut reply = vec![0; reply_len];
84 connection.read_exact(&mut reply)?;
85 Ok(reply)
86}