1use std::sync::Arc;
2#[cfg(not(feature = "tokio"))]
3use std::{net::TcpStream, sync::Mutex, thread, time::Duration};
4
5use protobuf::CodedInputStream;
6#[cfg(feature = "tokio")]
7use tokio::{net::TcpStream, sync::Mutex};
8
9use crate::{
10 error::RpcError,
11 schema::{
12 self, connection_request, connection_response::Status,
13 ConnectionRequest, ConnectionResponse, DecodeUntagged, StreamUpdate,
14 },
15 stream::StreamWrangler,
16};
17
18pub struct Client {
46 rpc: Mutex<TcpStream>,
47 stream: Mutex<TcpStream>,
48 streams: StreamWrangler,
49}
50
51impl Client {
52 #[cfg(not(feature = "tokio"))]
61 pub fn new(
62 name: &str,
63 ip_addr: &str,
64 rpc_port: u16,
65 stream_port: u16,
66 ) -> Result<Arc<Self>, RpcError> {
67 let rpc_request = schema::ConnectionRequest {
68 type_: protobuf::EnumOrUnknown::new(connection_request::Type::RPC),
69 client_name: String::from(name),
70 ..Default::default()
71 };
72 let (rpc_stream, rpc_result) = connect(ip_addr, rpc_port, rpc_request)?;
73
74 let stream_request = schema::ConnectionRequest {
75 type_: protobuf::EnumOrUnknown::new(
76 connection_request::Type::STREAM,
77 ),
78 client_name: String::from(name),
79 client_identifier: rpc_result.client_identifier,
80 ..Default::default()
81 };
82 let (stream_stream, _) = connect(ip_addr, stream_port, stream_request)?;
83
84 let client = Arc::new(Self {
85 rpc: Mutex::new(rpc_stream),
86 stream: Mutex::new(stream_stream),
87 streams: StreamWrangler::default(),
88 });
89
90 let bg_client = client.clone();
92 thread::spawn(move || loop {
93 bg_client.update_streams().ok();
94 });
95
96 Ok(client)
97 }
98
99 #[cfg(feature = "tokio")]
108 pub async fn new(
109 name: &str,
110 ip_addr: &str,
111 rpc_port: u16,
112 stream_port: u16,
113 ) -> Result<Arc<Self>, RpcError> {
114 let rpc_request = schema::ConnectionRequest {
115 type_: protobuf::EnumOrUnknown::new(connection_request::Type::RPC),
116 client_name: String::from(name),
117 ..Default::default()
118 };
119 let (rpc_stream, rpc_result) =
120 connect(ip_addr, rpc_port, rpc_request).await?;
121
122 let stream_request = schema::ConnectionRequest {
123 type_: protobuf::EnumOrUnknown::new(
124 connection_request::Type::STREAM,
125 ),
126 client_name: String::from(name),
127 client_identifier: rpc_result.client_identifier,
128 ..Default::default()
129 };
130 let (stream_stream, _) =
131 connect(ip_addr, stream_port, stream_request).await?;
132
133 let client = Arc::new(Self {
134 rpc: Mutex::new(rpc_stream),
135 stream: Mutex::new(stream_stream),
136 streams: StreamWrangler::default(),
137 });
138
139 let bg_client = client.clone();
141 tokio::task::spawn(async move {
142 loop {
143 bg_client.update_streams().await.ok();
144 }
145 });
146
147 Ok(client)
148 }
149
150 #[cfg(not(feature = "tokio"))]
151 pub(crate) fn call(
152 &self,
153 request: schema::Request,
154 ) -> Result<schema::Response, RpcError> {
155 let mut rpc = self.rpc.lock().map_err(|_| RpcError::Client)?;
156
157 send(&mut rpc, request)?;
158 recv(&mut rpc)
159 }
160
161 #[cfg(feature = "tokio")]
162 pub(crate) async fn call(
163 &self,
164 request: schema::Request,
165 ) -> Result<schema::Response, RpcError> {
166 let mut rpc = self.rpc.lock().await;
167
168 send(&mut rpc, request).await?;
169 recv(&mut rpc).await
170 }
171
172 pub(crate) fn proc_call(
173 service: &str,
174 procedure: &str,
175 args: Vec<schema::Argument>,
176 ) -> schema::ProcedureCall {
177 schema::ProcedureCall {
178 service: service.into(),
179 procedure: procedure.into(),
180 arguments: args,
181 ..Default::default()
182 }
183 }
184
185 #[cfg(not(feature = "tokio"))]
186 pub(crate) fn update_streams(self: &Arc<Self>) -> Result<(), RpcError> {
187 let mut stream = self.stream.lock()?;
188 let update = recv::<StreamUpdate>(&mut stream)?;
189 for result in update.results {
190 self.streams.insert(
191 result.id,
192 result.result.into_option().ok_or(RpcError::Client)?,
193 )?;
194 }
195 Ok(())
196 }
197
198 #[cfg(feature = "tokio")]
199 pub(crate) fn register_stream(self: &Arc<Self>, stream_id: u64) -> u32 {
200 self.streams.increment_refcount(stream_id)
201 }
202
203 #[cfg(feature = "tokio")]
204 pub(crate) fn release_stream(self: &Arc<Self>, stream_id: u64) -> u32 {
205 self.streams.decrement_refcount(stream_id)
206 }
207
208 #[cfg(feature = "tokio")]
209 pub(crate) async fn update_streams(
210 self: &Arc<Self>,
211 ) -> Result<(), RpcError> {
212 let mut stream = self.stream.lock().await;
213 let update = recv::<StreamUpdate>(&mut stream).await?;
214 for result in update.results {
215 self.streams
216 .insert(
217 result.id,
218 result.result.into_option().ok_or(RpcError::Client)?,
219 )
220 .await?;
221 }
222 Ok(())
223 }
224
225 #[cfg(not(feature = "tokio"))]
226 pub(crate) fn read_stream<T: DecodeUntagged>(
227 self: &Arc<Self>,
228 id: u64,
229 ) -> Result<T, RpcError> {
230 self.streams.get(self.clone(), id)
231 }
232
233 #[cfg(feature = "tokio")]
234 pub(crate) async fn read_stream<T: DecodeUntagged>(
235 self: &Arc<Self>,
236 id: u64,
237 ) -> Result<T, RpcError> {
238 self.streams.get(self.clone(), id).await
239 }
240
241 #[cfg(not(feature = "tokio"))]
242 pub(crate) fn remove_stream(
243 self: &Arc<Self>,
244 id: u64,
245 ) -> Result<(), RpcError> {
246 self.streams.remove(id);
247 Ok(())
248 }
249
250 #[cfg(feature = "tokio")]
251 pub(crate) async fn remove_stream(
252 self: &Arc<Self>,
253 id: u64,
254 ) -> Result<(), RpcError> {
255 self.streams.remove(id).await;
256 Ok(())
257 }
258
259 #[cfg(not(feature = "tokio"))]
260 pub(crate) fn await_stream(&self, id: u64) {
261 self.streams.wait(id)
262 }
263
264 #[cfg(not(feature = "tokio"))]
265 pub(crate) fn await_stream_timeout(&self, id: u64, dur: Duration) {
266 self.streams.wait_timeout(id, dur)
267 }
268
269 #[cfg(feature = "tokio")]
270 pub(crate) async fn await_stream(&self, id: u64) {
271 self.streams.wait(id).await
272 }
273}
274
275#[cfg(not(feature = "tokio"))]
276fn connect(
277 ip_addr: &str,
278 port: u16,
279 request: ConnectionRequest,
280) -> Result<(TcpStream, ConnectionResponse), RpcError> {
281 let mut conn = TcpStream::connect(format!("{ip_addr}:{port}"))
282 .map_err(RpcError::Connection)?;
283
284 send(&mut conn, request)?;
285 let response = recv::<ConnectionResponse>(&mut conn)?;
286 if response.status.value() != Status::OK as i32 {
287 return Err(RpcError::Client);
288 }
289
290 Ok((conn, response))
291}
292
293#[cfg(feature = "tokio")]
294async fn connect(
295 ip_addr: &str,
296 port: u16,
297 request: ConnectionRequest,
298) -> Result<(TcpStream, ConnectionResponse), RpcError> {
299 let mut conn = TcpStream::connect(format!("{ip_addr}:{port}"))
300 .await
301 .map_err(RpcError::Connection)?;
302
303 send(&mut conn, request).await?;
304 let response = recv::<ConnectionResponse>(&mut conn).await?;
305 if response.status.value() != Status::OK as i32 {
306 return Err(RpcError::Client);
307 }
308
309 Ok((conn, response))
310}
311
312#[cfg(not(feature = "tokio"))]
313fn send<T: protobuf::Message>(
314 rpc: &mut TcpStream,
315 message: T,
316) -> Result<(), RpcError> {
317 message
318 .write_length_delimited_to_writer(rpc)
319 .map_err(Into::into)
320}
321
322#[cfg(feature = "tokio")]
323async fn send<T: protobuf::Message>(
324 rpc: &mut TcpStream,
325 message: T,
326) -> Result<(), RpcError> {
327 use tokio::io::AsyncWriteExt;
328
329 let message = message
330 .write_length_delimited_to_bytes()
331 .map_err(Into::<RpcError>::into)?;
332 rpc.write_all(&message).await.map_err(Into::into)
333}
334
335#[cfg(not(feature = "tokio"))]
336fn recv<T: protobuf::Message + Default>(
337 rpc: &mut TcpStream,
338) -> Result<T, RpcError> {
339 CodedInputStream::new(rpc)
340 .read_message()
341 .map_err(Into::into)
342}
343
344#[cfg(feature = "tokio")]
345async fn recv<T: protobuf::Message + Default>(
346 rpc: &mut TcpStream,
347) -> Result<T, RpcError> {
348 use bytes::{Buf, BytesMut};
349 use tokio::io::AsyncReadExt;
350
351 let mut buffer = BytesMut::new();
352 while buffer.is_empty() {
353 rpc.read_buf(&mut buffer)
354 .await
355 .map_err(Into::<RpcError>::into)?;
356 }
357
358 let (length, processed) = {
359 let mut decoder = CodedInputStream::from_bytes(&buffer);
360
361 (
362 decoder
363 .read_raw_varint64()?
364 .try_into()
365 .expect("Should always fit"),
366 decoder.pos().try_into().expect("Should always fit"),
367 )
368 };
369
370 buffer.advance(processed);
371
372 while buffer.len() < length {
373 rpc.read_buf(&mut buffer)
374 .await
375 .map_err(Into::<RpcError>::into)?;
376 }
377
378 T::parse_from_tokio_bytes(&buffer.freeze()).map_err(Into::into)
379}