1use std::{
4 mem::{size_of, size_of_val},
5 sync::atomic::{AtomicU32, Ordering},
6};
7
8use bincode::Options;
9use eyre::WrapErr;
10use serde::{de::DeserializeOwned, Serialize};
11use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
12use treasury_id::AssetId;
13
14#[repr(C)]
16pub struct Handshake {
17 pub magic: u32,
19
20 pub version: u32,
22}
23
24#[derive(Debug, serde::Serialize, serde::Deserialize)]
27pub struct OpenRequest {
28 pub path: Box<str>,
30
31 pub init: bool,
34}
35
36#[derive(Debug, serde::Serialize, serde::Deserialize)]
38pub enum OpenResponse {
39 Success,
40
41 Failure {
44 description: Box<str>,
45 },
46}
47
48#[derive(Debug, serde::Serialize, serde::Deserialize)]
50pub enum Request {
51 Store {
53 source: Box<str>,
55
56 format: Option<Box<str>>,
58
59 target: Box<str>,
61 },
62
63 FetchUrl { id: AssetId },
65
66 FindAsset { source: Box<str>, target: Box<str> },
68}
69
70#[derive(Debug, serde::Serialize, serde::Deserialize)]
72pub enum StoreResponse {
73 Success { id: AssetId, path: Box<str> },
76
77 NeedData { url: Box<str> },
79
80 Failure { description: Box<str> },
83}
84
85#[derive(Debug, serde::Serialize, serde::Deserialize)]
86pub enum FetchUrlResponse {
87 Success { artifact: Box<str> },
90
91 NotFound,
93
94 Failure { description: Box<str> },
96}
97
98#[derive(Debug, serde::Serialize, serde::Deserialize)]
99pub enum FindResponse {
100 Success { id: AssetId, path: Box<str> },
103
104 NotFound,
106
107 Failure { description: Box<str> },
109}
110
111pub const MAGIC: u32 = u32::from_be_bytes(*b"TRES");
112
113pub fn version() -> u32 {
114 static VERSION: AtomicU32 = AtomicU32::new(u32::MAX);
115
116 #[cold]
117 fn init_version() -> u32 {
118 env!("CARGO_PKG_VERSION_MAJOR")
120 .parse()
121 .expect("Bad major version")
122 }
123
124 let mut version = VERSION.load(Ordering::Relaxed);
125 if version == u32::MAX {
126 version = init_version();
127 VERSION.store(version, Ordering::Relaxed);
128 }
129 version
130}
131
132#[derive(Debug)]
133#[repr(C)]
134pub struct MessageHeader {
135 pub size: u32,
136}
137
138pub const DEFAULT_PORT: u16 = 12345;
139
140pub fn get_port() -> u16 {
141 match std::env::var("TREASURY_SERVICE_PORT") {
142 Ok(port_string) => match port_string.parse() {
143 Ok(port) => port,
144 Err(_) => {
145 tracing::error!(
146 "Failed to parse desired treasury port from env '{}'. Using default {}",
147 port_string,
148 DEFAULT_PORT
149 );
150 DEFAULT_PORT
151 }
152 },
153 Err(_) => DEFAULT_PORT,
154 }
155}
156
157const INLINE_MESSAGE_LIMIT: usize = 1 << 12; const MESSAGE_LIMIT: usize = 1 << 28; pub async fn send_message<T: Serialize>(
161 stream: &mut (impl AsyncWrite + Unpin),
162 message: T,
163) -> eyre::Result<()> {
164 let size = bincode_options()
165 .serialized_size(&message)
166 .wrap_err("Failed to determine serialized size of the message")?;
167
168 eyre::ensure!(size <= MESSAGE_LIMIT as u64, "Message is too large");
169
170 let size = size as u32;
171 let header = MessageHeader { size };
172 tracing::debug!("Sending message header {:?}", header);
173
174 let mut buffer = [0; INLINE_MESSAGE_LIMIT];
175 if size > INLINE_MESSAGE_LIMIT as u32 {
176 let mut buffer = vec![0; size_of::<MessageHeader>() + size as usize];
177
178 buffer[..size_of::<MessageHeader>()].copy_from_slice(&header.size.to_le_bytes());
179
180 bincode_options()
181 .serialize_into(&mut buffer[size_of::<MessageHeader>()..], &message)
182 .wrap_err("Failed to serialize message")?;
183
184 stream
185 .write_all(&buffer)
186 .await
187 .wrap_err("Failed to send message")?;
188
189 tracing::debug!("{} bytes sent", buffer.len());
190 } else {
191 let buffer = &mut buffer[..size_of::<MessageHeader>() + size as usize];
192
193 buffer[..size_of::<MessageHeader>()].copy_from_slice(&header.size.to_le_bytes());
194
195 bincode_options()
196 .serialize_into(&mut buffer[size_of::<MessageHeader>()..], &message)
197 .wrap_err("Failed to serialize message")?;
198
199 stream
200 .write_all(buffer)
201 .await
202 .wrap_err("Failed to send message")?;
203
204 tracing::debug!("{} bytes sent", buffer.len());
205 }
206
207 Ok(())
208}
209
210async fn next_message_header(
211 stream: &mut (impl AsyncRead + Unpin),
212) -> std::io::Result<Option<MessageHeader>> {
213 let mut buffer = [0; size_of::<MessageHeader>()];
214 match stream.read_exact(&mut buffer).await {
215 Ok(_) => Ok(Some(MessageHeader {
216 size: u32::from_le_bytes(buffer),
217 })),
218 Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => Ok(None),
219 Err(err) => Err(err),
220 }
221}
222
223pub async fn recv_message<T: DeserializeOwned>(
224 stream: &mut (impl AsyncRead + Unpin),
225) -> eyre::Result<Option<T>> {
226 let header = match next_message_header(stream).await? {
227 None => {
228 tracing::debug!("Connection closed");
229 return Ok(None);
230 }
231 Some(header) => header,
232 };
233
234 tracing::debug!("Next message header {:?}", header);
235
236 eyre::ensure!(header.size <= MESSAGE_LIMIT as u32, "Message is too large");
237
238 let mut buffer = [0; INLINE_MESSAGE_LIMIT];
239
240 if header.size > INLINE_MESSAGE_LIMIT as u32 {
241 let mut buffer = vec![0; header.size as usize];
242 stream.read_exact(&mut buffer).await?;
243
244 tracing::debug!(
245 "{} bytes received",
246 size_of::<MessageHeader>() + header.size as usize
247 );
248
249 let message = bincode_options()
250 .deserialize(&buffer)
251 .wrap_err("Failed to parse request")?;
252
253 Ok(Some(message))
254 } else {
255 let buffer = &mut buffer[..header.size as usize];
256 stream.read_exact(buffer).await?;
257
258 tracing::debug!(
259 "{} bytes received",
260 size_of::<MessageHeader>() + header.size as usize
261 );
262
263 let message = bincode_options()
264 .deserialize(buffer)
265 .wrap_err("Failed to parse request")?;
266
267 Ok(Some(message))
268 }
269}
270
271pub async fn recv_handshake(stream: &mut (impl AsyncRead + Unpin)) -> eyre::Result<()> {
272 let mut buffer = [0; size_of::<Handshake>()];
273
274 stream
275 .read_exact(&mut buffer)
276 .await
277 .wrap_err("Handshake failed")?;
278
279 let handshake = Handshake {
280 magic: u32::from_le_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]),
281 version: u32::from_le_bytes([buffer[4], buffer[5], buffer[6], buffer[7]]),
282 };
283
284 tracing::debug!(
285 "Handshake received {}:{}",
286 handshake.magic,
287 handshake.version
288 );
289
290 eyre::ensure!(
291 handshake.magic == MAGIC,
292 "Wrong MAGIC number. Expected '{}', found '{}'",
293 MAGIC,
294 handshake.magic
295 );
296
297 let version = version();
298
299 eyre::ensure!(
300 handshake.version == version,
301 "Treasury API version mismatch. Expected '{}', found '{}'",
302 version,
303 handshake.version,
304 );
305
306 tracing::info!("Handshake valid");
307
308 Ok(())
309}
310
311pub async fn send_handshake(stream: &mut (impl AsyncWrite + Unpin)) -> eyre::Result<()> {
312 let mut buffer = [0; size_of::<Handshake>()];
313
314 buffer[..size_of_val(&MAGIC)].copy_from_slice(&MAGIC.to_le_bytes());
315 buffer[size_of_val(&MAGIC)..].copy_from_slice(&version().to_le_bytes());
316
317 stream
318 .write_all(&buffer)
319 .await
320 .wrap_err("Handshake failed")?;
321
322 tracing::debug!("Handshake sent {}:{}", MAGIC, version());
323
324 Ok(())
325}
326
327fn bincode_options() -> impl Options {
328 bincode::options()
329 .with_big_endian()
330 .with_fixint_encoding()
331 .allow_trailing_bytes()
332}