use super::{PROXY_PORT_ENV_KEY, ProxyHandle, pick_ephemeral_port, wait_for_bind};
use anyhow::{Context, Result, bail};
use std::collections::HashMap;
use std::time::Duration;
use tokio::process::Command;
use tracing::debug;
#[derive(Debug, Clone)]
pub struct ExternalProxy {
pub command: Vec<String>,
pub env: HashMap<String, String>,
pub port: Option<u16>,
pub startup_timeout: Duration,
}
impl ExternalProxy {
pub fn new<S: Into<String>>(command: impl IntoIterator<Item = S>) -> Self {
Self {
command: command.into_iter().map(Into::into).collect(),
env: HashMap::new(),
port: None,
startup_timeout: Duration::from_secs(5),
}
}
pub async fn spawn(&self) -> Result<ProxyHandle> {
if self.command.is_empty() {
bail!("ExternalProxy::command must not be empty");
}
let port = match self.port {
Some(p) => p,
None => pick_ephemeral_port().context("pick ephemeral port for proxy")?,
};
let mut cmd = Command::new(&self.command[0]);
cmd.args(&self.command[1..]);
cmd.env(PROXY_PORT_ENV_KEY, port.to_string());
for (k, v) in &self.env {
cmd.env(k, v);
}
cmd.stdin(std::process::Stdio::null());
let child = cmd
.spawn()
.with_context(|| format!("spawn external proxy: {:?}", self.command))?;
debug!(
"external proxy spawned: cmd={:?} port={} pid={:?}",
self.command,
port,
child.id()
);
wait_for_bind(port, self.startup_timeout)
.await
.with_context(|| format!("external proxy did not bind 127.0.0.1:{port}"))?;
Ok(ProxyHandle::from_child(port, child))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn external_proxy_new_sets_defaults() {
let p = ExternalProxy::new(["mitmdump", "--listen-port", "8877"]);
assert_eq!(p.command, vec!["mitmdump", "--listen-port", "8877"]);
assert!(p.env.is_empty());
assert!(p.port.is_none());
assert_eq!(p.startup_timeout, Duration::from_secs(5));
}
#[tokio::test]
async fn external_proxy_empty_command_errors() {
let p = ExternalProxy {
command: vec![],
env: HashMap::new(),
port: None,
startup_timeout: Duration::from_millis(100),
};
let err = p.spawn().await.expect_err("must error on empty command");
assert!(err.to_string().contains("must not be empty"));
}
#[tokio::test]
async fn external_proxy_unbound_command_times_out() {
let p = ExternalProxy {
command: vec!["true".to_string()],
env: HashMap::new(),
port: None,
startup_timeout: Duration::from_millis(150),
};
let err = p.spawn().await.expect_err("must time out");
let msg = format!("{err:#}");
assert!(msg.contains("did not bind"), "got: {msg}");
}
}