Skip to main content

reifydb_engine/
remote.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4#[cfg(not(reifydb_single_threaded))]
5use std::{
6	collections::HashMap,
7	sync::{Mutex, mpsc},
8};
9
10#[cfg(not(reifydb_single_threaded))]
11use reifydb_client::{GrpcClient, WireFormat};
12#[cfg(not(reifydb_single_threaded))]
13use reifydb_runtime::SharedRuntime;
14#[cfg(not(reifydb_single_threaded))]
15use reifydb_type::error::Diagnostic;
16use reifydb_type::error::Error;
17#[cfg(not(reifydb_single_threaded))]
18use reifydb_type::{params::Params, value::frame::frame::Frame};
19
20#[cfg(not(reifydb_single_threaded))]
21type CacheKey = (String, Option<String>);
22
23#[cfg(not(reifydb_single_threaded))]
24pub struct RemoteRegistry {
25	runtime: SharedRuntime,
26	clients: Mutex<HashMap<CacheKey, GrpcClient>>,
27}
28
29#[cfg(not(reifydb_single_threaded))]
30impl RemoteRegistry {
31	pub fn new(runtime: SharedRuntime) -> Self {
32		Self {
33			runtime,
34			clients: Mutex::new(HashMap::new()),
35		}
36	}
37
38	pub fn forward_query(
39		&self,
40		address: &str,
41		rql: &str,
42		params: Params,
43		token: Option<&str>,
44	) -> Result<Vec<Frame>, Error> {
45		let params_opt = match &params {
46			Params::None => None,
47			_ => Some(params),
48		};
49
50		let client = self.get_or_connect(address, token)?;
51		match self.run_query(&client, rql, params_opt.clone()) {
52			Ok(frames) => Ok(frames),
53			Err(e) if is_transport_error(&e) => {
54				self.evict(address, token);
55				let client = self.get_or_connect(address, token)?;
56				self.run_query(&client, rql, params_opt)
57			}
58			Err(e) => Err(e),
59		}
60	}
61
62	fn run_query(&self, client: &GrpcClient, rql: &str, params: Option<Params>) -> Result<Vec<Frame>, Error> {
63		let client = client.clone();
64		let rql = rql.to_string();
65		let (tx, rx) = mpsc::sync_channel(1);
66
67		self.runtime.spawn(async move {
68			let result = client.query(&rql, params).await;
69			let _ = tx.send(result);
70		});
71
72		rx.recv().map_err(|_| {
73			Error(Box::new(Diagnostic {
74				code: "REMOTE_002".to_string(),
75				message: "remote query channel closed".to_string(),
76				..Default::default()
77			}))
78		})?
79	}
80
81	fn get_or_connect(&self, address: &str, token: Option<&str>) -> Result<GrpcClient, Error> {
82		let key = cache_key(address, token);
83		if let Some(c) = self.clients.lock().unwrap().get(&key) {
84			return Ok(c.clone());
85		}
86		let client = self.connect(address, token)?;
87		self.clients.lock().unwrap().entry(key).or_insert_with(|| client.clone());
88		Ok(client)
89	}
90
91	fn evict(&self, address: &str, token: Option<&str>) {
92		self.clients.lock().unwrap().remove(&cache_key(address, token));
93	}
94
95	#[cfg(test)]
96	fn cache_len(&self) -> usize {
97		self.clients.lock().unwrap().len()
98	}
99
100	fn connect(&self, address: &str, token: Option<&str>) -> Result<GrpcClient, Error> {
101		let address_owned = address.to_string();
102		let (tx, rx) = mpsc::sync_channel(1);
103
104		self.runtime.spawn(async move {
105			let result = GrpcClient::connect(&address_owned, WireFormat::Proto).await;
106			let _ = tx.send(result);
107		});
108
109		let mut client = rx.recv().map_err(|_| {
110			Error(Box::new(Diagnostic {
111				code: "REMOTE_002".to_string(),
112				message: "remote connect channel closed".to_string(),
113				..Default::default()
114			}))
115		})??;
116		if let Some(token) = token {
117			client.authenticate(token);
118		}
119		Ok(client)
120	}
121}
122
123#[cfg(not(reifydb_single_threaded))]
124fn cache_key(address: &str, token: Option<&str>) -> CacheKey {
125	(address.to_string(), token.map(str::to_string))
126}
127
128#[cfg(not(reifydb_single_threaded))]
129fn is_transport_error(err: &Error) -> bool {
130	err.0.code.starts_with("GRPC_")
131}
132
133pub fn is_remote_query(err: &Error) -> bool {
134	err.0.code == "REMOTE_001"
135}
136
137pub fn extract_remote_address(err: &Error) -> Option<String> {
138	err.0.notes.iter().find_map(|n| n.strip_prefix("Remote gRPC address: ")).map(|s| s.to_string())
139}
140
141pub fn extract_remote_token(err: &Error) -> Option<String> {
142	err.0.notes.iter().find_map(|n| n.strip_prefix("Remote token: ")).map(|s| s.to_string())
143}
144
145#[cfg(test)]
146mod tests {
147	use reifydb_runtime::{SharedRuntime, SharedRuntimeConfig, pool::PoolConfig};
148	use reifydb_type::{error::Diagnostic, fragment::Fragment};
149
150	use super::*;
151
152	fn make_remote_error(address: &str) -> Error {
153		Error(Box::new(Diagnostic {
154			code: "REMOTE_001".to_string(),
155			message: format!(
156				"Remote namespace 'remote_ns': source 'users' is on remote instance at {}",
157				address
158			),
159			notes: vec![
160				"Namespace 'remote_ns' is configured as a remote namespace".to_string(),
161				format!("Remote gRPC address: {}", address),
162			],
163			fragment: Fragment::None,
164			..Default::default()
165		}))
166	}
167
168	#[test]
169	fn test_is_remote_query_true() {
170		let err = make_remote_error("http://localhost:50051");
171		assert!(is_remote_query(&err));
172	}
173
174	#[test]
175	fn test_is_remote_query_false() {
176		let err = Error(Box::new(Diagnostic {
177			code: "CATALOG_001".to_string(),
178			message: "Table not found".to_string(),
179			fragment: Fragment::None,
180			..Default::default()
181		}));
182		assert!(!is_remote_query(&err));
183	}
184
185	#[test]
186	fn test_extract_remote_address() {
187		let err = make_remote_error("http://localhost:50051");
188		assert_eq!(extract_remote_address(&err), Some("http://localhost:50051".to_string()));
189	}
190
191	#[test]
192	fn test_extract_remote_address_missing() {
193		let err = Error(Box::new(Diagnostic {
194			code: "REMOTE_001".to_string(),
195			message: "Some error".to_string(),
196			notes: vec![],
197			fragment: Fragment::None,
198			..Default::default()
199		}));
200		assert_eq!(extract_remote_address(&err), None);
201	}
202
203	#[test]
204	fn test_extract_remote_token() {
205		let err = Error(Box::new(Diagnostic {
206			code: "REMOTE_001".to_string(),
207			message: "Remote namespace".to_string(),
208			notes: vec![
209				"Namespace 'test' is configured as a remote namespace".to_string(),
210				"Remote gRPC address: http://localhost:50051".to_string(),
211				"Remote token: my-secret".to_string(),
212			],
213			fragment: Fragment::None,
214			..Default::default()
215		}));
216		assert_eq!(extract_remote_token(&err), Some("my-secret".to_string()));
217	}
218
219	#[test]
220	fn test_extract_remote_token_missing() {
221		let err = make_remote_error("http://localhost:50051");
222		assert_eq!(extract_remote_token(&err), None);
223	}
224
225	#[test]
226	fn test_is_transport_error() {
227		let grpc_err = Error(Box::new(Diagnostic {
228			code: "GRPC_Unavailable".to_string(),
229			message: "channel closed".to_string(),
230			..Default::default()
231		}));
232		assert!(is_transport_error(&grpc_err));
233
234		let app_err = Error(Box::new(Diagnostic {
235			code: "CATALOG_001".to_string(),
236			message: "Table not found".to_string(),
237			..Default::default()
238		}));
239		assert!(!is_transport_error(&app_err));
240	}
241
242	#[test]
243	fn test_cache_key_distinguishes_tokens() {
244		assert_ne!(cache_key("addr", Some("a")), cache_key("addr", Some("b")));
245		assert_ne!(cache_key("addr", None), cache_key("addr", Some("a")));
246		assert_eq!(cache_key("addr", Some("a")), cache_key("addr", Some("a")));
247	}
248
249	#[test]
250	fn test_connect_failure_does_not_pollute_cache() {
251		let runtime = SharedRuntime::from_config(SharedRuntimeConfig::default(), PoolConfig::default());
252		let registry = RemoteRegistry::new(runtime);
253
254		// 127.0.0.1:1 is reserved; connect must fail fast.
255		let err = registry.forward_query("http://127.0.0.1:1", "FROM x", Params::None, None).unwrap_err();
256		assert!(err.0.code.starts_with("GRPC_") || err.0.code == "REMOTE_002");
257		assert_eq!(registry.cache_len(), 0);
258	}
259
260	#[test]
261	fn test_evict_missing_key_is_noop() {
262		let runtime = SharedRuntime::from_config(SharedRuntimeConfig::default(), PoolConfig::default());
263		let registry = RemoteRegistry::new(runtime);
264		registry.evict("http://127.0.0.1:1", None);
265		registry.evict("http://127.0.0.1:1", Some("tok"));
266		assert_eq!(registry.cache_len(), 0);
267	}
268}