Skip to main content

kdb_plus_fixed/ipc/
connection.rs

1//++++++++++++++++++++++++++++++++++++++++++++++++++//
2// >> Load Libraries
3//++++++++++++++++++++++++++++++++++++++++++++++++++//
4
5use super::serialize::ENCODING;
6use super::Result;
7use super::{qtype, K};
8use async_trait::async_trait;
9use io::BufRead;
10use once_cell::sync::Lazy;
11use sha1_smol::Sha1;
12use std::collections::HashMap;
13use std::convert::TryInto;
14use std::net::{IpAddr, Ipv4Addr};
15use std::path::Path;
16use std::{env, fs, io, str};
17use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
18use tokio::net::{TcpListener, TcpStream};
19#[cfg(unix)]
20use tokio::net::{UnixListener, UnixStream};
21use tokio_native_tls::native_tls::{
22    Identity, TlsAcceptor as TlsAcceptorInner, TlsConnector as TlsConnectorInner,
23};
24use tokio_native_tls::{TlsAcceptor, TlsConnector, TlsStream};
25use trust_dns_resolver::TokioAsyncResolver;
26
27//++++++++++++++++++++++++++++++++++++++++++++++++++//
28// >> Global Variable
29//++++++++++++++++++++++++++++++++++++++++++++++++++//
30
31//%% QStream %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
32
33pub mod qmsg_type {
34    //! This module provides a list of q message type used for IPC.
35    //!  The motivation to contain them in a module is to tie them up as related items rather
36    //!  than scattered values. Hence user should use these indicators with `qmsg_type::` prefix, e.g., `qmsg_type::asynchronous`.
37    //!
38    //! # Example
39    //! ```no_run
40    //! use kdbplus::ipc::*;
41    //!
42    //! // Print `K` object.
43    //! fn print(obj: &K) {
44    //!     println!("{}", obj);
45    //! }
46    //!
47    //! // Calculate something from two long arguments.
48    //! fn nonsense(arg1: i64, arg2: i64) -> i64 {
49    //!     arg1 * arg2
50    //! }
51    //!
52    //! #[tokio::main]
53    //! async fn main() -> Result<()> {
54    //!     // Connect to qprocess running on localhost:5000 via TCP
55    //!     let mut socket =
56    //!         QStream::connect(ConnectionMethod::TCP, "localhost", 5000_u16, "ideal:person").await?;
57    //!
58    //!     // Set a function which sends back a non-response message during its execution.
59    //!     socket
60    //!         .send_async_message(
61    //!             &"complex:{neg[.z.w](`print; \"counter\"); what: .z.w (`nonsense; 1; 2); what*100}",
62    //!         )
63    //!         .await?;
64    //!
65    //!     // Send a query `(`complex; ::)` without waiting for a response.
66    //!     socket
67    //!         .send_message(
68    //!             &K::new_compound_list(vec![K::new_symbol(String::from("complex")), K::new_null()]),
69    //!             qmsg_type::synchronous,
70    //!         )
71    //!         .await?;
72    //!
73    //!     // Receive an asynchronous call from the function.
74    //!     match socket.receive_message().await {
75    //!         Ok((qmsg_type::asynchronous, message)) => {
76    //!             println!("asynchronous call: {}", message);
77    //!             let list = message.as_vec::<K>().unwrap();
78    //!             if list[0].get_symbol().unwrap() == "print" {
79    //!                 print(&list[1])
80    //!             }
81    //!         }
82    //!         _ => unreachable!(),
83    //!     }
84    //!
85    //!     // Receive a synchronous call from the function.
86    //!     match socket.receive_message().await {
87    //!         Ok((qmsg_type::synchronous, message)) => {
88    //!             println!("synchronous call: {}", message);
89    //!             let list = message.as_vec::<K>().unwrap();
90    //!             if list[0].get_symbol().unwrap() == "nonsense" {
91    //!                 let res = nonsense(list[1].get_long().unwrap(), list[2].get_long().unwrap());
92    //!                 // Send bach a response.
93    //!                 socket
94    //!                     .send_message(&K::new_long(res), qmsg_type::response)
95    //!                     .await?;
96    //!             }
97    //!         }
98    //!         _ => unreachable!(),
99    //!     }
100    //!
101    //!     // Receive a final result.
102    //!     match socket.receive_message().await {
103    //!         Ok((qmsg_type::response, message)) => {
104    //!             println!("final: {}", message);
105    //!         }
106    //!         _ => unreachable!(),
107    //!     }
108    //!
109    //!     Ok(())
110    //! }
111    //!```
112    /// Used to send a message to q/kdb+ asynchronously.
113    pub const asynchronous: u8 = 0;
114    /// Used to send a message to q/kdb+ synchronously.
115    pub const synchronous: u8 = 1;
116    /// Used by q/kdb+ to identify a response for a synchronous query.
117    pub const response: u8 = 2;
118}
119
120//%% QStream Acceptor %%//vvvvvvvvvvvvvvvvvvvvvvvvvvv/
121
122/// Map from user name to password hashed with SHA1.
123#[allow(clippy::declare_interior_mutable_const)]
124const ACCOUNTS: Lazy<HashMap<String, String>> = Lazy::new(|| {
125    // Map from user to password
126    let mut map: HashMap<String, String> = HashMap::new();
127    // Open credential file
128    let file = fs::OpenOptions::new()
129        .read(true)
130        .open(env::var("KDBPLUS_ACCOUNT_FILE").expect("KDBPLUS_ACCOUNT_FILE is not set"))
131        .expect("failed to open account file");
132    let mut reader = io::BufReader::new(file);
133    let mut line = String::new();
134    loop {
135        match reader.read_line(&mut line) {
136            Ok(0) => break,
137            Ok(_) => {
138                let credential = line.as_str().split(':').collect::<Vec<&str>>();
139                let mut password = credential[1];
140                if password.ends_with('\n') {
141                    password = &password[0..password.len() - 1];
142                }
143                map.insert(credential[0].to_string(), password.to_string());
144                line.clear();
145            }
146            _ => unreachable!(),
147        }
148    }
149    map
150});
151
152//++++++++++++++++++++++++++++++++++++++++++++++++++//
153// >> Structs
154//++++++++++++++++++++++++++++++++++++++++++++++++++//
155
156//%% ConnectionMethod %%//vvvvvvvvvvvvvvvvvvvvvvvvvvv/
157
158/// Connection method to q/kdb+.
159pub enum ConnectionMethod {
160    TCP = 0,
161    TLS = 1,
162    /// Unix domanin socket.
163    UDS = 2,
164}
165
166//%% Query %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
167
168/// Feature of query object.
169#[async_trait]
170pub trait Query: Send + Sync {
171    /// Serialize into q IPC bytes including a header (encoding, message type, compresssion flag and total message length).
172    ///   If the connection is within the same host, the message is not compressed under any conditions.
173    /// # Parameters
174    /// - `message_type`: Message type. One of followings:
175    ///   - `qmsg_type::asynchronous`
176    ///   - `qmsg_type::synchronous`
177    ///   - `qmsg_type::response`
178    /// - `is_local`: Flag of whether the connection is within the same host.
179    async fn serialize(&self, message_type: u8, is_local: bool) -> Vec<u8>;
180}
181
182//%% QStreamInner %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
183
184/// Features which streams communicating with q must have.
185#[async_trait]
186trait QStreamInner: Send + Sync {
187    /// Shutdown underlying stream.
188    async fn shutdown(&mut self, is_server: bool) -> Result<()>;
189    /// Send a message with a specified message type without waiting for a response.
190    /// # Parameters
191    /// - `message`: q command to execute on the remote q process.
192    /// - `message_type`: Asynchronous or synchronous.
193    /// - `is_local`: Flag of whether the connection is within the same host.
194    async fn send_message(
195        &mut self,
196        message: &dyn Query,
197        message_type: u8,
198        is_local: bool,
199    ) -> Result<()>;
200    /// Send a message asynchronously.
201    /// # Parameters
202    /// - `message`: q command in two ways:
203    ///   - `&str`: q command in a string form.
204    ///   - `K`: Query in a functional form.
205    /// - `is_local`: Flag of whether the connection is within the same host.
206    async fn send_async_message(&mut self, message: &dyn Query, is_local: bool) -> Result<()>;
207    /// Send a message asynchronously.
208    /// # Parameters
209    /// - `message`: q command in two ways:
210    ///   - `&str`: q command in a string form.
211    ///   - `K`: Query in a functional form.
212    /// - `is_local`: Flag of whether the connection is within the same host.
213    async fn send_sync_message(&mut self, message: &dyn Query, is_local: bool) -> Result<K>;
214    /// Receive a message from a remote q process. The received message is parsed as `K` and message type is
215    ///  stored in the first returned value.
216    async fn receive_message(&mut self) -> Result<(u8, K)>;
217}
218
219//%% QStream %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
220
221/// Stream to communicate with q/kdb+.
222pub struct QStream {
223    /// Actual stream to communicate.
224    stream: Box<dyn QStreamInner>,
225    /// Connection method. One of followings:
226    /// - TCP
227    /// - TLS
228    /// - UDS
229    method: ConnectionMethod,
230    /// Indicator of whether the stream is an acceptor or client.
231    /// - `true`: Acceptor
232    /// - `false`: Client
233    listener: bool,
234    /// Indicator of whether the connection is within the same host.
235    /// - `true`: Connection within the same host.
236    /// - `false`: Connection with outseide.
237    local: bool,
238}
239
240//%% MessageHeader %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
241
242/// Header of q IPC data frame.
243#[derive(Clone, Copy, Debug)]
244struct MessageHeader {
245    /// Ennoding.
246    /// - 0: Big Endian
247    /// - 1: Little Endian
248    encoding: u8,
249    /// Message type. One of followings:
250    /// - 0: Asynchronous
251    /// - 1: Synchronous
252    /// - 2: Response
253    message_type: u8,
254    /// Indicator of whether the message is compressed or not.
255    /// - 0: Uncompressed
256    /// - 1: Compressed
257    compressed: u8,
258    /// Reserved byte.
259    _unused: u8,
260    /// Total length of the uncompressed message.
261    length: u32,
262}
263
264//++++++++++++++++++++++++++++++++++++++++++++++++++//
265// >> Implementation
266//++++++++++++++++++++++++++++++++++++++++++++++++++//
267
268//%% Query %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
269
270/// Text query.
271#[async_trait]
272impl Query for &str {
273    async fn serialize(&self, message_type: u8, _: bool) -> Vec<u8> {
274        //  Build header //--------------------------------/
275        // Message header + (type indicator of string + header of string type) + string length
276        let byte_message = self.as_bytes();
277        let message_length = byte_message.len() as u32;
278        let total_length = MessageHeader::size() as u32 + 6 + message_length;
279
280        let total_length_bytes = match ENCODING {
281            0 => total_length.to_be_bytes(),
282            _ => total_length.to_le_bytes(),
283        };
284
285        // encode, message type, 0x00 for compression and 0x00 for reserved.
286        // Do not compress string data because it is highly unlikely that the length of the string query
287        //  is greater than 2000.
288        let mut message = Vec::with_capacity(message_length as usize + MessageHeader::size());
289        message.extend_from_slice(&[ENCODING, message_type, 0, 0]);
290        // total body length
291        message.extend_from_slice(&total_length_bytes);
292        // vector type and 0x00 for attribute
293        message.extend_from_slice(&[qtype::STRING as u8, 0]);
294
295        //  Build body //---------------------------------/
296        let length_info = match ENCODING {
297            0 => message_length.to_be_bytes(),
298            _ => message_length.to_le_bytes(),
299        };
300
301        // length of vector(message)
302        message.extend_from_slice(&length_info);
303        // message
304        message.extend_from_slice(byte_message);
305
306        message
307    }
308}
309
310/// Functional query.
311#[async_trait]
312impl Query for K {
313    async fn serialize(&self, message_type: u8, is_local: bool) -> Vec<u8> {
314        //  Build header //--------------------------------/
315        // Message header + encoded data size
316        let mut byte_message = self.q_ipc_encode();
317        let message_length = byte_message.len();
318        let total_length = (MessageHeader::size() + message_length) as u32;
319
320        let total_length_bytes = match ENCODING {
321            0 => total_length.to_be_bytes(),
322            _ => total_length.to_le_bytes(),
323        };
324
325        // Compression is trigerred when entire message size is more than 2000 bytes
326        //  and the connection is with outseide.
327        if message_length > 1992 && !is_local {
328            // encode, message type, 0x00 for compression, 0x00 for reserved and 0x00000000 for total size
329            let mut message = Vec::with_capacity(message_length + 8);
330            message.extend_from_slice(&[ENCODING, message_type, 0, 0, 0, 0, 0, 0]);
331            message.append(&mut byte_message);
332            // Try to encode entire message.
333            match compress(message).await {
334                (true, compressed) => {
335                    // Message was compressed
336                    return compressed;
337                }
338                (false, mut uncompressed) => {
339                    // Message was not compressed.
340                    // Write original total data size.
341                    uncompressed[4..8].copy_from_slice(&total_length_bytes);
342                    return uncompressed;
343                }
344            }
345        } else {
346            // encode, message type, 0x00 for compression and 0x00 for reserved
347            let mut message = Vec::with_capacity(message_length + MessageHeader::size());
348            message.extend_from_slice(&[ENCODING, message_type, 0, 0]);
349            // Total length of body
350            message.extend_from_slice(&total_length_bytes);
351            message.append(&mut byte_message);
352            return message;
353        }
354    }
355}
356
357//%% QStream %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
358
359impl QStream {
360    /// General constructor of `QStream`.
361    fn new(
362        stream: Box<dyn QStreamInner>,
363        method: ConnectionMethod,
364        is_listener: bool,
365        is_local: bool,
366    ) -> Self {
367        QStream {
368            stream,
369            method,
370            listener: is_listener,
371            local: is_local,
372        }
373    }
374
375    /// Connect to q/kdb+ specifying a connection method, destination host, destination port and access credential.
376    /// # Parameters
377    /// - `method`: Connection method. One of followings:
378    ///   - TCP
379    ///   - TLS
380    ///   - UDS
381    /// - `host`: Hostname or IP address of the target q process. Empty `str` for Unix domain socket.
382    /// - `port`: Port of the target q process.
383    /// - `credential`: Credential in the form of `username:password` to connect to the target q process.
384    /// # Example
385    /// ```
386    /// use kdbplus::qattribute;
387    /// use kdbplus::ipc::*;
388    ///
389    /// #[tokio::main(flavor = "multi_thread", worker_threads = 2)]
390    /// async fn main() -> Result<()> {
391    ///     let mut socket =
392    ///         QStream::connect(ConnectionMethod::UDS, "", 5000_u16, "ideal:person").await?;
393    ///     println!("Connection type: {}", socket.get_connection_type());
394    ///
395    ///     // Set remote function with asynchronous message
396    ///     socket.send_async_message(&"collatz:{[n] seq:enlist n; while[not n = 1; seq,: n:$[n mod 2; 1 + 3 * n; `long$n % 2]]; seq}").await?;
397    ///
398    ///     // Send a query synchronously
399    ///     let mut result = socket.send_sync_message(&"collatz[12]").await?;
400    ///     println!("collatz[12]: {}", result);
401    ///
402    ///     // Send a functional query.
403    ///     let mut message = K::new_compound_list(vec![
404    ///         K::new_symbol(String::from("collatz")),
405    ///         K::new_long(100),
406    ///     ]);
407    ///     result = socket.send_sync_message(&message).await?;
408    ///     println!("collatz[100]: {}", result);
409    ///
410    ///     // Send a functional asynchronous query.
411    ///     message = K::new_compound_list(vec![
412    ///         K::new_string(String::from("show"), qattribute::NONE),
413    ///         K::new_symbol(String::from("goodbye")),
414    ///     ]);
415    ///     socket.send_async_message(&message).await?;
416    ///
417    ///     socket.shutdown().await?;
418    ///
419    ///     Ok(())
420    /// }
421    /// ```
422    pub async fn connect(
423        method: ConnectionMethod,
424        host: &str,
425        port: u16,
426        credential: &str,
427    ) -> Result<Self> {
428        match method {
429            ConnectionMethod::TCP => {
430                let stream = connect_tcp(host, port, credential).await?;
431                let is_local = matches!(host, "localhost" | "127.0.0.1");
432                Ok(QStream::new(
433                    Box::new(stream),
434                    ConnectionMethod::TCP,
435                    false,
436                    is_local,
437                ))
438            }
439            ConnectionMethod::TLS => {
440                let stream = connect_tls(host, port, credential).await?;
441                Ok(QStream::new(
442                    Box::new(stream),
443                    ConnectionMethod::TLS,
444                    false,
445                    false,
446                ))
447            }
448            ConnectionMethod::UDS => {
449                let stream = connect_uds(port, credential).await?;
450                Ok(QStream::new(
451                    Box::new(stream),
452                    ConnectionMethod::UDS,
453                    false,
454                    true,
455                ))
456            }
457        }
458    }
459
460    /// Accept connection and does handshake.
461    /// # Parameters
462    /// - `method`: Connection method. One of followings:
463    ///   - TCP
464    ///   - TLS
465    ///   - UDS
466    /// - host: Hostname or IP address of this listener. Empty `str` for Unix domain socket.
467    /// - port: Listening port.
468    /// # Example
469    /// ```no_run
470    /// use kdbplus::ipc::*;
471    ///  
472    /// #[tokio::main]
473    /// async fn main() -> Result<()> {
474    ///     // Start listenening over UDS at the port 7000 with authentication enabled.
475    ///     while let Ok(mut socket) = QStream::accept(ConnectionMethod::UDS, "", 7000).await {
476    ///         tokio::task::spawn(async move {
477    ///             loop {
478    ///                 match socket.receive_message().await {
479    ///                     Ok((_, message)) => {
480    ///                         println!("request: {}", message);
481    ///                     }
482    ///                     _ => {
483    ///                         socket.shutdown().await.unwrap();
484    ///                         break;
485    ///                     }
486    ///                 }
487    ///             }
488    ///         });
489    ///     }
490    ///
491    ///     Ok(())
492    /// }
493    /// ```
494    /// q processes can connect and send messages to this acceptor.
495    /// ```q
496    /// q)// Process1
497    /// q)h:hopen `:unix://7000:reluctant:slowday
498    /// q)neg[h] (`monalizza; 3.8)
499    /// q)neg[h] (`pizza; 125)
500    /// ```
501    /// ```q
502    /// q)// Process2
503    /// q)h:hopen `:unix://7000:mattew:oracle
504    /// q)neg[h] (`teddy; "bear")
505    /// ```
506    /// # Note
507    /// - TLS acceptor sets `.kdbplus.close_tls_connection_` on q clien via an asynchronous message. This function is necessary to close
508    ///   the socket from the server side without crashing server side application.
509    /// - TLS acceptor and UDS acceptor use specific environmental variables to work. See the [Environmental Variable](../ipc/index.html#environmentl-variables) section for details.
510    pub async fn accept(method: ConnectionMethod, host: &str, port: u16) -> Result<Self> {
511        match method {
512            ConnectionMethod::TCP => {
513                // Bind to the endpoint.
514                let listener = TcpListener::bind(&format!("{}:{}", host, port)).await?;
515                // Listen to the endpoint.
516                let (mut socket, ip_address) = listener.accept().await?;
517                // Read untill null bytes and send back capacity.
518                while (read_client_input(&mut socket).await).is_err() {
519                    // Continue to listen in case of error.
520                    socket = listener.accept().await?.0;
521                }
522                // Check if the connection is local
523                Ok(QStream::new(
524                    Box::new(socket),
525                    ConnectionMethod::TCP,
526                    true,
527                    ip_address.ip() == IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
528                ))
529            }
530            ConnectionMethod::TLS => {
531                // Bind to the endpoint.
532                let listener = TcpListener::bind(&format!("{}:{}", host, port)).await?;
533                // Check if key exists and decode an identity with a given password.
534                let identity = build_identity_from_cert().await?;
535                // Build TLS acceptor.
536                let tls_acceptor = TlsAcceptor::from(TlsAcceptorInner::new(identity).unwrap());
537                // Listen to the endpoint.
538                let (mut socket, _) = listener.accept().await?;
539                // TLS processing.
540                let mut tls_socket = tls_acceptor
541                    .accept(socket)
542                    .await
543                    .expect("failed to accept TLS connection");
544                // Read untill null bytes and send back a capacity.
545                while (read_client_input(&mut tls_socket).await).is_err() {
546                    // Continue to listen in case of error.
547                    socket = listener.accept().await?.0;
548                    tls_socket = tls_acceptor
549                        .accept(socket)
550                        .await
551                        .expect("failed to accept TLS connection");
552                }
553                // TLS is always a remote connection
554                let mut qstream =
555                    QStream::new(Box::new(tls_socket), ConnectionMethod::TCP, true, false);
556                // In order to close the connection from the server side, it needs to tell a client to close the connection.
557                // The `kdbplus_close_tls_connection_` will be called from the server at shutdown.
558                qstream
559                    .send_async_message(&".kdbplus.close_tls_connection_:{[] hclose .z.w;}")
560                    .await?;
561                Ok(qstream)
562            }
563            ConnectionMethod::UDS => {
564                // uild a sockt file path.
565                let uds_path = create_sockfile_path(port)?;
566                let abstract_sockfile_ = format!("\x00{}", uds_path);
567                let abstract_sockfile = Path::new(&abstract_sockfile_);
568                // Bind to the file
569                let listener = UnixListener::bind(abstract_sockfile).unwrap();
570                // Listen to the endpoint
571                let (mut socket, _) = listener.accept().await?;
572                // Read untill null bytes and send back capacity.
573                while (read_client_input(&mut socket).await).is_err() {
574                    // Continue to listen in case of error.
575                    socket = listener.accept().await?.0;
576                }
577                // UDS is always a local connection
578                Ok(QStream::new(Box::new(socket), method, true, true))
579            }
580        }
581    }
582
583    /// Shutdown the socket for a q process.
584    /// # Example
585    /// See the example of [`connect`](#method.connect).
586    pub async fn shutdown(mut self) -> Result<()> {
587        self.stream.shutdown(self.listener).await
588    }
589
590    /// Send a message with a specified message type without waiting for a response even for a synchronous message.
591    ///  If you need to receive a response you need to use [`receive_message`](#method.receive_message).
592    /// # Note
593    /// The usage of this function for a synchronous message is to handle an asynchronous message or a synchronous message
594    ///   sent by a remote function during its execution.
595    /// # Parameters
596    /// - `message`: q command to execute on the remote q process.
597    ///   - `&str`: q command in a string form.
598    ///   - `K`: Query in a functional form.
599    /// - `message_type`: Asynchronous or synchronous.
600    /// # Example
601    /// See the example of [`connect`](#method.connect).
602    pub async fn send_message(&mut self, message: &dyn Query, message_type: u8) -> Result<()> {
603        self.stream
604            .send_message(message, message_type, self.local)
605            .await
606    }
607
608    /// Send a message asynchronously.
609    /// # Parameters
610    /// - `message`: q command to execute on the remote q process.
611    ///   - `&str`: q command in a string form.
612    ///   - `K`: Query in a functional form.
613    /// # Example
614    /// See the example of [`connect`](#method.connect).
615    pub async fn send_async_message(&mut self, message: &dyn Query) -> Result<()> {
616        self.stream.send_async_message(message, self.local).await
617    }
618
619    /// Send a message synchronously.
620    /// # Note
621    /// Remote function must NOT send back a message of asynchronous or synchronous type durning execution of the function.
622    /// # Parameters
623    /// - `message`: q command to execute on the remote q process.
624    ///   - `&str`: q command in a string form.
625    ///   - `K`: Query in a functional form.
626    /// # Example
627    /// See the example of [`connect`](#method.connect).
628    pub async fn send_sync_message(&mut self, message: &dyn Query) -> Result<K> {
629        self.stream.send_sync_message(message, self.local).await
630    }
631
632    /// Receive a message from a remote q process. The received message is parsed as `K` and message type is
633    ///  stored in the first returned value.
634    /// # Example
635    /// See the example of [`accept`](#method.accept).
636    pub async fn receive_message(&mut self) -> Result<(u8, K)> {
637        self.stream.receive_message().await
638    }
639
640    /// Return underlying connection type. One of `TCP`, `TLS` or `UDS`.
641    /// # Example
642    /// See the example of [`connect`](#method.connect).
643    pub fn get_connection_type(&self) -> &str {
644        match self.method {
645            ConnectionMethod::TCP => "TCP",
646            ConnectionMethod::TLS => "TLS",
647            ConnectionMethod::UDS => "UDS",
648        }
649    }
650
651    /// Enforce compression if the size of a message exceeds 2000 regardless of locality of the connection.
652    ///  This flag is not revertible intentionally.
653    pub fn enforce_compression(&mut self) {
654        self.local = false;
655    }
656}
657
658//%% QStreamInner %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
659
660#[async_trait]
661impl QStreamInner for TcpStream {
662    async fn shutdown(&mut self, _: bool) -> Result<()> {
663        AsyncWriteExt::shutdown(self).await?;
664        Ok(())
665    }
666
667    async fn send_message(
668        &mut self,
669        message: &dyn Query,
670        message_type: u8,
671        is_local: bool,
672    ) -> Result<()> {
673        // Serialize a message
674        let byte_message = message.serialize(message_type, is_local).await;
675        // Send the message
676        self.write_all(&byte_message).await?;
677        Ok(())
678    }
679
680    async fn send_async_message(&mut self, message: &dyn Query, is_local: bool) -> Result<()> {
681        // Serialize a message
682        let byte_message = message.serialize(qmsg_type::asynchronous, is_local).await;
683        // Send the message
684        self.write_all(&byte_message).await?;
685        Ok(())
686    }
687
688    async fn send_sync_message(&mut self, message: &dyn Query, is_local: bool) -> Result<K> {
689        // Serialize a message
690        let byte_message = message.serialize(qmsg_type::synchronous, is_local).await;
691        // Send the message
692        self.write_all(&byte_message).await?;
693        // Receive a response. If message type is not response it returns an error.
694        match receive_message(self).await {
695            Ok((qmsg_type::response, response)) => Ok(response),
696            Err(error) => Err(error),
697            Ok((_, message)) => Err(io::Error::new(
698                io::ErrorKind::InvalidData,
699                format!("expected a response: {}", message),
700            )
701            .into()),
702        }
703    }
704
705    async fn receive_message(&mut self) -> Result<(u8, K)> {
706        receive_message(self).await
707    }
708}
709
710#[async_trait]
711impl QStreamInner for TlsStream<TcpStream> {
712    async fn shutdown(&mut self, is_listener: bool) -> Result<()> {
713        if is_listener {
714            // Closing the handle from the server side by `self.get_mut().shutdown()` crashes due to 'assertion failed: !self.context.is_null()'.
715            // No reason to compress.
716            self.send_async_message(&".kdbplus.close_tls_connection_[]", false)
717                .await
718        } else {
719            self.get_mut().shutdown()?;
720            Ok(())
721        }
722    }
723
724    async fn send_message(
725        &mut self,
726        message: &dyn Query,
727        message_type: u8,
728        is_local: bool,
729    ) -> Result<()> {
730        // Serialize a message
731        let byte_message = message.serialize(message_type, is_local).await;
732        // Send the message
733        self.write_all(&byte_message).await?;
734        Ok(())
735    }
736
737    async fn send_async_message(&mut self, message: &dyn Query, is_local: bool) -> Result<()> {
738        // Serialize a message
739        let byte_message = message.serialize(qmsg_type::asynchronous, is_local).await;
740        // Send the message
741        self.write_all(&byte_message).await?;
742        Ok(())
743    }
744
745    async fn send_sync_message(&mut self, message: &dyn Query, is_local: bool) -> Result<K> {
746        // Serialize a message
747        let byte_message = message.serialize(qmsg_type::synchronous, is_local).await;
748        // Send the message
749        self.write_all(&byte_message).await?;
750        // Receive a response. If message type is not response it returns an error.
751        match receive_message(self).await {
752            Ok((qmsg_type::response, response)) => Ok(response),
753            Err(error) => Err(error),
754            Ok((_, message)) => Err(io::Error::new(
755                io::ErrorKind::InvalidData,
756                format!("expected a response: {}", message),
757            )
758            .into()),
759        }
760    }
761
762    async fn receive_message(&mut self) -> Result<(u8, K)> {
763        receive_message(self).await
764    }
765}
766
767#[async_trait]
768impl QStreamInner for UnixStream {
769    /// Close a handle to a q process which is connected with Unix Domain Socket.
770    ///  Socket file is removed.
771    async fn shutdown(&mut self, _: bool) -> Result<()> {
772        AsyncWriteExt::shutdown(self).await?;
773        Ok(())
774    }
775
776    async fn send_message(
777        &mut self,
778        message: &dyn Query,
779        message_type: u8,
780        is_local: bool,
781    ) -> Result<()> {
782        // Serialize a message
783        let byte_message = message.serialize(message_type, is_local).await;
784        // Send the message
785        self.write_all(&byte_message).await?;
786        Ok(())
787    }
788
789    async fn send_async_message(&mut self, message: &dyn Query, is_local: bool) -> Result<()> {
790        // Serialize a message
791        let byte_message = message.serialize(qmsg_type::asynchronous, is_local).await;
792        // Send the message
793        self.write_all(&byte_message).await?;
794        Ok(())
795    }
796
797    async fn send_sync_message(&mut self, message: &dyn Query, is_local: bool) -> Result<K> {
798        // Serialize a message
799        let byte_message = message.serialize(qmsg_type::synchronous, is_local).await;
800        // Send the message
801        self.write_all(&byte_message).await?;
802        // Receive a response. If message type is not response it returns an error.
803        match receive_message(self).await {
804            Ok((qmsg_type::response, response)) => Ok(response),
805            Err(error) => Err(error),
806            Ok((_, message)) => Err(io::Error::new(
807                io::ErrorKind::InvalidData,
808                format!("expected a response: {}", message),
809            )
810            .into()),
811        }
812    }
813
814    async fn receive_message(&mut self) -> Result<(u8, K)> {
815        receive_message(self).await
816    }
817}
818
819//%% MessageHeader %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
820
821impl MessageHeader {
822    /// Constructor.
823    fn new(encoding: u8, message_type: u8, compressed: u8, length: u32) -> Self {
824        MessageHeader {
825            encoding,
826            message_type,
827            compressed,
828            _unused: 0,
829            length,
830        }
831    }
832
833    /// Constructor from bytes.
834    fn from_bytes(bytes: [u8; 8]) -> Self {
835        let encoding = bytes[0];
836
837        let length = match encoding {
838            0 => u32::from_be_bytes(bytes[4..8].try_into().unwrap()),
839            _ => u32::from_le_bytes(bytes[4..8].try_into().unwrap()),
840        };
841
842        // Build header
843        MessageHeader::new(encoding, bytes[1], bytes[2], length)
844    }
845
846    /// Length of bytes for a header.
847    fn size() -> usize {
848        8
849    }
850}
851
852//++++++++++++++++++++++++++++++++++++++++++++++++++//
853// >> Private Functions
854//++++++++++++++++++++++++++++++++++++++++++++++++++//
855
856//%% QStream Connector %%//vvvvvvvvvvvvvvvvvvvvvvvvvv/
857
858/// Inner function of `connect_tcp` and `connect_tls` to establish a TCP connection with the sepcified
859///  endpoint. The hostname is resolved to an IP address with a system DNS resolver or parsed directly
860///  as an IP address.
861///
862/// Tries to connect to multiple resolved IP addresses until the first successful connection. Error is
863///  returned if none of them are valid.
864/// # Parameters
865/// - `host`: Hostname or IP address of the target q/kdb+ process.
866/// - `port`: Port of the target q process
867async fn connect_tcp_impl(host: &str, port: u16) -> Result<TcpStream> {
868    // DNS system resolver (should not fail)
869    let resolver =
870        TokioAsyncResolver::tokio_from_system_conf().expect("failed to create a resolver");
871
872    // Check if we were given an IP address
873    let ips;
874    if let Ok(ip) = host.parse::<IpAddr>() {
875        ips = vec![ip.to_string()]
876    } else {
877        // Resolve the given hostname
878        ips = resolver
879            .ipv4_lookup(format!("{}.", host).as_str())
880            .await
881            .unwrap()
882            .iter()
883            .map(|result| result.to_string())
884            .collect()
885    };
886
887    for answer in ips {
888        let host_port = format!("{}:{}", answer, port);
889        // Return if this IP address is valid
890        match TcpStream::connect(&host_port).await {
891            Ok(socket) => {
892                //println!("connected: {}", host_port);
893                return Ok(socket);
894            }
895            Err(_) => {
896                eprintln!("connection refused: {}. try next.", host_port);
897            }
898        }
899    }
900    // All addresses failed.
901    Err(io::Error::new(io::ErrorKind::ConnectionRefused, "failed to connect").into())
902}
903
904/// Send a credential and receive a common capacity.
905async fn handshake<S>(socket: &mut S, credential_: &str, method_bytes: &str) -> Result<()>
906where
907    S: Unpin + AsyncWriteExt + AsyncReadExt,
908{
909    // Send credential
910    let credential = credential_.to_string() + method_bytes;
911    socket.write_all(credential.as_bytes()).await?;
912
913    // Placeholder of common capablility
914    let mut cap = [0u8; 1];
915    if (socket.read_exact(&mut cap).await).is_err() {
916        // Connection is closed in case of authentication failure
917        Err(io::Error::new(io::ErrorKind::ConnectionAborted, "authentication failure").into())
918    } else {
919        Ok(())
920    }
921}
922
923/// Connect to q process running on a specified `host` and `port` via TCP with a credential `username:password`.
924/// # Parameters
925/// - `host`: Hostname or IP address of the target q process.
926/// - `port`: Port of the target q process.
927/// - `credential`: Credential in the form of `username:password` to connect to the target q process.
928async fn connect_tcp(host: &str, port: u16, credential: &str) -> Result<TcpStream> {
929    // Connect via TCP
930    let mut socket = connect_tcp_impl(host, port).await?;
931    // Handshake
932    handshake(&mut socket, credential, "\x03\x00").await?;
933    Ok(socket)
934}
935
936/// TLS version of `connect_tcp`.
937/// # Parameters
938/// - `host`: Hostname or IP address of the target q process.
939/// - `port`: Port of the target q process.
940/// - `credential`: Credential in the form of `username:password` to connect to the target q process.
941async fn connect_tls(host: &str, port: u16, credential: &str) -> Result<TlsStream<TcpStream>> {
942    // Connect via TCP
943    let socket_ = connect_tcp_impl(host, port).await?;
944    // Use TLS
945    let connector = TlsConnector::from(TlsConnectorInner::new().unwrap());
946    let mut socket = connector
947        .connect(host, socket_)
948        .await
949        .expect("failed to create TLS session");
950    // Handshake
951    handshake(&mut socket, credential, "\x03\x00").await?;
952    Ok(socket)
953}
954
955/// Build a path of a socket file.
956fn create_sockfile_path(port: u16) -> Result<String> {
957    // Create file path
958    let udspath = match env::var("QUDSPATH") {
959        Ok(dir) => format!("{}/kx.{}", dir, port),
960        Err(_) => format!("/tmp/kx.{}", port),
961    };
962
963    Ok(udspath)
964}
965
966/// Connect to q process running on the specified `port` via Unix domain socket with a credential `username:password`.
967/// # Parameters
968/// - `port`: Port of the target q process.
969/// - `credential`: Credential in the form of `username:password` to connect to the target q process.
970#[cfg(unix)]
971async fn connect_uds(port: u16, credential: &str) -> Result<UnixStream> {
972    // Create a file path.
973    let uds_path = create_sockfile_path(port)?;
974    let abstract_sockfile_ = format!("\x00{}", uds_path);
975    let abstract_sockfile = Path::new(&abstract_sockfile_);
976    // Connect to kdb+.
977    let mut socket = UnixStream::connect(&abstract_sockfile).await?;
978    // Handshake
979    handshake(&mut socket, credential, "\x06\x00").await?;
980
981    Ok(socket)
982}
983
984//%% QStream Acceptor %%//vvvvvvvvvvvvvvvvvvvvvvvvvvv/
985
986/// Read username, password, capacity and null byte from q client at the connection and does authentication.
987///  Close the handle if the authentication fails.
988async fn read_client_input<S>(socket: &mut S) -> Result<()>
989where
990    S: Unpin + AsyncWriteExt + AsyncReadExt,
991{
992    // Buffer to read inputs.
993    let mut client_input = [0u8; 32];
994    // credential will be built from small fractions of bytes.
995    let mut passed_credential = String::new();
996    loop {
997        // Read a client credential input.
998        match socket.read(&mut client_input).await {
999            Ok(0) => {
1000                // No bytes were read
1001            }
1002            Ok(_) => {
1003                // Locate a byte denoting a capacity
1004                if let Some(index) = client_input
1005                    .iter()
1006                    .position(|byte| *byte == 0x03 || *byte == 0x06)
1007                {
1008                    let capacity = client_input[index];
1009                    passed_credential
1010                        .push_str(str::from_utf8(&client_input[0..index]).expect("invalid bytes"));
1011                    let credential = passed_credential.as_str().split(':').collect::<Vec<&str>>();
1012                    #[allow(clippy::borrow_interior_mutable_const)]
1013                    if let Some(encoded) = ACCOUNTS.get(&credential[0].to_string()) {
1014                        // User exists
1015                        let mut hasher = Sha1::new();
1016                        hasher.update(credential[1].as_bytes());
1017                        let encoded_password = hasher.digest().to_string();
1018                        if encoded == &encoded_password {
1019                            // Client passed correct credential
1020                            socket.write_all(&[capacity; 1]).await?;
1021                            return Ok(());
1022                        } else {
1023                            // Authentication failure.
1024                            // Close connection.
1025                            socket.shutdown().await?;
1026                            return Err(io::Error::new(
1027                                io::ErrorKind::InvalidData,
1028                                "authentication failed",
1029                            )
1030                            .into());
1031                        }
1032                    } else {
1033                        // Authentication failure.
1034                        // Close connection.
1035                        socket.shutdown().await?;
1036                        return Err(io::Error::new(
1037                            io::ErrorKind::InvalidData,
1038                            "authentication failed",
1039                        )
1040                        .into());
1041                    }
1042                } else {
1043                    // Append a fraction of credential
1044                    passed_credential
1045                        .push_str(str::from_utf8(&client_input).expect("invalid bytes"));
1046                }
1047            }
1048            Err(error) => {
1049                return Err(error.into());
1050            }
1051        }
1052    }
1053}
1054
1055/// Check if server key exists and return teh contents.
1056async fn build_identity_from_cert() -> Result<Identity> {
1057    // Check if server key exists.
1058    if let Ok(path) = env::var("KDBPLUS_TLS_KEY_FILE") {
1059        if let Ok(password) = env::var("KDBPLUS_TLS_KEY_FILE_SECRET") {
1060            let cert_file = tokio::fs::File::open(Path::new(&path)).await.unwrap();
1061            let mut reader = BufReader::new(cert_file);
1062            let mut der: Vec<u8> = Vec::new();
1063            // Read the key file.
1064            reader.read_to_end(&mut der).await?;
1065            // Create identity.
1066            if let Ok(identity) = Identity::from_pkcs12(&der, &password) {
1067                Ok(identity)
1068            } else {
1069                Err(io::Error::new(io::ErrorKind::InvalidData, "authentication failed").into())
1070            }
1071        } else {
1072            Err(io::Error::new(
1073                io::ErrorKind::NotFound,
1074                "KDBPLUS_TLS_KEY_FILE_SECRET is not set",
1075            )
1076            .into())
1077        }
1078    } else {
1079        Err(io::Error::new(io::ErrorKind::NotFound, "KDBPLUS_TLS_KEY_FILE is not set").into())
1080    }
1081}
1082
1083//%% QStream Query %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
1084
1085/// Receive a message from q process with decompression if necessary. The received message is parsed as `K` and message type is
1086///  stored in the first returned value.
1087/// # Parameters
1088/// - `socket`: Socket to communicate with a q process. Either of `TcpStream`, `TlsStream<TcpStream>` or `UnixStream`.
1089async fn receive_message<S>(socket: &mut S) -> Result<(u8, K)>
1090where
1091    S: Unpin + AsyncReadExt,
1092{
1093    // Read header
1094    let mut header_buffer = [0u8; 8];
1095    if let Err(err) = socket.read_exact(&mut header_buffer).await {
1096        // The expected message is header or EOF (close due to q process failure resulting from a bad query)
1097        return Err(io::Error::new(
1098            io::ErrorKind::ConnectionAborted,
1099            format!("Connection dropped: {}", err),
1100        )
1101        .into());
1102    }
1103
1104    // Parse message header
1105    let header = MessageHeader::from_bytes(header_buffer);
1106
1107    // Read body
1108    let body_length = header.length as usize - MessageHeader::size();
1109    let mut body: Vec<u8> = vec![0; body_length];
1110    if let Err(err) = socket.read_exact(&mut body).await {
1111        // Fails if q process fails before reading the body
1112        return Err(io::Error::new(
1113            io::ErrorKind::UnexpectedEof,
1114            format!("Failed to read body of message: {}", err),
1115        )
1116        .into());
1117    }
1118
1119    // Decompress if necessary
1120    if header.compressed == 0x01 {
1121        body = decompress(body, header.encoding).await;
1122    }
1123
1124    Ok((
1125        header.message_type,
1126        K::q_ipc_decode(&body, header.encoding).await,
1127    ))
1128}
1129
1130/// Compress body. The combination of serializing the data and compressing will result in
1131/// the same output as shown in the q language by using the -18! function e.g.
1132/// serializing 2000 bools set to true, then compressing, will have the same output as `-18!2000#1b`.
1133/// # Parameter
1134/// - `raw`: Serialized message.
1135/// - `encode`: `0` if Big Endian; `1` if Little Endian.
1136async fn compress(raw: Vec<u8>) -> (bool, Vec<u8>) {
1137    let mut i = 0_u8;
1138    let mut f = 0_u8;
1139    let mut h0 = 0_usize;
1140    let mut h = 0_usize;
1141    let mut g: bool;
1142    let mut compressed: Vec<u8> = vec![0; (raw.len()) / 2];
1143
1144    // Start index of compressed body
1145    // 12 bytes are reserved for the header + size of raw bytes
1146    let mut c = 12;
1147    let mut d = c;
1148    let e = compressed.len();
1149    let mut p = 0_usize;
1150    let mut q: usize;
1151    let mut r: usize;
1152    let mut s0 = 0_usize;
1153
1154    // Body starts from index 8
1155    let mut s = 8_usize;
1156    let t = raw.len();
1157    let mut a = [0_i32; 256];
1158
1159    // Copy encode, message type, compressed and reserved
1160    compressed[0..4].copy_from_slice(&raw[0..4]);
1161    // Set compressed flag
1162    compressed[2] = 1;
1163
1164    // Write size of raw bytes including a header
1165    let raw_size = match ENCODING {
1166        0 => (t as u32).to_be_bytes(),
1167        _ => (t as u32).to_le_bytes(),
1168    };
1169    compressed[8..12].copy_from_slice(&raw_size);
1170
1171    while s < t {
1172        if i == 0 {
1173            if d > e - 17 {
1174                // Early return when compressing to less than half failed
1175                return (false, raw);
1176            }
1177            i = 1;
1178            compressed[c] = f;
1179            c = d;
1180            d += 1;
1181            f = 0;
1182        }
1183        g = s > t - 3;
1184        if !g {
1185            h = (raw[s] ^ raw[s + 1]) as usize;
1186            p = a[h] as usize;
1187            g = (0 == p) || (0 != (raw[s] ^ raw[p]));
1188        }
1189        if 0 < s0 {
1190            a[h0] = s0 as i32;
1191            s0 = 0;
1192        }
1193        if g {
1194            h0 = h;
1195            s0 = s;
1196            compressed[d] = raw[s];
1197            d += 1;
1198            s += 1;
1199        } else {
1200            a[h] = s as i32;
1201            f |= i;
1202            p += 2;
1203            s += 2;
1204            r = s;
1205            q = if s + 255 > t { t } else { s + 255 };
1206            while (s < q) && (raw[p] == raw[s]) {
1207                s += 1;
1208                if s < q {
1209                    p += 1;
1210                }
1211            }
1212            compressed[d] = h as u8;
1213            d += 1;
1214            compressed[d] = (s - r) as u8;
1215            d += 1;
1216        }
1217        i = i.wrapping_mul(2);
1218    }
1219    compressed[c] = f;
1220    // Final compressed data size
1221    let compressed_size = match ENCODING {
1222        0 => (d as u32).to_be_bytes(),
1223        _ => (d as u32).to_le_bytes(),
1224    };
1225    compressed[4..8].copy_from_slice(&compressed_size);
1226    let _ = compressed.split_off(d);
1227    (true, compressed)
1228}
1229
1230/// Decompress body. The combination of decompressing and deserializing the data
1231///  will result in the same output as shown in the q language by using the `-19!` function.
1232/// # Parameter
1233/// - `compressed`: Compressed serialized message.
1234/// - `encoding`:
1235///   - `0`: Big Endian
1236///   - `1`: Little Endian.
1237async fn decompress(compressed: Vec<u8>, encoding: u8) -> Vec<u8> {
1238    let mut n = 0;
1239    let mut r: usize;
1240    let mut f = 0_usize;
1241
1242    // Header has already been removed.
1243    // Start index of decompressed bytes is 0
1244    let mut s = 0_usize;
1245    let mut p = s;
1246    let mut i = 0_usize;
1247
1248    // Subtract 8 bytes from decoded bytes size as 8 bytes have already been taken as header
1249    let size = match encoding {
1250        0 => {
1251            i32::from_be_bytes(
1252                compressed[0..4]
1253                    .try_into()
1254                    .expect("slice does not have length 4"),
1255            ) - 8
1256        }
1257        _ => {
1258            i32::from_le_bytes(
1259                compressed[0..4]
1260                    .try_into()
1261                    .expect("slice does not have length 4"),
1262            ) - 8
1263        }
1264    };
1265    let mut decompressed: Vec<u8> = vec![0; size as usize];
1266
1267    // Start index of compressed body.
1268    // 8 bytes have already been removed as header
1269    let mut d = 4;
1270    let mut aa = [0_i32; 256];
1271    while s < decompressed.len() {
1272        if i == 0 {
1273            f = compressed[d] as usize;
1274            d += 1;
1275            i = 1;
1276        }
1277        if (f & i) != 0 {
1278            r = aa[compressed[d] as usize] as usize;
1279            d += 1;
1280            decompressed[s] = decompressed[r];
1281            s += 1;
1282            r += 1;
1283            decompressed[s] = decompressed[r];
1284            s += 1;
1285            r += 1;
1286            n = compressed[d] as usize;
1287            d += 1;
1288            for m in 0..n {
1289                decompressed[s + m] = decompressed[r + m];
1290            }
1291        } else {
1292            decompressed[s] = compressed[d];
1293            s += 1;
1294            d += 1;
1295        }
1296        while p < s - 1 {
1297            aa[(decompressed[p] ^ decompressed[p + 1]) as usize] = p as i32;
1298            p += 1;
1299        }
1300        if (f & i) != 0 {
1301            s += n;
1302            p = s;
1303        }
1304        i *= 2;
1305        if i == 256 {
1306            i = 0;
1307        }
1308    }
1309    decompressed
1310}