stately_arrow/
registry.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use crate::backend::{Backend, ConnectionMetadata};
6use crate::error::Result;
7
8/// Registry responsible for supplying connectors to the viewer.
9#[async_trait]
10pub trait ConnectorRegistry: Send + Sync {
11    async fn get(&self, id: &str) -> Result<Arc<dyn Backend>>;
12    async fn list(&self) -> Result<Vec<ConnectionMetadata>>;
13    async fn registered(&self) -> Result<Vec<ConnectionMetadata>>;
14}
15
16// The following are provided for convenience, if using the state types directly
17
18#[cfg(feature = "registry")]
19pub mod generic {
20    use std::collections::HashMap;
21    use std::hash::RandomState;
22    use std::sync::Arc;
23    use std::time::{SystemTime, UNIX_EPOCH};
24
25    use async_trait::async_trait;
26    use serde::{Deserialize, Serialize};
27    use tokio::sync::RwLock;
28
29    use super::ConnectorRegistry;
30    use crate::backend::{Backend, BackendMetadata, ConnectionKind, ConnectionMetadata};
31    #[cfg(feature = "clickhouse")]
32    use crate::database::Database as DatabaseType;
33    #[cfg(feature = "clickhouse")]
34    use crate::database::clickhouse::{CLICKHOUSE_CATALOG, ClickHouseBackend};
35    use crate::error::{Error, Result};
36    use crate::object_store::ObjectStoreBackend;
37
38    fn default_connector_name() -> String {
39        let id = uuid::Uuid::now_v7().to_string();
40        format!("connection-{}", &id[..8])
41    }
42
43    // NOTE: This struct can be used or a customized structure can be used instead. This provides a
44    // simple default implementation.
45    /// Connector Stately `entity` type.
46    ///
47    /// Use this with [`Connectors`] and [`Registry`] to create a turnkey connector registry.
48    #[stately::entity]
49    #[derive(
50        Debug, Clone, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize, utoipa::ToSchema,
51    )]
52    pub struct Connector {
53        /// Human-readable name for this connection.
54        #[serde(default = "default_connector_name")]
55        pub name:   String,
56        pub config: Type,
57    }
58
59    #[allow(missing_copy_implementations)] // This will grow, so Copy is a breaking change
60    #[non_exhaustive]
61    #[derive(
62        Debug, Clone, PartialEq, Hash, Eq, serde::Deserialize, serde::Serialize, utoipa::ToSchema,
63    )]
64    #[schema(as = ConnectorType)]
65    #[serde(rename_all = "snake_case")]
66    pub enum Type {
67        ObjectStore(Box<crate::object_store::Config>),
68        #[cfg(feature = "database")]
69        Database(Box<crate::database::Config>),
70    }
71
72    /// Trait for state types that provide read-only access to connectors.
73    ///
74    /// Implement this trait on your state type to use the generic [`Registry`].
75    ///
76    /// # Example
77    ///
78    /// ```ignore
79    /// use stately_arrow::registry::generic::{Connector, Connectors};
80    ///
81    /// #[stately::state]
82    /// pub struct MyState {
83    ///     pub connectors: Connector,
84    ///     // ... other fields
85    /// }
86    ///
87    /// impl Connectors for MyState {
88    ///     fn iter(&self) -> impl Iterator<Item = (&str, &Connector)> {
89    ///         self.connectors.iter().map(|(id, c)| (id.as_ref(), c))
90    ///     }
91    ///
92    ///     fn get(&self, id: &str) -> Option<&Connector> {
93    ///         self.connectors.get_by_name(id).map(|(_, c)| c)
94    ///     }
95    /// }
96    /// ```
97    pub trait Connectors {
98        /// Returns an iterator over all (id, connector) pairs.
99        fn list(&self) -> Vec<(String, Connector)>;
100
101        /// Gets a connector by ID or name.
102        fn get(&self, id: &str) -> Option<Connector>;
103    }
104
105    fn metadata_from_connector(id: String, connector: &Connector) -> ConnectionMetadata {
106        let (metadata, catalog) = match &connector.config {
107            Type::ObjectStore(c) => (ObjectStoreBackend::metadata(), Some(c.store.url())),
108            #[cfg(feature = "database")]
109            #[cfg_attr(not(feature = "clickhouse"), allow(unused))]
110            Type::Database(c) => {
111                #[allow(unused_mut)]
112                let mut metadata =
113                    BackendMetadata::new(ConnectionKind::Database).with_capabilities(vec![]);
114                #[allow(unused_mut)]
115                let mut catalog = None;
116
117                #[cfg(feature = "clickhouse")]
118                #[cfg_attr(feature = "clickhouse", allow(clippy::single_match))]
119                match &c.driver {
120                    DatabaseType::ClickHouse(_) => {
121                        metadata = ClickHouseBackend::metadata();
122                        catalog = Some(CLICKHOUSE_CATALOG.to_string());
123                    }
124                    #[allow(unreachable_patterns)]
125                    _ => {}
126                }
127
128                (metadata, catalog)
129            }
130        };
131
132        ConnectionMetadata { id, name: connector.name.clone(), catalog, metadata }
133    }
134
135    /// Generic registry options.
136    ///
137    /// Provided as a convenience if using state entity types directly, ie [`Connector`]
138    #[derive(
139        Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize, utoipa::ToSchema,
140    )]
141    pub struct RegistryOptions {
142        /// Set the maximum lifetime that a connection should be kept around for.
143        #[serde(skip_serializing_if = "Option::is_none")]
144        pub max_lifetime:  Option<u64>,
145        /// Set the maximum size any connector will use for its pool. Set to 0 to disable pooling.
146        #[serde(skip_serializing_if = "Option::is_none")]
147        pub max_pool_size: Option<u32>,
148    }
149
150    #[derive(Clone)]
151    struct BackendEntry {
152        registered_at: u64,
153        connection:    ConnectionMetadata,
154        backend:       Arc<dyn Backend>,
155    }
156
157    /// Generic registry implementation for state types implementing [`Connectors`].
158    ///
159    /// This provides a default [`ConnectorRegistry`] implementation for users who:
160    /// - Use the provided [`Connector`] entity type in their stately state
161    /// - Implement the [`Connectors`] trait on their state
162    ///
163    /// For custom connector types or more complex needs, implement [`ConnectorRegistry`] directly.
164    pub struct Registry<S: Connectors + Send + Sync> {
165        state:      Arc<RwLock<S>>,
166        registered: Arc<RwLock<HashMap<u64, BackendEntry>>>,
167        options:    RegistryOptions,
168    }
169
170    impl<S: Connectors + Send + Sync> Registry<S> {
171        pub fn new(state: Arc<RwLock<S>>) -> Self {
172            Self {
173                state,
174                registered: Arc::new(RwLock::new(HashMap::default())),
175                options: RegistryOptions::default(),
176            }
177        }
178
179        #[must_use]
180        pub fn with_options(mut self, options: RegistryOptions) -> Self {
181            self.options = options;
182            self
183        }
184    }
185
186    #[async_trait]
187    impl<S: Connectors + Send + Sync + 'static> ConnectorRegistry for Registry<S> {
188        async fn list(&self) -> Result<Vec<ConnectionMetadata>> {
189            Ok(self
190                .state
191                .read()
192                .await
193                .list()
194                .into_iter()
195                .map(|(id, conn)| metadata_from_connector(id, &conn))
196                .collect())
197        }
198
199        async fn registered(&self) -> Result<Vec<ConnectionMetadata>> {
200            Ok(self
201                .registered
202                .read()
203                .await
204                .values()
205                .map(|entry| entry.connection.clone())
206                .collect())
207        }
208
209        async fn get(&self, id: &str) -> Result<Arc<dyn Backend>> {
210            use std::hash::BuildHasher;
211
212            let connector = {
213                self.state
214                    .read()
215                    .await
216                    .get(id)
217                    .ok_or_else(|| Error::ConnectionNotFound(id.to_string()))?
218            };
219
220            let key = RandomState::new().hash_one(&connector);
221            let now = SystemTime::now()
222                .duration_since(UNIX_EPOCH)
223                .expect("Time went backwards")
224                .as_secs();
225
226            if let Some(BackendEntry { backend, .. }) = self
227                .registered
228                .read()
229                .await
230                .get(&key)
231                .filter(|entry| {
232                    // Keep using cached connection if not too old
233                    entry.registered_at
234                        >= (now
235                            - self.options.max_lifetime.unwrap_or(60 * 30 /* 30 Minutes */))
236                })
237                .cloned()
238            {
239                tracing::debug!(key, name = connector.name, "Connector cached");
240                return Ok(backend);
241            }
242
243            let metadata = metadata_from_connector(id.to_string(), &connector);
244            let backend: Arc<dyn Backend> = match connector.config {
245                Type::ObjectStore(config) => {
246                    Arc::new(ObjectStoreBackend::try_new(id, &connector.name, &config)?)
247                }
248                #[cfg(feature = "database")]
249                Type::Database(config) => {
250                    let mut pool = config.pool;
251                    // Ensure connection does not create a pool
252                    let pool_disabled = self.options.max_pool_size.is_some_and(|p| p == 0);
253                    if pool_disabled {
254                        pool.pool_size = Some(1);
255                    } else {
256                        pool.pool_size = pool
257                            .pool_size
258                            .map(|s| self.options.max_pool_size.map_or(s, |m| s.min(m).max(1)))
259                            .or(self.options.max_pool_size);
260                    }
261
262                    #[allow(unreachable_code)]
263                    match config.driver {
264                        #[cfg(feature = "clickhouse")]
265                        DatabaseType::ClickHouse(clickhouse_conf) => {
266                            let backend = ClickHouseBackend::try_new(
267                                id,
268                                &connector.name,
269                                &config.options,
270                                clickhouse_conf,
271                                pool,
272                            )
273                            .await?;
274                            Arc::new(backend)
275                        }
276                        #[allow(unreachable_patterns)]
277                        _ => return Err(Error::UnsupportedConnector(id.to_string())),
278                    }
279                }
280            };
281
282            // Write to cache
283            let mut connectors = self.registered.write().await;
284
285            // Cleanup any connection there might be
286            drop(connectors.remove(&key));
287            tracing::debug!(
288                key,
289                name = connector.name,
290                metadata = ?backend.connection(),
291                "Connector not cached, creating",
292            );
293
294            // Insert
295            drop(connectors.insert(key, BackendEntry {
296                registered_at: now,
297                connection:    metadata,
298                backend:       Arc::clone(&backend),
299            }));
300
301            Ok(backend)
302        }
303    }
304}