use anyhow::anyhow;
use flate2::read::GzDecoder;
use futures_util::StreamExt;
use serde::Deserialize;
use std::{
fs::OpenOptions,
io,
path::{Path, PathBuf},
};
use temporalio_client::{Connection, ConnectionOptions};
use tokio::{
task::spawn_blocking,
time::{Duration, sleep},
};
use tokio_util::io::{StreamReader, SyncIoBridge};
use url::Url;
use zip::read::read_zipfile_from_stream;
#[cfg(target_family = "unix")]
use std::os::unix::fs::OpenOptionsExt;
use std::process::Stdio;
#[derive(Debug, Clone, bon::Builder)]
#[builder(on(String, into))]
pub struct TemporalDevServerConfig {
pub exe: EphemeralExe,
#[builder(default = "default".to_owned())]
pub namespace: String,
#[builder(default = "127.0.0.1".to_owned())]
pub ip: String,
pub port: Option<u16>,
pub ui_port: Option<u16>,
pub db_filename: Option<String>,
#[builder(default)]
pub ui: bool,
#[builder(default = ("pretty".to_owned(), "warn".to_owned()))]
pub log: (String, String),
#[builder(default)]
pub extra_args: Vec<String>,
}
impl TemporalDevServerConfig {
pub async fn start_server(&self) -> anyhow::Result<EphemeralServer> {
self.start_server_with_output(Stdio::inherit(), Stdio::inherit())
.await
}
pub async fn start_server_with_output(
&self,
output: Stdio,
err_output: Stdio,
) -> anyhow::Result<EphemeralServer> {
let exe_path = self
.exe
.get_or_download("cli", "temporal", Some("tar.gz"))
.await?;
let port = match self.port {
Some(p) => p,
None => get_free_port(&self.ip)?,
};
let mut args = vec![
"server".to_owned(),
"start-dev".to_owned(),
"--port".to_owned(),
port.to_string(),
"--namespace".to_owned(),
self.namespace.clone(),
"--ip".to_owned(),
self.ip.clone(),
"--log-format".to_owned(),
self.log.0.clone(),
"--log-level".to_owned(),
self.log.1.clone(),
"--dynamic-config-value".to_owned(),
"frontend.enableServerVersionCheck=false".to_owned(),
"--dynamic-config-value".to_owned(),
"frontend.enableUpdateWorkflowExecution=true".to_owned(),
"--dynamic-config-value".to_owned(),
"frontend.enableUpdateWorkflowExecutionAsyncAccepted=true".to_owned(),
];
if let Some(db_filename) = &self.db_filename {
args.push("--db-filename".to_owned());
args.push(db_filename.clone());
}
if let Some(ui_port) = self.ui_port {
args.push("--ui-port".to_owned());
args.push(ui_port.to_string());
} else if self.ui {
let port = port.saturating_add(1000);
args.push("--ui-port".to_owned());
args.push(port.to_string());
} else {
args.push("--headless".to_owned());
}
args.extend(self.extra_args.clone());
EphemeralServer::start(EphemeralServerConfig {
exe_path,
port,
args,
has_test_service: false,
output,
err_output,
})
.await
}
}
#[derive(Debug, Clone, bon::Builder)]
pub struct TestServerConfig {
pub exe: EphemeralExe,
pub port: Option<u16>,
#[builder(default)]
pub extra_args: Vec<String>,
}
impl TestServerConfig {
pub async fn start_server(&self) -> anyhow::Result<EphemeralServer> {
self.start_server_with_output(Stdio::inherit(), Stdio::inherit())
.await
}
pub async fn start_server_with_output(
&self,
output: Stdio,
err_output: Stdio,
) -> anyhow::Result<EphemeralServer> {
let exe_path = self
.exe
.get_or_download("temporal-test-server", "temporal-test-server", None)
.await?;
let port = match self.port {
Some(p) => p,
None => get_free_port("0.0.0.0")?,
};
let mut args = vec![port.to_string()];
args.extend(self.extra_args.clone());
EphemeralServer::start(EphemeralServerConfig {
exe_path,
port,
args,
has_test_service: true,
output,
err_output,
})
.await
}
}
struct EphemeralServerConfig {
exe_path: PathBuf,
port: u16,
args: Vec<String>,
has_test_service: bool,
output: Stdio,
err_output: Stdio,
}
#[derive(Debug)]
pub struct EphemeralServer {
pub target: String,
pub has_test_service: bool,
child: tokio::process::Child,
}
impl EphemeralServer {
async fn start(config: EphemeralServerConfig) -> anyhow::Result<EphemeralServer> {
let child = tokio::process::Command::new(config.exe_path)
.args(config.args)
.stdin(Stdio::null())
.stdout(config.output)
.stderr(config.err_output)
.spawn()?;
let target = format!("127.0.0.1:{}", config.port);
let target_url = format!("http://{target}");
let success = Ok(EphemeralServer {
target,
has_test_service: config.has_test_service,
child,
});
let connection_options = ConnectionOptions::new(Url::parse(&target_url)?)
.identity("online_checker".to_owned())
.client_name("online-checker".to_owned())
.client_version("0.1.0".to_owned())
.build();
let mut last_error = None;
for _ in 0..50 {
sleep(Duration::from_millis(100)).await;
let connect_res = Connection::connect(connection_options.clone()).await;
if let Err(err) = connect_res {
last_error = Some(err);
} else {
return success;
}
}
Err(anyhow!(
"Failed connecting to test server after 5 seconds, last error: {last_error:?}"
))
}
pub async fn shutdown(&mut self) -> anyhow::Result<()> {
if self.child.id().is_some() {
Ok(self.child.kill().await?)
} else {
Ok(())
}
}
pub fn child_process_id(&self) -> Option<u32> {
self.child.id()
}
}
#[derive(Debug, Clone)]
pub enum EphemeralExe {
ExistingPath(String),
CachedDownload {
version: EphemeralExeVersion,
dest_dir: Option<String>,
ttl: Option<Duration>,
},
}
pub fn default_cached_download() -> EphemeralExe {
EphemeralExe::CachedDownload {
version: EphemeralExeVersion::SDKDefault {
sdk_name: "sdk-rust".to_string(),
sdk_version: "0.1.0".to_string(),
},
dest_dir: None,
ttl: Some(Duration::from_secs(60 * 60 * 24 * 15)),
}
}
#[derive(Debug, Clone)]
pub enum EphemeralExeVersion {
SDKDefault {
sdk_name: String,
sdk_version: String,
},
Fixed(String),
}
#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
struct DownloadInfo {
archive_url: String,
file_to_extract: String,
}
impl EphemeralExe {
async fn get_or_download(
&self,
artifact_name: &str,
downloaded_name_prefix: &str,
preferred_format: Option<&str>,
) -> anyhow::Result<PathBuf> {
match self {
EphemeralExe::ExistingPath(exe_path) => {
let path = PathBuf::from(exe_path);
if !path.exists() {
return Err(anyhow!("Exe path does not exist"));
}
Ok(path)
}
EphemeralExe::CachedDownload {
version,
dest_dir,
ttl,
} => {
let dest_dir = dest_dir
.as_ref()
.map(PathBuf::from)
.unwrap_or_else(std::env::temp_dir);
let (platform, out_ext) = match std::env::consts::OS {
"windows" => ("windows", ".exe"),
"macos" => ("darwin", ""),
_ => ("linux", ""),
};
let dest = dest_dir.join(match version {
EphemeralExeVersion::SDKDefault {
sdk_name,
sdk_version,
} => format!("{downloaded_name_prefix}-{sdk_name}-{sdk_version}{out_ext}"),
EphemeralExeVersion::Fixed(version) => {
format!("{downloaded_name_prefix}-{version}{out_ext}")
}
});
debug!(
"Lazily downloading or using existing exe at {}",
dest.display()
);
if dest.exists() && remove_file_past_ttl(ttl, &dest)? {
return Ok(dest);
}
let arch = match std::env::consts::ARCH {
"x86_64" => "amd64",
"arm" | "aarch64" => "arm64",
other => return Err(anyhow!("Unsupported arch: {other}")),
};
let mut get_info_params = vec![("arch", arch), ("platform", platform)];
if let Some(format) = preferred_format {
get_info_params.push(("format", format));
}
let version_name = match version {
EphemeralExeVersion::SDKDefault {
sdk_name,
sdk_version,
} => {
get_info_params.push(("sdk-name", sdk_name.as_str()));
get_info_params.push(("sdk-version", sdk_version.as_str()));
"default"
}
EphemeralExeVersion::Fixed(version) => version,
};
let client = reqwest::Client::new();
let resp = client
.get(format!(
"https://temporal.download/{artifact_name}/{version_name}"
))
.query(&get_info_params)
.send()
.await?
.error_for_status()?;
let info: DownloadInfo = resp.json().await?;
loop {
if lazy_download_exe(
&client,
&info.archive_url,
Path::new(&info.file_to_extract),
&dest,
false,
)
.await?
{
return Ok(dest);
}
}
}
}
}
}
fn get_free_port(bind_ip: &str) -> io::Result<u16> {
let listen = std::net::TcpListener::bind((bind_ip, 0))?;
let addr = listen.local_addr()?;
#[cfg(not(any(target_os = "windows", target_os = "macos")))]
{
let _stream = std::net::TcpStream::connect(addr)?;
let (socket, _addr) = listen.accept()?;
drop(socket);
}
Ok(addr.port())
}
async fn lazy_download_exe(
client: &reqwest::Client,
uri: &str,
file_to_extract: &Path,
dest: &Path,
already_tried_cleaning_old: bool,
) -> anyhow::Result<bool> {
if dest.exists() {
return Ok(true);
}
let temp_dest_str = format!("{}{}", dest.to_str().unwrap(), ".downloading");
let temp_dest = Path::new(&temp_dest_str);
#[cfg(target_family = "unix")]
let file = OpenOptions::new()
.create_new(true)
.write(true)
.mode(0o755)
.open(temp_dest);
#[cfg(not(target_family = "unix"))]
let file = OpenOptions::new()
.create_new(true)
.write(true)
.open(temp_dest);
match file {
Err(err) if err.kind() == io::ErrorKind::AlreadyExists => {
if !already_tried_cleaning_old
&& temp_dest.metadata()?.modified()?.elapsed()?.as_secs() > 90
{
std::fs::remove_file(temp_dest)?;
return Box::pin(lazy_download_exe(client, uri, file_to_extract, dest, true)).await;
}
for _ in 0..20 {
sleep(Duration::from_secs(1)).await;
if !temp_dest.exists() {
return Ok(false);
}
}
Err(anyhow!(
"Temp download file at {} not complete after 20 seconds. \
Make sure another download isn't running for too long and delete the temp file.",
temp_dest.display()
))
}
Err(err) => Err(err.into()),
Ok(_) if dest.exists() => {
std::fs::remove_file(temp_dest)?;
return Ok(true);
}
Ok(mut temp_file) => {
info!("Downloading {} to {}", uri, dest.display());
download_and_extract(client, uri, file_to_extract, &mut temp_file)
.await
.inspect_err(|_| {
if let Err(err) = std::fs::remove_file(temp_dest) {
warn!(
"Failed removing temp file at {}: {:?}",
temp_dest.display(),
err
);
}
})
}
}?;
std::fs::rename(temp_dest, dest)?;
Ok(true)
}
async fn download_and_extract(
client: &reqwest::Client,
uri: &str,
file_to_extract: &Path,
dest: &mut std::fs::File,
) -> anyhow::Result<()> {
let resp = client.get(uri).send().await?.error_for_status()?;
let stream = resp
.bytes_stream()
.map(|item| item.map_err(io::Error::other));
let mut reader = SyncIoBridge::new(StreamReader::new(stream));
let tarball = if uri.ends_with(".tar.gz") {
true
} else if uri.ends_with(".zip") {
false
} else {
return Err(anyhow!("URI not .tar.gz or .zip"));
};
let file_to_extract = file_to_extract.to_path_buf();
let mut dest = dest.try_clone()?;
spawn_blocking(move || {
if tarball {
for entry in tar::Archive::new(GzDecoder::new(reader)).entries()? {
let mut entry = entry?;
if entry.path()? == file_to_extract {
std::io::copy(&mut entry, &mut dest)?;
return Ok(());
}
}
Err(anyhow!("Unable to find file in tarball"))
} else {
loop {
if let Some(mut file) = read_zipfile_from_stream(&mut reader)? {
if file.enclosed_name().as_ref() == Some(&file_to_extract) {
std::io::copy(&mut file, &mut dest)?;
return Ok(());
}
} else {
return Err(anyhow!("Unable to find file in zip"));
}
}
}
})
.await?
}
fn remove_file_past_ttl(ttl: &Option<Duration>, dest: &PathBuf) -> Result<bool, anyhow::Error> {
match ttl {
None => return Ok(true),
Some(ttl) => {
if let Ok(mtime) = dest.metadata().and_then(|d| d.modified()) {
if mtime.elapsed().unwrap_or_default().lt(ttl) {
return Ok(true);
} else {
std::fs::remove_file(dest)?;
}
}
}
}
Ok(false)
}
#[cfg(test)]
mod tests {
use super::{get_free_port, remove_file_past_ttl};
use std::{
env::temp_dir,
fs::File,
net::{TcpListener, TcpStream},
time::{Duration, SystemTime},
};
#[test]
fn get_free_port_can_bind_immediately() {
let host = "127.0.0.1";
for _ in 0..500 {
let port = get_free_port(host).unwrap();
try_listen_and_dial_on(host, port).expect("Failed to bind to port");
}
}
#[tokio::test]
async fn respects_file_ttl() {
let rand_fname = format!("{}", rand::random::<u64>());
let temp_dir = temp_dir();
let dest_file_path = temp_dir.join(format!("core-test-{}", &rand_fname));
let dest_file = File::create(&dest_file_path).unwrap();
let set_time_to = SystemTime::now() - Duration::from_secs(100);
dest_file.set_modified(set_time_to).unwrap();
remove_file_past_ttl(&Some(Duration::from_secs(60)), &dest_file_path).unwrap();
assert!(!dest_file_path.exists());
}
fn try_listen_and_dial_on(host: &str, port: u16) -> std::io::Result<()> {
let listener = TcpListener::bind((host, port))?;
let _stream = TcpStream::connect((host, port))?;
let (socket, _addr) = listener.accept()?;
drop(socket);
Ok(())
}
}