Skip to main content

tt_routing/
store.rs

1//! Where routes come from at runtime.
2//!
3//! The gateway only knows about a [`RoutingStore`] trait. Production wires
4//! [`PostgresRoutingStore`] (behind the `postgres` feature flag) reading the
5//! `routes` table that the cloud dashboard writes; tests use
6//! [`InMemoryRoutingStore`]. A separate [`crate::cache::CachingRoutingStore`]
7//! wraps either with a per-org TTL cache so the hot path isn't a DB round-trip.
8
9use std::collections::HashMap;
10use std::sync::RwLock;
11
12use async_trait::async_trait;
13use uuid::Uuid;
14
15use crate::{Route, RouteAction, RouteConditions};
16
17/// Fields needed to create a route; the store assigns the `id`.
18#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
19pub struct NewRoute {
20    pub name: String,
21    #[serde(default = "default_priority")]
22    pub priority: u32,
23    #[serde(default = "default_enabled")]
24    pub enabled: bool,
25    #[serde(default)]
26    pub when: RouteConditions,
27    pub then: RouteAction,
28}
29
30fn default_priority() -> u32 {
31    100
32}
33fn default_enabled() -> bool {
34    true
35}
36
37/// Source of truth for an org's enabled routes.
38///
39/// Implementations return ALL enabled routes for `org_id`; ordering is the
40/// caller's problem ([`crate::RoutingEngine::with_routes`] sorts internally).
41#[async_trait]
42pub trait RoutingStore: Send + Sync + std::fmt::Debug {
43    /// Fetch the enabled-and-current route list for `org_id`. Returns
44    /// `Ok(vec![])` when the org has no routes — not an error.
45    async fn list_for_org(&self, org_id: Uuid) -> Result<Vec<Route>, RoutingStoreError>;
46
47    /// Management: ALL of an org's routes, including disabled ones.
48    async fn list_all_for_org(&self, _org_id: Uuid) -> Result<Vec<Route>, RoutingStoreError> {
49        Err(RoutingStoreError::Backend(
50            "management unsupported by this store".into(),
51        ))
52    }
53    /// Management: create a route for `org_id`. Returns the created route.
54    async fn create_route(
55        &self,
56        _org_id: Uuid,
57        _spec: NewRoute,
58    ) -> Result<Route, RoutingStoreError> {
59        Err(RoutingStoreError::Backend(
60            "management unsupported by this store".into(),
61        ))
62    }
63    /// Management: fetch one route owned by `org_id`.
64    async fn get_route(
65        &self,
66        _org_id: Uuid,
67        _id: Uuid,
68    ) -> Result<Option<Route>, RoutingStoreError> {
69        Err(RoutingStoreError::Backend(
70            "management unsupported by this store".into(),
71        ))
72    }
73    /// Management: delete one route owned by `org_id`. Returns whether a row was removed.
74    async fn delete_route(&self, _org_id: Uuid, _id: Uuid) -> Result<bool, RoutingStoreError> {
75        Err(RoutingStoreError::Backend(
76            "management unsupported by this store".into(),
77        ))
78    }
79}
80
81#[derive(Debug, thiserror::Error)]
82pub enum RoutingStoreError {
83    #[error("backend error: {0}")]
84    Backend(String),
85}
86
87/// Test / dev backend. Holds a HashMap<org_id, Vec<Route>>; the gateway treats
88/// it like any other store.
89#[derive(Debug, Default)]
90pub struct InMemoryRoutingStore {
91    inner: RwLock<HashMap<Uuid, Vec<Route>>>,
92}
93
94impl InMemoryRoutingStore {
95    pub fn new() -> Self {
96        Self::default()
97    }
98
99    /// Replace the routes for an org. Useful for plant-and-assert tests.
100    pub fn set_routes(&self, org_id: Uuid, routes: Vec<Route>) {
101        let mut g = self.inner.write().expect("inmemory routing store poisoned");
102        g.insert(org_id, routes);
103    }
104}
105
106#[async_trait]
107impl RoutingStore for InMemoryRoutingStore {
108    async fn list_for_org(&self, org_id: Uuid) -> Result<Vec<Route>, RoutingStoreError> {
109        let g = self.inner.read().expect("inmemory routing store poisoned");
110        Ok(g.get(&org_id).cloned().unwrap_or_default())
111    }
112
113    async fn list_all_for_org(&self, org_id: Uuid) -> Result<Vec<Route>, RoutingStoreError> {
114        let g = self.inner.read().expect("inmemory routing store poisoned");
115        Ok(g.get(&org_id).cloned().unwrap_or_default())
116    }
117
118    async fn create_route(&self, org_id: Uuid, spec: NewRoute) -> Result<Route, RoutingStoreError> {
119        let route = Route {
120            id: Uuid::now_v7(),
121            name: spec.name,
122            priority: spec.priority,
123            enabled: spec.enabled,
124            when: spec.when,
125            then: spec.then,
126        };
127        let mut g = self.inner.write().expect("inmemory routing store poisoned");
128        g.entry(org_id).or_default().push(route.clone());
129        Ok(route)
130    }
131
132    async fn get_route(&self, org_id: Uuid, id: Uuid) -> Result<Option<Route>, RoutingStoreError> {
133        let g = self.inner.read().expect("inmemory routing store poisoned");
134        Ok(g.get(&org_id)
135            .and_then(|v| v.iter().find(|r| r.id == id).cloned()))
136    }
137
138    async fn delete_route(&self, org_id: Uuid, id: Uuid) -> Result<bool, RoutingStoreError> {
139        let mut g = self.inner.write().expect("inmemory routing store poisoned");
140        let Some(v) = g.get_mut(&org_id) else {
141            return Ok(false);
142        };
143        let before = v.len();
144        v.retain(|r| r.id != id);
145        Ok(v.len() != before)
146    }
147}
148
149#[cfg(feature = "postgres")]
150mod pg {
151    use super::*;
152    use crate::{RouteAction, RouteConditions};
153    use sqlx::PgPool;
154
155    /// Reads the `routes` table written by the cloud dashboard / tt-api admin
156    /// endpoints. Schema lives in tokentrimmer-cloud
157    /// (`crates/api/migrations/0002_routes.up.sql`):
158    ///
159    /// ```sql
160    /// CREATE TABLE routes (
161    ///   id          UUID PRIMARY KEY,
162    ///   org_id      UUID NOT NULL,
163    ///   name        TEXT NOT NULL,
164    ///   priority    INT  NOT NULL,
165    ///   conditions  JSONB NOT NULL,
166    ///   target      JSONB NOT NULL,
167    ///   enabled     BOOLEAN NOT NULL,
168    ///   ...
169    /// );
170    /// ```
171    ///
172    /// Rows whose `conditions` or `target` JSON fails to decode are skipped
173    /// with a warning — a single malformed row must not knock out routing for
174    /// the org. Wrap in [`crate::cache::CachingRoutingStore`] to amortize the
175    /// SELECT across hot-path requests.
176    #[derive(Clone, Debug)]
177    pub struct PostgresRoutingStore {
178        pool: PgPool,
179    }
180
181    impl PostgresRoutingStore {
182        pub fn new(pool: PgPool) -> Self {
183            Self { pool }
184        }
185    }
186
187    #[async_trait]
188    impl RoutingStore for PostgresRoutingStore {
189        async fn list_for_org(&self, org_id: Uuid) -> Result<Vec<Route>, RoutingStoreError> {
190            let rows = sqlx::query_as::<_, RouteRow>(
191                "SELECT id, name, priority, conditions, target \
192                 FROM routes \
193                 WHERE org_id = $1 AND enabled = TRUE \
194                 ORDER BY priority DESC, created_at ASC",
195            )
196            .bind(org_id)
197            .fetch_all(&self.pool)
198            .await
199            .map_err(|e| RoutingStoreError::Backend(e.to_string()))?;
200
201            Ok(rows.into_iter().filter_map(RouteRow::into_route).collect())
202        }
203
204        async fn list_all_for_org(&self, org_id: Uuid) -> Result<Vec<Route>, RoutingStoreError> {
205            let rows = sqlx::query_as::<_, MgmtRouteRow>(
206                "SELECT id, name, priority, enabled, conditions, target \
207                 FROM routes WHERE org_id = $1 ORDER BY priority DESC, created_at ASC",
208            )
209            .bind(org_id)
210            .fetch_all(&self.pool)
211            .await
212            .map_err(|e| RoutingStoreError::Backend(e.to_string()))?;
213            Ok(rows
214                .into_iter()
215                .filter_map(MgmtRouteRow::into_route)
216                .collect())
217        }
218
219        async fn create_route(
220            &self,
221            org_id: Uuid,
222            spec: crate::store::NewRoute,
223        ) -> Result<Route, RoutingStoreError> {
224            let conditions = serde_json::to_value(&spec.when)
225                .map_err(|e| RoutingStoreError::Backend(e.to_string()))?;
226            let target = serde_json::to_value(&spec.then)
227                .map_err(|e| RoutingStoreError::Backend(e.to_string()))?;
228            let row = sqlx::query_as::<_, MgmtRouteRow>(
229                "INSERT INTO routes (org_id, name, priority, conditions, target, enabled) \
230                 VALUES ($1, $2, $3, $4, $5, $6) \
231                 RETURNING id, name, priority, enabled, conditions, target",
232            )
233            .bind(org_id)
234            .bind(&spec.name)
235            .bind(i32::try_from(spec.priority).unwrap_or(i32::MAX))
236            .bind(&conditions)
237            .bind(&target)
238            .bind(spec.enabled)
239            .fetch_one(&self.pool)
240            .await
241            .map_err(|e| RoutingStoreError::Backend(e.to_string()))?;
242            row.into_route()
243                .ok_or_else(|| RoutingStoreError::Backend("created route failed to decode".into()))
244        }
245
246        async fn get_route(
247            &self,
248            org_id: Uuid,
249            id: Uuid,
250        ) -> Result<Option<Route>, RoutingStoreError> {
251            let row = sqlx::query_as::<_, MgmtRouteRow>(
252                "SELECT id, name, priority, enabled, conditions, target \
253                 FROM routes WHERE org_id = $1 AND id = $2",
254            )
255            .bind(org_id)
256            .bind(id)
257            .fetch_optional(&self.pool)
258            .await
259            .map_err(|e| RoutingStoreError::Backend(e.to_string()))?;
260            Ok(row.and_then(MgmtRouteRow::into_route))
261        }
262
263        async fn delete_route(&self, org_id: Uuid, id: Uuid) -> Result<bool, RoutingStoreError> {
264            let res = sqlx::query("DELETE FROM routes WHERE org_id = $1 AND id = $2")
265                .bind(org_id)
266                .bind(id)
267                .execute(&self.pool)
268                .await
269                .map_err(|e| RoutingStoreError::Backend(e.to_string()))?;
270            Ok(res.rows_affected() > 0)
271        }
272    }
273
274    #[derive(sqlx::FromRow)]
275    struct RouteRow {
276        id: Uuid,
277        name: String,
278        priority: i32,
279        conditions: sqlx::types::Json<serde_json::Value>,
280        target: sqlx::types::Json<serde_json::Value>,
281    }
282
283    impl RouteRow {
284        fn into_route(self) -> Option<Route> {
285            let when = match serde_json::from_value::<RouteConditions>(self.conditions.0) {
286                Ok(c) => c,
287                Err(e) => {
288                    tracing::warn!(route_id = %self.id, error = %e, "skipping route — conditions JSON failed to decode");
289                    return None;
290                }
291            };
292            let then = match serde_json::from_value::<RouteAction>(self.target.0) {
293                Ok(t) => t,
294                Err(e) => {
295                    tracing::warn!(route_id = %self.id, error = %e, "skipping route — target JSON failed to decode");
296                    return None;
297                }
298            };
299            Some(Route {
300                id: self.id,
301                name: self.name,
302                priority: u32::try_from(self.priority).unwrap_or(0),
303                enabled: true,
304                when,
305                then,
306            })
307        }
308    }
309
310    /// Like [`RouteRow`] but carries `enabled` (management lists disabled routes too).
311    #[derive(sqlx::FromRow)]
312    struct MgmtRouteRow {
313        id: Uuid,
314        name: String,
315        priority: i32,
316        enabled: bool,
317        conditions: sqlx::types::Json<serde_json::Value>,
318        target: sqlx::types::Json<serde_json::Value>,
319    }
320
321    impl MgmtRouteRow {
322        fn into_route(self) -> Option<Route> {
323            let when = serde_json::from_value::<RouteConditions>(self.conditions.0).ok()?;
324            let then = serde_json::from_value::<RouteAction>(self.target.0).ok()?;
325            Some(Route {
326                id: self.id,
327                name: self.name,
328                priority: u32::try_from(self.priority).unwrap_or(0),
329                enabled: self.enabled,
330                when,
331                then,
332            })
333        }
334    }
335}
336
337#[cfg(feature = "postgres")]
338pub use pg::PostgresRoutingStore;
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    #[allow(unused_imports)]
344    use crate::Route;
345    use crate::{RouteAction, RouteConditions};
346
347    fn route(name: &str, priority: u32, target: &str) -> Route {
348        Route {
349            id: Uuid::now_v7(),
350            name: name.into(),
351            priority,
352            enabled: true,
353            when: RouteConditions::default(),
354            then: RouteAction {
355                target_model: target.into(),
356                fallbacks: Vec::new(),
357                disable_cache: false,
358                max_cost_usd: None,
359            },
360        }
361    }
362
363    #[tokio::test]
364    async fn in_memory_returns_empty_for_unknown_org() {
365        let s = InMemoryRoutingStore::new();
366        let rs = s.list_for_org(Uuid::now_v7()).await.unwrap();
367        assert!(rs.is_empty());
368    }
369
370    #[tokio::test]
371    async fn in_memory_set_and_fetch_round_trips() {
372        let s = InMemoryRoutingStore::new();
373        let org = Uuid::now_v7();
374        s.set_routes(org, vec![route("a", 10, "m1"), route("b", 5, "m2")]);
375        let rs = s.list_for_org(org).await.unwrap();
376        assert_eq!(rs.len(), 2);
377    }
378
379    #[tokio::test]
380    async fn in_memory_create_list_get_delete() {
381        let s = InMemoryRoutingStore::new();
382        let org = Uuid::now_v7();
383        let spec = NewRoute {
384            name: "pin".into(),
385            priority: 100,
386            enabled: true,
387            when: RouteConditions::default(),
388            then: RouteAction {
389                target_model: "m1".into(),
390                fallbacks: vec![],
391                disable_cache: false,
392                max_cost_usd: None,
393            },
394        };
395        let created = s.create_route(org, spec).await.unwrap();
396        assert_eq!(created.name, "pin");
397
398        let all = s.list_all_for_org(org).await.unwrap();
399        assert_eq!(all.len(), 1);
400
401        let got = s.get_route(org, created.id).await.unwrap();
402        assert_eq!(got.unwrap().id, created.id);
403
404        assert!(s.delete_route(org, created.id).await.unwrap());
405        assert!(s.get_route(org, created.id).await.unwrap().is_none());
406        assert!(!s.delete_route(org, created.id).await.unwrap());
407    }
408
409    #[tokio::test]
410    async fn in_memory_management_is_org_scoped() {
411        let s = InMemoryRoutingStore::new();
412        let org_a = Uuid::now_v7();
413        let org_b = Uuid::now_v7();
414        let created = s
415            .create_route(
416                org_a,
417                NewRoute {
418                    name: "a".into(),
419                    priority: 1,
420                    enabled: true,
421                    when: RouteConditions::default(),
422                    then: RouteAction {
423                        target_model: "m".into(),
424                        fallbacks: vec![],
425                        disable_cache: false,
426                        max_cost_usd: None,
427                    },
428                },
429            )
430            .await
431            .unwrap();
432        assert!(s.get_route(org_b, created.id).await.unwrap().is_none());
433        assert!(!s.delete_route(org_b, created.id).await.unwrap());
434        assert_eq!(s.list_all_for_org(org_b).await.unwrap().len(), 0);
435    }
436}