1pub mod db;
32pub mod drivers;
33#[cfg(feature = "db-sqlx")]
36pub mod migrate;
37pub mod outbox;
38pub mod tx;
39
40use std::collections::HashMap;
41use std::sync::atomic::{AtomicBool, Ordering};
42
43use futures::future::BoxFuture;
44
45use crate::web::tenant::TenantConfig;
46
47#[derive(Clone, Copy, PartialEq, Eq, Debug)]
54pub enum AccessIntent {
55 Read,
56 Write,
57}
58
59#[derive(Debug)]
63#[non_exhaustive]
64pub struct DataError {
65 pub kind: DataErrorKind,
66 pub message: String,
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71#[non_exhaustive]
72pub enum DataErrorKind {
73 Config,
76 Connection,
79 Query,
81 Timeout,
83 Conflict,
87 Other,
89}
90
91impl DataError {
92 pub fn new(kind: DataErrorKind, message: impl Into<String>) -> Self {
93 Self {
94 kind,
95 message: message.into(),
96 }
97 }
98 pub fn config(m: impl Into<String>) -> Self {
99 Self::new(DataErrorKind::Config, m)
100 }
101 pub fn connection(m: impl Into<String>) -> Self {
102 Self::new(DataErrorKind::Connection, m)
103 }
104 pub fn query(m: impl Into<String>) -> Self {
105 Self::new(DataErrorKind::Query, m)
106 }
107 pub fn timeout(m: impl Into<String>) -> Self {
108 Self::new(DataErrorKind::Timeout, m)
109 }
110 pub fn conflict(m: impl Into<String>) -> Self {
111 Self::new(DataErrorKind::Conflict, m)
112 }
113 pub fn other(m: impl Into<String>) -> Self {
114 Self::new(DataErrorKind::Other, m)
115 }
116
117 pub fn is_retryable(&self) -> bool {
119 matches!(
120 self.kind,
121 DataErrorKind::Connection | DataErrorKind::Timeout
122 )
123 }
124}
125
126impl std::fmt::Display for DataError {
127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128 write!(f, "datasource error ({:?}): {}", self.kind, self.message)
129 }
130}
131impl std::error::Error for DataError {}
132
133pub trait DataSource: Send + Sync + 'static {
141 type Conn: Send;
142
143 fn acquire(&self, intent: AccessIntent) -> BoxFuture<'_, Result<Self::Conn, DataError>>;
144
145 fn name(&self) -> &'static str;
147}
148
149#[derive(Default)]
157pub struct ReadAfterWritePin {
158 wrote: AtomicBool,
159}
160
161impl ReadAfterWritePin {
162 pub fn new() -> Self {
163 Self::default()
164 }
165
166 pub fn apply(&self, intent: AccessIntent) -> AccessIntent {
168 match intent {
169 AccessIntent::Write => {
170 self.wrote.store(true, Ordering::Relaxed);
171 AccessIntent::Write
172 }
173 AccessIntent::Read if self.wrote.load(Ordering::Relaxed) => AccessIntent::Write,
174 AccessIntent::Read => AccessIntent::Read,
175 }
176 }
177}
178
179pub struct DataSourceRegistry<D: DataSource> {
184 default: D,
185 by_name: HashMap<&'static str, D>,
186}
187
188impl<D: DataSource> DataSourceRegistry<D> {
189 pub fn new(default: D) -> Self {
190 Self {
191 default,
192 by_name: HashMap::new(),
193 }
194 }
195
196 pub fn with(mut self, name: &'static str, ds: D) -> Self {
198 self.by_name.insert(name, ds);
199 self
200 }
201
202 pub fn for_tenant(&self, tenant: Option<&TenantConfig>) -> &D {
205 tenant
206 .and_then(|t| self.by_name.get(t.datasource.as_str()))
207 .unwrap_or(&self.default)
208 }
209
210 pub fn iter(&self) -> impl Iterator<Item = (&'static str, &D)> {
213 let mut named: Vec<_> = self.by_name.iter().map(|(k, v)| (*k, v)).collect();
214 named.sort_by_key(|(k, _)| *k);
215 std::iter::once(("", &self.default)).chain(named)
216 }
217
218 pub async fn acquire(
220 &self,
221 ds: &D,
222 intent: AccessIntent,
223 pin: &ReadAfterWritePin,
224 ) -> Result<D::Conn, DataError> {
225 ds.acquire(pin.apply(intent)).await
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use crate::web::tenant::{TenantConfig, TenantId};
233
234 #[test]
235 fn pin_upgrades_reads_after_first_write() {
236 let pin = ReadAfterWritePin::new();
237 assert_eq!(pin.apply(AccessIntent::Read), AccessIntent::Read);
238 assert_eq!(pin.apply(AccessIntent::Write), AccessIntent::Write);
239 assert_eq!(pin.apply(AccessIntent::Read), AccessIntent::Write);
241 assert_eq!(pin.apply(AccessIntent::Read), AccessIntent::Write);
242 }
243
244 struct FakeDs(&'static str);
245 impl DataSource for FakeDs {
246 type Conn = ();
247 fn acquire(&self, _: AccessIntent) -> BoxFuture<'_, Result<(), DataError>> {
248 Box::pin(async { Ok(()) })
249 }
250 fn name(&self) -> &'static str {
251 self.0
252 }
253 }
254
255 fn tenant(ds: &str) -> TenantConfig {
256 TenantConfig {
257 id: TenantId::new("t"),
258 display_name: "T".into(),
259 datasource: ds.into(),
260 }
261 }
262
263 #[test]
264 fn registry_routes_by_tenant_datasource_with_default_fallback() {
265 let reg = DataSourceRegistry::new(FakeDs("default"))
266 .with("acme", FakeDs("acme"))
267 .with("globex", FakeDs("globex"));
268
269 assert_eq!(reg.for_tenant(Some(&tenant("acme"))).name(), "acme");
270 assert_eq!(reg.for_tenant(Some(&tenant("unknown"))).name(), "default");
271 assert_eq!(reg.for_tenant(None).name(), "default");
272 }
273
274 #[test]
275 fn data_error_taxonomy_classifies_retryability() {
276 assert!(DataError::connection("pool down").is_retryable());
277 assert!(DataError::timeout("slow").is_retryable());
278 assert!(!DataError::query("syntax").is_retryable());
279 assert!(!DataError::config("bad url").is_retryable());
280 assert!(!DataError::conflict("dup key").is_retryable());
281 let e = DataError::connection("x");
282 assert_eq!(e.kind, DataErrorKind::Connection);
283 assert!(e.to_string().contains("Connection"));
284 }
285
286 #[test]
287 fn registry_iter_is_deterministic_default_first() {
288 let reg = DataSourceRegistry::new(FakeDs("default"))
289 .with("b", FakeDs("b"))
290 .with("a", FakeDs("a"));
291 let order: Vec<&str> = reg.iter().map(|(n, _)| n).collect();
292 assert_eq!(order, vec!["", "a", "b"]);
293 }
294}