distill_loader/
rpc_io.rs

1use std::{error::Error, path::PathBuf, sync::Mutex};
2
3use capnp::message::ReaderOptions;
4use capnp_rpc::{pry, rpc_twoparty_capnp, twoparty, RpcSystem};
5use crossbeam_channel::{unbounded, Receiver, Sender};
6use distill_core::{utils, AssetMetadata, AssetUuid};
7use distill_schema::{data::asset_change_event, parse_db_metadata, service::asset_hub};
8use futures_util::AsyncReadExt;
9use tokio::{
10    net::TcpStream,
11    runtime::{Builder, Runtime},
12    sync::oneshot,
13};
14
15use crate::{
16    io::{DataRequest, LoaderIO, MetadataRequest, MetadataRequestResult, ResolveRequest},
17    loader::LoaderState,
18};
19
20type Promise<T> = capnp::capability::Promise<T, capnp::Error>;
21
22/// a connection to the capnp provided rpc and an event receiver for SnapshotChange events
23struct RpcConnection {
24    snapshot: asset_hub::snapshot::Client,
25    snapshot_rx: Receiver<SnapshotChange>,
26}
27
28/// an event which represents change to the assets
29struct SnapshotChange {
30    snapshot: asset_hub::snapshot::Client,
31    changed_assets: Vec<AssetUuid>,
32    deleted_assets: Vec<AssetUuid>,
33    changed_paths: Vec<PathBuf>,
34    deleted_paths: Vec<PathBuf>,
35}
36
37enum InternalConnectionState {
38    None,
39    Connecting(oneshot::Receiver<Result<RpcConnection, Box<dyn Error>>>),
40    Connected(RpcConnection),
41    Error(Box<dyn Error>),
42}
43
44/// the tokio::Runtime and tasks, as well as the connection state
45struct RpcRuntime {
46    runtime: Runtime,
47    local: tokio::task::LocalSet,
48    connection: InternalConnectionState,
49}
50
51// While capnp_rpc does not impl Send or Sync, in our usage of the API there can only be one thread
52// accessing the internal state at any time due to Mutex. The !Send constraint in capnp_rpc is because
53// of internal object lifetime management that is unsafe in the face of multiple threads accessing data
54// from separate objects.
55unsafe impl Send for RpcRuntime {}
56
57impl RpcRuntime {
58    fn check_asset_changes(&mut self, loader: &LoaderState) {
59        self.connection =
60            match std::mem::replace(&mut self.connection, InternalConnectionState::None) {
61                InternalConnectionState::Connected(mut conn) => {
62                    if let Ok(change) = conn.snapshot_rx.try_recv() {
63                        log::trace!("RpcRuntime check_asset_changes Ok(change)");
64                        conn.snapshot = change.snapshot;
65                        let mut changed_assets = Vec::new();
66                        for asset in change.changed_assets {
67                            log::trace!(
68                                "RpcRuntime check_asset_changes changed asset.id: {:?}",
69                                asset
70                            );
71                            changed_assets.push(asset);
72                        }
73                        for asset in change.deleted_assets {
74                            log::trace!(
75                                "RpcRuntime check_asset_changes deleted asset.id: {:?}",
76                                asset
77                            );
78                            changed_assets.push(asset);
79                        }
80                        loader.invalidate_assets(&changed_assets);
81                        let mut changed_paths = Vec::new();
82                        for path in change.changed_paths {
83                            changed_paths.push(path);
84                        }
85                        for path in change.deleted_paths {
86                            changed_paths.push(path);
87                        }
88                        loader.invalidate_paths(&changed_paths);
89                    }
90                    InternalConnectionState::Connected(conn)
91                }
92                c => c,
93            };
94    }
95
96    fn connect(&mut self, connect_string: String) {
97        match self.connection {
98            InternalConnectionState::Connected(_) | InternalConnectionState::Connecting(_) => {
99                panic!("Trying to connect while already connected or connecting")
100            }
101            _ => {}
102        };
103
104        let (conn_tx, conn_rx) = oneshot::channel();
105
106        self.local.spawn_local(async move {
107            let result = async move {
108                log::trace!("Tcp connect to {:?}", connect_string);
109                let stream = TcpStream::connect(connect_string).await?;
110                stream.set_nodelay(true)?;
111
112                use tokio_util::compat::*;
113                let (reader, writer) = stream.compat().split();
114
115                log::trace!("Creating capnp VatNetwork");
116                let rpc_network = Box::new(twoparty::VatNetwork::new(
117                    reader,
118                    writer,
119                    rpc_twoparty_capnp::Side::Client,
120                    *ReaderOptions::new()
121                        .nesting_limit(64)
122                        .traversal_limit_in_words(Some(256 * 1024 * 1024)),
123                ));
124
125                let mut rpc_system = RpcSystem::new(rpc_network, None);
126
127                let hub: asset_hub::Client = rpc_system.bootstrap(rpc_twoparty_capnp::Side::Server);
128
129                let _disconnector = rpc_system.get_disconnector();
130
131                tokio::task::spawn_local(rpc_system);
132
133                log::trace!("Requesting RPC snapshot..");
134                let response = hub.get_snapshot_request().send().promise.await?;
135
136                let snapshot = response.get()?.get_snapshot()?;
137                log::trace!("Received snapshot, registering listener..");
138                let (snapshot_tx, snapshot_rx) = unbounded();
139                let listener: asset_hub::listener::Client = capnp_rpc::new_client(ListenerImpl {
140                    snapshot_channel: snapshot_tx,
141                    snapshot_change: None,
142                });
143
144                let mut request = hub.register_listener_request();
145                request.get().set_listener(listener);
146                let rpc_conn = request.send().promise.await.map(|_| RpcConnection {
147                    snapshot,
148                    snapshot_rx,
149                })?;
150                log::trace!("Registered listener, done connecting RPC loader.");
151
152                Ok(rpc_conn)
153            }
154            .await;
155            let _ = conn_tx.send(result);
156        });
157
158        self.connection = InternalConnectionState::Connecting(conn_rx)
159    }
160}
161
162pub struct RpcIO {
163    connect_string: String,
164    runtime: Mutex<RpcRuntime>,
165    requests: QueuedRequests,
166}
167
168#[derive(Default)]
169struct QueuedRequests {
170    data_requests: Vec<DataRequest>,
171    metadata_requests: Vec<MetadataRequest>,
172    resolve_requests: Vec<ResolveRequest>,
173}
174
175impl Default for RpcIO {
176    fn default() -> RpcIO {
177        RpcIO::new("127.0.0.1:9999".to_string()).unwrap()
178    }
179}
180
181impl RpcIO {
182    pub fn new(connect_string: String) -> std::io::Result<RpcIO> {
183        Ok(RpcIO {
184            connect_string,
185            runtime: Mutex::new(RpcRuntime {
186                runtime: Builder::new_current_thread().enable_all().build()?,
187                local: tokio::task::LocalSet::new(),
188                connection: InternalConnectionState::None,
189            }),
190            requests: Default::default(),
191        })
192    }
193}
194
195impl LoaderIO for RpcIO {
196    fn get_asset_metadata_with_dependencies(&mut self, request: MetadataRequest) {
197        self.requests.metadata_requests.push(request);
198        let mut runtime = self.runtime.lock().unwrap();
199        process_requests(&mut runtime, &mut self.requests);
200    }
201
202    fn get_asset_candidates(&mut self, requests: Vec<ResolveRequest>) {
203        self.requests.resolve_requests.extend(requests);
204        let mut runtime = self.runtime.lock().unwrap();
205        process_requests(&mut runtime, &mut self.requests);
206    }
207
208    fn get_artifacts(&mut self, requests: Vec<DataRequest>) {
209        self.requests.data_requests.extend(requests);
210        let mut runtime = self.runtime.lock().unwrap();
211        process_requests(&mut runtime, &mut self.requests);
212    }
213
214    fn tick(&mut self, loader: &mut LoaderState) {
215        let mut runtime = self.runtime.lock().unwrap();
216
217        match &runtime.connection {
218            InternalConnectionState::Error(err) => {
219                log::error!("Error connecting RpcIO: {}", err);
220                runtime.connect(self.connect_string.clone());
221            }
222            InternalConnectionState::None => {
223                runtime.connect(self.connect_string.clone());
224            }
225            _ => {}
226        };
227
228        process_requests(&mut runtime, &mut self.requests);
229
230        runtime.connection =
231            match std::mem::replace(&mut runtime.connection, InternalConnectionState::None) {
232                // update connection state
233                InternalConnectionState::Connecting(mut pending_connection) => {
234                    match pending_connection.try_recv() {
235                        Ok(connection_result) => match connection_result {
236                            Ok(conn) => InternalConnectionState::Connected(conn),
237                            Err(err) => InternalConnectionState::Error(err),
238                        },
239                        Err(oneshot::error::TryRecvError::Closed) => {
240                            InternalConnectionState::Error(Box::new(
241                                oneshot::error::TryRecvError::Closed,
242                            ))
243                        }
244                        Err(oneshot::error::TryRecvError::Empty) => {
245                            InternalConnectionState::Connecting(pending_connection)
246                        }
247                    }
248                }
249                c => c,
250            };
251
252        runtime
253            .local
254            .block_on(&runtime.runtime, tokio::task::yield_now());
255
256        runtime.check_asset_changes(loader);
257    }
258
259    fn with_runtime(&self, f: &mut dyn FnMut(&tokio::runtime::Runtime)) {
260        let runtime = self.runtime.lock().unwrap();
261        f(&runtime.runtime)
262    }
263}
264
265async fn do_metadata_request(
266    asset: &MetadataRequest,
267    snapshot: &asset_hub::snapshot::Client,
268) -> Result<Vec<MetadataRequestResult>, capnp::Error> {
269    let mut request = snapshot.get_asset_metadata_with_dependencies_request();
270    let mut assets = request
271        .get()
272        .init_assets(asset.requested_assets().count() as u32);
273    for (idx, asset) in asset.requested_assets().enumerate() {
274        assets.reborrow().get(idx as u32).set_id(&asset.0);
275    }
276    let response = request.send().promise.await?;
277    let reader = response.get()?;
278    let artifacts = reader
279        .get_assets()?
280        .into_iter()
281        .map(|a| parse_db_metadata(&a))
282        .filter(|a| a.artifact.is_some())
283        .map(|a| MetadataRequestResult {
284            artifact_metadata: a.artifact.clone().unwrap(),
285            asset_metadata: if asset.include_asset_metadata() {
286                Some(a)
287            } else {
288                None
289            },
290        })
291        .collect::<Vec<_>>();
292    Ok(artifacts)
293}
294
295async fn do_import_artifact_request(
296    asset: &DataRequest,
297    snapshot: &asset_hub::snapshot::Client,
298) -> Result<Vec<u8>, capnp::Error> {
299    let mut request = snapshot.get_import_artifacts_request();
300    let mut assets = request.get().init_assets(1);
301    assets.reborrow().get(0).set_id(&asset.asset_id().0);
302    let response = request.send().promise.await?;
303    let reader = response.get()?;
304    let artifact = reader.get_artifacts()?.get(0);
305    Ok(Vec::from(artifact.get_data()?))
306}
307
308async fn do_resolve_request(
309    resolve: &ResolveRequest,
310    snapshot: &asset_hub::snapshot::Client,
311) -> Result<Vec<(PathBuf, Vec<AssetMetadata>)>, capnp::Error> {
312    let path = resolve.identifier().path();
313    // get asset IDs at path
314    let mut request = snapshot.get_assets_for_paths_request();
315    let mut paths = request.get().init_paths(1);
316    paths.reborrow().set(0, path.as_bytes());
317    let response = request.send().promise.await?;
318    let reader = response.get()?;
319    let mut results = Vec::new();
320    for reader in reader.get_assets()? {
321        let path = PathBuf::from(std::str::from_utf8(reader.get_path()?)?);
322        let asset_ids = reader.get_assets()?;
323        // get metadata for the assetIDs
324        let mut request = snapshot.get_asset_metadata_request();
325        request.get().set_assets(asset_ids)?;
326        let response = request.send().promise.await?;
327        let reader = response.get()?;
328        results.push((
329            path,
330            reader
331                .get_assets()?
332                .into_iter()
333                .map(|a| parse_db_metadata(&a))
334                .collect::<Vec<_>>(),
335        ));
336    }
337    Ok(results)
338}
339
340fn process_requests(runtime: &mut RpcRuntime, requests: &mut QueuedRequests) {
341    if let InternalConnectionState::Connected(connection) = &runtime.connection {
342        let len = requests.data_requests.len();
343        for asset in requests.data_requests.drain(0..len) {
344            let snapshot = connection.snapshot.clone();
345            runtime.local.spawn_local(async move {
346                match do_import_artifact_request(&asset, &snapshot).await {
347                    Ok(data) => {
348                        asset.complete(data);
349                    }
350                    Err(e) => {
351                        asset.error(e);
352                    }
353                }
354            });
355        }
356
357        let len = requests.metadata_requests.len();
358        for m in requests.metadata_requests.drain(0..len) {
359            let snapshot = connection.snapshot.clone();
360            runtime.local.spawn_local(async move {
361                match do_metadata_request(&m, &snapshot).await {
362                    Ok(data) => {
363                        m.complete(data);
364                    }
365                    Err(e) => {
366                        m.error(e);
367                    }
368                }
369            });
370        }
371
372        let len = requests.resolve_requests.len();
373        for m in requests.resolve_requests.drain(0..len) {
374            let snapshot = connection.snapshot.clone();
375            runtime.local.spawn_local(async move {
376                match do_resolve_request(&m, &snapshot).await {
377                    Ok(data) => {
378                        m.complete(data);
379                    }
380                    Err(e) => {
381                        m.error(e);
382                    }
383                }
384            });
385        }
386    }
387}
388
389struct ListenerImpl {
390    snapshot_channel: Sender<SnapshotChange>,
391    snapshot_change: Option<u64>,
392}
393impl asset_hub::listener::Server for ListenerImpl {
394    fn update(
395        &mut self,
396        params: asset_hub::listener::UpdateParams,
397        _results: asset_hub::listener::UpdateResults,
398    ) -> Promise<()> {
399        let params = pry!(params.get());
400        let snapshot = pry!(params.get_snapshot());
401        log::trace!(
402            "ListenerImpl::update self.snapshot_change: {:?}",
403            self.snapshot_change
404        );
405        if let Some(change_num) = self.snapshot_change {
406            let channel = self.snapshot_channel.clone();
407            let mut request = snapshot.get_asset_changes_request();
408            request.get().set_start(change_num);
409            request
410                .get()
411                .set_count(params.get_latest_change() - change_num);
412            return Promise::from_future(async move {
413                let response = request.send().promise.await?;
414                let response = response.get()?;
415
416                let mut changed_assets = Vec::new();
417                let mut deleted_assets = Vec::new();
418                let mut changed_paths = Vec::new();
419                let mut deleted_paths = Vec::new();
420
421                for change in response.get_changes()? {
422                    match change.get_event()?.which()? {
423                        asset_change_event::ContentUpdateEvent(evt) => {
424                            let id = utils::make_array(evt?.get_id()?.get_id()?);
425                            log::trace!("ListenerImpl::update asset_change_event::ContentUpdateEvent(evt) id: {:?}", id);
426                            changed_assets.push(id);
427                        }
428                        asset_change_event::RemoveEvent(evt) => {
429                            let id = utils::make_array(evt?.get_id()?.get_id()?);
430                            log::trace!(
431                                "ListenerImpl::update asset_change_event::RemoveEvent(evt) id: {:?}",
432                                id
433                            );
434                            deleted_assets.push(id);
435                        }
436                        asset_change_event::PathRemoveEvent(evt) => {
437                            deleted_paths
438                                .push(PathBuf::from(std::str::from_utf8(evt?.get_path()?)?));
439                        }
440                        asset_change_event::PathUpdateEvent(evt) => {
441                            changed_paths
442                                .push(PathBuf::from(std::str::from_utf8(evt?.get_path()?)?));
443                        }
444                    }
445                }
446
447                channel
448                    .send(SnapshotChange {
449                        snapshot,
450                        changed_assets,
451                        deleted_assets,
452                        deleted_paths,
453                        changed_paths,
454                    })
455                    .map_err(|_| capnp::Error::failed("Could not send SnapshotChange".into()))
456            });
457        } else {
458            let _ = self.snapshot_channel.try_send(SnapshotChange {
459                snapshot,
460                changed_assets: Vec::new(),
461                deleted_assets: Vec::new(),
462                changed_paths: Vec::new(),
463                deleted_paths: Vec::new(),
464            });
465        }
466        self.snapshot_change = Some(params.get_latest_change());
467        Promise::ok(())
468    }
469}