Skip to main content

systemprompt_runtime/
registry.rs

1use axum::Router;
2use std::collections::HashMap;
3
4pub use systemprompt_models::modules::{Module, ModuleType, Modules, ServiceCategory};
5
6use crate::AppContext;
7
8#[derive(Debug)]
9pub struct ModuleApiRegistry {
10    registry: HashMap<String, ModuleApiImpl>,
11}
12
13#[derive(Debug)]
14struct ModuleApiImpl {
15    category: ServiceCategory,
16    module_type: ModuleType,
17    router_fn: fn(&AppContext) -> Router,
18    auth_required: bool,
19}
20
21#[derive(Debug, Copy, Clone)]
22pub struct ModuleApiRegistration {
23    pub module_name: &'static str,
24    pub category: ServiceCategory,
25    pub module_type: ModuleType,
26    pub router_fn: fn(&AppContext) -> Router,
27    pub auth_required: bool,
28}
29
30inventory::collect!(ModuleApiRegistration);
31
32#[derive(Debug, Clone, Copy)]
33pub struct WellKnownRoute {
34    pub path: &'static str,
35    pub handler_fn: fn(&AppContext) -> Router,
36    pub methods: &'static [axum::http::Method],
37}
38
39inventory::collect!(WellKnownRoute);
40
41impl Default for ModuleApiRegistry {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47impl ModuleApiRegistry {
48    pub fn new() -> Self {
49        let mut registry = HashMap::new();
50
51        for registration in inventory::iter::<ModuleApiRegistration> {
52            let api_impl = ModuleApiImpl {
53                category: registration.category,
54                module_type: registration.module_type,
55                router_fn: registration.router_fn,
56                auth_required: registration.auth_required,
57            };
58            registry.insert(registration.module_name.to_string(), api_impl);
59        }
60
61        Self { registry }
62    }
63
64    pub fn get_routes(&self, module_name: &str, ctx: &AppContext) -> Option<Router> {
65        self.registry
66            .get(module_name)
67            .map(|impl_| (impl_.router_fn)(ctx))
68    }
69
70    pub fn get_category(&self, module_name: &str) -> Option<ServiceCategory> {
71        self.registry.get(module_name).map(|impl_| impl_.category)
72    }
73
74    pub fn get_module_type(&self, module_name: &str) -> Option<ModuleType> {
75        self.registry
76            .get(module_name)
77            .map(|impl_| impl_.module_type)
78    }
79
80    pub fn get_auth_required(&self, module_name: &str) -> Option<bool> {
81        self.registry
82            .get(module_name)
83            .map(|impl_| impl_.auth_required)
84    }
85
86    #[allow(private_interfaces)]
87    pub fn get_registration(&self, module_name: &str) -> Option<&ModuleApiImpl> {
88        self.registry.get(module_name)
89    }
90
91    pub fn modules_by_category(&self, category: ServiceCategory) -> Vec<String> {
92        self.registry
93            .iter()
94            .filter(|(_, impl_)| matches!(impl_.category, c if c as u8 == category as u8))
95            .map(|(name, _)| name.clone())
96            .collect()
97    }
98}
99
100pub trait ModuleRuntime {
101    fn routes(&self, ctx: &AppContext, registry: &ModuleApiRegistry) -> Option<Router>;
102    fn create_api_registry(&self) -> ModuleApiRegistry;
103}
104
105impl ModuleRuntime for Module {
106    fn routes(&self, ctx: &AppContext, registry: &ModuleApiRegistry) -> Option<Router> {
107        if let Some(api) = &self.api {
108            if api.enabled {
109                return registry.get_routes(&self.name, ctx);
110            }
111        }
112        None
113    }
114
115    fn create_api_registry(&self) -> ModuleApiRegistry {
116        ModuleApiRegistry::new()
117    }
118}
119
120impl ModuleRuntime for Modules {
121    fn routes(&self, _ctx: &AppContext, _registry: &ModuleApiRegistry) -> Option<Router> {
122        None
123    }
124
125    fn create_api_registry(&self) -> ModuleApiRegistry {
126        ModuleApiRegistry::new()
127    }
128}