1use anyhow::Result;
2use candle_core::{DType, Device, Tensor};
3use safetensors::View;
4use speedy::{BigEndian, Readable, Writable};
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6
7fn dtype_to_u8(dtype: DType) -> u8 {
9 match dtype {
10 DType::U8 => 0,
11 DType::U32 => 1,
12 DType::I64 => 2,
13 DType::BF16 => 3,
14 DType::F16 => 4,
15 DType::F32 => 5,
16 DType::F64 => 6,
17 DType::F8E4M3 => 7,
18 _ => 255,
20 }
21}
22
23fn u8_to_dtype(tag: u8) -> Result<DType> {
25 match tag {
26 0 => Ok(DType::U8),
27 1 => Ok(DType::U32),
28 2 => Ok(DType::I64),
29 3 => Ok(DType::BF16),
30 4 => Ok(DType::F16),
31 5 => Ok(DType::F32),
32 6 => Ok(DType::F64),
33 7 => Ok(DType::F8E4M3),
34 _ => Err(anyhow!("unknown dtype tag: {tag}")),
35 }
36}
37
38#[derive(Readable, Writable)]
40pub struct RawTensor {
41 pub data: Vec<u8>,
43 pub dtype: u8,
45 pub shape: Vec<usize>,
47}
48
49impl std::fmt::Debug for RawTensor {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 f.debug_struct("RawTensor")
52 .field("dtype", &self.dtype)
53 .field("shape", &self.shape)
54 .field("data_len", &self.data.len())
55 .finish()
56 }
57}
58
59impl RawTensor {
60 pub fn from_tensor(x: &Tensor) -> Self {
62 let data: Vec<u8> = x.data().into_owned();
63 let dtype = dtype_to_u8(x.dtype());
64 let shape = x.shape().clone().into_dims();
65 Self { data, dtype, shape }
66 }
67
68 pub fn to_tensor(&self, device: &Device) -> Result<Tensor> {
70 let dtype = u8_to_dtype(self.dtype)?;
71 Ok(Tensor::from_raw_buffer(
72 &self.data,
73 dtype,
74 &self.shape,
75 device,
76 )?)
77 }
78}
79
80#[derive(Debug, Default, Readable, Writable)]
82pub struct WorkerInfo {
83 pub version: String,
85 pub dtype: String,
87 pub os: String,
89 pub arch: String,
91 pub device: String,
93 pub device_idx: usize,
95 pub latency: u128,
97}
98
99#[derive(Debug, Readable, Writable)]
101pub enum Message {
102 Hello,
104 WorkerInfo(WorkerInfo),
106 SingleOp {
108 layer_name: String,
109 x: RawTensor,
110 index_pos: usize,
111 block_idx: usize,
112 },
113 Batch {
115 x: RawTensor,
116 batch: Vec<(String, usize, usize)>,
117 },
118 Tensor(RawTensor),
120 Goodbye,
122
123 LayerAssignment {
127 layers: Vec<String>,
128 model_hash: String,
130 },
131 LayerAssignmentAck { needs_data: bool },
133 ModelDataChunk {
135 filename: String,
136 offset: u64,
137 total_size: u64,
138 data: Vec<u8>,
139 },
140 ModelDataDone,
142 WorkerReady,
144 WorkerError { message: String },
146}
147
148impl Message {
149 pub fn single_op(layer_name: &str, x: &Tensor, index_pos: usize, block_idx: usize) -> Self {
151 let layer_name = layer_name.to_owned();
152 let x = RawTensor::from_tensor(x);
153 Self::SingleOp {
154 layer_name,
155 x,
156 index_pos,
157 block_idx,
158 }
159 }
160
161 pub fn from_tensor(x: &Tensor) -> Self {
163 Self::Tensor(RawTensor::from_tensor(x))
164 }
165
166 pub fn from_batch(x: &Tensor, batch: Vec<(String, usize, usize)>) -> Self {
168 Self::Batch {
169 x: RawTensor::from_tensor(x),
170 batch,
171 }
172 }
173
174 fn to_bytes(&self) -> Result<Vec<u8>> {
179 Ok(self.write_to_vec_with_ctx(BigEndian::default())?)
180 }
181
182 fn from_bytes(raw: &[u8]) -> Result<Self> {
184 Ok(Self::read_from_buffer_with_ctx(BigEndian::default(), raw)?)
185 }
186
187 pub async fn from_reader<R>(reader: &mut R) -> Result<(usize, Self)>
189 where
190 R: AsyncReadExt + Unpin,
191 {
192 let mut buf = Vec::new();
193 Self::from_reader_buf(reader, &mut buf).await
194 }
195
196 pub async fn from_reader_buf<R>(reader: &mut R, buf: &mut Vec<u8>) -> Result<(usize, Self)>
198 where
199 R: AsyncReadExt + Unpin,
200 {
201 let magic = reader.read_u32().await?;
203 if magic != super::PROTO_MAGIC {
204 return Err(anyhow!("invalid magic value: {magic}"));
205 }
206
207 let req_size = reader.read_u32().await?;
208 if req_size > super::MESSAGE_MAX_SIZE {
209 return Err(anyhow!("request size {req_size} > MESSAGE_MAX_SIZE"));
210 }
211
212 let req_size = req_size as usize;
213 buf.resize(req_size, 0);
214 reader.read_exact(&mut buf[..req_size]).await?;
215
216 Ok((req_size, Self::from_bytes(&buf[..req_size])?))
217 }
218
219 pub async fn to_writer<W>(&self, writer: &mut W) -> Result<usize>
221 where
222 W: AsyncWriteExt + Unpin,
223 {
224 let mut buf = Vec::new();
225 self.to_writer_buf(writer, &mut buf).await
226 }
227
228 pub async fn to_writer_buf<W>(&self, writer: &mut W, buf: &mut Vec<u8>) -> Result<usize>
230 where
231 W: AsyncWriteExt + Unpin,
232 {
233 let payload = self.to_bytes()?;
234 let payload_size = payload.len() as u32;
235 if payload_size > super::MESSAGE_MAX_SIZE {
236 return Err(anyhow!("request size {payload_size} > MESSAGE_MAX_SIZE"));
237 }
238
239 let frame_len = 8 + payload.len();
241 buf.clear();
242 buf.reserve(frame_len);
243 buf.extend_from_slice(&super::PROTO_MAGIC.to_be_bytes());
244 buf.extend_from_slice(&payload_size.to_be_bytes());
245 buf.extend_from_slice(&payload);
246 writer.write_all(buf).await?;
247
248 Ok(frame_len)
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255 use candle_core::{DType, Device, Tensor};
256
257 fn make_f32_tensor(shape: &[usize]) -> Tensor {
258 let n: usize = shape.iter().product();
259 let data: Vec<f32> = (0..n).map(|i| i as f32 * 0.1).collect();
260 Tensor::from_vec(data, shape, &Device::Cpu).unwrap()
261 }
262
263 fn make_f16_tensor(shape: &[usize]) -> Tensor {
264 make_f32_tensor(shape).to_dtype(DType::F16).unwrap()
265 }
266
267 #[test]
270 fn test_raw_tensor_roundtrip_f32() {
271 let original = make_f32_tensor(&[2, 64]);
272 let raw = RawTensor::from_tensor(&original);
273 assert_eq!(raw.dtype, dtype_to_u8(DType::F32));
274 assert_eq!(raw.shape, vec![2, 64]);
275 assert_eq!(raw.data.len(), 2 * 64 * 4);
276
277 let recovered = raw.to_tensor(&Device::Cpu).unwrap();
278 assert_eq!(recovered.dtype(), DType::F32);
279 assert_eq!(recovered.shape().dims(), &[2, 64]);
280 }
281
282 #[test]
283 fn test_raw_tensor_roundtrip_f16() {
284 let original = make_f16_tensor(&[1, 128]);
285 let raw = RawTensor::from_tensor(&original);
286 assert_eq!(raw.dtype, dtype_to_u8(DType::F16));
287 assert_eq!(raw.shape, vec![1, 128]);
288 assert_eq!(raw.data.len(), 1 * 128 * 2);
289
290 let recovered = raw.to_tensor(&Device::Cpu).unwrap();
291 assert_eq!(recovered.dtype(), DType::F16);
292 assert_eq!(recovered.shape().dims(), &[1, 128]);
293 }
294
295 #[test]
296 fn test_raw_tensor_data_integrity() {
297 let original = make_f16_tensor(&[1, 256]);
298 let orig_bytes: Vec<u8> = original.data().to_vec();
299 let raw = RawTensor::from_tensor(&original);
300 let recovered = raw.to_tensor(&Device::Cpu).unwrap();
301 let recv_bytes: Vec<u8> = recovered.data().to_vec();
302 assert_eq!(orig_bytes, recv_bytes);
303 }
304
305 #[test]
308 fn test_message_hello_roundtrip() {
309 let bytes = Message::Hello.to_bytes().unwrap();
310 let decoded = Message::from_bytes(&bytes).unwrap();
311 assert!(matches!(decoded, Message::Hello));
312 }
313
314 #[test]
315 fn test_message_goodbye_roundtrip() {
316 let bytes = Message::Goodbye.to_bytes().unwrap();
317 let decoded = Message::from_bytes(&bytes).unwrap();
318 assert!(matches!(decoded, Message::Goodbye));
319 }
320
321 #[test]
322 fn test_message_worker_info_roundtrip() {
323 let info = WorkerInfo {
324 version: "0.1.0".into(),
325 dtype: "F16".into(),
326 os: "linux".into(),
327 arch: "x86_64".into(),
328 device: "cuda".into(),
329 device_idx: 2,
330 latency: 42,
331 };
332 let bytes = Message::WorkerInfo(info).to_bytes().unwrap();
333 let decoded = Message::from_bytes(&bytes).unwrap();
334 match decoded {
335 Message::WorkerInfo(wi) => {
336 assert_eq!(wi.version, "0.1.0");
337 assert_eq!(wi.dtype, "F16");
338 assert_eq!(wi.os, "linux");
339 assert_eq!(wi.arch, "x86_64");
340 assert_eq!(wi.device, "cuda");
341 assert_eq!(wi.device_idx, 2);
342 assert_eq!(wi.latency, 42);
343 }
344 other => panic!("expected WorkerInfo, got {:?}", other),
345 }
346 }
347
348 #[test]
349 fn test_message_tensor_roundtrip() {
350 let tensor = make_f16_tensor(&[1, 64]);
351 let bytes = Message::from_tensor(&tensor).to_bytes().unwrap();
352 let decoded = Message::from_bytes(&bytes).unwrap();
353 match decoded {
354 Message::Tensor(raw) => {
355 let t = raw.to_tensor(&Device::Cpu).unwrap();
356 assert_eq!(t.dtype(), DType::F16);
357 assert_eq!(t.shape().dims(), &[1, 64]);
358 }
359 other => panic!("expected Tensor, got {:?}", other),
360 }
361 }
362
363 #[test]
364 fn test_message_single_op_roundtrip() {
365 let tensor = make_f16_tensor(&[1, 64]);
366 let msg = Message::single_op("model.layers.5", &tensor, 42, 7);
367 let bytes = msg.to_bytes().unwrap();
368 let decoded = Message::from_bytes(&bytes).unwrap();
369 match decoded {
370 Message::SingleOp {
371 layer_name,
372 x,
373 index_pos,
374 block_idx,
375 } => {
376 assert_eq!(layer_name, "model.layers.5");
377 assert_eq!(index_pos, 42);
378 assert_eq!(block_idx, 7);
379 let t = x.to_tensor(&Device::Cpu).unwrap();
380 assert_eq!(t.shape().dims(), &[1, 64]);
381 }
382 other => panic!("expected SingleOp, got {:?}", other),
383 }
384 }
385
386 #[test]
387 fn test_message_batch_roundtrip() {
388 let tensor = make_f16_tensor(&[1, 128]);
389 let batch = vec![
390 ("model.layers.0".into(), 0usize, 0usize),
391 ("model.layers.1".into(), 1, 1),
392 ("model.layers.2".into(), 2, 2),
393 ];
394 let msg = Message::from_batch(&tensor, batch);
395 let bytes = msg.to_bytes().unwrap();
396 let decoded = Message::from_bytes(&bytes).unwrap();
397 match decoded {
398 Message::Batch { x, batch } => {
399 assert_eq!(batch.len(), 3);
400 assert_eq!(batch[0].0, "model.layers.0");
401 assert_eq!(batch[1].1, 1);
402 assert_eq!(batch[2].2, 2);
403 let t = x.to_tensor(&Device::Cpu).unwrap();
404 assert_eq!(t.shape().dims(), &[1, 128]);
405 }
406 other => panic!("expected Batch, got {:?}", other),
407 }
408 }
409
410 #[test]
411 fn test_message_worker_error_roundtrip() {
412 let msg = Message::WorkerError {
413 message: "layer not found".into(),
414 };
415 let bytes = msg.to_bytes().unwrap();
416 let decoded = Message::from_bytes(&bytes).unwrap();
417 match decoded {
418 Message::WorkerError { message } => {
419 assert_eq!(message, "layer not found");
420 }
421 other => panic!("expected WorkerError, got {:?}", other),
422 }
423 }
424
425 #[tokio::test]
428 async fn test_wire_hello() {
429 let (mut writer, mut reader) = tokio::io::duplex(1024);
430 let written = Message::Hello.to_writer(&mut writer).await.unwrap();
431 drop(writer);
432
433 let (payload_size, decoded) = Message::from_reader(&mut reader).await.unwrap();
434 assert!(matches!(decoded, Message::Hello));
435 assert_eq!(written, 8 + payload_size);
436 }
437
438 #[tokio::test]
439 async fn test_wire_tensor() {
440 let (mut writer, mut reader) = tokio::io::duplex(64 * 1024);
441 let tensor = make_f16_tensor(&[1, 128]);
442 let orig_bytes: Vec<u8> = tensor.data().to_vec();
443
444 Message::from_tensor(&tensor)
445 .to_writer(&mut writer)
446 .await
447 .unwrap();
448 drop(writer);
449
450 let (_, decoded) = Message::from_reader(&mut reader).await.unwrap();
451 match decoded {
452 Message::Tensor(raw) => {
453 assert_eq!(raw.dtype, dtype_to_u8(DType::F16));
454 assert_eq!(raw.shape, vec![1, 128]);
455 let t = raw.to_tensor(&Device::Cpu).unwrap();
456 assert_eq!(t.data().to_vec(), orig_bytes);
457 }
458 other => panic!("expected Tensor, got {:?}", other),
459 }
460 }
461
462 #[tokio::test]
463 async fn test_wire_invalid_magic() {
464 let (mut writer, mut reader) = tokio::io::duplex(1024);
465 writer.write_u32(0xDEADBEEF_u32).await.unwrap();
466 writer.write_u32(4_u32).await.unwrap();
467 writer.write_all(&[0, 0, 0, 0]).await.unwrap();
468 drop(writer);
469
470 let result = Message::from_reader(&mut reader).await;
471 assert!(result.is_err());
472 assert!(
473 result.unwrap_err().to_string().contains("invalid magic"),
474 "should report invalid magic"
475 );
476 }
477
478 #[tokio::test]
479 async fn test_wire_oversized_message() {
480 let (mut writer, mut reader) = tokio::io::duplex(1024);
481 writer
484 .write_u32(crate::cake::proto::PROTO_MAGIC)
485 .await
486 .unwrap();
487 writer
488 .write_u32(crate::cake::proto::MESSAGE_MAX_SIZE + 1)
489 .await
490 .unwrap();
491 drop(writer);
492
493 let result = Message::from_reader(&mut reader).await;
494 assert!(result.is_err());
495 assert!(
496 result
497 .unwrap_err()
498 .to_string()
499 .contains("MESSAGE_MAX_SIZE"),
500 "should report oversized message"
501 );
502 }
503
504 #[tokio::test]
505 async fn test_wire_multiple_messages() {
506 let (mut writer, mut reader) = tokio::io::duplex(64 * 1024);
507 let tensor = make_f16_tensor(&[1, 32]);
508
509 Message::Hello.to_writer(&mut writer).await.unwrap();
510 Message::single_op("model.layers.0", &tensor, 0, 0)
511 .to_writer(&mut writer)
512 .await
513 .unwrap();
514 Message::from_tensor(&tensor)
515 .to_writer(&mut writer)
516 .await
517 .unwrap();
518 Message::Goodbye.to_writer(&mut writer).await.unwrap();
519 drop(writer);
520
521 let (_, m1) = Message::from_reader(&mut reader).await.unwrap();
522 assert!(matches!(m1, Message::Hello));
523
524 let (_, m2) = Message::from_reader(&mut reader).await.unwrap();
525 assert!(matches!(m2, Message::SingleOp { .. }));
526
527 let (_, m3) = Message::from_reader(&mut reader).await.unwrap();
528 assert!(matches!(m3, Message::Tensor(_)));
529
530 let (_, m4) = Message::from_reader(&mut reader).await.unwrap();
531 assert!(matches!(m4, Message::Goodbye));
532 }
533}