Skip to main content

carton_runner_interface/
runner.rs

1// Copyright 2023 Vivek Panyam
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{collections::HashMap, sync::Arc};
16
17use crate::{
18    client::Client,
19    do_not_modify::comms::OwnedComms,
20    do_not_modify::{
21        alloc::TypedAlloc,
22        alloc_inline::{InlineAllocator, InlineTensorStorage},
23        types::{Device, RPCRequestData, RPCResponseData, SealHandle, Tensor},
24    },
25    types::{Handle, RunnerOpt, TensorStorage},
26};
27
28use futures::Stream;
29use lunchbox::types::{MaybeSend, MaybeSync};
30
31pub struct Runner {
32    client: Client,
33}
34
35impl Runner {
36    #[cfg(not(target_family = "wasm"))]
37    pub async fn new(
38        runner_path: &std::path::Path,
39        visible_device: Device,
40    ) -> Result<Runner, String> {
41        use tokio::process::Command;
42
43        // Make sure the runner exists
44        if !runner_path.exists() {
45            return Err("Runner doesn't exist".into());
46        }
47
48        // Create comms
49        let (comms, uds_path) = OwnedComms::new().await;
50
51        // Create a command to start the runner
52        let mut command = Command::new(runner_path);
53
54        // Check if we have a UUID for a GPU
55        if let Device::GPU { uuid: Some(uuid) } = visible_device {
56            // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars
57            command.env("CUDA_VISIBLE_DEVICES", uuid);
58        } else {
59            // Hide all GPUs
60            command.env("CUDA_VISIBLE_DEVICES", "");
61        }
62
63        command
64            .args(["--uds-path", uds_path.to_str().unwrap()])
65            .spawn()
66            .expect("Runner failed to start");
67
68        // Create a client
69        let client = Client::new(comms).await;
70
71        Ok(Self { client })
72    }
73
74    #[cfg(target_family = "wasm")]
75    pub async fn new() -> Result<Runner, String> {
76        // Create comms
77        let comms = OwnedComms::new().await;
78
79        // Create a client
80        let client = Client::new(comms).await;
81
82        Ok(Self { client })
83    }
84
85    pub async fn load<T>(
86        &self,
87        fs: &Arc<T>,
88        runner_name: String,
89        required_framework_version: semver::VersionReq,
90        runner_compat_version: u64,
91        runner_opts: Option<HashMap<String, RunnerOpt>>,
92        visible_device: Device,
93        carton_manifest_hash: Option<String>,
94    ) -> Result<(), String>
95    where
96        T: lunchbox::ReadableFileSystem + MaybeSend + MaybeSync + 'static,
97        T::FileType: lunchbox::types::ReadableFile + MaybeSend + MaybeSync + Unpin,
98        T::ReadDirPollerType: MaybeSend,
99    {
100        // Serve the filesystem
101        let token = self.client.serve_readonly_fs(fs.clone()).await;
102
103        match self
104            .client
105            .do_rpc(RPCRequestData::Load {
106                fs: token,
107                runner_name,
108                required_framework_version,
109                runner_compat_version,
110                runner_opts,
111                visible_device,
112                carton_manifest_hash,
113            })
114            .await
115        {
116            RPCResponseData::Load => Ok(()),
117            RPCResponseData::Error { e } => Err(e),
118            _ => panic!("Unexpected RPC response type!"),
119        }
120    }
121
122    // pub async fn seal(&self, tensors: HashMap<String, Tensor>) -> Result<SealHandle, String> {
123    //     match self.client.do_rpc(RPCRequestData::Seal { tensors }).await {
124    //         RPCResponseData::Seal { handle } => Ok(handle),
125    //         RPCResponseData::Error { e } => Err(e),
126    //         _ => panic!("Unexpected RPC response type!"),
127    //     }
128    // }
129
130    pub async fn infer_with_inputs(
131        &self,
132        tensors_orig: HashMap<String, Tensor>,
133    ) -> Result<HashMap<String, Tensor>, String> {
134        // Wrap each tensor in a handle (this possibly sends the fd for backing SHM chunks to the other process)
135        let comms = self.client.get_comms();
136        let mut tensors = HashMap::new();
137        for (k, v) in tensors_orig.into_iter() {
138            tensors.insert(k, Handle::new(v, comms).await);
139        }
140
141        match self
142            .client
143            .do_rpc(RPCRequestData::InferWithTensors {
144                tensors,
145                streaming: false,
146            })
147            .await
148        {
149            RPCResponseData::Infer { tensors } => {
150                let mut out = HashMap::new();
151                for (k, v) in tensors.into_iter() {
152                    out.insert(k, v.into_inner(comms).await);
153                }
154
155                Ok(out)
156            }
157            RPCResponseData::Error { e } => Err(e),
158            _ => panic!("Unexpected RPC response type!"),
159        }
160    }
161
162    pub async fn streaming_infer_with_inputs(
163        &self,
164        tensors_orig: HashMap<String, Tensor>,
165    ) -> impl Stream<Item = Result<HashMap<String, Tensor>, String>> + '_ {
166        // Wrap each tensor in a handle (this possibly sends the fd for backing SHM chunks to the other process)
167        let comms = self.client.get_comms();
168        let mut tensors = HashMap::new();
169        for (k, v) in tensors_orig.into_iter() {
170            tensors.insert(k, Handle::new(v, comms).await);
171        }
172
173        let mut res = self
174            .client
175            .do_streaming_rpc(RPCRequestData::InferWithTensors {
176                tensors,
177                streaming: true,
178            })
179            .await;
180
181        async_stream::stream! {
182            while let Some(v) = res.recv().await {
183                match v {
184                    RPCResponseData::Infer { tensors } => {
185                        let mut out = HashMap::new();
186                        for (k, v) in tensors.into_iter() {
187                            out.insert(k, v.into_inner(comms).await);
188                        }
189
190                        yield Ok(out)
191                    }
192                    RPCResponseData::Error { e } => yield Err(e),
193                    RPCResponseData::Empty => { } // We can get this on the last message. Do nothing
194                    _ => panic!("Unexpected RPC response type!"),
195                }
196            }
197        }
198    }
199
200    pub async fn seal(&self, tensors_orig: HashMap<String, Tensor>) -> Result<u64, String> {
201        // Wrap each tensor in a handle (this possibly sends the fd for backing SHM chunks to the other process)
202        let comms = self.client.get_comms();
203        let mut tensors = HashMap::new();
204        for (k, v) in tensors_orig.into_iter() {
205            tensors.insert(k, Handle::new(v, comms).await);
206        }
207
208        match self.client.do_rpc(RPCRequestData::Seal { tensors }).await {
209            RPCResponseData::Seal { handle } => Ok(handle.0),
210            RPCResponseData::Error { e } => Err(e),
211            _ => panic!("Unexpected RPC response type!"),
212        }
213    }
214
215    pub async fn infer_with_handle(&self, handle: u64) -> Result<HashMap<String, Tensor>, String> {
216        let comms = self.client.get_comms();
217
218        match self
219            .client
220            .do_rpc(RPCRequestData::InferWithHandle {
221                handle: SealHandle(handle),
222                streaming: false,
223            })
224            .await
225        {
226            RPCResponseData::Infer { tensors } => {
227                let mut out = HashMap::new();
228                for (k, v) in tensors.into_iter() {
229                    out.insert(k, v.into_inner(comms).await);
230                }
231
232                Ok(out)
233            }
234            RPCResponseData::Error { e } => Err(e),
235            _ => panic!("Unexpected RPC response type!"),
236        }
237    }
238
239    pub async fn streaming_infer_with_handle(
240        &self,
241        handle: u64,
242    ) -> impl Stream<Item = Result<HashMap<String, Tensor>, String>> + '_ {
243        let comms = self.client.get_comms();
244
245        let mut res = self
246            .client
247            .do_streaming_rpc(RPCRequestData::InferWithHandle {
248                handle: SealHandle(handle),
249                streaming: true,
250            })
251            .await;
252
253        async_stream::stream! {
254            while let Some(v) = res.recv().await {
255                match v {
256                    RPCResponseData::Infer { tensors } => {
257                        let mut out = HashMap::new();
258                        for (k, v) in tensors.into_iter() {
259                            out.insert(k, v.into_inner(comms).await);
260                        }
261
262                        yield Ok(out)
263                    }
264                    RPCResponseData::Error { e } => yield Err(e),
265                    RPCResponseData::Empty => { } // We can get this on the last message. Do nothing
266                    _ => panic!("Unexpected RPC response type!"),
267                }
268            }
269        }
270    }
271
272    /// Pack a model and return a path to the output directory
273    pub async fn pack<T>(
274        &self,
275        fs: &Arc<T>,
276        input_path: &lunchbox::path::Path,
277        temp_folder: &lunchbox::path::Path,
278    ) -> Result<lunchbox::path::PathBuf, String>
279    where
280        T: lunchbox::WritableFileSystem + MaybeSend + MaybeSync + 'static,
281        T::FileType: lunchbox::types::WritableFile + MaybeSend + MaybeSync + Unpin,
282        T::ReadDirPollerType: MaybeSend,
283    {
284        // Serve the filesystem
285        let token = self.client.serve_writable_fs(fs.clone()).await;
286
287        match self
288            .client
289            .do_rpc(RPCRequestData::Pack {
290                fs: token,
291                input_path: input_path.to_string(),
292                temp_folder: temp_folder.to_string(),
293            })
294            .await
295        {
296            RPCResponseData::Pack { output_path } => Ok(output_path.into()),
297            RPCResponseData::Error { e } => Err(e),
298            _ => panic!("Unexpected RPC response type!"),
299        }
300    }
301
302    pub async fn alloc_tensor<T: Clone + Default>(&self, shape: Vec<u64>) -> Result<Tensor, String>
303    where
304        InlineAllocator: TypedAlloc<T, Output = InlineTensorStorage>,
305        Tensor: From<TensorStorage<T>>,
306    {
307        Ok(TensorStorage::new(shape).into())
308    }
309
310    // pub async fn infer_with_handle(
311    //     &self,
312    //     handle: SealHandle,
313    // ) -> Result<HashMap<String, Tensor>, String> {
314    //     match self
315    //         .client
316    //         .do_rpc(RPCRequestData::InferWithHandle { handle })
317    //         .await
318    //     {
319    //         RPCResponseData::Infer { tensors } => Ok(tensors),
320    //         RPCResponseData::Error { e } => Err(e),
321    //         _ => panic!("Unexpected RPC response type!"),
322    //     }
323    // }
324}