use std::{
any::Any,
collections::HashMap,
path::{Path, PathBuf},
sync::{atomic::AtomicBool, Arc, Mutex},
};
use anywhere::types::{AnywhereFS, ReadOnlyFS, ReadWriteFS};
use clap::Parser;
use tokio::sync::mpsc::{self, error::SendError};
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
use crate::{
do_not_modify::comms::Comms,
do_not_modify::types::{ChannelId, FsToken, RPCRequest, RPCResponse},
multiplexer::Multiplexer,
types::{Device, Handle, LogRecord, RPCRequestData, RPCResponseData, RpcId, RunnerOpt, Tensor},
};
pub struct Server {
comms: Comms,
fs_multiplexer: Multiplexer<
anywhere::transport::serde::RequestMessageType,
anywhere::transport::serde::ResponseMessageType,
>,
outgoing: mpsc::Sender<RPCResponse>,
incoming: mpsc::Receiver<RPCRequest>,
_keepalive: Vec<Box<dyn Any + Send + Sync>>,
is_shutdown: Arc<AtomicBool>,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub struct SealHandle(pub(crate) u64);
impl SealHandle {
pub fn new(v: u64) -> Self {
SealHandle(v)
}
pub fn get(&self) -> u64 {
self.0
}
}
impl From<crate::types::SealHandle> for SealHandle {
fn from(value: crate::types::SealHandle) -> Self {
Self(value.0)
}
}
impl From<SealHandle> for crate::types::SealHandle {
fn from(value: SealHandle) -> Self {
Self(value.0)
}
}
#[derive(Debug)]
pub struct Request {
pub id: RpcId,
pub data: RequestData,
}
impl Request {
async fn from(req: RPCRequest, comms: &Comms) -> Self {
Request {
id: req.id,
data: RequestData::from(req.data, comms).await,
}
}
}
#[derive(Debug)]
pub enum RequestData {
Load {
fs: FsToken,
runner_name: String,
required_framework_version: semver::VersionReq,
runner_compat_version: u64,
runner_opts: Option<HashMap<String, RunnerOpt>>,
visible_device: Device,
carton_manifest_hash: Option<String>,
},
Pack {
fs: FsToken,
input_path: String,
temp_folder: String,
},
Seal {
tensors: HashMap<String, Tensor>,
},
InferWithTensors {
tensors: HashMap<String, Tensor>,
streaming: bool,
},
InferWithHandle {
handle: SealHandle,
streaming: bool,
},
}
impl RequestData {
async fn from(value: RPCRequestData, comms: &Comms) -> Self {
let from_handles = |tensors: HashMap<String, Handle<Tensor>>| async {
let mut out = HashMap::new();
for (k, v) in tensors {
out.insert(k, v.into_inner(comms).await);
}
out
};
match value {
RPCRequestData::Load {
fs,
runner_name,
required_framework_version,
runner_compat_version,
runner_opts,
visible_device,
carton_manifest_hash,
} => Self::Load {
fs,
runner_name,
required_framework_version,
runner_compat_version,
runner_opts,
visible_device,
carton_manifest_hash,
},
RPCRequestData::Pack {
fs,
input_path,
temp_folder,
} => Self::Pack {
fs,
input_path,
temp_folder,
},
RPCRequestData::Seal { tensors } => Self::Seal {
tensors: from_handles(tensors).await,
},
RPCRequestData::InferWithTensors { tensors, streaming } => Self::InferWithTensors {
tensors: from_handles(tensors).await,
streaming,
},
RPCRequestData::InferWithHandle { handle, streaming } => Self::InferWithHandle {
handle: handle.into(),
streaming,
},
}
}
}
#[derive(Debug)]
pub enum ResponseData {
Load,
Pack {
output_path: String,
},
Seal {
handle: SealHandle,
},
Infer {
tensors: HashMap<String, Tensor>,
},
Error {
e: String,
},
LogMessage {
record: LogRecord,
},
Empty,
}
impl ResponseData {
async fn to_rpc(self, comms: &Comms) -> RPCResponseData {
let into_handles = |tensors: HashMap<String, Tensor>| async {
let mut out = HashMap::new();
for (k, v) in tensors {
out.insert(k, Handle::new(v, comms).await);
}
out
};
match self {
ResponseData::Load => RPCResponseData::Load,
ResponseData::Pack { output_path } => RPCResponseData::Pack { output_path },
ResponseData::Seal { handle } => RPCResponseData::Seal {
handle: handle.into(),
},
ResponseData::Infer { tensors } => RPCResponseData::Infer {
tensors: into_handles(tensors).await,
},
ResponseData::Error { e } => RPCResponseData::Error { e },
ResponseData::LogMessage { record } => RPCResponseData::LogMessage { record },
ResponseData::Empty => RPCResponseData::Empty,
}
}
}
impl Server {
async fn connect(path: &Path, logger: Option<&PassThroughLogger>) -> Self {
let comms = Comms::connect(path).await;
let (tx, rx) = comms.get_channel(ChannelId::FileSystem).await;
let fs_multiplexer = Multiplexer::new(tx, rx).await;
let (tx, rx) = comms.get_channel(ChannelId::Rpc).await;
let is_shutdown = Arc::new(AtomicBool::new(false));
if let Some(logger) = logger {
let mut messages = logger.get_rx();
let out = tx.clone();
let is_shutdown = is_shutdown.clone();
tokio::spawn(async move {
while let Some(record) = messages.recv().await {
if is_shutdown.load(std::sync::atomic::Ordering::Relaxed) {
break;
}
let status = out
.send(RPCResponse {
id: 0,
complete: true,
data: RPCResponseData::LogMessage { record },
})
.await;
if let Err(s) = status {
if is_shutdown.load(std::sync::atomic::Ordering::Relaxed) {
break;
} else {
Err(s).unwrap()
}
}
}
});
}
Server {
comms,
fs_multiplexer,
incoming: rx,
outgoing: tx,
_keepalive: Vec::new(),
is_shutdown,
}
}
pub async fn get_next_request(&mut self) -> Option<Request> {
match self.incoming.recv().await {
Some(req) => Some(Request::from(req, &self.comms).await),
None => None,
}
}
pub async fn send_response_for_request(
&self,
req_id: u64,
res: ResponseData,
) -> Result<(), SendError<()>> {
self.outgoing
.send(RPCResponse {
id: req_id,
complete: true,
data: res.to_rpc(&self.comms).await,
})
.await
.map_err(|_| SendError(()))
}
pub async fn send_streaming_response_for_request(
&self,
req_id: u64,
complete: bool,
res: ResponseData,
) -> Result<(), SendError<()>> {
self.outgoing
.send(RPCResponse {
id: req_id,
complete,
data: res.to_rpc(&self.comms).await,
})
.await
.map_err(|_| SendError(()))
}
pub async fn get_writable_filesystem(&self, token: FsToken) -> std::io::Result<ReadWriteFS> {
self.get_filesystem_internal(token).await
}
pub async fn get_readonly_filesystem(&self, token: FsToken) -> std::io::Result<ReadOnlyFS> {
self.get_filesystem_internal(token).await
}
async fn get_filesystem_internal<const W: bool, const S: bool>(
&self,
token: FsToken,
) -> std::io::Result<AnywhereFS<W, S>> {
let (tx, rx) = self.fs_multiplexer.get_stream_for_id(token.0).await;
anywhere::transport::serde::connect(tx, rx).await
}
}
impl Drop for Server {
fn drop(&mut self) {
self.is_shutdown
.store(true, std::sync::atomic::Ordering::Relaxed);
}
}
#[derive(Parser, Debug)]
struct Args {
#[arg(long)]
uds_path: String,
}
pub async fn init_runner() -> Server {
let args = Args::parse();
#[cfg(not(target_os = "macos"))]
if unsafe { libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGKILL) } != 0 {
panic!("prctl failed")
}
#[cfg(target_os = "macos")]
std::thread::spawn(|| {
loop {
let ppid = unsafe { libc::getppid() };
if ppid == 1 {
std::process::exit(0);
}
std::thread::sleep(std::time::Duration::from_secs(1));
}
});
let mut keepalive = None;
let mut pass_through_logger = None;
match std::env::var("CARTON_RUNNER_TRACE_FILE") {
Ok(path) => {
let (chrome_layer, _guard) = ChromeLayerBuilder::new()
.file(path)
.include_args(true)
.build();
tracing_subscriber::registry().with(chrome_layer).init();
keepalive = Some(_guard);
}
Err(_) => {
let logger: &'static PassThroughLogger = Box::leak(Box::new(PassThroughLogger::new()));
log::set_logger(logger).unwrap();
log::set_max_level(log::LevelFilter::Trace);
pass_through_logger = Some(logger);
}
};
let mut s = Server::connect(&PathBuf::from(args.uds_path), pass_through_logger).await;
if let Some(ka) = keepalive {
s._keepalive.push(Box::new(Mutex::new(ka)));
}
s
}
struct PassThroughLogger {
tx: mpsc::UnboundedSender<LogRecord>,
rx: std::sync::Mutex<Option<mpsc::UnboundedReceiver<LogRecord>>>,
}
impl PassThroughLogger {
fn new() -> Self {
let (tx, rx) = mpsc::unbounded_channel();
Self {
tx,
rx: std::sync::Mutex::new(Some(rx)),
}
}
fn get_rx(&self) -> mpsc::UnboundedReceiver<LogRecord> {
self.rx.lock().unwrap().take().unwrap()
}
}
impl log::Log for PassThroughLogger {
fn enabled(&self, _metadata: &log::Metadata) -> bool {
true
}
fn log(&self, record: &log::Record) {
let _ = self.tx.send(record.into());
}
fn flush(&self) {
}
}