1use std::collections::HashMap;
2use std::error::Error;
3use std::fmt;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::{
7 Arc,
8 atomic::{AtomicU64, Ordering},
9};
10use std::time::{Duration, SystemTime, UNIX_EPOCH};
11
12use surrealdb::{Connection, Surreal};
13use tokio::sync::{OnceCell, RwLock};
14
15use crate::control::DocxControlPlane;
16use crate::store::SurrealDocStore;
17
18pub type BuildHandleFuture<C> =
20 Pin<Box<dyn Future<Output = Result<Arc<SolutionHandle<C>>, RegistryError>> + Send + 'static>>;
21pub type BuildHandleFn<C> =
23 Arc<dyn Fn(String) -> BuildHandleFuture<C> + Send + Sync + 'static>;
24
25#[derive(Clone)]
27pub struct SolutionRegistryConfig<C: Connection> {
28 pub ttl: Option<Duration>,
30 pub sweep_interval: Duration,
32 pub max_entries: Option<usize>,
34 pub build_handle: BuildHandleFn<C>,
36}
37
38impl<C: Connection> SolutionRegistryConfig<C> {
39 #[must_use]
40 pub fn new(build_handle: BuildHandleFn<C>) -> Self {
41 Self {
42 ttl: None,
43 sweep_interval: Duration::from_secs(60),
44 max_entries: None,
45 build_handle,
46 }
47 }
48
49 #[must_use]
50 pub const fn with_ttl(mut self, ttl: Duration) -> Self {
51 self.ttl = Some(ttl);
52 self
53 }
54
55 #[must_use]
56 pub const fn with_sweep_interval(mut self, sweep_interval: Duration) -> Self {
57 self.sweep_interval = sweep_interval;
58 self
59 }
60
61 #[must_use]
62 pub const fn with_max_entries(mut self, max_entries: usize) -> Self {
63 self.max_entries = Some(max_entries);
64 self
65 }
66}
67
68#[derive(Debug)]
70pub enum RegistryError {
71 UnknownSolution(String),
73 CapacityReached { max: usize },
75 BuildFailed(String),
77}
78
79impl fmt::Display for RegistryError {
80 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81 match self {
82 Self::UnknownSolution(solution) => write!(f, "unknown solution: {solution}"),
83 Self::CapacityReached { max } => {
84 write!(f, "solution registry capacity reached (max {max})")
85 }
86 Self::BuildFailed(message) => write!(f, "failed to build solution handle: {message}"),
87 }
88 }
89}
90
91impl Error for RegistryError {}
92
93pub struct SolutionHandle<C: Connection> {
95 db: Arc<Surreal<C>>,
96 store: SurrealDocStore<C>,
97 control: DocxControlPlane<C>,
98}
99
100impl<C: Connection> Clone for SolutionHandle<C> {
101 fn clone(&self) -> Self {
102 Self {
103 db: self.db.clone(),
104 store: self.store.clone(),
105 control: self.control.clone(),
106 }
107 }
108}
109
110impl<C: Connection> SolutionHandle<C> {
111 #[must_use]
112 pub fn new(db: Arc<Surreal<C>>) -> Self {
113 let store = SurrealDocStore::from_arc(db.clone());
114 let control = DocxControlPlane::with_store(store.clone());
115 Self { db, store, control }
116 }
117
118 #[must_use]
119 pub fn from_surreal(db: Surreal<C>) -> Self {
120 Self::new(Arc::new(db))
121 }
122
123 #[must_use]
124 pub fn db(&self) -> Arc<Surreal<C>> {
125 self.db.clone()
126 }
127
128 #[must_use]
129 pub fn store(&self) -> SurrealDocStore<C> {
130 self.store.clone()
131 }
132
133 #[must_use]
134 pub fn control(&self) -> DocxControlPlane<C> {
135 self.control.clone()
136 }
137}
138
139#[derive(Clone)]
141pub struct SolutionRegistry<C: Connection> {
142 inner: Arc<SolutionRegistryInner<C>>,
143}
144
145struct SolutionRegistryInner<C: Connection> {
147 entries: RwLock<HashMap<String, Arc<SolutionEntry<C>>>>,
148 config: SolutionRegistryConfig<C>,
149}
150
151struct SolutionEntry<C: Connection> {
153 handle: OnceCell<Arc<SolutionHandle<C>>>,
154 last_used_ms: AtomicU64,
155}
156
157impl<C: Connection> SolutionEntry<C> {
158 fn new() -> Self {
159 Self {
160 handle: OnceCell::new(),
161 last_used_ms: AtomicU64::new(now_ms()),
162 }
163 }
164
165 fn touch(&self) {
166 self.last_used_ms.store(now_ms(), Ordering::Relaxed);
167 }
168
169 fn idle_for(&self, now_ms: u64) -> Duration {
170 let last = self.last_used_ms.load(Ordering::Relaxed);
171 Duration::from_millis(now_ms.saturating_sub(last))
172 }
173}
174
175impl<C: Connection> SolutionRegistry<C> {
176 #[must_use]
177 pub fn new(config: SolutionRegistryConfig<C>) -> Self {
178 Self {
179 inner: Arc::new(SolutionRegistryInner {
180 entries: RwLock::new(HashMap::new()),
181 config,
182 }),
183 }
184 }
185
186 pub async fn get_or_init(
191 &self,
192 solution: &str,
193 ) -> Result<Arc<SolutionHandle<C>>, RegistryError> {
194 let entry = {
195 let map = self.inner.entries.read().await;
196 map.get(solution).cloned()
197 };
198
199 let entry = if let Some(entry) = entry {
200 entry
201 } else {
202 let mut map = self.inner.entries.write().await;
203 if let Some(entry) = map.get(solution).cloned() {
204 entry
205 } else {
206 if let Some(max_entries) = self.inner.config.max_entries
207 && map.len() >= max_entries
208 {
209 return Err(RegistryError::CapacityReached { max: max_entries });
210 }
211 let entry = Arc::new(SolutionEntry::new());
212 map.insert(solution.to_string(), entry.clone());
213 entry
214 }
215 };
216
217 entry.touch();
218
219 let build_handle = self.inner.config.build_handle.clone();
220 let handle = entry
221 .handle
222 .get_or_try_init(|| (build_handle)(solution.to_string()))
223 .await?;
224 Ok(handle.clone())
225 }
226
227 pub async fn list_solutions(&self) -> Vec<String> {
229 let map = self.inner.entries.read().await;
230 map.keys().cloned().collect()
231 }
232
233 pub async fn evict_idle(&self) -> usize {
235 let Some(ttl) = self.inner.config.ttl else {
236 return 0;
237 };
238 let now = now_ms();
239 let mut map = self.inner.entries.write().await;
240 let before = map.len();
241 map.retain(|_, entry| entry.idle_for(now) <= ttl);
242 before.saturating_sub(map.len())
243 }
244
245 #[must_use]
246 pub fn spawn_sweeper(self) -> Option<tokio::task::JoinHandle<()>>
248 where
249 C: Send + Sync + 'static,
250 {
251 let _ttl = self.inner.config.ttl?;
252 let interval = self.inner.config.sweep_interval;
253 let registry = self;
254 Some(tokio::spawn(async move {
255 let mut ticker = tokio::time::interval(interval);
256 loop {
257 ticker.tick().await;
258 let _ = registry.evict_idle().await;
259 }
260 }))
261 }
262}
263
264fn now_ms() -> u64 {
265 let millis = SystemTime::now()
266 .duration_since(UNIX_EPOCH)
267 .unwrap_or_default()
268 .as_millis();
269 u64::try_from(millis).unwrap_or(u64::MAX)
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use std::sync::atomic::{AtomicUsize, Ordering};
276
277 use surrealdb::engine::local::{Db, Mem};
278
279 fn build_test_registry(
280 calls: Arc<AtomicUsize>,
281 ttl: Option<Duration>,
282 ) -> SolutionRegistry<Db> {
283 let build: BuildHandleFn<Db> = Arc::new(move |solution: String| {
284 let calls = calls.clone();
285 Box::pin(async move {
286 calls.fetch_add(1, Ordering::SeqCst);
287 let db = Surreal::new::<Mem>(())
288 .await
289 .map_err(|err| RegistryError::BuildFailed(err.to_string()))?;
290 db.use_ns("docx")
291 .use_db(&solution)
292 .await
293 .map_err(|err| RegistryError::BuildFailed(err.to_string()))?;
294 Ok(Arc::new(SolutionHandle::from_surreal(db)))
295 })
296 });
297
298 let mut config = SolutionRegistryConfig::new(build);
299 if let Some(ttl) = ttl {
300 config = config.with_ttl(ttl).with_sweep_interval(Duration::from_millis(1));
301 }
302 SolutionRegistry::new(config)
303 }
304
305 #[tokio::test]
306 async fn registry_single_flight() {
307 let calls = Arc::new(AtomicUsize::new(0));
308 let registry = build_test_registry(calls.clone(), None);
309
310 let r1 = registry.clone();
311 let r2 = registry.clone();
312 let (left, right) = tokio::join!(r1.get_or_init("alpha"), r2.get_or_init("alpha"));
313 assert!(left.is_ok());
314 assert!(right.is_ok());
315 assert_eq!(calls.load(Ordering::SeqCst), 1);
316 }
317
318 #[tokio::test]
319 async fn registry_evicts_idle_entries() {
320 let calls = Arc::new(AtomicUsize::new(0));
321 let registry = build_test_registry(calls, Some(Duration::from_millis(1)));
322
323 let _ = registry.get_or_init("alpha").await.unwrap();
324 tokio::time::sleep(Duration::from_millis(5)).await;
325 let evicted = registry.evict_idle().await;
326 assert_eq!(evicted, 1);
327 }
328}