Skip to main content

cake_core/cake/proto/
message.rs

1use anyhow::Result;
2use candle_core::{DType, Device, Tensor};
3use safetensors::View;
4use speedy::{BigEndian, Readable, Writable};
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6
7/// Map a candle DType to a compact u8 tag for wire encoding.
8fn dtype_to_u8(dtype: DType) -> u8 {
9    match dtype {
10        DType::U8 => 0,
11        DType::U32 => 1,
12        DType::I64 => 2,
13        DType::BF16 => 3,
14        DType::F16 => 4,
15        DType::F32 => 5,
16        DType::F64 => 6,
17        DType::F8E4M3 => 7,
18        // Catch-all for newer candle dtypes we don't use on the wire yet.
19        _ => 255,
20    }
21}
22
23/// Map a u8 wire tag back to a candle DType.
24fn u8_to_dtype(tag: u8) -> Result<DType> {
25    match tag {
26        0 => Ok(DType::U8),
27        1 => Ok(DType::U32),
28        2 => Ok(DType::I64),
29        3 => Ok(DType::BF16),
30        4 => Ok(DType::F16),
31        5 => Ok(DType::F32),
32        6 => Ok(DType::F64),
33        7 => Ok(DType::F8E4M3),
34        _ => Err(anyhow!("unknown dtype tag: {tag}")),
35    }
36}
37
38/// Represents a tensor in Cake protocol.
39#[derive(Readable, Writable)]
40pub struct RawTensor {
41    /// Tensor data.
42    pub data: Vec<u8>,
43    /// The data type as a compact u8 tag (see dtype_to_u8 / u8_to_dtype).
44    pub dtype: u8,
45    /// The tensor shape.
46    pub shape: Vec<usize>,
47}
48
49impl std::fmt::Debug for RawTensor {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        f.debug_struct("RawTensor")
52            .field("dtype", &self.dtype)
53            .field("shape", &self.shape)
54            .field("data_len", &self.data.len())
55            .finish()
56    }
57}
58
59impl RawTensor {
60    /// Convert x into a RawTensor.
61    pub fn from_tensor(x: &Tensor) -> Self {
62        let data: Vec<u8> = x.data().into_owned();
63        let dtype = dtype_to_u8(x.dtype());
64        let shape = x.shape().clone().into_dims();
65        Self { data, dtype, shape }
66    }
67
68    /// Convert the raw tensor in a Tensor allocated on the given device.
69    pub fn to_tensor(&self, device: &Device) -> Result<Tensor> {
70        let dtype = u8_to_dtype(self.dtype)?;
71        Ok(Tensor::from_raw_buffer(
72            &self.data,
73            dtype,
74            &self.shape,
75            device,
76        )?)
77    }
78}
79
80/// Diagnostic information about a worker.
81#[derive(Debug, Default, Readable, Writable)]
82pub struct WorkerInfo {
83    /// Protocol version.
84    pub version: String,
85    /// Tensors data type.
86    pub dtype: String,
87    /// Operating system.
88    pub os: String,
89    /// Architecture.
90    pub arch: String,
91    /// Device.
92    pub device: String,
93    /// Device index for multi GPU environments.
94    pub device_idx: usize,
95    /// Latency in millisenconds.
96    pub latency: u128,
97}
98
99/// A Cake protocol message.
100#[derive(Debug, Readable, Writable)]
101pub enum Message {
102    /// First message sent.
103    Hello,
104    /// Message that the worker sends when a master connects with runtime information.
105    WorkerInfo(WorkerInfo),
106    /// Single inference operation for a given layer.
107    SingleOp {
108        layer_name: String,
109        x: RawTensor,
110        index_pos: usize,
111        block_idx: usize,
112    },
113    /// Batched inference operations over a Tensor.
114    Batch {
115        x: RawTensor,
116        batch: Vec<(String, usize, usize)>,
117    },
118    /// A message to transmit tensors.
119    Tensor(RawTensor),
120    /// Last message sent.
121    Goodbye,
122
123    // ── Zero-config setup messages ──────────────────────────────
124
125    /// Master tells worker which layers to serve.
126    LayerAssignment {
127        layers: Vec<String>,
128        /// Short hash of model config for cache keying.
129        model_hash: String,
130    },
131    /// Worker tells master whether it needs model data.
132    LayerAssignmentAck { needs_data: bool },
133    /// Chunk of model file data from master to worker.
134    ModelDataChunk {
135        filename: String,
136        offset: u64,
137        total_size: u64,
138        data: Vec<u8>,
139    },
140    /// All model files have been sent.
141    ModelDataDone,
142    /// Worker has loaded all assigned layers and is ready for inference.
143    WorkerReady,
144    /// Worker encountered an error during inference.
145    WorkerError { message: String },
146}
147
148impl Message {
149    /// Create a Message::SingleOp message.
150    pub fn single_op(layer_name: &str, x: &Tensor, index_pos: usize, block_idx: usize) -> Self {
151        let layer_name = layer_name.to_owned();
152        let x = RawTensor::from_tensor(x);
153        Self::SingleOp {
154            layer_name,
155            x,
156            index_pos,
157            block_idx,
158        }
159    }
160
161    /// Create a Message::Tensor message.
162    pub fn from_tensor(x: &Tensor) -> Self {
163        Self::Tensor(RawTensor::from_tensor(x))
164    }
165
166    /// Create a Message::Batch message.
167    pub fn from_batch(x: &Tensor, batch: Vec<(String, usize, usize)>) -> Self {
168        Self::Batch {
169            x: RawTensor::from_tensor(x),
170            batch,
171        }
172    }
173
174    // Yes, I could use GRPC, but this is simpler and faster.
175    // Check speedy benchmarks ;)
176
177    /// Serializes the message to raw bytes.
178    fn to_bytes(&self) -> Result<Vec<u8>> {
179        Ok(self.write_to_vec_with_ctx(BigEndian::default())?)
180    }
181
182    /// Deserializes a Message from raw bytes.
183    fn from_bytes(raw: &[u8]) -> Result<Self> {
184        Ok(Self::read_from_buffer_with_ctx(BigEndian::default(), raw)?)
185    }
186
187    /// Read a Message with the provided reader.
188    pub async fn from_reader<R>(reader: &mut R) -> Result<(usize, Self)>
189    where
190        R: AsyncReadExt + Unpin,
191    {
192        let mut buf = Vec::new();
193        Self::from_reader_buf(reader, &mut buf).await
194    }
195
196    /// Read a Message, reusing `buf` to avoid per-message heap allocation.
197    pub async fn from_reader_buf<R>(reader: &mut R, buf: &mut Vec<u8>) -> Result<(usize, Self)>
198    where
199        R: AsyncReadExt + Unpin,
200    {
201        // read_u32() reads 4 bytes as big-endian and returns the native value.
202        let magic = reader.read_u32().await?;
203        if magic != super::PROTO_MAGIC {
204            return Err(anyhow!("invalid magic value: {magic}"));
205        }
206
207        let req_size = reader.read_u32().await?;
208        if req_size > super::MESSAGE_MAX_SIZE {
209            return Err(anyhow!("request size {req_size} > MESSAGE_MAX_SIZE"));
210        }
211
212        let req_size = req_size as usize;
213        buf.resize(req_size, 0);
214        reader.read_exact(&mut buf[..req_size]).await?;
215
216        Ok((req_size, Self::from_bytes(&buf[..req_size])?))
217    }
218
219    /// Write a Message with the provided writer.
220    pub async fn to_writer<W>(&self, writer: &mut W) -> Result<usize>
221    where
222        W: AsyncWriteExt + Unpin,
223    {
224        let mut buf = Vec::new();
225        self.to_writer_buf(writer, &mut buf).await
226    }
227
228    /// Write a Message, reusing `buf` to avoid per-message heap allocation.
229    pub async fn to_writer_buf<W>(&self, writer: &mut W, buf: &mut Vec<u8>) -> Result<usize>
230    where
231        W: AsyncWriteExt + Unpin,
232    {
233        let payload = self.to_bytes()?;
234        let payload_size = payload.len() as u32;
235        if payload_size > super::MESSAGE_MAX_SIZE {
236            return Err(anyhow!("request size {payload_size} > MESSAGE_MAX_SIZE"));
237        }
238
239        // Coalesce header + payload into a single write to avoid Nagle delays.
240        let frame_len = 8 + payload.len();
241        buf.clear();
242        buf.reserve(frame_len);
243        buf.extend_from_slice(&super::PROTO_MAGIC.to_be_bytes());
244        buf.extend_from_slice(&payload_size.to_be_bytes());
245        buf.extend_from_slice(&payload);
246        writer.write_all(buf).await?;
247
248        Ok(frame_len)
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255    use candle_core::{DType, Device, Tensor};
256
257    fn make_f32_tensor(shape: &[usize]) -> Tensor {
258        let n: usize = shape.iter().product();
259        let data: Vec<f32> = (0..n).map(|i| i as f32 * 0.1).collect();
260        Tensor::from_vec(data, shape, &Device::Cpu).unwrap()
261    }
262
263    fn make_f16_tensor(shape: &[usize]) -> Tensor {
264        make_f32_tensor(shape).to_dtype(DType::F16).unwrap()
265    }
266
267    // ── RawTensor round-trips ──────────────────────────────────
268
269    #[test]
270    fn test_raw_tensor_roundtrip_f32() {
271        let original = make_f32_tensor(&[2, 64]);
272        let raw = RawTensor::from_tensor(&original);
273        assert_eq!(raw.dtype, dtype_to_u8(DType::F32));
274        assert_eq!(raw.shape, vec![2, 64]);
275        assert_eq!(raw.data.len(), 2 * 64 * 4);
276
277        let recovered = raw.to_tensor(&Device::Cpu).unwrap();
278        assert_eq!(recovered.dtype(), DType::F32);
279        assert_eq!(recovered.shape().dims(), &[2, 64]);
280    }
281
282    #[test]
283    fn test_raw_tensor_roundtrip_f16() {
284        let original = make_f16_tensor(&[1, 128]);
285        let raw = RawTensor::from_tensor(&original);
286        assert_eq!(raw.dtype, dtype_to_u8(DType::F16));
287        assert_eq!(raw.shape, vec![1, 128]);
288        assert_eq!(raw.data.len(), 1 * 128 * 2);
289
290        let recovered = raw.to_tensor(&Device::Cpu).unwrap();
291        assert_eq!(recovered.dtype(), DType::F16);
292        assert_eq!(recovered.shape().dims(), &[1, 128]);
293    }
294
295    #[test]
296    fn test_raw_tensor_data_integrity() {
297        let original = make_f16_tensor(&[1, 256]);
298        let orig_bytes: Vec<u8> = original.data().to_vec();
299        let raw = RawTensor::from_tensor(&original);
300        let recovered = raw.to_tensor(&Device::Cpu).unwrap();
301        let recv_bytes: Vec<u8> = recovered.data().to_vec();
302        assert_eq!(orig_bytes, recv_bytes);
303    }
304
305    // ── Message serialization (to_bytes / from_bytes) ──────────
306
307    #[test]
308    fn test_message_hello_roundtrip() {
309        let bytes = Message::Hello.to_bytes().unwrap();
310        let decoded = Message::from_bytes(&bytes).unwrap();
311        assert!(matches!(decoded, Message::Hello));
312    }
313
314    #[test]
315    fn test_message_goodbye_roundtrip() {
316        let bytes = Message::Goodbye.to_bytes().unwrap();
317        let decoded = Message::from_bytes(&bytes).unwrap();
318        assert!(matches!(decoded, Message::Goodbye));
319    }
320
321    #[test]
322    fn test_message_worker_info_roundtrip() {
323        let info = WorkerInfo {
324            version: "0.1.0".into(),
325            dtype: "F16".into(),
326            os: "linux".into(),
327            arch: "x86_64".into(),
328            device: "cuda".into(),
329            device_idx: 2,
330            latency: 42,
331        };
332        let bytes = Message::WorkerInfo(info).to_bytes().unwrap();
333        let decoded = Message::from_bytes(&bytes).unwrap();
334        match decoded {
335            Message::WorkerInfo(wi) => {
336                assert_eq!(wi.version, "0.1.0");
337                assert_eq!(wi.dtype, "F16");
338                assert_eq!(wi.os, "linux");
339                assert_eq!(wi.arch, "x86_64");
340                assert_eq!(wi.device, "cuda");
341                assert_eq!(wi.device_idx, 2);
342                assert_eq!(wi.latency, 42);
343            }
344            other => panic!("expected WorkerInfo, got {:?}", other),
345        }
346    }
347
348    #[test]
349    fn test_message_tensor_roundtrip() {
350        let tensor = make_f16_tensor(&[1, 64]);
351        let bytes = Message::from_tensor(&tensor).to_bytes().unwrap();
352        let decoded = Message::from_bytes(&bytes).unwrap();
353        match decoded {
354            Message::Tensor(raw) => {
355                let t = raw.to_tensor(&Device::Cpu).unwrap();
356                assert_eq!(t.dtype(), DType::F16);
357                assert_eq!(t.shape().dims(), &[1, 64]);
358            }
359            other => panic!("expected Tensor, got {:?}", other),
360        }
361    }
362
363    #[test]
364    fn test_message_single_op_roundtrip() {
365        let tensor = make_f16_tensor(&[1, 64]);
366        let msg = Message::single_op("model.layers.5", &tensor, 42, 7);
367        let bytes = msg.to_bytes().unwrap();
368        let decoded = Message::from_bytes(&bytes).unwrap();
369        match decoded {
370            Message::SingleOp {
371                layer_name,
372                x,
373                index_pos,
374                block_idx,
375            } => {
376                assert_eq!(layer_name, "model.layers.5");
377                assert_eq!(index_pos, 42);
378                assert_eq!(block_idx, 7);
379                let t = x.to_tensor(&Device::Cpu).unwrap();
380                assert_eq!(t.shape().dims(), &[1, 64]);
381            }
382            other => panic!("expected SingleOp, got {:?}", other),
383        }
384    }
385
386    #[test]
387    fn test_message_batch_roundtrip() {
388        let tensor = make_f16_tensor(&[1, 128]);
389        let batch = vec![
390            ("model.layers.0".into(), 0usize, 0usize),
391            ("model.layers.1".into(), 1, 1),
392            ("model.layers.2".into(), 2, 2),
393        ];
394        let msg = Message::from_batch(&tensor, batch);
395        let bytes = msg.to_bytes().unwrap();
396        let decoded = Message::from_bytes(&bytes).unwrap();
397        match decoded {
398            Message::Batch { x, batch } => {
399                assert_eq!(batch.len(), 3);
400                assert_eq!(batch[0].0, "model.layers.0");
401                assert_eq!(batch[1].1, 1);
402                assert_eq!(batch[2].2, 2);
403                let t = x.to_tensor(&Device::Cpu).unwrap();
404                assert_eq!(t.shape().dims(), &[1, 128]);
405            }
406            other => panic!("expected Batch, got {:?}", other),
407        }
408    }
409
410    #[test]
411    fn test_message_worker_error_roundtrip() {
412        let msg = Message::WorkerError {
413            message: "layer not found".into(),
414        };
415        let bytes = msg.to_bytes().unwrap();
416        let decoded = Message::from_bytes(&bytes).unwrap();
417        match decoded {
418            Message::WorkerError { message } => {
419                assert_eq!(message, "layer not found");
420            }
421            other => panic!("expected WorkerError, got {:?}", other),
422        }
423    }
424
425    // ── Wire format (to_writer / from_reader) ──────────────────
426
427    #[tokio::test]
428    async fn test_wire_hello() {
429        let (mut writer, mut reader) = tokio::io::duplex(1024);
430        let written = Message::Hello.to_writer(&mut writer).await.unwrap();
431        drop(writer);
432
433        let (payload_size, decoded) = Message::from_reader(&mut reader).await.unwrap();
434        assert!(matches!(decoded, Message::Hello));
435        assert_eq!(written, 8 + payload_size);
436    }
437
438    #[tokio::test]
439    async fn test_wire_tensor() {
440        let (mut writer, mut reader) = tokio::io::duplex(64 * 1024);
441        let tensor = make_f16_tensor(&[1, 128]);
442        let orig_bytes: Vec<u8> = tensor.data().to_vec();
443
444        Message::from_tensor(&tensor)
445            .to_writer(&mut writer)
446            .await
447            .unwrap();
448        drop(writer);
449
450        let (_, decoded) = Message::from_reader(&mut reader).await.unwrap();
451        match decoded {
452            Message::Tensor(raw) => {
453                assert_eq!(raw.dtype, dtype_to_u8(DType::F16));
454                assert_eq!(raw.shape, vec![1, 128]);
455                let t = raw.to_tensor(&Device::Cpu).unwrap();
456                assert_eq!(t.data().to_vec(), orig_bytes);
457            }
458            other => panic!("expected Tensor, got {:?}", other),
459        }
460    }
461
462    #[tokio::test]
463    async fn test_wire_invalid_magic() {
464        let (mut writer, mut reader) = tokio::io::duplex(1024);
465        writer.write_u32(0xDEADBEEF_u32).await.unwrap();
466        writer.write_u32(4_u32).await.unwrap();
467        writer.write_all(&[0, 0, 0, 0]).await.unwrap();
468        drop(writer);
469
470        let result = Message::from_reader(&mut reader).await;
471        assert!(result.is_err());
472        assert!(
473            result.unwrap_err().to_string().contains("invalid magic"),
474            "should report invalid magic"
475        );
476    }
477
478    #[tokio::test]
479    async fn test_wire_oversized_message() {
480        let (mut writer, mut reader) = tokio::io::duplex(1024);
481        // Write valid magic, then size > MESSAGE_MAX_SIZE
482        // write_u32 already writes in big-endian, so pass native values directly.
483        writer
484            .write_u32(crate::cake::proto::PROTO_MAGIC)
485            .await
486            .unwrap();
487        writer
488            .write_u32(crate::cake::proto::MESSAGE_MAX_SIZE + 1)
489            .await
490            .unwrap();
491        drop(writer);
492
493        let result = Message::from_reader(&mut reader).await;
494        assert!(result.is_err());
495        assert!(
496            result
497                .unwrap_err()
498                .to_string()
499                .contains("MESSAGE_MAX_SIZE"),
500            "should report oversized message"
501        );
502    }
503
504    #[tokio::test]
505    async fn test_wire_multiple_messages() {
506        let (mut writer, mut reader) = tokio::io::duplex(64 * 1024);
507        let tensor = make_f16_tensor(&[1, 32]);
508
509        Message::Hello.to_writer(&mut writer).await.unwrap();
510        Message::single_op("model.layers.0", &tensor, 0, 0)
511            .to_writer(&mut writer)
512            .await
513            .unwrap();
514        Message::from_tensor(&tensor)
515            .to_writer(&mut writer)
516            .await
517            .unwrap();
518        Message::Goodbye.to_writer(&mut writer).await.unwrap();
519        drop(writer);
520
521        let (_, m1) = Message::from_reader(&mut reader).await.unwrap();
522        assert!(matches!(m1, Message::Hello));
523
524        let (_, m2) = Message::from_reader(&mut reader).await.unwrap();
525        assert!(matches!(m2, Message::SingleOp { .. }));
526
527        let (_, m3) = Message::from_reader(&mut reader).await.unwrap();
528        assert!(matches!(m3, Message::Tensor(_)));
529
530        let (_, m4) = Message::from_reader(&mut reader).await.unwrap();
531        assert!(matches!(m4, Message::Goodbye));
532    }
533}