kdb_plus_fixed/ipc/connection.rs
1//++++++++++++++++++++++++++++++++++++++++++++++++++//
2// >> Load Libraries
3//++++++++++++++++++++++++++++++++++++++++++++++++++//
4
5use super::serialize::ENCODING;
6use super::Result;
7use super::{qtype, K};
8use async_trait::async_trait;
9use hickory_resolver::TokioAsyncResolver;
10use io::BufRead;
11use once_cell::sync::Lazy;
12use sha1_smol::Sha1;
13use std::collections::HashMap;
14use std::convert::TryInto;
15use std::net::{IpAddr, Ipv4Addr};
16use std::path::Path;
17use std::{env, fs, io, str};
18use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
19use tokio::net::{TcpListener, TcpStream};
20#[cfg(unix)]
21use tokio::net::{UnixListener, UnixStream};
22use tokio_native_tls::native_tls::{
23 Identity, TlsAcceptor as TlsAcceptorInner, TlsConnector as TlsConnectorInner,
24};
25use tokio_native_tls::{TlsAcceptor, TlsConnector, TlsStream};
26
27//++++++++++++++++++++++++++++++++++++++++++++++++++//
28// >> Global Variable
29//++++++++++++++++++++++++++++++++++++++++++++++++++//
30
31//%% QStream %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
32
33pub mod qmsg_type {
34 //! This module provides a list of q message type used for IPC.
35 //! The motivation to contain them in a module is to tie them up as related items rather
36 //! than scattered values. Hence user should use these indicators with `qmsg_type::` prefix, e.g., `qmsg_type::asynchronous`.
37 //!
38 //! # Example
39 //! ```no_run
40 //! use kdbplus::ipc::*;
41 //!
42 //! // Print `K` object.
43 //! fn print(obj: &K) {
44 //! println!("{}", obj);
45 //! }
46 //!
47 //! // Calculate something from two long arguments.
48 //! fn nonsense(arg1: i64, arg2: i64) -> i64 {
49 //! arg1 * arg2
50 //! }
51 //!
52 //! #[tokio::main]
53 //! async fn main() -> Result<()> {
54 //! // Connect to qprocess running on localhost:5000 via TCP
55 //! let mut socket =
56 //! QStream::connect(ConnectionMethod::TCP, "localhost", 5000_u16, "ideal:person").await?;
57 //!
58 //! // Set a function which sends back a non-response message during its execution.
59 //! socket
60 //! .send_async_message(
61 //! &"complex:{neg[.z.w](`print; \"counter\"); what: .z.w (`nonsense; 1; 2); what*100}",
62 //! )
63 //! .await?;
64 //!
65 //! // Send a query `(`complex; ::)` without waiting for a response.
66 //! socket
67 //! .send_message(
68 //! &K::new_compound_list(vec![K::new_symbol(String::from("complex")), K::new_null()]),
69 //! qmsg_type::synchronous,
70 //! )
71 //! .await?;
72 //!
73 //! // Receive an asynchronous call from the function.
74 //! match socket.receive_message().await {
75 //! Ok((qmsg_type::asynchronous, message)) => {
76 //! println!("asynchronous call: {}", message);
77 //! let list = message.as_vec::<K>().unwrap();
78 //! if list[0].get_symbol().unwrap() == "print" {
79 //! print(&list[1])
80 //! }
81 //! }
82 //! _ => unreachable!(),
83 //! }
84 //!
85 //! // Receive a synchronous call from the function.
86 //! match socket.receive_message().await {
87 //! Ok((qmsg_type::synchronous, message)) => {
88 //! println!("synchronous call: {}", message);
89 //! let list = message.as_vec::<K>().unwrap();
90 //! if list[0].get_symbol().unwrap() == "nonsense" {
91 //! let res = nonsense(list[1].get_long().unwrap(), list[2].get_long().unwrap());
92 //! // Send bach a response.
93 //! socket
94 //! .send_message(&K::new_long(res), qmsg_type::response)
95 //! .await?;
96 //! }
97 //! }
98 //! _ => unreachable!(),
99 //! }
100 //!
101 //! // Receive a final result.
102 //! match socket.receive_message().await {
103 //! Ok((qmsg_type::response, message)) => {
104 //! println!("final: {}", message);
105 //! }
106 //! _ => unreachable!(),
107 //! }
108 //!
109 //! Ok(())
110 //! }
111 //!```
112 /// Used to send a message to q/kdb+ asynchronously.
113 pub const asynchronous: u8 = 0;
114 /// Used to send a message to q/kdb+ synchronously.
115 pub const synchronous: u8 = 1;
116 /// Used by q/kdb+ to identify a response for a synchronous query.
117 pub const response: u8 = 2;
118}
119
120//%% QStream Acceptor %%//vvvvvvvvvvvvvvvvvvvvvvvvvvv/
121
122/// Map from user name to password hashed with SHA1.
123#[allow(clippy::declare_interior_mutable_const)]
124const ACCOUNTS: Lazy<HashMap<String, String>> = Lazy::new(|| {
125 // Map from user to password
126 let mut map: HashMap<String, String> = HashMap::new();
127 // Open credential file
128 let file = fs::OpenOptions::new()
129 .read(true)
130 .open(env::var("KDBPLUS_ACCOUNT_FILE").expect("KDBPLUS_ACCOUNT_FILE is not set"))
131 .expect("failed to open account file");
132 let mut reader = io::BufReader::new(file);
133 let mut line = String::new();
134 loop {
135 match reader.read_line(&mut line) {
136 Ok(0) => break,
137 Ok(_) => {
138 let credential = line.as_str().split(':').collect::<Vec<&str>>();
139 let mut password = credential[1];
140 if password.ends_with('\n') {
141 password = &password[0..password.len() - 1];
142 }
143 map.insert(credential[0].to_string(), password.to_string());
144 line.clear();
145 }
146 _ => unreachable!(),
147 }
148 }
149 map
150});
151
152//++++++++++++++++++++++++++++++++++++++++++++++++++//
153// >> Structs
154//++++++++++++++++++++++++++++++++++++++++++++++++++//
155
156//%% ConnectionMethod %%//vvvvvvvvvvvvvvvvvvvvvvvvvvv/
157
158/// Connection method to q/kdb+.
159pub enum ConnectionMethod {
160 TCP = 0,
161 TLS = 1,
162 /// Unix domanin socket.
163 UDS = 2,
164}
165
166//%% Query %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
167
168/// Feature of query object.
169#[async_trait]
170pub trait Query: Send + Sync {
171 /// Serialize into q IPC bytes including a header (encoding, message type, compresssion flag and total message length).
172 /// If the connection is within the same host, the message is not compressed under any conditions.
173 /// # Parameters
174 /// - `message_type`: Message type. One of followings:
175 /// - `qmsg_type::asynchronous`
176 /// - `qmsg_type::synchronous`
177 /// - `qmsg_type::response`
178 /// - `is_local`: Flag of whether the connection is within the same host.
179 async fn serialize(&self, message_type: u8, is_local: bool) -> Vec<u8>;
180}
181
182//%% QStreamInner %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
183
184/// Features which streams communicating with q must have.
185#[async_trait]
186trait QStreamInner: Send + Sync {
187 /// Shutdown underlying stream.
188 async fn shutdown(&mut self, is_server: bool) -> Result<()>;
189 /// Send a message with a specified message type without waiting for a response.
190 /// # Parameters
191 /// - `message`: q command to execute on the remote q process.
192 /// - `message_type`: Asynchronous or synchronous.
193 /// - `is_local`: Flag of whether the connection is within the same host.
194 async fn send_message(
195 &mut self,
196 message: &dyn Query,
197 message_type: u8,
198 is_local: bool,
199 ) -> Result<()>;
200 /// Send a message asynchronously.
201 /// # Parameters
202 /// - `message`: q command in two ways:
203 /// - `&str`: q command in a string form.
204 /// - `K`: Query in a functional form.
205 /// - `is_local`: Flag of whether the connection is within the same host.
206 async fn send_async_message(&mut self, message: &dyn Query, is_local: bool) -> Result<()>;
207 /// Send a message asynchronously.
208 /// # Parameters
209 /// - `message`: q command in two ways:
210 /// - `&str`: q command in a string form.
211 /// - `K`: Query in a functional form.
212 /// - `is_local`: Flag of whether the connection is within the same host.
213 async fn send_sync_message(&mut self, message: &dyn Query, is_local: bool) -> Result<K>;
214 /// Receive a message from a remote q process. The received message is parsed as `K` and message type is
215 /// stored in the first returned value.
216 async fn receive_message(&mut self) -> Result<(u8, K)>;
217}
218
219//%% QStream %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
220
221/// Stream to communicate with q/kdb+.
222pub struct QStream {
223 /// Actual stream to communicate.
224 stream: Box<dyn QStreamInner>,
225 /// Connection method. One of followings:
226 /// - TCP
227 /// - TLS
228 /// - UDS
229 method: ConnectionMethod,
230 /// Indicator of whether the stream is an acceptor or client.
231 /// - `true`: Acceptor
232 /// - `false`: Client
233 listener: bool,
234 /// Indicator of whether the connection is within the same host.
235 /// - `true`: Connection within the same host.
236 /// - `false`: Connection with outseide.
237 local: bool,
238}
239
240//%% MessageHeader %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
241
242/// Header of q IPC data frame.
243#[derive(Clone, Copy, Debug)]
244struct MessageHeader {
245 /// Ennoding.
246 /// - 0: Big Endian
247 /// - 1: Little Endian
248 encoding: u8,
249 /// Message type. One of followings:
250 /// - 0: Asynchronous
251 /// - 1: Synchronous
252 /// - 2: Response
253 message_type: u8,
254 /// Indicator of whether the message is compressed or not.
255 /// - 0: Uncompressed
256 /// - 1: Compressed
257 compressed: u8,
258 /// Reserved byte.
259 _unused: u8,
260 /// Total length of the uncompressed message.
261 length: u32,
262}
263
264//++++++++++++++++++++++++++++++++++++++++++++++++++//
265// >> Implementation
266//++++++++++++++++++++++++++++++++++++++++++++++++++//
267
268//%% Query %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
269
270/// Text query.
271#[async_trait]
272impl Query for &str {
273 async fn serialize(&self, message_type: u8, _: bool) -> Vec<u8> {
274 // Build header //--------------------------------/
275 // Message header + (type indicator of string + header of string type) + string length
276 let byte_message = self.as_bytes();
277 let message_length = byte_message.len() as u32;
278 let total_length = MessageHeader::size() as u32 + 6 + message_length;
279
280 let total_length_bytes = match ENCODING {
281 0 => total_length.to_be_bytes(),
282 _ => total_length.to_le_bytes(),
283 };
284
285 // encode, message type, 0x00 for compression and 0x00 for reserved.
286 // Do not compress string data because it is highly unlikely that the length of the string query
287 // is greater than 2000.
288 let mut message = Vec::with_capacity(message_length as usize + MessageHeader::size());
289 message.extend_from_slice(&[ENCODING, message_type, 0, 0]);
290 // total body length
291 message.extend_from_slice(&total_length_bytes);
292 // vector type and 0x00 for attribute
293 message.extend_from_slice(&[qtype::STRING as u8, 0]);
294
295 // Build body //---------------------------------/
296 let length_info = match ENCODING {
297 0 => message_length.to_be_bytes(),
298 _ => message_length.to_le_bytes(),
299 };
300
301 // length of vector(message)
302 message.extend_from_slice(&length_info);
303 // message
304 message.extend_from_slice(byte_message);
305
306 message
307 }
308}
309
310/// Functional query.
311#[async_trait]
312impl Query for K {
313 async fn serialize(&self, message_type: u8, is_local: bool) -> Vec<u8> {
314 // Build header //--------------------------------/
315 // Message header + encoded data size
316 let mut byte_message = self.q_ipc_encode();
317 let message_length = byte_message.len();
318 let total_length = (MessageHeader::size() + message_length) as u32;
319
320 let total_length_bytes = match ENCODING {
321 0 => total_length.to_be_bytes(),
322 _ => total_length.to_le_bytes(),
323 };
324
325 // Compression is trigerred when entire message size is more than 2000 bytes
326 // and the connection is with outseide.
327 if message_length > 1992 && !is_local {
328 // encode, message type, 0x00 for compression, 0x00 for reserved and 0x00000000 for total size
329 let mut message = Vec::with_capacity(message_length + 8);
330 message.extend_from_slice(&[ENCODING, message_type, 0, 0, 0, 0, 0, 0]);
331 message.append(&mut byte_message);
332 // Try to encode entire message.
333 match compress(message).await {
334 (true, compressed) => {
335 // Message was compressed
336 return compressed;
337 }
338 (false, mut uncompressed) => {
339 // Message was not compressed.
340 // Write original total data size.
341 uncompressed[4..8].copy_from_slice(&total_length_bytes);
342 return uncompressed;
343 }
344 }
345 } else {
346 // encode, message type, 0x00 for compression and 0x00 for reserved
347 let mut message = Vec::with_capacity(message_length + MessageHeader::size());
348 message.extend_from_slice(&[ENCODING, message_type, 0, 0]);
349 // Total length of body
350 message.extend_from_slice(&total_length_bytes);
351 message.append(&mut byte_message);
352 return message;
353 }
354 }
355}
356
357//%% QStream %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
358
359impl QStream {
360 /// General constructor of `QStream`.
361 fn new(
362 stream: Box<dyn QStreamInner>,
363 method: ConnectionMethod,
364 is_listener: bool,
365 is_local: bool,
366 ) -> Self {
367 QStream {
368 stream,
369 method,
370 listener: is_listener,
371 local: is_local,
372 }
373 }
374
375 /// Connect to q/kdb+ specifying a connection method, destination host, destination port and access credential.
376 /// # Parameters
377 /// - `method`: Connection method. One of followings:
378 /// - TCP
379 /// - TLS
380 /// - UDS
381 /// - `host`: Hostname or IP address of the target q process. Empty `str` for Unix domain socket.
382 /// - `port`: Port of the target q process.
383 /// - `credential`: Credential in the form of `username:password` to connect to the target q process.
384 /// # Example
385 /// ```
386 /// use kdbplus::qattribute;
387 /// use kdbplus::ipc::*;
388 ///
389 /// #[tokio::main(flavor = "multi_thread", worker_threads = 2)]
390 /// async fn main() -> Result<()> {
391 /// let mut socket =
392 /// QStream::connect(ConnectionMethod::UDS, "", 5000_u16, "ideal:person").await?;
393 /// println!("Connection type: {}", socket.get_connection_type());
394 ///
395 /// // Set remote function with asynchronous message
396 /// socket.send_async_message(&"collatz:{[n] seq:enlist n; while[not n = 1; seq,: n:$[n mod 2; 1 + 3 * n; `long$n % 2]]; seq}").await?;
397 ///
398 /// // Send a query synchronously
399 /// let mut result = socket.send_sync_message(&"collatz[12]").await?;
400 /// println!("collatz[12]: {}", result);
401 ///
402 /// // Send a functional query.
403 /// let mut message = K::new_compound_list(vec![
404 /// K::new_symbol(String::from("collatz")),
405 /// K::new_long(100),
406 /// ]);
407 /// result = socket.send_sync_message(&message).await?;
408 /// println!("collatz[100]: {}", result);
409 ///
410 /// // Send a functional asynchronous query.
411 /// message = K::new_compound_list(vec![
412 /// K::new_string(String::from("show"), qattribute::NONE),
413 /// K::new_symbol(String::from("goodbye")),
414 /// ]);
415 /// socket.send_async_message(&message).await?;
416 ///
417 /// socket.shutdown().await?;
418 ///
419 /// Ok(())
420 /// }
421 /// ```
422 pub async fn connect(
423 method: ConnectionMethod,
424 host: &str,
425 port: u16,
426 credential: &str,
427 ) -> Result<Self> {
428 match method {
429 ConnectionMethod::TCP => {
430 let stream = connect_tcp(host, port, credential).await?;
431 let is_local = matches!(host, "localhost" | "127.0.0.1");
432 Ok(QStream::new(
433 Box::new(stream),
434 ConnectionMethod::TCP,
435 false,
436 is_local,
437 ))
438 }
439 ConnectionMethod::TLS => {
440 let stream = connect_tls(host, port, credential).await?;
441 Ok(QStream::new(
442 Box::new(stream),
443 ConnectionMethod::TLS,
444 false,
445 false,
446 ))
447 }
448 #[cfg(unix)]
449 ConnectionMethod::UDS => {
450 let stream = connect_uds(port, credential).await?;
451 Ok(QStream::new(
452 Box::new(stream),
453 ConnectionMethod::UDS,
454 false,
455 true,
456 ))
457 }
458 #[cfg(not(unix))]
459 ConnectionMethod::UDS => Err(io::Error::new(
460 io::ErrorKind::Unsupported,
461 "Unix Domain Socket is not supported on this platform",
462 )
463 .into()),
464 }
465 }
466
467 /// Accept connection and does handshake.
468 /// # Parameters
469 /// - `method`: Connection method. One of followings:
470 /// - TCP
471 /// - TLS
472 /// - UDS
473 /// - host: Hostname or IP address of this listener. Empty `str` for Unix domain socket.
474 /// - port: Listening port.
475 /// # Example
476 /// ```no_run
477 /// use kdbplus::ipc::*;
478 ///
479 /// #[tokio::main]
480 /// async fn main() -> Result<()> {
481 /// // Start listenening over UDS at the port 7000 with authentication enabled.
482 /// while let Ok(mut socket) = QStream::accept(ConnectionMethod::UDS, "", 7000).await {
483 /// tokio::task::spawn(async move {
484 /// loop {
485 /// match socket.receive_message().await {
486 /// Ok((_, message)) => {
487 /// println!("request: {}", message);
488 /// }
489 /// _ => {
490 /// socket.shutdown().await.unwrap();
491 /// break;
492 /// }
493 /// }
494 /// }
495 /// });
496 /// }
497 ///
498 /// Ok(())
499 /// }
500 /// ```
501 /// q processes can connect and send messages to this acceptor.
502 /// ```q
503 /// q)// Process1
504 /// q)h:hopen `:unix://7000:reluctant:slowday
505 /// q)neg[h] (`monalizza; 3.8)
506 /// q)neg[h] (`pizza; 125)
507 /// ```
508 /// ```q
509 /// q)// Process2
510 /// q)h:hopen `:unix://7000:mattew:oracle
511 /// q)neg[h] (`teddy; "bear")
512 /// ```
513 /// # Note
514 /// - TLS acceptor sets `.kdbplus.close_tls_connection_` on q clien via an asynchronous message. This function is necessary to close
515 /// the socket from the server side without crashing server side application.
516 /// - TLS acceptor and UDS acceptor use specific environmental variables to work. See the [Environmental Variable](../ipc/index.html#environmentl-variables) section for details.
517 pub async fn accept(method: ConnectionMethod, host: &str, port: u16) -> Result<Self> {
518 match method {
519 ConnectionMethod::TCP => {
520 // Bind to the endpoint.
521 let listener = TcpListener::bind(&format!("{}:{}", host, port)).await?;
522 // Listen to the endpoint.
523 let (mut socket, ip_address) = listener.accept().await?;
524 // Read untill null bytes and send back capacity.
525 while (read_client_input(&mut socket).await).is_err() {
526 // Continue to listen in case of error.
527 socket = listener.accept().await?.0;
528 }
529 // Check if the connection is local
530 Ok(QStream::new(
531 Box::new(socket),
532 ConnectionMethod::TCP,
533 true,
534 ip_address.ip() == IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
535 ))
536 }
537 ConnectionMethod::TLS => {
538 // Bind to the endpoint.
539 let listener = TcpListener::bind(&format!("{}:{}", host, port)).await?;
540 // Check if key exists and decode an identity with a given password.
541 let identity = build_identity_from_cert().await?;
542 // Build TLS acceptor.
543 let tls_acceptor = TlsAcceptor::from(TlsAcceptorInner::new(identity).unwrap());
544 // Listen to the endpoint.
545 let (mut socket, _) = listener.accept().await?;
546 // TLS processing.
547 let mut tls_socket = tls_acceptor
548 .accept(socket)
549 .await
550 .expect("failed to accept TLS connection");
551 // Read untill null bytes and send back a capacity.
552 while (read_client_input(&mut tls_socket).await).is_err() {
553 // Continue to listen in case of error.
554 socket = listener.accept().await?.0;
555 tls_socket = tls_acceptor
556 .accept(socket)
557 .await
558 .expect("failed to accept TLS connection");
559 }
560 // TLS is always a remote connection
561 let mut qstream =
562 QStream::new(Box::new(tls_socket), ConnectionMethod::TCP, true, false);
563 // In order to close the connection from the server side, it needs to tell a client to close the connection.
564 // The `kdbplus_close_tls_connection_` will be called from the server at shutdown.
565 qstream
566 .send_async_message(&".kdbplus.close_tls_connection_:{[] hclose .z.w;}")
567 .await?;
568 Ok(qstream)
569 }
570 #[cfg(unix)]
571 ConnectionMethod::UDS => {
572 // uild a sockt file path.
573 let uds_path = create_sockfile_path(port)?;
574 let abstract_sockfile_ = format!("\x00{}", uds_path);
575 let abstract_sockfile = Path::new(&abstract_sockfile_);
576 // Bind to the file
577 let listener = UnixListener::bind(abstract_sockfile).unwrap();
578 // Listen to the endpoint
579 let (mut socket, _) = listener.accept().await?;
580 // Read untill null bytes and send back capacity.
581 while (read_client_input(&mut socket).await).is_err() {
582 // Continue to listen in case of error.
583 socket = listener.accept().await?.0;
584 }
585 // UDS is always a local connection
586 Ok(QStream::new(Box::new(socket), method, true, true))
587 }
588 #[cfg(not(unix))]
589 ConnectionMethod::UDS => Err(io::Error::new(
590 io::ErrorKind::Unsupported,
591 "Unix Domain Socket is not supported on this platform",
592 )
593 .into()),
594 }
595 }
596
597 /// Shutdown the socket for a q process.
598 /// # Example
599 /// See the example of [`connect`](#method.connect).
600 pub async fn shutdown(mut self) -> Result<()> {
601 self.stream.shutdown(self.listener).await
602 }
603
604 /// Send a message with a specified message type without waiting for a response even for a synchronous message.
605 /// If you need to receive a response you need to use [`receive_message`](#method.receive_message).
606 /// # Note
607 /// The usage of this function for a synchronous message is to handle an asynchronous message or a synchronous message
608 /// sent by a remote function during its execution.
609 /// # Parameters
610 /// - `message`: q command to execute on the remote q process.
611 /// - `&str`: q command in a string form.
612 /// - `K`: Query in a functional form.
613 /// - `message_type`: Asynchronous or synchronous.
614 /// # Example
615 /// See the example of [`connect`](#method.connect).
616 pub async fn send_message(&mut self, message: &dyn Query, message_type: u8) -> Result<()> {
617 self.stream
618 .send_message(message, message_type, self.local)
619 .await
620 }
621
622 /// Send a message asynchronously.
623 /// # Parameters
624 /// - `message`: q command to execute on the remote q process.
625 /// - `&str`: q command in a string form.
626 /// - `K`: Query in a functional form.
627 /// # Example
628 /// See the example of [`connect`](#method.connect).
629 pub async fn send_async_message(&mut self, message: &dyn Query) -> Result<()> {
630 self.stream.send_async_message(message, self.local).await
631 }
632
633 /// Send a message synchronously.
634 /// # Note
635 /// Remote function must NOT send back a message of asynchronous or synchronous type durning execution of the function.
636 /// # Parameters
637 /// - `message`: q command to execute on the remote q process.
638 /// - `&str`: q command in a string form.
639 /// - `K`: Query in a functional form.
640 /// # Example
641 /// See the example of [`connect`](#method.connect).
642 pub async fn send_sync_message(&mut self, message: &dyn Query) -> Result<K> {
643 self.stream.send_sync_message(message, self.local).await
644 }
645
646 /// Receive a message from a remote q process. The received message is parsed as `K` and message type is
647 /// stored in the first returned value.
648 /// # Example
649 /// See the example of [`accept`](#method.accept).
650 pub async fn receive_message(&mut self) -> Result<(u8, K)> {
651 self.stream.receive_message().await
652 }
653
654 /// Return underlying connection type. One of `TCP`, `TLS` or `UDS`.
655 /// # Example
656 /// See the example of [`connect`](#method.connect).
657 pub fn get_connection_type(&self) -> &str {
658 match self.method {
659 ConnectionMethod::TCP => "TCP",
660 ConnectionMethod::TLS => "TLS",
661 ConnectionMethod::UDS => "UDS",
662 }
663 }
664
665 /// Enforce compression if the size of a message exceeds 2000 regardless of locality of the connection.
666 /// This flag is not revertible intentionally.
667 pub fn enforce_compression(&mut self) {
668 self.local = false;
669 }
670}
671
672//%% QStreamInner %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
673
674#[async_trait]
675impl QStreamInner for TcpStream {
676 async fn shutdown(&mut self, _: bool) -> Result<()> {
677 AsyncWriteExt::shutdown(self).await?;
678 Ok(())
679 }
680
681 async fn send_message(
682 &mut self,
683 message: &dyn Query,
684 message_type: u8,
685 is_local: bool,
686 ) -> Result<()> {
687 // Serialize a message
688 let byte_message = message.serialize(message_type, is_local).await;
689 // Send the message
690 self.write_all(&byte_message).await?;
691 Ok(())
692 }
693
694 async fn send_async_message(&mut self, message: &dyn Query, is_local: bool) -> Result<()> {
695 // Serialize a message
696 let byte_message = message.serialize(qmsg_type::asynchronous, is_local).await;
697 // Send the message
698 self.write_all(&byte_message).await?;
699 Ok(())
700 }
701
702 async fn send_sync_message(&mut self, message: &dyn Query, is_local: bool) -> Result<K> {
703 // Serialize a message
704 let byte_message = message.serialize(qmsg_type::synchronous, is_local).await;
705 // Send the message
706 self.write_all(&byte_message).await?;
707 // Receive a response. If message type is not response it returns an error.
708 match receive_message(self).await {
709 Ok((qmsg_type::response, response)) => Ok(response),
710 Err(error) => Err(error),
711 Ok((_, message)) => Err(io::Error::new(
712 io::ErrorKind::InvalidData,
713 format!("expected a response: {}", message),
714 )
715 .into()),
716 }
717 }
718
719 async fn receive_message(&mut self) -> Result<(u8, K)> {
720 receive_message(self).await
721 }
722}
723
724#[async_trait]
725impl QStreamInner for TlsStream<TcpStream> {
726 async fn shutdown(&mut self, is_listener: bool) -> Result<()> {
727 if is_listener {
728 // Closing the handle from the server side by `self.get_mut().shutdown()` crashes due to 'assertion failed: !self.context.is_null()'.
729 // No reason to compress.
730 self.send_async_message(&".kdbplus.close_tls_connection_[]", false)
731 .await
732 } else {
733 self.get_mut().shutdown()?;
734 Ok(())
735 }
736 }
737
738 async fn send_message(
739 &mut self,
740 message: &dyn Query,
741 message_type: u8,
742 is_local: bool,
743 ) -> Result<()> {
744 // Serialize a message
745 let byte_message = message.serialize(message_type, is_local).await;
746 // Send the message
747 self.write_all(&byte_message).await?;
748 Ok(())
749 }
750
751 async fn send_async_message(&mut self, message: &dyn Query, is_local: bool) -> Result<()> {
752 // Serialize a message
753 let byte_message = message.serialize(qmsg_type::asynchronous, is_local).await;
754 // Send the message
755 self.write_all(&byte_message).await?;
756 Ok(())
757 }
758
759 async fn send_sync_message(&mut self, message: &dyn Query, is_local: bool) -> Result<K> {
760 // Serialize a message
761 let byte_message = message.serialize(qmsg_type::synchronous, is_local).await;
762 // Send the message
763 self.write_all(&byte_message).await?;
764 // Receive a response. If message type is not response it returns an error.
765 match receive_message(self).await {
766 Ok((qmsg_type::response, response)) => Ok(response),
767 Err(error) => Err(error),
768 Ok((_, message)) => Err(io::Error::new(
769 io::ErrorKind::InvalidData,
770 format!("expected a response: {}", message),
771 )
772 .into()),
773 }
774 }
775
776 async fn receive_message(&mut self) -> Result<(u8, K)> {
777 receive_message(self).await
778 }
779}
780
781#[cfg(unix)]
782#[async_trait]
783impl QStreamInner for UnixStream {
784 /// Close a handle to a q process which is connected with Unix Domain Socket.
785 /// Socket file is removed.
786 async fn shutdown(&mut self, _: bool) -> Result<()> {
787 AsyncWriteExt::shutdown(self).await?;
788 Ok(())
789 }
790
791 async fn send_message(
792 &mut self,
793 message: &dyn Query,
794 message_type: u8,
795 is_local: bool,
796 ) -> Result<()> {
797 // Serialize a message
798 let byte_message = message.serialize(message_type, is_local).await;
799 // Send the message
800 self.write_all(&byte_message).await?;
801 Ok(())
802 }
803
804 async fn send_async_message(&mut self, message: &dyn Query, is_local: bool) -> Result<()> {
805 // Serialize a message
806 let byte_message = message.serialize(qmsg_type::asynchronous, is_local).await;
807 // Send the message
808 self.write_all(&byte_message).await?;
809 Ok(())
810 }
811
812 async fn send_sync_message(&mut self, message: &dyn Query, is_local: bool) -> Result<K> {
813 // Serialize a message
814 let byte_message = message.serialize(qmsg_type::synchronous, is_local).await;
815 // Send the message
816 self.write_all(&byte_message).await?;
817 // Receive a response. If message type is not response it returns an error.
818 match receive_message(self).await {
819 Ok((qmsg_type::response, response)) => Ok(response),
820 Err(error) => Err(error),
821 Ok((_, message)) => Err(io::Error::new(
822 io::ErrorKind::InvalidData,
823 format!("expected a response: {}", message),
824 )
825 .into()),
826 }
827 }
828
829 async fn receive_message(&mut self) -> Result<(u8, K)> {
830 receive_message(self).await
831 }
832}
833
834//%% MessageHeader %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
835
836impl MessageHeader {
837 /// Constructor.
838 fn new(encoding: u8, message_type: u8, compressed: u8, length: u32) -> Self {
839 MessageHeader {
840 encoding,
841 message_type,
842 compressed,
843 _unused: 0,
844 length,
845 }
846 }
847
848 /// Constructor from bytes.
849 fn from_bytes(bytes: [u8; 8]) -> Self {
850 let encoding = bytes[0];
851
852 let length = match encoding {
853 0 => u32::from_be_bytes(bytes[4..8].try_into().unwrap()),
854 _ => u32::from_le_bytes(bytes[4..8].try_into().unwrap()),
855 };
856
857 // Build header
858 MessageHeader::new(encoding, bytes[1], bytes[2], length)
859 }
860
861 /// Length of bytes for a header.
862 fn size() -> usize {
863 8
864 }
865}
866
867//++++++++++++++++++++++++++++++++++++++++++++++++++//
868// >> Private Functions
869//++++++++++++++++++++++++++++++++++++++++++++++++++//
870
871//%% QStream Connector %%//vvvvvvvvvvvvvvvvvvvvvvvvvv/
872
873/// Inner function of `connect_tcp` and `connect_tls` to establish a TCP connection with the sepcified
874/// endpoint. The hostname is resolved to an IP address with a system DNS resolver or parsed directly
875/// as an IP address.
876///
877/// Tries to connect to multiple resolved IP addresses until the first successful connection. Error is
878/// returned if none of them are valid.
879/// # Parameters
880/// - `host`: Hostname or IP address of the target q/kdb+ process.
881/// - `port`: Port of the target q process
882async fn connect_tcp_impl(host: &str, port: u16) -> Result<TcpStream> {
883 // DNS system resolver (should not fail)
884 let resolver =
885 TokioAsyncResolver::tokio_from_system_conf().expect("failed to create a resolver");
886
887 // Check if we were given an IP address
888 let ips;
889 if let Ok(ip) = host.parse::<IpAddr>() {
890 ips = vec![ip.to_string()]
891 } else {
892 // Resolve the given hostname
893 ips = resolver
894 .ipv4_lookup(format!("{}.", host).as_str())
895 .await
896 .unwrap()
897 .iter()
898 .map(|result| result.to_string())
899 .collect()
900 };
901
902 for answer in ips {
903 let host_port = format!("{}:{}", answer, port);
904 // Return if this IP address is valid
905 match TcpStream::connect(&host_port).await {
906 Ok(socket) => {
907 //println!("connected: {}", host_port);
908 return Ok(socket);
909 }
910 Err(_) => {
911 eprintln!("connection refused: {}. try next.", host_port);
912 }
913 }
914 }
915 // All addresses failed.
916 Err(io::Error::new(io::ErrorKind::ConnectionRefused, "failed to connect").into())
917}
918
919/// Send a credential and receive a common capacity.
920async fn handshake<S>(socket: &mut S, credential_: &str, method_bytes: &str) -> Result<()>
921where
922 S: Unpin + AsyncWriteExt + AsyncReadExt,
923{
924 // Send credential
925 let credential = credential_.to_string() + method_bytes;
926 socket.write_all(credential.as_bytes()).await?;
927
928 // Placeholder of common capablility
929 let mut cap = [0u8; 1];
930 if (socket.read_exact(&mut cap).await).is_err() {
931 // Connection is closed in case of authentication failure
932 Err(io::Error::new(io::ErrorKind::ConnectionAborted, "authentication failure").into())
933 } else {
934 Ok(())
935 }
936}
937
938/// Connect to q process running on a specified `host` and `port` via TCP with a credential `username:password`.
939/// # Parameters
940/// - `host`: Hostname or IP address of the target q process.
941/// - `port`: Port of the target q process.
942/// - `credential`: Credential in the form of `username:password` to connect to the target q process.
943async fn connect_tcp(host: &str, port: u16, credential: &str) -> Result<TcpStream> {
944 // Connect via TCP
945 let mut socket = connect_tcp_impl(host, port).await?;
946 // Handshake
947 handshake(&mut socket, credential, "\x03\x00").await?;
948 Ok(socket)
949}
950
951/// TLS version of `connect_tcp`.
952/// # Parameters
953/// - `host`: Hostname or IP address of the target q process.
954/// - `port`: Port of the target q process.
955/// - `credential`: Credential in the form of `username:password` to connect to the target q process.
956async fn connect_tls(host: &str, port: u16, credential: &str) -> Result<TlsStream<TcpStream>> {
957 // Connect via TCP
958 let socket_ = connect_tcp_impl(host, port).await?;
959 // Use TLS
960 let connector = TlsConnector::from(TlsConnectorInner::new().unwrap());
961 let mut socket = connector
962 .connect(host, socket_)
963 .await
964 .expect("failed to create TLS session");
965 // Handshake
966 handshake(&mut socket, credential, "\x03\x00").await?;
967 Ok(socket)
968}
969
970/// Build a path of a socket file.
971fn create_sockfile_path(port: u16) -> Result<String> {
972 // Create file path
973 let udspath = match env::var("QUDSPATH") {
974 Ok(dir) => format!("{}/kx.{}", dir, port),
975 Err(_) => format!("/tmp/kx.{}", port),
976 };
977
978 Ok(udspath)
979}
980
981/// Connect to q process running on the specified `port` via Unix domain socket with a credential `username:password`.
982/// # Parameters
983/// - `port`: Port of the target q process.
984/// - `credential`: Credential in the form of `username:password` to connect to the target q process.
985#[cfg(unix)]
986async fn connect_uds(port: u16, credential: &str) -> Result<UnixStream> {
987 // Create a file path.
988 let uds_path = create_sockfile_path(port)?;
989 let abstract_sockfile_ = format!("\x00{}", uds_path);
990 let abstract_sockfile = Path::new(&abstract_sockfile_);
991 // Connect to kdb+.
992 let mut socket = UnixStream::connect(&abstract_sockfile).await?;
993 // Handshake
994 handshake(&mut socket, credential, "\x06\x00").await?;
995
996 Ok(socket)
997}
998
999//%% QStream Acceptor %%//vvvvvvvvvvvvvvvvvvvvvvvvvvv/
1000
1001/// Read username, password, capacity and null byte from q client at the connection and does authentication.
1002/// Close the handle if the authentication fails.
1003async fn read_client_input<S>(socket: &mut S) -> Result<()>
1004where
1005 S: Unpin + AsyncWriteExt + AsyncReadExt,
1006{
1007 // Buffer to read inputs.
1008 let mut client_input = [0u8; 32];
1009 // credential will be built from small fractions of bytes.
1010 let mut passed_credential = String::new();
1011 loop {
1012 // Read a client credential input.
1013 match socket.read(&mut client_input).await {
1014 Ok(0) => {
1015 // No bytes were read
1016 }
1017 Ok(_) => {
1018 // Locate a byte denoting a capacity
1019 if let Some(index) = client_input
1020 .iter()
1021 .position(|byte| *byte == 0x03 || *byte == 0x06)
1022 {
1023 let capacity = client_input[index];
1024 passed_credential
1025 .push_str(str::from_utf8(&client_input[0..index]).expect("invalid bytes"));
1026 let credential = passed_credential.as_str().split(':').collect::<Vec<&str>>();
1027 #[allow(clippy::borrow_interior_mutable_const)]
1028 if let Some(encoded) = ACCOUNTS.get(&credential[0].to_string()) {
1029 // User exists
1030 let mut hasher = Sha1::new();
1031 hasher.update(credential[1].as_bytes());
1032 let encoded_password = hasher.digest().to_string();
1033 if encoded == &encoded_password {
1034 // Client passed correct credential
1035 socket.write_all(&[capacity; 1]).await?;
1036 return Ok(());
1037 } else {
1038 // Authentication failure.
1039 // Close connection.
1040 socket.shutdown().await?;
1041 return Err(io::Error::new(
1042 io::ErrorKind::InvalidData,
1043 "authentication failed",
1044 )
1045 .into());
1046 }
1047 } else {
1048 // Authentication failure.
1049 // Close connection.
1050 socket.shutdown().await?;
1051 return Err(io::Error::new(
1052 io::ErrorKind::InvalidData,
1053 "authentication failed",
1054 )
1055 .into());
1056 }
1057 } else {
1058 // Append a fraction of credential
1059 passed_credential
1060 .push_str(str::from_utf8(&client_input).expect("invalid bytes"));
1061 }
1062 }
1063 Err(error) => {
1064 return Err(error.into());
1065 }
1066 }
1067 }
1068}
1069
1070/// Check if server key exists and return teh contents.
1071async fn build_identity_from_cert() -> Result<Identity> {
1072 // Check if server key exists.
1073 if let Ok(path) = env::var("KDBPLUS_TLS_KEY_FILE") {
1074 if let Ok(password) = env::var("KDBPLUS_TLS_KEY_FILE_SECRET") {
1075 let cert_file = tokio::fs::File::open(Path::new(&path)).await.unwrap();
1076 let mut reader = BufReader::new(cert_file);
1077 let mut der: Vec<u8> = Vec::new();
1078 // Read the key file.
1079 reader.read_to_end(&mut der).await?;
1080 // Create identity.
1081 if let Ok(identity) = Identity::from_pkcs12(&der, &password) {
1082 Ok(identity)
1083 } else {
1084 Err(io::Error::new(io::ErrorKind::InvalidData, "authentication failed").into())
1085 }
1086 } else {
1087 Err(io::Error::new(
1088 io::ErrorKind::NotFound,
1089 "KDBPLUS_TLS_KEY_FILE_SECRET is not set",
1090 )
1091 .into())
1092 }
1093 } else {
1094 Err(io::Error::new(io::ErrorKind::NotFound, "KDBPLUS_TLS_KEY_FILE is not set").into())
1095 }
1096}
1097
1098//%% QStream Query %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
1099
1100/// Receive a message from q process with decompression if necessary. The received message is parsed as `K` and message type is
1101/// stored in the first returned value.
1102/// # Parameters
1103/// - `socket`: Socket to communicate with a q process. Either of `TcpStream`, `TlsStream<TcpStream>` or `UnixStream`.
1104async fn receive_message<S>(socket: &mut S) -> Result<(u8, K)>
1105where
1106 S: Unpin + AsyncReadExt,
1107{
1108 // Read header
1109 let mut header_buffer = [0u8; 8];
1110 if let Err(err) = socket.read_exact(&mut header_buffer).await {
1111 // The expected message is header or EOF (close due to q process failure resulting from a bad query)
1112 return Err(io::Error::new(
1113 io::ErrorKind::ConnectionAborted,
1114 format!("Connection dropped: {}", err),
1115 )
1116 .into());
1117 }
1118
1119 // Parse message header
1120 let header = MessageHeader::from_bytes(header_buffer);
1121
1122 // Read body
1123 let body_length = header.length as usize - MessageHeader::size();
1124 let mut body: Vec<u8> = vec![0; body_length];
1125 if let Err(err) = socket.read_exact(&mut body).await {
1126 // Fails if q process fails before reading the body
1127 return Err(io::Error::new(
1128 io::ErrorKind::UnexpectedEof,
1129 format!("Failed to read body of message: {}", err),
1130 )
1131 .into());
1132 }
1133
1134 // Decompress if necessary
1135 if header.compressed == 0x01 {
1136 body = decompress(body, header.encoding).await;
1137 }
1138
1139 Ok((
1140 header.message_type,
1141 K::q_ipc_decode(&body, header.encoding).await,
1142 ))
1143}
1144
1145/// Compress body. The combination of serializing the data and compressing will result in
1146/// the same output as shown in the q language by using the -18! function e.g.
1147/// serializing 2000 bools set to true, then compressing, will have the same output as `-18!2000#1b`.
1148/// # Parameter
1149/// - `raw`: Serialized message.
1150/// - `encode`: `0` if Big Endian; `1` if Little Endian.
1151async fn compress(raw: Vec<u8>) -> (bool, Vec<u8>) {
1152 let mut i = 0_u8;
1153 let mut f = 0_u8;
1154 let mut h0 = 0_usize;
1155 let mut h = 0_usize;
1156 let mut g: bool;
1157 let mut compressed: Vec<u8> = vec![0; (raw.len()) / 2];
1158
1159 // Start index of compressed body
1160 // 12 bytes are reserved for the header + size of raw bytes
1161 let mut c = 12;
1162 let mut d = c;
1163 let e = compressed.len();
1164 let mut p = 0_usize;
1165 let mut q: usize;
1166 let mut r: usize;
1167 let mut s0 = 0_usize;
1168
1169 // Body starts from index 8
1170 let mut s = 8_usize;
1171 let t = raw.len();
1172 let mut a = [0_i32; 256];
1173
1174 // Copy encode, message type, compressed and reserved
1175 compressed[0..4].copy_from_slice(&raw[0..4]);
1176 // Set compressed flag
1177 compressed[2] = 1;
1178
1179 // Write size of raw bytes including a header
1180 let raw_size = match ENCODING {
1181 0 => (t as u32).to_be_bytes(),
1182 _ => (t as u32).to_le_bytes(),
1183 };
1184 compressed[8..12].copy_from_slice(&raw_size);
1185
1186 while s < t {
1187 if i == 0 {
1188 if d > e - 17 {
1189 // Early return when compressing to less than half failed
1190 return (false, raw);
1191 }
1192 i = 1;
1193 compressed[c] = f;
1194 c = d;
1195 d += 1;
1196 f = 0;
1197 }
1198 g = s > t - 3;
1199 if !g {
1200 h = (raw[s] ^ raw[s + 1]) as usize;
1201 p = a[h] as usize;
1202 g = (0 == p) || (0 != (raw[s] ^ raw[p]));
1203 }
1204 if 0 < s0 {
1205 a[h0] = s0 as i32;
1206 s0 = 0;
1207 }
1208 if g {
1209 h0 = h;
1210 s0 = s;
1211 compressed[d] = raw[s];
1212 d += 1;
1213 s += 1;
1214 } else {
1215 a[h] = s as i32;
1216 f |= i;
1217 p += 2;
1218 s += 2;
1219 r = s;
1220 q = if s + 255 > t { t } else { s + 255 };
1221 while (s < q) && (raw[p] == raw[s]) {
1222 s += 1;
1223 if s < q {
1224 p += 1;
1225 }
1226 }
1227 compressed[d] = h as u8;
1228 d += 1;
1229 compressed[d] = (s - r) as u8;
1230 d += 1;
1231 }
1232 i = i.wrapping_mul(2);
1233 }
1234 compressed[c] = f;
1235 // Final compressed data size
1236 let compressed_size = match ENCODING {
1237 0 => (d as u32).to_be_bytes(),
1238 _ => (d as u32).to_le_bytes(),
1239 };
1240 compressed[4..8].copy_from_slice(&compressed_size);
1241 let _ = compressed.split_off(d);
1242 (true, compressed)
1243}
1244
1245/// Decompress body. The combination of decompressing and deserializing the data
1246/// will result in the same output as shown in the q language by using the `-19!` function.
1247/// # Parameter
1248/// - `compressed`: Compressed serialized message.
1249/// - `encoding`:
1250/// - `0`: Big Endian
1251/// - `1`: Little Endian.
1252async fn decompress(compressed: Vec<u8>, encoding: u8) -> Vec<u8> {
1253 let mut n = 0;
1254 let mut r: usize;
1255 let mut f = 0_usize;
1256
1257 // Header has already been removed.
1258 // Start index of decompressed bytes is 0
1259 let mut s = 0_usize;
1260 let mut p = s;
1261 let mut i = 0_usize;
1262
1263 // Subtract 8 bytes from decoded bytes size as 8 bytes have already been taken as header
1264 let size = match encoding {
1265 0 => {
1266 i32::from_be_bytes(
1267 compressed[0..4]
1268 .try_into()
1269 .expect("slice does not have length 4"),
1270 ) - 8
1271 }
1272 _ => {
1273 i32::from_le_bytes(
1274 compressed[0..4]
1275 .try_into()
1276 .expect("slice does not have length 4"),
1277 ) - 8
1278 }
1279 };
1280 let mut decompressed: Vec<u8> = vec![0; size as usize];
1281
1282 // Start index of compressed body.
1283 // 8 bytes have already been removed as header
1284 let mut d = 4;
1285 let mut aa = [0_i32; 256];
1286 while s < decompressed.len() {
1287 if i == 0 {
1288 f = compressed[d] as usize;
1289 d += 1;
1290 i = 1;
1291 }
1292 if (f & i) != 0 {
1293 r = aa[compressed[d] as usize] as usize;
1294 d += 1;
1295 decompressed[s] = decompressed[r];
1296 s += 1;
1297 r += 1;
1298 decompressed[s] = decompressed[r];
1299 s += 1;
1300 r += 1;
1301 n = compressed[d] as usize;
1302 d += 1;
1303 for m in 0..n {
1304 decompressed[s + m] = decompressed[r + m];
1305 }
1306 } else {
1307 decompressed[s] = compressed[d];
1308 s += 1;
1309 d += 1;
1310 }
1311 while p < s - 1 {
1312 aa[(decompressed[p] ^ decompressed[p + 1]) as usize] = p as i32;
1313 p += 1;
1314 }
1315 if (f & i) != 0 {
1316 s += n;
1317 p = s;
1318 }
1319 i *= 2;
1320 if i == 256 {
1321 i = 0;
1322 }
1323 }
1324 decompressed
1325}