relay_actions/cache/
agent.rs1use anyhow::{Context, Result, bail};
2use serde::{Deserialize, Serialize};
3use std::{
4 collections::HashMap,
5 path::Path,
6 time::{SystemTime, UNIX_EPOCH},
7};
8
9use super::Cache;
10use relay_lib::prelude::AgentId;
11
12const FILE: &str = "agents.json";
13const SRV_PREFIX: &str = "_relay._tcp";
14
15#[derive(Debug, Default, Serialize, Deserialize)]
16pub struct AgentsCache {
17 pub agents: HashMap<AgentId, CachedAgent>,
18}
19
20impl Cache for AgentsCache {
21 fn load(base: &Path) -> Result<Self> {
22 let content = std::fs::read(base.join(FILE)).unwrap_or_default();
23 let agents = serde_json::from_slice(&content).unwrap_or_default();
24 Ok(Self { agents })
25 }
26
27 fn save(&self, base: &Path) -> Result<()> {
28 let content = serde_json::to_vec_pretty(&self.agents)?;
29 std::fs::create_dir_all(base)?;
30 std::fs::write(base.join(FILE), content)?;
31 Ok(())
32 }
33
34 fn cleanup(&mut self) -> Result<()> {
35 self.agents.retain(|_, agent| agent.is_valid());
36 Ok(())
37 }
38}
39
40impl AgentsCache {
41 pub fn fetch(&mut self, agent_id: &AgentId) -> Result<CachedAgent> {
42 if let Some(agent) = self.agents.get(agent_id)
43 && agent.is_valid()
44 {
45 return Ok(agent.clone());
46 }
47
48 let rt = tokio::runtime::Runtime::new()?;
49 let agent = rt.block_on(async {
50 match Self::fetch_via_srv(agent_id).await {
51 Ok(agent) => Ok(agent),
52 Err(srv_err) => Self::fetch_via_well_known(agent_id).await.with_context(|| {
53 format!(
54 "SRV failed: {} and well-known fallback also failed",
55 srv_err
56 )
57 }),
58 }
59 })?;
60
61 self.agents.insert(agent_id.clone(), agent.clone());
62 Ok(agent)
63 }
64
65 pub fn url(&mut self, agent_id: &AgentId) -> Result<reqwest::Url> {
66 let agent = self.fetch(agent_id)?;
67
68 if !agent.is_valid() {
69 bail!("Cached agent information is expired");
70 }
71
72 reqwest::Url::parse(&format!("https://{}:{}", agent.host, agent.port))
73 .context("Failed to construct agent URL")
74 }
75
76 async fn fetch_via_srv(agent_id: &AgentId) -> Result<CachedAgent> {
77 let resolver = hickory_resolver::Resolver::builder_tokio().unwrap().build();
78
79 let response = resolver
80 .srv_lookup(format!("{}.{}", SRV_PREFIX, agent_id))
81 .await
82 .context("SRV lookup failed")?;
83
84 let mut records: Vec<_> = response.iter().collect();
85
86 if records.is_empty() {
87 bail!("No SRV records found");
88 }
89
90 records.sort_by_key(|r| r.priority());
91
92 for record in records {
93 let host = record
94 .target()
95 .to_string()
96 .trim_end_matches('.')
97 .to_string();
98
99 let port = record.port();
100
101 if host.is_empty() || port == 0 {
102 continue;
103 }
104
105 if resolver.lookup_ip(host.clone()).await.is_err() {
106 continue;
107 }
108
109 return Ok(CachedAgent {
110 host,
111 port,
112 fetched_at: now_secs(),
113 ttl_secs: 24 * 3600,
114 });
115 }
116
117 bail!("No valid SRV records")
118 }
119
120 async fn fetch_via_well_known(agent_id: &AgentId) -> Result<CachedAgent> {
121 let url = format!("https://{}/.well-known/relay", agent_id);
122
123 let resp = reqwest::get(&url)
124 .await
125 .context("Failed to fetch /.well-known/relay")?;
126
127 if !resp.status().is_success() {
128 bail!("well-known returned non-success status: {}", resp.status());
129 }
130
131 #[derive(Deserialize)]
132 struct WellKnown {
133 agent: AgentInfo,
134 }
135
136 #[derive(Deserialize)]
137 struct AgentInfo {
138 url: String,
139 }
140
141 let data: WellKnown = resp
142 .json()
143 .await
144 .context("Invalid JSON in well-known response")?;
145
146 let url = reqwest::Url::parse(&data.agent.url)
147 .context("Invalid agent URL in well-known response")?;
148
149 let host = url
150 .host_str()
151 .ok_or_else(|| anyhow::anyhow!("Missing host in URL"))?
152 .to_string();
153
154 let port = url
155 .port_or_known_default()
156 .ok_or_else(|| anyhow::anyhow!("Missing port in URL"))?;
157
158 Ok(CachedAgent {
159 host,
160 port,
161 fetched_at: now_secs(),
162 ttl_secs: 24 * 3600,
163 })
164 }
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct CachedAgent {
169 pub host: String,
170 pub port: u16,
171 pub fetched_at: u64,
172 pub ttl_secs: u64,
173}
174
175impl CachedAgent {
176 pub fn is_valid(&self) -> bool {
177 let now = now_secs();
178 now < self.fetched_at + self.ttl_secs
179 }
180}
181
182fn now_secs() -> u64 {
183 SystemTime::now()
184 .duration_since(UNIX_EPOCH)
185 .unwrap()
186 .as_secs()
187}