Skip to main content

mini_chat/
module.rs

1use std::sync::{Arc, Mutex, OnceLock};
2
3use async_trait::async_trait;
4use authn_resolver_sdk::{AuthNResolverClient, ClientCredentialsRequest};
5use authz_resolver_sdk::AuthZResolverClient;
6use mini_chat_sdk::{MiniChatAuditPluginSpecV1, MiniChatModelPolicyPluginSpecV1};
7use modkit::api::OpenApiRegistry;
8use modkit::contracts::RunnableCapability;
9use modkit::{DatabaseCapability, Module, ModuleCtx, RestApiCapability};
10use modkit_db::outbox::{Outbox, OutboxHandle, Partitions};
11use oagw_sdk::ServiceGatewayClientV1;
12use sea_orm_migration::MigrationTrait;
13use tokio_util::sync::CancellationToken;
14use tracing::info;
15use types_registry_sdk::{RegisterResult, TypesRegistryClient};
16
17use crate::api::rest::routes;
18use crate::config::ProviderEntry;
19use crate::domain::ports::MiniChatMetricsPort;
20use crate::domain::service::{AppServices as GenericAppServices, Repositories};
21use crate::infra::metrics::MiniChatMetricsMeter;
22use crate::infra::outbox::{
23    AttachmentCleanupHandler, AuditEventHandler, InfraOutboxEnqueuer, UsageEventHandler,
24};
25
26pub(crate) type AppServices = GenericAppServices<
27    TurnRepository,
28    MessageRepository,
29    QuotaUsageRepository,
30    ReactionRepository,
31    ChatRepository,
32    ThreadSummaryRepository,
33    AttachmentRepository,
34    VectorStoreRepository,
35    MessageAttachmentRepository,
36>;
37use crate::infra::audit_gateway::AuditGateway;
38use crate::infra::db::repo::attachment_repo::AttachmentRepository;
39use crate::infra::db::repo::chat_repo::ChatRepository;
40use crate::infra::db::repo::message_attachment_repo::MessageAttachmentRepository;
41use crate::infra::db::repo::message_repo::MessageRepository;
42use crate::infra::db::repo::quota_usage_repo::QuotaUsageRepository;
43use crate::infra::db::repo::reaction_repo::ReactionRepository;
44use crate::infra::db::repo::thread_summary_repo::ThreadSummaryRepository;
45use crate::infra::db::repo::turn_repo::TurnRepository;
46use crate::infra::db::repo::vector_store_repo::VectorStoreRepository;
47use crate::infra::llm::provider_resolver::ProviderResolver;
48use crate::infra::model_policy::ModelPolicyGateway;
49
50/// Default URL prefix for all mini-chat REST routes.
51pub const DEFAULT_URL_PREFIX: &str = "/mini-chat";
52
53/// The mini-chat module: multi-tenant AI chat with SSE streaming.
54#[modkit::module(
55    name = "mini-chat",
56    deps = ["types-registry", "authn-resolver", "authz-resolver", "oagw"],
57    capabilities = [db, rest, stateful],
58)]
59pub struct MiniChatModule {
60    service: OnceLock<Arc<AppServices>>,
61    url_prefix: OnceLock<String>,
62    outbox_handle: Mutex<Option<OutboxHandle>>,
63    /// OAGW gateway + provider config for deferred upstream registration in `start()`.
64    oagw_deferred: OnceLock<OagwDeferred>,
65}
66
67/// State needed to register OAGW upstreams in `start()` (after GTS is ready).
68struct OagwDeferred {
69    gateway: Arc<dyn ServiceGatewayClientV1>,
70    authn: Arc<dyn AuthNResolverClient>,
71    client_credentials: crate::config::ClientCredentialsConfig,
72    providers: std::collections::HashMap<String, ProviderEntry>,
73}
74
75impl Default for MiniChatModule {
76    fn default() -> Self {
77        Self {
78            service: OnceLock::new(),
79            url_prefix: OnceLock::new(),
80            outbox_handle: Mutex::new(None),
81            oagw_deferred: OnceLock::new(),
82        }
83    }
84}
85
86#[allow(clippy::too_many_lines)]
87#[async_trait]
88impl Module for MiniChatModule {
89    async fn init(&self, ctx: &ModuleCtx) -> anyhow::Result<()> {
90        info!("Initializing {} module", Self::MODULE_NAME);
91
92        let mut cfg: crate::config::MiniChatConfig = ctx.config_expanded()?;
93        cfg.streaming
94            .validate()
95            .map_err(|e| anyhow::anyhow!("streaming config: {e}"))?;
96        cfg.estimation_budgets
97            .validate()
98            .map_err(|e| anyhow::anyhow!("estimation_budgets config: {e}"))?;
99        cfg.quota
100            .validate()
101            .map_err(|e| anyhow::anyhow!("quota config: {e}"))?;
102        cfg.outbox
103            .validate()
104            .map_err(|e| anyhow::anyhow!("outbox config: {e}"))?;
105        cfg.context
106            .validate()
107            .map_err(|e| anyhow::anyhow!("context config: {e}"))?;
108        cfg.client_credentials
109            .validate()
110            .map_err(|e| anyhow::anyhow!("client_credentials config: {e}"))?;
111        for (id, entry) in &cfg.providers {
112            entry
113                .validate(id)
114                .map_err(|e| anyhow::anyhow!("providers config: {e}"))?;
115        }
116
117        let vendor = cfg.vendor.trim().to_owned();
118        if vendor.is_empty() {
119            return Err(anyhow::anyhow!(
120                "{}: vendor must be a non-empty string",
121                Self::MODULE_NAME
122            ));
123        }
124
125        let registry = ctx.client_hub().get::<dyn TypesRegistryClient>()?;
126        register_plugin_schemas(
127            &*registry,
128            &[
129                (
130                    MiniChatModelPolicyPluginSpecV1::gts_schema_with_refs_as_string(),
131                    MiniChatModelPolicyPluginSpecV1::gts_schema_id(),
132                    "model-policy",
133                ),
134                (
135                    MiniChatAuditPluginSpecV1::gts_schema_with_refs_as_string(),
136                    MiniChatAuditPluginSpecV1::gts_schema_id(),
137                    "audit",
138                ),
139            ],
140        )
141        .await?;
142
143        self.url_prefix
144            .set(cfg.url_prefix)
145            .map_err(|_| anyhow::anyhow!("{} url_prefix already set", Self::MODULE_NAME))?;
146
147        let db_provider = ctx.db_required()?;
148        let db = Arc::new(db_provider);
149
150        // Create the model-policy gateway early for both outbox handler and services.
151        let model_policy_gw = Arc::new(ModelPolicyGateway::new(
152            ctx.client_hub(),
153            vendor.clone(),
154            ctx.cancellation_token().clone(),
155        ));
156
157        // Audit gateway: lazily resolves audit plugin(s) on first emission.
158        let audit_gateway = Arc::new(AuditGateway::new(ctx.client_hub(), vendor));
159
160        // Start the outbox pipeline eagerly in init() (migrations ran in phase 2, DB is ready).
161        // The framework guarantees stop() is called on init failure, so the pipeline
162        // will be shut down cleanly if any later init step errors.
163        // The handler resolves the plugin lazily on first message delivery,
164        // avoiding a hard dependency on plugin availability during init().
165        let outbox_db = db.db();
166        let num_partitions = cfg.outbox.num_partitions;
167        let queue_name = cfg.outbox.queue_name.clone();
168        let cleanup_queue_name = cfg.outbox.cleanup_queue_name.clone();
169        let audit_queue_name = cfg.outbox.audit_queue_name.clone();
170
171        let partitions = Partitions::of(
172            u16::try_from(num_partitions)
173                .map_err(|_| anyhow::anyhow!("num_partitions {num_partitions} exceeds u16"))?,
174        );
175
176        // Metrics are created here (before the outbox) so they can be passed to AuditEventHandler.
177        let metrics_prefix = cfg.metrics.effective_prefix(Self::MODULE_NAME);
178        let scope =
179            opentelemetry::InstrumentationScope::builder(Self::MODULE_NAME.to_owned()).build();
180        let metrics: Arc<dyn MiniChatMetricsPort> = Arc::new(MiniChatMetricsMeter::new(
181            &opentelemetry::global::meter_with_scope(scope),
182            &metrics_prefix,
183        ));
184
185        let outbox_handle = Outbox::builder(outbox_db)
186            .queue(&queue_name, partitions)
187            .decoupled(UsageEventHandler {
188                plugin_provider: model_policy_gw.clone(),
189            })
190            .queue(&cleanup_queue_name, partitions)
191            .decoupled(AttachmentCleanupHandler)
192            .queue(&audit_queue_name, partitions)
193            .decoupled(AuditEventHandler {
194                audit_gateway: Arc::clone(&audit_gateway),
195                metrics: Arc::clone(&metrics),
196            })
197            .start()
198            .await
199            .map_err(|e| anyhow::anyhow!("outbox start: {e}"))?;
200
201        let outbox = Arc::clone(outbox_handle.outbox());
202        let outbox_enqueuer = Arc::new(InfraOutboxEnqueuer::new(
203            outbox,
204            queue_name,
205            cleanup_queue_name,
206            audit_queue_name,
207            num_partitions,
208        ));
209
210        {
211            let mut guard = self
212                .outbox_handle
213                .lock()
214                .map_err(|e| anyhow::anyhow!("outbox_handle lock: {e}"))?;
215            if guard.is_some() {
216                return Err(anyhow::anyhow!(
217                    "{} outbox_handle already set",
218                    Self::MODULE_NAME
219                ));
220            }
221            *guard = Some(outbox_handle);
222        }
223
224        info!("Outbox pipeline started");
225
226        let authz = ctx
227            .client_hub()
228            .get::<dyn AuthZResolverClient>()
229            .map_err(|e| anyhow::anyhow!("failed to get AuthZ resolver: {e}"))?;
230
231        let authn_client = ctx
232            .client_hub()
233            .get::<dyn AuthNResolverClient>()
234            .map_err(|e| anyhow::anyhow!("failed to get AuthN resolver: {e}"))?;
235
236        let gateway = ctx
237            .client_hub()
238            .get::<dyn ServiceGatewayClientV1>()
239            .map_err(|e| anyhow::anyhow!("failed to get OAGW gateway: {e}"))?;
240
241        // Pre-fill upstream_alias with host as fallback so ProviderResolver
242        // works immediately. The actual OAGW registration is deferred to
243        // start() because GTS instances are not visible via list() until
244        // post_init (types-registry switches to ready mode there).
245        for entry in cfg.providers.values_mut() {
246            if entry.upstream_alias.is_none() {
247                entry.upstream_alias = Some(entry.host.clone());
248            }
249            for ovr in entry.tenant_overrides.values_mut() {
250                if ovr.upstream_alias.is_none()
251                    && let Some(ref h) = ovr.host
252                {
253                    ovr.upstream_alias = Some(h.clone());
254                }
255            }
256        }
257
258        // Save a copy for deferred OAGW registration in start().
259        // Ignore the result: if already set, we keep the first value.
260        drop(self.oagw_deferred.set(OagwDeferred {
261            gateway: Arc::clone(&gateway),
262            authn: Arc::clone(&authn_client),
263            client_credentials: cfg.client_credentials.clone(),
264            providers: cfg.providers.clone(),
265        }));
266
267        let provider_resolver = Arc::new(ProviderResolver::new(&gateway, cfg.providers));
268
269        let repos = Repositories {
270            chat: Arc::new(ChatRepository::new(modkit_db::odata::LimitCfg {
271                default: 20,
272                max: 100,
273            })),
274            attachment: Arc::new(AttachmentRepository),
275            message: Arc::new(MessageRepository::new(modkit_db::odata::LimitCfg {
276                default: 20,
277                max: 100,
278            })),
279            quota: Arc::new(QuotaUsageRepository),
280            turn: Arc::new(TurnRepository),
281            reaction: Arc::new(ReactionRepository),
282            thread_summary: Arc::new(ThreadSummaryRepository),
283            vector_store: Arc::new(VectorStoreRepository),
284            message_attachment: Arc::new(MessageAttachmentRepository),
285        };
286
287        let rag_client = Arc::new(
288            crate::infra::llm::providers::rag_http_client::RagHttpClient::new(Arc::clone(&gateway)),
289        );
290
291        // Build provider-specific file/vector store impls per provider entry.
292        // Dispatch by storage_kind: Azure → Azure impls, OpenAi → OpenAI impls.
293        let mut file_impls: std::collections::HashMap<
294            String,
295            Arc<dyn crate::domain::ports::FileStorageProvider>,
296        > = std::collections::HashMap::new();
297        let mut vs_impls: std::collections::HashMap<
298            String,
299            Arc<dyn crate::domain::ports::VectorStoreProvider>,
300        > = std::collections::HashMap::new();
301        for (provider_id, entry) in provider_resolver.entries() {
302            let (file, vs): (
303                Arc<dyn crate::domain::ports::FileStorageProvider>,
304                Arc<dyn crate::domain::ports::VectorStoreProvider>,
305            ) = match entry.storage_kind {
306                crate::config::StorageKind::Azure => {
307                    let api_version = entry.api_version.clone().unwrap_or_else(|| {
308                        panic!(
309                            "provider '{provider_id}': storage_kind is 'azure' \
310                             but api_version is not set"
311                        )
312                    });
313                    (
314                        Arc::new(
315                            crate::infra::llm::providers::azure_file_storage::AzureFileStorage::new(
316                                Arc::clone(&rag_client),
317                                Arc::clone(&provider_resolver),
318                                api_version.clone(),
319                            ),
320                        ),
321                        Arc::new(
322                            crate::infra::llm::providers::azure_vector_store::AzureVectorStore::new(
323                                Arc::clone(&rag_client),
324                                Arc::clone(&provider_resolver),
325                                api_version,
326                            ),
327                        ),
328                    )
329                }
330                crate::config::StorageKind::OpenAi => (
331                    Arc::new(
332                        crate::infra::llm::providers::openai_file_storage::OpenAiFileStorage::new(
333                            Arc::clone(&rag_client),
334                            Arc::clone(&provider_resolver),
335                        ),
336                    ),
337                    Arc::new(
338                        crate::infra::llm::providers::openai_vector_store::OpenAiVectorStore::new(
339                            Arc::clone(&rag_client),
340                            Arc::clone(&provider_resolver),
341                        ),
342                    ),
343                ),
344            };
345            file_impls.insert(provider_id.clone(), file);
346            vs_impls.insert(provider_id.clone(), vs);
347        }
348        let file_storage: Arc<dyn crate::domain::ports::FileStorageProvider> = Arc::new(
349            crate::infra::llm::providers::dispatching_storage::DispatchingFileStorage::new(
350                file_impls,
351            ),
352        );
353        let vector_store_prov: Arc<dyn crate::domain::ports::VectorStoreProvider> = Arc::new(
354            crate::infra::llm::providers::dispatching_storage::DispatchingVectorStore::new(
355                vs_impls,
356            ),
357        );
358
359        let services = Arc::new(AppServices::new(
360            &repos,
361            db,
362            authz,
363            &(model_policy_gw.clone() as Arc<dyn crate::domain::repos::ModelResolver>),
364            &provider_resolver,
365            cfg.streaming,
366            model_policy_gw.clone() as Arc<dyn crate::domain::repos::PolicySnapshotProvider>,
367            model_policy_gw as Arc<dyn crate::domain::repos::UserLimitsProvider>,
368            cfg.estimation_budgets,
369            cfg.quota,
370            &(outbox_enqueuer as Arc<dyn crate::domain::repos::OutboxEnqueuer>),
371            cfg.context,
372            file_storage,
373            vector_store_prov,
374            cfg.rag,
375            metrics,
376        ));
377
378        self.service
379            .set(services)
380            .map_err(|_| anyhow::anyhow!("{} module already initialized", Self::MODULE_NAME))?;
381
382        info!("{} module initialized successfully", Self::MODULE_NAME);
383        Ok(())
384    }
385}
386
387impl DatabaseCapability for MiniChatModule {
388    fn migrations(&self) -> Vec<Box<dyn MigrationTrait>> {
389        use sea_orm_migration::MigratorTrait;
390        info!("Providing mini-chat database migrations");
391        let mut m = crate::infra::db::migrations::Migrator::migrations();
392        m.extend(modkit_db::outbox::outbox_migrations());
393        m
394    }
395}
396
397impl RestApiCapability for MiniChatModule {
398    fn register_rest(
399        &self,
400        _ctx: &ModuleCtx,
401        router: axum::Router,
402        openapi: &dyn OpenApiRegistry,
403    ) -> anyhow::Result<axum::Router> {
404        let services = self
405            .service
406            .get()
407            .ok_or_else(|| anyhow::anyhow!("{} not initialized", Self::MODULE_NAME))?;
408
409        info!("Registering mini-chat REST routes");
410        let prefix = self
411            .url_prefix
412            .get()
413            .ok_or_else(|| anyhow::anyhow!("{} not initialized (url_prefix)", Self::MODULE_NAME))?;
414
415        let router = routes::register_routes(router, openapi, Arc::clone(services), prefix);
416        info!("Mini-chat REST routes registered successfully");
417        Ok(router)
418    }
419}
420
421#[async_trait]
422impl RunnableCapability for MiniChatModule {
423    async fn start(&self, _cancel: CancellationToken) -> anyhow::Result<()> {
424        // Register OAGW upstreams now that GTS is in ready mode (post_init
425        // has completed). During init() this fails because types-registry
426        // list() only queries the persistent store which is empty until
427        // switch_to_ready().
428        if let Some(deferred) = self.oagw_deferred.get() {
429            let ctx =
430                exchange_client_credentials(&deferred.authn, &deferred.client_credentials).await?;
431            let mut providers = deferred.providers.clone();
432            crate::infra::oagw_provisioning::register_oagw_upstreams(
433                &deferred.gateway,
434                &ctx,
435                &mut providers,
436            )
437            .await?;
438        }
439        Ok(())
440    }
441
442    async fn stop(&self, cancel: CancellationToken) -> anyhow::Result<()> {
443        let handle = self
444            .outbox_handle
445            .lock()
446            .map_err(|e| anyhow::anyhow!("outbox_handle lock: {e}"))?
447            .take();
448        if let Some(handle) = handle {
449            info!("Stopping outbox pipeline");
450            tokio::select! {
451                () = handle.stop() => {
452                    info!("Outbox pipeline stopped");
453                }
454                () = cancel.cancelled() => {
455                    info!("Outbox pipeline stop cancelled by framework deadline");
456                }
457            }
458        }
459        Ok(())
460    }
461}
462
463/// Exchange `OAuth2` client credentials via the `AuthN` resolver to obtain
464/// a `SecurityContext` for OAGW upstream provisioning.
465async fn exchange_client_credentials(
466    authn: &Arc<dyn AuthNResolverClient>,
467    creds: &crate::config::ClientCredentialsConfig,
468) -> anyhow::Result<modkit_security::SecurityContext> {
469    info!("Exchanging client credentials for OAGW provisioning context");
470    let request = ClientCredentialsRequest {
471        client_id: creds.client_id.clone(),
472        client_secret: creds.client_secret.clone(),
473        scopes: Vec::new(),
474    };
475    let result = authn
476        .exchange_client_credentials(&request)
477        .await
478        .map_err(|e| anyhow::anyhow!("client credentials exchange failed: {e}"))?;
479    info!("Security context obtained for OAGW provisioning");
480    Ok(result.security_context)
481}
482
483async fn register_plugin_schemas(
484    registry: &dyn TypesRegistryClient,
485    schemas: &[(String, &str, &str)],
486) -> anyhow::Result<()> {
487    let mut payload = Vec::with_capacity(schemas.len());
488    for (schema_str, schema_id, _label) in schemas {
489        let mut schema_json: serde_json::Value = serde_json::from_str(schema_str)?;
490        let obj = schema_json
491            .as_object_mut()
492            .ok_or_else(|| anyhow::anyhow!("schema {schema_id} is not a JSON object"))?;
493        obj.insert(
494            "additionalProperties".to_owned(),
495            serde_json::Value::Bool(false),
496        );
497        payload.push(schema_json);
498    }
499    let results = registry.register(payload).await?;
500    RegisterResult::ensure_all_ok(&results)?;
501    for (_schema_str, schema_id, label) in schemas {
502        info!(schema_id = %schema_id, "Registered {label} plugin schema in types-registry");
503    }
504    Ok(())
505}