Skip to main content

docx_core/
services.rs

1use std::collections::{HashMap, HashSet};
2use std::error::Error;
3use std::fmt;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::{
7    Arc,
8    atomic::{AtomicU64, Ordering},
9};
10use std::time::{Duration, SystemTime, UNIX_EPOCH};
11
12use surrealdb::{Connection, Surreal};
13use tokio::sync::RwLock;
14
15use crate::control::DocxControlPlane;
16use crate::store::SurrealDocStore;
17
18/// Solution name reserved for internal namespace-discovery connections.
19/// Ingestion into this name must be rejected to prevent polluting the DB.
20pub const RESERVED_SOLUTION: &str = "__discovery__";
21
22/// Future returned by the solution handle builder.
23pub type BuildHandleFuture<C> =
24    Pin<Box<dyn Future<Output = Result<Arc<SolutionHandle<C>>, RegistryError>> + Send + 'static>>;
25/// Builder function that creates a solution handle for a solution name.
26pub type BuildHandleFn<C> = Arc<dyn Fn(String) -> BuildHandleFuture<C> + Send + Sync + 'static>;
27
28/// Future returned by the solution discovery function.
29pub type DiscoverSolutionsFuture = Pin<Box<dyn Future<Output = Vec<String>> + Send + 'static>>;
30/// Optional function that discovers existing solution names from the database
31/// without requiring a specific database to be selected.
32pub type DiscoverSolutionsFn = Arc<dyn Fn() -> DiscoverSolutionsFuture + Send + Sync + 'static>;
33
34/// Configuration for the solution registry cache and builder.
35#[derive(Clone)]
36pub struct SolutionRegistryConfig<C: Connection> {
37    /// Optional TTL for cached solutions.
38    pub ttl: Option<Duration>,
39    /// Sweep interval for the background eviction task.
40    pub sweep_interval: Duration,
41    /// Optional maximum number of cached solutions.
42    pub max_entries: Option<usize>,
43    /// Builder used to create solution handles.
44    pub build_handle: BuildHandleFn<C>,
45    /// Idle threshold before running a health check on next access.
46    pub health_check_after: Duration,
47    /// Optional function to discover existing solution names from the database
48    /// at the namespace level (no specific database required).
49    pub discover_solutions: Option<DiscoverSolutionsFn>,
50}
51
52impl<C: Connection> SolutionRegistryConfig<C> {
53    #[must_use]
54    pub fn new(build_handle: BuildHandleFn<C>) -> Self {
55        Self {
56            ttl: None,
57            sweep_interval: Duration::from_secs(60),
58            max_entries: None,
59            build_handle,
60            health_check_after: Duration::from_secs(60),
61            discover_solutions: None,
62        }
63    }
64
65    #[must_use]
66    pub fn with_discover_solutions(mut self, f: DiscoverSolutionsFn) -> Self {
67        self.discover_solutions = Some(f);
68        self
69    }
70
71    #[must_use]
72    pub const fn with_ttl(mut self, ttl: Duration) -> Self {
73        self.ttl = Some(ttl);
74        self
75    }
76
77    #[must_use]
78    pub const fn with_sweep_interval(mut self, sweep_interval: Duration) -> Self {
79        self.sweep_interval = sweep_interval;
80        self
81    }
82
83    #[must_use]
84    pub const fn with_max_entries(mut self, max_entries: usize) -> Self {
85        self.max_entries = Some(max_entries);
86        self
87    }
88
89    #[must_use]
90    pub const fn with_health_check_after(mut self, health_check_after: Duration) -> Self {
91        self.health_check_after = health_check_after;
92        self
93    }
94}
95
96/// Errors produced by the solution registry.
97#[derive(Debug)]
98pub enum RegistryError {
99    /// Unknown solution name was requested.
100    UnknownSolution(String),
101    /// Registry hit its configured capacity.
102    CapacityReached { max: usize },
103    /// Failed to build a solution handle.
104    BuildFailed(String),
105}
106
107impl fmt::Display for RegistryError {
108    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109        match self {
110            Self::UnknownSolution(solution) => write!(f, "unknown solution: {solution}"),
111            Self::CapacityReached { max } => {
112                write!(f, "solution registry capacity reached (max {max})")
113            }
114            Self::BuildFailed(message) => write!(f, "failed to build solution handle: {message}"),
115        }
116    }
117}
118
119impl Error for RegistryError {}
120
121/// Shared service handle for a single solution's database.
122pub struct SolutionHandle<C: Connection> {
123    db: Arc<Surreal<C>>,
124    store: SurrealDocStore<C>,
125    control: DocxControlPlane<C>,
126}
127
128impl<C: Connection> Clone for SolutionHandle<C> {
129    fn clone(&self) -> Self {
130        Self {
131            db: self.db.clone(),
132            store: self.store.clone(),
133            control: self.control.clone(),
134        }
135    }
136}
137
138impl<C: Connection> SolutionHandle<C> {
139    #[must_use]
140    pub fn new(db: Arc<Surreal<C>>) -> Self {
141        let store = SurrealDocStore::from_arc(db.clone());
142        let control = DocxControlPlane::with_store(store.clone());
143        Self { db, store, control }
144    }
145
146    #[must_use]
147    pub fn from_surreal(db: Surreal<C>) -> Self {
148        Self::new(Arc::new(db))
149    }
150
151    #[must_use]
152    pub fn db(&self) -> Arc<Surreal<C>> {
153        self.db.clone()
154    }
155
156    #[must_use]
157    pub fn store(&self) -> SurrealDocStore<C> {
158        self.store.clone()
159    }
160
161    #[must_use]
162    pub fn control(&self) -> DocxControlPlane<C> {
163        self.control.clone()
164    }
165
166    /// Lists all database names in the current namespace.
167    pub async fn list_databases(&self) -> Vec<String> {
168        self.store.list_databases().await.unwrap_or_default()
169    }
170
171    /// Runs a lightweight health check against the database connection.
172    pub async fn ping(&self) -> bool {
173        self.db
174            .query("SELECT 1")
175            .await
176            .is_ok_and(|r| r.check().is_ok())
177    }
178}
179
180/// Registry for dynamically created solution handles.
181#[derive(Clone)]
182pub struct SolutionRegistry<C: Connection> {
183    inner: Arc<SolutionRegistryInner<C>>,
184}
185
186/// Internal registry state shared across clones.
187struct SolutionRegistryInner<C: Connection> {
188    entries: RwLock<HashMap<String, Arc<SolutionEntry<C>>>>,
189    config: SolutionRegistryConfig<C>,
190}
191
192/// Cache entry that tracks a solution handle and last access time.
193struct SolutionEntry<C: Connection> {
194    handle: RwLock<Option<Arc<SolutionHandle<C>>>>,
195    last_used_ms: AtomicU64,
196}
197
198impl<C: Connection> SolutionEntry<C> {
199    fn new() -> Self {
200        Self {
201            handle: RwLock::new(None),
202            last_used_ms: AtomicU64::new(now_ms()),
203        }
204    }
205
206    fn touch(&self) {
207        self.last_used_ms.store(now_ms(), Ordering::Relaxed);
208    }
209
210    fn idle_for(&self, now_ms: u64) -> Duration {
211        let last = self.last_used_ms.load(Ordering::Relaxed);
212        Duration::from_millis(now_ms.saturating_sub(last))
213    }
214}
215
216impl<C: Connection> SolutionRegistry<C> {
217    #[must_use]
218    pub fn new(config: SolutionRegistryConfig<C>) -> Self {
219        Self {
220            inner: Arc::new(SolutionRegistryInner {
221                entries: RwLock::new(HashMap::new()),
222                config,
223            }),
224        }
225    }
226
227    /// Gets the solution handle or builds it once if missing.
228    ///
229    /// If the handle has been idle longer than `health_check_after`, a ping is
230    /// issued. On failure the stale handle is evicted and rebuilt.
231    ///
232    /// # Errors
233    /// Returns `RegistryError` if capacity is exceeded or the build fails.
234    pub async fn get_or_init(
235        &self,
236        solution: &str,
237    ) -> Result<Arc<SolutionHandle<C>>, RegistryError> {
238        let entry = {
239            let map = self.inner.entries.read().await;
240            map.get(solution).cloned()
241        };
242
243        let entry = if let Some(entry) = entry {
244            entry
245        } else {
246            let mut map = self.inner.entries.write().await;
247            if let Some(entry) = map.get(solution).cloned() {
248                entry
249            } else {
250                if let Some(max_entries) = self.inner.config.max_entries
251                    && map.len() >= max_entries
252                {
253                    return Err(RegistryError::CapacityReached { max: max_entries });
254                }
255                let entry = Arc::new(SolutionEntry::new());
256                map.insert(solution.to_string(), entry.clone());
257                entry
258            }
259        };
260
261        // Try to get existing handle
262        {
263            let guard = entry.handle.read().await;
264            if let Some(handle) = guard.as_ref() {
265                // Health check if idle long enough
266                let idle = entry.idle_for(now_ms());
267                if idle <= self.inner.config.health_check_after || handle.ping().await {
268                    entry.touch();
269                    return Ok(handle.clone());
270                }
271                // Ping failed — fall through to rebuild
272                tracing::debug!("health check failed for solution '{solution}', rebuilding");
273            }
274        }
275
276        // Build or rebuild under write lock
277        let mut guard = entry.handle.write().await;
278        // Double-check: another task may have rebuilt while we waited
279        if let Some(handle) = guard.as_ref()
280            && handle.ping().await
281        {
282            entry.touch();
283            return Ok(handle.clone());
284        }
285        let build_handle = self.inner.config.build_handle.clone();
286        let handle = (build_handle)(solution.to_string()).await?;
287        *guard = Some(handle.clone());
288        drop(guard);
289        entry.touch();
290        Ok(handle)
291    }
292
293    /// Lists known solutions by merging the in-memory cache with a live DB
294    /// discovery query (`INFO FOR NS`).
295    ///
296    /// When a `discover_solutions` function is configured it is called first;
297    /// otherwise any live cached handle is used for the namespace query.  If
298    /// neither is available the result falls back to the cache alone.
299    pub async fn list_solutions(&self) -> Vec<String> {
300        // Try the dedicated discovery function first (preferred path).
301        let db_names: Vec<String> = if let Some(discover) = &self.inner.config.discover_solutions {
302            (discover)().await
303        } else {
304            // Fallback: collect cached entries without holding the map lock, then
305            // find any live handle to run INFO FOR NS through.
306            let entries: Vec<Arc<SolutionEntry<C>>> = {
307                let map = self.inner.entries.read().await;
308                map.values().cloned().collect()
309            };
310            let mut live_handle: Option<Arc<SolutionHandle<C>>> = None;
311            for entry in &entries {
312                let guard = entry.handle.read().await;
313                if let Some(h) = guard.as_ref() {
314                    live_handle = Some(h.clone());
315                    break;
316                }
317            }
318            match live_handle {
319                Some(h) => h.list_databases().await,
320                None => vec![],
321            }
322        };
323
324        // Merge DB-discovered names with cached names.
325        let mut names: HashSet<String> = db_names.into_iter().collect();
326        {
327            let map = self.inner.entries.read().await;
328            names.extend(map.keys().cloned());
329        }
330        let mut result: Vec<String> = names.into_iter().collect();
331        result.sort();
332        result
333    }
334
335    /// Removes a cached solution handle entry.
336    pub async fn remove_solution(&self, solution: &str) -> bool {
337        let mut map = self.inner.entries.write().await;
338        map.remove(solution).is_some()
339    }
340
341    /// Evicts idle entries that exceed the configured TTL.
342    pub async fn evict_idle(&self) -> usize {
343        let Some(ttl) = self.inner.config.ttl else {
344            return 0;
345        };
346        let now = now_ms();
347        let mut map = self.inner.entries.write().await;
348        let before = map.len();
349        map.retain(|key, entry| {
350            let keep = entry.idle_for(now) <= ttl;
351            if !keep {
352                tracing::debug!("evicted idle solution: {key}");
353            }
354            keep
355        });
356        before.saturating_sub(map.len())
357    }
358
359    #[must_use]
360    /// Spawns a background task to evict idle entries on a schedule.
361    pub fn spawn_sweeper(self) -> Option<tokio::task::JoinHandle<()>>
362    where
363        C: Send + Sync + 'static,
364    {
365        let _ttl = self.inner.config.ttl?;
366        let interval = self.inner.config.sweep_interval;
367        let registry = self;
368        Some(tokio::spawn(async move {
369            let mut ticker = tokio::time::interval(interval);
370            loop {
371                ticker.tick().await;
372                let _ = registry.evict_idle().await;
373            }
374        }))
375    }
376}
377
378fn now_ms() -> u64 {
379    let millis = SystemTime::now()
380        .duration_since(UNIX_EPOCH)
381        .unwrap_or_default()
382        .as_millis();
383    u64::try_from(millis).unwrap_or(u64::MAX)
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389    use std::sync::atomic::{AtomicUsize, Ordering};
390
391    use surrealdb::engine::local::{Db, Mem};
392
393    fn build_test_registry(calls: Arc<AtomicUsize>, ttl: Option<Duration>) -> SolutionRegistry<Db> {
394        let build: BuildHandleFn<Db> = Arc::new(move |solution: String| {
395            let calls = calls.clone();
396            Box::pin(async move {
397                calls.fetch_add(1, Ordering::SeqCst);
398                let db = Surreal::new::<Mem>(())
399                    .await
400                    .map_err(|err| RegistryError::BuildFailed(err.to_string()))?;
401                db.use_ns("docx")
402                    .use_db(&solution)
403                    .await
404                    .map_err(|err| RegistryError::BuildFailed(err.to_string()))?;
405                Ok(Arc::new(SolutionHandle::from_surreal(db)))
406            })
407        });
408
409        let mut config = SolutionRegistryConfig::new(build);
410        if let Some(ttl) = ttl {
411            config = config
412                .with_ttl(ttl)
413                .with_sweep_interval(Duration::from_millis(1));
414        }
415        SolutionRegistry::new(config)
416    }
417
418    #[tokio::test]
419    async fn registry_single_flight() {
420        let calls = Arc::new(AtomicUsize::new(0));
421        let registry = build_test_registry(calls.clone(), None);
422
423        let r1 = registry.clone();
424        let r2 = registry.clone();
425        let (left, right) = tokio::join!(r1.get_or_init("alpha"), r2.get_or_init("alpha"));
426        assert!(left.is_ok());
427        assert!(right.is_ok());
428        assert_eq!(calls.load(Ordering::SeqCst), 1);
429    }
430
431    #[tokio::test]
432    async fn registry_evicts_idle_entries() {
433        let calls = Arc::new(AtomicUsize::new(0));
434        let registry = build_test_registry(calls, Some(Duration::from_millis(1)));
435
436        let _ = registry.get_or_init("alpha").await.unwrap();
437        tokio::time::sleep(Duration::from_millis(5)).await;
438        let evicted = registry.evict_idle().await;
439        assert_eq!(evicted, 1);
440    }
441
442    #[tokio::test]
443    async fn registry_remove_solution_drops_cache_entry() {
444        let calls = Arc::new(AtomicUsize::new(0));
445        let registry = build_test_registry(calls.clone(), None);
446
447        let _ = registry.get_or_init("alpha").await.unwrap();
448        assert_eq!(calls.load(Ordering::SeqCst), 1);
449
450        assert!(registry.remove_solution("alpha").await);
451        let _ = registry.get_or_init("alpha").await.unwrap();
452        assert_eq!(calls.load(Ordering::SeqCst), 2);
453    }
454}