use {
crate::handshake::{
shared::{
GLOBAL_ALLOCATORS, LOGON_FAILURE, LOGON_SUCCESS, MAX_ALLOCATOR_HANDLES, MAX_WORKERS,
VERSION,
},
ClientLogon,
},
agave_scheduler_bindings::{
PackToWorkerMessage, ProgressMessage, TpuToPackMessage, WorkerToPackMessage,
},
nix::sys::socket::{self, ControlMessage, MsgFlags, UnixAddr},
rts_alloc::Allocator,
std::{
ffi::CStr,
fs::File,
io::{IoSlice, Read, Write},
os::{
fd::{AsRawFd, FromRawFd},
unix::net::{UnixListener, UnixStream},
},
path::Path,
time::{Duration, Instant},
},
thiserror::Error,
};
type ShaqError = shaq::error::Error;
type RtsAllocError = rts_alloc::error::Error;
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(1);
const SHMEM_NAME: &CStr = c"/agave-scheduler-bindings";
pub struct Server {
listener: UnixListener,
buffer: [u8; 1024],
}
impl Server {
pub fn new(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
let listener = UnixListener::bind(path)?;
Ok(Self {
listener,
buffer: [0; 1024],
})
}
pub fn accept(&mut self) -> Result<AgaveSession, AgaveHandshakeError> {
let (mut stream, _) = self.listener.accept()?;
stream.set_read_timeout(Some(HANDSHAKE_TIMEOUT))?;
match self.handle_logon(&mut stream) {
Ok(session) => Ok(session),
Err(err) => {
let reason = err.to_string();
let reason_len = u8::try_from(reason.len()).unwrap_or(u8::MAX);
let buffer_len = 2usize.checked_add(usize::from(reason_len)).unwrap();
self.buffer[0] = LOGON_FAILURE;
self.buffer[1] = reason_len;
self.buffer[2..buffer_len]
.copy_from_slice(&reason.as_bytes()[..usize::from(reason_len)]);
stream.set_nonblocking(true)?;
let _ = stream.write(&self.buffer[..buffer_len])?;
Err(err)
}
}
}
fn handle_logon(
&mut self,
stream: &mut UnixStream,
) -> Result<AgaveSession, AgaveHandshakeError> {
let logon = self.recv_logon(stream)?;
let (session, files) = Self::setup_session(logon)?;
let fds_raw: Vec<_> = files.iter().map(|file| file.as_raw_fd()).collect();
let iov = [IoSlice::new(&[LOGON_SUCCESS])];
let cmsgs = [ControlMessage::ScmRights(&fds_raw)];
let sent =
socket::sendmsg::<UnixAddr>(stream.as_raw_fd(), &iov, &cmsgs, MsgFlags::empty(), None)
.map_err(std::io::Error::from)?;
debug_assert_eq!(sent, 1);
Ok(session)
}
fn recv_logon(&mut self, stream: &mut UnixStream) -> Result<ClientLogon, AgaveHandshakeError> {
let handshake_start = Instant::now();
let mut buffer_len = 0;
while buffer_len < self.buffer.len() {
let read = stream.read(&mut self.buffer[buffer_len..])?;
if read == 0 {
return Err(AgaveHandshakeError::EofDuringHandshake);
}
buffer_len = buffer_len.checked_add(read).unwrap();
if handshake_start.elapsed() > HANDSHAKE_TIMEOUT {
return Err(AgaveHandshakeError::Timeout);
}
}
let version = u64::from_le_bytes(self.buffer[..8].try_into().unwrap());
if version != VERSION {
return Err(AgaveHandshakeError::Version {
server: VERSION,
client: version,
});
}
const LOGON_END: usize = 8 + core::mem::size_of::<ClientLogon>();
let logon = ClientLogon::try_from_bytes(&self.buffer[8..LOGON_END]).unwrap();
if !(1..=MAX_WORKERS).contains(&logon.worker_count) {
return Err(AgaveHandshakeError::WorkerCount(logon.worker_count));
}
if !(1..=MAX_ALLOCATOR_HANDLES).contains(&logon.allocator_handles) {
return Err(AgaveHandshakeError::AllocatorHandles(
logon.allocator_handles,
));
}
Ok(logon)
}
fn setup_session(logon: ClientLogon) -> Result<(AgaveSession, Vec<File>), AgaveHandshakeError> {
let allocator_count = GLOBAL_ALLOCATORS
.checked_add(logon.worker_count)
.unwrap()
.checked_add(logon.allocator_handles)
.unwrap();
let allocator_file = Self::create_shmem()?;
let tpu_to_pack_allocator = Allocator::create(
&allocator_file,
logon.allocator_size,
u32::try_from(allocator_count).unwrap(),
2 * 1024 * 1024,
0,
)?;
let (tpu_to_pack_file, tpu_to_pack_queue) = Self::create_producer(logon.tpu_to_pack_size)?;
let (progress_tracker_file, progress_tracker) =
Self::create_producer(logon.progress_tracker_size)?;
let (worker_files, workers) = (0..logon.worker_count).try_fold(
(Vec::default(), Vec::default()),
|(mut fds, mut workers), offset| {
let worker_index = GLOBAL_ALLOCATORS.checked_add(offset).unwrap();
let worker_index = u32::try_from(worker_index).unwrap();
let allocator = unsafe { Allocator::join(&allocator_file, worker_index) }?;
let (pack_to_worker_file, pack_to_worker) =
Self::create_consumer(logon.pack_to_worker_size)?;
let (worker_to_pack_file, worker_to_pack) =
Self::create_producer(logon.worker_to_pack_size)?;
fds.extend([pack_to_worker_file, worker_to_pack_file]);
workers.push(AgaveWorkerSession {
allocator,
pack_to_worker,
worker_to_pack,
});
Ok::<_, AgaveHandshakeError>((fds, workers))
},
)?;
Ok((
AgaveSession {
tpu_to_pack: AgaveTpuToPackSession {
allocator: tpu_to_pack_allocator,
producer: tpu_to_pack_queue,
},
progress_tracker,
workers,
},
[allocator_file, tpu_to_pack_file, progress_tracker_file]
.into_iter()
.chain(worker_files)
.collect(),
))
}
fn create_producer<T>(size: usize) -> Result<(File, shaq::Producer<T>), ShaqError> {
let file = Self::create_shmem()?;
let queue = shaq::Producer::create(&file, size)?;
Ok((file, queue))
}
fn create_consumer(
size: usize,
) -> Result<(File, shaq::Consumer<PackToWorkerMessage>), ShaqError> {
let file = Self::create_shmem()?;
let queue = shaq::Consumer::create(&file, size)?;
Ok((file, queue))
}
#[cfg(any(
target_os = "linux",
target_os = "l4re",
target_os = "android",
target_os = "emscripten"
))]
fn create_shmem() -> Result<File, std::io::Error> {
unsafe {
let ret = libc::memfd_create(SHMEM_NAME.as_ptr(), 0);
if ret == -1 {
return Err(std::io::Error::last_os_error());
}
Ok(File::from_raw_fd(ret))
}
}
#[cfg(not(any(
target_os = "linux",
target_os = "l4re",
target_os = "android",
target_os = "emscripten"
)))]
fn create_shmem() -> Result<File, std::io::Error> {
unsafe {
let ret = libc::shm_unlink(SHMEM_NAME.as_ptr());
if ret == -1 {
let err = std::io::Error::last_os_error();
if err.kind() != std::io::ErrorKind::NotFound {
return Err(err);
}
}
let ret = libc::shm_open(
SHMEM_NAME.as_ptr(),
libc::O_CREAT | libc::O_EXCL | libc::O_RDWR,
#[cfg(not(target_os = "macos"))]
{
libc::S_IRUSR | libc::S_IWUSR
},
#[cfg(any(target_os = "macos", target_os = "ios"))]
{
(libc::S_IRUSR | libc::S_IWUSR) as libc::c_uint
},
);
if ret == -1 {
return Err(std::io::Error::last_os_error());
}
let file = File::from_raw_fd(ret);
let ret = libc::shm_unlink(SHMEM_NAME.as_ptr());
if ret == -1 {
return Err(std::io::Error::last_os_error());
}
Ok(file)
}
}
}
pub struct AgaveSession {
pub tpu_to_pack: AgaveTpuToPackSession,
pub progress_tracker: shaq::Producer<ProgressMessage>,
pub workers: Vec<AgaveWorkerSession>,
}
pub struct AgaveTpuToPackSession {
pub allocator: Allocator,
pub producer: shaq::Producer<TpuToPackMessage>,
}
pub struct AgaveWorkerSession {
pub allocator: Allocator,
pub pack_to_worker: shaq::Consumer<PackToWorkerMessage>,
pub worker_to_pack: shaq::Producer<WorkerToPackMessage>,
}
#[derive(Debug, Error)]
pub enum AgaveHandshakeError {
#[error("Io; err={0}")]
Io(#[from] std::io::Error),
#[error("Timeout")]
Timeout,
#[error("Close during handshake")]
EofDuringHandshake,
#[error("Version; server={server}; client={client}")]
Version { server: u64, client: u64 },
#[error("Worker count; count={0}")]
WorkerCount(usize),
#[error("Allocator handles; count={0}")]
AllocatorHandles(usize),
#[error("Rts alloc; err={0:?}")]
RtsAlloc(#[from] RtsAllocError),
#[error("Shaq; err={0:?}")]
Shaq(#[from] ShaqError),
}