Skip to main content

mini_chat/
module.rs

1use std::sync::{Arc, Mutex, OnceLock};
2
3use async_trait::async_trait;
4use authn_resolver_sdk::{AuthNResolverClient, ClientCredentialsRequest};
5use authz_resolver_sdk::AuthZResolverClient;
6use mini_chat_sdk::{MiniChatAuditPluginSpecV1, MiniChatModelPolicyPluginSpecV1};
7use modkit::api::OpenApiRegistry;
8use modkit::contracts::RunnableCapability;
9use modkit::{DatabaseCapability, Module, ModuleCtx, RestApiCapability};
10use modkit_db::outbox::{Outbox, OutboxHandle, Partitions};
11use oagw_sdk::ServiceGatewayClientV1;
12use sea_orm_migration::MigrationTrait;
13use tokio_util::sync::CancellationToken;
14use tracing::info;
15use types_registry_sdk::{RegisterResult, TypesRegistryClient};
16
17use crate::api::rest::routes;
18use crate::config::ProviderEntry;
19use crate::domain::service::{AppServices as GenericAppServices, Repositories};
20use crate::infra::outbox::{InfraOutboxEnqueuer, UsageEventHandler};
21
22pub(crate) type AppServices = GenericAppServices<
23    TurnRepository,
24    MessageRepository,
25    QuotaUsageRepository,
26    ReactionRepository,
27    ChatRepository,
28    ThreadSummaryRepository,
29>;
30use crate::infra::db::repo::attachment_repo::AttachmentRepository;
31use crate::infra::db::repo::chat_repo::ChatRepository;
32use crate::infra::db::repo::message_repo::MessageRepository;
33use crate::infra::db::repo::quota_usage_repo::QuotaUsageRepository;
34use crate::infra::db::repo::reaction_repo::ReactionRepository;
35use crate::infra::db::repo::thread_summary_repo::ThreadSummaryRepository;
36use crate::infra::db::repo::turn_repo::TurnRepository;
37use crate::infra::db::repo::vector_store_repo::VectorStoreRepository;
38use crate::infra::llm::provider_resolver::ProviderResolver;
39use crate::infra::model_policy::ModelPolicyGateway;
40
41/// Default URL prefix for all mini-chat REST routes.
42pub const DEFAULT_URL_PREFIX: &str = "/mini-chat";
43
44/// The mini-chat module: multi-tenant AI chat with SSE streaming.
45#[modkit::module(
46    name = "mini-chat",
47    deps = ["types-registry", "authn-resolver", "authz-resolver", "oagw"],
48    capabilities = [db, rest, stateful],
49)]
50pub struct MiniChatModule {
51    service: OnceLock<Arc<AppServices>>,
52    url_prefix: OnceLock<String>,
53    outbox_handle: Mutex<Option<OutboxHandle>>,
54    /// OAGW gateway + provider config for deferred upstream registration in `start()`.
55    oagw_deferred: OnceLock<OagwDeferred>,
56}
57
58/// State needed to register OAGW upstreams in `start()` (after GTS is ready).
59struct OagwDeferred {
60    gateway: Arc<dyn ServiceGatewayClientV1>,
61    authn: Arc<dyn AuthNResolverClient>,
62    client_credentials: crate::config::ClientCredentialsConfig,
63    providers: std::collections::HashMap<String, ProviderEntry>,
64}
65
66impl Default for MiniChatModule {
67    fn default() -> Self {
68        Self {
69            service: OnceLock::new(),
70            url_prefix: OnceLock::new(),
71            outbox_handle: Mutex::new(None),
72            oagw_deferred: OnceLock::new(),
73        }
74    }
75}
76
77#[async_trait]
78impl Module for MiniChatModule {
79    async fn init(&self, ctx: &ModuleCtx) -> anyhow::Result<()> {
80        info!("Initializing {} module", Self::MODULE_NAME);
81
82        let mut cfg: crate::config::MiniChatConfig = ctx.config_expanded()?;
83        cfg.streaming
84            .validate()
85            .map_err(|e| anyhow::anyhow!("streaming config: {e}"))?;
86        cfg.estimation_budgets
87            .validate()
88            .map_err(|e| anyhow::anyhow!("estimation_budgets config: {e}"))?;
89        cfg.quota
90            .validate()
91            .map_err(|e| anyhow::anyhow!("quota config: {e}"))?;
92        cfg.outbox
93            .validate()
94            .map_err(|e| anyhow::anyhow!("outbox config: {e}"))?;
95        cfg.context
96            .validate()
97            .map_err(|e| anyhow::anyhow!("context config: {e}"))?;
98        cfg.client_credentials
99            .validate()
100            .map_err(|e| anyhow::anyhow!("client_credentials config: {e}"))?;
101        for (id, entry) in &cfg.providers {
102            entry
103                .validate(id)
104                .map_err(|e| anyhow::anyhow!("providers config: {e}"))?;
105        }
106
107        let vendor = cfg.vendor.trim().to_owned();
108        if vendor.is_empty() {
109            return Err(anyhow::anyhow!(
110                "{}: vendor must be a non-empty string",
111                Self::MODULE_NAME
112            ));
113        }
114
115        let registry = ctx.client_hub().get::<dyn TypesRegistryClient>()?;
116        register_plugin_schemas(
117            &*registry,
118            &[
119                (
120                    MiniChatModelPolicyPluginSpecV1::gts_schema_with_refs_as_string(),
121                    MiniChatModelPolicyPluginSpecV1::gts_schema_id(),
122                    "model-policy",
123                ),
124                (
125                    MiniChatAuditPluginSpecV1::gts_schema_with_refs_as_string(),
126                    MiniChatAuditPluginSpecV1::gts_schema_id(),
127                    "audit",
128                ),
129            ],
130        )
131        .await?;
132
133        self.url_prefix
134            .set(cfg.url_prefix)
135            .map_err(|_| anyhow::anyhow!("{} url_prefix already set", Self::MODULE_NAME))?;
136
137        let db_provider = ctx.db_required()?;
138        let db = Arc::new(db_provider);
139
140        // Create the model-policy gateway early for both outbox handler and services.
141        let model_policy_gw = Arc::new(ModelPolicyGateway::new(
142            ctx.client_hub(),
143            vendor,
144            ctx.cancellation_token().clone(),
145        ));
146
147        // Start the outbox pipeline eagerly in init() (migrations ran in phase 2, DB is ready).
148        // The framework guarantees stop() is called on init failure, so the pipeline
149        // will be shut down cleanly if any later init step errors.
150        // The handler resolves the plugin lazily on first message delivery,
151        // avoiding a hard dependency on plugin availability during init().
152        let outbox_db = db.db();
153        let num_partitions = cfg.outbox.num_partitions;
154        let queue_name = cfg.outbox.queue_name.clone();
155
156        let outbox_handle =
157            Outbox::builder(outbox_db)
158                .queue(
159                    &queue_name,
160                    Partitions::of(u16::try_from(num_partitions).map_err(|_| {
161                        anyhow::anyhow!("num_partitions {num_partitions} exceeds u16")
162                    })?),
163                )
164                .decoupled(UsageEventHandler {
165                    plugin_provider: model_policy_gw.clone(),
166                })
167                .start()
168                .await
169                .map_err(|e| anyhow::anyhow!("outbox start: {e}"))?;
170
171        let outbox = Arc::clone(outbox_handle.outbox());
172        let outbox_enqueuer =
173            Arc::new(InfraOutboxEnqueuer::new(outbox, queue_name, num_partitions));
174
175        {
176            let mut guard = self
177                .outbox_handle
178                .lock()
179                .map_err(|e| anyhow::anyhow!("outbox_handle lock: {e}"))?;
180            if guard.is_some() {
181                return Err(anyhow::anyhow!(
182                    "{} outbox_handle already set",
183                    Self::MODULE_NAME
184                ));
185            }
186            *guard = Some(outbox_handle);
187        }
188
189        info!("Outbox pipeline started");
190
191        let authz = ctx
192            .client_hub()
193            .get::<dyn AuthZResolverClient>()
194            .map_err(|e| anyhow::anyhow!("failed to get AuthZ resolver: {e}"))?;
195
196        let authn_client = ctx
197            .client_hub()
198            .get::<dyn AuthNResolverClient>()
199            .map_err(|e| anyhow::anyhow!("failed to get AuthN resolver: {e}"))?;
200
201        let gateway = ctx
202            .client_hub()
203            .get::<dyn ServiceGatewayClientV1>()
204            .map_err(|e| anyhow::anyhow!("failed to get OAGW gateway: {e}"))?;
205
206        // Pre-fill upstream_alias with host as fallback so ProviderResolver
207        // works immediately. The actual OAGW registration is deferred to
208        // start() because GTS instances are not visible via list() until
209        // post_init (types-registry switches to ready mode there).
210        for entry in cfg.providers.values_mut() {
211            if entry.upstream_alias.is_none() {
212                entry.upstream_alias = Some(entry.host.clone());
213            }
214            for ovr in entry.tenant_overrides.values_mut() {
215                if ovr.upstream_alias.is_none()
216                    && let Some(ref h) = ovr.host
217                {
218                    ovr.upstream_alias = Some(h.clone());
219                }
220            }
221        }
222
223        // Save a copy for deferred OAGW registration in start().
224        // Ignore the result: if already set, we keep the first value.
225        drop(self.oagw_deferred.set(OagwDeferred {
226            gateway: Arc::clone(&gateway),
227            authn: Arc::clone(&authn_client),
228            client_credentials: cfg.client_credentials.clone(),
229            providers: cfg.providers.clone(),
230        }));
231
232        let provider_resolver = Arc::new(ProviderResolver::new(&gateway, cfg.providers));
233
234        let repos = Repositories {
235            chat: Arc::new(ChatRepository::new(modkit_db::odata::LimitCfg {
236                default: 20,
237                max: 100,
238            })),
239            attachment: Arc::new(AttachmentRepository),
240            message: Arc::new(MessageRepository::new(modkit_db::odata::LimitCfg {
241                default: 20,
242                max: 100,
243            })),
244            quota: Arc::new(QuotaUsageRepository),
245            turn: Arc::new(TurnRepository),
246            reaction: Arc::new(ReactionRepository),
247            thread_summary: Arc::new(ThreadSummaryRepository),
248            vector_store: Arc::new(VectorStoreRepository),
249        };
250
251        let services = Arc::new(AppServices::new(
252            &repos,
253            db,
254            authz,
255            &(model_policy_gw.clone() as Arc<dyn crate::domain::repos::ModelResolver>),
256            provider_resolver,
257            cfg.streaming,
258            model_policy_gw.clone() as Arc<dyn crate::domain::repos::PolicySnapshotProvider>,
259            model_policy_gw as Arc<dyn crate::domain::repos::UserLimitsProvider>,
260            cfg.estimation_budgets,
261            cfg.quota,
262            outbox_enqueuer,
263            cfg.context,
264        ));
265
266        self.service
267            .set(services)
268            .map_err(|_| anyhow::anyhow!("{} module already initialized", Self::MODULE_NAME))?;
269
270        info!("{} module initialized successfully", Self::MODULE_NAME);
271        Ok(())
272    }
273}
274
275impl DatabaseCapability for MiniChatModule {
276    fn migrations(&self) -> Vec<Box<dyn MigrationTrait>> {
277        use sea_orm_migration::MigratorTrait;
278        info!("Providing mini-chat database migrations");
279        let mut m = crate::infra::db::migrations::Migrator::migrations();
280        m.extend(modkit_db::outbox::outbox_migrations());
281        m
282    }
283}
284
285impl RestApiCapability for MiniChatModule {
286    fn register_rest(
287        &self,
288        _ctx: &ModuleCtx,
289        router: axum::Router,
290        openapi: &dyn OpenApiRegistry,
291    ) -> anyhow::Result<axum::Router> {
292        let services = self
293            .service
294            .get()
295            .ok_or_else(|| anyhow::anyhow!("{} not initialized", Self::MODULE_NAME))?;
296
297        info!("Registering mini-chat REST routes");
298        let prefix = self
299            .url_prefix
300            .get()
301            .ok_or_else(|| anyhow::anyhow!("{} not initialized (url_prefix)", Self::MODULE_NAME))?;
302
303        let router = routes::register_routes(router, openapi, Arc::clone(services), prefix);
304        info!("Mini-chat REST routes registered successfully");
305        Ok(router)
306    }
307}
308
309#[async_trait]
310impl RunnableCapability for MiniChatModule {
311    async fn start(&self, _cancel: CancellationToken) -> anyhow::Result<()> {
312        // Register OAGW upstreams now that GTS is in ready mode (post_init
313        // has completed). During init() this fails because types-registry
314        // list() only queries the persistent store which is empty until
315        // switch_to_ready().
316        if let Some(deferred) = self.oagw_deferred.get() {
317            let ctx =
318                exchange_client_credentials(&deferred.authn, &deferred.client_credentials).await?;
319            let mut providers = deferred.providers.clone();
320            crate::infra::oagw_provisioning::register_oagw_upstreams(
321                &deferred.gateway,
322                &ctx,
323                &mut providers,
324            )
325            .await?;
326        }
327        Ok(())
328    }
329
330    async fn stop(&self, cancel: CancellationToken) -> anyhow::Result<()> {
331        let handle = self
332            .outbox_handle
333            .lock()
334            .map_err(|e| anyhow::anyhow!("outbox_handle lock: {e}"))?
335            .take();
336        if let Some(handle) = handle {
337            info!("Stopping outbox pipeline");
338            tokio::select! {
339                () = handle.stop() => {
340                    info!("Outbox pipeline stopped");
341                }
342                () = cancel.cancelled() => {
343                    info!("Outbox pipeline stop cancelled by framework deadline");
344                }
345            }
346        }
347        Ok(())
348    }
349}
350
351/// Exchange `OAuth2` client credentials via the `AuthN` resolver to obtain
352/// a `SecurityContext` for OAGW upstream provisioning.
353async fn exchange_client_credentials(
354    authn: &Arc<dyn AuthNResolverClient>,
355    creds: &crate::config::ClientCredentialsConfig,
356) -> anyhow::Result<modkit_security::SecurityContext> {
357    info!("Exchanging client credentials for OAGW provisioning context");
358    let request = ClientCredentialsRequest {
359        client_id: creds.client_id.clone(),
360        client_secret: creds.client_secret.clone(),
361        scopes: Vec::new(),
362    };
363    let result = authn
364        .exchange_client_credentials(&request)
365        .await
366        .map_err(|e| anyhow::anyhow!("client credentials exchange failed: {e}"))?;
367    info!("Security context obtained for OAGW provisioning");
368    Ok(result.security_context)
369}
370
371async fn register_plugin_schemas(
372    registry: &dyn TypesRegistryClient,
373    schemas: &[(String, &str, &str)],
374) -> anyhow::Result<()> {
375    let mut payload = Vec::with_capacity(schemas.len());
376    for (schema_str, schema_id, _label) in schemas {
377        let mut schema_json: serde_json::Value = serde_json::from_str(schema_str)?;
378        let obj = schema_json
379            .as_object_mut()
380            .ok_or_else(|| anyhow::anyhow!("schema {schema_id} is not a JSON object"))?;
381        obj.insert(
382            "additionalProperties".to_owned(),
383            serde_json::Value::Bool(false),
384        );
385        payload.push(schema_json);
386    }
387    let results = registry.register(payload).await?;
388    RegisterResult::ensure_all_ok(&results)?;
389    for (_schema_str, schema_id, label) in schemas {
390        info!(schema_id = %schema_id, "Registered {label} plugin schema in types-registry");
391    }
392    Ok(())
393}