agave_scheduling_utils/handshake/
server.rs1use {
2 crate::handshake::{
3 shared::{
4 GLOBAL_ALLOCATORS, LOGON_FAILURE, LOGON_SUCCESS, MAX_ALLOCATOR_HANDLES, MAX_WORKERS,
5 VERSION,
6 },
7 ClientLogon,
8 },
9 agave_scheduler_bindings::{
10 PackToWorkerMessage, ProgressMessage, TpuToPackMessage, WorkerToPackMessage,
11 },
12 nix::sys::socket::{self, ControlMessage, MsgFlags, UnixAddr},
13 rts_alloc::Allocator,
14 std::{
15 ffi::CStr,
16 fs::File,
17 io::{IoSlice, Read, Write},
18 os::{
19 fd::{AsRawFd, FromRawFd},
20 unix::net::{UnixListener, UnixStream},
21 },
22 path::Path,
23 time::{Duration, Instant},
24 },
25 thiserror::Error,
26};
27
28type ShaqError = shaq::error::Error;
29type RtsAllocError = rts_alloc::error::Error;
30
31const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(1);
32const SHMEM_NAME: &CStr = c"/agave-scheduler-bindings";
33
34pub struct Server {
36 listener: UnixListener,
37
38 buffer: [u8; 1024],
39}
40
41impl Server {
42 pub fn new(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
43 let listener = UnixListener::bind(path)?;
44
45 Ok(Self {
46 listener,
47 buffer: [0; 1024],
48 })
49 }
50
51 pub fn accept(&mut self) -> Result<AgaveSession, AgaveHandshakeError> {
52 let (mut stream, _) = self.listener.accept()?;
54 stream.set_read_timeout(Some(HANDSHAKE_TIMEOUT))?;
55
56 match self.handle_logon(&mut stream) {
57 Ok(session) => Ok(session),
58 Err(err) => {
59 let reason = err.to_string();
60 let reason_len = u8::try_from(reason.len()).unwrap_or(u8::MAX);
61
62 let buffer_len = 2usize.checked_add(usize::from(reason_len)).unwrap();
63 self.buffer[0] = LOGON_FAILURE;
64 self.buffer[1] = reason_len;
65 self.buffer[2..buffer_len]
66 .copy_from_slice(&reason.as_bytes()[..usize::from(reason_len)]);
67
68 stream.set_nonblocking(true)?;
69 let _ = stream.write(&self.buffer[..buffer_len])?;
72
73 Err(err)
74 }
75 }
76 }
77
78 fn handle_logon(
79 &mut self,
80 stream: &mut UnixStream,
81 ) -> Result<AgaveSession, AgaveHandshakeError> {
82 let logon = self.recv_logon(stream)?;
84
85 let (session, files) = Self::setup_session(logon)?;
87
88 let fds_raw: Vec<_> = files.iter().map(|file| file.as_raw_fd()).collect();
90 let iov = [IoSlice::new(&[LOGON_SUCCESS])];
91 let cmsgs = [ControlMessage::ScmRights(&fds_raw)];
92 let sent =
93 socket::sendmsg::<UnixAddr>(stream.as_raw_fd(), &iov, &cmsgs, MsgFlags::empty(), None)
94 .map_err(std::io::Error::from)?;
95 debug_assert_eq!(sent, 1);
96
97 Ok(session)
98 }
99
100 fn recv_logon(&mut self, stream: &mut UnixStream) -> Result<ClientLogon, AgaveHandshakeError> {
101 let handshake_start = Instant::now();
103 let mut buffer_len = 0;
104 while buffer_len < self.buffer.len() {
105 let read = stream.read(&mut self.buffer[buffer_len..])?;
106 if read == 0 {
107 return Err(AgaveHandshakeError::EofDuringHandshake);
108 }
109
110 buffer_len = buffer_len.checked_add(read).unwrap();
112
113 if handshake_start.elapsed() > HANDSHAKE_TIMEOUT {
114 return Err(AgaveHandshakeError::Timeout);
115 }
116 }
117
118 let version = u64::from_le_bytes(self.buffer[..8].try_into().unwrap());
121 if version != VERSION {
122 return Err(AgaveHandshakeError::Version {
123 server: VERSION,
124 client: version,
125 });
126 }
127
128 const LOGON_END: usize = 8 + core::mem::size_of::<ClientLogon>();
131 let logon = ClientLogon::try_from_bytes(&self.buffer[8..LOGON_END]).unwrap();
132
133 if !(1..=MAX_WORKERS).contains(&logon.worker_count) {
135 return Err(AgaveHandshakeError::WorkerCount(logon.worker_count));
136 }
137
138 if !(1..=MAX_ALLOCATOR_HANDLES).contains(&logon.allocator_handles) {
140 return Err(AgaveHandshakeError::AllocatorHandles(
141 logon.allocator_handles,
142 ));
143 }
144
145 Ok(logon)
146 }
147
148 fn setup_session(logon: ClientLogon) -> Result<(AgaveSession, Vec<File>), AgaveHandshakeError> {
149 let allocator_count = GLOBAL_ALLOCATORS
152 .checked_add(logon.worker_count)
153 .unwrap()
154 .checked_add(logon.allocator_handles)
155 .unwrap();
156 let allocator_file = Self::create_shmem()?;
157 let tpu_to_pack_allocator = Allocator::create(
158 &allocator_file,
159 logon.allocator_size,
160 u32::try_from(allocator_count).unwrap(),
161 2 * 1024 * 1024,
162 0,
163 )?;
164
165 let (tpu_to_pack_file, tpu_to_pack_queue) = Self::create_producer(logon.tpu_to_pack_size)?;
167 let (progress_tracker_file, progress_tracker) =
168 Self::create_producer(logon.progress_tracker_size)?;
169
170 let (worker_files, workers) = (0..logon.worker_count).try_fold(
172 (Vec::default(), Vec::default()),
173 |(mut fds, mut workers), offset| {
174 let worker_index = GLOBAL_ALLOCATORS.checked_add(offset).unwrap();
177 let worker_index = u32::try_from(worker_index).unwrap();
178 let allocator = unsafe { Allocator::join(&allocator_file, worker_index) }?;
180
181 let (pack_to_worker_file, pack_to_worker) =
182 Self::create_consumer(logon.pack_to_worker_size)?;
183 let (worker_to_pack_file, worker_to_pack) =
184 Self::create_producer(logon.worker_to_pack_size)?;
185
186 fds.extend([pack_to_worker_file, worker_to_pack_file]);
187 workers.push(AgaveWorkerSession {
188 allocator,
189 pack_to_worker,
190 worker_to_pack,
191 });
192
193 Ok::<_, AgaveHandshakeError>((fds, workers))
194 },
195 )?;
196
197 Ok((
198 AgaveSession {
199 tpu_to_pack: AgaveTpuToPackSession {
200 allocator: tpu_to_pack_allocator,
201 producer: tpu_to_pack_queue,
202 },
203 progress_tracker,
204 workers,
205 },
206 [allocator_file, tpu_to_pack_file, progress_tracker_file]
207 .into_iter()
208 .chain(worker_files)
209 .collect(),
210 ))
211 }
212
213 fn create_producer<T>(size: usize) -> Result<(File, shaq::Producer<T>), ShaqError> {
214 let file = Self::create_shmem()?;
215 let queue = shaq::Producer::create(&file, size)?;
216
217 Ok((file, queue))
218 }
219
220 fn create_consumer(
221 size: usize,
222 ) -> Result<(File, shaq::Consumer<PackToWorkerMessage>), ShaqError> {
223 let file = Self::create_shmem()?;
224 let queue = shaq::Consumer::create(&file, size)?;
225
226 Ok((file, queue))
227 }
228
229 #[cfg(any(
230 target_os = "linux",
231 target_os = "l4re",
232 target_os = "android",
233 target_os = "emscripten"
234 ))]
235 fn create_shmem() -> Result<File, std::io::Error> {
236 unsafe {
237 let ret = libc::memfd_create(SHMEM_NAME.as_ptr(), 0);
238 if ret == -1 {
239 return Err(std::io::Error::last_os_error());
240 }
241
242 Ok(File::from_raw_fd(ret))
243 }
244 }
245
246 #[cfg(not(any(
247 target_os = "linux",
248 target_os = "l4re",
249 target_os = "android",
250 target_os = "emscripten"
251 )))]
252 fn create_shmem() -> Result<File, std::io::Error> {
253 unsafe {
254 let ret = libc::shm_unlink(SHMEM_NAME.as_ptr());
256 if ret == -1 {
257 let err = std::io::Error::last_os_error();
258 if err.kind() != std::io::ErrorKind::NotFound {
259 return Err(err);
260 }
261 }
262
263 let ret = libc::shm_open(
265 SHMEM_NAME.as_ptr(),
266 libc::O_CREAT | libc::O_EXCL | libc::O_RDWR,
267 #[cfg(not(target_os = "macos"))]
268 {
269 libc::S_IRUSR | libc::S_IWUSR
270 },
271 #[cfg(any(target_os = "macos", target_os = "ios"))]
272 {
273 (libc::S_IRUSR | libc::S_IWUSR) as libc::c_uint
274 },
275 );
276 if ret == -1 {
277 return Err(std::io::Error::last_os_error());
278 }
279 let file = File::from_raw_fd(ret);
280
281 let ret = libc::shm_unlink(SHMEM_NAME.as_ptr());
283 if ret == -1 {
284 return Err(std::io::Error::last_os_error());
285 }
286
287 Ok(file)
288 }
289 }
290}
291
292pub struct AgaveSession {
294 pub tpu_to_pack: AgaveTpuToPackSession,
295 pub progress_tracker: shaq::Producer<ProgressMessage>,
296 pub workers: Vec<AgaveWorkerSession>,
297}
298
299pub struct AgaveTpuToPackSession {
301 pub allocator: Allocator,
302 pub producer: shaq::Producer<TpuToPackMessage>,
303}
304
305pub struct AgaveWorkerSession {
307 pub allocator: Allocator,
308 pub pack_to_worker: shaq::Consumer<PackToWorkerMessage>,
309 pub worker_to_pack: shaq::Producer<WorkerToPackMessage>,
310}
311
312#[derive(Debug, Error)]
318pub enum AgaveHandshakeError {
319 #[error("Io; err={0}")]
320 Io(#[from] std::io::Error),
321 #[error("Timeout")]
322 Timeout,
323 #[error("Close during handshake")]
324 EofDuringHandshake,
325 #[error("Version; server={server}; client={client}")]
326 Version { server: u64, client: u64 },
327 #[error("Worker count; count={0}")]
328 WorkerCount(usize),
329 #[error("Allocator handles; count={0}")]
330 AllocatorHandles(usize),
331 #[error("Rts alloc; err={0:?}")]
332 RtsAlloc(#[from] RtsAllocError),
333 #[error("Shaq; err={0:?}")]
334 Shaq(#[from] ShaqError),
335}