1use std::{collections::HashMap, sync::Arc};
16
17use crate::{
18 client::Client,
19 do_not_modify::comms::OwnedComms,
20 do_not_modify::{
21 alloc::TypedAlloc,
22 alloc_inline::{InlineAllocator, InlineTensorStorage},
23 types::{Device, RPCRequestData, RPCResponseData, SealHandle, Tensor},
24 },
25 types::{Handle, RunnerOpt, TensorStorage},
26};
27
28use futures::Stream;
29use lunchbox::types::{MaybeSend, MaybeSync};
30
31pub struct Runner {
32 client: Client,
33}
34
35impl Runner {
36 #[cfg(not(target_family = "wasm"))]
37 pub async fn new(
38 runner_path: &std::path::Path,
39 visible_device: Device,
40 ) -> Result<Runner, String> {
41 use tokio::process::Command;
42
43 if !runner_path.exists() {
45 return Err("Runner doesn't exist".into());
46 }
47
48 let (comms, uds_path) = OwnedComms::new().await;
50
51 let mut command = Command::new(runner_path);
53
54 if let Device::GPU { uuid: Some(uuid) } = visible_device {
56 command.env("CUDA_VISIBLE_DEVICES", uuid);
58 } else {
59 command.env("CUDA_VISIBLE_DEVICES", "");
61 }
62
63 command
64 .args(["--uds-path", uds_path.to_str().unwrap()])
65 .spawn()
66 .expect("Runner failed to start");
67
68 let client = Client::new(comms).await;
70
71 Ok(Self { client })
72 }
73
74 #[cfg(target_family = "wasm")]
75 pub async fn new() -> Result<Runner, String> {
76 let comms = OwnedComms::new().await;
78
79 let client = Client::new(comms).await;
81
82 Ok(Self { client })
83 }
84
85 pub async fn load<T>(
86 &self,
87 fs: &Arc<T>,
88 runner_name: String,
89 required_framework_version: semver::VersionReq,
90 runner_compat_version: u64,
91 runner_opts: Option<HashMap<String, RunnerOpt>>,
92 visible_device: Device,
93 carton_manifest_hash: Option<String>,
94 ) -> Result<(), String>
95 where
96 T: lunchbox::ReadableFileSystem + MaybeSend + MaybeSync + 'static,
97 T::FileType: lunchbox::types::ReadableFile + MaybeSend + MaybeSync + Unpin,
98 T::ReadDirPollerType: MaybeSend,
99 {
100 let token = self.client.serve_readonly_fs(fs.clone()).await;
102
103 match self
104 .client
105 .do_rpc(RPCRequestData::Load {
106 fs: token,
107 runner_name,
108 required_framework_version,
109 runner_compat_version,
110 runner_opts,
111 visible_device,
112 carton_manifest_hash,
113 })
114 .await
115 {
116 RPCResponseData::Load => Ok(()),
117 RPCResponseData::Error { e } => Err(e),
118 _ => panic!("Unexpected RPC response type!"),
119 }
120 }
121
122 pub async fn infer_with_inputs(
131 &self,
132 tensors_orig: HashMap<String, Tensor>,
133 ) -> Result<HashMap<String, Tensor>, String> {
134 let comms = self.client.get_comms();
136 let mut tensors = HashMap::new();
137 for (k, v) in tensors_orig.into_iter() {
138 tensors.insert(k, Handle::new(v, comms).await);
139 }
140
141 match self
142 .client
143 .do_rpc(RPCRequestData::InferWithTensors {
144 tensors,
145 streaming: false,
146 })
147 .await
148 {
149 RPCResponseData::Infer { tensors } => {
150 let mut out = HashMap::new();
151 for (k, v) in tensors.into_iter() {
152 out.insert(k, v.into_inner(comms).await);
153 }
154
155 Ok(out)
156 }
157 RPCResponseData::Error { e } => Err(e),
158 _ => panic!("Unexpected RPC response type!"),
159 }
160 }
161
162 pub async fn streaming_infer_with_inputs(
163 &self,
164 tensors_orig: HashMap<String, Tensor>,
165 ) -> impl Stream<Item = Result<HashMap<String, Tensor>, String>> + '_ {
166 let comms = self.client.get_comms();
168 let mut tensors = HashMap::new();
169 for (k, v) in tensors_orig.into_iter() {
170 tensors.insert(k, Handle::new(v, comms).await);
171 }
172
173 let mut res = self
174 .client
175 .do_streaming_rpc(RPCRequestData::InferWithTensors {
176 tensors,
177 streaming: true,
178 })
179 .await;
180
181 async_stream::stream! {
182 while let Some(v) = res.recv().await {
183 match v {
184 RPCResponseData::Infer { tensors } => {
185 let mut out = HashMap::new();
186 for (k, v) in tensors.into_iter() {
187 out.insert(k, v.into_inner(comms).await);
188 }
189
190 yield Ok(out)
191 }
192 RPCResponseData::Error { e } => yield Err(e),
193 RPCResponseData::Empty => { } _ => panic!("Unexpected RPC response type!"),
195 }
196 }
197 }
198 }
199
200 pub async fn seal(&self, tensors_orig: HashMap<String, Tensor>) -> Result<u64, String> {
201 let comms = self.client.get_comms();
203 let mut tensors = HashMap::new();
204 for (k, v) in tensors_orig.into_iter() {
205 tensors.insert(k, Handle::new(v, comms).await);
206 }
207
208 match self.client.do_rpc(RPCRequestData::Seal { tensors }).await {
209 RPCResponseData::Seal { handle } => Ok(handle.0),
210 RPCResponseData::Error { e } => Err(e),
211 _ => panic!("Unexpected RPC response type!"),
212 }
213 }
214
215 pub async fn infer_with_handle(&self, handle: u64) -> Result<HashMap<String, Tensor>, String> {
216 let comms = self.client.get_comms();
217
218 match self
219 .client
220 .do_rpc(RPCRequestData::InferWithHandle {
221 handle: SealHandle(handle),
222 streaming: false,
223 })
224 .await
225 {
226 RPCResponseData::Infer { tensors } => {
227 let mut out = HashMap::new();
228 for (k, v) in tensors.into_iter() {
229 out.insert(k, v.into_inner(comms).await);
230 }
231
232 Ok(out)
233 }
234 RPCResponseData::Error { e } => Err(e),
235 _ => panic!("Unexpected RPC response type!"),
236 }
237 }
238
239 pub async fn streaming_infer_with_handle(
240 &self,
241 handle: u64,
242 ) -> impl Stream<Item = Result<HashMap<String, Tensor>, String>> + '_ {
243 let comms = self.client.get_comms();
244
245 let mut res = self
246 .client
247 .do_streaming_rpc(RPCRequestData::InferWithHandle {
248 handle: SealHandle(handle),
249 streaming: true,
250 })
251 .await;
252
253 async_stream::stream! {
254 while let Some(v) = res.recv().await {
255 match v {
256 RPCResponseData::Infer { tensors } => {
257 let mut out = HashMap::new();
258 for (k, v) in tensors.into_iter() {
259 out.insert(k, v.into_inner(comms).await);
260 }
261
262 yield Ok(out)
263 }
264 RPCResponseData::Error { e } => yield Err(e),
265 RPCResponseData::Empty => { } _ => panic!("Unexpected RPC response type!"),
267 }
268 }
269 }
270 }
271
272 pub async fn pack<T>(
274 &self,
275 fs: &Arc<T>,
276 input_path: &lunchbox::path::Path,
277 temp_folder: &lunchbox::path::Path,
278 ) -> Result<lunchbox::path::PathBuf, String>
279 where
280 T: lunchbox::WritableFileSystem + MaybeSend + MaybeSync + 'static,
281 T::FileType: lunchbox::types::WritableFile + MaybeSend + MaybeSync + Unpin,
282 T::ReadDirPollerType: MaybeSend,
283 {
284 let token = self.client.serve_writable_fs(fs.clone()).await;
286
287 match self
288 .client
289 .do_rpc(RPCRequestData::Pack {
290 fs: token,
291 input_path: input_path.to_string(),
292 temp_folder: temp_folder.to_string(),
293 })
294 .await
295 {
296 RPCResponseData::Pack { output_path } => Ok(output_path.into()),
297 RPCResponseData::Error { e } => Err(e),
298 _ => panic!("Unexpected RPC response type!"),
299 }
300 }
301
302 pub async fn alloc_tensor<T: Clone + Default>(&self, shape: Vec<u64>) -> Result<Tensor, String>
303 where
304 InlineAllocator: TypedAlloc<T, Output = InlineTensorStorage>,
305 Tensor: From<TensorStorage<T>>,
306 {
307 Ok(TensorStorage::new(shape).into())
308 }
309
310 }