Skip to main content

docx_core/
services.rs

1use std::collections::HashMap;
2use std::error::Error;
3use std::fmt;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::{
7    Arc,
8    atomic::{AtomicU64, Ordering},
9};
10use std::time::{Duration, SystemTime, UNIX_EPOCH};
11
12use surrealdb::{Connection, Surreal};
13use tokio::sync::{OnceCell, RwLock};
14
15use crate::control::DocxControlPlane;
16use crate::store::SurrealDocStore;
17
18/// Future returned by the solution handle builder.
19pub type BuildHandleFuture<C> =
20    Pin<Box<dyn Future<Output = Result<Arc<SolutionHandle<C>>, RegistryError>> + Send + 'static>>;
21/// Builder function that creates a solution handle for a solution name.
22pub type BuildHandleFn<C> =
23    Arc<dyn Fn(String) -> BuildHandleFuture<C> + Send + Sync + 'static>;
24
25/// Configuration for the solution registry cache and builder.
26#[derive(Clone)]
27pub struct SolutionRegistryConfig<C: Connection> {
28    /// Optional TTL for cached solutions.
29    pub ttl: Option<Duration>,
30    /// Sweep interval for the background eviction task.
31    pub sweep_interval: Duration,
32    /// Optional maximum number of cached solutions.
33    pub max_entries: Option<usize>,
34    /// Builder used to create solution handles.
35    pub build_handle: BuildHandleFn<C>,
36}
37
38impl<C: Connection> SolutionRegistryConfig<C> {
39    #[must_use]
40    pub fn new(build_handle: BuildHandleFn<C>) -> Self {
41        Self {
42            ttl: None,
43            sweep_interval: Duration::from_secs(60),
44            max_entries: None,
45            build_handle,
46        }
47    }
48
49    #[must_use]
50    pub const fn with_ttl(mut self, ttl: Duration) -> Self {
51        self.ttl = Some(ttl);
52        self
53    }
54
55    #[must_use]
56    pub const fn with_sweep_interval(mut self, sweep_interval: Duration) -> Self {
57        self.sweep_interval = sweep_interval;
58        self
59    }
60
61    #[must_use]
62    pub const fn with_max_entries(mut self, max_entries: usize) -> Self {
63        self.max_entries = Some(max_entries);
64        self
65    }
66}
67
68/// Errors produced by the solution registry.
69#[derive(Debug)]
70pub enum RegistryError {
71    /// Unknown solution name was requested.
72    UnknownSolution(String),
73    /// Registry hit its configured capacity.
74    CapacityReached { max: usize },
75    /// Failed to build a solution handle.
76    BuildFailed(String),
77}
78
79impl fmt::Display for RegistryError {
80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81        match self {
82            Self::UnknownSolution(solution) => write!(f, "unknown solution: {solution}"),
83            Self::CapacityReached { max } => {
84                write!(f, "solution registry capacity reached (max {max})")
85            }
86            Self::BuildFailed(message) => write!(f, "failed to build solution handle: {message}"),
87        }
88    }
89}
90
91impl Error for RegistryError {}
92
93/// Shared service handle for a single solution's database.
94pub struct SolutionHandle<C: Connection> {
95    db: Arc<Surreal<C>>,
96    store: SurrealDocStore<C>,
97    control: DocxControlPlane<C>,
98}
99
100impl<C: Connection> Clone for SolutionHandle<C> {
101    fn clone(&self) -> Self {
102        Self {
103            db: self.db.clone(),
104            store: self.store.clone(),
105            control: self.control.clone(),
106        }
107    }
108}
109
110impl<C: Connection> SolutionHandle<C> {
111    #[must_use]
112    pub fn new(db: Arc<Surreal<C>>) -> Self {
113        let store = SurrealDocStore::from_arc(db.clone());
114        let control = DocxControlPlane::with_store(store.clone());
115        Self { db, store, control }
116    }
117
118    #[must_use]
119    pub fn from_surreal(db: Surreal<C>) -> Self {
120        Self::new(Arc::new(db))
121    }
122
123    #[must_use]
124    pub fn db(&self) -> Arc<Surreal<C>> {
125        self.db.clone()
126    }
127
128    #[must_use]
129    pub fn store(&self) -> SurrealDocStore<C> {
130        self.store.clone()
131    }
132
133    #[must_use]
134    pub fn control(&self) -> DocxControlPlane<C> {
135        self.control.clone()
136    }
137}
138
139/// Registry for dynamically created solution handles.
140#[derive(Clone)]
141pub struct SolutionRegistry<C: Connection> {
142    inner: Arc<SolutionRegistryInner<C>>,
143}
144
145/// Internal registry state shared across clones.
146struct SolutionRegistryInner<C: Connection> {
147    entries: RwLock<HashMap<String, Arc<SolutionEntry<C>>>>,
148    config: SolutionRegistryConfig<C>,
149}
150
151/// Cache entry that tracks a solution handle and last access time.
152struct SolutionEntry<C: Connection> {
153    handle: OnceCell<Arc<SolutionHandle<C>>>,
154    last_used_ms: AtomicU64,
155}
156
157impl<C: Connection> SolutionEntry<C> {
158    fn new() -> Self {
159        Self {
160            handle: OnceCell::new(),
161            last_used_ms: AtomicU64::new(now_ms()),
162        }
163    }
164
165    fn touch(&self) {
166        self.last_used_ms.store(now_ms(), Ordering::Relaxed);
167    }
168
169    fn idle_for(&self, now_ms: u64) -> Duration {
170        let last = self.last_used_ms.load(Ordering::Relaxed);
171        Duration::from_millis(now_ms.saturating_sub(last))
172    }
173}
174
175impl<C: Connection> SolutionRegistry<C> {
176    #[must_use]
177    pub fn new(config: SolutionRegistryConfig<C>) -> Self {
178        Self {
179            inner: Arc::new(SolutionRegistryInner {
180                entries: RwLock::new(HashMap::new()),
181                config,
182            }),
183        }
184    }
185
186    /// Gets the solution handle or builds it once if missing.
187    ///
188    /// # Errors
189    /// Returns `RegistryError` if capacity is exceeded or the build fails.
190    pub async fn get_or_init(
191        &self,
192        solution: &str,
193    ) -> Result<Arc<SolutionHandle<C>>, RegistryError> {
194        let entry = {
195            let map = self.inner.entries.read().await;
196            map.get(solution).cloned()
197        };
198
199        let entry = if let Some(entry) = entry {
200            entry
201        } else {
202            let mut map = self.inner.entries.write().await;
203            if let Some(entry) = map.get(solution).cloned() {
204                entry
205            } else {
206                if let Some(max_entries) = self.inner.config.max_entries
207                    && map.len() >= max_entries
208                {
209                    return Err(RegistryError::CapacityReached { max: max_entries });
210                }
211                let entry = Arc::new(SolutionEntry::new());
212                map.insert(solution.to_string(), entry.clone());
213                entry
214            }
215        };
216
217        entry.touch();
218
219        let build_handle = self.inner.config.build_handle.clone();
220        let handle = entry
221            .handle
222            .get_or_try_init(|| (build_handle)(solution.to_string()))
223            .await?;
224        Ok(handle.clone())
225    }
226
227    /// Lists known solutions from the cache.
228    pub async fn list_solutions(&self) -> Vec<String> {
229        let map = self.inner.entries.read().await;
230        map.keys().cloned().collect()
231    }
232
233    /// Evicts idle entries that exceed the configured TTL.
234    pub async fn evict_idle(&self) -> usize {
235        let Some(ttl) = self.inner.config.ttl else {
236            return 0;
237        };
238        let now = now_ms();
239        let mut map = self.inner.entries.write().await;
240        let before = map.len();
241        map.retain(|_, entry| entry.idle_for(now) <= ttl);
242        before.saturating_sub(map.len())
243    }
244
245    #[must_use]
246    /// Spawns a background task to evict idle entries on a schedule.
247    pub fn spawn_sweeper(self) -> Option<tokio::task::JoinHandle<()>>
248    where
249        C: Send + Sync + 'static,
250    {
251        let _ttl = self.inner.config.ttl?;
252        let interval = self.inner.config.sweep_interval;
253        let registry = self;
254        Some(tokio::spawn(async move {
255            let mut ticker = tokio::time::interval(interval);
256            loop {
257                ticker.tick().await;
258                let _ = registry.evict_idle().await;
259            }
260        }))
261    }
262}
263
264fn now_ms() -> u64 {
265    let millis = SystemTime::now()
266        .duration_since(UNIX_EPOCH)
267        .unwrap_or_default()
268        .as_millis();
269    u64::try_from(millis).unwrap_or(u64::MAX)
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use std::sync::atomic::{AtomicUsize, Ordering};
276
277    use surrealdb::engine::local::{Db, Mem};
278
279    fn build_test_registry(
280        calls: Arc<AtomicUsize>,
281        ttl: Option<Duration>,
282    ) -> SolutionRegistry<Db> {
283        let build: BuildHandleFn<Db> = Arc::new(move |solution: String| {
284            let calls = calls.clone();
285            Box::pin(async move {
286                calls.fetch_add(1, Ordering::SeqCst);
287                let db = Surreal::new::<Mem>(())
288                    .await
289                    .map_err(|err| RegistryError::BuildFailed(err.to_string()))?;
290                db.use_ns("docx")
291                    .use_db(&solution)
292                    .await
293                    .map_err(|err| RegistryError::BuildFailed(err.to_string()))?;
294                Ok(Arc::new(SolutionHandle::from_surreal(db)))
295            })
296        });
297
298        let mut config = SolutionRegistryConfig::new(build);
299        if let Some(ttl) = ttl {
300            config = config.with_ttl(ttl).with_sweep_interval(Duration::from_millis(1));
301        }
302        SolutionRegistry::new(config)
303    }
304
305    #[tokio::test]
306    async fn registry_single_flight() {
307        let calls = Arc::new(AtomicUsize::new(0));
308        let registry = build_test_registry(calls.clone(), None);
309
310        let r1 = registry.clone();
311        let r2 = registry.clone();
312        let (left, right) = tokio::join!(r1.get_or_init("alpha"), r2.get_or_init("alpha"));
313        assert!(left.is_ok());
314        assert!(right.is_ok());
315        assert_eq!(calls.load(Ordering::SeqCst), 1);
316    }
317
318    #[tokio::test]
319    async fn registry_evicts_idle_entries() {
320        let calls = Arc::new(AtomicUsize::new(0));
321        let registry = build_test_registry(calls, Some(Duration::from_millis(1)));
322
323        let _ = registry.get_or_init("alpha").await.unwrap();
324        tokio::time::sleep(Duration::from_millis(5)).await;
325        let evicted = registry.evict_idle().await;
326        assert_eq!(evicted, 1);
327    }
328}