1use anyhow::Result;
2use async_trait::async_trait;
3use candle_core::{Device, Tensor};
4use tokio::net::TcpStream;
5
6use super::{Context, Message, WorkerInfo};
7
8#[derive(Debug)]
12pub struct Client {
13 device: Device,
14 address: String,
15 layer_name: String,
16 stream: TcpStream,
17 info: WorkerInfo,
18 read_buf: Vec<u8>,
19 write_buf: Vec<u8>,
20}
21
22impl Client {
23 pub async fn new(
28 device: Device,
29 address: &str,
30 layer_name: &str,
31 cluster_key: Option<&str>,
32 ) -> Result<Self> {
33 let address = address.to_string();
34 let layer_name = layer_name.to_string();
35 let stream = TcpStream::connect(&address)
36 .await
37 .map_err(|e| anyhow!("can't connect to {address}: {e}"))?;
38 stream.set_nodelay(true)?;
39 let worker_info = WorkerInfo::default();
40
41 let mut client = Self {
42 address,
43 device,
44 stream,
45 layer_name,
46 info: worker_info,
47 read_buf: Vec::new(),
48 write_buf: Vec::new(),
49 };
50
51 if let Some(key) = cluster_key {
53 super::auth::authenticate_as_master(&mut client.stream, key).await?;
54 }
55
56 let resp = client.request(Message::Hello).await?;
57 client.info = if let Message::WorkerInfo(info) = resp {
58 info
59 } else {
60 return Err(anyhow!("unexpected worker info message: {:?}", &resp));
61 };
62
63 Ok(client)
64 }
65
66 async fn request(&mut self, req: Message) -> Result<Message> {
68 req.to_writer(&mut self.stream)
69 .await
70 .map_err(|e| anyhow!("error sending message {:?}: {}", req, e))?;
71
72 let (_, msg) = super::Message::from_reader_buf(&mut self.stream, &mut self.read_buf)
73 .await
74 .map_err(|e| anyhow!("error receiving response for {:?}: {}", req, e))?;
75 Ok(msg)
76 }
77
78 async fn forward_request(&mut self, req: Message) -> Result<Tensor> {
79 let send_start = std::time::Instant::now();
80 req.to_writer_buf(&mut self.stream, &mut self.write_buf)
81 .await
82 .map_err(|e| anyhow!("error sending message {:?}: {}", req, e))?;
83 let send_elapsed = send_start.elapsed();
84
85 let recv_start = std::time::Instant::now();
86 let (resp_size, msg) = super::Message::from_reader_buf(&mut self.stream, &mut self.read_buf)
87 .await
88 .map_err(|e| anyhow!("error receiving response for {:?}: {}", req, e))?;
89 let recv_elapsed = recv_start.elapsed();
90
91 log::debug!(
92 " {} send={:.1}ms recv={:.1}ms ({})",
93 &self.address,
94 send_elapsed.as_secs_f64() * 1000.0,
95 recv_elapsed.as_secs_f64() * 1000.0,
96 human_bytes::human_bytes(resp_size as f64),
97 );
98
99 match msg {
100 Message::Tensor(raw) => Ok(raw.to_tensor(&self.device)?),
101 Message::WorkerError { message } => Err(anyhow!(
102 "worker {} reported error: {}",
103 &self.address,
104 message
105 )),
106 _ => Err(anyhow!("unexpected response {:?}", &msg)),
107 }
108 }
109}
110
111impl std::fmt::Display for Client {
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 write!(
114 f,
115 "{}@{} [{}<{}> {}-{} latency={}ms]",
116 &self.layer_name,
117 &self.address,
118 &self.info.device,
119 &self.info.device_idx,
120 &self.info.os,
121 &self.info.arch,
122 self.info.latency
123 )
124 }
125}
126
127#[async_trait]
128impl super::Forwarder for Client {
129 fn load(_: String, _: &Context) -> Result<Box<Self>> {
130 Err(anyhow!("load should never be called on cake::Client"))
131 }
132
133 async fn forward(&self, _: &Tensor, _: usize, _: usize, _: &mut Context) -> Result<Tensor> {
134 Err(anyhow!(
135 "immutable forward should never be called on cake::Client"
136 ))
137 }
138
139 async fn forward_mut(
141 &mut self,
142 x: &Tensor,
143 index_pos: usize,
144 block_idx: usize,
145 _: &mut Context,
146 ) -> Result<Tensor> {
147 log::debug!("forwarding single op");
148 self.forward_request(super::Message::single_op(
149 &self.layer_name,
150 x,
151 index_pos,
152 block_idx,
153 ))
154 .await
155 }
156
157 async fn forward_batch(
159 &mut self,
160 x: &Tensor,
161 batch: Vec<(String, usize, usize)>,
162 _: &mut Context,
163 ) -> Result<Tensor> {
164 log::debug!("forwarding batch of {} elements", batch.len());
165 self.forward_request(super::Message::from_batch(x, batch))
166 .await
167 }
168
169 async fn goodbye(&mut self) -> Result<()> {
170 self.request(Message::Goodbye).await?;
171 Ok(())
172 }
173
174 fn layer_name(&self) -> &str {
175 &self.layer_name
176 }
177
178 fn ident(&self) -> &str {
179 &self.address
180 }
181}