1use std::collections::{HashMap, HashSet};
2use std::error::Error;
3use std::fmt;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::{
7 Arc,
8 atomic::{AtomicU64, Ordering},
9};
10use std::time::{Duration, SystemTime, UNIX_EPOCH};
11
12use surrealdb::{Connection, Surreal};
13use tokio::sync::RwLock;
14
15use crate::control::DocxControlPlane;
16use crate::store::SurrealDocStore;
17
18pub const RESERVED_SOLUTION: &str = "__discovery__";
21
22pub type BuildHandleFuture<C> =
24 Pin<Box<dyn Future<Output = Result<Arc<SolutionHandle<C>>, RegistryError>> + Send + 'static>>;
25pub type BuildHandleFn<C> = Arc<dyn Fn(String) -> BuildHandleFuture<C> + Send + Sync + 'static>;
27
28pub type DiscoverSolutionsFuture = Pin<Box<dyn Future<Output = Vec<String>> + Send + 'static>>;
30pub type DiscoverSolutionsFn = Arc<dyn Fn() -> DiscoverSolutionsFuture + Send + Sync + 'static>;
33
34#[derive(Clone)]
36pub struct SolutionRegistryConfig<C: Connection> {
37 pub ttl: Option<Duration>,
39 pub sweep_interval: Duration,
41 pub max_entries: Option<usize>,
43 pub build_handle: BuildHandleFn<C>,
45 pub health_check_after: Duration,
47 pub discover_solutions: Option<DiscoverSolutionsFn>,
50}
51
52impl<C: Connection> SolutionRegistryConfig<C> {
53 #[must_use]
54 pub fn new(build_handle: BuildHandleFn<C>) -> Self {
55 Self {
56 ttl: None,
57 sweep_interval: Duration::from_secs(60),
58 max_entries: None,
59 build_handle,
60 health_check_after: Duration::from_secs(60),
61 discover_solutions: None,
62 }
63 }
64
65 #[must_use]
66 pub fn with_discover_solutions(mut self, f: DiscoverSolutionsFn) -> Self {
67 self.discover_solutions = Some(f);
68 self
69 }
70
71 #[must_use]
72 pub const fn with_ttl(mut self, ttl: Duration) -> Self {
73 self.ttl = Some(ttl);
74 self
75 }
76
77 #[must_use]
78 pub const fn with_sweep_interval(mut self, sweep_interval: Duration) -> Self {
79 self.sweep_interval = sweep_interval;
80 self
81 }
82
83 #[must_use]
84 pub const fn with_max_entries(mut self, max_entries: usize) -> Self {
85 self.max_entries = Some(max_entries);
86 self
87 }
88
89 #[must_use]
90 pub const fn with_health_check_after(mut self, health_check_after: Duration) -> Self {
91 self.health_check_after = health_check_after;
92 self
93 }
94}
95
96#[derive(Debug)]
98pub enum RegistryError {
99 UnknownSolution(String),
101 CapacityReached { max: usize },
103 BuildFailed(String),
105}
106
107impl fmt::Display for RegistryError {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109 match self {
110 Self::UnknownSolution(solution) => write!(f, "unknown solution: {solution}"),
111 Self::CapacityReached { max } => {
112 write!(f, "solution registry capacity reached (max {max})")
113 }
114 Self::BuildFailed(message) => write!(f, "failed to build solution handle: {message}"),
115 }
116 }
117}
118
119impl Error for RegistryError {}
120
121pub struct SolutionHandle<C: Connection> {
123 db: Arc<Surreal<C>>,
124 store: SurrealDocStore<C>,
125 control: DocxControlPlane<C>,
126}
127
128impl<C: Connection> Clone for SolutionHandle<C> {
129 fn clone(&self) -> Self {
130 Self {
131 db: self.db.clone(),
132 store: self.store.clone(),
133 control: self.control.clone(),
134 }
135 }
136}
137
138impl<C: Connection> SolutionHandle<C> {
139 #[must_use]
140 pub fn new(db: Arc<Surreal<C>>) -> Self {
141 let store = SurrealDocStore::from_arc(db.clone());
142 let control = DocxControlPlane::with_store(store.clone());
143 Self { db, store, control }
144 }
145
146 #[must_use]
147 pub fn from_surreal(db: Surreal<C>) -> Self {
148 Self::new(Arc::new(db))
149 }
150
151 #[must_use]
152 pub fn db(&self) -> Arc<Surreal<C>> {
153 self.db.clone()
154 }
155
156 #[must_use]
157 pub fn store(&self) -> SurrealDocStore<C> {
158 self.store.clone()
159 }
160
161 #[must_use]
162 pub fn control(&self) -> DocxControlPlane<C> {
163 self.control.clone()
164 }
165
166 pub async fn list_databases(&self) -> Vec<String> {
168 self.store.list_databases().await.unwrap_or_default()
169 }
170
171 pub async fn ping(&self) -> bool {
173 self.db
174 .query("SELECT 1")
175 .await
176 .is_ok_and(|r| r.check().is_ok())
177 }
178}
179
180#[derive(Clone)]
182pub struct SolutionRegistry<C: Connection> {
183 inner: Arc<SolutionRegistryInner<C>>,
184}
185
186struct SolutionRegistryInner<C: Connection> {
188 entries: RwLock<HashMap<String, Arc<SolutionEntry<C>>>>,
189 config: SolutionRegistryConfig<C>,
190}
191
192struct SolutionEntry<C: Connection> {
194 handle: RwLock<Option<Arc<SolutionHandle<C>>>>,
195 last_used_ms: AtomicU64,
196}
197
198impl<C: Connection> SolutionEntry<C> {
199 fn new() -> Self {
200 Self {
201 handle: RwLock::new(None),
202 last_used_ms: AtomicU64::new(now_ms()),
203 }
204 }
205
206 fn touch(&self) {
207 self.last_used_ms.store(now_ms(), Ordering::Relaxed);
208 }
209
210 fn idle_for(&self, now_ms: u64) -> Duration {
211 let last = self.last_used_ms.load(Ordering::Relaxed);
212 Duration::from_millis(now_ms.saturating_sub(last))
213 }
214}
215
216impl<C: Connection> SolutionRegistry<C> {
217 #[must_use]
218 pub fn new(config: SolutionRegistryConfig<C>) -> Self {
219 Self {
220 inner: Arc::new(SolutionRegistryInner {
221 entries: RwLock::new(HashMap::new()),
222 config,
223 }),
224 }
225 }
226
227 pub async fn get_or_init(
235 &self,
236 solution: &str,
237 ) -> Result<Arc<SolutionHandle<C>>, RegistryError> {
238 let entry = {
239 let map = self.inner.entries.read().await;
240 map.get(solution).cloned()
241 };
242
243 let entry = if let Some(entry) = entry {
244 entry
245 } else {
246 let mut map = self.inner.entries.write().await;
247 if let Some(entry) = map.get(solution).cloned() {
248 entry
249 } else {
250 if let Some(max_entries) = self.inner.config.max_entries
251 && map.len() >= max_entries
252 {
253 return Err(RegistryError::CapacityReached { max: max_entries });
254 }
255 let entry = Arc::new(SolutionEntry::new());
256 map.insert(solution.to_string(), entry.clone());
257 entry
258 }
259 };
260
261 {
263 let guard = entry.handle.read().await;
264 if let Some(handle) = guard.as_ref() {
265 let idle = entry.idle_for(now_ms());
267 if idle <= self.inner.config.health_check_after || handle.ping().await {
268 entry.touch();
269 return Ok(handle.clone());
270 }
271 tracing::debug!("health check failed for solution '{solution}', rebuilding");
273 }
274 }
275
276 let mut guard = entry.handle.write().await;
278 if let Some(handle) = guard.as_ref()
280 && handle.ping().await
281 {
282 entry.touch();
283 return Ok(handle.clone());
284 }
285 let build_handle = self.inner.config.build_handle.clone();
286 let handle = (build_handle)(solution.to_string()).await?;
287 *guard = Some(handle.clone());
288 drop(guard);
289 entry.touch();
290 Ok(handle)
291 }
292
293 pub async fn list_solutions(&self) -> Vec<String> {
300 let db_names: Vec<String> = if let Some(discover) = &self.inner.config.discover_solutions {
302 (discover)().await
303 } else {
304 let entries: Vec<Arc<SolutionEntry<C>>> = {
307 let map = self.inner.entries.read().await;
308 map.values().cloned().collect()
309 };
310 let mut live_handle: Option<Arc<SolutionHandle<C>>> = None;
311 for entry in &entries {
312 let guard = entry.handle.read().await;
313 if let Some(h) = guard.as_ref() {
314 live_handle = Some(h.clone());
315 break;
316 }
317 }
318 match live_handle {
319 Some(h) => h.list_databases().await,
320 None => vec![],
321 }
322 };
323
324 let mut names: HashSet<String> = db_names.into_iter().collect();
326 {
327 let map = self.inner.entries.read().await;
328 names.extend(map.keys().cloned());
329 }
330 let mut result: Vec<String> = names.into_iter().collect();
331 result.sort();
332 result
333 }
334
335 pub async fn remove_solution(&self, solution: &str) -> bool {
337 let mut map = self.inner.entries.write().await;
338 map.remove(solution).is_some()
339 }
340
341 pub async fn evict_idle(&self) -> usize {
343 let Some(ttl) = self.inner.config.ttl else {
344 return 0;
345 };
346 let now = now_ms();
347 let mut map = self.inner.entries.write().await;
348 let before = map.len();
349 map.retain(|key, entry| {
350 let keep = entry.idle_for(now) <= ttl;
351 if !keep {
352 tracing::debug!("evicted idle solution: {key}");
353 }
354 keep
355 });
356 before.saturating_sub(map.len())
357 }
358
359 #[must_use]
360 pub fn spawn_sweeper(self) -> Option<tokio::task::JoinHandle<()>>
362 where
363 C: Send + Sync + 'static,
364 {
365 let _ttl = self.inner.config.ttl?;
366 let interval = self.inner.config.sweep_interval;
367 let registry = self;
368 Some(tokio::spawn(async move {
369 let mut ticker = tokio::time::interval(interval);
370 loop {
371 ticker.tick().await;
372 let _ = registry.evict_idle().await;
373 }
374 }))
375 }
376}
377
378fn now_ms() -> u64 {
379 let millis = SystemTime::now()
380 .duration_since(UNIX_EPOCH)
381 .unwrap_or_default()
382 .as_millis();
383 u64::try_from(millis).unwrap_or(u64::MAX)
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389 use std::sync::atomic::{AtomicUsize, Ordering};
390
391 use surrealdb::engine::local::{Db, Mem};
392
393 fn build_test_registry(calls: Arc<AtomicUsize>, ttl: Option<Duration>) -> SolutionRegistry<Db> {
394 let build: BuildHandleFn<Db> = Arc::new(move |solution: String| {
395 let calls = calls.clone();
396 Box::pin(async move {
397 calls.fetch_add(1, Ordering::SeqCst);
398 let db = Surreal::new::<Mem>(())
399 .await
400 .map_err(|err| RegistryError::BuildFailed(err.to_string()))?;
401 db.use_ns("docx")
402 .use_db(&solution)
403 .await
404 .map_err(|err| RegistryError::BuildFailed(err.to_string()))?;
405 Ok(Arc::new(SolutionHandle::from_surreal(db)))
406 })
407 });
408
409 let mut config = SolutionRegistryConfig::new(build);
410 if let Some(ttl) = ttl {
411 config = config
412 .with_ttl(ttl)
413 .with_sweep_interval(Duration::from_millis(1));
414 }
415 SolutionRegistry::new(config)
416 }
417
418 #[tokio::test]
419 async fn registry_single_flight() {
420 let calls = Arc::new(AtomicUsize::new(0));
421 let registry = build_test_registry(calls.clone(), None);
422
423 let r1 = registry.clone();
424 let r2 = registry.clone();
425 let (left, right) = tokio::join!(r1.get_or_init("alpha"), r2.get_or_init("alpha"));
426 assert!(left.is_ok());
427 assert!(right.is_ok());
428 assert_eq!(calls.load(Ordering::SeqCst), 1);
429 }
430
431 #[tokio::test]
432 async fn registry_evicts_idle_entries() {
433 let calls = Arc::new(AtomicUsize::new(0));
434 let registry = build_test_registry(calls, Some(Duration::from_millis(1)));
435
436 let _ = registry.get_or_init("alpha").await.unwrap();
437 tokio::time::sleep(Duration::from_millis(5)).await;
438 let evicted = registry.evict_idle().await;
439 assert_eq!(evicted, 1);
440 }
441
442 #[tokio::test]
443 async fn registry_remove_solution_drops_cache_entry() {
444 let calls = Arc::new(AtomicUsize::new(0));
445 let registry = build_test_registry(calls.clone(), None);
446
447 let _ = registry.get_or_init("alpha").await.unwrap();
448 assert_eq!(calls.load(Ordering::SeqCst), 1);
449
450 assert!(registry.remove_solution("alpha").await);
451 let _ = registry.get_or_init("alpha").await.unwrap();
452 assert_eq!(calls.load(Ordering::SeqCst), 2);
453 }
454}