agave_scheduling_utils/handshake/
client.rs1use {
2 crate::handshake::{
3 shared::{GLOBAL_ALLOCATORS, LOGON_FAILURE, MAX_WORKERS, VERSION},
4 ClientLogon,
5 },
6 agave_scheduler_bindings::{
7 PackToWorkerMessage, ProgressMessage, TpuToPackMessage, WorkerToPackMessage,
8 },
9 libc::CMSG_LEN,
10 nix::sys::socket::{self, ControlMessageOwned, MsgFlags, UnixAddr},
11 rts_alloc::Allocator,
12 std::{
13 fs::File,
14 io::{IoSliceMut, Write},
15 os::{
16 fd::{AsRawFd, FromRawFd},
17 unix::net::UnixStream,
18 },
19 path::Path,
20 time::Duration,
21 },
22 thiserror::Error,
23};
24
25type RtsError = rts_alloc::error::Error;
26type ShaqError = shaq::error::Error;
27
28const GLOBAL_SHMEM: usize = 3;
30
31const CMSG_MAX_SIZE: usize = (GLOBAL_SHMEM + MAX_WORKERS * 2) * 4;
37
38pub fn connect(
49 path: impl AsRef<Path>,
50 logon: ClientLogon,
51 timeout: Duration,
52) -> Result<ClientSession, ClientHandshakeError> {
53 connect_path(path.as_ref(), logon, timeout)
54}
55
56fn connect_path(
57 path: &Path,
58 logon: ClientLogon,
59 timeout: Duration,
60) -> Result<ClientSession, ClientHandshakeError> {
61 let mut stream = UnixStream::connect(path)?;
69 stream.set_read_timeout(Some(timeout))?;
70 stream.set_write_timeout(Some(timeout))?;
71
72 send_logon(&mut stream, logon)?;
74
75 let fds = recv_response(&mut stream)?;
77
78 let session = setup_session(&logon, fds)?;
80
81 Ok(session)
82}
83
84fn send_logon(stream: &mut UnixStream, logon: ClientLogon) -> Result<(), ClientHandshakeError> {
85 let mut buf = [0; 1024];
87 buf[..8].copy_from_slice(&VERSION.to_le_bytes());
88 const LOGON_END: usize = 8 + core::mem::size_of::<ClientLogon>();
89 let ptr = buf[8..LOGON_END].as_mut_ptr().cast::<ClientLogon>();
90 unsafe {
94 core::ptr::write_unaligned(ptr, logon);
95 }
96 stream.write_all(&buf)?;
97
98 Ok(())
99}
100
101fn recv_response(stream: &mut UnixStream) -> Result<Vec<i32>, ClientHandshakeError> {
102 let mut buf = [0; 1024];
104 let mut iov = [IoSliceMut::new(&mut buf)];
105 let mut cmsgs = [0u8; unsafe { CMSG_LEN(CMSG_MAX_SIZE as u32) as usize }];
107 let msg = socket::recvmsg::<UnixAddr>(
108 stream.as_raw_fd(),
109 &mut iov,
110 Some(&mut cmsgs),
111 MsgFlags::empty(),
112 )?;
113
114 let buf = msg.iovs().next().unwrap();
116 if buf[0] == LOGON_FAILURE {
117 let reason_len = usize::from(buf[1]);
118 #[allow(clippy::arithmetic_side_effects)]
119 let reason = std::str::from_utf8(&buf[2..2 + reason_len]).unwrap();
120
121 return Err(ClientHandshakeError::Rejected(reason.to_string()));
122 }
123
124 let mut cmsgs = msg.cmsgs().unwrap();
126 let fds = match cmsgs.next() {
127 Some(ControlMessageOwned::ScmRights(fds)) => fds,
128 Some(msg) => panic!("Unexpected; msg={msg:?}"),
129 None => panic!(),
130 };
131
132 Ok(fds)
133}
134
135fn setup_session(
136 logon: &ClientLogon,
137 fds: Vec<i32>,
138) -> Result<ClientSession, ClientHandshakeError> {
139 let [allocator_fd, tpu_to_pack_fd, progress_tracker_fd] = fds[..GLOBAL_SHMEM] else {
140 panic!();
141 };
142 let allocator_file = unsafe { File::from_raw_fd(allocator_fd) };
145 let worker_fds = &fds[GLOBAL_SHMEM..];
146
147 let allocators = (0..logon.allocator_handles)
149 .map(|offset| {
150 let id = GLOBAL_ALLOCATORS
153 .checked_add(logon.worker_count)
154 .unwrap()
155 .checked_add(offset)
156 .unwrap();
157
158 unsafe { Allocator::join(&allocator_file, u32::try_from(id).unwrap()) }
159 })
160 .collect::<Result<Vec<_>, _>>()?;
161
162 if worker_fds.is_empty()
164 || worker_fds.len() % 2 != 0
165 || worker_fds.len() / 2 != logon.worker_count
166 {
167 return Err(ClientHandshakeError::ProtocolViolation);
168 }
169
170 let session = ClientSession {
173 allocators,
174 tpu_to_pack: unsafe { shaq::Consumer::join(&File::from_raw_fd(tpu_to_pack_fd))? },
175 progress_tracker: unsafe { shaq::Consumer::join(&File::from_raw_fd(progress_tracker_fd))? },
176 workers: worker_fds
177 .chunks(2)
178 .map(|window| {
179 let [pack_to_worker, worker_to_pack] = window else {
180 panic!();
181 };
182
183 Ok(ClientWorkerSession {
184 pack_to_worker: unsafe {
185 shaq::Producer::join(&File::from_raw_fd(*pack_to_worker))?
186 },
187 worker_to_pack: unsafe {
188 shaq::Consumer::join(&File::from_raw_fd(*worker_to_pack))?
189 },
190 })
191 })
192 .collect::<Result<_, ClientHandshakeError>>()?,
193 };
194
195 Ok(session)
196}
197
198pub struct ClientSession {
200 pub allocators: Vec<Allocator>,
201 pub tpu_to_pack: shaq::Consumer<TpuToPackMessage>,
202 pub progress_tracker: shaq::Consumer<ProgressMessage>,
203 pub workers: Vec<ClientWorkerSession>,
204}
205
206pub struct ClientWorkerSession {
208 pub pack_to_worker: shaq::Producer<PackToWorkerMessage>,
209 pub worker_to_pack: shaq::Consumer<WorkerToPackMessage>,
210}
211
212#[derive(Debug, Error)]
214pub enum ClientHandshakeError {
215 #[error("Io; err={0}")]
216 Io(#[from] std::io::Error),
217 #[error("Timed out")]
218 TimedOut,
219 #[error("Protocol violation")]
220 ProtocolViolation,
221 #[error("Rejected; reason={0}")]
222 Rejected(String),
223 #[error("Rts alloc; err={0}")]
224 RtsAlloc(#[from] RtsError),
225 #[error("Shaq; err={0}")]
226 Shaq(#[from] ShaqError),
227}
228
229impl From<nix::Error> for ClientHandshakeError {
230 fn from(value: nix::Error) -> Self {
231 Self::Io(value.into())
232 }
233}