krpc_client/
client.rs

1use std::sync::Arc;
2#[cfg(not(feature = "tokio"))]
3use std::{net::TcpStream, sync::Mutex, thread, time::Duration};
4
5use protobuf::CodedInputStream;
6#[cfg(feature = "tokio")]
7use tokio::{net::TcpStream, sync::Mutex};
8
9use crate::{
10    error::RpcError,
11    schema::{
12        self, connection_request, connection_response::Status,
13        ConnectionRequest, ConnectionResponse, DecodeUntagged, StreamUpdate,
14    },
15    stream::StreamWrangler,
16};
17
18/// The base kRPC client type.
19///
20/// ## Connecting to the kRPC server
21///
22/// Call [`new`][new] to establish a connection with the
23/// kRPC server.
24///
25/// ```
26/// use krpc_client::Client;
27/// let client = Client::new("Test KRPC", "127.0.0.1", 50000, 50001);
28/// ```
29///
30/// ## Using RPC services
31///
32/// Pass or clone the client instance returned by
33/// [`Client::new`][new] to any RPC service in
34/// [`krpc_client::services::*`][services].
35///
36/// ```
37/// use krpc_client::{services::space_center::SpaceCenter, Client};
38/// let space_center = SpaceCenter::new(client);
39/// // Then call procedures with the created service.
40/// println!("Hello, {}!", space_center.get_active_vessel()?.get_name()?);
41/// ```
42///
43/// [new]: Client::new
44/// [services]: crate::services
45pub struct Client {
46    rpc: Mutex<TcpStream>,
47    stream: Mutex<TcpStream>,
48    streams: StreamWrangler,
49}
50
51impl Client {
52    /// Constructs a new `Client`.
53    ///
54    /// # Examples
55    ///
56    /// ```
57    /// use krpc_client::Client;
58    /// let client = Client::new("Test KRPC", "127.0.0.1", 50000, 50001);
59    /// ```
60    #[cfg(not(feature = "tokio"))]
61    pub fn new(
62        name: &str,
63        ip_addr: &str,
64        rpc_port: u16,
65        stream_port: u16,
66    ) -> Result<Arc<Self>, RpcError> {
67        let rpc_request = schema::ConnectionRequest {
68            type_: protobuf::EnumOrUnknown::new(connection_request::Type::RPC),
69            client_name: String::from(name),
70            ..Default::default()
71        };
72        let (rpc_stream, rpc_result) = connect(ip_addr, rpc_port, rpc_request)?;
73
74        let stream_request = schema::ConnectionRequest {
75            type_: protobuf::EnumOrUnknown::new(
76                connection_request::Type::STREAM,
77            ),
78            client_name: String::from(name),
79            client_identifier: rpc_result.client_identifier,
80            ..Default::default()
81        };
82        let (stream_stream, _) = connect(ip_addr, stream_port, stream_request)?;
83
84        let client = Arc::new(Self {
85            rpc: Mutex::new(rpc_stream),
86            stream: Mutex::new(stream_stream),
87            streams: StreamWrangler::default(),
88        });
89
90        // Spawn a thread to receive stream updates.
91        let bg_client = client.clone();
92        thread::spawn(move || loop {
93            bg_client.update_streams().ok();
94        });
95
96        Ok(client)
97    }
98
99    /// Constructs a new `Client`.
100    ///
101    /// # Examples
102    ///
103    /// ```
104    /// use krpc_client::Client;
105    /// let client = Client::new("Test KRPC", "127.0.0.1", 50000, 50001);
106    /// ```
107    #[cfg(feature = "tokio")]
108    pub async fn new(
109        name: &str,
110        ip_addr: &str,
111        rpc_port: u16,
112        stream_port: u16,
113    ) -> Result<Arc<Self>, RpcError> {
114        let rpc_request = schema::ConnectionRequest {
115            type_: protobuf::EnumOrUnknown::new(connection_request::Type::RPC),
116            client_name: String::from(name),
117            ..Default::default()
118        };
119        let (rpc_stream, rpc_result) =
120            connect(ip_addr, rpc_port, rpc_request).await?;
121
122        let stream_request = schema::ConnectionRequest {
123            type_: protobuf::EnumOrUnknown::new(
124                connection_request::Type::STREAM,
125            ),
126            client_name: String::from(name),
127            client_identifier: rpc_result.client_identifier,
128            ..Default::default()
129        };
130        let (stream_stream, _) =
131            connect(ip_addr, stream_port, stream_request).await?;
132
133        let client = Arc::new(Self {
134            rpc: Mutex::new(rpc_stream),
135            stream: Mutex::new(stream_stream),
136            streams: StreamWrangler::default(),
137        });
138
139        // Spawn a thread to receive stream updates.
140        let bg_client = client.clone();
141        tokio::task::spawn(async move {
142            loop {
143                bg_client.update_streams().await.ok();
144            }
145        });
146
147        Ok(client)
148    }
149
150    #[cfg(not(feature = "tokio"))]
151    pub(crate) fn call(
152        &self,
153        request: schema::Request,
154    ) -> Result<schema::Response, RpcError> {
155        let mut rpc = self.rpc.lock().map_err(|_| RpcError::Client)?;
156
157        send(&mut rpc, request)?;
158        recv(&mut rpc)
159    }
160
161    #[cfg(feature = "tokio")]
162    pub(crate) async fn call(
163        &self,
164        request: schema::Request,
165    ) -> Result<schema::Response, RpcError> {
166        let mut rpc = self.rpc.lock().await;
167
168        send(&mut rpc, request).await?;
169        recv(&mut rpc).await
170    }
171
172    pub(crate) fn proc_call(
173        service: &str,
174        procedure: &str,
175        args: Vec<schema::Argument>,
176    ) -> schema::ProcedureCall {
177        schema::ProcedureCall {
178            service: service.into(),
179            procedure: procedure.into(),
180            arguments: args,
181            ..Default::default()
182        }
183    }
184
185    #[cfg(not(feature = "tokio"))]
186    pub(crate) fn update_streams(self: &Arc<Self>) -> Result<(), RpcError> {
187        let mut stream = self.stream.lock()?;
188        let update = recv::<StreamUpdate>(&mut stream)?;
189        for result in update.results {
190            self.streams.insert(
191                result.id,
192                result.result.into_option().ok_or(RpcError::Client)?,
193            )?;
194        }
195        Ok(())
196    }
197
198    #[cfg(feature = "tokio")]
199    pub(crate) fn register_stream(self: &Arc<Self>, stream_id: u64) -> u32 {
200        self.streams.increment_refcount(stream_id)
201    }
202
203    #[cfg(feature = "tokio")]
204    pub(crate) fn release_stream(self: &Arc<Self>, stream_id: u64) -> u32 {
205       self.streams.decrement_refcount(stream_id)
206    }
207
208    #[cfg(feature = "tokio")]
209    pub(crate) async fn update_streams(
210        self: &Arc<Self>,
211    ) -> Result<(), RpcError> {
212        let mut stream = self.stream.lock().await;
213        let update = recv::<StreamUpdate>(&mut stream).await?;
214        for result in update.results {
215            self.streams
216                .insert(
217                    result.id,
218                    result.result.into_option().ok_or(RpcError::Client)?,
219                )
220                .await?;
221        }
222        Ok(())
223    }
224
225    #[cfg(not(feature = "tokio"))]
226    pub(crate) fn read_stream<T: DecodeUntagged>(
227        self: &Arc<Self>,
228        id: u64,
229    ) -> Result<T, RpcError> {
230        self.streams.get(self.clone(), id)
231    }
232
233    #[cfg(feature = "tokio")]
234    pub(crate) async fn read_stream<T: DecodeUntagged>(
235        self: &Arc<Self>,
236        id: u64,
237    ) -> Result<T, RpcError> {
238        self.streams.get(self.clone(), id).await
239    }
240
241    #[cfg(not(feature = "tokio"))]
242    pub(crate) fn remove_stream(
243        self: &Arc<Self>,
244        id: u64,
245    ) -> Result<(), RpcError> {
246        self.streams.remove(id);
247        Ok(())
248    }
249
250    #[cfg(feature = "tokio")]
251    pub(crate) async fn remove_stream(
252        self: &Arc<Self>,
253        id: u64,
254    ) -> Result<(), RpcError> {
255        self.streams.remove(id).await;
256        Ok(())
257    }
258
259    #[cfg(not(feature = "tokio"))]
260    pub(crate) fn await_stream(&self, id: u64) {
261        self.streams.wait(id)
262    }
263
264    #[cfg(not(feature = "tokio"))]
265    pub(crate) fn await_stream_timeout(&self, id: u64, dur: Duration) {
266        self.streams.wait_timeout(id, dur)
267    }
268
269    #[cfg(feature = "tokio")]
270    pub(crate) async fn await_stream(&self, id: u64) {
271        self.streams.wait(id).await
272    }
273}
274
275#[cfg(not(feature = "tokio"))]
276fn connect(
277    ip_addr: &str,
278    port: u16,
279    request: ConnectionRequest,
280) -> Result<(TcpStream, ConnectionResponse), RpcError> {
281    let mut conn = TcpStream::connect(format!("{ip_addr}:{port}"))
282        .map_err(RpcError::Connection)?;
283
284    send(&mut conn, request)?;
285    let response = recv::<ConnectionResponse>(&mut conn)?;
286    if response.status.value() != Status::OK as i32 {
287        return Err(RpcError::Client);
288    }
289
290    Ok((conn, response))
291}
292
293#[cfg(feature = "tokio")]
294async fn connect(
295    ip_addr: &str,
296    port: u16,
297    request: ConnectionRequest,
298) -> Result<(TcpStream, ConnectionResponse), RpcError> {
299    let mut conn = TcpStream::connect(format!("{ip_addr}:{port}"))
300        .await
301        .map_err(RpcError::Connection)?;
302
303    send(&mut conn, request).await?;
304    let response = recv::<ConnectionResponse>(&mut conn).await?;
305    if response.status.value() != Status::OK as i32 {
306        return Err(RpcError::Client);
307    }
308
309    Ok((conn, response))
310}
311
312#[cfg(not(feature = "tokio"))]
313fn send<T: protobuf::Message>(
314    rpc: &mut TcpStream,
315    message: T,
316) -> Result<(), RpcError> {
317    message
318        .write_length_delimited_to_writer(rpc)
319        .map_err(Into::into)
320}
321
322#[cfg(feature = "tokio")]
323async fn send<T: protobuf::Message>(
324    rpc: &mut TcpStream,
325    message: T,
326) -> Result<(), RpcError> {
327    use tokio::io::AsyncWriteExt;
328
329    let message = message
330        .write_length_delimited_to_bytes()
331        .map_err(Into::<RpcError>::into)?;
332    rpc.write_all(&message).await.map_err(Into::into)
333}
334
335#[cfg(not(feature = "tokio"))]
336fn recv<T: protobuf::Message + Default>(
337    rpc: &mut TcpStream,
338) -> Result<T, RpcError> {
339    CodedInputStream::new(rpc)
340        .read_message()
341        .map_err(Into::into)
342}
343
344#[cfg(feature = "tokio")]
345async fn recv<T: protobuf::Message + Default>(
346    rpc: &mut TcpStream,
347) -> Result<T, RpcError> {
348    use bytes::{Buf, BytesMut};
349    use tokio::io::AsyncReadExt;
350
351    let mut buffer = BytesMut::new();
352    while buffer.is_empty() {
353        rpc.read_buf(&mut buffer)
354            .await
355            .map_err(Into::<RpcError>::into)?;
356    }
357
358    let (length, processed) = {
359        let mut decoder = CodedInputStream::from_bytes(&buffer);
360
361        (
362            decoder
363                .read_raw_varint64()?
364                .try_into()
365                .expect("Should always fit"),
366            decoder.pos().try_into().expect("Should always fit"),
367        )
368    };
369
370    buffer.advance(processed);
371
372    while buffer.len() < length {
373        rpc.read_buf(&mut buffer)
374            .await
375            .map_err(Into::<RpcError>::into)?;
376    }
377
378    T::parse_from_tokio_bytes(&buffer.freeze()).map_err(Into::into)
379}