use crate::context::GlobalParams;
use crate::error::{Error, ErrorKind, Result};
use crate::modules::{Module, ModuleResult, parse_params};
#[cfg(feature = "docs")]
use rash_derive::DocJsonSchema;
use minijinja::Value;
#[cfg(feature = "docs")]
use schemars::{JsonSchema, Schema};
use serde::Deserialize;
use serde_norway::Value as YamlValue;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, TcpStream};
use std::time::{Duration, Instant};
const DEFAULT_CONNECT_TIMEOUT: u64 = 5;
const DEFAULT_SLEEP_MS: u64 = 100;
fn default_connect_timeout() -> u64 {
DEFAULT_CONNECT_TIMEOUT
}
#[derive(Debug, PartialEq, Deserialize)]
#[cfg_attr(feature = "docs", derive(JsonSchema, DocJsonSchema))]
#[serde(deny_unknown_fields)]
pub struct Params {
#[serde(default = "default_connect_timeout")]
connect_timeout: u64,
port: u16,
timeout: u64,
#[serde(default = "default_host")]
host: String,
}
fn default_host() -> String {
"127.0.0.1".to_owned()
}
fn check_port(host: &str, port: u16, connect_timeout: u64) -> std::io::Result<()> {
let addr: SocketAddr = SocketAddr::V4(SocketAddrV4::new(
host.parse::<Ipv4Addr>()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?,
port,
));
TcpStream::connect_timeout(&addr, Duration::from_secs(connect_timeout))?;
Ok(())
}
fn wait_for_port(params: Params) -> Result<ModuleResult> {
let start = Instant::now();
let timeout = Duration::from_secs(params.timeout);
let sleep_duration = Duration::from_millis(DEFAULT_SLEEP_MS);
loop {
match check_port(¶ms.host, params.port, params.connect_timeout) {
Ok(_) => {
return Ok(ModuleResult::new(
false,
None,
Some(params.port.to_string()),
));
}
Err(e) => {
if start.elapsed() >= timeout {
return Err(Error::new(
ErrorKind::SubprocessFail,
format!(
"Timeout waiting for port {} on {}: {}",
params.port, params.host, e
),
));
}
std::thread::sleep(sleep_duration);
}
}
}
}
#[derive(Debug)]
pub struct WaitFor;
impl Module for WaitFor {
fn get_name(&self) -> &str {
"wait_for"
}
fn exec(
&self,
_: &GlobalParams,
optional_params: YamlValue,
_vars: &Value,
_check_mode: bool,
) -> Result<(ModuleResult, Option<Value>)> {
Ok((wait_for_port(parse_params(optional_params)?)?, None))
}
#[cfg(feature = "docs")]
fn get_json_schema(&self) -> Option<Schema> {
Some(Params::get_json_schema())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_params() {
let yaml: YamlValue = serde_norway::from_str(
r#"
port: 8080
timeout: 30
"#,
)
.unwrap();
let params: Params = parse_params(yaml).unwrap();
assert_eq!(
params,
Params {
port: 8080,
timeout: 30,
connect_timeout: DEFAULT_CONNECT_TIMEOUT,
host: "127.0.0.1".to_owned(),
}
);
}
#[test]
fn test_parse_params_with_all_fields() {
let yaml: YamlValue = serde_norway::from_str(
r#"
port: 5432
timeout: 60
connect_timeout: 10
host: "192.168.1.1"
"#,
)
.unwrap();
let params: Params = parse_params(yaml).unwrap();
assert_eq!(
params,
Params {
port: 5432,
timeout: 60,
connect_timeout: 10,
host: "192.168.1.1".to_owned(),
}
);
}
#[test]
fn test_parse_params_missing_required() {
let yaml: YamlValue = serde_norway::from_str(
r#"
port: 8080
"#,
)
.unwrap();
let error = parse_params::<Params>(yaml).unwrap_err();
assert_eq!(error.kind(), ErrorKind::InvalidData);
}
#[test]
fn test_wait_for_port_timeout() {
let params = Params {
port: 1,
timeout: 1,
connect_timeout: 1,
host: "127.0.0.1".to_owned(),
};
let result = wait_for_port(params);
assert!(result.is_err());
}
}