mod bootstrap;
pub(crate) mod client;
pub(crate) mod matchers;
pub(crate) mod models;
pub(crate) mod policy;
pub(crate) mod sanitizers;
use self::{
bootstrap::*, matchers::CustomDefaultMatcher, models::SanitizerList,
sanitizers::DEFAULT_SANITIZERS_TO_REMOVE,
};
use azure_core::{
error::ErrorKind,
http::{
headers::{Header, HeaderName, HeaderValue},
Url,
},
Result,
};
use client::Client;
use serde::Serializer;
use std::process::ExitStatus;
use std::{fmt, str::FromStr, sync::Arc};
use tokio::{io::Lines, process::Child};
const ABSTRACTION_IDENTIFIER: HeaderName = HeaderName::from_static("x-abstraction-identifier");
const RECORDING_ID: HeaderName = HeaderName::from_static("x-recording-id");
const RECORDING_MODE: HeaderName = HeaderName::from_static("x-recording-mode");
const RECORDING_UPSTREAM_BASE_URI: HeaderName =
HeaderName::from_static("x-recording-upstream-base-uri");
pub use bootstrap::start;
#[derive(Debug, Default)]
pub struct Proxy {
command: Option<Child>,
endpoint: Option<Url>,
client: Option<Client>,
}
impl Proxy {
async fn start<I: Iterator<Item = String>>(
&mut self,
git_dir: &Path,
executable_file_path: &Path,
args: I,
) -> Result<()> {
let mut command = Command::new(executable_file_path)
.args(args)
.env(
KESTREL_CERT_PATH_ENV,
git_dir.join("eng/common/testproxy/dotnet-devcert.pfx"),
)
.env(KESTREL_CERT_PASSWORD_ENV, KESTREL_CERT_PASSWORD)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|err| {
azure_core::Error::with_error(
ErrorKind::Io,
err,
format!("{} failed to start", executable_file_path.display()),
)
})?;
let mut stdout = BufReader::new(
command
.stdout
.take()
.ok_or_else(|| azure_core::Error::with_message(ErrorKind::Io, "no stdout pipe"))?,
)
.lines();
self.command = Some(command);
self.wait_till_listening(&mut stdout).await?;
tokio::spawn(async move {
while let Some(line) = stdout.next_line().await.unwrap_or(None) {
trace_line(Level::TRACE, &line);
}
});
if let Some(endpoint) = &self.endpoint {
self.client = Some(Client::new(endpoint.clone())?);
}
Ok(())
}
pub(crate) async fn initialize(&self) -> Result<()> {
if let Some(client) = &self.client {
client
.set_matcher(CustomDefaultMatcher::default().into(), None)
.await?;
let body = SanitizerList {
sanitizers: Vec::from_iter(
DEFAULT_SANITIZERS_TO_REMOVE
.iter()
.map(|s| String::from(*s)),
),
};
client.remove_sanitizers(body.try_into()?, None).await?;
}
Ok(())
}
pub async fn stop(&mut self) -> Result<()> {
if let Some(command) = &mut self.command {
tracing::debug!(pid = ?command.id(), "stopping");
return Ok(command.kill().await?);
}
Ok(())
}
pub async fn wait(&mut self) -> Result<ExitStatus> {
if let Some(command) = &mut self.command {
return Ok(command.wait().await?);
}
Ok(ExitStatus::default())
}
async fn wait_till_listening(
&mut self,
stdout: &mut Lines<BufReader<ChildStdout>>,
) -> Result<()> {
let pid = self.command.as_ref().and_then(Child::id);
while let Some(line) = stdout.next_line().await? {
const RUNNING_PATTERN: &str = "Running proxy version is Azure.Sdk.Tools.TestProxy ";
const LISTENING_PATTERN: &str = "Now listening on: ";
trace_line(Level::TRACE, &line);
if let Some(idx) = line.find(RUNNING_PATTERN) {
let idx = idx + RUNNING_PATTERN.len();
let version: Version = line[idx..].parse()?;
tracing::info!(?pid, %version, "started test-proxy version {version}");
if version < MIN_VERSION {
return Err(azure_core::Error::with_message(
ErrorKind::Io,
format!("test-proxy older than required version {MIN_VERSION}"),
));
}
continue;
}
if let Some(idx) = line.find(LISTENING_PATTERN) {
let idx = idx + LISTENING_PATTERN.len();
let mut endpoint: Url = line[idx..].parse()?;
endpoint.set_host(Some("localhost"))?;
tracing::info!(?pid, %endpoint, "test-proxy listening on {endpoint}");
self.endpoint = Some(endpoint);
break;
}
}
Ok(())
}
}
impl Proxy {
pub fn existing() -> Result<Self> {
let endpoint: Url = "http://localhost:5000".parse()?;
Ok(Self {
command: None,
endpoint: Some(endpoint.clone()),
client: Some(Client::new(endpoint)?),
})
}
pub fn endpoint(&self) -> &Url {
self.endpoint
.as_ref()
.unwrap_or_else(|| panic!("endpoint not set"))
}
}
impl Drop for Proxy {
fn drop(&mut self) {
if let Some(command) = &mut self.command {
let _ = command.start_kill();
}
}
}
pub(crate) trait ProxyExt<'a> {
fn client(&'a self) -> Option<&'a Client>;
}
impl<'a> ProxyExt<'a> for Arc<Proxy> {
fn client(&'a self) -> Option<&'a Client> {
self.client.as_ref()
}
}
impl<'a> ProxyExt<'a> for Option<Arc<Proxy>> {
fn client(&'a self) -> Option<&'a Client> {
self.as_ref().and_then(|proxy| proxy.client.as_ref())
}
}
#[derive(Clone, Debug)]
pub struct ProxyOptions {
pub auto: bool,
pub insecure: bool,
pub auto_shutdown_in_seconds: u32,
}
impl ProxyOptions {
fn copy_to(&self, args: &mut Vec<String>) {
if self.insecure {
args.push("--insecure".into());
}
args.extend_from_slice(&[
"--auto-shutdown-in-seconds".into(),
self.auto_shutdown_in_seconds.to_string(),
]);
if self.auto {
args.extend_from_slice(&["--", "--urls", "http://0.0.0.0:0"].map(Into::into));
}
}
}
impl Default for ProxyOptions {
fn default() -> Self {
Self {
auto: true,
insecure: false,
auto_shutdown_in_seconds: 300,
}
}
}
#[derive(Clone, Debug)]
pub struct RecordingId(String);
impl Header for RecordingId {
fn name(&self) -> HeaderName {
RECORDING_ID
}
fn value(&self) -> HeaderValue {
self.0.clone().into()
}
}
impl Header for &RecordingId {
fn name(&self) -> HeaderName {
RECORDING_ID
}
fn value(&self) -> HeaderValue {
self.0.clone().into()
}
}
impl AsRef<str> for RecordingId {
fn as_ref(&self) -> &str {
self.0.as_str()
}
}
impl FromStr for RecordingId {
type Err = std::convert::Infallible;
fn from_str(value: &str) -> std::result::Result<Self, Self::Err> {
Ok(RecordingId(value.to_string()))
}
}
#[derive(Debug, Default, Eq, PartialEq, Ord, PartialOrd)]
struct Version {
major: i32,
minor: i32,
metadata: Option<String>,
}
impl fmt::Display for Version {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(metadata) = &self.metadata {
return write!(f, "{}.{}-{metadata}", self.major, self.minor);
}
write!(f, "{}.{}", self.major, self.minor)
}
}
impl FromStr for Version {
type Err = azure_core::Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
let mut v = Version::default();
let mut cur = s.splitn(2, '.');
if let Some(major) = cur.next() {
v.major = major.parse()?;
} else {
return Err(azure_core::Error::with_message(
ErrorKind::DataConversion,
"major version required",
));
}
if let Some(minor) = cur.next() {
let mut cur = minor.splitn(2, '-');
if let Some(minor) = cur.next() {
v.minor = minor.parse()?;
}
v.metadata = cur.next().map(String::from);
}
Ok(v)
}
}
#[test]
fn version_eq() {
let a = Version {
major: 1,
..Default::default()
};
let b = Version {
major: 1,
..Default::default()
};
assert_eq!(a, b);
let a = Version {
major: 1,
minor: 2,
metadata: Some("preview".into()),
};
let b = Version {
major: 1,
minor: 2,
metadata: Some("preview".into()),
};
assert_eq!(a, b);
}
#[test]
fn version_cmp() {
let a = Version {
major: 20240107,
minor: 1,
..Default::default()
};
let b = Version {
major: 20240107,
minor: 2,
..Default::default()
};
let c = Version {
major: 20240109,
minor: 1,
metadata: Some("1".into()),
};
let d = Version {
major: 20240109,
minor: 1,
metadata: Some("2".into()),
};
assert!(a == a);
assert!(a < b);
assert!(b > a);
assert!(b < c);
assert!(c != d);
assert!(c < d);
}
#[test]
fn version_fmt() {
let mut v = Version {
major: 1,
..Default::default()
};
assert_eq!(v.to_string(), "1.0");
v.minor = 2;
v.metadata = Some("preview".into());
assert_eq!(v.to_string(), "1.2-preview");
}
#[test]
fn version_parse() {
let mut v = Version {
major: 1,
..Default::default()
};
assert!(matches!("1".parse::<Version>(), Ok(ver) if ver == v));
assert!(matches!("1.0".parse::<Version>(), Ok(ver) if ver == v));
v.minor = 2;
assert!(matches!("1.2".parse::<Version>(), Ok(ver) if ver == v));
v.metadata = Some("preview".into());
assert!(matches!("1.2-preview".parse::<Version>(), Ok(ver) if ver == v));
}
fn join<S>(value: &[&str], serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let s = value.join(",");
serializer.serialize_str(s.as_ref())
}