use std::{cmp::max, time::Duration};
use anyhow::bail;
use colored::Colorize;
use is_terminal::IsTerminal;
use queries::domains::DomainsDomains;
use serde_json::json;
use crate::{
consts::TICK_STRING,
controllers::project::{ensure_project_and_environment_exist, get_project},
errors::RailwayError,
};
use super::*;
#[derive(Parser)]
pub struct Args {
#[clap(short, long)]
port: Option<u16>,
#[clap(short, long)]
service: Option<String>,
domain: Option<String>,
#[clap(long)]
json: bool,
}
pub async fn command(args: Args) -> Result<()> {
if let Some(domain) = args.domain {
create_custom_domain(domain, args.port, args.service, args.json).await?;
} else {
create_service_domain(args.service, args.json).await?;
}
Ok(())
}
async fn create_service_domain(service_name: Option<String>, json: bool) -> Result<()> {
let configs = Configs::new()?;
let client = GQLClient::new_authorized(&configs)?;
let linked_project = configs.get_linked_project().await?;
ensure_project_and_environment_exist(&client, &configs, &linked_project).await?;
let project = get_project(&client, &configs, linked_project.project.clone()).await?;
let service = get_service(&linked_project, &project, service_name)?;
let vars = queries::domains::Variables {
project_id: linked_project.project.clone(),
environment_id: linked_project.environment.clone(),
service_id: service.id.clone(),
};
let domains = post_graphql::<queries::Domains, _>(&client, configs.get_backboard(), vars)
.await?
.domains;
let domain_count = domains.service_domains.len() + domains.custom_domains.len();
if domain_count > 0 {
return print_existing_domains(&domains, json);
}
let spinner = (std::io::stdout().is_terminal() && !json)
.then(|| creating_domain_spiner(None))
.and_then(|s| s.ok());
let vars = mutations::service_domain_create::Variables {
service_id: service.id.clone(),
environment_id: linked_project.environment.clone(),
};
let domain =
post_graphql::<mutations::ServiceDomainCreate, _>(&client, configs.get_backboard(), vars)
.await?
.service_domain_create
.domain;
if let Some(spinner) = spinner {
spinner.finish_and_clear();
}
let formatted_domain = format!("https://{domain}");
if json {
let out = json!({
"domain": formatted_domain
});
println!("{}", serde_json::to_string_pretty(&out)?);
} else {
println!(
"Service Domain created:\n🚀 {}",
formatted_domain.magenta().bold()
);
}
Ok(())
}
fn print_existing_domains(domains: &DomainsDomains, json: bool) -> Result<()> {
if json {
let all_domains: Vec<String> = domains
.service_domains
.iter()
.map(|d| format!("https://{}", d.domain))
.chain(
domains
.custom_domains
.iter()
.map(|d| format!("https://{}", d.domain)),
)
.collect();
println!(
"{}",
serde_json::to_string_pretty(&json!({ "domains": all_domains }))?
);
return Ok(());
}
println!("Domains already exist on the service:");
let domain_count = domains.service_domains.len() + domains.custom_domains.len();
if domain_count == 1 {
let domain = domains
.service_domains
.first()
.map(|d| d.domain.clone())
.unwrap_or_else(|| {
domains
.custom_domains
.first()
.map(|d| d.domain.clone())
.unwrap_or_else(|| unreachable!())
});
let formatted_domain = format!("https://{domain}");
println!("🚀 {}", formatted_domain.magenta().bold());
return Ok(());
}
for domain in &domains.custom_domains {
let formatted_domain = format!("https://{}", domain.domain);
println!("- {}", formatted_domain.magenta().bold());
}
for domain in &domains.service_domains {
let formatted_domain = format!("https://{}", domain.domain);
println!("- {}", formatted_domain.magenta().bold());
}
Ok(())
}
pub fn get_service<'a>(
linked_project: &'a LinkedProject,
project: &'a queries::project::ProjectProject,
service_name: Option<String>,
) -> anyhow::Result<&'a queries::project::ProjectProjectServicesEdgesNode> {
let services = project.services.edges.iter().collect::<Vec<_>>();
if services.is_empty() {
bail!(RailwayError::NoServices);
}
if project.services.edges.len() == 1 {
return Ok(&project.services.edges[0].node);
}
if let Some(service_name) = service_name {
if let Some(service) = project
.services
.edges
.iter()
.find(|s| s.node.name == service_name)
{
return Ok(&service.node);
}
bail!(RailwayError::ServiceNotFound(service_name));
}
if let Some(service) = linked_project.service.clone() {
if project.services.edges.iter().any(|s| s.node.id == service) {
return Ok(&project
.services
.edges
.iter()
.find(|s| s.node.id == service)
.unwrap()
.node);
}
}
bail!(RailwayError::NoServices);
}
pub fn creating_domain_spiner(message: Option<String>) -> anyhow::Result<indicatif::ProgressBar> {
let spinner = indicatif::ProgressBar::new_spinner()
.with_style(
indicatif::ProgressStyle::default_spinner()
.tick_chars(TICK_STRING)
.template("{spinner:.green} {msg}")?,
)
.with_message(message.unwrap_or_else(|| "Creating domain...".to_string()));
spinner.enable_steady_tick(Duration::from_millis(100));
Ok(spinner)
}
async fn create_custom_domain(
domain: String,
port: Option<u16>,
service_name: Option<String>,
json: bool,
) -> Result<()> {
let configs = Configs::new()?;
let client = GQLClient::new_authorized(&configs)?;
let linked_project = configs.get_linked_project().await?;
ensure_project_and_environment_exist(&client, &configs, &linked_project).await?;
let project = get_project(&client, &configs, linked_project.project.clone()).await?;
let service = get_service(&linked_project, &project, service_name)?;
let spinner = (std::io::stdout().is_terminal() && !json)
.then(|| {
creating_domain_spiner(Some(format!(
"Creating custom domain for service {}{}...",
service.name,
port.map(|p| format!(" on port {p}")).unwrap_or_default()
)))
})
.and_then(|s| s.ok());
let is_available = post_graphql::<queries::CustomDomainAvailable, _>(
&client,
configs.get_backboard(),
queries::custom_domain_available::Variables {
domain: domain.clone(),
},
)
.await?
.custom_domain_available
.available;
if !is_available {
bail!("Domain is not available:\n\t{}", domain);
}
let vars = mutations::custom_domain_create::Variables {
input: mutations::custom_domain_create::CustomDomainCreateInput {
domain: domain.clone(),
environment_id: linked_project.environment.clone(),
project_id: linked_project.project.clone(),
service_id: service.id.clone(),
target_port: port.map(|p| p as i64),
},
};
let response =
post_graphql::<mutations::CustomDomainCreate, _>(&client, configs.get_backboard(), vars)
.await?;
if let Some(s) = spinner {
s.finish_and_clear()
}
if json {
println!("{}", serde_json::to_string_pretty(&response)?);
return Ok(());
}
println!("Domain created: {}", response.custom_domain_create.domain);
if response.custom_domain_create.status.dns_records.is_empty() {
bail!("No DNS records found. Please check the Railway dashboard for more information.");
}
let zone = response.custom_domain_create.status.dns_records[0]
.zone
.clone();
println!(
"To finish setting up your custom domain, add the following DNS records to {}:\n",
&zone
);
print_dns(
response.custom_domain_create.status.dns_records,
&response.custom_domain_create.status.verification_dns_host,
&response.custom_domain_create.status.verification_token,
response.custom_domain_create.status.verified,
&zone,
);
println!(
"\nNote: if the Name is \"@\", the DNS record should be created for the root of the domain."
);
println!("DNS changes can take up to 72 hours to propagate worldwide.");
Ok(())
}
fn print_dns(
domains: Vec<
mutations::custom_domain_create::CustomDomainCreateCustomDomainCreateStatusDnsRecords,
>,
verification_dns_host: &Option<String>,
verification_token: &Option<String>,
verified: bool,
zone: &str,
) {
let txt_verification = if !verified {
match (verification_dns_host, verification_token) {
(Some(host), Some(token)) => {
let host_label = host.strip_suffix(&format!(".{}", zone)).unwrap_or(host);
Some((host_label.to_string(), format!("railway-verify={}", token)))
}
_ => None,
}
} else {
None
};
let (padding_type, padding_hostlabel, padding_value) = domains
.iter()
.fold((5, 5, 5), |(max_type, max_hostlabel, max_value), d| {
(
max(max_type, d.record_type.to_string().len()),
max(max_hostlabel, d.hostlabel.len()),
max(max_value, d.required_value.len()),
)
});
let (padding_type, padding_hostlabel, padding_value) =
if let Some((host, value)) = &txt_verification {
(
max(padding_type, 3), max(padding_hostlabel, host.len()),
max(padding_value, value.len()),
)
} else {
(padding_type, padding_hostlabel, padding_value)
};
let [padding_type, padding_hostlabel, padding_value] =
[padding_type + 3, padding_hostlabel + 3, padding_value + 3];
println!(
"\t{:<width_type$}{:<width_host$}{:<width_value$}",
"Type",
"Name",
"Value",
width_type = padding_type,
width_host = padding_hostlabel,
width_value = padding_value
);
for domain in &domains {
println!(
"\t{:<width_type$}{:<width_host$}{:<width_value$}",
domain.record_type.to_string(),
if domain.hostlabel.is_empty() {
"@"
} else {
&domain.hostlabel
},
domain.required_value,
width_type = padding_type,
width_host = padding_hostlabel,
width_value = padding_value
);
}
if let Some((host, value)) = txt_verification {
println!(
"\t{:<width_type$}{:<width_host$}{:<width_value$}",
"TXT",
host,
value,
width_type = padding_type,
width_host = padding_hostlabel,
width_value = padding_value
);
}
}