1use std::{
16 any::Any,
17 collections::HashMap,
18 path::{Path, PathBuf},
19 sync::{atomic::AtomicBool, Arc, Mutex},
20};
21
22use anywhere::types::{AnywhereFS, ReadOnlyFS, ReadWriteFS};
23use clap::Parser;
24use tokio::sync::mpsc::{self, error::SendError};
25use tracing_chrome::ChromeLayerBuilder;
26use tracing_subscriber::prelude::*;
27
28use crate::{
29 do_not_modify::comms::Comms,
30 do_not_modify::types::{ChannelId, FsToken, RPCRequest, RPCResponse},
31 multiplexer::Multiplexer,
32 types::{Device, Handle, LogRecord, RPCRequestData, RPCResponseData, RpcId, RunnerOpt, Tensor},
33};
34
35pub struct Server {
36 comms: Comms,
37 fs_multiplexer: Multiplexer<
38 anywhere::transport::serde::RequestMessageType,
39 anywhere::transport::serde::ResponseMessageType,
40 >,
41
42 outgoing: mpsc::Sender<RPCResponse>,
43 incoming: mpsc::Receiver<RPCRequest>,
44
45 _keepalive: Vec<Box<dyn Any + Send + Sync>>,
47
48 is_shutdown: Arc<AtomicBool>,
50}
51
52#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
54pub struct SealHandle(pub(crate) u64);
55
56impl SealHandle {
57 pub fn new(v: u64) -> Self {
58 SealHandle(v)
59 }
60
61 pub fn get(&self) -> u64 {
62 self.0
63 }
64}
65
66impl From<crate::types::SealHandle> for SealHandle {
67 fn from(value: crate::types::SealHandle) -> Self {
68 Self(value.0)
69 }
70}
71
72impl From<SealHandle> for crate::types::SealHandle {
73 fn from(value: SealHandle) -> Self {
74 Self(value.0)
75 }
76}
77
78#[derive(Debug)]
80pub struct Request {
81 pub id: RpcId,
82
83 pub data: RequestData,
84}
85
86impl Request {
87 async fn from(req: RPCRequest, comms: &Comms) -> Self {
88 Request {
89 id: req.id,
90 data: RequestData::from(req.data, comms).await,
91 }
92 }
93}
94
95#[derive(Debug)]
96pub enum RequestData {
97 Load {
98 fs: FsToken,
101
102 runner_name: String,
104 required_framework_version: semver::VersionReq,
105 runner_compat_version: u64,
106 runner_opts: Option<HashMap<String, RunnerOpt>>,
107 visible_device: Device,
108
109 carton_manifest_hash: Option<String>,
112 },
113
114 Pack {
116 fs: FsToken,
118
119 input_path: String,
123
124 temp_folder: String,
129 },
130
131 Seal {
132 tensors: HashMap<String, Tensor>,
133 },
134
135 InferWithTensors {
136 tensors: HashMap<String, Tensor>,
137
138 streaming: bool,
140 },
141
142 InferWithHandle {
143 handle: SealHandle,
144
145 streaming: bool,
147 },
148}
149
150impl RequestData {
151 async fn from(value: RPCRequestData, comms: &Comms) -> Self {
152 let from_handles = |tensors: HashMap<String, Handle<Tensor>>| async {
153 let mut out = HashMap::new();
154 for (k, v) in tensors {
155 out.insert(k, v.into_inner(comms).await);
156 }
157
158 out
159 };
160
161 match value {
162 RPCRequestData::Load {
163 fs,
164 runner_name,
165 required_framework_version,
166 runner_compat_version,
167 runner_opts,
168 visible_device,
169 carton_manifest_hash,
170 } => Self::Load {
171 fs,
172 runner_name,
173 required_framework_version,
174 runner_compat_version,
175 runner_opts,
176 visible_device,
177 carton_manifest_hash,
178 },
179 RPCRequestData::Pack {
180 fs,
181 input_path,
182 temp_folder,
183 } => Self::Pack {
184 fs,
185 input_path,
186 temp_folder,
187 },
188 RPCRequestData::Seal { tensors } => Self::Seal {
189 tensors: from_handles(tensors).await,
190 },
191 RPCRequestData::InferWithTensors { tensors, streaming } => Self::InferWithTensors {
192 tensors: from_handles(tensors).await,
193 streaming,
194 },
195 RPCRequestData::InferWithHandle { handle, streaming } => Self::InferWithHandle {
196 handle: handle.into(),
197 streaming,
198 },
199 }
200 }
201}
202
203#[derive(Debug)]
204pub enum ResponseData {
205 Load,
207
208 Pack {
209 output_path: String,
214 },
215
216 Seal {
217 handle: SealHandle,
218 },
219
220 Infer {
221 tensors: HashMap<String, Tensor>,
222 },
223
224 Error {
226 e: String,
227 },
228
229 LogMessage {
231 record: LogRecord,
232 },
233
234 Empty,
235}
236
237impl ResponseData {
238 async fn to_rpc(self, comms: &Comms) -> RPCResponseData {
239 let into_handles = |tensors: HashMap<String, Tensor>| async {
240 let mut out = HashMap::new();
241 for (k, v) in tensors {
242 out.insert(k, Handle::new(v, comms).await);
243 }
244
245 out
246 };
247
248 match self {
249 ResponseData::Load => RPCResponseData::Load,
250 ResponseData::Pack { output_path } => RPCResponseData::Pack { output_path },
251 ResponseData::Seal { handle } => RPCResponseData::Seal {
252 handle: handle.into(),
253 },
254 ResponseData::Infer { tensors } => RPCResponseData::Infer {
255 tensors: into_handles(tensors).await,
256 },
257 ResponseData::Error { e } => RPCResponseData::Error { e },
258 ResponseData::LogMessage { record } => RPCResponseData::LogMessage { record },
259 ResponseData::Empty => RPCResponseData::Empty,
260 }
261 }
262}
263
264impl Server {
265 async fn connect(path: &Path, logger: Option<&PassThroughLogger>) -> Self {
266 let comms = Comms::connect(path).await;
267
268 let (tx, rx) = comms.get_channel(ChannelId::FileSystem).await;
270 let fs_multiplexer = Multiplexer::new(tx, rx).await;
271
272 let (tx, rx) = comms.get_channel(ChannelId::Rpc).await;
273
274 let is_shutdown = Arc::new(AtomicBool::new(false));
275 if let Some(logger) = logger {
276 let mut messages = logger.get_rx();
277 let out = tx.clone();
278 let is_shutdown = is_shutdown.clone();
279 tokio::spawn(async move {
280 while let Some(record) = messages.recv().await {
281 if is_shutdown.load(std::sync::atomic::Ordering::Relaxed) {
282 break;
283 }
284
285 let status = out
287 .send(RPCResponse {
288 id: 0,
289 complete: true,
290 data: RPCResponseData::LogMessage { record },
291 })
292 .await;
293
294 if let Err(s) = status {
296 if is_shutdown.load(std::sync::atomic::Ordering::Relaxed) {
297 break;
298 } else {
299 Err(s).unwrap()
300 }
301 }
302 }
303 });
304 }
305
306 Server {
307 comms,
308 fs_multiplexer,
309 incoming: rx,
310 outgoing: tx,
311 _keepalive: Vec::new(),
312 is_shutdown,
313 }
314 }
315
316 pub async fn get_next_request(&mut self) -> Option<Request> {
317 match self.incoming.recv().await {
318 Some(req) => Some(Request::from(req, &self.comms).await),
319 None => None,
320 }
321 }
322
323 pub async fn send_response_for_request(
324 &self,
325 req_id: u64,
326 res: ResponseData,
327 ) -> Result<(), SendError<()>> {
328 self.outgoing
329 .send(RPCResponse {
330 id: req_id,
331 complete: true,
332 data: res.to_rpc(&self.comms).await,
333 })
334 .await
335 .map_err(|_| SendError(()))
336 }
337
338 pub async fn send_streaming_response_for_request(
339 &self,
340 req_id: u64,
341 complete: bool,
342 res: ResponseData,
343 ) -> Result<(), SendError<()>> {
344 self.outgoing
345 .send(RPCResponse {
346 id: req_id,
347 complete,
348 data: res.to_rpc(&self.comms).await,
349 })
350 .await
351 .map_err(|_| SendError(()))
352 }
353
354 pub async fn get_writable_filesystem(&self, token: FsToken) -> std::io::Result<ReadWriteFS> {
355 self.get_filesystem_internal(token).await
356 }
357
358 pub async fn get_readonly_filesystem(&self, token: FsToken) -> std::io::Result<ReadOnlyFS> {
359 self.get_filesystem_internal(token).await
360 }
361
362 async fn get_filesystem_internal<const W: bool, const S: bool>(
363 &self,
364 token: FsToken,
365 ) -> std::io::Result<AnywhereFS<W, S>> {
366 let (tx, rx) = self.fs_multiplexer.get_stream_for_id(token.0).await;
367
368 anywhere::transport::serde::connect(tx, rx).await
369 }
370}
371
372impl Drop for Server {
373 fn drop(&mut self) {
374 self.is_shutdown
377 .store(true, std::sync::atomic::Ordering::Relaxed);
378 }
379}
380
381#[derive(Parser, Debug)]
382struct Args {
383 #[arg(long)]
384 uds_path: String,
385}
386
387pub async fn init_runner() -> Server {
389 let args = Args::parse();
390
391 #[cfg(not(target_os = "macos"))]
396 if unsafe { libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGKILL) } != 0 {
397 panic!("prctl failed")
398 }
399
400 #[cfg(target_os = "macos")]
402 std::thread::spawn(|| {
403 loop {
404 let ppid = unsafe { libc::getppid() };
405 if ppid == 1 {
406 std::process::exit(0);
408 }
409
410 std::thread::sleep(std::time::Duration::from_secs(1));
411 }
412 });
413
414 let mut keepalive = None;
416 let mut pass_through_logger = None;
417 match std::env::var("CARTON_RUNNER_TRACE_FILE") {
418 Ok(path) => {
419 let (chrome_layer, _guard) = ChromeLayerBuilder::new()
421 .file(path)
422 .include_args(true)
423 .build();
424 tracing_subscriber::registry().with(chrome_layer).init();
425
426 keepalive = Some(_guard);
427 }
428 Err(_) => {
429 let logger: &'static PassThroughLogger = Box::leak(Box::new(PassThroughLogger::new()));
431 log::set_logger(logger).unwrap();
432 log::set_max_level(log::LevelFilter::Trace);
433
434 pass_through_logger = Some(logger);
435 }
436 };
437
438 let mut s = Server::connect(&PathBuf::from(args.uds_path), pass_through_logger).await;
440
441 if let Some(ka) = keepalive {
442 s._keepalive.push(Box::new(Mutex::new(ka)));
443 }
444
445 s
446}
447
448struct PassThroughLogger {
450 tx: mpsc::UnboundedSender<LogRecord>,
451 rx: std::sync::Mutex<Option<mpsc::UnboundedReceiver<LogRecord>>>,
452}
453
454impl PassThroughLogger {
455 fn new() -> Self {
456 let (tx, rx) = mpsc::unbounded_channel();
457 Self {
458 tx,
459 rx: std::sync::Mutex::new(Some(rx)),
460 }
461 }
462
463 fn get_rx(&self) -> mpsc::UnboundedReceiver<LogRecord> {
465 self.rx.lock().unwrap().take().unwrap()
466 }
467}
468
469impl log::Log for PassThroughLogger {
470 fn enabled(&self, _metadata: &log::Metadata) -> bool {
471 true
475 }
476
477 fn log(&self, record: &log::Record) {
478 let _ = self.tx.send(record.into());
481 }
482
483 fn flush(&self) {
484 }
486}