Skip to main content

carton_runner_interface/
server.rs

1// Copyright 2023 Vivek Panyam
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    any::Any,
17    collections::HashMap,
18    path::{Path, PathBuf},
19    sync::{atomic::AtomicBool, Arc, Mutex},
20};
21
22use anywhere::types::{AnywhereFS, ReadOnlyFS, ReadWriteFS};
23use clap::Parser;
24use tokio::sync::mpsc::{self, error::SendError};
25use tracing_chrome::ChromeLayerBuilder;
26use tracing_subscriber::prelude::*;
27
28use crate::{
29    do_not_modify::comms::Comms,
30    do_not_modify::types::{ChannelId, FsToken, RPCRequest, RPCResponse},
31    multiplexer::Multiplexer,
32    types::{Device, Handle, LogRecord, RPCRequestData, RPCResponseData, RpcId, RunnerOpt, Tensor},
33};
34
35pub struct Server {
36    comms: Comms,
37    fs_multiplexer: Multiplexer<
38        anywhere::transport::serde::RequestMessageType,
39        anywhere::transport::serde::ResponseMessageType,
40    >,
41
42    outgoing: mpsc::Sender<RPCResponse>,
43    incoming: mpsc::Receiver<RPCRequest>,
44
45    // Keep this alive while the server is up
46    _keepalive: Vec<Box<dyn Any + Send + Sync>>,
47
48    // A flag that stops us from attempting to send log messages after shutdown
49    is_shutdown: Arc<AtomicBool>,
50}
51
52/// A handle that represents a map of sealed tensors
53#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
54pub struct SealHandle(pub(crate) u64);
55
56impl SealHandle {
57    pub fn new(v: u64) -> Self {
58        SealHandle(v)
59    }
60
61    pub fn get(&self) -> u64 {
62        self.0
63    }
64}
65
66impl From<crate::types::SealHandle> for SealHandle {
67    fn from(value: crate::types::SealHandle) -> Self {
68        Self(value.0)
69    }
70}
71
72impl From<SealHandle> for crate::types::SealHandle {
73    fn from(value: SealHandle) -> Self {
74        Self(value.0)
75    }
76}
77
78/// A request from the core library
79#[derive(Debug)]
80pub struct Request {
81    pub id: RpcId,
82
83    pub data: RequestData,
84}
85
86impl Request {
87    async fn from(req: RPCRequest, comms: &Comms) -> Self {
88        Request {
89            id: req.id,
90            data: RequestData::from(req.data, comms).await,
91        }
92    }
93}
94
95#[derive(Debug)]
96pub enum RequestData {
97    Load {
98        /// This filesystem points to a folder that is of the same structure as the output of `Pack` (for a particular runner)
99        /// For a readonly filesystem
100        fs: FsToken,
101
102        /// Load options
103        runner_name: String,
104        required_framework_version: semver::VersionReq,
105        runner_compat_version: u64,
106        runner_opts: Option<HashMap<String, RunnerOpt>>,
107        visible_device: Device,
108
109        // The hash of the model
110        // This should always be avalable unless we're loading an unpacked model
111        carton_manifest_hash: Option<String>,
112    },
113
114    // Pack a model
115    Pack {
116        /// A token for a read/write filesystem that the below paths reference
117        fs: FsToken,
118
119        // The path to user input data
120        // If this is a folder, the runner is allowed to place data in a `.carton` subfolder
121        // This can be used if it wants to generate a lockfile for example
122        input_path: String,
123
124        // A temporary folder generated by the core library. The runner can use this if it needs
125        // to generate output in a new folder.
126        // (In some cases, the input can be wrapped as-is and doesn't need to be copied into a new folder)
127        // This folder is owned by the core library and will be deleted by it
128        temp_folder: String,
129    },
130
131    Seal {
132        tensors: HashMap<String, Tensor>,
133    },
134
135    InferWithTensors {
136        tensors: HashMap<String, Tensor>,
137
138        // Do we support a streaming response
139        streaming: bool,
140    },
141
142    InferWithHandle {
143        handle: SealHandle,
144
145        // Do we support a streaming response
146        streaming: bool,
147    },
148}
149
150impl RequestData {
151    async fn from(value: RPCRequestData, comms: &Comms) -> Self {
152        let from_handles = |tensors: HashMap<String, Handle<Tensor>>| async {
153            let mut out = HashMap::new();
154            for (k, v) in tensors {
155                out.insert(k, v.into_inner(comms).await);
156            }
157
158            out
159        };
160
161        match value {
162            RPCRequestData::Load {
163                fs,
164                runner_name,
165                required_framework_version,
166                runner_compat_version,
167                runner_opts,
168                visible_device,
169                carton_manifest_hash,
170            } => Self::Load {
171                fs,
172                runner_name,
173                required_framework_version,
174                runner_compat_version,
175                runner_opts,
176                visible_device,
177                carton_manifest_hash,
178            },
179            RPCRequestData::Pack {
180                fs,
181                input_path,
182                temp_folder,
183            } => Self::Pack {
184                fs,
185                input_path,
186                temp_folder,
187            },
188            RPCRequestData::Seal { tensors } => Self::Seal {
189                tensors: from_handles(tensors).await,
190            },
191            RPCRequestData::InferWithTensors { tensors, streaming } => Self::InferWithTensors {
192                tensors: from_handles(tensors).await,
193                streaming,
194            },
195            RPCRequestData::InferWithHandle { handle, streaming } => Self::InferWithHandle {
196                handle: handle.into(),
197                streaming,
198            },
199        }
200    }
201}
202
203#[derive(Debug)]
204pub enum ResponseData {
205    /// Successful load
206    Load,
207
208    Pack {
209        // The path to the output directory. This can be in the temp folder passed into `Pack`
210        // Note: this must be a *directory* even if the input was a file
211        // This references a path on the FS that was passed in
212        // during the request
213        output_path: String,
214    },
215
216    Seal {
217        handle: SealHandle,
218    },
219
220    Infer {
221        tensors: HashMap<String, Tensor>,
222    },
223
224    /// Something went wrong
225    Error {
226        e: String,
227    },
228
229    /// Logging
230    LogMessage {
231        record: LogRecord,
232    },
233
234    Empty,
235}
236
237impl ResponseData {
238    async fn to_rpc(self, comms: &Comms) -> RPCResponseData {
239        let into_handles = |tensors: HashMap<String, Tensor>| async {
240            let mut out = HashMap::new();
241            for (k, v) in tensors {
242                out.insert(k, Handle::new(v, comms).await);
243            }
244
245            out
246        };
247
248        match self {
249            ResponseData::Load => RPCResponseData::Load,
250            ResponseData::Pack { output_path } => RPCResponseData::Pack { output_path },
251            ResponseData::Seal { handle } => RPCResponseData::Seal {
252                handle: handle.into(),
253            },
254            ResponseData::Infer { tensors } => RPCResponseData::Infer {
255                tensors: into_handles(tensors).await,
256            },
257            ResponseData::Error { e } => RPCResponseData::Error { e },
258            ResponseData::LogMessage { record } => RPCResponseData::LogMessage { record },
259            ResponseData::Empty => RPCResponseData::Empty,
260        }
261    }
262}
263
264impl Server {
265    async fn connect(path: &Path, logger: Option<&PassThroughLogger>) -> Self {
266        let comms = Comms::connect(path).await;
267
268        // Set up filesystem handling
269        let (tx, rx) = comms.get_channel(ChannelId::FileSystem).await;
270        let fs_multiplexer = Multiplexer::new(tx, rx).await;
271
272        let (tx, rx) = comms.get_channel(ChannelId::Rpc).await;
273
274        let is_shutdown = Arc::new(AtomicBool::new(false));
275        if let Some(logger) = logger {
276            let mut messages = logger.get_rx();
277            let out = tx.clone();
278            let is_shutdown = is_shutdown.clone();
279            tokio::spawn(async move {
280                while let Some(record) = messages.recv().await {
281                    if is_shutdown.load(std::sync::atomic::Ordering::Relaxed) {
282                        break;
283                    }
284
285                    // TODO: don't hardcode 0
286                    let status = out
287                        .send(RPCResponse {
288                            id: 0,
289                            complete: true,
290                            data: RPCResponseData::LogMessage { record },
291                        })
292                        .await;
293
294                    // Ignore send errors only when we're shutting down
295                    if let Err(s) = status {
296                        if is_shutdown.load(std::sync::atomic::Ordering::Relaxed) {
297                            break;
298                        } else {
299                            Err(s).unwrap()
300                        }
301                    }
302                }
303            });
304        }
305
306        Server {
307            comms,
308            fs_multiplexer,
309            incoming: rx,
310            outgoing: tx,
311            _keepalive: Vec::new(),
312            is_shutdown,
313        }
314    }
315
316    pub async fn get_next_request(&mut self) -> Option<Request> {
317        match self.incoming.recv().await {
318            Some(req) => Some(Request::from(req, &self.comms).await),
319            None => None,
320        }
321    }
322
323    pub async fn send_response_for_request(
324        &self,
325        req_id: u64,
326        res: ResponseData,
327    ) -> Result<(), SendError<()>> {
328        self.outgoing
329            .send(RPCResponse {
330                id: req_id,
331                complete: true,
332                data: res.to_rpc(&self.comms).await,
333            })
334            .await
335            .map_err(|_| SendError(()))
336    }
337
338    pub async fn send_streaming_response_for_request(
339        &self,
340        req_id: u64,
341        complete: bool,
342        res: ResponseData,
343    ) -> Result<(), SendError<()>> {
344        self.outgoing
345            .send(RPCResponse {
346                id: req_id,
347                complete,
348                data: res.to_rpc(&self.comms).await,
349            })
350            .await
351            .map_err(|_| SendError(()))
352    }
353
354    pub async fn get_writable_filesystem(&self, token: FsToken) -> std::io::Result<ReadWriteFS> {
355        self.get_filesystem_internal(token).await
356    }
357
358    pub async fn get_readonly_filesystem(&self, token: FsToken) -> std::io::Result<ReadOnlyFS> {
359        self.get_filesystem_internal(token).await
360    }
361
362    async fn get_filesystem_internal<const W: bool, const S: bool>(
363        &self,
364        token: FsToken,
365    ) -> std::io::Result<AnywhereFS<W, S>> {
366        let (tx, rx) = self.fs_multiplexer.get_stream_for_id(token.0).await;
367
368        anywhere::transport::serde::connect(tx, rx).await
369    }
370}
371
372impl Drop for Server {
373    fn drop(&mut self) {
374        // Mark that we shutdown
375        // TODO: we should be able to remove this once we remove the `unwrap`s in comms
376        self.is_shutdown
377            .store(true, std::sync::atomic::Ordering::Relaxed);
378    }
379}
380
381#[derive(Parser, Debug)]
382struct Args {
383    #[arg(long)]
384    uds_path: String,
385}
386
387/// Initialize the runner from command line args and return two queues to use to communicate
388pub async fn init_runner() -> Server {
389    let args = Args::parse();
390
391    // Shutdown the runner if the parent process dies
392    // NOTE: this technically shuts down if the thread that forked this process dies, but since
393    // the parent should be running in tokio, this should be okay because if the parent's tokio
394    // runtime goes down, we should go down.
395    #[cfg(not(target_os = "macos"))]
396    if unsafe { libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGKILL) } != 0 {
397        panic!("prctl failed")
398    }
399
400    // Watchdog on macos where we can't use PR_SET_PDEATHSIG
401    #[cfg(target_os = "macos")]
402    std::thread::spawn(|| {
403        loop {
404            let ppid = unsafe { libc::getppid() };
405            if ppid == 1 {
406                // The parent exited so we should exit
407                std::process::exit(0);
408            }
409
410            std::thread::sleep(std::time::Duration::from_secs(1));
411        }
412    });
413
414    // TODO: this is a little messy. Clean it up
415    let mut keepalive = None;
416    let mut pass_through_logger = None;
417    match std::env::var("CARTON_RUNNER_TRACE_FILE") {
418        Ok(path) => {
419            // Setup tracing
420            let (chrome_layer, _guard) = ChromeLayerBuilder::new()
421                .file(path)
422                .include_args(true)
423                .build();
424            tracing_subscriber::registry().with(chrome_layer).init();
425
426            keepalive = Some(_guard);
427        }
428        Err(_) => {
429            // Initialize logging
430            let logger: &'static PassThroughLogger = Box::leak(Box::new(PassThroughLogger::new()));
431            log::set_logger(logger).unwrap();
432            log::set_max_level(log::LevelFilter::Trace);
433
434            pass_through_logger = Some(logger);
435        }
436    };
437
438    // TODO: run the FD passing channel on top of UDS and get the appropriate channels out
439    let mut s = Server::connect(&PathBuf::from(args.uds_path), pass_through_logger).await;
440
441    if let Some(ka) = keepalive {
442        s._keepalive.push(Box::new(Mutex::new(ka)));
443    }
444
445    s
446}
447
448/// A logging implementation that passes through to the main process
449struct PassThroughLogger {
450    tx: mpsc::UnboundedSender<LogRecord>,
451    rx: std::sync::Mutex<Option<mpsc::UnboundedReceiver<LogRecord>>>,
452}
453
454impl PassThroughLogger {
455    fn new() -> Self {
456        let (tx, rx) = mpsc::unbounded_channel();
457        Self {
458            tx,
459            rx: std::sync::Mutex::new(Some(rx)),
460        }
461    }
462
463    // Can only be called once
464    fn get_rx(&self) -> mpsc::UnboundedReceiver<LogRecord> {
465        self.rx.lock().unwrap().take().unwrap()
466    }
467}
468
469impl log::Log for PassThroughLogger {
470    fn enabled(&self, _metadata: &log::Metadata) -> bool {
471        // This isn't ideal, but for now, lets always return true and let the
472        // main process handle it
473        // TODO: improve this
474        true
475    }
476
477    fn log(&self, record: &log::Record) {
478        // TODO: check if this is reasonably efficient
479        // Ignore send failures
480        let _ = self.tx.send(record.into());
481    }
482
483    fn flush(&self) {
484        // Noop for now
485    }
486}