1use std::{error::Error, path::PathBuf, sync::Mutex};
2
3use capnp::message::ReaderOptions;
4use capnp_rpc::{pry, rpc_twoparty_capnp, twoparty, RpcSystem};
5use crossbeam_channel::{unbounded, Receiver, Sender};
6use distill_core::{utils, AssetMetadata, AssetUuid};
7use distill_schema::{data::asset_change_event, parse_db_metadata, service::asset_hub};
8use futures_util::AsyncReadExt;
9use tokio::{
10 net::TcpStream,
11 runtime::{Builder, Runtime},
12 sync::oneshot,
13};
14
15use crate::{
16 io::{DataRequest, LoaderIO, MetadataRequest, MetadataRequestResult, ResolveRequest},
17 loader::LoaderState,
18};
19
20type Promise<T> = capnp::capability::Promise<T, capnp::Error>;
21
22struct RpcConnection {
24 snapshot: asset_hub::snapshot::Client,
25 snapshot_rx: Receiver<SnapshotChange>,
26}
27
28struct SnapshotChange {
30 snapshot: asset_hub::snapshot::Client,
31 changed_assets: Vec<AssetUuid>,
32 deleted_assets: Vec<AssetUuid>,
33 changed_paths: Vec<PathBuf>,
34 deleted_paths: Vec<PathBuf>,
35}
36
37enum InternalConnectionState {
38 None,
39 Connecting(oneshot::Receiver<Result<RpcConnection, Box<dyn Error>>>),
40 Connected(RpcConnection),
41 Error(Box<dyn Error>),
42}
43
44struct RpcRuntime {
46 runtime: Runtime,
47 local: tokio::task::LocalSet,
48 connection: InternalConnectionState,
49}
50
51unsafe impl Send for RpcRuntime {}
56
57impl RpcRuntime {
58 fn check_asset_changes(&mut self, loader: &LoaderState) {
59 self.connection =
60 match std::mem::replace(&mut self.connection, InternalConnectionState::None) {
61 InternalConnectionState::Connected(mut conn) => {
62 if let Ok(change) = conn.snapshot_rx.try_recv() {
63 log::trace!("RpcRuntime check_asset_changes Ok(change)");
64 conn.snapshot = change.snapshot;
65 let mut changed_assets = Vec::new();
66 for asset in change.changed_assets {
67 log::trace!(
68 "RpcRuntime check_asset_changes changed asset.id: {:?}",
69 asset
70 );
71 changed_assets.push(asset);
72 }
73 for asset in change.deleted_assets {
74 log::trace!(
75 "RpcRuntime check_asset_changes deleted asset.id: {:?}",
76 asset
77 );
78 changed_assets.push(asset);
79 }
80 loader.invalidate_assets(&changed_assets);
81 let mut changed_paths = Vec::new();
82 for path in change.changed_paths {
83 changed_paths.push(path);
84 }
85 for path in change.deleted_paths {
86 changed_paths.push(path);
87 }
88 loader.invalidate_paths(&changed_paths);
89 }
90 InternalConnectionState::Connected(conn)
91 }
92 c => c,
93 };
94 }
95
96 fn connect(&mut self, connect_string: String) {
97 match self.connection {
98 InternalConnectionState::Connected(_) | InternalConnectionState::Connecting(_) => {
99 panic!("Trying to connect while already connected or connecting")
100 }
101 _ => {}
102 };
103
104 let (conn_tx, conn_rx) = oneshot::channel();
105
106 self.local.spawn_local(async move {
107 let result = async move {
108 log::trace!("Tcp connect to {:?}", connect_string);
109 let stream = TcpStream::connect(connect_string).await?;
110 stream.set_nodelay(true)?;
111
112 use tokio_util::compat::*;
113 let (reader, writer) = stream.compat().split();
114
115 log::trace!("Creating capnp VatNetwork");
116 let rpc_network = Box::new(twoparty::VatNetwork::new(
117 reader,
118 writer,
119 rpc_twoparty_capnp::Side::Client,
120 *ReaderOptions::new()
121 .nesting_limit(64)
122 .traversal_limit_in_words(Some(256 * 1024 * 1024)),
123 ));
124
125 let mut rpc_system = RpcSystem::new(rpc_network, None);
126
127 let hub: asset_hub::Client = rpc_system.bootstrap(rpc_twoparty_capnp::Side::Server);
128
129 let _disconnector = rpc_system.get_disconnector();
130
131 tokio::task::spawn_local(rpc_system);
132
133 log::trace!("Requesting RPC snapshot..");
134 let response = hub.get_snapshot_request().send().promise.await?;
135
136 let snapshot = response.get()?.get_snapshot()?;
137 log::trace!("Received snapshot, registering listener..");
138 let (snapshot_tx, snapshot_rx) = unbounded();
139 let listener: asset_hub::listener::Client = capnp_rpc::new_client(ListenerImpl {
140 snapshot_channel: snapshot_tx,
141 snapshot_change: None,
142 });
143
144 let mut request = hub.register_listener_request();
145 request.get().set_listener(listener);
146 let rpc_conn = request.send().promise.await.map(|_| RpcConnection {
147 snapshot,
148 snapshot_rx,
149 })?;
150 log::trace!("Registered listener, done connecting RPC loader.");
151
152 Ok(rpc_conn)
153 }
154 .await;
155 let _ = conn_tx.send(result);
156 });
157
158 self.connection = InternalConnectionState::Connecting(conn_rx)
159 }
160}
161
162pub struct RpcIO {
163 connect_string: String,
164 runtime: Mutex<RpcRuntime>,
165 requests: QueuedRequests,
166}
167
168#[derive(Default)]
169struct QueuedRequests {
170 data_requests: Vec<DataRequest>,
171 metadata_requests: Vec<MetadataRequest>,
172 resolve_requests: Vec<ResolveRequest>,
173}
174
175impl Default for RpcIO {
176 fn default() -> RpcIO {
177 RpcIO::new("127.0.0.1:9999".to_string()).unwrap()
178 }
179}
180
181impl RpcIO {
182 pub fn new(connect_string: String) -> std::io::Result<RpcIO> {
183 Ok(RpcIO {
184 connect_string,
185 runtime: Mutex::new(RpcRuntime {
186 runtime: Builder::new_current_thread().enable_all().build()?,
187 local: tokio::task::LocalSet::new(),
188 connection: InternalConnectionState::None,
189 }),
190 requests: Default::default(),
191 })
192 }
193}
194
195impl LoaderIO for RpcIO {
196 fn get_asset_metadata_with_dependencies(&mut self, request: MetadataRequest) {
197 self.requests.metadata_requests.push(request);
198 let mut runtime = self.runtime.lock().unwrap();
199 process_requests(&mut runtime, &mut self.requests);
200 }
201
202 fn get_asset_candidates(&mut self, requests: Vec<ResolveRequest>) {
203 self.requests.resolve_requests.extend(requests);
204 let mut runtime = self.runtime.lock().unwrap();
205 process_requests(&mut runtime, &mut self.requests);
206 }
207
208 fn get_artifacts(&mut self, requests: Vec<DataRequest>) {
209 self.requests.data_requests.extend(requests);
210 let mut runtime = self.runtime.lock().unwrap();
211 process_requests(&mut runtime, &mut self.requests);
212 }
213
214 fn tick(&mut self, loader: &mut LoaderState) {
215 let mut runtime = self.runtime.lock().unwrap();
216
217 match &runtime.connection {
218 InternalConnectionState::Error(err) => {
219 log::error!("Error connecting RpcIO: {}", err);
220 runtime.connect(self.connect_string.clone());
221 }
222 InternalConnectionState::None => {
223 runtime.connect(self.connect_string.clone());
224 }
225 _ => {}
226 };
227
228 process_requests(&mut runtime, &mut self.requests);
229
230 runtime.connection =
231 match std::mem::replace(&mut runtime.connection, InternalConnectionState::None) {
232 InternalConnectionState::Connecting(mut pending_connection) => {
234 match pending_connection.try_recv() {
235 Ok(connection_result) => match connection_result {
236 Ok(conn) => InternalConnectionState::Connected(conn),
237 Err(err) => InternalConnectionState::Error(err),
238 },
239 Err(oneshot::error::TryRecvError::Closed) => {
240 InternalConnectionState::Error(Box::new(
241 oneshot::error::TryRecvError::Closed,
242 ))
243 }
244 Err(oneshot::error::TryRecvError::Empty) => {
245 InternalConnectionState::Connecting(pending_connection)
246 }
247 }
248 }
249 c => c,
250 };
251
252 runtime
253 .local
254 .block_on(&runtime.runtime, tokio::task::yield_now());
255
256 runtime.check_asset_changes(loader);
257 }
258
259 fn with_runtime(&self, f: &mut dyn FnMut(&tokio::runtime::Runtime)) {
260 let runtime = self.runtime.lock().unwrap();
261 f(&runtime.runtime)
262 }
263}
264
265async fn do_metadata_request(
266 asset: &MetadataRequest,
267 snapshot: &asset_hub::snapshot::Client,
268) -> Result<Vec<MetadataRequestResult>, capnp::Error> {
269 let mut request = snapshot.get_asset_metadata_with_dependencies_request();
270 let mut assets = request
271 .get()
272 .init_assets(asset.requested_assets().count() as u32);
273 for (idx, asset) in asset.requested_assets().enumerate() {
274 assets.reborrow().get(idx as u32).set_id(&asset.0);
275 }
276 let response = request.send().promise.await?;
277 let reader = response.get()?;
278 let artifacts = reader
279 .get_assets()?
280 .into_iter()
281 .map(|a| parse_db_metadata(&a))
282 .filter(|a| a.artifact.is_some())
283 .map(|a| MetadataRequestResult {
284 artifact_metadata: a.artifact.clone().unwrap(),
285 asset_metadata: if asset.include_asset_metadata() {
286 Some(a)
287 } else {
288 None
289 },
290 })
291 .collect::<Vec<_>>();
292 Ok(artifacts)
293}
294
295async fn do_import_artifact_request(
296 asset: &DataRequest,
297 snapshot: &asset_hub::snapshot::Client,
298) -> Result<Vec<u8>, capnp::Error> {
299 let mut request = snapshot.get_import_artifacts_request();
300 let mut assets = request.get().init_assets(1);
301 assets.reborrow().get(0).set_id(&asset.asset_id().0);
302 let response = request.send().promise.await?;
303 let reader = response.get()?;
304 let artifact = reader.get_artifacts()?.get(0);
305 Ok(Vec::from(artifact.get_data()?))
306}
307
308async fn do_resolve_request(
309 resolve: &ResolveRequest,
310 snapshot: &asset_hub::snapshot::Client,
311) -> Result<Vec<(PathBuf, Vec<AssetMetadata>)>, capnp::Error> {
312 let path = resolve.identifier().path();
313 let mut request = snapshot.get_assets_for_paths_request();
315 let mut paths = request.get().init_paths(1);
316 paths.reborrow().set(0, path.as_bytes());
317 let response = request.send().promise.await?;
318 let reader = response.get()?;
319 let mut results = Vec::new();
320 for reader in reader.get_assets()? {
321 let path = PathBuf::from(std::str::from_utf8(reader.get_path()?)?);
322 let asset_ids = reader.get_assets()?;
323 let mut request = snapshot.get_asset_metadata_request();
325 request.get().set_assets(asset_ids)?;
326 let response = request.send().promise.await?;
327 let reader = response.get()?;
328 results.push((
329 path,
330 reader
331 .get_assets()?
332 .into_iter()
333 .map(|a| parse_db_metadata(&a))
334 .collect::<Vec<_>>(),
335 ));
336 }
337 Ok(results)
338}
339
340fn process_requests(runtime: &mut RpcRuntime, requests: &mut QueuedRequests) {
341 if let InternalConnectionState::Connected(connection) = &runtime.connection {
342 let len = requests.data_requests.len();
343 for asset in requests.data_requests.drain(0..len) {
344 let snapshot = connection.snapshot.clone();
345 runtime.local.spawn_local(async move {
346 match do_import_artifact_request(&asset, &snapshot).await {
347 Ok(data) => {
348 asset.complete(data);
349 }
350 Err(e) => {
351 asset.error(e);
352 }
353 }
354 });
355 }
356
357 let len = requests.metadata_requests.len();
358 for m in requests.metadata_requests.drain(0..len) {
359 let snapshot = connection.snapshot.clone();
360 runtime.local.spawn_local(async move {
361 match do_metadata_request(&m, &snapshot).await {
362 Ok(data) => {
363 m.complete(data);
364 }
365 Err(e) => {
366 m.error(e);
367 }
368 }
369 });
370 }
371
372 let len = requests.resolve_requests.len();
373 for m in requests.resolve_requests.drain(0..len) {
374 let snapshot = connection.snapshot.clone();
375 runtime.local.spawn_local(async move {
376 match do_resolve_request(&m, &snapshot).await {
377 Ok(data) => {
378 m.complete(data);
379 }
380 Err(e) => {
381 m.error(e);
382 }
383 }
384 });
385 }
386 }
387}
388
389struct ListenerImpl {
390 snapshot_channel: Sender<SnapshotChange>,
391 snapshot_change: Option<u64>,
392}
393impl asset_hub::listener::Server for ListenerImpl {
394 fn update(
395 &mut self,
396 params: asset_hub::listener::UpdateParams,
397 _results: asset_hub::listener::UpdateResults,
398 ) -> Promise<()> {
399 let params = pry!(params.get());
400 let snapshot = pry!(params.get_snapshot());
401 log::trace!(
402 "ListenerImpl::update self.snapshot_change: {:?}",
403 self.snapshot_change
404 );
405 if let Some(change_num) = self.snapshot_change {
406 let channel = self.snapshot_channel.clone();
407 let mut request = snapshot.get_asset_changes_request();
408 request.get().set_start(change_num);
409 request
410 .get()
411 .set_count(params.get_latest_change() - change_num);
412 return Promise::from_future(async move {
413 let response = request.send().promise.await?;
414 let response = response.get()?;
415
416 let mut changed_assets = Vec::new();
417 let mut deleted_assets = Vec::new();
418 let mut changed_paths = Vec::new();
419 let mut deleted_paths = Vec::new();
420
421 for change in response.get_changes()? {
422 match change.get_event()?.which()? {
423 asset_change_event::ContentUpdateEvent(evt) => {
424 let id = utils::make_array(evt?.get_id()?.get_id()?);
425 log::trace!("ListenerImpl::update asset_change_event::ContentUpdateEvent(evt) id: {:?}", id);
426 changed_assets.push(id);
427 }
428 asset_change_event::RemoveEvent(evt) => {
429 let id = utils::make_array(evt?.get_id()?.get_id()?);
430 log::trace!(
431 "ListenerImpl::update asset_change_event::RemoveEvent(evt) id: {:?}",
432 id
433 );
434 deleted_assets.push(id);
435 }
436 asset_change_event::PathRemoveEvent(evt) => {
437 deleted_paths
438 .push(PathBuf::from(std::str::from_utf8(evt?.get_path()?)?));
439 }
440 asset_change_event::PathUpdateEvent(evt) => {
441 changed_paths
442 .push(PathBuf::from(std::str::from_utf8(evt?.get_path()?)?));
443 }
444 }
445 }
446
447 channel
448 .send(SnapshotChange {
449 snapshot,
450 changed_assets,
451 deleted_assets,
452 deleted_paths,
453 changed_paths,
454 })
455 .map_err(|_| capnp::Error::failed("Could not send SnapshotChange".into()))
456 });
457 } else {
458 let _ = self.snapshot_channel.try_send(SnapshotChange {
459 snapshot,
460 changed_assets: Vec::new(),
461 deleted_assets: Vec::new(),
462 changed_paths: Vec::new(),
463 deleted_paths: Vec::new(),
464 });
465 }
466 self.snapshot_change = Some(params.get_latest_change());
467 Promise::ok(())
468 }
469}