Skip to main content

relay_actions/cache/
agent.rs

1use anyhow::{Context, Result, bail};
2use serde::{Deserialize, Serialize};
3use std::{
4    collections::HashMap,
5    path::Path,
6    time::{SystemTime, UNIX_EPOCH},
7};
8
9use super::Cache;
10use relay_lib::prelude::AgentId;
11
12const FILE: &str = "agents.ron";
13const SRV_PREFIX: &str = "_relay._tcp";
14
15#[derive(Debug, Default, Serialize, Deserialize)]
16pub struct AgentsCache {
17    pub agents: HashMap<AgentId, CachedAgent>,
18}
19
20impl Cache for AgentsCache {
21    fn load(base: &Path) -> Result<Self> {
22        let content = std::fs::read_to_string(base.join(FILE)).unwrap_or_default();
23        let agents = ron::from_str(&content).unwrap_or_default();
24        Ok(Self { agents })
25    }
26
27    fn save(&self, base: &Path) -> Result<()> {
28        let content = ron::to_string(&self.agents)?.into_bytes();
29        std::fs::create_dir_all(base)?;
30        std::fs::write(base.join(FILE), content)?;
31        Ok(())
32    }
33
34    fn cleanup(&mut self) -> Result<()> {
35        self.agents.retain(|_, agent| agent.is_valid());
36        Ok(())
37    }
38}
39
40impl AgentsCache {
41    pub fn fetch(&mut self, agent_id: &AgentId) -> Result<CachedAgent> {
42        if let Some(agent) = self.agents.get(agent_id)
43            && agent.is_valid()
44        {
45            return Ok(agent.clone());
46        }
47
48        let rt = tokio::runtime::Runtime::new()?;
49        let agent = rt.block_on(async {
50            match Self::fetch_via_srv(agent_id).await {
51                Ok(agent) => Ok(agent),
52                Err(srv_err) => Self::fetch_via_well_known(agent_id).await.with_context(|| {
53                    format!(
54                        "SRV failed: {} and well-known fallback also failed",
55                        srv_err
56                    )
57                }),
58            }
59        })?;
60
61        self.agents.insert(agent_id.clone(), agent.clone());
62        Ok(agent)
63    }
64
65    pub fn url(&mut self, agent_id: &AgentId) -> Result<reqwest::Url> {
66        let agent = self.fetch(agent_id)?;
67
68        if !agent.is_valid() {
69            bail!("Cached agent information is expired");
70        }
71
72        reqwest::Url::parse(&format!("{}://{}:{}", protocol(), agent.host, agent.port))
73            .context("Failed to construct agent URL")
74    }
75
76    async fn fetch_via_srv(agent_id: &AgentId) -> Result<CachedAgent> {
77        let resolver = hickory_resolver::Resolver::builder_tokio().unwrap().build();
78
79        let response = resolver
80            .srv_lookup(format!("{}.{}", SRV_PREFIX, agent_id))
81            .await
82            .context("SRV lookup failed")?;
83
84        let mut records: Vec<_> = response.iter().collect();
85
86        if records.is_empty() {
87            bail!("No SRV records found");
88        }
89
90        records.sort_by_key(|r| r.priority());
91
92        for record in records {
93            let host = record
94                .target()
95                .to_string()
96                .trim_end_matches('.')
97                .to_string();
98
99            let port = record.port();
100
101            if host.is_empty() || port == 0 {
102                continue;
103            }
104
105            if resolver.lookup_ip(host.clone()).await.is_err() {
106                continue;
107            }
108
109            return Ok(CachedAgent {
110                host,
111                port,
112                fetched_at: now_secs(),
113                ttl_secs: 24 * 3600,
114            });
115        }
116
117        bail!("No valid SRV records")
118    }
119
120    async fn fetch_via_well_known(agent_id: &AgentId) -> Result<CachedAgent> {
121        let host = agent_id.to_string();
122        let ports = if protocol() == "https" {
123            [433, 8433, 2525]
124        } else {
125            [80, 8080, 2525]
126        };
127
128        for port in ports {
129            let url = format!("{}://{}:{}/.well-known/relay", protocol(), host, port);
130
131            match Self::try_fetch(&url).await {
132                Ok(agent) => return Ok(agent),
133                Err(_) => {
134                    continue;
135                }
136            }
137        }
138
139        bail!("well-known failed on all known ports")
140    }
141
142    async fn try_fetch(url: &str) -> Result<CachedAgent> {
143        let resp = reqwest::get(url).await.context("request failed")?;
144
145        if !resp.status().is_success() {
146            bail!("status {}", resp.status());
147        }
148
149        #[derive(Deserialize)]
150        struct WellKnown {
151            agent: AgentInfo,
152        }
153
154        #[derive(Deserialize)]
155        struct AgentInfo {
156            url: String,
157        }
158
159        let data: WellKnown = resp.json().await?;
160
161        let url = reqwest::Url::parse(&data.agent.url)?;
162
163        let host = url
164            .host_str()
165            .ok_or_else(|| anyhow::anyhow!("Missing host"))?;
166        let port = url
167            .port_or_known_default()
168            .ok_or_else(|| anyhow::anyhow!("Missing port"))?;
169
170        Ok(CachedAgent {
171            host: host.to_string(),
172            port,
173            fetched_at: now_secs(),
174            ttl_secs: 24 * 3600,
175        })
176    }
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct CachedAgent {
181    pub host: String,
182    pub port: u16,
183    pub fetched_at: u64,
184    pub ttl_secs: u64,
185}
186
187impl CachedAgent {
188    pub fn is_valid(&self) -> bool {
189        let now = now_secs();
190        now < self.fetched_at + self.ttl_secs
191    }
192}
193
194fn now_secs() -> u64 {
195    SystemTime::now()
196        .duration_since(UNIX_EPOCH)
197        .unwrap()
198        .as_secs()
199}
200
201/// Useful fofr testing local agents without SSL. In debug mode, the protocol can be overridden to
202/// "http" by setting the environment variable RELAY_NO_SSL=1.
203#[inline]
204fn protocol() -> &'static str {
205    if cfg!(debug_assertions) {
206        match std::env::var("RELAY_NO_SSL") {
207            Ok(val) if val == "1" => "http",
208            Ok(_) => "https",
209            Err(_) => "https",
210        }
211    } else {
212        "https"
213    }
214}