Skip to main content

mini_chat/
module.rs

1use std::sync::{Arc, Mutex, OnceLock};
2
3use async_trait::async_trait;
4use authz_resolver_sdk::AuthZResolverClient;
5use mini_chat_sdk::{MiniChatAuditPluginSpecV1, MiniChatModelPolicyPluginSpecV1};
6use modkit::api::OpenApiRegistry;
7use modkit::contracts::RunnableCapability;
8use modkit::{DatabaseCapability, Module, ModuleCtx, RestApiCapability};
9use modkit_db::outbox::{Outbox, OutboxHandle, Partitions};
10use oagw_sdk::ServiceGatewayClientV1;
11use sea_orm_migration::MigrationTrait;
12use tokio_util::sync::CancellationToken;
13use tracing::info;
14use types_registry_sdk::{RegisterResult, TypesRegistryClient};
15
16use crate::api::rest::routes;
17use crate::domain::service::{AppServices as GenericAppServices, Repositories};
18use crate::infra::outbox::{InfraOutboxEnqueuer, UsageEventHandler};
19
20pub(crate) type AppServices = GenericAppServices<
21    TurnRepository,
22    MessageRepository,
23    QuotaUsageRepository,
24    ReactionRepository,
25    ChatRepository,
26    ThreadSummaryRepository,
27>;
28use crate::infra::db::repo::attachment_repo::AttachmentRepository;
29use crate::infra::db::repo::chat_repo::ChatRepository;
30use crate::infra::db::repo::message_repo::MessageRepository;
31use crate::infra::db::repo::quota_usage_repo::QuotaUsageRepository;
32use crate::infra::db::repo::reaction_repo::ReactionRepository;
33use crate::infra::db::repo::thread_summary_repo::ThreadSummaryRepository;
34use crate::infra::db::repo::turn_repo::TurnRepository;
35use crate::infra::db::repo::vector_store_repo::VectorStoreRepository;
36use crate::infra::llm::provider_resolver::ProviderResolver;
37use crate::infra::model_policy::ModelPolicyGateway;
38
39/// Default URL prefix for all mini-chat REST routes.
40pub const DEFAULT_URL_PREFIX: &str = "/mini-chat";
41
42/// The mini-chat module: multi-tenant AI chat with SSE streaming.
43#[modkit::module(
44    name = "mini-chat",
45    deps = ["types-registry", "authz-resolver", "oagw"],
46    capabilities = [db, rest, stateful],
47)]
48pub struct MiniChatModule {
49    service: OnceLock<Arc<AppServices>>,
50    url_prefix: OnceLock<String>,
51    outbox_handle: Mutex<Option<OutboxHandle>>,
52    /// OAGW gateway + provider config for deferred upstream registration in `start()`.
53    oagw_deferred: OnceLock<OagwDeferred>,
54}
55
56/// State needed to register OAGW upstreams in `start()` (after GTS is ready).
57struct OagwDeferred {
58    gateway: Arc<dyn ServiceGatewayClientV1>,
59    providers: std::collections::HashMap<String, crate::config::ProviderEntry>,
60}
61
62impl Default for MiniChatModule {
63    fn default() -> Self {
64        Self {
65            service: OnceLock::new(),
66            url_prefix: OnceLock::new(),
67            outbox_handle: Mutex::new(None),
68            oagw_deferred: OnceLock::new(),
69        }
70    }
71}
72
73#[async_trait]
74impl Module for MiniChatModule {
75    async fn init(&self, ctx: &ModuleCtx) -> anyhow::Result<()> {
76        info!("Initializing {} module", Self::MODULE_NAME);
77
78        let mut cfg: crate::config::MiniChatConfig = ctx.config_expanded()?;
79        cfg.streaming
80            .validate()
81            .map_err(|e| anyhow::anyhow!("streaming config: {e}"))?;
82        cfg.estimation_budgets
83            .validate()
84            .map_err(|e| anyhow::anyhow!("estimation_budgets config: {e}"))?;
85        cfg.quota
86            .validate()
87            .map_err(|e| anyhow::anyhow!("quota config: {e}"))?;
88        cfg.outbox
89            .validate()
90            .map_err(|e| anyhow::anyhow!("outbox config: {e}"))?;
91        cfg.context
92            .validate()
93            .map_err(|e| anyhow::anyhow!("context config: {e}"))?;
94        for (id, entry) in &cfg.providers {
95            entry
96                .validate(id)
97                .map_err(|e| anyhow::anyhow!("providers config: {e}"))?;
98        }
99
100        let vendor = cfg.vendor.trim().to_owned();
101        if vendor.is_empty() {
102            return Err(anyhow::anyhow!(
103                "{}: vendor must be a non-empty string",
104                Self::MODULE_NAME
105            ));
106        }
107
108        let registry = ctx.client_hub().get::<dyn TypesRegistryClient>()?;
109        register_plugin_schemas(
110            &*registry,
111            &[
112                (
113                    MiniChatModelPolicyPluginSpecV1::gts_schema_with_refs_as_string(),
114                    MiniChatModelPolicyPluginSpecV1::gts_schema_id(),
115                    "model-policy",
116                ),
117                (
118                    MiniChatAuditPluginSpecV1::gts_schema_with_refs_as_string(),
119                    MiniChatAuditPluginSpecV1::gts_schema_id(),
120                    "audit",
121                ),
122            ],
123        )
124        .await?;
125
126        self.url_prefix
127            .set(cfg.url_prefix)
128            .map_err(|_| anyhow::anyhow!("{} url_prefix already set", Self::MODULE_NAME))?;
129
130        let db_provider = ctx.db_required()?;
131        let db = Arc::new(db_provider);
132
133        // Create the model-policy gateway early for both outbox handler and services.
134        let model_policy_gw = Arc::new(ModelPolicyGateway::new(ctx.client_hub(), vendor));
135
136        // Start the outbox pipeline eagerly in init() (migrations ran in phase 2, DB is ready).
137        // The framework guarantees stop() is called on init failure, so the pipeline
138        // will be shut down cleanly if any later init step errors.
139        // The handler resolves the plugin lazily on first message delivery,
140        // avoiding a hard dependency on plugin availability during init().
141        let outbox_db = db.db();
142        let num_partitions = cfg.outbox.num_partitions;
143        let queue_name = cfg.outbox.queue_name.clone();
144
145        let outbox_handle =
146            Outbox::builder(outbox_db)
147                .queue(
148                    &queue_name,
149                    Partitions::of(u16::try_from(num_partitions).map_err(|_| {
150                        anyhow::anyhow!("num_partitions {num_partitions} exceeds u16")
151                    })?),
152                )
153                .decoupled(UsageEventHandler {
154                    plugin_provider: model_policy_gw.clone(),
155                })
156                .start()
157                .await
158                .map_err(|e| anyhow::anyhow!("outbox start: {e}"))?;
159
160        let outbox = Arc::clone(outbox_handle.outbox());
161        let outbox_enqueuer =
162            Arc::new(InfraOutboxEnqueuer::new(outbox, queue_name, num_partitions));
163
164        {
165            let mut guard = self
166                .outbox_handle
167                .lock()
168                .map_err(|e| anyhow::anyhow!("outbox_handle lock: {e}"))?;
169            if guard.is_some() {
170                return Err(anyhow::anyhow!(
171                    "{} outbox_handle already set",
172                    Self::MODULE_NAME
173                ));
174            }
175            *guard = Some(outbox_handle);
176        }
177
178        info!("Outbox pipeline started");
179
180        let authz = ctx
181            .client_hub()
182            .get::<dyn AuthZResolverClient>()
183            .map_err(|e| anyhow::anyhow!("failed to get AuthZ resolver: {e}"))?;
184
185        let gateway = ctx
186            .client_hub()
187            .get::<dyn ServiceGatewayClientV1>()
188            .map_err(|e| anyhow::anyhow!("failed to get OAGW gateway: {e}"))?;
189
190        // Pre-fill upstream_alias with host as fallback so ProviderResolver
191        // works immediately. The actual OAGW registration is deferred to
192        // start() because GTS instances are not visible via list() until
193        // post_init (types-registry switches to ready mode there).
194        for entry in cfg.providers.values_mut() {
195            if entry.upstream_alias.is_none() {
196                entry.upstream_alias = Some(entry.host.clone());
197            }
198            for ovr in entry.tenant_overrides.values_mut() {
199                if ovr.upstream_alias.is_none()
200                    && let Some(ref h) = ovr.host
201                {
202                    ovr.upstream_alias = Some(h.clone());
203                }
204            }
205        }
206
207        // Save a copy for deferred OAGW registration in start().
208        // Ignore the result: if already set, we keep the first value.
209        drop(self.oagw_deferred.set(OagwDeferred {
210            gateway: Arc::clone(&gateway),
211            providers: cfg.providers.clone(),
212        }));
213
214        let provider_resolver = Arc::new(ProviderResolver::new(&gateway, cfg.providers));
215
216        let repos = Repositories {
217            chat: Arc::new(ChatRepository::new(modkit_db::odata::LimitCfg {
218                default: 20,
219                max: 100,
220            })),
221            attachment: Arc::new(AttachmentRepository),
222            message: Arc::new(MessageRepository::new(modkit_db::odata::LimitCfg {
223                default: 20,
224                max: 100,
225            })),
226            quota: Arc::new(QuotaUsageRepository),
227            turn: Arc::new(TurnRepository),
228            reaction: Arc::new(ReactionRepository),
229            thread_summary: Arc::new(ThreadSummaryRepository),
230            vector_store: Arc::new(VectorStoreRepository),
231        };
232
233        let services = Arc::new(AppServices::new(
234            &repos,
235            db,
236            authz,
237            &(model_policy_gw.clone() as Arc<dyn crate::domain::repos::ModelResolver>),
238            provider_resolver,
239            cfg.streaming,
240            model_policy_gw.clone() as Arc<dyn crate::domain::repos::PolicySnapshotProvider>,
241            model_policy_gw as Arc<dyn crate::domain::repos::UserLimitsProvider>,
242            cfg.estimation_budgets,
243            cfg.quota,
244            outbox_enqueuer,
245            cfg.context,
246        ));
247
248        self.service
249            .set(services)
250            .map_err(|_| anyhow::anyhow!("{} module already initialized", Self::MODULE_NAME))?;
251
252        info!("{} module initialized successfully", Self::MODULE_NAME);
253        Ok(())
254    }
255}
256
257impl DatabaseCapability for MiniChatModule {
258    fn migrations(&self) -> Vec<Box<dyn MigrationTrait>> {
259        use sea_orm_migration::MigratorTrait;
260        info!("Providing mini-chat database migrations");
261        let mut m = crate::infra::db::migrations::Migrator::migrations();
262        m.extend(modkit_db::outbox::outbox_migrations());
263        m
264    }
265}
266
267impl RestApiCapability for MiniChatModule {
268    fn register_rest(
269        &self,
270        _ctx: &ModuleCtx,
271        router: axum::Router,
272        openapi: &dyn OpenApiRegistry,
273    ) -> anyhow::Result<axum::Router> {
274        let services = self
275            .service
276            .get()
277            .ok_or_else(|| anyhow::anyhow!("{} not initialized", Self::MODULE_NAME))?;
278
279        info!("Registering mini-chat REST routes");
280        let prefix = self
281            .url_prefix
282            .get()
283            .ok_or_else(|| anyhow::anyhow!("{} not initialized (url_prefix)", Self::MODULE_NAME))?;
284
285        let router = routes::register_routes(router, openapi, Arc::clone(services), prefix);
286        info!("Mini-chat REST routes registered successfully");
287        Ok(router)
288    }
289}
290
291#[async_trait]
292impl RunnableCapability for MiniChatModule {
293    async fn start(&self, _cancel: CancellationToken) -> anyhow::Result<()> {
294        // Register OAGW upstreams now that GTS is in ready mode (post_init
295        // has completed). During init() this fails because types-registry
296        // list() only queries the persistent store which is empty until
297        // switch_to_ready().
298        if let Some(deferred) = self.oagw_deferred.get() {
299            let mut providers = deferred.providers.clone();
300            crate::infra::oagw_provisioning::register_oagw_upstreams(
301                &deferred.gateway,
302                &mut providers,
303            )
304            .await?;
305        }
306        Ok(())
307    }
308
309    async fn stop(&self, cancel: CancellationToken) -> anyhow::Result<()> {
310        let handle = self
311            .outbox_handle
312            .lock()
313            .map_err(|e| anyhow::anyhow!("outbox_handle lock: {e}"))?
314            .take();
315        if let Some(handle) = handle {
316            info!("Stopping outbox pipeline");
317            tokio::select! {
318                () = handle.stop() => {
319                    info!("Outbox pipeline stopped");
320                }
321                () = cancel.cancelled() => {
322                    info!("Outbox pipeline stop cancelled by framework deadline");
323                }
324            }
325        }
326        Ok(())
327    }
328}
329
330async fn register_plugin_schemas(
331    registry: &dyn TypesRegistryClient,
332    schemas: &[(String, &str, &str)],
333) -> anyhow::Result<()> {
334    let mut payload = Vec::with_capacity(schemas.len());
335    for (schema_str, schema_id, _label) in schemas {
336        let mut schema_json: serde_json::Value = serde_json::from_str(schema_str)?;
337        let obj = schema_json
338            .as_object_mut()
339            .ok_or_else(|| anyhow::anyhow!("schema {schema_id} is not a JSON object"))?;
340        obj.insert(
341            "additionalProperties".to_owned(),
342            serde_json::Value::Bool(false),
343        );
344        payload.push(schema_json);
345    }
346    let results = registry.register(payload).await?;
347    RegisterResult::ensure_all_ok(&results)?;
348    for (_schema_str, schema_id, label) in schemas {
349        info!(schema_id = %schema_id, "Registered {label} plugin schema in types-registry");
350    }
351    Ok(())
352}