agave_scheduling_utils/handshake/
server.rs

1use {
2    crate::handshake::{
3        shared::{
4            GLOBAL_ALLOCATORS, LOGON_FAILURE, LOGON_SUCCESS, MAX_ALLOCATOR_HANDLES, MAX_WORKERS,
5            VERSION,
6        },
7        ClientLogon,
8    },
9    agave_scheduler_bindings::{
10        PackToWorkerMessage, ProgressMessage, TpuToPackMessage, WorkerToPackMessage,
11    },
12    nix::sys::socket::{self, ControlMessage, MsgFlags, UnixAddr},
13    rts_alloc::Allocator,
14    std::{
15        ffi::CStr,
16        fs::File,
17        io::{IoSlice, Read, Write},
18        os::{
19            fd::{AsRawFd, FromRawFd},
20            unix::net::{UnixListener, UnixStream},
21        },
22        path::Path,
23        time::{Duration, Instant},
24    },
25    thiserror::Error,
26};
27
28type ShaqError = shaq::error::Error;
29type RtsAllocError = rts_alloc::error::Error;
30
31const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(1);
32const SHMEM_NAME: &CStr = c"/agave-scheduler-bindings";
33
34/// Implements the Agave side of the scheduler bindings handshake protocol.
35pub struct Server {
36    listener: UnixListener,
37
38    buffer: [u8; 1024],
39}
40
41impl Server {
42    pub fn new(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
43        let listener = UnixListener::bind(path)?;
44
45        Ok(Self {
46            listener,
47            buffer: [0; 1024],
48        })
49    }
50
51    pub fn accept(&mut self) -> Result<AgaveSession, AgaveHandshakeError> {
52        // Wait for next stream.
53        let (mut stream, _) = self.listener.accept()?;
54        stream.set_read_timeout(Some(HANDSHAKE_TIMEOUT))?;
55
56        match self.handle_logon(&mut stream) {
57            Ok(session) => Ok(session),
58            Err(err) => {
59                let reason = err.to_string();
60                let reason_len = u8::try_from(reason.len()).unwrap_or(u8::MAX);
61
62                let buffer_len = 2usize.checked_add(usize::from(reason_len)).unwrap();
63                self.buffer[0] = LOGON_FAILURE;
64                self.buffer[1] = reason_len;
65                self.buffer[2..buffer_len]
66                    .copy_from_slice(&reason.as_bytes()[..usize::from(reason_len)]);
67
68                stream.set_nonblocking(true)?;
69                // NB: Caller will still error out even if our write fails so it's fine to ignore the
70                // result.
71                let _ = stream.write(&self.buffer[..buffer_len])?;
72
73                Err(err)
74            }
75        }
76    }
77
78    fn handle_logon(
79        &mut self,
80        stream: &mut UnixStream,
81    ) -> Result<AgaveSession, AgaveHandshakeError> {
82        // Receive & validate the logon message.
83        let logon = self.recv_logon(stream)?;
84
85        // Setup the requested shared memory regions.
86        let (session, files) = Self::setup_session(logon)?;
87
88        // Send the file descriptors to the client.
89        let fds_raw: Vec<_> = files.iter().map(|file| file.as_raw_fd()).collect();
90        let iov = [IoSlice::new(&[LOGON_SUCCESS])];
91        let cmsgs = [ControlMessage::ScmRights(&fds_raw)];
92        let sent =
93            socket::sendmsg::<UnixAddr>(stream.as_raw_fd(), &iov, &cmsgs, MsgFlags::empty(), None)
94                .map_err(std::io::Error::from)?;
95        debug_assert_eq!(sent, 1);
96
97        Ok(session)
98    }
99
100    fn recv_logon(&mut self, stream: &mut UnixStream) -> Result<ClientLogon, AgaveHandshakeError> {
101        // Read the logon message.
102        let handshake_start = Instant::now();
103        let mut buffer_len = 0;
104        while buffer_len < self.buffer.len() {
105            let read = stream.read(&mut self.buffer[buffer_len..])?;
106            if read == 0 {
107                return Err(AgaveHandshakeError::EofDuringHandshake);
108            }
109
110            // SAFETY: We cannot read a value greater than buffer.len() which itself is a usize.
111            buffer_len = buffer_len.checked_add(read).unwrap();
112
113            if handshake_start.elapsed() > HANDSHAKE_TIMEOUT {
114                return Err(AgaveHandshakeError::Timeout);
115            }
116        }
117
118        // Ensure exact version match, version will be bumped any time a backwards incompatible
119        // change is made to handshake/shared memory objects.
120        let version = u64::from_le_bytes(self.buffer[..8].try_into().unwrap());
121        if version != VERSION {
122            return Err(AgaveHandshakeError::Version {
123                server: VERSION,
124                client: version,
125            });
126        }
127
128        // Read the logon message, cannot panic as we ensure the correct buf size at compile time
129        // (hence the const just below).
130        const LOGON_END: usize = 8 + core::mem::size_of::<ClientLogon>();
131        let logon = ClientLogon::try_from_bytes(&self.buffer[8..LOGON_END]).unwrap();
132
133        // Put a hard limit of 64 worker threads for now.
134        if !(1..=MAX_WORKERS).contains(&logon.worker_count) {
135            return Err(AgaveHandshakeError::WorkerCount(logon.worker_count));
136        }
137
138        // Hard limit allocator handles to 128.
139        if !(1..=MAX_ALLOCATOR_HANDLES).contains(&logon.allocator_handles) {
140            return Err(AgaveHandshakeError::AllocatorHandles(
141                logon.allocator_handles,
142            ));
143        }
144
145        Ok(logon)
146    }
147
148    fn setup_session(logon: ClientLogon) -> Result<(AgaveSession, Vec<File>), AgaveHandshakeError> {
149        // Setup the allocator in shared memory (`worker_count` & `allocator_handles` have been
150        // validated so this won't panic).
151        let allocator_count = GLOBAL_ALLOCATORS
152            .checked_add(logon.worker_count)
153            .unwrap()
154            .checked_add(logon.allocator_handles)
155            .unwrap();
156        let allocator_file = Self::create_shmem()?;
157        let tpu_to_pack_allocator = Allocator::create(
158            &allocator_file,
159            logon.allocator_size,
160            u32::try_from(allocator_count).unwrap(),
161            2 * 1024 * 1024,
162            0,
163        )?;
164
165        // Setup the global queues.
166        let (tpu_to_pack_file, tpu_to_pack_queue) = Self::create_producer(logon.tpu_to_pack_size)?;
167        let (progress_tracker_file, progress_tracker) =
168            Self::create_producer(logon.progress_tracker_size)?;
169
170        // Setup the worker sessions.
171        let (worker_files, workers) = (0..logon.worker_count).try_fold(
172            (Vec::default(), Vec::default()),
173            |(mut fds, mut workers), offset| {
174                // NB: Server validates all requested counts are within expected bands so this
175                // should never panic.
176                let worker_index = GLOBAL_ALLOCATORS.checked_add(offset).unwrap();
177                let worker_index = u32::try_from(worker_index).unwrap();
178                // SAFETY: Worker index is guaranteed to be unique.
179                let allocator = unsafe { Allocator::join(&allocator_file, worker_index) }?;
180
181                let (pack_to_worker_file, pack_to_worker) =
182                    Self::create_consumer(logon.pack_to_worker_size)?;
183                let (worker_to_pack_file, worker_to_pack) =
184                    Self::create_producer(logon.worker_to_pack_size)?;
185
186                fds.extend([pack_to_worker_file, worker_to_pack_file]);
187                workers.push(AgaveWorkerSession {
188                    allocator,
189                    pack_to_worker,
190                    worker_to_pack,
191                });
192
193                Ok::<_, AgaveHandshakeError>((fds, workers))
194            },
195        )?;
196
197        Ok((
198            AgaveSession {
199                tpu_to_pack: AgaveTpuToPackSession {
200                    allocator: tpu_to_pack_allocator,
201                    producer: tpu_to_pack_queue,
202                },
203                progress_tracker,
204                workers,
205            },
206            [allocator_file, tpu_to_pack_file, progress_tracker_file]
207                .into_iter()
208                .chain(worker_files)
209                .collect(),
210        ))
211    }
212
213    fn create_producer<T>(size: usize) -> Result<(File, shaq::Producer<T>), ShaqError> {
214        let file = Self::create_shmem()?;
215        let queue = shaq::Producer::create(&file, size)?;
216
217        Ok((file, queue))
218    }
219
220    fn create_consumer(
221        size: usize,
222    ) -> Result<(File, shaq::Consumer<PackToWorkerMessage>), ShaqError> {
223        let file = Self::create_shmem()?;
224        let queue = shaq::Consumer::create(&file, size)?;
225
226        Ok((file, queue))
227    }
228
229    #[cfg(any(
230        target_os = "linux",
231        target_os = "l4re",
232        target_os = "android",
233        target_os = "emscripten"
234    ))]
235    fn create_shmem() -> Result<File, std::io::Error> {
236        unsafe {
237            let ret = libc::memfd_create(SHMEM_NAME.as_ptr(), 0);
238            if ret == -1 {
239                return Err(std::io::Error::last_os_error());
240            }
241
242            Ok(File::from_raw_fd(ret))
243        }
244    }
245
246    #[cfg(not(any(
247        target_os = "linux",
248        target_os = "l4re",
249        target_os = "android",
250        target_os = "emscripten"
251    )))]
252    fn create_shmem() -> Result<File, std::io::Error> {
253        unsafe {
254            // Clean up the previous link if one exists.
255            let ret = libc::shm_unlink(SHMEM_NAME.as_ptr());
256            if ret == -1 {
257                let err = std::io::Error::last_os_error();
258                if err.kind() != std::io::ErrorKind::NotFound {
259                    return Err(err);
260                }
261            }
262
263            // Create a new shared memory object.
264            let ret = libc::shm_open(
265                SHMEM_NAME.as_ptr(),
266                libc::O_CREAT | libc::O_EXCL | libc::O_RDWR,
267                #[cfg(not(target_os = "macos"))]
268                {
269                    libc::S_IRUSR | libc::S_IWUSR
270                },
271                #[cfg(any(target_os = "macos", target_os = "ios"))]
272                {
273                    (libc::S_IRUSR | libc::S_IWUSR) as libc::c_uint
274                },
275            );
276            if ret == -1 {
277                return Err(std::io::Error::last_os_error());
278            }
279            let file = File::from_raw_fd(ret);
280
281            // Clean up after ourself.
282            let ret = libc::shm_unlink(SHMEM_NAME.as_ptr());
283            if ret == -1 {
284                return Err(std::io::Error::last_os_error());
285            }
286
287            Ok(file)
288        }
289    }
290}
291
292/// An initialized scheduling session.
293pub struct AgaveSession {
294    pub tpu_to_pack: AgaveTpuToPackSession,
295    pub progress_tracker: shaq::Producer<ProgressMessage>,
296    pub workers: Vec<AgaveWorkerSession>,
297}
298
299/// Shared memory objects for the tpu to pack worker.
300pub struct AgaveTpuToPackSession {
301    pub allocator: Allocator,
302    pub producer: shaq::Producer<TpuToPackMessage>,
303}
304
305/// Shared memory objects for a single banking worker.
306pub struct AgaveWorkerSession {
307    pub allocator: Allocator,
308    pub pack_to_worker: shaq::Consumer<PackToWorkerMessage>,
309    pub worker_to_pack: shaq::Producer<WorkerToPackMessage>,
310}
311
312/// Potential errors that can occur during the Agave side of the handshake.
313///
314/// # Note
315///
316/// These errors are stringified (up to 256 bytes then truncated) and sent to the client.
317#[derive(Debug, Error)]
318pub enum AgaveHandshakeError {
319    #[error("Io; err={0}")]
320    Io(#[from] std::io::Error),
321    #[error("Timeout")]
322    Timeout,
323    #[error("Close during handshake")]
324    EofDuringHandshake,
325    #[error("Version; server={server}; client={client}")]
326    Version { server: u64, client: u64 },
327    #[error("Worker count; count={0}")]
328    WorkerCount(usize),
329    #[error("Allocator handles; count={0}")]
330    AllocatorHandles(usize),
331    #[error("Rts alloc; err={0:?}")]
332    RtsAlloc(#[from] RtsAllocError),
333    #[error("Shaq; err={0:?}")]
334    Shaq(#[from] ShaqError),
335}