nut_client/tokio/
mod.rs

1use std::net::SocketAddr;
2
3use crate::cmd::{Command, Response};
4use crate::tokio::stream::ConnectionStream;
5use crate::{Config, Host, NutError};
6use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
7use tokio::net::TcpStream;
8
9mod stream;
10
11/// An async NUT client connection.
12pub enum Connection {
13    /// A TCP connection.
14    Tcp(TcpConnection),
15}
16
17impl Connection {
18    /// Initializes a connection to a NUT server (upsd).
19    pub async fn new(config: &Config) -> crate::Result<Self> {
20        let mut conn = match &config.host {
21            Host::Tcp(host) => Self::Tcp(TcpConnection::new(config.clone(), &host.addr).await?),
22        };
23
24        conn.get_network_version().await?;
25        conn.login(config).await?;
26
27        Ok(conn)
28    }
29
30    /// Gracefully closes the connection.
31    pub async fn close(mut self) -> crate::Result<()> {
32        self.logout().await?;
33        Ok(())
34    }
35
36    /// Sends username and password, as applicable.
37    async fn login(&mut self, config: &Config) -> crate::Result<()> {
38        if let Some(auth) = config.auth.clone() {
39            // Pass username and check for 'OK'
40            self.set_username(&auth.username).await?;
41
42            // Pass password and check for 'OK'
43            if let Some(password) = &auth.password {
44                self.set_password(password).await?;
45            }
46        }
47        Ok(())
48    }
49}
50
51/// A blocking TCP NUT client connection.
52pub struct TcpConnection {
53    config: Config,
54    stream: ConnectionStream,
55}
56
57impl TcpConnection {
58    async fn new(config: Config, socket_addr: &SocketAddr) -> crate::Result<Self> {
59        // Create the TCP connection
60        let tcp_stream = TcpStream::connect(socket_addr).await?;
61        let mut connection = Self {
62            config,
63            stream: ConnectionStream::Plain(tcp_stream),
64        };
65        connection = connection.enable_ssl().await?;
66        Ok(connection)
67    }
68
69    #[cfg(feature = "async-ssl")]
70    async fn enable_ssl(mut self) -> crate::Result<Self> {
71        if self.config.ssl {
72            // Send TLS request and check for 'OK'
73            self.write_cmd(Command::StartTLS).await?;
74            self.read_response()
75                .await
76                .map_err(|e| {
77                    if let crate::ClientError::Nut(NutError::FeatureNotConfigured) = e {
78                        crate::ClientError::Nut(NutError::SslNotSupported)
79                    } else {
80                        e
81                    }
82                })?
83                .expect_ok()?;
84
85            let mut ssl_config = rustls::ClientConfig::new();
86            let dns_name: webpki::DNSName;
87
88            if self.config.ssl_insecure {
89                ssl_config
90                    .dangerous()
91                    .set_certificate_verifier(std::sync::Arc::new(
92                        crate::ssl::InsecureCertificateValidator::new(&self.config),
93                    ));
94
95                dns_name = webpki::DNSNameRef::try_from_ascii_str("www.google.com")
96                    .unwrap()
97                    .to_owned();
98            } else {
99                // Try to get hostname as given (e.g. localhost can be used for strict SSL, but not 127.0.0.1)
100                let hostname = self
101                    .config
102                    .host
103                    .hostname()
104                    .ok_or(crate::ClientError::Nut(NutError::SslInvalidHostname))?;
105
106                dns_name = webpki::DNSNameRef::try_from_ascii_str(&hostname)
107                    .map_err(|_| crate::ClientError::Nut(NutError::SslInvalidHostname))?
108                    .to_owned();
109
110                ssl_config
111                    .root_store
112                    .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
113            };
114
115            let config = tokio_rustls::TlsConnector::from(std::sync::Arc::new(ssl_config));
116
117            // Wrap and override the TCP stream
118            self.stream = self.stream.upgrade_ssl(config, dns_name.as_ref()).await?;
119        }
120        Ok(self)
121    }
122
123    #[cfg(not(feature = "async-ssl"))]
124    async fn enable_ssl(self) -> crate::Result<Self> {
125        Ok(self)
126    }
127
128    pub(crate) async fn write_cmd(&mut self, line: Command<'_>) -> crate::Result<()> {
129        let line = format!("{}\n", line);
130        if self.config.debug {
131            eprint!("DEBUG -> {}", line);
132        }
133        self.stream.write_all(line.as_bytes()).await?;
134        self.stream.flush().await?;
135        Ok(())
136    }
137
138    async fn parse_line(
139        reader: &mut BufReader<&mut ConnectionStream>,
140        debug: bool,
141    ) -> crate::Result<Vec<String>> {
142        let mut raw = String::new();
143        reader.read_line(&mut raw).await?;
144        if debug {
145            eprint!("DEBUG <- {}", raw);
146        }
147        raw = raw[..raw.len() - 1].to_string(); // Strip off \n
148
149        // Parse args by splitting whitespace, minding quotes for args with multiple words
150        let args = shell_words::split(&raw)
151            .map_err(|e| NutError::Generic(format!("Parsing server response failed: {}", e)))?;
152
153        Ok(args)
154    }
155
156    pub(crate) async fn read_response(&mut self) -> crate::Result<Response> {
157        let mut reader = BufReader::new(&mut self.stream);
158        let args = Self::parse_line(&mut reader, self.config.debug).await?;
159        Response::from_args(args)
160    }
161
162    pub(crate) async fn read_plain_response(&mut self) -> crate::Result<String> {
163        let mut reader = BufReader::new(&mut self.stream);
164        let args = Self::parse_line(&mut reader, self.config.debug).await?;
165        Ok(args.join(" "))
166    }
167
168    pub(crate) async fn read_list(&mut self, query: &[&str]) -> crate::Result<Vec<Response>> {
169        let mut reader = BufReader::new(&mut self.stream);
170        let args = Self::parse_line(&mut reader, self.config.debug).await?;
171
172        Response::from_args(args)?.expect_begin_list(query)?;
173        let mut lines: Vec<Response> = Vec::new();
174
175        loop {
176            let args = Self::parse_line(&mut reader, self.config.debug).await?;
177            let resp = Response::from_args(args)?;
178
179            match resp {
180                Response::EndList(_) => {
181                    break;
182                }
183                _ => lines.push(resp),
184            }
185        }
186
187        Ok(lines)
188    }
189}