Skip to main content

kdb_plus_fixed/ipc/
connection.rs

1//++++++++++++++++++++++++++++++++++++++++++++++++++//
2// >> Load Libraries
3//++++++++++++++++++++++++++++++++++++++++++++++++++//
4
5use super::serialize::ENCODING;
6use super::Result;
7use super::{qtype, K};
8use async_trait::async_trait;
9use hickory_resolver::TokioAsyncResolver;
10use io::BufRead;
11use once_cell::sync::Lazy;
12use sha1_smol::Sha1;
13use std::collections::HashMap;
14use std::convert::TryInto;
15use std::net::{IpAddr, Ipv4Addr};
16use std::path::Path;
17use std::{env, fs, io, str};
18use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
19use tokio::net::{TcpListener, TcpStream};
20#[cfg(unix)]
21use tokio::net::{UnixListener, UnixStream};
22use tokio_native_tls::native_tls::{
23    Identity, TlsAcceptor as TlsAcceptorInner, TlsConnector as TlsConnectorInner,
24};
25use tokio_native_tls::{TlsAcceptor, TlsConnector, TlsStream};
26
27//++++++++++++++++++++++++++++++++++++++++++++++++++//
28// >> Global Variable
29//++++++++++++++++++++++++++++++++++++++++++++++++++//
30
31//%% QStream %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
32
33pub mod qmsg_type {
34    //! This module provides a list of q message type used for IPC.
35    //!  The motivation to contain them in a module is to tie them up as related items rather
36    //!  than scattered values. Hence user should use these indicators with `qmsg_type::` prefix, e.g., `qmsg_type::asynchronous`.
37    //!
38    //! # Example
39    //! ```no_run
40    //! use kdbplus::ipc::*;
41    //!
42    //! // Print `K` object.
43    //! fn print(obj: &K) {
44    //!     println!("{}", obj);
45    //! }
46    //!
47    //! // Calculate something from two long arguments.
48    //! fn nonsense(arg1: i64, arg2: i64) -> i64 {
49    //!     arg1 * arg2
50    //! }
51    //!
52    //! #[tokio::main]
53    //! async fn main() -> Result<()> {
54    //!     // Connect to qprocess running on localhost:5000 via TCP
55    //!     let mut socket =
56    //!         QStream::connect(ConnectionMethod::TCP, "localhost", 5000_u16, "ideal:person").await?;
57    //!
58    //!     // Set a function which sends back a non-response message during its execution.
59    //!     socket
60    //!         .send_async_message(
61    //!             &"complex:{neg[.z.w](`print; \"counter\"); what: .z.w (`nonsense; 1; 2); what*100}",
62    //!         )
63    //!         .await?;
64    //!
65    //!     // Send a query `(`complex; ::)` without waiting for a response.
66    //!     socket
67    //!         .send_message(
68    //!             &K::new_compound_list(vec![K::new_symbol(String::from("complex")), K::new_null()]),
69    //!             qmsg_type::synchronous,
70    //!         )
71    //!         .await?;
72    //!
73    //!     // Receive an asynchronous call from the function.
74    //!     match socket.receive_message().await {
75    //!         Ok((qmsg_type::asynchronous, message)) => {
76    //!             println!("asynchronous call: {}", message);
77    //!             let list = message.as_vec::<K>().unwrap();
78    //!             if list[0].get_symbol().unwrap() == "print" {
79    //!                 print(&list[1])
80    //!             }
81    //!         }
82    //!         _ => unreachable!(),
83    //!     }
84    //!
85    //!     // Receive a synchronous call from the function.
86    //!     match socket.receive_message().await {
87    //!         Ok((qmsg_type::synchronous, message)) => {
88    //!             println!("synchronous call: {}", message);
89    //!             let list = message.as_vec::<K>().unwrap();
90    //!             if list[0].get_symbol().unwrap() == "nonsense" {
91    //!                 let res = nonsense(list[1].get_long().unwrap(), list[2].get_long().unwrap());
92    //!                 // Send bach a response.
93    //!                 socket
94    //!                     .send_message(&K::new_long(res), qmsg_type::response)
95    //!                     .await?;
96    //!             }
97    //!         }
98    //!         _ => unreachable!(),
99    //!     }
100    //!
101    //!     // Receive a final result.
102    //!     match socket.receive_message().await {
103    //!         Ok((qmsg_type::response, message)) => {
104    //!             println!("final: {}", message);
105    //!         }
106    //!         _ => unreachable!(),
107    //!     }
108    //!
109    //!     Ok(())
110    //! }
111    //!```
112    /// Used to send a message to q/kdb+ asynchronously.
113    pub const asynchronous: u8 = 0;
114    /// Used to send a message to q/kdb+ synchronously.
115    pub const synchronous: u8 = 1;
116    /// Used by q/kdb+ to identify a response for a synchronous query.
117    pub const response: u8 = 2;
118}
119
120//%% QStream Acceptor %%//vvvvvvvvvvvvvvvvvvvvvvvvvvv/
121
122/// Map from user name to password hashed with SHA1.
123#[allow(clippy::declare_interior_mutable_const)]
124const ACCOUNTS: Lazy<HashMap<String, String>> = Lazy::new(|| {
125    // Map from user to password
126    let mut map: HashMap<String, String> = HashMap::new();
127    // Open credential file
128    let file = fs::OpenOptions::new()
129        .read(true)
130        .open(env::var("KDBPLUS_ACCOUNT_FILE").expect("KDBPLUS_ACCOUNT_FILE is not set"))
131        .expect("failed to open account file");
132    let mut reader = io::BufReader::new(file);
133    let mut line = String::new();
134    loop {
135        match reader.read_line(&mut line) {
136            Ok(0) => break,
137            Ok(_) => {
138                let credential = line.as_str().split(':').collect::<Vec<&str>>();
139                let mut password = credential[1];
140                if password.ends_with('\n') {
141                    password = &password[0..password.len() - 1];
142                }
143                map.insert(credential[0].to_string(), password.to_string());
144                line.clear();
145            }
146            _ => unreachable!(),
147        }
148    }
149    map
150});
151
152//++++++++++++++++++++++++++++++++++++++++++++++++++//
153// >> Structs
154//++++++++++++++++++++++++++++++++++++++++++++++++++//
155
156//%% ConnectionMethod %%//vvvvvvvvvvvvvvvvvvvvvvvvvvv/
157
158/// Connection method to q/kdb+.
159pub enum ConnectionMethod {
160    TCP = 0,
161    TLS = 1,
162    /// Unix domanin socket.
163    UDS = 2,
164}
165
166//%% Query %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
167
168/// Feature of query object.
169#[async_trait]
170pub trait Query: Send + Sync {
171    /// Serialize into q IPC bytes including a header (encoding, message type, compresssion flag and total message length).
172    ///   If the connection is within the same host, the message is not compressed under any conditions.
173    /// # Parameters
174    /// - `message_type`: Message type. One of followings:
175    ///   - `qmsg_type::asynchronous`
176    ///   - `qmsg_type::synchronous`
177    ///   - `qmsg_type::response`
178    /// - `is_local`: Flag of whether the connection is within the same host.
179    async fn serialize(&self, message_type: u8, is_local: bool) -> Vec<u8>;
180}
181
182//%% QStreamInner %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
183
184/// Features which streams communicating with q must have.
185#[async_trait]
186trait QStreamInner: Send + Sync {
187    /// Shutdown underlying stream.
188    async fn shutdown(&mut self, is_server: bool) -> Result<()>;
189    /// Send a message with a specified message type without waiting for a response.
190    /// # Parameters
191    /// - `message`: q command to execute on the remote q process.
192    /// - `message_type`: Asynchronous or synchronous.
193    /// - `is_local`: Flag of whether the connection is within the same host.
194    async fn send_message(
195        &mut self,
196        message: &dyn Query,
197        message_type: u8,
198        is_local: bool,
199    ) -> Result<()>;
200    /// Send a message asynchronously.
201    /// # Parameters
202    /// - `message`: q command in two ways:
203    ///   - `&str`: q command in a string form.
204    ///   - `K`: Query in a functional form.
205    /// - `is_local`: Flag of whether the connection is within the same host.
206    async fn send_async_message(&mut self, message: &dyn Query, is_local: bool) -> Result<()>;
207    /// Send a message asynchronously.
208    /// # Parameters
209    /// - `message`: q command in two ways:
210    ///   - `&str`: q command in a string form.
211    ///   - `K`: Query in a functional form.
212    /// - `is_local`: Flag of whether the connection is within the same host.
213    async fn send_sync_message(&mut self, message: &dyn Query, is_local: bool) -> Result<K>;
214    /// Receive a message from a remote q process. The received message is parsed as `K` and message type is
215    ///  stored in the first returned value.
216    async fn receive_message(&mut self) -> Result<(u8, K)>;
217}
218
219//%% QStream %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
220
221/// Stream to communicate with q/kdb+.
222pub struct QStream {
223    /// Actual stream to communicate.
224    stream: Box<dyn QStreamInner>,
225    /// Connection method. One of followings:
226    /// - TCP
227    /// - TLS
228    /// - UDS
229    method: ConnectionMethod,
230    /// Indicator of whether the stream is an acceptor or client.
231    /// - `true`: Acceptor
232    /// - `false`: Client
233    listener: bool,
234    /// Indicator of whether the connection is within the same host.
235    /// - `true`: Connection within the same host.
236    /// - `false`: Connection with outseide.
237    local: bool,
238}
239
240//%% MessageHeader %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
241
242/// Header of q IPC data frame.
243#[derive(Clone, Copy, Debug)]
244struct MessageHeader {
245    /// Ennoding.
246    /// - 0: Big Endian
247    /// - 1: Little Endian
248    encoding: u8,
249    /// Message type. One of followings:
250    /// - 0: Asynchronous
251    /// - 1: Synchronous
252    /// - 2: Response
253    message_type: u8,
254    /// Indicator of whether the message is compressed or not.
255    /// - 0: Uncompressed
256    /// - 1: Compressed
257    compressed: u8,
258    /// Reserved byte.
259    _unused: u8,
260    /// Total length of the uncompressed message.
261    length: u32,
262}
263
264//++++++++++++++++++++++++++++++++++++++++++++++++++//
265// >> Implementation
266//++++++++++++++++++++++++++++++++++++++++++++++++++//
267
268//%% Query %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
269
270/// Text query.
271#[async_trait]
272impl Query for &str {
273    async fn serialize(&self, message_type: u8, _: bool) -> Vec<u8> {
274        //  Build header //--------------------------------/
275        // Message header + (type indicator of string + header of string type) + string length
276        let byte_message = self.as_bytes();
277        let message_length = byte_message.len() as u32;
278        let total_length = MessageHeader::size() as u32 + 6 + message_length;
279
280        let total_length_bytes = match ENCODING {
281            0 => total_length.to_be_bytes(),
282            _ => total_length.to_le_bytes(),
283        };
284
285        // encode, message type, 0x00 for compression and 0x00 for reserved.
286        // Do not compress string data because it is highly unlikely that the length of the string query
287        //  is greater than 2000.
288        let mut message = Vec::with_capacity(message_length as usize + MessageHeader::size());
289        message.extend_from_slice(&[ENCODING, message_type, 0, 0]);
290        // total body length
291        message.extend_from_slice(&total_length_bytes);
292        // vector type and 0x00 for attribute
293        message.extend_from_slice(&[qtype::STRING as u8, 0]);
294
295        //  Build body //---------------------------------/
296        let length_info = match ENCODING {
297            0 => message_length.to_be_bytes(),
298            _ => message_length.to_le_bytes(),
299        };
300
301        // length of vector(message)
302        message.extend_from_slice(&length_info);
303        // message
304        message.extend_from_slice(byte_message);
305
306        message
307    }
308}
309
310/// Functional query.
311#[async_trait]
312impl Query for K {
313    async fn serialize(&self, message_type: u8, is_local: bool) -> Vec<u8> {
314        //  Build header //--------------------------------/
315        // Message header + encoded data size
316        let mut byte_message = self.q_ipc_encode();
317        let message_length = byte_message.len();
318        let total_length = (MessageHeader::size() + message_length) as u32;
319
320        let total_length_bytes = match ENCODING {
321            0 => total_length.to_be_bytes(),
322            _ => total_length.to_le_bytes(),
323        };
324
325        // Compression is trigerred when entire message size is more than 2000 bytes
326        //  and the connection is with outseide.
327        if message_length > 1992 && !is_local {
328            // encode, message type, 0x00 for compression, 0x00 for reserved and 0x00000000 for total size
329            let mut message = Vec::with_capacity(message_length + 8);
330            message.extend_from_slice(&[ENCODING, message_type, 0, 0, 0, 0, 0, 0]);
331            message.append(&mut byte_message);
332            // Try to encode entire message.
333            match compress(message).await {
334                (true, compressed) => {
335                    // Message was compressed
336                    return compressed;
337                }
338                (false, mut uncompressed) => {
339                    // Message was not compressed.
340                    // Write original total data size.
341                    uncompressed[4..8].copy_from_slice(&total_length_bytes);
342                    return uncompressed;
343                }
344            }
345        } else {
346            // encode, message type, 0x00 for compression and 0x00 for reserved
347            let mut message = Vec::with_capacity(message_length + MessageHeader::size());
348            message.extend_from_slice(&[ENCODING, message_type, 0, 0]);
349            // Total length of body
350            message.extend_from_slice(&total_length_bytes);
351            message.append(&mut byte_message);
352            return message;
353        }
354    }
355}
356
357//%% QStream %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
358
359impl QStream {
360    /// General constructor of `QStream`.
361    fn new(
362        stream: Box<dyn QStreamInner>,
363        method: ConnectionMethod,
364        is_listener: bool,
365        is_local: bool,
366    ) -> Self {
367        QStream {
368            stream,
369            method,
370            listener: is_listener,
371            local: is_local,
372        }
373    }
374
375    /// Connect to q/kdb+ specifying a connection method, destination host, destination port and access credential.
376    /// # Parameters
377    /// - `method`: Connection method. One of followings:
378    ///   - TCP
379    ///   - TLS
380    ///   - UDS
381    /// - `host`: Hostname or IP address of the target q process. Empty `str` for Unix domain socket.
382    /// - `port`: Port of the target q process.
383    /// - `credential`: Credential in the form of `username:password` to connect to the target q process.
384    /// # Example
385    /// ```
386    /// use kdbplus::qattribute;
387    /// use kdbplus::ipc::*;
388    ///
389    /// #[tokio::main(flavor = "multi_thread", worker_threads = 2)]
390    /// async fn main() -> Result<()> {
391    ///     let mut socket =
392    ///         QStream::connect(ConnectionMethod::UDS, "", 5000_u16, "ideal:person").await?;
393    ///     println!("Connection type: {}", socket.get_connection_type());
394    ///
395    ///     // Set remote function with asynchronous message
396    ///     socket.send_async_message(&"collatz:{[n] seq:enlist n; while[not n = 1; seq,: n:$[n mod 2; 1 + 3 * n; `long$n % 2]]; seq}").await?;
397    ///
398    ///     // Send a query synchronously
399    ///     let mut result = socket.send_sync_message(&"collatz[12]").await?;
400    ///     println!("collatz[12]: {}", result);
401    ///
402    ///     // Send a functional query.
403    ///     let mut message = K::new_compound_list(vec![
404    ///         K::new_symbol(String::from("collatz")),
405    ///         K::new_long(100),
406    ///     ]);
407    ///     result = socket.send_sync_message(&message).await?;
408    ///     println!("collatz[100]: {}", result);
409    ///
410    ///     // Send a functional asynchronous query.
411    ///     message = K::new_compound_list(vec![
412    ///         K::new_string(String::from("show"), qattribute::NONE),
413    ///         K::new_symbol(String::from("goodbye")),
414    ///     ]);
415    ///     socket.send_async_message(&message).await?;
416    ///
417    ///     socket.shutdown().await?;
418    ///
419    ///     Ok(())
420    /// }
421    /// ```
422    pub async fn connect(
423        method: ConnectionMethod,
424        host: &str,
425        port: u16,
426        credential: &str,
427    ) -> Result<Self> {
428        match method {
429            ConnectionMethod::TCP => {
430                let stream = connect_tcp(host, port, credential).await?;
431                let is_local = matches!(host, "localhost" | "127.0.0.1");
432                Ok(QStream::new(
433                    Box::new(stream),
434                    ConnectionMethod::TCP,
435                    false,
436                    is_local,
437                ))
438            }
439            ConnectionMethod::TLS => {
440                let stream = connect_tls(host, port, credential).await?;
441                Ok(QStream::new(
442                    Box::new(stream),
443                    ConnectionMethod::TLS,
444                    false,
445                    false,
446                ))
447            }
448            #[cfg(unix)]
449            ConnectionMethod::UDS => {
450                let stream = connect_uds(port, credential).await?;
451                Ok(QStream::new(
452                    Box::new(stream),
453                    ConnectionMethod::UDS,
454                    false,
455                    true,
456                ))
457            }
458            #[cfg(not(unix))]
459            ConnectionMethod::UDS => Err(io::Error::new(
460                io::ErrorKind::Unsupported,
461                "Unix Domain Socket is not supported on this platform",
462            )
463            .into()),
464        }
465    }
466
467    /// Accept connection and does handshake.
468    /// # Parameters
469    /// - `method`: Connection method. One of followings:
470    ///   - TCP
471    ///   - TLS
472    ///   - UDS
473    /// - host: Hostname or IP address of this listener. Empty `str` for Unix domain socket.
474    /// - port: Listening port.
475    /// # Example
476    /// ```no_run
477    /// use kdbplus::ipc::*;
478    ///  
479    /// #[tokio::main]
480    /// async fn main() -> Result<()> {
481    ///     // Start listenening over UDS at the port 7000 with authentication enabled.
482    ///     while let Ok(mut socket) = QStream::accept(ConnectionMethod::UDS, "", 7000).await {
483    ///         tokio::task::spawn(async move {
484    ///             loop {
485    ///                 match socket.receive_message().await {
486    ///                     Ok((_, message)) => {
487    ///                         println!("request: {}", message);
488    ///                     }
489    ///                     _ => {
490    ///                         socket.shutdown().await.unwrap();
491    ///                         break;
492    ///                     }
493    ///                 }
494    ///             }
495    ///         });
496    ///     }
497    ///
498    ///     Ok(())
499    /// }
500    /// ```
501    /// q processes can connect and send messages to this acceptor.
502    /// ```q
503    /// q)// Process1
504    /// q)h:hopen `:unix://7000:reluctant:slowday
505    /// q)neg[h] (`monalizza; 3.8)
506    /// q)neg[h] (`pizza; 125)
507    /// ```
508    /// ```q
509    /// q)// Process2
510    /// q)h:hopen `:unix://7000:mattew:oracle
511    /// q)neg[h] (`teddy; "bear")
512    /// ```
513    /// # Note
514    /// - TLS acceptor sets `.kdbplus.close_tls_connection_` on q clien via an asynchronous message. This function is necessary to close
515    ///   the socket from the server side without crashing server side application.
516    /// - TLS acceptor and UDS acceptor use specific environmental variables to work. See the [Environmental Variable](../ipc/index.html#environmentl-variables) section for details.
517    pub async fn accept(method: ConnectionMethod, host: &str, port: u16) -> Result<Self> {
518        match method {
519            ConnectionMethod::TCP => {
520                // Bind to the endpoint.
521                let listener = TcpListener::bind(&format!("{}:{}", host, port)).await?;
522                // Listen to the endpoint.
523                let (mut socket, ip_address) = listener.accept().await?;
524                // Read untill null bytes and send back capacity.
525                while (read_client_input(&mut socket).await).is_err() {
526                    // Continue to listen in case of error.
527                    socket = listener.accept().await?.0;
528                }
529                // Check if the connection is local
530                Ok(QStream::new(
531                    Box::new(socket),
532                    ConnectionMethod::TCP,
533                    true,
534                    ip_address.ip() == IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
535                ))
536            }
537            ConnectionMethod::TLS => {
538                // Bind to the endpoint.
539                let listener = TcpListener::bind(&format!("{}:{}", host, port)).await?;
540                // Check if key exists and decode an identity with a given password.
541                let identity = build_identity_from_cert().await?;
542                // Build TLS acceptor.
543                let tls_acceptor = TlsAcceptor::from(TlsAcceptorInner::new(identity).unwrap());
544                // Listen to the endpoint.
545                let (mut socket, _) = listener.accept().await?;
546                // TLS processing.
547                let mut tls_socket = tls_acceptor
548                    .accept(socket)
549                    .await
550                    .expect("failed to accept TLS connection");
551                // Read untill null bytes and send back a capacity.
552                while (read_client_input(&mut tls_socket).await).is_err() {
553                    // Continue to listen in case of error.
554                    socket = listener.accept().await?.0;
555                    tls_socket = tls_acceptor
556                        .accept(socket)
557                        .await
558                        .expect("failed to accept TLS connection");
559                }
560                // TLS is always a remote connection
561                let mut qstream =
562                    QStream::new(Box::new(tls_socket), ConnectionMethod::TCP, true, false);
563                // In order to close the connection from the server side, it needs to tell a client to close the connection.
564                // The `kdbplus_close_tls_connection_` will be called from the server at shutdown.
565                qstream
566                    .send_async_message(&".kdbplus.close_tls_connection_:{[] hclose .z.w;}")
567                    .await?;
568                Ok(qstream)
569            }
570            #[cfg(unix)]
571            ConnectionMethod::UDS => {
572                // uild a sockt file path.
573                let uds_path = create_sockfile_path(port)?;
574                let abstract_sockfile_ = format!("\x00{}", uds_path);
575                let abstract_sockfile = Path::new(&abstract_sockfile_);
576                // Bind to the file
577                let listener = UnixListener::bind(abstract_sockfile).unwrap();
578                // Listen to the endpoint
579                let (mut socket, _) = listener.accept().await?;
580                // Read untill null bytes and send back capacity.
581                while (read_client_input(&mut socket).await).is_err() {
582                    // Continue to listen in case of error.
583                    socket = listener.accept().await?.0;
584                }
585                // UDS is always a local connection
586                Ok(QStream::new(Box::new(socket), method, true, true))
587            }
588            #[cfg(not(unix))]
589            ConnectionMethod::UDS => Err(io::Error::new(
590                io::ErrorKind::Unsupported,
591                "Unix Domain Socket is not supported on this platform",
592            )
593            .into()),
594        }
595    }
596
597    /// Shutdown the socket for a q process.
598    /// # Example
599    /// See the example of [`connect`](#method.connect).
600    pub async fn shutdown(mut self) -> Result<()> {
601        self.stream.shutdown(self.listener).await
602    }
603
604    /// Send a message with a specified message type without waiting for a response even for a synchronous message.
605    ///  If you need to receive a response you need to use [`receive_message`](#method.receive_message).
606    /// # Note
607    /// The usage of this function for a synchronous message is to handle an asynchronous message or a synchronous message
608    ///   sent by a remote function during its execution.
609    /// # Parameters
610    /// - `message`: q command to execute on the remote q process.
611    ///   - `&str`: q command in a string form.
612    ///   - `K`: Query in a functional form.
613    /// - `message_type`: Asynchronous or synchronous.
614    /// # Example
615    /// See the example of [`connect`](#method.connect).
616    pub async fn send_message(&mut self, message: &dyn Query, message_type: u8) -> Result<()> {
617        self.stream
618            .send_message(message, message_type, self.local)
619            .await
620    }
621
622    /// Send a message asynchronously.
623    /// # Parameters
624    /// - `message`: q command to execute on the remote q process.
625    ///   - `&str`: q command in a string form.
626    ///   - `K`: Query in a functional form.
627    /// # Example
628    /// See the example of [`connect`](#method.connect).
629    pub async fn send_async_message(&mut self, message: &dyn Query) -> Result<()> {
630        self.stream.send_async_message(message, self.local).await
631    }
632
633    /// Send a message synchronously.
634    /// # Note
635    /// Remote function must NOT send back a message of asynchronous or synchronous type durning execution of the function.
636    /// # Parameters
637    /// - `message`: q command to execute on the remote q process.
638    ///   - `&str`: q command in a string form.
639    ///   - `K`: Query in a functional form.
640    /// # Example
641    /// See the example of [`connect`](#method.connect).
642    pub async fn send_sync_message(&mut self, message: &dyn Query) -> Result<K> {
643        self.stream.send_sync_message(message, self.local).await
644    }
645
646    /// Receive a message from a remote q process. The received message is parsed as `K` and message type is
647    ///  stored in the first returned value.
648    /// # Example
649    /// See the example of [`accept`](#method.accept).
650    pub async fn receive_message(&mut self) -> Result<(u8, K)> {
651        self.stream.receive_message().await
652    }
653
654    /// Return underlying connection type. One of `TCP`, `TLS` or `UDS`.
655    /// # Example
656    /// See the example of [`connect`](#method.connect).
657    pub fn get_connection_type(&self) -> &str {
658        match self.method {
659            ConnectionMethod::TCP => "TCP",
660            ConnectionMethod::TLS => "TLS",
661            ConnectionMethod::UDS => "UDS",
662        }
663    }
664
665    /// Enforce compression if the size of a message exceeds 2000 regardless of locality of the connection.
666    ///  This flag is not revertible intentionally.
667    pub fn enforce_compression(&mut self) {
668        self.local = false;
669    }
670}
671
672//%% QStreamInner %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
673
674#[async_trait]
675impl QStreamInner for TcpStream {
676    async fn shutdown(&mut self, _: bool) -> Result<()> {
677        AsyncWriteExt::shutdown(self).await?;
678        Ok(())
679    }
680
681    async fn send_message(
682        &mut self,
683        message: &dyn Query,
684        message_type: u8,
685        is_local: bool,
686    ) -> Result<()> {
687        // Serialize a message
688        let byte_message = message.serialize(message_type, is_local).await;
689        // Send the message
690        self.write_all(&byte_message).await?;
691        Ok(())
692    }
693
694    async fn send_async_message(&mut self, message: &dyn Query, is_local: bool) -> Result<()> {
695        // Serialize a message
696        let byte_message = message.serialize(qmsg_type::asynchronous, is_local).await;
697        // Send the message
698        self.write_all(&byte_message).await?;
699        Ok(())
700    }
701
702    async fn send_sync_message(&mut self, message: &dyn Query, is_local: bool) -> Result<K> {
703        // Serialize a message
704        let byte_message = message.serialize(qmsg_type::synchronous, is_local).await;
705        // Send the message
706        self.write_all(&byte_message).await?;
707        // Receive a response. If message type is not response it returns an error.
708        match receive_message(self).await {
709            Ok((qmsg_type::response, response)) => Ok(response),
710            Err(error) => Err(error),
711            Ok((_, message)) => Err(io::Error::new(
712                io::ErrorKind::InvalidData,
713                format!("expected a response: {}", message),
714            )
715            .into()),
716        }
717    }
718
719    async fn receive_message(&mut self) -> Result<(u8, K)> {
720        receive_message(self).await
721    }
722}
723
724#[async_trait]
725impl QStreamInner for TlsStream<TcpStream> {
726    async fn shutdown(&mut self, is_listener: bool) -> Result<()> {
727        if is_listener {
728            // Closing the handle from the server side by `self.get_mut().shutdown()` crashes due to 'assertion failed: !self.context.is_null()'.
729            // No reason to compress.
730            self.send_async_message(&".kdbplus.close_tls_connection_[]", false)
731                .await
732        } else {
733            self.get_mut().shutdown()?;
734            Ok(())
735        }
736    }
737
738    async fn send_message(
739        &mut self,
740        message: &dyn Query,
741        message_type: u8,
742        is_local: bool,
743    ) -> Result<()> {
744        // Serialize a message
745        let byte_message = message.serialize(message_type, is_local).await;
746        // Send the message
747        self.write_all(&byte_message).await?;
748        Ok(())
749    }
750
751    async fn send_async_message(&mut self, message: &dyn Query, is_local: bool) -> Result<()> {
752        // Serialize a message
753        let byte_message = message.serialize(qmsg_type::asynchronous, is_local).await;
754        // Send the message
755        self.write_all(&byte_message).await?;
756        Ok(())
757    }
758
759    async fn send_sync_message(&mut self, message: &dyn Query, is_local: bool) -> Result<K> {
760        // Serialize a message
761        let byte_message = message.serialize(qmsg_type::synchronous, is_local).await;
762        // Send the message
763        self.write_all(&byte_message).await?;
764        // Receive a response. If message type is not response it returns an error.
765        match receive_message(self).await {
766            Ok((qmsg_type::response, response)) => Ok(response),
767            Err(error) => Err(error),
768            Ok((_, message)) => Err(io::Error::new(
769                io::ErrorKind::InvalidData,
770                format!("expected a response: {}", message),
771            )
772            .into()),
773        }
774    }
775
776    async fn receive_message(&mut self) -> Result<(u8, K)> {
777        receive_message(self).await
778    }
779}
780
781#[cfg(unix)]
782#[async_trait]
783impl QStreamInner for UnixStream {
784    /// Close a handle to a q process which is connected with Unix Domain Socket.
785    ///  Socket file is removed.
786    async fn shutdown(&mut self, _: bool) -> Result<()> {
787        AsyncWriteExt::shutdown(self).await?;
788        Ok(())
789    }
790
791    async fn send_message(
792        &mut self,
793        message: &dyn Query,
794        message_type: u8,
795        is_local: bool,
796    ) -> Result<()> {
797        // Serialize a message
798        let byte_message = message.serialize(message_type, is_local).await;
799        // Send the message
800        self.write_all(&byte_message).await?;
801        Ok(())
802    }
803
804    async fn send_async_message(&mut self, message: &dyn Query, is_local: bool) -> Result<()> {
805        // Serialize a message
806        let byte_message = message.serialize(qmsg_type::asynchronous, is_local).await;
807        // Send the message
808        self.write_all(&byte_message).await?;
809        Ok(())
810    }
811
812    async fn send_sync_message(&mut self, message: &dyn Query, is_local: bool) -> Result<K> {
813        // Serialize a message
814        let byte_message = message.serialize(qmsg_type::synchronous, is_local).await;
815        // Send the message
816        self.write_all(&byte_message).await?;
817        // Receive a response. If message type is not response it returns an error.
818        match receive_message(self).await {
819            Ok((qmsg_type::response, response)) => Ok(response),
820            Err(error) => Err(error),
821            Ok((_, message)) => Err(io::Error::new(
822                io::ErrorKind::InvalidData,
823                format!("expected a response: {}", message),
824            )
825            .into()),
826        }
827    }
828
829    async fn receive_message(&mut self) -> Result<(u8, K)> {
830        receive_message(self).await
831    }
832}
833
834//%% MessageHeader %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
835
836impl MessageHeader {
837    /// Constructor.
838    fn new(encoding: u8, message_type: u8, compressed: u8, length: u32) -> Self {
839        MessageHeader {
840            encoding,
841            message_type,
842            compressed,
843            _unused: 0,
844            length,
845        }
846    }
847
848    /// Constructor from bytes.
849    fn from_bytes(bytes: [u8; 8]) -> Self {
850        let encoding = bytes[0];
851
852        let length = match encoding {
853            0 => u32::from_be_bytes(bytes[4..8].try_into().unwrap()),
854            _ => u32::from_le_bytes(bytes[4..8].try_into().unwrap()),
855        };
856
857        // Build header
858        MessageHeader::new(encoding, bytes[1], bytes[2], length)
859    }
860
861    /// Length of bytes for a header.
862    fn size() -> usize {
863        8
864    }
865}
866
867//++++++++++++++++++++++++++++++++++++++++++++++++++//
868// >> Private Functions
869//++++++++++++++++++++++++++++++++++++++++++++++++++//
870
871//%% QStream Connector %%//vvvvvvvvvvvvvvvvvvvvvvvvvv/
872
873/// Inner function of `connect_tcp` and `connect_tls` to establish a TCP connection with the sepcified
874///  endpoint. The hostname is resolved to an IP address with a system DNS resolver or parsed directly
875///  as an IP address.
876///
877/// Tries to connect to multiple resolved IP addresses until the first successful connection. Error is
878///  returned if none of them are valid.
879/// # Parameters
880/// - `host`: Hostname or IP address of the target q/kdb+ process.
881/// - `port`: Port of the target q process
882async fn connect_tcp_impl(host: &str, port: u16) -> Result<TcpStream> {
883    // DNS system resolver (should not fail)
884    let resolver =
885        TokioAsyncResolver::tokio_from_system_conf().expect("failed to create a resolver");
886
887    // Check if we were given an IP address
888    let ips;
889    if let Ok(ip) = host.parse::<IpAddr>() {
890        ips = vec![ip.to_string()]
891    } else {
892        // Resolve the given hostname
893        ips = resolver
894            .ipv4_lookup(format!("{}.", host).as_str())
895            .await
896            .unwrap()
897            .iter()
898            .map(|result| result.to_string())
899            .collect()
900    };
901
902    for answer in ips {
903        let host_port = format!("{}:{}", answer, port);
904        // Return if this IP address is valid
905        match TcpStream::connect(&host_port).await {
906            Ok(socket) => {
907                //println!("connected: {}", host_port);
908                return Ok(socket);
909            }
910            Err(_) => {
911                eprintln!("connection refused: {}. try next.", host_port);
912            }
913        }
914    }
915    // All addresses failed.
916    Err(io::Error::new(io::ErrorKind::ConnectionRefused, "failed to connect").into())
917}
918
919/// Send a credential and receive a common capacity.
920async fn handshake<S>(socket: &mut S, credential_: &str, method_bytes: &str) -> Result<()>
921where
922    S: Unpin + AsyncWriteExt + AsyncReadExt,
923{
924    // Send credential
925    let credential = credential_.to_string() + method_bytes;
926    socket.write_all(credential.as_bytes()).await?;
927
928    // Placeholder of common capablility
929    let mut cap = [0u8; 1];
930    if (socket.read_exact(&mut cap).await).is_err() {
931        // Connection is closed in case of authentication failure
932        Err(io::Error::new(io::ErrorKind::ConnectionAborted, "authentication failure").into())
933    } else {
934        Ok(())
935    }
936}
937
938/// Connect to q process running on a specified `host` and `port` via TCP with a credential `username:password`.
939/// # Parameters
940/// - `host`: Hostname or IP address of the target q process.
941/// - `port`: Port of the target q process.
942/// - `credential`: Credential in the form of `username:password` to connect to the target q process.
943async fn connect_tcp(host: &str, port: u16, credential: &str) -> Result<TcpStream> {
944    // Connect via TCP
945    let mut socket = connect_tcp_impl(host, port).await?;
946    // Handshake
947    handshake(&mut socket, credential, "\x03\x00").await?;
948    Ok(socket)
949}
950
951/// TLS version of `connect_tcp`.
952/// # Parameters
953/// - `host`: Hostname or IP address of the target q process.
954/// - `port`: Port of the target q process.
955/// - `credential`: Credential in the form of `username:password` to connect to the target q process.
956async fn connect_tls(host: &str, port: u16, credential: &str) -> Result<TlsStream<TcpStream>> {
957    // Connect via TCP
958    let socket_ = connect_tcp_impl(host, port).await?;
959    // Use TLS
960    let connector = TlsConnector::from(TlsConnectorInner::new().unwrap());
961    let mut socket = connector
962        .connect(host, socket_)
963        .await
964        .expect("failed to create TLS session");
965    // Handshake
966    handshake(&mut socket, credential, "\x03\x00").await?;
967    Ok(socket)
968}
969
970/// Build a path of a socket file.
971fn create_sockfile_path(port: u16) -> Result<String> {
972    // Create file path
973    let udspath = match env::var("QUDSPATH") {
974        Ok(dir) => format!("{}/kx.{}", dir, port),
975        Err(_) => format!("/tmp/kx.{}", port),
976    };
977
978    Ok(udspath)
979}
980
981/// Connect to q process running on the specified `port` via Unix domain socket with a credential `username:password`.
982/// # Parameters
983/// - `port`: Port of the target q process.
984/// - `credential`: Credential in the form of `username:password` to connect to the target q process.
985#[cfg(unix)]
986async fn connect_uds(port: u16, credential: &str) -> Result<UnixStream> {
987    // Create a file path.
988    let uds_path = create_sockfile_path(port)?;
989    let abstract_sockfile_ = format!("\x00{}", uds_path);
990    let abstract_sockfile = Path::new(&abstract_sockfile_);
991    // Connect to kdb+.
992    let mut socket = UnixStream::connect(&abstract_sockfile).await?;
993    // Handshake
994    handshake(&mut socket, credential, "\x06\x00").await?;
995
996    Ok(socket)
997}
998
999//%% QStream Acceptor %%//vvvvvvvvvvvvvvvvvvvvvvvvvvv/
1000
1001/// Read username, password, capacity and null byte from q client at the connection and does authentication.
1002///  Close the handle if the authentication fails.
1003async fn read_client_input<S>(socket: &mut S) -> Result<()>
1004where
1005    S: Unpin + AsyncWriteExt + AsyncReadExt,
1006{
1007    // Buffer to read inputs.
1008    let mut client_input = [0u8; 32];
1009    // credential will be built from small fractions of bytes.
1010    let mut passed_credential = String::new();
1011    loop {
1012        // Read a client credential input.
1013        match socket.read(&mut client_input).await {
1014            Ok(0) => {
1015                // No bytes were read
1016            }
1017            Ok(_) => {
1018                // Locate a byte denoting a capacity
1019                if let Some(index) = client_input
1020                    .iter()
1021                    .position(|byte| *byte == 0x03 || *byte == 0x06)
1022                {
1023                    let capacity = client_input[index];
1024                    passed_credential
1025                        .push_str(str::from_utf8(&client_input[0..index]).expect("invalid bytes"));
1026                    let credential = passed_credential.as_str().split(':').collect::<Vec<&str>>();
1027                    #[allow(clippy::borrow_interior_mutable_const)]
1028                    if let Some(encoded) = ACCOUNTS.get(&credential[0].to_string()) {
1029                        // User exists
1030                        let mut hasher = Sha1::new();
1031                        hasher.update(credential[1].as_bytes());
1032                        let encoded_password = hasher.digest().to_string();
1033                        if encoded == &encoded_password {
1034                            // Client passed correct credential
1035                            socket.write_all(&[capacity; 1]).await?;
1036                            return Ok(());
1037                        } else {
1038                            // Authentication failure.
1039                            // Close connection.
1040                            socket.shutdown().await?;
1041                            return Err(io::Error::new(
1042                                io::ErrorKind::InvalidData,
1043                                "authentication failed",
1044                            )
1045                            .into());
1046                        }
1047                    } else {
1048                        // Authentication failure.
1049                        // Close connection.
1050                        socket.shutdown().await?;
1051                        return Err(io::Error::new(
1052                            io::ErrorKind::InvalidData,
1053                            "authentication failed",
1054                        )
1055                        .into());
1056                    }
1057                } else {
1058                    // Append a fraction of credential
1059                    passed_credential
1060                        .push_str(str::from_utf8(&client_input).expect("invalid bytes"));
1061                }
1062            }
1063            Err(error) => {
1064                return Err(error.into());
1065            }
1066        }
1067    }
1068}
1069
1070/// Check if server key exists and return teh contents.
1071async fn build_identity_from_cert() -> Result<Identity> {
1072    // Check if server key exists.
1073    if let Ok(path) = env::var("KDBPLUS_TLS_KEY_FILE") {
1074        if let Ok(password) = env::var("KDBPLUS_TLS_KEY_FILE_SECRET") {
1075            let cert_file = tokio::fs::File::open(Path::new(&path)).await.unwrap();
1076            let mut reader = BufReader::new(cert_file);
1077            let mut der: Vec<u8> = Vec::new();
1078            // Read the key file.
1079            reader.read_to_end(&mut der).await?;
1080            // Create identity.
1081            if let Ok(identity) = Identity::from_pkcs12(&der, &password) {
1082                Ok(identity)
1083            } else {
1084                Err(io::Error::new(io::ErrorKind::InvalidData, "authentication failed").into())
1085            }
1086        } else {
1087            Err(io::Error::new(
1088                io::ErrorKind::NotFound,
1089                "KDBPLUS_TLS_KEY_FILE_SECRET is not set",
1090            )
1091            .into())
1092        }
1093    } else {
1094        Err(io::Error::new(io::ErrorKind::NotFound, "KDBPLUS_TLS_KEY_FILE is not set").into())
1095    }
1096}
1097
1098//%% QStream Query %%//vvvvvvvvvvvvvvvvvvvvvvvvvvvvvv/
1099
1100/// Receive a message from q process with decompression if necessary. The received message is parsed as `K` and message type is
1101///  stored in the first returned value.
1102/// # Parameters
1103/// - `socket`: Socket to communicate with a q process. Either of `TcpStream`, `TlsStream<TcpStream>` or `UnixStream`.
1104async fn receive_message<S>(socket: &mut S) -> Result<(u8, K)>
1105where
1106    S: Unpin + AsyncReadExt,
1107{
1108    // Read header
1109    let mut header_buffer = [0u8; 8];
1110    if let Err(err) = socket.read_exact(&mut header_buffer).await {
1111        // The expected message is header or EOF (close due to q process failure resulting from a bad query)
1112        return Err(io::Error::new(
1113            io::ErrorKind::ConnectionAborted,
1114            format!("Connection dropped: {}", err),
1115        )
1116        .into());
1117    }
1118
1119    // Parse message header
1120    let header = MessageHeader::from_bytes(header_buffer);
1121
1122    // Read body
1123    let body_length = header.length as usize - MessageHeader::size();
1124    let mut body: Vec<u8> = vec![0; body_length];
1125    if let Err(err) = socket.read_exact(&mut body).await {
1126        // Fails if q process fails before reading the body
1127        return Err(io::Error::new(
1128            io::ErrorKind::UnexpectedEof,
1129            format!("Failed to read body of message: {}", err),
1130        )
1131        .into());
1132    }
1133
1134    // Decompress if necessary
1135    if header.compressed == 0x01 {
1136        body = decompress(body, header.encoding).await;
1137    }
1138
1139    Ok((
1140        header.message_type,
1141        K::q_ipc_decode(&body, header.encoding).await,
1142    ))
1143}
1144
1145/// Compress body. The combination of serializing the data and compressing will result in
1146/// the same output as shown in the q language by using the -18! function e.g.
1147/// serializing 2000 bools set to true, then compressing, will have the same output as `-18!2000#1b`.
1148/// # Parameter
1149/// - `raw`: Serialized message.
1150/// - `encode`: `0` if Big Endian; `1` if Little Endian.
1151async fn compress(raw: Vec<u8>) -> (bool, Vec<u8>) {
1152    let mut i = 0_u8;
1153    let mut f = 0_u8;
1154    let mut h0 = 0_usize;
1155    let mut h = 0_usize;
1156    let mut g: bool;
1157    let mut compressed: Vec<u8> = vec![0; (raw.len()) / 2];
1158
1159    // Start index of compressed body
1160    // 12 bytes are reserved for the header + size of raw bytes
1161    let mut c = 12;
1162    let mut d = c;
1163    let e = compressed.len();
1164    let mut p = 0_usize;
1165    let mut q: usize;
1166    let mut r: usize;
1167    let mut s0 = 0_usize;
1168
1169    // Body starts from index 8
1170    let mut s = 8_usize;
1171    let t = raw.len();
1172    let mut a = [0_i32; 256];
1173
1174    // Copy encode, message type, compressed and reserved
1175    compressed[0..4].copy_from_slice(&raw[0..4]);
1176    // Set compressed flag
1177    compressed[2] = 1;
1178
1179    // Write size of raw bytes including a header
1180    let raw_size = match ENCODING {
1181        0 => (t as u32).to_be_bytes(),
1182        _ => (t as u32).to_le_bytes(),
1183    };
1184    compressed[8..12].copy_from_slice(&raw_size);
1185
1186    while s < t {
1187        if i == 0 {
1188            if d > e - 17 {
1189                // Early return when compressing to less than half failed
1190                return (false, raw);
1191            }
1192            i = 1;
1193            compressed[c] = f;
1194            c = d;
1195            d += 1;
1196            f = 0;
1197        }
1198        g = s > t - 3;
1199        if !g {
1200            h = (raw[s] ^ raw[s + 1]) as usize;
1201            p = a[h] as usize;
1202            g = (0 == p) || (0 != (raw[s] ^ raw[p]));
1203        }
1204        if 0 < s0 {
1205            a[h0] = s0 as i32;
1206            s0 = 0;
1207        }
1208        if g {
1209            h0 = h;
1210            s0 = s;
1211            compressed[d] = raw[s];
1212            d += 1;
1213            s += 1;
1214        } else {
1215            a[h] = s as i32;
1216            f |= i;
1217            p += 2;
1218            s += 2;
1219            r = s;
1220            q = if s + 255 > t { t } else { s + 255 };
1221            while (s < q) && (raw[p] == raw[s]) {
1222                s += 1;
1223                if s < q {
1224                    p += 1;
1225                }
1226            }
1227            compressed[d] = h as u8;
1228            d += 1;
1229            compressed[d] = (s - r) as u8;
1230            d += 1;
1231        }
1232        i = i.wrapping_mul(2);
1233    }
1234    compressed[c] = f;
1235    // Final compressed data size
1236    let compressed_size = match ENCODING {
1237        0 => (d as u32).to_be_bytes(),
1238        _ => (d as u32).to_le_bytes(),
1239    };
1240    compressed[4..8].copy_from_slice(&compressed_size);
1241    let _ = compressed.split_off(d);
1242    (true, compressed)
1243}
1244
1245/// Decompress body. The combination of decompressing and deserializing the data
1246///  will result in the same output as shown in the q language by using the `-19!` function.
1247/// # Parameter
1248/// - `compressed`: Compressed serialized message.
1249/// - `encoding`:
1250///   - `0`: Big Endian
1251///   - `1`: Little Endian.
1252async fn decompress(compressed: Vec<u8>, encoding: u8) -> Vec<u8> {
1253    let mut n = 0;
1254    let mut r: usize;
1255    let mut f = 0_usize;
1256
1257    // Header has already been removed.
1258    // Start index of decompressed bytes is 0
1259    let mut s = 0_usize;
1260    let mut p = s;
1261    let mut i = 0_usize;
1262
1263    // Subtract 8 bytes from decoded bytes size as 8 bytes have already been taken as header
1264    let size = match encoding {
1265        0 => {
1266            i32::from_be_bytes(
1267                compressed[0..4]
1268                    .try_into()
1269                    .expect("slice does not have length 4"),
1270            ) - 8
1271        }
1272        _ => {
1273            i32::from_le_bytes(
1274                compressed[0..4]
1275                    .try_into()
1276                    .expect("slice does not have length 4"),
1277            ) - 8
1278        }
1279    };
1280    let mut decompressed: Vec<u8> = vec![0; size as usize];
1281
1282    // Start index of compressed body.
1283    // 8 bytes have already been removed as header
1284    let mut d = 4;
1285    let mut aa = [0_i32; 256];
1286    while s < decompressed.len() {
1287        if i == 0 {
1288            f = compressed[d] as usize;
1289            d += 1;
1290            i = 1;
1291        }
1292        if (f & i) != 0 {
1293            r = aa[compressed[d] as usize] as usize;
1294            d += 1;
1295            decompressed[s] = decompressed[r];
1296            s += 1;
1297            r += 1;
1298            decompressed[s] = decompressed[r];
1299            s += 1;
1300            r += 1;
1301            n = compressed[d] as usize;
1302            d += 1;
1303            for m in 0..n {
1304                decompressed[s + m] = decompressed[r + m];
1305            }
1306        } else {
1307            decompressed[s] = compressed[d];
1308            s += 1;
1309            d += 1;
1310        }
1311        while p < s - 1 {
1312            aa[(decompressed[p] ^ decompressed[p + 1]) as usize] = p as i32;
1313            p += 1;
1314        }
1315        if (f & i) != 0 {
1316            s += n;
1317            p = s;
1318        }
1319        i *= 2;
1320        if i == 256 {
1321            i = 0;
1322        }
1323    }
1324    decompressed
1325}