Skip to main content

relay_actions/cache/
agent.rs

1use anyhow::{Context, Result, bail};
2use serde::{Deserialize, Serialize};
3use std::{
4    collections::HashMap,
5    path::Path,
6    time::{SystemTime, UNIX_EPOCH},
7};
8
9use super::Cache;
10use relay_lib::prelude::AgentId;
11
12const FILE: &str = "agents.json";
13const SRV_PREFIX: &str = "_relay._tcp";
14
15#[derive(Debug, Default, Serialize, Deserialize)]
16pub struct AgentsCache {
17    pub agents: HashMap<AgentId, CachedAgent>,
18}
19
20impl Cache for AgentsCache {
21    fn load(base: &Path) -> Result<Self> {
22        let content = std::fs::read(base.join(FILE)).unwrap_or_default();
23        let agents = serde_json::from_slice(&content).unwrap_or_default();
24        Ok(Self { agents })
25    }
26
27    fn save(&self, base: &Path) -> Result<()> {
28        let content = serde_json::to_vec_pretty(&self.agents)?;
29        std::fs::create_dir_all(base)?;
30        std::fs::write(base.join(FILE), content)?;
31        Ok(())
32    }
33
34    fn cleanup(&mut self) -> Result<()> {
35        self.agents.retain(|_, agent| agent.is_valid());
36        Ok(())
37    }
38}
39
40impl AgentsCache {
41    pub fn fetch(&mut self, agent_id: &AgentId) -> Result<CachedAgent> {
42        if let Some(agent) = self.agents.get(agent_id)
43            && agent.is_valid()
44        {
45            return Ok(agent.clone());
46        }
47
48        let rt = tokio::runtime::Runtime::new()?;
49        let agent = rt.block_on(async {
50            match Self::fetch_via_srv(agent_id).await {
51                Ok(agent) => Ok(agent),
52                Err(srv_err) => Self::fetch_via_well_known(agent_id).await.with_context(|| {
53                    format!(
54                        "SRV failed: {} and well-known fallback also failed",
55                        srv_err
56                    )
57                }),
58            }
59        })?;
60
61        self.agents.insert(agent_id.clone(), agent.clone());
62        Ok(agent)
63    }
64
65    pub fn url(&mut self, agent_id: &AgentId) -> Result<reqwest::Url> {
66        let agent = self.fetch(agent_id)?;
67
68        if !agent.is_valid() {
69            bail!("Cached agent information is expired");
70        }
71
72        reqwest::Url::parse(&format!("https://{}:{}", agent.host, agent.port))
73            .context("Failed to construct agent URL")
74    }
75
76    async fn fetch_via_srv(agent_id: &AgentId) -> Result<CachedAgent> {
77        let resolver = hickory_resolver::Resolver::builder_tokio().unwrap().build();
78
79        let response = resolver
80            .srv_lookup(format!("{}.{}", SRV_PREFIX, agent_id))
81            .await
82            .context("SRV lookup failed")?;
83
84        let mut records: Vec<_> = response.iter().collect();
85
86        if records.is_empty() {
87            bail!("No SRV records found");
88        }
89
90        records.sort_by_key(|r| r.priority());
91
92        for record in records {
93            let host = record
94                .target()
95                .to_string()
96                .trim_end_matches('.')
97                .to_string();
98
99            let port = record.port();
100
101            if host.is_empty() || port == 0 {
102                continue;
103            }
104
105            if resolver.lookup_ip(host.clone()).await.is_err() {
106                continue;
107            }
108
109            return Ok(CachedAgent {
110                host,
111                port,
112                fetched_at: now_secs(),
113                ttl_secs: 24 * 3600,
114            });
115        }
116
117        bail!("No valid SRV records")
118    }
119
120    async fn fetch_via_well_known(agent_id: &AgentId) -> Result<CachedAgent> {
121        let url = format!("https://{}/.well-known/relay", agent_id);
122
123        let resp = reqwest::get(&url)
124            .await
125            .context("Failed to fetch /.well-known/relay")?;
126
127        if !resp.status().is_success() {
128            bail!("well-known returned non-success status: {}", resp.status());
129        }
130
131        #[derive(Deserialize)]
132        struct WellKnown {
133            agent: AgentInfo,
134        }
135
136        #[derive(Deserialize)]
137        struct AgentInfo {
138            url: String,
139        }
140
141        let data: WellKnown = resp
142            .json()
143            .await
144            .context("Invalid JSON in well-known response")?;
145
146        let url = reqwest::Url::parse(&data.agent.url)
147            .context("Invalid agent URL in well-known response")?;
148
149        let host = url
150            .host_str()
151            .ok_or_else(|| anyhow::anyhow!("Missing host in URL"))?
152            .to_string();
153
154        let port = url
155            .port_or_known_default()
156            .ok_or_else(|| anyhow::anyhow!("Missing port in URL"))?;
157
158        Ok(CachedAgent {
159            host,
160            port,
161            fetched_at: now_secs(),
162            ttl_secs: 24 * 3600,
163        })
164    }
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct CachedAgent {
169    pub host: String,
170    pub port: u16,
171    pub fetched_at: u64,
172    pub ttl_secs: u64,
173}
174
175impl CachedAgent {
176    pub fn is_valid(&self) -> bool {
177        let now = now_secs();
178        now < self.fetched_at + self.ttl_secs
179    }
180}
181
182fn now_secs() -> u64 {
183    SystemTime::now()
184        .duration_since(UNIX_EPOCH)
185        .unwrap()
186        .as_secs()
187}