kdb_plus_fixed/ipc/connection.rs
1//++++++++++++++++++++++++++++++++++++++++++++++++++//
2// >> Load Libraries
3//++++++++++++++++++++++++++++++++++++++++++++++++++//
4
5use super::serialize::ENCODING;
6use super::Result;
7use super::{qtype, K};
8use async_trait::async_trait;
9use io::BufRead;
10use once_cell::sync::Lazy;
11use sha1_smol::Sha1;
12use std::collections::HashMap;
13use std::convert::TryInto;
14use std::net::{IpAddr, Ipv4Addr};
15use std::path::Path;
16use std::{env, fs, io, str};
17use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
18use tokio::net::{TcpListener, TcpStream};
19#[cfg(unix)]
20use tokio::net::{UnixListener, UnixStream};
21use tokio_native_tls::native_tls::{
22 Identity, TlsAcceptor as TlsAcceptorInner, TlsConnector as TlsConnectorInner,
23};
24use tokio_native_tls::{TlsAcceptor, TlsConnector, TlsStream};
25use trust_dns_resolver::TokioAsyncResolver;
26
27//++++++++++++++++++++++++++++++++++++++++++++++++++//
28// >> Global Variable
29//++++++++++++++++++++++++++++++++++++++++++++++++++//
30
31//%% QStream %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
32
33pub mod qmsg_type {
34 //! This module provides a list of q message type used for IPC.
35 //! The motivation to contain them in a module is to tie them up as related items rather
36 //! than scattered values. Hence user should use these indicators with `qmsg_type::` prefix, e.g., `qmsg_type::asynchronous`.
37 //!
38 //! # Example
39 //! ```no_run
40 //! use kdbplus::ipc::*;
41 //!
42 //! // Print `K` object.
43 //! fn print(obj: &K) {
44 //! println!("{}", obj);
45 //! }
46 //!
47 //! // Calculate something from two long arguments.
48 //! fn nonsense(arg1: i64, arg2: i64) -> i64 {
49 //! arg1 * arg2
50 //! }
51 //!
52 //! #[tokio::main]
53 //! async fn main() -> Result<()> {
54 //! // Connect to qprocess running on localhost:5000 via TCP
55 //! let mut socket =
56 //! QStream::connect(ConnectionMethod::TCP, "localhost", 5000_u16, "ideal:person").await?;
57 //!
58 //! // Set a function which sends back a non-response message during its execution.
59 //! socket
60 //! .send_async_message(
61 //! &"complex:{neg[.z.w](`print; \"counter\"); what: .z.w (`nonsense; 1; 2); what*100}",
62 //! )
63 //! .await?;
64 //!
65 //! // Send a query `(`complex; ::)` without waiting for a response.
66 //! socket
67 //! .send_message(
68 //! &K::new_compound_list(vec![K::new_symbol(String::from("complex")), K::new_null()]),
69 //! qmsg_type::synchronous,
70 //! )
71 //! .await?;
72 //!
73 //! // Receive an asynchronous call from the function.
74 //! match socket.receive_message().await {
75 //! Ok((qmsg_type::asynchronous, message)) => {
76 //! println!("asynchronous call: {}", message);
77 //! let list = message.as_vec::<K>().unwrap();
78 //! if list[0].get_symbol().unwrap() == "print" {
79 //! print(&list[1])
80 //! }
81 //! }
82 //! _ => unreachable!(),
83 //! }
84 //!
85 //! // Receive a synchronous call from the function.
86 //! match socket.receive_message().await {
87 //! Ok((qmsg_type::synchronous, message)) => {
88 //! println!("synchronous call: {}", message);
89 //! let list = message.as_vec::<K>().unwrap();
90 //! if list[0].get_symbol().unwrap() == "nonsense" {
91 //! let res = nonsense(list[1].get_long().unwrap(), list[2].get_long().unwrap());
92 //! // Send bach a response.
93 //! socket
94 //! .send_message(&K::new_long(res), qmsg_type::response)
95 //! .await?;
96 //! }
97 //! }
98 //! _ => unreachable!(),
99 //! }
100 //!
101 //! // Receive a final result.
102 //! match socket.receive_message().await {
103 //! Ok((qmsg_type::response, message)) => {
104 //! println!("final: {}", message);
105 //! }
106 //! _ => unreachable!(),
107 //! }
108 //!
109 //! Ok(())
110 //! }
111 //!```
112 /// Used to send a message to q/kdb+ asynchronously.
113 pub const asynchronous: u8 = 0;
114 /// Used to send a message to q/kdb+ synchronously.
115 pub const synchronous: u8 = 1;
116 /// Used by q/kdb+ to identify a response for a synchronous query.
117 pub const response: u8 = 2;
118}
119
120//%% QStream Acceptor %%//vvvvvvvvvvvvvvvvvvvvvvvvvvv/
121
122/// Map from user name to password hashed with SHA1.
123#[allow(clippy::declare_interior_mutable_const)]
124const ACCOUNTS: Lazy<HashMap<String, String>> = Lazy::new(|| {
125 // Map from user to password
126 let mut map: HashMap<String, String> = HashMap::new();
127 // Open credential file
128 let file = fs::OpenOptions::new()
129 .read(true)
130 .open(env::var("KDBPLUS_ACCOUNT_FILE").expect("KDBPLUS_ACCOUNT_FILE is not set"))
131 .expect("failed to open account file");
132 let mut reader = io::BufReader::new(file);
133 let mut line = String::new();
134 loop {
135 match reader.read_line(&mut line) {
136 Ok(0) => break,
137 Ok(_) => {
138 let credential = line.as_str().split(':').collect::<Vec<&str>>();
139 let mut password = credential[1];
140 if password.ends_with('\n') {
141 password = &password[0..password.len() - 1];
142 }
143 map.insert(credential[0].to_string(), password.to_string());
144 line.clear();
145 }
146 _ => unreachable!(),
147 }
148 }
149 map
150});
151
152//++++++++++++++++++++++++++++++++++++++++++++++++++//
153// >> Structs
154//++++++++++++++++++++++++++++++++++++++++++++++++++//
155
156//%% ConnectionMethod %%//vvvvvvvvvvvvvvvvvvvvvvvvvvv/
157
158/// Connection method to q/kdb+.
159pub enum ConnectionMethod {
160 TCP = 0,
161 TLS = 1,
162 /// Unix domanin socket.
163 UDS = 2,
164}
165
166//%% Query %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
167
168/// Feature of query object.
169#[async_trait]
170pub trait Query: Send + Sync {
171 /// Serialize into q IPC bytes including a header (encoding, message type, compresssion flag and total message length).
172 /// If the connection is within the same host, the message is not compressed under any conditions.
173 /// # Parameters
174 /// - `message_type`: Message type. One of followings:
175 /// - `qmsg_type::asynchronous`
176 /// - `qmsg_type::synchronous`
177 /// - `qmsg_type::response`
178 /// - `is_local`: Flag of whether the connection is within the same host.
179 async fn serialize(&self, message_type: u8, is_local: bool) -> Vec<u8>;
180}
181
182//%% QStreamInner %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
183
184/// Features which streams communicating with q must have.
185#[async_trait]
186trait QStreamInner: Send + Sync {
187 /// Shutdown underlying stream.
188 async fn shutdown(&mut self, is_server: bool) -> Result<()>;
189 /// Send a message with a specified message type without waiting for a response.
190 /// # Parameters
191 /// - `message`: q command to execute on the remote q process.
192 /// - `message_type`: Asynchronous or synchronous.
193 /// - `is_local`: Flag of whether the connection is within the same host.
194 async fn send_message(
195 &mut self,
196 message: &dyn Query,
197 message_type: u8,
198 is_local: bool,
199 ) -> Result<()>;
200 /// Send a message asynchronously.
201 /// # Parameters
202 /// - `message`: q command in two ways:
203 /// - `&str`: q command in a string form.
204 /// - `K`: Query in a functional form.
205 /// - `is_local`: Flag of whether the connection is within the same host.
206 async fn send_async_message(&mut self, message: &dyn Query, is_local: bool) -> Result<()>;
207 /// Send a message asynchronously.
208 /// # Parameters
209 /// - `message`: q command in two ways:
210 /// - `&str`: q command in a string form.
211 /// - `K`: Query in a functional form.
212 /// - `is_local`: Flag of whether the connection is within the same host.
213 async fn send_sync_message(&mut self, message: &dyn Query, is_local: bool) -> Result<K>;
214 /// Receive a message from a remote q process. The received message is parsed as `K` and message type is
215 /// stored in the first returned value.
216 async fn receive_message(&mut self) -> Result<(u8, K)>;
217}
218
219//%% QStream %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
220
221/// Stream to communicate with q/kdb+.
222pub struct QStream {
223 /// Actual stream to communicate.
224 stream: Box<dyn QStreamInner>,
225 /// Connection method. One of followings:
226 /// - TCP
227 /// - TLS
228 /// - UDS
229 method: ConnectionMethod,
230 /// Indicator of whether the stream is an acceptor or client.
231 /// - `true`: Acceptor
232 /// - `false`: Client
233 listener: bool,
234 /// Indicator of whether the connection is within the same host.
235 /// - `true`: Connection within the same host.
236 /// - `false`: Connection with outseide.
237 local: bool,
238}
239
240//%% MessageHeader %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
241
242/// Header of q IPC data frame.
243#[derive(Clone, Copy, Debug)]
244struct MessageHeader {
245 /// Ennoding.
246 /// - 0: Big Endian
247 /// - 1: Little Endian
248 encoding: u8,
249 /// Message type. One of followings:
250 /// - 0: Asynchronous
251 /// - 1: Synchronous
252 /// - 2: Response
253 message_type: u8,
254 /// Indicator of whether the message is compressed or not.
255 /// - 0: Uncompressed
256 /// - 1: Compressed
257 compressed: u8,
258 /// Reserved byte.
259 _unused: u8,
260 /// Total length of the uncompressed message.
261 length: u32,
262}
263
264//++++++++++++++++++++++++++++++++++++++++++++++++++//
265// >> Implementation
266//++++++++++++++++++++++++++++++++++++++++++++++++++//
267
268//%% Query %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
269
270/// Text query.
271#[async_trait]
272impl Query for &str {
273 async fn serialize(&self, message_type: u8, _: bool) -> Vec<u8> {
274 // Build header //--------------------------------/
275 // Message header + (type indicator of string + header of string type) + string length
276 let byte_message = self.as_bytes();
277 let message_length = byte_message.len() as u32;
278 let total_length = MessageHeader::size() as u32 + 6 + message_length;
279
280 let total_length_bytes = match ENCODING {
281 0 => total_length.to_be_bytes(),
282 _ => total_length.to_le_bytes(),
283 };
284
285 // encode, message type, 0x00 for compression and 0x00 for reserved.
286 // Do not compress string data because it is highly unlikely that the length of the string query
287 // is greater than 2000.
288 let mut message = Vec::with_capacity(message_length as usize + MessageHeader::size());
289 message.extend_from_slice(&[ENCODING, message_type, 0, 0]);
290 // total body length
291 message.extend_from_slice(&total_length_bytes);
292 // vector type and 0x00 for attribute
293 message.extend_from_slice(&[qtype::STRING as u8, 0]);
294
295 // Build body //---------------------------------/
296 let length_info = match ENCODING {
297 0 => message_length.to_be_bytes(),
298 _ => message_length.to_le_bytes(),
299 };
300
301 // length of vector(message)
302 message.extend_from_slice(&length_info);
303 // message
304 message.extend_from_slice(byte_message);
305
306 message
307 }
308}
309
310/// Functional query.
311#[async_trait]
312impl Query for K {
313 async fn serialize(&self, message_type: u8, is_local: bool) -> Vec<u8> {
314 // Build header //--------------------------------/
315 // Message header + encoded data size
316 let mut byte_message = self.q_ipc_encode();
317 let message_length = byte_message.len();
318 let total_length = (MessageHeader::size() + message_length) as u32;
319
320 let total_length_bytes = match ENCODING {
321 0 => total_length.to_be_bytes(),
322 _ => total_length.to_le_bytes(),
323 };
324
325 // Compression is trigerred when entire message size is more than 2000 bytes
326 // and the connection is with outseide.
327 if message_length > 1992 && !is_local {
328 // encode, message type, 0x00 for compression, 0x00 for reserved and 0x00000000 for total size
329 let mut message = Vec::with_capacity(message_length + 8);
330 message.extend_from_slice(&[ENCODING, message_type, 0, 0, 0, 0, 0, 0]);
331 message.append(&mut byte_message);
332 // Try to encode entire message.
333 match compress(message).await {
334 (true, compressed) => {
335 // Message was compressed
336 return compressed;
337 }
338 (false, mut uncompressed) => {
339 // Message was not compressed.
340 // Write original total data size.
341 uncompressed[4..8].copy_from_slice(&total_length_bytes);
342 return uncompressed;
343 }
344 }
345 } else {
346 // encode, message type, 0x00 for compression and 0x00 for reserved
347 let mut message = Vec::with_capacity(message_length + MessageHeader::size());
348 message.extend_from_slice(&[ENCODING, message_type, 0, 0]);
349 // Total length of body
350 message.extend_from_slice(&total_length_bytes);
351 message.append(&mut byte_message);
352 return message;
353 }
354 }
355}
356
357//%% QStream %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
358
359impl QStream {
360 /// General constructor of `QStream`.
361 fn new(
362 stream: Box<dyn QStreamInner>,
363 method: ConnectionMethod,
364 is_listener: bool,
365 is_local: bool,
366 ) -> Self {
367 QStream {
368 stream,
369 method,
370 listener: is_listener,
371 local: is_local,
372 }
373 }
374
375 /// Connect to q/kdb+ specifying a connection method, destination host, destination port and access credential.
376 /// # Parameters
377 /// - `method`: Connection method. One of followings:
378 /// - TCP
379 /// - TLS
380 /// - UDS
381 /// - `host`: Hostname or IP address of the target q process. Empty `str` for Unix domain socket.
382 /// - `port`: Port of the target q process.
383 /// - `credential`: Credential in the form of `username:password` to connect to the target q process.
384 /// # Example
385 /// ```
386 /// use kdbplus::qattribute;
387 /// use kdbplus::ipc::*;
388 ///
389 /// #[tokio::main(flavor = "multi_thread", worker_threads = 2)]
390 /// async fn main() -> Result<()> {
391 /// let mut socket =
392 /// QStream::connect(ConnectionMethod::UDS, "", 5000_u16, "ideal:person").await?;
393 /// println!("Connection type: {}", socket.get_connection_type());
394 ///
395 /// // Set remote function with asynchronous message
396 /// socket.send_async_message(&"collatz:{[n] seq:enlist n; while[not n = 1; seq,: n:$[n mod 2; 1 + 3 * n; `long$n % 2]]; seq}").await?;
397 ///
398 /// // Send a query synchronously
399 /// let mut result = socket.send_sync_message(&"collatz[12]").await?;
400 /// println!("collatz[12]: {}", result);
401 ///
402 /// // Send a functional query.
403 /// let mut message = K::new_compound_list(vec![
404 /// K::new_symbol(String::from("collatz")),
405 /// K::new_long(100),
406 /// ]);
407 /// result = socket.send_sync_message(&message).await?;
408 /// println!("collatz[100]: {}", result);
409 ///
410 /// // Send a functional asynchronous query.
411 /// message = K::new_compound_list(vec![
412 /// K::new_string(String::from("show"), qattribute::NONE),
413 /// K::new_symbol(String::from("goodbye")),
414 /// ]);
415 /// socket.send_async_message(&message).await?;
416 ///
417 /// socket.shutdown().await?;
418 ///
419 /// Ok(())
420 /// }
421 /// ```
422 pub async fn connect(
423 method: ConnectionMethod,
424 host: &str,
425 port: u16,
426 credential: &str,
427 ) -> Result<Self> {
428 match method {
429 ConnectionMethod::TCP => {
430 let stream = connect_tcp(host, port, credential).await?;
431 let is_local = matches!(host, "localhost" | "127.0.0.1");
432 Ok(QStream::new(
433 Box::new(stream),
434 ConnectionMethod::TCP,
435 false,
436 is_local,
437 ))
438 }
439 ConnectionMethod::TLS => {
440 let stream = connect_tls(host, port, credential).await?;
441 Ok(QStream::new(
442 Box::new(stream),
443 ConnectionMethod::TLS,
444 false,
445 false,
446 ))
447 }
448 ConnectionMethod::UDS => {
449 let stream = connect_uds(port, credential).await?;
450 Ok(QStream::new(
451 Box::new(stream),
452 ConnectionMethod::UDS,
453 false,
454 true,
455 ))
456 }
457 }
458 }
459
460 /// Accept connection and does handshake.
461 /// # Parameters
462 /// - `method`: Connection method. One of followings:
463 /// - TCP
464 /// - TLS
465 /// - UDS
466 /// - host: Hostname or IP address of this listener. Empty `str` for Unix domain socket.
467 /// - port: Listening port.
468 /// # Example
469 /// ```no_run
470 /// use kdbplus::ipc::*;
471 ///
472 /// #[tokio::main]
473 /// async fn main() -> Result<()> {
474 /// // Start listenening over UDS at the port 7000 with authentication enabled.
475 /// while let Ok(mut socket) = QStream::accept(ConnectionMethod::UDS, "", 7000).await {
476 /// tokio::task::spawn(async move {
477 /// loop {
478 /// match socket.receive_message().await {
479 /// Ok((_, message)) => {
480 /// println!("request: {}", message);
481 /// }
482 /// _ => {
483 /// socket.shutdown().await.unwrap();
484 /// break;
485 /// }
486 /// }
487 /// }
488 /// });
489 /// }
490 ///
491 /// Ok(())
492 /// }
493 /// ```
494 /// q processes can connect and send messages to this acceptor.
495 /// ```q
496 /// q)// Process1
497 /// q)h:hopen `:unix://7000:reluctant:slowday
498 /// q)neg[h] (`monalizza; 3.8)
499 /// q)neg[h] (`pizza; 125)
500 /// ```
501 /// ```q
502 /// q)// Process2
503 /// q)h:hopen `:unix://7000:mattew:oracle
504 /// q)neg[h] (`teddy; "bear")
505 /// ```
506 /// # Note
507 /// - TLS acceptor sets `.kdbplus.close_tls_connection_` on q clien via an asynchronous message. This function is necessary to close
508 /// the socket from the server side without crashing server side application.
509 /// - TLS acceptor and UDS acceptor use specific environmental variables to work. See the [Environmental Variable](../ipc/index.html#environmentl-variables) section for details.
510 pub async fn accept(method: ConnectionMethod, host: &str, port: u16) -> Result<Self> {
511 match method {
512 ConnectionMethod::TCP => {
513 // Bind to the endpoint.
514 let listener = TcpListener::bind(&format!("{}:{}", host, port)).await?;
515 // Listen to the endpoint.
516 let (mut socket, ip_address) = listener.accept().await?;
517 // Read untill null bytes and send back capacity.
518 while (read_client_input(&mut socket).await).is_err() {
519 // Continue to listen in case of error.
520 socket = listener.accept().await?.0;
521 }
522 // Check if the connection is local
523 Ok(QStream::new(
524 Box::new(socket),
525 ConnectionMethod::TCP,
526 true,
527 ip_address.ip() == IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
528 ))
529 }
530 ConnectionMethod::TLS => {
531 // Bind to the endpoint.
532 let listener = TcpListener::bind(&format!("{}:{}", host, port)).await?;
533 // Check if key exists and decode an identity with a given password.
534 let identity = build_identity_from_cert().await?;
535 // Build TLS acceptor.
536 let tls_acceptor = TlsAcceptor::from(TlsAcceptorInner::new(identity).unwrap());
537 // Listen to the endpoint.
538 let (mut socket, _) = listener.accept().await?;
539 // TLS processing.
540 let mut tls_socket = tls_acceptor
541 .accept(socket)
542 .await
543 .expect("failed to accept TLS connection");
544 // Read untill null bytes and send back a capacity.
545 while (read_client_input(&mut tls_socket).await).is_err() {
546 // Continue to listen in case of error.
547 socket = listener.accept().await?.0;
548 tls_socket = tls_acceptor
549 .accept(socket)
550 .await
551 .expect("failed to accept TLS connection");
552 }
553 // TLS is always a remote connection
554 let mut qstream =
555 QStream::new(Box::new(tls_socket), ConnectionMethod::TCP, true, false);
556 // In order to close the connection from the server side, it needs to tell a client to close the connection.
557 // The `kdbplus_close_tls_connection_` will be called from the server at shutdown.
558 qstream
559 .send_async_message(&".kdbplus.close_tls_connection_:{[] hclose .z.w;}")
560 .await?;
561 Ok(qstream)
562 }
563 ConnectionMethod::UDS => {
564 // uild a sockt file path.
565 let uds_path = create_sockfile_path(port)?;
566 let abstract_sockfile_ = format!("\x00{}", uds_path);
567 let abstract_sockfile = Path::new(&abstract_sockfile_);
568 // Bind to the file
569 let listener = UnixListener::bind(abstract_sockfile).unwrap();
570 // Listen to the endpoint
571 let (mut socket, _) = listener.accept().await?;
572 // Read untill null bytes and send back capacity.
573 while (read_client_input(&mut socket).await).is_err() {
574 // Continue to listen in case of error.
575 socket = listener.accept().await?.0;
576 }
577 // UDS is always a local connection
578 Ok(QStream::new(Box::new(socket), method, true, true))
579 }
580 }
581 }
582
583 /// Shutdown the socket for a q process.
584 /// # Example
585 /// See the example of [`connect`](#method.connect).
586 pub async fn shutdown(mut self) -> Result<()> {
587 self.stream.shutdown(self.listener).await
588 }
589
590 /// Send a message with a specified message type without waiting for a response even for a synchronous message.
591 /// If you need to receive a response you need to use [`receive_message`](#method.receive_message).
592 /// # Note
593 /// The usage of this function for a synchronous message is to handle an asynchronous message or a synchronous message
594 /// sent by a remote function during its execution.
595 /// # Parameters
596 /// - `message`: q command to execute on the remote q process.
597 /// - `&str`: q command in a string form.
598 /// - `K`: Query in a functional form.
599 /// - `message_type`: Asynchronous or synchronous.
600 /// # Example
601 /// See the example of [`connect`](#method.connect).
602 pub async fn send_message(&mut self, message: &dyn Query, message_type: u8) -> Result<()> {
603 self.stream
604 .send_message(message, message_type, self.local)
605 .await
606 }
607
608 /// Send a message asynchronously.
609 /// # Parameters
610 /// - `message`: q command to execute on the remote q process.
611 /// - `&str`: q command in a string form.
612 /// - `K`: Query in a functional form.
613 /// # Example
614 /// See the example of [`connect`](#method.connect).
615 pub async fn send_async_message(&mut self, message: &dyn Query) -> Result<()> {
616 self.stream.send_async_message(message, self.local).await
617 }
618
619 /// Send a message synchronously.
620 /// # Note
621 /// Remote function must NOT send back a message of asynchronous or synchronous type durning execution of the function.
622 /// # Parameters
623 /// - `message`: q command to execute on the remote q process.
624 /// - `&str`: q command in a string form.
625 /// - `K`: Query in a functional form.
626 /// # Example
627 /// See the example of [`connect`](#method.connect).
628 pub async fn send_sync_message(&mut self, message: &dyn Query) -> Result<K> {
629 self.stream.send_sync_message(message, self.local).await
630 }
631
632 /// Receive a message from a remote q process. The received message is parsed as `K` and message type is
633 /// stored in the first returned value.
634 /// # Example
635 /// See the example of [`accept`](#method.accept).
636 pub async fn receive_message(&mut self) -> Result<(u8, K)> {
637 self.stream.receive_message().await
638 }
639
640 /// Return underlying connection type. One of `TCP`, `TLS` or `UDS`.
641 /// # Example
642 /// See the example of [`connect`](#method.connect).
643 pub fn get_connection_type(&self) -> &str {
644 match self.method {
645 ConnectionMethod::TCP => "TCP",
646 ConnectionMethod::TLS => "TLS",
647 ConnectionMethod::UDS => "UDS",
648 }
649 }
650
651 /// Enforce compression if the size of a message exceeds 2000 regardless of locality of the connection.
652 /// This flag is not revertible intentionally.
653 pub fn enforce_compression(&mut self) {
654 self.local = false;
655 }
656}
657
658//%% QStreamInner %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
659
660#[async_trait]
661impl QStreamInner for TcpStream {
662 async fn shutdown(&mut self, _: bool) -> Result<()> {
663 AsyncWriteExt::shutdown(self).await?;
664 Ok(())
665 }
666
667 async fn send_message(
668 &mut self,
669 message: &dyn Query,
670 message_type: u8,
671 is_local: bool,
672 ) -> Result<()> {
673 // Serialize a message
674 let byte_message = message.serialize(message_type, is_local).await;
675 // Send the message
676 self.write_all(&byte_message).await?;
677 Ok(())
678 }
679
680 async fn send_async_message(&mut self, message: &dyn Query, is_local: bool) -> Result<()> {
681 // Serialize a message
682 let byte_message = message.serialize(qmsg_type::asynchronous, is_local).await;
683 // Send the message
684 self.write_all(&byte_message).await?;
685 Ok(())
686 }
687
688 async fn send_sync_message(&mut self, message: &dyn Query, is_local: bool) -> Result<K> {
689 // Serialize a message
690 let byte_message = message.serialize(qmsg_type::synchronous, is_local).await;
691 // Send the message
692 self.write_all(&byte_message).await?;
693 // Receive a response. If message type is not response it returns an error.
694 match receive_message(self).await {
695 Ok((qmsg_type::response, response)) => Ok(response),
696 Err(error) => Err(error),
697 Ok((_, message)) => Err(io::Error::new(
698 io::ErrorKind::InvalidData,
699 format!("expected a response: {}", message),
700 )
701 .into()),
702 }
703 }
704
705 async fn receive_message(&mut self) -> Result<(u8, K)> {
706 receive_message(self).await
707 }
708}
709
710#[async_trait]
711impl QStreamInner for TlsStream<TcpStream> {
712 async fn shutdown(&mut self, is_listener: bool) -> Result<()> {
713 if is_listener {
714 // Closing the handle from the server side by `self.get_mut().shutdown()` crashes due to 'assertion failed: !self.context.is_null()'.
715 // No reason to compress.
716 self.send_async_message(&".kdbplus.close_tls_connection_[]", false)
717 .await
718 } else {
719 self.get_mut().shutdown()?;
720 Ok(())
721 }
722 }
723
724 async fn send_message(
725 &mut self,
726 message: &dyn Query,
727 message_type: u8,
728 is_local: bool,
729 ) -> Result<()> {
730 // Serialize a message
731 let byte_message = message.serialize(message_type, is_local).await;
732 // Send the message
733 self.write_all(&byte_message).await?;
734 Ok(())
735 }
736
737 async fn send_async_message(&mut self, message: &dyn Query, is_local: bool) -> Result<()> {
738 // Serialize a message
739 let byte_message = message.serialize(qmsg_type::asynchronous, is_local).await;
740 // Send the message
741 self.write_all(&byte_message).await?;
742 Ok(())
743 }
744
745 async fn send_sync_message(&mut self, message: &dyn Query, is_local: bool) -> Result<K> {
746 // Serialize a message
747 let byte_message = message.serialize(qmsg_type::synchronous, is_local).await;
748 // Send the message
749 self.write_all(&byte_message).await?;
750 // Receive a response. If message type is not response it returns an error.
751 match receive_message(self).await {
752 Ok((qmsg_type::response, response)) => Ok(response),
753 Err(error) => Err(error),
754 Ok((_, message)) => Err(io::Error::new(
755 io::ErrorKind::InvalidData,
756 format!("expected a response: {}", message),
757 )
758 .into()),
759 }
760 }
761
762 async fn receive_message(&mut self) -> Result<(u8, K)> {
763 receive_message(self).await
764 }
765}
766
767#[async_trait]
768impl QStreamInner for UnixStream {
769 /// Close a handle to a q process which is connected with Unix Domain Socket.
770 /// Socket file is removed.
771 async fn shutdown(&mut self, _: bool) -> Result<()> {
772 AsyncWriteExt::shutdown(self).await?;
773 Ok(())
774 }
775
776 async fn send_message(
777 &mut self,
778 message: &dyn Query,
779 message_type: u8,
780 is_local: bool,
781 ) -> Result<()> {
782 // Serialize a message
783 let byte_message = message.serialize(message_type, is_local).await;
784 // Send the message
785 self.write_all(&byte_message).await?;
786 Ok(())
787 }
788
789 async fn send_async_message(&mut self, message: &dyn Query, is_local: bool) -> Result<()> {
790 // Serialize a message
791 let byte_message = message.serialize(qmsg_type::asynchronous, is_local).await;
792 // Send the message
793 self.write_all(&byte_message).await?;
794 Ok(())
795 }
796
797 async fn send_sync_message(&mut self, message: &dyn Query, is_local: bool) -> Result<K> {
798 // Serialize a message
799 let byte_message = message.serialize(qmsg_type::synchronous, is_local).await;
800 // Send the message
801 self.write_all(&byte_message).await?;
802 // Receive a response. If message type is not response it returns an error.
803 match receive_message(self).await {
804 Ok((qmsg_type::response, response)) => Ok(response),
805 Err(error) => Err(error),
806 Ok((_, message)) => Err(io::Error::new(
807 io::ErrorKind::InvalidData,
808 format!("expected a response: {}", message),
809 )
810 .into()),
811 }
812 }
813
814 async fn receive_message(&mut self) -> Result<(u8, K)> {
815 receive_message(self).await
816 }
817}
818
819//%% MessageHeader %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
820
821impl MessageHeader {
822 /// Constructor.
823 fn new(encoding: u8, message_type: u8, compressed: u8, length: u32) -> Self {
824 MessageHeader {
825 encoding,
826 message_type,
827 compressed,
828 _unused: 0,
829 length,
830 }
831 }
832
833 /// Constructor from bytes.
834 fn from_bytes(bytes: [u8; 8]) -> Self {
835 let encoding = bytes[0];
836
837 let length = match encoding {
838 0 => u32::from_be_bytes(bytes[4..8].try_into().unwrap()),
839 _ => u32::from_le_bytes(bytes[4..8].try_into().unwrap()),
840 };
841
842 // Build header
843 MessageHeader::new(encoding, bytes[1], bytes[2], length)
844 }
845
846 /// Length of bytes for a header.
847 fn size() -> usize {
848 8
849 }
850}
851
852//++++++++++++++++++++++++++++++++++++++++++++++++++//
853// >> Private Functions
854//++++++++++++++++++++++++++++++++++++++++++++++++++//
855
856//%% QStream Connector %%//vvvvvvvvvvvvvvvvvvvvvvvvvv/
857
858/// Inner function of `connect_tcp` and `connect_tls` to establish a TCP connection with the sepcified
859/// endpoint. The hostname is resolved to an IP address with a system DNS resolver or parsed directly
860/// as an IP address.
861///
862/// Tries to connect to multiple resolved IP addresses until the first successful connection. Error is
863/// returned if none of them are valid.
864/// # Parameters
865/// - `host`: Hostname or IP address of the target q/kdb+ process.
866/// - `port`: Port of the target q process
867async fn connect_tcp_impl(host: &str, port: u16) -> Result<TcpStream> {
868 // DNS system resolver (should not fail)
869 let resolver =
870 TokioAsyncResolver::tokio_from_system_conf().expect("failed to create a resolver");
871
872 // Check if we were given an IP address
873 let ips;
874 if let Ok(ip) = host.parse::<IpAddr>() {
875 ips = vec![ip.to_string()]
876 } else {
877 // Resolve the given hostname
878 ips = resolver
879 .ipv4_lookup(format!("{}.", host).as_str())
880 .await
881 .unwrap()
882 .iter()
883 .map(|result| result.to_string())
884 .collect()
885 };
886
887 for answer in ips {
888 let host_port = format!("{}:{}", answer, port);
889 // Return if this IP address is valid
890 match TcpStream::connect(&host_port).await {
891 Ok(socket) => {
892 //println!("connected: {}", host_port);
893 return Ok(socket);
894 }
895 Err(_) => {
896 eprintln!("connection refused: {}. try next.", host_port);
897 }
898 }
899 }
900 // All addresses failed.
901 Err(io::Error::new(io::ErrorKind::ConnectionRefused, "failed to connect").into())
902}
903
904/// Send a credential and receive a common capacity.
905async fn handshake<S>(socket: &mut S, credential_: &str, method_bytes: &str) -> Result<()>
906where
907 S: Unpin + AsyncWriteExt + AsyncReadExt,
908{
909 // Send credential
910 let credential = credential_.to_string() + method_bytes;
911 socket.write_all(credential.as_bytes()).await?;
912
913 // Placeholder of common capablility
914 let mut cap = [0u8; 1];
915 if (socket.read_exact(&mut cap).await).is_err() {
916 // Connection is closed in case of authentication failure
917 Err(io::Error::new(io::ErrorKind::ConnectionAborted, "authentication failure").into())
918 } else {
919 Ok(())
920 }
921}
922
923/// Connect to q process running on a specified `host` and `port` via TCP with a credential `username:password`.
924/// # Parameters
925/// - `host`: Hostname or IP address of the target q process.
926/// - `port`: Port of the target q process.
927/// - `credential`: Credential in the form of `username:password` to connect to the target q process.
928async fn connect_tcp(host: &str, port: u16, credential: &str) -> Result<TcpStream> {
929 // Connect via TCP
930 let mut socket = connect_tcp_impl(host, port).await?;
931 // Handshake
932 handshake(&mut socket, credential, "\x03\x00").await?;
933 Ok(socket)
934}
935
936/// TLS version of `connect_tcp`.
937/// # Parameters
938/// - `host`: Hostname or IP address of the target q process.
939/// - `port`: Port of the target q process.
940/// - `credential`: Credential in the form of `username:password` to connect to the target q process.
941async fn connect_tls(host: &str, port: u16, credential: &str) -> Result<TlsStream<TcpStream>> {
942 // Connect via TCP
943 let socket_ = connect_tcp_impl(host, port).await?;
944 // Use TLS
945 let connector = TlsConnector::from(TlsConnectorInner::new().unwrap());
946 let mut socket = connector
947 .connect(host, socket_)
948 .await
949 .expect("failed to create TLS session");
950 // Handshake
951 handshake(&mut socket, credential, "\x03\x00").await?;
952 Ok(socket)
953}
954
955/// Build a path of a socket file.
956fn create_sockfile_path(port: u16) -> Result<String> {
957 // Create file path
958 let udspath = match env::var("QUDSPATH") {
959 Ok(dir) => format!("{}/kx.{}", dir, port),
960 Err(_) => format!("/tmp/kx.{}", port),
961 };
962
963 Ok(udspath)
964}
965
966/// Connect to q process running on the specified `port` via Unix domain socket with a credential `username:password`.
967/// # Parameters
968/// - `port`: Port of the target q process.
969/// - `credential`: Credential in the form of `username:password` to connect to the target q process.
970#[cfg(unix)]
971async fn connect_uds(port: u16, credential: &str) -> Result<UnixStream> {
972 // Create a file path.
973 let uds_path = create_sockfile_path(port)?;
974 let abstract_sockfile_ = format!("\x00{}", uds_path);
975 let abstract_sockfile = Path::new(&abstract_sockfile_);
976 // Connect to kdb+.
977 let mut socket = UnixStream::connect(&abstract_sockfile).await?;
978 // Handshake
979 handshake(&mut socket, credential, "\x06\x00").await?;
980
981 Ok(socket)
982}
983
984//%% QStream Acceptor %%//vvvvvvvvvvvvvvvvvvvvvvvvvvv/
985
986/// Read username, password, capacity and null byte from q client at the connection and does authentication.
987/// Close the handle if the authentication fails.
988async fn read_client_input<S>(socket: &mut S) -> Result<()>
989where
990 S: Unpin + AsyncWriteExt + AsyncReadExt,
991{
992 // Buffer to read inputs.
993 let mut client_input = [0u8; 32];
994 // credential will be built from small fractions of bytes.
995 let mut passed_credential = String::new();
996 loop {
997 // Read a client credential input.
998 match socket.read(&mut client_input).await {
999 Ok(0) => {
1000 // No bytes were read
1001 }
1002 Ok(_) => {
1003 // Locate a byte denoting a capacity
1004 if let Some(index) = client_input
1005 .iter()
1006 .position(|byte| *byte == 0x03 || *byte == 0x06)
1007 {
1008 let capacity = client_input[index];
1009 passed_credential
1010 .push_str(str::from_utf8(&client_input[0..index]).expect("invalid bytes"));
1011 let credential = passed_credential.as_str().split(':').collect::<Vec<&str>>();
1012 #[allow(clippy::borrow_interior_mutable_const)]
1013 if let Some(encoded) = ACCOUNTS.get(&credential[0].to_string()) {
1014 // User exists
1015 let mut hasher = Sha1::new();
1016 hasher.update(credential[1].as_bytes());
1017 let encoded_password = hasher.digest().to_string();
1018 if encoded == &encoded_password {
1019 // Client passed correct credential
1020 socket.write_all(&[capacity; 1]).await?;
1021 return Ok(());
1022 } else {
1023 // Authentication failure.
1024 // Close connection.
1025 socket.shutdown().await?;
1026 return Err(io::Error::new(
1027 io::ErrorKind::InvalidData,
1028 "authentication failed",
1029 )
1030 .into());
1031 }
1032 } else {
1033 // Authentication failure.
1034 // Close connection.
1035 socket.shutdown().await?;
1036 return Err(io::Error::new(
1037 io::ErrorKind::InvalidData,
1038 "authentication failed",
1039 )
1040 .into());
1041 }
1042 } else {
1043 // Append a fraction of credential
1044 passed_credential
1045 .push_str(str::from_utf8(&client_input).expect("invalid bytes"));
1046 }
1047 }
1048 Err(error) => {
1049 return Err(error.into());
1050 }
1051 }
1052 }
1053}
1054
1055/// Check if server key exists and return teh contents.
1056async fn build_identity_from_cert() -> Result<Identity> {
1057 // Check if server key exists.
1058 if let Ok(path) = env::var("KDBPLUS_TLS_KEY_FILE") {
1059 if let Ok(password) = env::var("KDBPLUS_TLS_KEY_FILE_SECRET") {
1060 let cert_file = tokio::fs::File::open(Path::new(&path)).await.unwrap();
1061 let mut reader = BufReader::new(cert_file);
1062 let mut der: Vec<u8> = Vec::new();
1063 // Read the key file.
1064 reader.read_to_end(&mut der).await?;
1065 // Create identity.
1066 if let Ok(identity) = Identity::from_pkcs12(&der, &password) {
1067 Ok(identity)
1068 } else {
1069 Err(io::Error::new(io::ErrorKind::InvalidData, "authentication failed").into())
1070 }
1071 } else {
1072 Err(io::Error::new(
1073 io::ErrorKind::NotFound,
1074 "KDBPLUS_TLS_KEY_FILE_SECRET is not set",
1075 )
1076 .into())
1077 }
1078 } else {
1079 Err(io::Error::new(io::ErrorKind::NotFound, "KDBPLUS_TLS_KEY_FILE is not set").into())
1080 }
1081}
1082
1083//%% QStream Query %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
1084
1085/// Receive a message from q process with decompression if necessary. The received message is parsed as `K` and message type is
1086/// stored in the first returned value.
1087/// # Parameters
1088/// - `socket`: Socket to communicate with a q process. Either of `TcpStream`, `TlsStream<TcpStream>` or `UnixStream`.
1089async fn receive_message<S>(socket: &mut S) -> Result<(u8, K)>
1090where
1091 S: Unpin + AsyncReadExt,
1092{
1093 // Read header
1094 let mut header_buffer = [0u8; 8];
1095 if let Err(err) = socket.read_exact(&mut header_buffer).await {
1096 // The expected message is header or EOF (close due to q process failure resulting from a bad query)
1097 return Err(io::Error::new(
1098 io::ErrorKind::ConnectionAborted,
1099 format!("Connection dropped: {}", err),
1100 )
1101 .into());
1102 }
1103
1104 // Parse message header
1105 let header = MessageHeader::from_bytes(header_buffer);
1106
1107 // Read body
1108 let body_length = header.length as usize - MessageHeader::size();
1109 let mut body: Vec<u8> = vec![0; body_length];
1110 if let Err(err) = socket.read_exact(&mut body).await {
1111 // Fails if q process fails before reading the body
1112 return Err(io::Error::new(
1113 io::ErrorKind::UnexpectedEof,
1114 format!("Failed to read body of message: {}", err),
1115 )
1116 .into());
1117 }
1118
1119 // Decompress if necessary
1120 if header.compressed == 0x01 {
1121 body = decompress(body, header.encoding).await;
1122 }
1123
1124 Ok((
1125 header.message_type,
1126 K::q_ipc_decode(&body, header.encoding).await,
1127 ))
1128}
1129
1130/// Compress body. The combination of serializing the data and compressing will result in
1131/// the same output as shown in the q language by using the -18! function e.g.
1132/// serializing 2000 bools set to true, then compressing, will have the same output as `-18!2000#1b`.
1133/// # Parameter
1134/// - `raw`: Serialized message.
1135/// - `encode`: `0` if Big Endian; `1` if Little Endian.
1136async fn compress(raw: Vec<u8>) -> (bool, Vec<u8>) {
1137 let mut i = 0_u8;
1138 let mut f = 0_u8;
1139 let mut h0 = 0_usize;
1140 let mut h = 0_usize;
1141 let mut g: bool;
1142 let mut compressed: Vec<u8> = vec![0; (raw.len()) / 2];
1143
1144 // Start index of compressed body
1145 // 12 bytes are reserved for the header + size of raw bytes
1146 let mut c = 12;
1147 let mut d = c;
1148 let e = compressed.len();
1149 let mut p = 0_usize;
1150 let mut q: usize;
1151 let mut r: usize;
1152 let mut s0 = 0_usize;
1153
1154 // Body starts from index 8
1155 let mut s = 8_usize;
1156 let t = raw.len();
1157 let mut a = [0_i32; 256];
1158
1159 // Copy encode, message type, compressed and reserved
1160 compressed[0..4].copy_from_slice(&raw[0..4]);
1161 // Set compressed flag
1162 compressed[2] = 1;
1163
1164 // Write size of raw bytes including a header
1165 let raw_size = match ENCODING {
1166 0 => (t as u32).to_be_bytes(),
1167 _ => (t as u32).to_le_bytes(),
1168 };
1169 compressed[8..12].copy_from_slice(&raw_size);
1170
1171 while s < t {
1172 if i == 0 {
1173 if d > e - 17 {
1174 // Early return when compressing to less than half failed
1175 return (false, raw);
1176 }
1177 i = 1;
1178 compressed[c] = f;
1179 c = d;
1180 d += 1;
1181 f = 0;
1182 }
1183 g = s > t - 3;
1184 if !g {
1185 h = (raw[s] ^ raw[s + 1]) as usize;
1186 p = a[h] as usize;
1187 g = (0 == p) || (0 != (raw[s] ^ raw[p]));
1188 }
1189 if 0 < s0 {
1190 a[h0] = s0 as i32;
1191 s0 = 0;
1192 }
1193 if g {
1194 h0 = h;
1195 s0 = s;
1196 compressed[d] = raw[s];
1197 d += 1;
1198 s += 1;
1199 } else {
1200 a[h] = s as i32;
1201 f |= i;
1202 p += 2;
1203 s += 2;
1204 r = s;
1205 q = if s + 255 > t { t } else { s + 255 };
1206 while (s < q) && (raw[p] == raw[s]) {
1207 s += 1;
1208 if s < q {
1209 p += 1;
1210 }
1211 }
1212 compressed[d] = h as u8;
1213 d += 1;
1214 compressed[d] = (s - r) as u8;
1215 d += 1;
1216 }
1217 i = i.wrapping_mul(2);
1218 }
1219 compressed[c] = f;
1220 // Final compressed data size
1221 let compressed_size = match ENCODING {
1222 0 => (d as u32).to_be_bytes(),
1223 _ => (d as u32).to_le_bytes(),
1224 };
1225 compressed[4..8].copy_from_slice(&compressed_size);
1226 let _ = compressed.split_off(d);
1227 (true, compressed)
1228}
1229
1230/// Decompress body. The combination of decompressing and deserializing the data
1231/// will result in the same output as shown in the q language by using the `-19!` function.
1232/// # Parameter
1233/// - `compressed`: Compressed serialized message.
1234/// - `encoding`:
1235/// - `0`: Big Endian
1236/// - `1`: Little Endian.
1237async fn decompress(compressed: Vec<u8>, encoding: u8) -> Vec<u8> {
1238 let mut n = 0;
1239 let mut r: usize;
1240 let mut f = 0_usize;
1241
1242 // Header has already been removed.
1243 // Start index of decompressed bytes is 0
1244 let mut s = 0_usize;
1245 let mut p = s;
1246 let mut i = 0_usize;
1247
1248 // Subtract 8 bytes from decoded bytes size as 8 bytes have already been taken as header
1249 let size = match encoding {
1250 0 => {
1251 i32::from_be_bytes(
1252 compressed[0..4]
1253 .try_into()
1254 .expect("slice does not have length 4"),
1255 ) - 8
1256 }
1257 _ => {
1258 i32::from_le_bytes(
1259 compressed[0..4]
1260 .try_into()
1261 .expect("slice does not have length 4"),
1262 ) - 8
1263 }
1264 };
1265 let mut decompressed: Vec<u8> = vec![0; size as usize];
1266
1267 // Start index of compressed body.
1268 // 8 bytes have already been removed as header
1269 let mut d = 4;
1270 let mut aa = [0_i32; 256];
1271 while s < decompressed.len() {
1272 if i == 0 {
1273 f = compressed[d] as usize;
1274 d += 1;
1275 i = 1;
1276 }
1277 if (f & i) != 0 {
1278 r = aa[compressed[d] as usize] as usize;
1279 d += 1;
1280 decompressed[s] = decompressed[r];
1281 s += 1;
1282 r += 1;
1283 decompressed[s] = decompressed[r];
1284 s += 1;
1285 r += 1;
1286 n = compressed[d] as usize;
1287 d += 1;
1288 for m in 0..n {
1289 decompressed[s + m] = decompressed[r + m];
1290 }
1291 } else {
1292 decompressed[s] = compressed[d];
1293 s += 1;
1294 d += 1;
1295 }
1296 while p < s - 1 {
1297 aa[(decompressed[p] ^ decompressed[p + 1]) as usize] = p as i32;
1298 p += 1;
1299 }
1300 if (f & i) != 0 {
1301 s += n;
1302 p = s;
1303 }
1304 i *= 2;
1305 if i == 256 {
1306 i = 0;
1307 }
1308 }
1309 decompressed
1310}