agave_scheduling_utils/handshake/
client.rs

1use {
2    crate::handshake::{
3        shared::{GLOBAL_ALLOCATORS, LOGON_FAILURE, MAX_WORKERS, VERSION},
4        ClientLogon,
5    },
6    agave_scheduler_bindings::{
7        PackToWorkerMessage, ProgressMessage, TpuToPackMessage, WorkerToPackMessage,
8    },
9    libc::CMSG_LEN,
10    nix::sys::socket::{self, ControlMessageOwned, MsgFlags, UnixAddr},
11    rts_alloc::Allocator,
12    std::{
13        fs::File,
14        io::{IoSliceMut, Write},
15        os::{
16            fd::{AsRawFd, FromRawFd},
17            unix::net::UnixStream,
18        },
19        path::Path,
20        time::Duration,
21    },
22    thiserror::Error,
23};
24
25type RtsError = rts_alloc::error::Error;
26type ShaqError = shaq::error::Error;
27
28/// Number of global shared memory objects (in addition to per worker objects).
29const GLOBAL_SHMEM: usize = 3;
30
31/// The maximum size in bytes of the control message containing the queues assuming [`MAX_WORKERS`]
32/// is respected.
33///
34/// Each FD is 4 bytes so we simply multiply the number of shmem objects by 4 to get the control
35/// message buffer size.
36const CMSG_MAX_SIZE: usize = (GLOBAL_SHMEM + MAX_WORKERS * 2) * 4;
37
38/// Connects to the scheduler server on the given IPC path.
39///
40/// # Timeout
41///
42/// Timeout is enforced at the syscall level. In the typical case, this function will do two
43/// syscalls, one to send the logon message and one to receive the response. However, if for
44/// whatever reason the OS does not accept 1024 bytes in a single syscall, then multiple writes
45/// could be needed. As such this timeout is meant to guard against a broken server but not
46/// necessarily ensure this function always returns before the timeout (this is somewhat in line
47/// with typical timeouts because you have no guarantee of being rescheduled).
48pub fn connect(
49    path: impl AsRef<Path>,
50    logon: ClientLogon,
51    timeout: Duration,
52) -> Result<ClientSession, ClientHandshakeError> {
53    connect_path(path.as_ref(), logon, timeout)
54}
55
56fn connect_path(
57    path: &Path,
58    logon: ClientLogon,
59    timeout: Duration,
60) -> Result<ClientSession, ClientHandshakeError> {
61    // NB: Technically this connect call can block indefinitely if the receiver's connection queue
62    // is full. In practice this should almost never happen. If it does work arounds are:
63    //
64    // - Users can spawn off a thread to handle the connect call and then just poll that thread
65    //   exiting.
66    // - This library could drop to raw unix sockets and use select/poll to enforce a timeout on the
67    //   IO operation.
68    let mut stream = UnixStream::connect(path)?;
69    stream.set_read_timeout(Some(timeout))?;
70    stream.set_write_timeout(Some(timeout))?;
71
72    // Send the logon message to the server.
73    send_logon(&mut stream, logon)?;
74
75    // Receive the server's response & on success the FDs for the newly allocated shared memory.
76    let fds = recv_response(&mut stream)?;
77
78    // Join the shared memory regions.
79    let session = setup_session(&logon, fds)?;
80
81    Ok(session)
82}
83
84fn send_logon(stream: &mut UnixStream, logon: ClientLogon) -> Result<(), ClientHandshakeError> {
85    // Send the logon message.
86    let mut buf = [0; 1024];
87    buf[..8].copy_from_slice(&VERSION.to_le_bytes());
88    const LOGON_END: usize = 8 + core::mem::size_of::<ClientLogon>();
89    let ptr = buf[8..LOGON_END].as_mut_ptr().cast::<ClientLogon>();
90    // SAFETY:
91    // - `buf` is valid for writes.
92    // - `buf.len()` has enough space for logon's size in memory.
93    unsafe {
94        core::ptr::write_unaligned(ptr, logon);
95    }
96    stream.write_all(&buf)?;
97
98    Ok(())
99}
100
101fn recv_response(stream: &mut UnixStream) -> Result<Vec<i32>, ClientHandshakeError> {
102    // Receive the requested FDs.
103    let mut buf = [0; 1024];
104    let mut iov = [IoSliceMut::new(&mut buf)];
105    // SAFETY: CMSG_LEN is always safe (const expression).
106    let mut cmsgs = [0u8; unsafe { CMSG_LEN(CMSG_MAX_SIZE as u32) as usize }];
107    let msg = socket::recvmsg::<UnixAddr>(
108        stream.as_raw_fd(),
109        &mut iov,
110        Some(&mut cmsgs),
111        MsgFlags::empty(),
112    )?;
113
114    // Check for failure.
115    let buf = msg.iovs().next().unwrap();
116    if buf[0] == LOGON_FAILURE {
117        let reason_len = usize::from(buf[1]);
118        #[allow(clippy::arithmetic_side_effects)]
119        let reason = std::str::from_utf8(&buf[2..2 + reason_len]).unwrap();
120
121        return Err(ClientHandshakeError::Rejected(reason.to_string()));
122    }
123
124    // Extract FDs.
125    let mut cmsgs = msg.cmsgs().unwrap();
126    let fds = match cmsgs.next() {
127        Some(ControlMessageOwned::ScmRights(fds)) => fds,
128        Some(msg) => panic!("Unexpected; msg={msg:?}"),
129        None => panic!(),
130    };
131
132    Ok(fds)
133}
134
135fn setup_session(
136    logon: &ClientLogon,
137    fds: Vec<i32>,
138) -> Result<ClientSession, ClientHandshakeError> {
139    let [allocator_fd, tpu_to_pack_fd, progress_tracker_fd] = fds[..GLOBAL_SHMEM] else {
140        panic!();
141    };
142    // SAFETY: `allocator_fd` represents a valid file descriptor that was just returned to us via
143    // `ScmRights`.
144    let allocator_file = unsafe { File::from_raw_fd(allocator_fd) };
145    let worker_fds = &fds[GLOBAL_SHMEM..];
146
147    // Setup requested allocators.
148    let allocators = (0..logon.allocator_handles)
149        .map(|offset| {
150            // NB: Server validates all requested counts are within expected bands so this should
151            // never panic.
152            let id = GLOBAL_ALLOCATORS
153                .checked_add(logon.worker_count)
154                .unwrap()
155                .checked_add(offset)
156                .unwrap();
157
158            unsafe { Allocator::join(&allocator_file, u32::try_from(id).unwrap()) }
159        })
160        .collect::<Result<Vec<_>, _>>()?;
161
162    // Ensure worker_fds length matches expectations.
163    if worker_fds.is_empty()
164        || worker_fds.len() % 2 != 0
165        || worker_fds.len() / 2 != logon.worker_count
166    {
167        return Err(ClientHandshakeError::ProtocolViolation);
168    }
169
170    // NB: After creating & mapping the queues we are fine to drop the FDs as mmap will keep the
171    // underlying object alive until process exit or munmap.
172    let session = ClientSession {
173        allocators,
174        tpu_to_pack: unsafe { shaq::Consumer::join(&File::from_raw_fd(tpu_to_pack_fd))? },
175        progress_tracker: unsafe { shaq::Consumer::join(&File::from_raw_fd(progress_tracker_fd))? },
176        workers: worker_fds
177            .chunks(2)
178            .map(|window| {
179                let [pack_to_worker, worker_to_pack] = window else {
180                    panic!();
181                };
182
183                Ok(ClientWorkerSession {
184                    pack_to_worker: unsafe {
185                        shaq::Producer::join(&File::from_raw_fd(*pack_to_worker))?
186                    },
187                    worker_to_pack: unsafe {
188                        shaq::Consumer::join(&File::from_raw_fd(*worker_to_pack))?
189                    },
190                })
191            })
192            .collect::<Result<_, ClientHandshakeError>>()?,
193    };
194
195    Ok(session)
196}
197
198/// The complete initialized scheduling session.
199pub struct ClientSession {
200    pub allocators: Vec<Allocator>,
201    pub tpu_to_pack: shaq::Consumer<TpuToPackMessage>,
202    pub progress_tracker: shaq::Consumer<ProgressMessage>,
203    pub workers: Vec<ClientWorkerSession>,
204}
205
206/// An per worker scheduling session.
207pub struct ClientWorkerSession {
208    pub pack_to_worker: shaq::Producer<PackToWorkerMessage>,
209    pub worker_to_pack: shaq::Consumer<WorkerToPackMessage>,
210}
211
212/// Potential errors that can occur during the client's side of the handshake.
213#[derive(Debug, Error)]
214pub enum ClientHandshakeError {
215    #[error("Io; err={0}")]
216    Io(#[from] std::io::Error),
217    #[error("Timed out")]
218    TimedOut,
219    #[error("Protocol violation")]
220    ProtocolViolation,
221    #[error("Rejected; reason={0}")]
222    Rejected(String),
223    #[error("Rts alloc; err={0}")]
224    RtsAlloc(#[from] RtsError),
225    #[error("Shaq; err={0}")]
226    Shaq(#[from] ShaqError),
227}
228
229impl From<nix::Error> for ClientHandshakeError {
230    fn from(value: nix::Error) -> Self {
231        Self::Io(value.into())
232    }
233}