use std::collections::HashMap;
use anyhow::{Context, Result};
use colored::Colorize;
use tsafe_core::{audit::AuditEntry, events::emit_event};
use crate::helpers::*;
const MAX_OTS_TRANSIENT_RETRIES: u32 = 3;
const DEFAULT_OTS_RETRY_SECS: u64 = 2;
type OtsHttpResult<T> = std::result::Result<T, Box<ureq::Error>>;
#[derive(Clone, Copy)]
struct OtsRetryPolicy {
max_transient_retries: u32,
base_retry_secs: u64,
}
impl OtsRetryPolicy {
fn default_policy() -> Self {
Self {
max_transient_retries: MAX_OTS_TRANSIENT_RETRIES,
base_retry_secs: DEFAULT_OTS_RETRY_SECS,
}
}
}
fn resolve_ots_base_url() -> Result<String> {
let url = std::env::var("TSAFE_OTS_BASE_URL")
.or_else(|_| std::env::var("TSAFE_OTS_URL"))
.map_err(|_| {
anyhow::anyhow!(
"TSAFE_OTS_BASE_URL is not set. Export it to your one-time secret (OTS) service HTTPS origin, \
e.g. https://ots.example.com (no path). The CLI POSTs JSON to {{origin}}{{TSAFE_OTS_CREATE_PATH}} (default /create)."
)
})?;
let trimmed = url.trim_end_matches('/');
if trimmed.is_empty() {
anyhow::bail!("TSAFE_OTS_BASE_URL must not be empty");
}
let allow_insecure = std::env::var("TSAFE_OTS_ALLOW_INSECURE")
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if !allow_insecure && !trimmed.starts_with("https://") {
anyhow::bail!("TSAFE_OTS_BASE_URL must start with https:// — secrets must not be sent over plain HTTP");
}
Ok(trimmed.to_string())
}
fn ots_create_url(base: &str) -> String {
let path = std::env::var("TSAFE_OTS_CREATE_PATH").unwrap_or_else(|_| "/create".into());
let path = if path.starts_with('/') {
path
} else {
format!("/{path}")
};
format!("{base}{path}")
}
fn parse_ots_response_body(body: &str) -> Option<String> {
let trimmed = body.trim();
if let Ok(v) = serde_json::from_str::<serde_json::Value>(trimmed) {
if let Some(s) = v.get("secret").and_then(|x| x.as_str()) {
return Some(s.to_string());
}
if let Some(s) = v.get("plaintext").and_then(|x| x.as_str()) {
return Some(s.to_string());
}
if let Some(s) = v.get("value").and_then(|x| x.as_str()) {
return Some(s.to_string());
}
}
trimmed
.split_once("id=\"secret-content\"")
.and_then(|(_, after)| after.split_once('>'))
.and_then(|(_, after)| after.split_once('<'))
.map(|(val, _)| val.trim().to_owned())
}
fn read_ots_response_body(response: ureq::Response) -> Result<String> {
response
.into_string()
.context("failed to read OTS service response")
}
fn is_retryable_ots_transport_error(message: &str) -> bool {
let msg = message.to_ascii_lowercase();
msg.contains("timed out")
|| msg.contains("timeout")
|| msg.contains("connection reset")
|| msg.contains("connection refused")
|| msg.contains("econnreset")
|| msg.contains("econnrefused")
|| msg.contains("temporar")
|| msg.contains("actively refused")
|| msg.contains("forcibly closed")
|| msg.contains("os error 10054") || msg.contains("os error 10060") || msg.contains("os error 10061") || msg.contains("os error 10065") }
fn jittered_ots_delay_secs(base_secs: u64) -> u64 {
if base_secs == 0 {
return 0;
}
let jitter_cap = std::cmp::max(1, base_secs / 4);
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.subsec_nanos() as u64)
.unwrap_or(0);
base_secs + (nanos % (jitter_cap + 1))
}
fn call_ots_with_transport_retry(
make_request: impl FnMut() -> OtsHttpResult<ureq::Response>,
) -> OtsHttpResult<ureq::Response> {
call_ots_with_transport_retry_policy(OtsRetryPolicy::default_policy(), make_request)
}
fn call_ots_with_transport_retry_policy(
policy: OtsRetryPolicy,
mut make_request: impl FnMut() -> OtsHttpResult<ureq::Response>,
) -> OtsHttpResult<ureq::Response> {
let mut transient_attempt = 0u32;
loop {
match make_request() {
Ok(resp) => return Ok(resp),
Err(err)
if matches!(err.as_ref(), ureq::Error::Transport(t) if transient_attempt < policy.max_transient_retries
&& is_retryable_ots_transport_error(t.to_string().as_str())) =>
{
let backoff = policy.base_retry_secs * 2u64.pow(transient_attempt);
let wait = std::cmp::min(jittered_ots_delay_secs(backoff), 30);
std::thread::sleep(std::time::Duration::from_secs(wait));
transient_attempt += 1;
}
Err(err) => return Err(err),
}
}
}
fn consume_ots_via_get_with(
policy: OtsRetryPolicy,
mut get_call: impl FnMut() -> OtsHttpResult<ureq::Response>,
) -> Result<String> {
match call_ots_with_transport_retry_policy(policy, &mut get_call) {
Ok(response) => read_ots_response_body(response),
Err(err) => match *err {
ureq::Error::Status(404 | 405 | 501, _) => {
anyhow::bail!("secret not found — already consumed or expired")
}
ureq::Error::Status(code, response) => {
let body = response.into_string().unwrap_or_default();
anyhow::bail!("OTS service returned HTTP {code} — {body}");
}
other => Err(anyhow::Error::new(other).context("failed to reach OTS service")),
},
}
}
fn consume_ots_response_body_with(
policy: OtsRetryPolicy,
mut post_call: impl FnMut() -> OtsHttpResult<ureq::Response>,
mut get_call: impl FnMut() -> OtsHttpResult<ureq::Response>,
) -> Result<String> {
match call_ots_with_transport_retry_policy(policy, &mut post_call) {
Ok(response) => return read_ots_response_body(response),
Err(err) => match *err {
ureq::Error::Status(404, _) => {
return consume_ots_via_get_with(policy, &mut get_call);
}
ureq::Error::Status(405 | 501, _) => {}
ureq::Error::Status(code, response) => {
let body = response.into_string().unwrap_or_default();
anyhow::bail!("OTS service returned HTTP {code} — {body}");
}
ureq::Error::Transport(t)
if is_retryable_ots_transport_error(t.to_string().as_str()) =>
{
return consume_ots_via_get_with(policy, &mut get_call);
}
other => return Err(anyhow::Error::new(other).context("failed to reach OTS service")),
},
}
consume_ots_via_get_with(policy, get_call)
}
fn consume_ots_response_body(agent: &ureq::Agent, consume_url: &str) -> Result<String> {
consume_ots_response_body_with(
OtsRetryPolicy::default_policy(),
|| agent.post(consume_url).call().map_err(Box::new),
|| agent.get(consume_url).call().map_err(Box::new),
)
}
pub(crate) fn cmd_share_once(profile: &str, key: &str, ttl: &str) -> Result<()> {
let vault = open_vault(profile)?;
let value = vault.get(key)?;
let base = resolve_ots_base_url()?;
let url = ots_create_url(&base);
let body = serde_json::json!({ "secret": &*value, "ttl": ttl });
let agent = build_http_agent();
let resp = match call_ots_with_transport_retry(|| {
agent
.post(&url)
.set("Content-Type", "application/json")
.send_json(body.clone())
.map_err(Box::new)
}) {
Ok(r) => r,
Err(err) => match *err {
ureq::Error::Status(code, response) => {
let server_msg = response.into_string().unwrap_or_default();
anyhow::bail!("OTS service returned HTTP {code} — {server_msg}");
}
other => {
return Err(anyhow::Error::new(other).context(format!(
"failed to reach OTS service at {url} — check TSAFE_OTS_BASE_URL"
)));
}
},
};
let payload: serde_json::Value = resp.into_json().context(
"OTS service returned an unexpected response (expected JSON with a \"url\" field)",
)?;
let one_time_url = payload["url"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("OTS response missing 'url' field"))?;
audit(profile)
.append(&AuditEntry::success(profile, "share-once", Some(key)))
.ok();
emit_event(profile, "share-once", Some(key));
println!("{} One-time link (expires: {ttl}):", "✓".green());
println!("{one_time_url}");
println!(
"{} Share this URL — retrieve once per your OTS server policy.",
"i".blue()
);
Ok(())
}
pub(crate) fn cmd_receive_once(profile: &str, url: &str, store: Option<&str>) -> Result<()> {
let consume_url = url.split_once('#').map(|(b, _)| b).unwrap_or(url);
let allow_insecure = std::env::var("TSAFE_OTS_ALLOW_INSECURE")
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if !allow_insecure && !consume_url.starts_with("https://") {
anyhow::bail!("URL must use https://");
}
let agent = build_http_agent();
let body = consume_ots_response_body(&agent, consume_url)?;
let plaintext = parse_ots_response_body(&body).ok_or_else(|| {
anyhow::anyhow!(
"could not parse secret from response (expected JSON with \"secret\" or HTML id=\"secret-content\") — link may already be consumed"
)
})?;
match store {
Some(key) => {
let mut vault = open_vault(profile)?;
vault.set(key, &plaintext, HashMap::new())?;
audit(profile)
.append(&AuditEntry::success(profile, "receive-once", Some(key)))
.ok();
emit_event(profile, "receive-once", Some(key));
println!("{} Stored received secret under key '{key}'.", "✓".green());
}
None => {
audit(profile)
.append(&AuditEntry::success(profile, "receive-once", None))
.ok();
emit_event(profile, "receive-once", None);
println!("{plaintext}");
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn retryable_ots_transport_classifier_detects_timeout() {
assert!(is_retryable_ots_transport_error("operation timed out"));
assert!(is_retryable_ots_transport_error("Connection reset by peer"));
assert!(!is_retryable_ots_transport_error("invalid request payload"));
}
#[test]
fn retryable_ots_transport_classifier_detects_windows_wsa_wording() {
assert!(is_retryable_ots_transport_error(
"No connection could be made because the target machine actively refused it."
));
assert!(is_retryable_ots_transport_error(
"An existing connection was forcibly closed by the remote host."
));
assert!(is_retryable_ots_transport_error(
"transport error: io error: os error 10054"
)); assert!(is_retryable_ots_transport_error(
"transport error: io error: os error 10060"
)); assert!(is_retryable_ots_transport_error(
"transport error: io error: os error 10061"
)); assert!(is_retryable_ots_transport_error(
"transport error: io error: os error 10065"
));
assert!(!is_retryable_ots_transport_error(
"HTTP 401 Unauthorized: token expired"
));
assert!(!is_retryable_ots_transport_error(
"HTTP 422 Unprocessable Entity: invalid TTL"
));
}
#[test]
fn jittered_ots_delay_stays_within_25_percent_bound() {
let base = 20;
let jittered = jittered_ots_delay_secs(base);
assert!(jittered >= base);
assert!(jittered <= base + (base / 4));
}
#[test]
fn ots_transport_retry_succeeds_after_transient_failure() {
let mut server = mockito::Server::new();
let mock = server
.mock("GET", "/ok")
.with_status(200)
.with_body("ok")
.create();
let mut attempts = 0usize;
let result = call_ots_with_transport_retry_policy(
OtsRetryPolicy {
max_transient_retries: 1,
base_retry_secs: 0,
},
|| {
attempts += 1;
if attempts == 1 {
ureq::get("http://127.0.0.1:1/retry-test")
.call()
.map_err(Box::new)
} else {
let url = format!("{}/ok", server.url());
ureq::get(&url).call().map_err(Box::new)
}
},
);
assert!(result.is_ok());
assert_eq!(attempts, 2);
mock.assert();
}
#[test]
fn ots_transport_retry_exhaustion_returns_transport_error() {
let mut attempts = 0usize;
let result = call_ots_with_transport_retry_policy(
OtsRetryPolicy {
max_transient_retries: 2,
base_retry_secs: 0,
},
|| {
attempts += 1;
ureq::get("http://127.0.0.1:1/exhaustion")
.call()
.map_err(Box::new)
},
);
assert!(matches!(result, Err(err) if matches!(err.as_ref(), ureq::Error::Transport(_))));
assert_eq!(attempts, 3);
}
#[test]
fn consume_ots_falls_back_to_get_after_post_transport_exhaustion() {
let mut server = mockito::Server::new();
let get_url = format!("{}/s/fallback", server.url());
let get_mock = server
.mock("GET", "/s/fallback")
.with_status(200)
.with_header("Content-Type", "application/json")
.with_body(r#"{"secret":"via-get"}"#)
.create();
let body = consume_ots_response_body_with(
OtsRetryPolicy {
max_transient_retries: 0,
base_retry_secs: 0,
},
|| {
ureq::get("http://127.0.0.1:1/post-fail")
.call()
.map_err(Box::new)
},
|| ureq::get(&get_url).call().map_err(Box::new),
)
.unwrap();
assert_eq!(body, r#"{"secret":"via-get"}"#);
get_mock.assert();
}
#[test]
fn consume_ots_reports_get_fallback_http_error_after_post_transport_exhaustion() {
let mut server = mockito::Server::new();
let get_url = format!("{}/s/fallback-error", server.url());
let get_mock = server
.mock("GET", "/s/fallback-error")
.with_status(500)
.with_body("backend failure")
.create();
let err = consume_ots_response_body_with(
OtsRetryPolicy {
max_transient_retries: 0,
base_retry_secs: 0,
},
|| {
ureq::get("http://127.0.0.1:1/post-fail")
.call()
.map_err(Box::new)
},
|| ureq::get(&get_url).call().map_err(Box::new),
)
.unwrap_err();
let msg = err.to_string();
assert!(msg.contains("OTS service returned HTTP 500"));
assert!(msg.contains("backend failure"));
get_mock.assert();
}
}