Skip to main content

rserve_rust_client/
lib.rs

1use bytes::{Buf, BufMut, Bytes};
2use tokio;
3use tokio::io::AsyncWriteExt;
4
5pub enum RserveConnection {
6  Tcp(tokio::net::TcpStream),
7  Unix(tokio::net::UnixStream),
8}
9
10pub enum ReturnValue {
11  Char(char),
12  Int(i32),
13  Double(f64),
14  Null(String),
15  Bool(bool),
16  Str(String),
17
18  IntVec(Vec<i32>),
19  DoubleVec(Vec<f64>),
20  BoolVec(Vec<bool>),
21  StrVec(Vec<String>),
22}
23
24impl RserveConnection {
25  async fn readable(&mut self) -> std::io::Result<()> {
26    match self {
27      RserveConnection::Tcp(stream) => stream.readable().await?,
28      RserveConnection::Unix(stream) => stream.readable().await?,
29    }
30    Ok(())
31  }
32
33  fn try_read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
34    match self {
35      RserveConnection::Tcp(stream) => stream.try_read(buf),
36      RserveConnection::Unix(stream) => stream.try_read(buf),
37    }
38  }
39
40  async fn writable(&self) -> std::io::Result<()> {
41    match self {
42      RserveConnection::Tcp(stream) => stream.writable().await?,
43      RserveConnection::Unix(stream) => stream.writable().await?,
44    }
45    Ok(())
46  }
47  fn try_write(&self, buf: &[u8]) -> std::io::Result<usize> {
48    match self {
49      RserveConnection::Tcp(stream) => stream.try_write(buf),
50      RserveConnection::Unix(stream) => stream.try_write(buf),
51    }
52  }
53  async fn shut_down(&mut self) -> std::io::Result<()> {
54    match self {
55      RserveConnection::Tcp(stream) => stream.shutdown().await?,
56      RserveConnection::Unix(stream) => stream.shutdown().await?,
57    }
58    Ok(())
59  }
60
61  // Implement other methods of RserveConnection similarly
62  // command: null terminated c-string, for example "1+1\0"
63  // if you got a error code, please run self.eval("geterrmessage()", false).await?
64  // to get R error information
65  pub async fn eval(
66    &mut self,
67    command: &str,
68    void: bool,
69  ) -> Result<ReturnValue, Box<dyn std::error::Error>> {
70    // write
71    self.writable().await?;
72    let cmd = Bytes::from(command.to_string());
73    let cmd_length = cmd.len() as i32;
74
75    let mut message_header = vec![];
76    if void {
77      message_header.put_i32_le(0x002_i32); // CMD_VOID_EVAL
78    } else {
79      message_header.put_i32_le(0x003_i32); // CMD_EVAL
80    }
81    message_header.put_i32_le(cmd_length + 4);
82    message_header.put_i32_le(0_i32);
83    message_header.put_i32_le(0_i32);
84
85    let mut data_header = vec![];
86    data_header.put_u8(0x04_u8);
87    data_header.put_i32_le(cmd_length);
88
89    let mut message = vec![];
90    message.put(&message_header[..]);
91    message.put(&data_header[..4]);
92    message.put(&cmd[..]);
93
94    match self.try_write(&message) {
95      Ok(n) => {
96        assert_eq!(n, message.len());
97      }
98      Err(ref e) if e.kind() == tokio::io::ErrorKind::WouldBlock => {}
99      Err(e) => {
100        return Err(e.into());
101      }
102    };
103
104    // read response
105    // dont support return value length larger than 1024
106    loop {
107      self.readable().await?;
108      let mut data = vec![0_u8; 1024];
109      match self.try_read(&mut data) {
110        Ok(n) => {
111          let mut res_data = &data[..n];
112          // response message header 16 bytes
113          let cmd_res = res_data.get_i32_le(); // 0-3
114          let err_code = (cmd_res >> 24) & 127;
115          let response_code = cmd_res & 0xfffff;
116
117          //error eval, return error info
118          if response_code != (0x10000 | 0x0001) {
119            /*
120              use async_recursion::async_recursion
121              let err_info = self.eval(stream, "geterrmessage()", false).await?;
122            */
123            let err_info = format!("error code: {}", err_code);
124            return Err(Box::new(std::io::Error::new(
125              std::io::ErrorKind::Other,
126              err_info,
127            )));
128          }
129
130          /*
131          ignore message header remain field
132
133          let data_length = res_data.get_i32_le();//4-7
134          let data_offset = res_data.get_i32_le();//8-11, 0
135          let data_header_length2 = res_data.get_i32_le(); //12-15
136          */
137          res_data.advance(12);
138
139          // response message data header 4 bytes
140          let data_type = res_data.get_u8(); //16
141                                             //let raw_data_header_length2 = res_data.take(3);//17-19
142          res_data.advance(3);
143          //let mut dst = vec![];
144          //dst.put(raw_data_header_length2);
145          //dst.put_u8(0_u8);
146          //let data_length2 = (&dst[..]).get_i32_le();
147
148          match data_type {
149            // DT_INT
150            1_u8 => {
151              let a = res_data.get_i32_le();
152              return Ok(ReturnValue::Int(a));
153            }
154            // DT_CHAR
155            2_u8 => {
156              let a = res_data.get_u8() as char;
157              return Ok(ReturnValue::Char(a));
158            }
159            // DT_DOUBLE
160            3_u8 => {
161              let a = res_data.get_f64_le();
162              return Ok(ReturnValue::Double(a));
163            }
164            // DT_STRING 0 terminted string
165            4_u8 => {
166              let a = res_data.chunk().to_vec();
167              let s = String::from_utf8(a).unwrap();
168              return Ok(ReturnValue::Str(s));
169            }
170
171            // DT_SEXP
172            10_u8 => {
173              let expression_type = res_data.get_u8(); // eXpression Type
174              let raw_data_header_length2 = res_data.take(3); // (24-bit int) length
175              let mut dst = vec![];
176              dst.put(raw_data_header_length2);
177              dst.put_u8(0_u8);
178              let data_length2 = (&dst[..]).get_i32_le();
179              // pass header part
180              res_data.advance(3);
181
182              match expression_type {
183                // XT_NULL
184                0_u8 => {
185                  return Ok(ReturnValue::Null("NULL".to_string()));
186                }
187                // XT_INT
188                1_u8 => {
189                  let a = res_data.get_i32_le();
190                  return Ok(ReturnValue::Int(a));
191                }
192                // XT_DOUBLE
193                2_u8 => {
194                  let a = res_data.get_f64_le();
195                  return Ok(ReturnValue::Double(a));
196                }
197                // XT_STR
198                3_u8 => {
199                  let a = res_data.chunk().to_vec();
200                  let s = String::from_utf8(a).unwrap();
201                  return Ok(ReturnValue::Str(s));
202                }
203                // XT_BOOL
204                6_u8 => {
205                  let a = res_data.get_u8();
206                  if a == 1 {
207                    return Ok(ReturnValue::Bool(true));
208                  } else {
209                    return Ok(ReturnValue::Bool(true));
210                  }
211                }
212                // XT_ARRAY_INT
213                32_u8 => {
214                  let mut a: Vec<i32> = vec![];
215                  for _ in 0..data_length2 / 4 {
216                    a.push(res_data.get_i32_le());
217                  }
218                  return Ok(ReturnValue::IntVec(a));
219                }
220                // XT_ARRAY_DOUBLE
221                33_u8 => {
222                  let mut a: Vec<f64> = vec![];
223                  for _ in 0..data_length2 / 8 {
224                    a.push(res_data.get_f64_le());
225                  }
226                  return Ok(ReturnValue::DoubleVec(a));
227                }
228                // XT_ARRAY_STR
229                34_u8 => {
230                  let a: Vec<String> =
231                    String::from_utf8(res_data.take(data_length2 as usize).chunk().to_vec())
232                      .unwrap()
233                      .split("\0")
234                      .map(|word| word.to_string())
235                      .collect();
236                  return Ok(ReturnValue::StrVec(a));
237                }
238                // XT_ARRAY_BOOL
239                36_u8 => {
240                  let mut a: Vec<bool> = vec![];
241                  for _ in 0..data_length2 {
242                    let b = res_data.get_u8();
243                    if b == 1 {
244                      a.push(true);
245                    } else {
246                      a.push(false);
247                    }
248                  }
249                  return Ok(ReturnValue::BoolVec(a));
250                }
251
252                _ => {
253                  return Err(Box::new(std::io::Error::new(
254                    std::io::ErrorKind::Unsupported,
255                    "unsupported outcome type!",
256                  )));
257                }
258              }
259            }
260            // DT_BYTE
261            // DT_ARRAY
262            // DT_CUSTOM
263            // DT_LARGE
264            _ => {
265              return Err(Box::new(std::io::Error::new(
266                std::io::ErrorKind::Unsupported,
267                "unsupported outcome type!",
268              )));
269            }
270          };
271        }
272        Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
273        Err(e) => {
274          return Err(e.into());
275        }
276      };
277    }
278  }
279}
280
281pub async fn connect(addr: &str) -> Result<RserveConnection, Box<dyn std::error::Error>> {
282  if addr.starts_with("tcp://") {
283    let addr = addr.trim_start_matches("tcp://");
284    let s = tokio::net::TcpStream::connect(addr).await?;
285    loop {
286      s.readable().await?;
287      let mut data = vec![0_u8; 1024];
288      match s.try_read(&mut data) {
289        Ok(n) => {
290          let string_result = String::from_utf8_lossy(&data[..n]);
291          assert!(string_result.starts_with("Rsrv01"));
292          break;
293        }
294        Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
295        Err(e) => {
296          return Err(e.into());
297        }
298      }
299    }
300    Ok(RserveConnection::Tcp(s))
301  } else if addr.starts_with("unix://") {
302    let path = addr.trim_start_matches("unix://");
303    let ss = tokio::net::UnixStream::connect(path).await?;
304    loop {
305      ss.readable().await?;
306      let mut data = vec![0_u8; 1024];
307      match ss.try_read(&mut data) {
308        Ok(n) => {
309          let string_result = String::from_utf8_lossy(&data[..n]);
310          assert!(string_result.starts_with("Rsrv01"));
311          break;
312        }
313        Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
314        Err(e) => {
315          return Err(e.into());
316        }
317      }
318    }
319    Ok(RserveConnection::Unix(ss))
320  } else {
321    Err(Box::new(std::io::Error::new(
322      std::io::ErrorKind::InvalidInput,
323      "Invalid address format",
324    )))
325  }
326}