Skip to main content

mini_chat/
module.rs

1use std::sync::{Arc, Mutex, OnceLock};
2
3use async_trait::async_trait;
4use authn_resolver_sdk::{AuthNResolverClient, ClientCredentialsRequest};
5use authz_resolver_sdk::AuthZResolverClient;
6use mini_chat_sdk::{MiniChatAuditPluginSpecV1, MiniChatModelPolicyPluginSpecV1};
7use modkit::api::OpenApiRegistry;
8use modkit::contracts::RunnableCapability;
9use modkit::{DatabaseCapability, Module, ModuleCtx, RestApiCapability};
10use modkit_db::outbox::{Outbox, OutboxHandle, Partitions};
11use oagw_sdk::ServiceGatewayClientV1;
12use sea_orm_migration::MigrationTrait;
13use tokio_util::sync::CancellationToken;
14use tracing::info;
15use types_registry_sdk::{RegisterResult, TypesRegistryClient};
16
17use crate::api::rest::routes;
18use crate::config::ProviderEntry;
19use crate::domain::ports::MiniChatMetricsPort;
20use crate::domain::service::{AppServices as GenericAppServices, Repositories};
21use crate::infra::metrics::MiniChatMetricsMeter;
22use crate::infra::outbox::{AttachmentCleanupHandler, InfraOutboxEnqueuer, UsageEventHandler};
23
24pub(crate) type AppServices = GenericAppServices<
25    TurnRepository,
26    MessageRepository,
27    QuotaUsageRepository,
28    ReactionRepository,
29    ChatRepository,
30    ThreadSummaryRepository,
31    AttachmentRepository,
32    VectorStoreRepository,
33    MessageAttachmentRepository,
34>;
35use crate::infra::db::repo::attachment_repo::AttachmentRepository;
36use crate::infra::db::repo::chat_repo::ChatRepository;
37use crate::infra::db::repo::message_attachment_repo::MessageAttachmentRepository;
38use crate::infra::db::repo::message_repo::MessageRepository;
39use crate::infra::db::repo::quota_usage_repo::QuotaUsageRepository;
40use crate::infra::db::repo::reaction_repo::ReactionRepository;
41use crate::infra::db::repo::thread_summary_repo::ThreadSummaryRepository;
42use crate::infra::db::repo::turn_repo::TurnRepository;
43use crate::infra::db::repo::vector_store_repo::VectorStoreRepository;
44use crate::infra::llm::provider_resolver::ProviderResolver;
45use crate::infra::model_policy::ModelPolicyGateway;
46
47/// Default URL prefix for all mini-chat REST routes.
48pub const DEFAULT_URL_PREFIX: &str = "/mini-chat";
49
50/// The mini-chat module: multi-tenant AI chat with SSE streaming.
51#[modkit::module(
52    name = "mini-chat",
53    deps = ["types-registry", "authn-resolver", "authz-resolver", "oagw"],
54    capabilities = [db, rest, stateful],
55)]
56pub struct MiniChatModule {
57    service: OnceLock<Arc<AppServices>>,
58    url_prefix: OnceLock<String>,
59    outbox_handle: Mutex<Option<OutboxHandle>>,
60    /// OAGW gateway + provider config for deferred upstream registration in `start()`.
61    oagw_deferred: OnceLock<OagwDeferred>,
62}
63
64/// State needed to register OAGW upstreams in `start()` (after GTS is ready).
65struct OagwDeferred {
66    gateway: Arc<dyn ServiceGatewayClientV1>,
67    authn: Arc<dyn AuthNResolverClient>,
68    client_credentials: crate::config::ClientCredentialsConfig,
69    providers: std::collections::HashMap<String, ProviderEntry>,
70}
71
72impl Default for MiniChatModule {
73    fn default() -> Self {
74        Self {
75            service: OnceLock::new(),
76            url_prefix: OnceLock::new(),
77            outbox_handle: Mutex::new(None),
78            oagw_deferred: OnceLock::new(),
79        }
80    }
81}
82
83#[allow(clippy::too_many_lines)]
84#[async_trait]
85impl Module for MiniChatModule {
86    async fn init(&self, ctx: &ModuleCtx) -> anyhow::Result<()> {
87        info!("Initializing {} module", Self::MODULE_NAME);
88
89        let mut cfg: crate::config::MiniChatConfig = ctx.config_expanded()?;
90        cfg.streaming
91            .validate()
92            .map_err(|e| anyhow::anyhow!("streaming config: {e}"))?;
93        cfg.estimation_budgets
94            .validate()
95            .map_err(|e| anyhow::anyhow!("estimation_budgets config: {e}"))?;
96        cfg.quota
97            .validate()
98            .map_err(|e| anyhow::anyhow!("quota config: {e}"))?;
99        cfg.outbox
100            .validate()
101            .map_err(|e| anyhow::anyhow!("outbox config: {e}"))?;
102        cfg.context
103            .validate()
104            .map_err(|e| anyhow::anyhow!("context config: {e}"))?;
105        cfg.client_credentials
106            .validate()
107            .map_err(|e| anyhow::anyhow!("client_credentials config: {e}"))?;
108        for (id, entry) in &cfg.providers {
109            entry
110                .validate(id)
111                .map_err(|e| anyhow::anyhow!("providers config: {e}"))?;
112        }
113
114        let vendor = cfg.vendor.trim().to_owned();
115        if vendor.is_empty() {
116            return Err(anyhow::anyhow!(
117                "{}: vendor must be a non-empty string",
118                Self::MODULE_NAME
119            ));
120        }
121
122        let registry = ctx.client_hub().get::<dyn TypesRegistryClient>()?;
123        register_plugin_schemas(
124            &*registry,
125            &[
126                (
127                    MiniChatModelPolicyPluginSpecV1::gts_schema_with_refs_as_string(),
128                    MiniChatModelPolicyPluginSpecV1::gts_schema_id(),
129                    "model-policy",
130                ),
131                (
132                    MiniChatAuditPluginSpecV1::gts_schema_with_refs_as_string(),
133                    MiniChatAuditPluginSpecV1::gts_schema_id(),
134                    "audit",
135                ),
136            ],
137        )
138        .await?;
139
140        self.url_prefix
141            .set(cfg.url_prefix)
142            .map_err(|_| anyhow::anyhow!("{} url_prefix already set", Self::MODULE_NAME))?;
143
144        let db_provider = ctx.db_required()?;
145        let db = Arc::new(db_provider);
146
147        // Create the model-policy gateway early for both outbox handler and services.
148        let model_policy_gw = Arc::new(ModelPolicyGateway::new(
149            ctx.client_hub(),
150            vendor,
151            ctx.cancellation_token().clone(),
152        ));
153
154        // Start the outbox pipeline eagerly in init() (migrations ran in phase 2, DB is ready).
155        // The framework guarantees stop() is called on init failure, so the pipeline
156        // will be shut down cleanly if any later init step errors.
157        // The handler resolves the plugin lazily on first message delivery,
158        // avoiding a hard dependency on plugin availability during init().
159        let outbox_db = db.db();
160        let num_partitions = cfg.outbox.num_partitions;
161        let queue_name = cfg.outbox.queue_name.clone();
162        let cleanup_queue_name = cfg.outbox.cleanup_queue_name.clone();
163
164        let partitions = Partitions::of(
165            u16::try_from(num_partitions)
166                .map_err(|_| anyhow::anyhow!("num_partitions {num_partitions} exceeds u16"))?,
167        );
168
169        let outbox_handle = Outbox::builder(outbox_db)
170            .queue(&queue_name, partitions)
171            .decoupled(UsageEventHandler {
172                plugin_provider: model_policy_gw.clone(),
173            })
174            .queue(&cleanup_queue_name, partitions)
175            .decoupled(AttachmentCleanupHandler)
176            .start()
177            .await
178            .map_err(|e| anyhow::anyhow!("outbox start: {e}"))?;
179
180        let outbox = Arc::clone(outbox_handle.outbox());
181        let outbox_enqueuer = Arc::new(InfraOutboxEnqueuer::new(
182            outbox,
183            queue_name,
184            cleanup_queue_name,
185            num_partitions,
186        ));
187
188        {
189            let mut guard = self
190                .outbox_handle
191                .lock()
192                .map_err(|e| anyhow::anyhow!("outbox_handle lock: {e}"))?;
193            if guard.is_some() {
194                return Err(anyhow::anyhow!(
195                    "{} outbox_handle already set",
196                    Self::MODULE_NAME
197                ));
198            }
199            *guard = Some(outbox_handle);
200        }
201
202        info!("Outbox pipeline started");
203
204        let authz = ctx
205            .client_hub()
206            .get::<dyn AuthZResolverClient>()
207            .map_err(|e| anyhow::anyhow!("failed to get AuthZ resolver: {e}"))?;
208
209        let authn_client = ctx
210            .client_hub()
211            .get::<dyn AuthNResolverClient>()
212            .map_err(|e| anyhow::anyhow!("failed to get AuthN resolver: {e}"))?;
213
214        let gateway = ctx
215            .client_hub()
216            .get::<dyn ServiceGatewayClientV1>()
217            .map_err(|e| anyhow::anyhow!("failed to get OAGW gateway: {e}"))?;
218
219        // Pre-fill upstream_alias with host as fallback so ProviderResolver
220        // works immediately. The actual OAGW registration is deferred to
221        // start() because GTS instances are not visible via list() until
222        // post_init (types-registry switches to ready mode there).
223        for entry in cfg.providers.values_mut() {
224            if entry.upstream_alias.is_none() {
225                entry.upstream_alias = Some(entry.host.clone());
226            }
227            for ovr in entry.tenant_overrides.values_mut() {
228                if ovr.upstream_alias.is_none()
229                    && let Some(ref h) = ovr.host
230                {
231                    ovr.upstream_alias = Some(h.clone());
232                }
233            }
234        }
235
236        // Save a copy for deferred OAGW registration in start().
237        // Ignore the result: if already set, we keep the first value.
238        drop(self.oagw_deferred.set(OagwDeferred {
239            gateway: Arc::clone(&gateway),
240            authn: Arc::clone(&authn_client),
241            client_credentials: cfg.client_credentials.clone(),
242            providers: cfg.providers.clone(),
243        }));
244
245        let provider_resolver = Arc::new(ProviderResolver::new(&gateway, cfg.providers));
246
247        let repos = Repositories {
248            chat: Arc::new(ChatRepository::new(modkit_db::odata::LimitCfg {
249                default: 20,
250                max: 100,
251            })),
252            attachment: Arc::new(AttachmentRepository),
253            message: Arc::new(MessageRepository::new(modkit_db::odata::LimitCfg {
254                default: 20,
255                max: 100,
256            })),
257            quota: Arc::new(QuotaUsageRepository),
258            turn: Arc::new(TurnRepository),
259            reaction: Arc::new(ReactionRepository),
260            thread_summary: Arc::new(ThreadSummaryRepository),
261            vector_store: Arc::new(VectorStoreRepository),
262            message_attachment: Arc::new(MessageAttachmentRepository),
263        };
264
265        let rag_client = Arc::new(
266            crate::infra::llm::providers::rag_http_client::RagHttpClient::new(Arc::clone(&gateway)),
267        );
268
269        // Build provider-specific file/vector store impls per provider entry.
270        // Dispatch by storage_kind: Azure → Azure impls, OpenAi → OpenAI impls.
271        let mut file_impls: std::collections::HashMap<
272            String,
273            Arc<dyn crate::domain::ports::FileStorageProvider>,
274        > = std::collections::HashMap::new();
275        let mut vs_impls: std::collections::HashMap<
276            String,
277            Arc<dyn crate::domain::ports::VectorStoreProvider>,
278        > = std::collections::HashMap::new();
279        for (provider_id, entry) in provider_resolver.entries() {
280            let (file, vs): (
281                Arc<dyn crate::domain::ports::FileStorageProvider>,
282                Arc<dyn crate::domain::ports::VectorStoreProvider>,
283            ) = match entry.storage_kind {
284                crate::config::StorageKind::Azure => {
285                    let api_version = entry.api_version.clone().unwrap_or_else(|| {
286                        panic!(
287                            "provider '{provider_id}': storage_kind is 'azure' \
288                             but api_version is not set"
289                        )
290                    });
291                    (
292                        Arc::new(
293                            crate::infra::llm::providers::azure_file_storage::AzureFileStorage::new(
294                                Arc::clone(&rag_client),
295                                Arc::clone(&provider_resolver),
296                                api_version.clone(),
297                            ),
298                        ),
299                        Arc::new(
300                            crate::infra::llm::providers::azure_vector_store::AzureVectorStore::new(
301                                Arc::clone(&rag_client),
302                                Arc::clone(&provider_resolver),
303                                api_version,
304                            ),
305                        ),
306                    )
307                }
308                crate::config::StorageKind::OpenAi => (
309                    Arc::new(
310                        crate::infra::llm::providers::openai_file_storage::OpenAiFileStorage::new(
311                            Arc::clone(&rag_client),
312                            Arc::clone(&provider_resolver),
313                        ),
314                    ),
315                    Arc::new(
316                        crate::infra::llm::providers::openai_vector_store::OpenAiVectorStore::new(
317                            Arc::clone(&rag_client),
318                            Arc::clone(&provider_resolver),
319                        ),
320                    ),
321                ),
322            };
323            file_impls.insert(provider_id.clone(), file);
324            vs_impls.insert(provider_id.clone(), vs);
325        }
326        let file_storage: Arc<dyn crate::domain::ports::FileStorageProvider> = Arc::new(
327            crate::infra::llm::providers::dispatching_storage::DispatchingFileStorage::new(
328                file_impls,
329            ),
330        );
331        let vector_store_prov: Arc<dyn crate::domain::ports::VectorStoreProvider> = Arc::new(
332            crate::infra::llm::providers::dispatching_storage::DispatchingVectorStore::new(
333                vs_impls,
334            ),
335        );
336
337        // Create metrics adapter
338        let metrics_prefix = cfg.metrics.effective_prefix(Self::MODULE_NAME);
339        let scope =
340            opentelemetry::InstrumentationScope::builder(Self::MODULE_NAME.to_owned()).build();
341        let metrics: Arc<dyn MiniChatMetricsPort> = Arc::new(MiniChatMetricsMeter::new(
342            &opentelemetry::global::meter_with_scope(scope),
343            &metrics_prefix,
344        ));
345        let services = Arc::new(AppServices::new(
346            &repos,
347            db,
348            authz,
349            &(model_policy_gw.clone() as Arc<dyn crate::domain::repos::ModelResolver>),
350            &provider_resolver,
351            cfg.streaming,
352            model_policy_gw.clone() as Arc<dyn crate::domain::repos::PolicySnapshotProvider>,
353            model_policy_gw as Arc<dyn crate::domain::repos::UserLimitsProvider>,
354            cfg.estimation_budgets,
355            cfg.quota,
356            &(Arc::clone(&outbox_enqueuer) as Arc<dyn crate::domain::repos::OutboxEnqueuer>),
357            cfg.context,
358            file_storage,
359            vector_store_prov,
360            cfg.rag,
361            metrics,
362        ));
363
364        self.service
365            .set(services)
366            .map_err(|_| anyhow::anyhow!("{} module already initialized", Self::MODULE_NAME))?;
367
368        info!("{} module initialized successfully", Self::MODULE_NAME);
369        Ok(())
370    }
371}
372
373impl DatabaseCapability for MiniChatModule {
374    fn migrations(&self) -> Vec<Box<dyn MigrationTrait>> {
375        use sea_orm_migration::MigratorTrait;
376        info!("Providing mini-chat database migrations");
377        let mut m = crate::infra::db::migrations::Migrator::migrations();
378        m.extend(modkit_db::outbox::outbox_migrations());
379        m
380    }
381}
382
383impl RestApiCapability for MiniChatModule {
384    fn register_rest(
385        &self,
386        _ctx: &ModuleCtx,
387        router: axum::Router,
388        openapi: &dyn OpenApiRegistry,
389    ) -> anyhow::Result<axum::Router> {
390        let services = self
391            .service
392            .get()
393            .ok_or_else(|| anyhow::anyhow!("{} not initialized", Self::MODULE_NAME))?;
394
395        info!("Registering mini-chat REST routes");
396        let prefix = self
397            .url_prefix
398            .get()
399            .ok_or_else(|| anyhow::anyhow!("{} not initialized (url_prefix)", Self::MODULE_NAME))?;
400
401        let router = routes::register_routes(router, openapi, Arc::clone(services), prefix);
402        info!("Mini-chat REST routes registered successfully");
403        Ok(router)
404    }
405}
406
407#[async_trait]
408impl RunnableCapability for MiniChatModule {
409    async fn start(&self, _cancel: CancellationToken) -> anyhow::Result<()> {
410        // Register OAGW upstreams now that GTS is in ready mode (post_init
411        // has completed). During init() this fails because types-registry
412        // list() only queries the persistent store which is empty until
413        // switch_to_ready().
414        if let Some(deferred) = self.oagw_deferred.get() {
415            let ctx =
416                exchange_client_credentials(&deferred.authn, &deferred.client_credentials).await?;
417            let mut providers = deferred.providers.clone();
418            crate::infra::oagw_provisioning::register_oagw_upstreams(
419                &deferred.gateway,
420                &ctx,
421                &mut providers,
422            )
423            .await?;
424        }
425        Ok(())
426    }
427
428    async fn stop(&self, cancel: CancellationToken) -> anyhow::Result<()> {
429        let handle = self
430            .outbox_handle
431            .lock()
432            .map_err(|e| anyhow::anyhow!("outbox_handle lock: {e}"))?
433            .take();
434        if let Some(handle) = handle {
435            info!("Stopping outbox pipeline");
436            tokio::select! {
437                () = handle.stop() => {
438                    info!("Outbox pipeline stopped");
439                }
440                () = cancel.cancelled() => {
441                    info!("Outbox pipeline stop cancelled by framework deadline");
442                }
443            }
444        }
445        Ok(())
446    }
447}
448
449/// Exchange `OAuth2` client credentials via the `AuthN` resolver to obtain
450/// a `SecurityContext` for OAGW upstream provisioning.
451async fn exchange_client_credentials(
452    authn: &Arc<dyn AuthNResolverClient>,
453    creds: &crate::config::ClientCredentialsConfig,
454) -> anyhow::Result<modkit_security::SecurityContext> {
455    info!("Exchanging client credentials for OAGW provisioning context");
456    let request = ClientCredentialsRequest {
457        client_id: creds.client_id.clone(),
458        client_secret: creds.client_secret.clone(),
459        scopes: Vec::new(),
460    };
461    let result = authn
462        .exchange_client_credentials(&request)
463        .await
464        .map_err(|e| anyhow::anyhow!("client credentials exchange failed: {e}"))?;
465    info!("Security context obtained for OAGW provisioning");
466    Ok(result.security_context)
467}
468
469async fn register_plugin_schemas(
470    registry: &dyn TypesRegistryClient,
471    schemas: &[(String, &str, &str)],
472) -> anyhow::Result<()> {
473    let mut payload = Vec::with_capacity(schemas.len());
474    for (schema_str, schema_id, _label) in schemas {
475        let mut schema_json: serde_json::Value = serde_json::from_str(schema_str)?;
476        let obj = schema_json
477            .as_object_mut()
478            .ok_or_else(|| anyhow::anyhow!("schema {schema_id} is not a JSON object"))?;
479        obj.insert(
480            "additionalProperties".to_owned(),
481            serde_json::Value::Bool(false),
482        );
483        payload.push(schema_json);
484    }
485    let results = registry.register(payload).await?;
486    RegisterResult::ensure_all_ok(&results)?;
487    for (_schema_str, schema_id, label) in schemas {
488        info!(schema_id = %schema_id, "Registered {label} plugin schema in types-registry");
489    }
490    Ok(())
491}