Skip to main content

mini_chat/
module.rs

1use std::sync::{Arc, Mutex, OnceLock};
2
3use async_trait::async_trait;
4use authn_resolver_sdk::{AuthNResolverClient, ClientCredentialsRequest};
5use authz_resolver_sdk::AuthZResolverClient;
6use mini_chat_sdk::{MiniChatAuditPluginSpecV1, MiniChatModelPolicyPluginSpecV1};
7use modkit::api::OpenApiRegistry;
8use modkit::contracts::RunnableCapability;
9use modkit::{DatabaseCapability, Module, ModuleCtx, RestApiCapability};
10use modkit_db::outbox::{Outbox, OutboxHandle, Partitions};
11use oagw_sdk::ServiceGatewayClientV1;
12use sea_orm_migration::MigrationTrait;
13use tokio_util::sync::CancellationToken;
14use tracing::info;
15use types_registry_sdk::{RegisterResult, TypesRegistryClient};
16
17use crate::api::rest::routes;
18use crate::config::ProviderEntry;
19use crate::domain::service::{AppServices as GenericAppServices, Repositories};
20use crate::infra::outbox::{AttachmentCleanupHandler, InfraOutboxEnqueuer, UsageEventHandler};
21
22pub(crate) type AppServices = GenericAppServices<
23    TurnRepository,
24    MessageRepository,
25    QuotaUsageRepository,
26    ReactionRepository,
27    ChatRepository,
28    ThreadSummaryRepository,
29    AttachmentRepository,
30    VectorStoreRepository,
31    MessageAttachmentRepository,
32>;
33use crate::infra::db::repo::attachment_repo::AttachmentRepository;
34use crate::infra::db::repo::chat_repo::ChatRepository;
35use crate::infra::db::repo::message_attachment_repo::MessageAttachmentRepository;
36use crate::infra::db::repo::message_repo::MessageRepository;
37use crate::infra::db::repo::quota_usage_repo::QuotaUsageRepository;
38use crate::infra::db::repo::reaction_repo::ReactionRepository;
39use crate::infra::db::repo::thread_summary_repo::ThreadSummaryRepository;
40use crate::infra::db::repo::turn_repo::TurnRepository;
41use crate::infra::db::repo::vector_store_repo::VectorStoreRepository;
42use crate::infra::llm::provider_resolver::ProviderResolver;
43use crate::infra::model_policy::ModelPolicyGateway;
44
45/// Default URL prefix for all mini-chat REST routes.
46pub const DEFAULT_URL_PREFIX: &str = "/mini-chat";
47
48/// The mini-chat module: multi-tenant AI chat with SSE streaming.
49#[modkit::module(
50    name = "mini-chat",
51    deps = ["types-registry", "authn-resolver", "authz-resolver", "oagw"],
52    capabilities = [db, rest, stateful],
53)]
54pub struct MiniChatModule {
55    service: OnceLock<Arc<AppServices>>,
56    url_prefix: OnceLock<String>,
57    outbox_handle: Mutex<Option<OutboxHandle>>,
58    /// OAGW gateway + provider config for deferred upstream registration in `start()`.
59    oagw_deferred: OnceLock<OagwDeferred>,
60}
61
62/// State needed to register OAGW upstreams in `start()` (after GTS is ready).
63struct OagwDeferred {
64    gateway: Arc<dyn ServiceGatewayClientV1>,
65    authn: Arc<dyn AuthNResolverClient>,
66    client_credentials: crate::config::ClientCredentialsConfig,
67    providers: std::collections::HashMap<String, ProviderEntry>,
68}
69
70impl Default for MiniChatModule {
71    fn default() -> Self {
72        Self {
73            service: OnceLock::new(),
74            url_prefix: OnceLock::new(),
75            outbox_handle: Mutex::new(None),
76            oagw_deferred: OnceLock::new(),
77        }
78    }
79}
80
81#[allow(clippy::too_many_lines)]
82#[async_trait]
83impl Module for MiniChatModule {
84    async fn init(&self, ctx: &ModuleCtx) -> anyhow::Result<()> {
85        info!("Initializing {} module", Self::MODULE_NAME);
86
87        let mut cfg: crate::config::MiniChatConfig = ctx.config_expanded()?;
88        cfg.streaming
89            .validate()
90            .map_err(|e| anyhow::anyhow!("streaming config: {e}"))?;
91        cfg.estimation_budgets
92            .validate()
93            .map_err(|e| anyhow::anyhow!("estimation_budgets config: {e}"))?;
94        cfg.quota
95            .validate()
96            .map_err(|e| anyhow::anyhow!("quota config: {e}"))?;
97        cfg.outbox
98            .validate()
99            .map_err(|e| anyhow::anyhow!("outbox config: {e}"))?;
100        cfg.context
101            .validate()
102            .map_err(|e| anyhow::anyhow!("context config: {e}"))?;
103        cfg.client_credentials
104            .validate()
105            .map_err(|e| anyhow::anyhow!("client_credentials config: {e}"))?;
106        for (id, entry) in &cfg.providers {
107            entry
108                .validate(id)
109                .map_err(|e| anyhow::anyhow!("providers config: {e}"))?;
110        }
111
112        let vendor = cfg.vendor.trim().to_owned();
113        if vendor.is_empty() {
114            return Err(anyhow::anyhow!(
115                "{}: vendor must be a non-empty string",
116                Self::MODULE_NAME
117            ));
118        }
119
120        let registry = ctx.client_hub().get::<dyn TypesRegistryClient>()?;
121        register_plugin_schemas(
122            &*registry,
123            &[
124                (
125                    MiniChatModelPolicyPluginSpecV1::gts_schema_with_refs_as_string(),
126                    MiniChatModelPolicyPluginSpecV1::gts_schema_id(),
127                    "model-policy",
128                ),
129                (
130                    MiniChatAuditPluginSpecV1::gts_schema_with_refs_as_string(),
131                    MiniChatAuditPluginSpecV1::gts_schema_id(),
132                    "audit",
133                ),
134            ],
135        )
136        .await?;
137
138        self.url_prefix
139            .set(cfg.url_prefix)
140            .map_err(|_| anyhow::anyhow!("{} url_prefix already set", Self::MODULE_NAME))?;
141
142        let db_provider = ctx.db_required()?;
143        let db = Arc::new(db_provider);
144
145        // Create the model-policy gateway early for both outbox handler and services.
146        let model_policy_gw = Arc::new(ModelPolicyGateway::new(
147            ctx.client_hub(),
148            vendor,
149            ctx.cancellation_token().clone(),
150        ));
151
152        // Start the outbox pipeline eagerly in init() (migrations ran in phase 2, DB is ready).
153        // The framework guarantees stop() is called on init failure, so the pipeline
154        // will be shut down cleanly if any later init step errors.
155        // The handler resolves the plugin lazily on first message delivery,
156        // avoiding a hard dependency on plugin availability during init().
157        let outbox_db = db.db();
158        let num_partitions = cfg.outbox.num_partitions;
159        let queue_name = cfg.outbox.queue_name.clone();
160        let cleanup_queue_name = cfg.outbox.cleanup_queue_name.clone();
161
162        let partitions = Partitions::of(
163            u16::try_from(num_partitions)
164                .map_err(|_| anyhow::anyhow!("num_partitions {num_partitions} exceeds u16"))?,
165        );
166
167        let outbox_handle = Outbox::builder(outbox_db)
168            .queue(&queue_name, partitions)
169            .decoupled(UsageEventHandler {
170                plugin_provider: model_policy_gw.clone(),
171            })
172            .queue(&cleanup_queue_name, partitions)
173            .decoupled(AttachmentCleanupHandler)
174            .start()
175            .await
176            .map_err(|e| anyhow::anyhow!("outbox start: {e}"))?;
177
178        let outbox = Arc::clone(outbox_handle.outbox());
179        let outbox_enqueuer = Arc::new(InfraOutboxEnqueuer::new(
180            outbox,
181            queue_name,
182            cleanup_queue_name,
183            num_partitions,
184        ));
185
186        {
187            let mut guard = self
188                .outbox_handle
189                .lock()
190                .map_err(|e| anyhow::anyhow!("outbox_handle lock: {e}"))?;
191            if guard.is_some() {
192                return Err(anyhow::anyhow!(
193                    "{} outbox_handle already set",
194                    Self::MODULE_NAME
195                ));
196            }
197            *guard = Some(outbox_handle);
198        }
199
200        info!("Outbox pipeline started");
201
202        let authz = ctx
203            .client_hub()
204            .get::<dyn AuthZResolverClient>()
205            .map_err(|e| anyhow::anyhow!("failed to get AuthZ resolver: {e}"))?;
206
207        let authn_client = ctx
208            .client_hub()
209            .get::<dyn AuthNResolverClient>()
210            .map_err(|e| anyhow::anyhow!("failed to get AuthN resolver: {e}"))?;
211
212        let gateway = ctx
213            .client_hub()
214            .get::<dyn ServiceGatewayClientV1>()
215            .map_err(|e| anyhow::anyhow!("failed to get OAGW gateway: {e}"))?;
216
217        // Pre-fill upstream_alias with host as fallback so ProviderResolver
218        // works immediately. The actual OAGW registration is deferred to
219        // start() because GTS instances are not visible via list() until
220        // post_init (types-registry switches to ready mode there).
221        for entry in cfg.providers.values_mut() {
222            if entry.upstream_alias.is_none() {
223                entry.upstream_alias = Some(entry.host.clone());
224            }
225            for ovr in entry.tenant_overrides.values_mut() {
226                if ovr.upstream_alias.is_none()
227                    && let Some(ref h) = ovr.host
228                {
229                    ovr.upstream_alias = Some(h.clone());
230                }
231            }
232        }
233
234        // Save a copy for deferred OAGW registration in start().
235        // Ignore the result: if already set, we keep the first value.
236        drop(self.oagw_deferred.set(OagwDeferred {
237            gateway: Arc::clone(&gateway),
238            authn: Arc::clone(&authn_client),
239            client_credentials: cfg.client_credentials.clone(),
240            providers: cfg.providers.clone(),
241        }));
242
243        let provider_resolver = Arc::new(ProviderResolver::new(&gateway, cfg.providers));
244
245        let repos = Repositories {
246            chat: Arc::new(ChatRepository::new(modkit_db::odata::LimitCfg {
247                default: 20,
248                max: 100,
249            })),
250            attachment: Arc::new(AttachmentRepository),
251            message: Arc::new(MessageRepository::new(modkit_db::odata::LimitCfg {
252                default: 20,
253                max: 100,
254            })),
255            quota: Arc::new(QuotaUsageRepository),
256            turn: Arc::new(TurnRepository),
257            reaction: Arc::new(ReactionRepository),
258            thread_summary: Arc::new(ThreadSummaryRepository),
259            vector_store: Arc::new(VectorStoreRepository),
260            message_attachment: Arc::new(MessageAttachmentRepository),
261        };
262
263        let rag_client = Arc::new(
264            crate::infra::llm::providers::rag_http_client::RagHttpClient::new(Arc::clone(&gateway)),
265        );
266
267        // Build provider-specific file/vector store impls per provider entry.
268        // Dispatch by storage_kind: Azure → Azure impls, OpenAi → OpenAI impls.
269        let mut file_impls: std::collections::HashMap<
270            String,
271            Arc<dyn crate::domain::ports::FileStorageProvider>,
272        > = std::collections::HashMap::new();
273        let mut vs_impls: std::collections::HashMap<
274            String,
275            Arc<dyn crate::domain::ports::VectorStoreProvider>,
276        > = std::collections::HashMap::new();
277        for (provider_id, entry) in provider_resolver.entries() {
278            let (file, vs): (
279                Arc<dyn crate::domain::ports::FileStorageProvider>,
280                Arc<dyn crate::domain::ports::VectorStoreProvider>,
281            ) = match entry.storage_kind {
282                crate::config::StorageKind::Azure => {
283                    let api_version = entry.api_version.clone().unwrap_or_else(|| {
284                        panic!(
285                            "provider '{provider_id}': storage_kind is 'azure' \
286                             but api_version is not set"
287                        )
288                    });
289                    (
290                        Arc::new(
291                            crate::infra::llm::providers::azure_file_storage::AzureFileStorage::new(
292                                Arc::clone(&rag_client),
293                                Arc::clone(&provider_resolver),
294                                api_version.clone(),
295                            ),
296                        ),
297                        Arc::new(
298                            crate::infra::llm::providers::azure_vector_store::AzureVectorStore::new(
299                                Arc::clone(&rag_client),
300                                Arc::clone(&provider_resolver),
301                                api_version,
302                            ),
303                        ),
304                    )
305                }
306                crate::config::StorageKind::OpenAi => (
307                    Arc::new(
308                        crate::infra::llm::providers::openai_file_storage::OpenAiFileStorage::new(
309                            Arc::clone(&rag_client),
310                            Arc::clone(&provider_resolver),
311                        ),
312                    ),
313                    Arc::new(
314                        crate::infra::llm::providers::openai_vector_store::OpenAiVectorStore::new(
315                            Arc::clone(&rag_client),
316                            Arc::clone(&provider_resolver),
317                        ),
318                    ),
319                ),
320            };
321            file_impls.insert(provider_id.clone(), file);
322            vs_impls.insert(provider_id.clone(), vs);
323        }
324        let file_storage: Arc<dyn crate::domain::ports::FileStorageProvider> = Arc::new(
325            crate::infra::llm::providers::dispatching_storage::DispatchingFileStorage::new(
326                file_impls,
327            ),
328        );
329        let vector_store_prov: Arc<dyn crate::domain::ports::VectorStoreProvider> = Arc::new(
330            crate::infra::llm::providers::dispatching_storage::DispatchingVectorStore::new(
331                vs_impls,
332            ),
333        );
334
335        let services = Arc::new(AppServices::new(
336            &repos,
337            db,
338            authz,
339            &(model_policy_gw.clone() as Arc<dyn crate::domain::repos::ModelResolver>),
340            &provider_resolver,
341            cfg.streaming,
342            model_policy_gw.clone() as Arc<dyn crate::domain::repos::PolicySnapshotProvider>,
343            model_policy_gw as Arc<dyn crate::domain::repos::UserLimitsProvider>,
344            cfg.estimation_budgets,
345            cfg.quota,
346            &(Arc::clone(&outbox_enqueuer) as Arc<dyn crate::domain::repos::OutboxEnqueuer>),
347            cfg.context,
348            file_storage,
349            vector_store_prov,
350            cfg.rag,
351        ));
352
353        self.service
354            .set(services)
355            .map_err(|_| anyhow::anyhow!("{} module already initialized", Self::MODULE_NAME))?;
356
357        info!("{} module initialized successfully", Self::MODULE_NAME);
358        Ok(())
359    }
360}
361
362impl DatabaseCapability for MiniChatModule {
363    fn migrations(&self) -> Vec<Box<dyn MigrationTrait>> {
364        use sea_orm_migration::MigratorTrait;
365        info!("Providing mini-chat database migrations");
366        let mut m = crate::infra::db::migrations::Migrator::migrations();
367        m.extend(modkit_db::outbox::outbox_migrations());
368        m
369    }
370}
371
372impl RestApiCapability for MiniChatModule {
373    fn register_rest(
374        &self,
375        _ctx: &ModuleCtx,
376        router: axum::Router,
377        openapi: &dyn OpenApiRegistry,
378    ) -> anyhow::Result<axum::Router> {
379        let services = self
380            .service
381            .get()
382            .ok_or_else(|| anyhow::anyhow!("{} not initialized", Self::MODULE_NAME))?;
383
384        info!("Registering mini-chat REST routes");
385        let prefix = self
386            .url_prefix
387            .get()
388            .ok_or_else(|| anyhow::anyhow!("{} not initialized (url_prefix)", Self::MODULE_NAME))?;
389
390        let router = routes::register_routes(router, openapi, Arc::clone(services), prefix);
391        info!("Mini-chat REST routes registered successfully");
392        Ok(router)
393    }
394}
395
396#[async_trait]
397impl RunnableCapability for MiniChatModule {
398    async fn start(&self, _cancel: CancellationToken) -> anyhow::Result<()> {
399        // Register OAGW upstreams now that GTS is in ready mode (post_init
400        // has completed). During init() this fails because types-registry
401        // list() only queries the persistent store which is empty until
402        // switch_to_ready().
403        if let Some(deferred) = self.oagw_deferred.get() {
404            let ctx =
405                exchange_client_credentials(&deferred.authn, &deferred.client_credentials).await?;
406            let mut providers = deferred.providers.clone();
407            crate::infra::oagw_provisioning::register_oagw_upstreams(
408                &deferred.gateway,
409                &ctx,
410                &mut providers,
411            )
412            .await?;
413        }
414        Ok(())
415    }
416
417    async fn stop(&self, cancel: CancellationToken) -> anyhow::Result<()> {
418        let handle = self
419            .outbox_handle
420            .lock()
421            .map_err(|e| anyhow::anyhow!("outbox_handle lock: {e}"))?
422            .take();
423        if let Some(handle) = handle {
424            info!("Stopping outbox pipeline");
425            tokio::select! {
426                () = handle.stop() => {
427                    info!("Outbox pipeline stopped");
428                }
429                () = cancel.cancelled() => {
430                    info!("Outbox pipeline stop cancelled by framework deadline");
431                }
432            }
433        }
434        Ok(())
435    }
436}
437
438/// Exchange `OAuth2` client credentials via the `AuthN` resolver to obtain
439/// a `SecurityContext` for OAGW upstream provisioning.
440async fn exchange_client_credentials(
441    authn: &Arc<dyn AuthNResolverClient>,
442    creds: &crate::config::ClientCredentialsConfig,
443) -> anyhow::Result<modkit_security::SecurityContext> {
444    info!("Exchanging client credentials for OAGW provisioning context");
445    let request = ClientCredentialsRequest {
446        client_id: creds.client_id.clone(),
447        client_secret: creds.client_secret.clone(),
448        scopes: Vec::new(),
449    };
450    let result = authn
451        .exchange_client_credentials(&request)
452        .await
453        .map_err(|e| anyhow::anyhow!("client credentials exchange failed: {e}"))?;
454    info!("Security context obtained for OAGW provisioning");
455    Ok(result.security_context)
456}
457
458async fn register_plugin_schemas(
459    registry: &dyn TypesRegistryClient,
460    schemas: &[(String, &str, &str)],
461) -> anyhow::Result<()> {
462    let mut payload = Vec::with_capacity(schemas.len());
463    for (schema_str, schema_id, _label) in schemas {
464        let mut schema_json: serde_json::Value = serde_json::from_str(schema_str)?;
465        let obj = schema_json
466            .as_object_mut()
467            .ok_or_else(|| anyhow::anyhow!("schema {schema_id} is not a JSON object"))?;
468        obj.insert(
469            "additionalProperties".to_owned(),
470            serde_json::Value::Bool(false),
471        );
472        payload.push(schema_json);
473    }
474    let results = registry.register(payload).await?;
475    RegisterResult::ensure_all_ok(&results)?;
476    for (_schema_str, schema_id, label) in schemas {
477        info!(schema_id = %schema_id, "Registered {label} plugin schema in types-registry");
478    }
479    Ok(())
480}