relay_actions/cache/
agent.rs1use anyhow::{Context, Result, bail};
2use serde::{Deserialize, Serialize};
3use std::{
4 collections::HashMap,
5 path::Path,
6 time::{SystemTime, UNIX_EPOCH},
7};
8
9use super::Cache;
10use relay_lib::prelude::AgentId;
11
12const FILE: &str = "agents.ron";
13const SRV_PREFIX: &str = "_relay._tcp";
14
15#[derive(Debug, Default, Serialize, Deserialize)]
16pub struct AgentsCache {
17 pub agents: HashMap<AgentId, CachedAgent>,
18}
19
20impl Cache for AgentsCache {
21 fn load(base: &Path) -> Result<Self> {
22 let content = std::fs::read_to_string(base.join(FILE)).unwrap_or_default();
23 let agents = ron::from_str(&content).unwrap_or_default();
24 Ok(Self { agents })
25 }
26
27 fn save(&self, base: &Path) -> Result<()> {
28 let content = ron::to_string(&self.agents)?.into_bytes();
29 std::fs::create_dir_all(base)?;
30 std::fs::write(base.join(FILE), content)?;
31 Ok(())
32 }
33
34 fn cleanup(&mut self) -> Result<()> {
35 self.agents.retain(|_, agent| agent.is_valid());
36 Ok(())
37 }
38}
39
40impl AgentsCache {
41 pub fn fetch(&mut self, agent_id: &AgentId) -> Result<CachedAgent> {
42 if let Some(agent) = self.agents.get(agent_id)
43 && agent.is_valid()
44 {
45 return Ok(agent.clone());
46 }
47
48 let rt = tokio::runtime::Runtime::new()?;
49 let agent = rt.block_on(async {
50 match Self::fetch_via_srv(agent_id).await {
51 Ok(agent) => Ok(agent),
52 Err(srv_err) => Self::fetch_via_well_known(agent_id).await.with_context(|| {
53 format!(
54 "SRV failed: {} and well-known fallback also failed",
55 srv_err
56 )
57 }),
58 }
59 })?;
60
61 self.agents.insert(agent_id.clone(), agent.clone());
62 Ok(agent)
63 }
64
65 pub fn url(&mut self, agent_id: &AgentId) -> Result<reqwest::Url> {
66 let agent = self.fetch(agent_id)?;
67
68 if !agent.is_valid() {
69 bail!("Cached agent information is expired");
70 }
71
72 reqwest::Url::parse(&format!("{}://{}:{}", protocol(), agent.host, agent.port))
73 .context("Failed to construct agent URL")
74 }
75
76 async fn fetch_via_srv(agent_id: &AgentId) -> Result<CachedAgent> {
77 let resolver = hickory_resolver::Resolver::builder_tokio().unwrap().build();
78
79 let response = resolver
80 .srv_lookup(format!("{}.{}", SRV_PREFIX, agent_id))
81 .await
82 .context("SRV lookup failed")?;
83
84 let mut records: Vec<_> = response.iter().collect();
85
86 if records.is_empty() {
87 bail!("No SRV records found");
88 }
89
90 records.sort_by_key(|r| r.priority());
91
92 for record in records {
93 let host = record
94 .target()
95 .to_string()
96 .trim_end_matches('.')
97 .to_string();
98
99 let port = record.port();
100
101 if host.is_empty() || port == 0 {
102 continue;
103 }
104
105 if resolver.lookup_ip(host.clone()).await.is_err() {
106 continue;
107 }
108
109 return Ok(CachedAgent {
110 host,
111 port,
112 fetched_at: now_secs(),
113 ttl_secs: 24 * 3600,
114 });
115 }
116
117 bail!("No valid SRV records")
118 }
119
120 async fn fetch_via_well_known(agent_id: &AgentId) -> Result<CachedAgent> {
121 let host = agent_id.to_string();
122 let ports = if protocol() == "https" {
123 [433, 8433, 2525]
124 } else {
125 [80, 8080, 2525]
126 };
127
128 for port in ports {
129 let url = format!("{}://{}:{}/.well-known/relay", protocol(), host, port);
130
131 match Self::try_fetch(&url).await {
132 Ok(agent) => return Ok(agent),
133 Err(_) => {
134 continue;
135 }
136 }
137 }
138
139 bail!("well-known failed on all known ports")
140 }
141
142 async fn try_fetch(url: &str) -> Result<CachedAgent> {
143 let resp = reqwest::get(url).await.context("request failed")?;
144
145 if !resp.status().is_success() {
146 bail!("status {}", resp.status());
147 }
148
149 #[derive(Deserialize)]
150 struct WellKnown {
151 agent: AgentInfo,
152 }
153
154 #[derive(Deserialize)]
155 struct AgentInfo {
156 url: String,
157 }
158
159 let data: WellKnown = resp.json().await?;
160
161 let url = reqwest::Url::parse(&data.agent.url)?;
162
163 let host = url
164 .host_str()
165 .ok_or_else(|| anyhow::anyhow!("Missing host"))?;
166 let port = url
167 .port_or_known_default()
168 .ok_or_else(|| anyhow::anyhow!("Missing port"))?;
169
170 Ok(CachedAgent {
171 host: host.to_string(),
172 port,
173 fetched_at: now_secs(),
174 ttl_secs: 24 * 3600,
175 })
176 }
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct CachedAgent {
181 pub host: String,
182 pub port: u16,
183 pub fetched_at: u64,
184 pub ttl_secs: u64,
185}
186
187impl CachedAgent {
188 pub fn is_valid(&self) -> bool {
189 let now = now_secs();
190 now < self.fetched_at + self.ttl_secs
191 }
192}
193
194fn now_secs() -> u64 {
195 SystemTime::now()
196 .duration_since(UNIX_EPOCH)
197 .unwrap()
198 .as_secs()
199}
200
201#[inline]
204fn protocol() -> &'static str {
205 if cfg!(debug_assertions) {
206 match std::env::var("RELAY_NO_SSL") {
207 Ok(val) if val == "1" => "http",
208 Ok(_) => "https",
209 Err(_) => "https",
210 }
211 } else {
212 "https"
213 }
214}