Skip to main content

cake_core/cake/
client.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use candle_core::{Device, Tensor};
4use tokio::net::TcpStream;
5
6use super::{Context, Message, WorkerInfo};
7
8/// A client object used by the master to connect and orchestrate the workers.
9/// From the Cake perspective, each worker is a server and the master uses
10/// multiple Client instances to connect to them.
11#[derive(Debug)]
12pub struct Client {
13    device: Device,
14    address: String,
15    layer_name: String,
16    stream: TcpStream,
17    info: WorkerInfo,
18    read_buf: Vec<u8>,
19    write_buf: Vec<u8>,
20}
21
22impl Client {
23    /// Connects to the given worker address.
24    /// NOTE: device and layer_name here are only passed for std::fmt::Display.
25    /// If `cluster_key` is provided, mutual PSK authentication is performed
26    /// before any protocol messages.
27    pub async fn new(
28        device: Device,
29        address: &str,
30        layer_name: &str,
31        cluster_key: Option<&str>,
32    ) -> Result<Self> {
33        let address = address.to_string();
34        let layer_name = layer_name.to_string();
35        let stream = TcpStream::connect(&address)
36            .await
37            .map_err(|e| anyhow!("can't connect to {address}: {e}"))?;
38        stream.set_nodelay(true)?;
39        let worker_info = WorkerInfo::default();
40
41        let mut client = Self {
42            address,
43            device,
44            stream,
45            layer_name,
46            info: worker_info,
47            read_buf: Vec::new(),
48            write_buf: Vec::new(),
49        };
50
51        // Authenticate if cluster key is set
52        if let Some(key) = cluster_key {
53            super::auth::authenticate_as_master(&mut client.stream, key).await?;
54        }
55
56        let resp = client.request(Message::Hello).await?;
57        client.info = if let Message::WorkerInfo(info) = resp {
58            info
59        } else {
60            return Err(anyhow!("unexpected worker info message: {:?}", &resp));
61        };
62
63        Ok(client)
64    }
65
66    /// Send a Message to the worker and return a response.
67    async fn request(&mut self, req: Message) -> Result<Message> {
68        req.to_writer(&mut self.stream)
69            .await
70            .map_err(|e| anyhow!("error sending message {:?}: {}", req, e))?;
71
72        let (_, msg) = super::Message::from_reader_buf(&mut self.stream, &mut self.read_buf)
73            .await
74            .map_err(|e| anyhow!("error receiving response for {:?}: {}", req, e))?;
75        Ok(msg)
76    }
77
78    async fn forward_request(&mut self, req: Message) -> Result<Tensor> {
79        let send_start = std::time::Instant::now();
80        req.to_writer_buf(&mut self.stream, &mut self.write_buf)
81            .await
82            .map_err(|e| anyhow!("error sending message {:?}: {}", req, e))?;
83        let send_elapsed = send_start.elapsed();
84
85        let recv_start = std::time::Instant::now();
86        let (resp_size, msg) = super::Message::from_reader_buf(&mut self.stream, &mut self.read_buf)
87            .await
88            .map_err(|e| anyhow!("error receiving response for {:?}: {}", req, e))?;
89        let recv_elapsed = recv_start.elapsed();
90
91        log::debug!(
92            "    {} send={:.1}ms recv={:.1}ms ({})",
93            &self.address,
94            send_elapsed.as_secs_f64() * 1000.0,
95            recv_elapsed.as_secs_f64() * 1000.0,
96            human_bytes::human_bytes(resp_size as f64),
97        );
98
99        match msg {
100            Message::Tensor(raw) => Ok(raw.to_tensor(&self.device)?),
101            Message::WorkerError { message } => Err(anyhow!(
102                "worker {} reported error: {}",
103                &self.address,
104                message
105            )),
106            _ => Err(anyhow!("unexpected response {:?}", &msg)),
107        }
108    }
109}
110
111impl std::fmt::Display for Client {
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        write!(
114            f,
115            "{}@{} [{}<{}> {}-{} latency={}ms]",
116            &self.layer_name,
117            &self.address,
118            &self.info.device,
119            &self.info.device_idx,
120            &self.info.os,
121            &self.info.arch,
122            self.info.latency
123        )
124    }
125}
126
127#[async_trait]
128impl super::Forwarder for Client {
129    fn load(_: String, _: &Context) -> Result<Box<Self>> {
130        Err(anyhow!("load should never be called on cake::Client"))
131    }
132
133    async fn forward(&self, _: &Tensor, _: usize, _: usize, _: &mut Context) -> Result<Tensor> {
134        Err(anyhow!(
135            "immutable forward should never be called on cake::Client"
136        ))
137    }
138
139    /// Executes the worker's pipeline for this tensor.
140    async fn forward_mut(
141        &mut self,
142        x: &Tensor,
143        index_pos: usize,
144        block_idx: usize,
145        _: &mut Context,
146    ) -> Result<Tensor> {
147        log::debug!("forwarding single op");
148        self.forward_request(super::Message::single_op(
149            &self.layer_name,
150            x,
151            index_pos,
152            block_idx,
153        ))
154        .await
155    }
156
157    /// Executes the worker's pipeline with multiple batched steps for this tensor.
158    async fn forward_batch(
159        &mut self,
160        x: &Tensor,
161        batch: Vec<(String, usize, usize)>,
162        _: &mut Context,
163    ) -> Result<Tensor> {
164        log::debug!("forwarding batch of {} elements", batch.len());
165        self.forward_request(super::Message::from_batch(x, batch))
166            .await
167    }
168
169    async fn goodbye(&mut self) -> Result<()> {
170        self.request(Message::Goodbye).await?;
171        Ok(())
172    }
173
174    fn layer_name(&self) -> &str {
175        &self.layer_name
176    }
177
178    fn ident(&self) -> &str {
179        &self.address
180    }
181}