1use std::sync::{Arc, Mutex, OnceLock};
2
3use async_trait::async_trait;
4use authn_resolver_sdk::{AuthNResolverClient, ClientCredentialsRequest};
5use authz_resolver_sdk::AuthZResolverClient;
6use mini_chat_sdk::{MiniChatAuditPluginSpecV1, MiniChatModelPolicyPluginSpecV1};
7use modkit::api::OpenApiRegistry;
8use modkit::contracts::RunnableCapability;
9use modkit::{DatabaseCapability, Module, ModuleCtx, RestApiCapability};
10use modkit_db::outbox::{Outbox, OutboxHandle, Partitions};
11use oagw_sdk::ServiceGatewayClientV1;
12use sea_orm_migration::MigrationTrait;
13use tokio_util::sync::CancellationToken;
14use tracing::info;
15use types_registry_sdk::{RegisterResult, TypesRegistryClient};
16
17use crate::api::rest::routes;
18use crate::config::ProviderEntry;
19use crate::domain::ports::MiniChatMetricsPort;
20use crate::domain::service::{AppServices as GenericAppServices, Repositories};
21use crate::infra::metrics::MiniChatMetricsMeter;
22use crate::infra::outbox::{
23 AttachmentCleanupHandler, AuditEventHandler, InfraOutboxEnqueuer, UsageEventHandler,
24};
25
26pub(crate) type AppServices = GenericAppServices<
27 TurnRepository,
28 MessageRepository,
29 QuotaUsageRepository,
30 ReactionRepository,
31 ChatRepository,
32 ThreadSummaryRepository,
33 AttachmentRepository,
34 VectorStoreRepository,
35 MessageAttachmentRepository,
36>;
37use crate::infra::audit_gateway::AuditGateway;
38use crate::infra::db::repo::attachment_repo::AttachmentRepository;
39use crate::infra::db::repo::chat_repo::ChatRepository;
40use crate::infra::db::repo::message_attachment_repo::MessageAttachmentRepository;
41use crate::infra::db::repo::message_repo::MessageRepository;
42use crate::infra::db::repo::quota_usage_repo::QuotaUsageRepository;
43use crate::infra::db::repo::reaction_repo::ReactionRepository;
44use crate::infra::db::repo::thread_summary_repo::ThreadSummaryRepository;
45use crate::infra::db::repo::turn_repo::TurnRepository;
46use crate::infra::db::repo::vector_store_repo::VectorStoreRepository;
47use crate::infra::llm::provider_resolver::ProviderResolver;
48use crate::infra::model_policy::ModelPolicyGateway;
49
50pub const DEFAULT_URL_PREFIX: &str = "/mini-chat";
52
53#[modkit::module(
55 name = "mini-chat",
56 deps = ["types-registry", "authn-resolver", "authz-resolver", "oagw"],
57 capabilities = [db, rest, stateful],
58)]
59pub struct MiniChatModule {
60 service: OnceLock<Arc<AppServices>>,
61 url_prefix: OnceLock<String>,
62 outbox_handle: Mutex<Option<OutboxHandle>>,
63 oagw_deferred: OnceLock<OagwDeferred>,
65}
66
67struct OagwDeferred {
69 gateway: Arc<dyn ServiceGatewayClientV1>,
70 authn: Arc<dyn AuthNResolverClient>,
71 client_credentials: crate::config::ClientCredentialsConfig,
72 providers: std::collections::HashMap<String, ProviderEntry>,
73}
74
75impl Default for MiniChatModule {
76 fn default() -> Self {
77 Self {
78 service: OnceLock::new(),
79 url_prefix: OnceLock::new(),
80 outbox_handle: Mutex::new(None),
81 oagw_deferred: OnceLock::new(),
82 }
83 }
84}
85
86#[allow(clippy::too_many_lines)]
87#[async_trait]
88impl Module for MiniChatModule {
89 async fn init(&self, ctx: &ModuleCtx) -> anyhow::Result<()> {
90 info!("Initializing {} module", Self::MODULE_NAME);
91
92 let mut cfg: crate::config::MiniChatConfig = ctx.config_expanded()?;
93 cfg.streaming
94 .validate()
95 .map_err(|e| anyhow::anyhow!("streaming config: {e}"))?;
96 cfg.estimation_budgets
97 .validate()
98 .map_err(|e| anyhow::anyhow!("estimation_budgets config: {e}"))?;
99 cfg.quota
100 .validate()
101 .map_err(|e| anyhow::anyhow!("quota config: {e}"))?;
102 cfg.outbox
103 .validate()
104 .map_err(|e| anyhow::anyhow!("outbox config: {e}"))?;
105 cfg.context
106 .validate()
107 .map_err(|e| anyhow::anyhow!("context config: {e}"))?;
108 cfg.client_credentials
109 .validate()
110 .map_err(|e| anyhow::anyhow!("client_credentials config: {e}"))?;
111 for (id, entry) in &cfg.providers {
112 entry
113 .validate(id)
114 .map_err(|e| anyhow::anyhow!("providers config: {e}"))?;
115 }
116
117 let vendor = cfg.vendor.trim().to_owned();
118 if vendor.is_empty() {
119 return Err(anyhow::anyhow!(
120 "{}: vendor must be a non-empty string",
121 Self::MODULE_NAME
122 ));
123 }
124
125 let registry = ctx.client_hub().get::<dyn TypesRegistryClient>()?;
126 register_plugin_schemas(
127 &*registry,
128 &[
129 (
130 MiniChatModelPolicyPluginSpecV1::gts_schema_with_refs_as_string(),
131 MiniChatModelPolicyPluginSpecV1::gts_schema_id(),
132 "model-policy",
133 ),
134 (
135 MiniChatAuditPluginSpecV1::gts_schema_with_refs_as_string(),
136 MiniChatAuditPluginSpecV1::gts_schema_id(),
137 "audit",
138 ),
139 ],
140 )
141 .await?;
142
143 self.url_prefix
144 .set(cfg.url_prefix)
145 .map_err(|_| anyhow::anyhow!("{} url_prefix already set", Self::MODULE_NAME))?;
146
147 let db_provider = ctx.db_required()?;
148 let db = Arc::new(db_provider);
149
150 let model_policy_gw = Arc::new(ModelPolicyGateway::new(
152 ctx.client_hub(),
153 vendor.clone(),
154 ctx.cancellation_token().clone(),
155 ));
156
157 let audit_gateway = Arc::new(AuditGateway::new(ctx.client_hub(), vendor));
159
160 let outbox_db = db.db();
166 let num_partitions = cfg.outbox.num_partitions;
167 let queue_name = cfg.outbox.queue_name.clone();
168 let cleanup_queue_name = cfg.outbox.cleanup_queue_name.clone();
169 let audit_queue_name = cfg.outbox.audit_queue_name.clone();
170
171 let partitions = Partitions::of(
172 u16::try_from(num_partitions)
173 .map_err(|_| anyhow::anyhow!("num_partitions {num_partitions} exceeds u16"))?,
174 );
175
176 let metrics_prefix = cfg.metrics.effective_prefix(Self::MODULE_NAME);
178 let scope =
179 opentelemetry::InstrumentationScope::builder(Self::MODULE_NAME.to_owned()).build();
180 let metrics: Arc<dyn MiniChatMetricsPort> = Arc::new(MiniChatMetricsMeter::new(
181 &opentelemetry::global::meter_with_scope(scope),
182 &metrics_prefix,
183 ));
184
185 let outbox_handle = Outbox::builder(outbox_db)
186 .queue(&queue_name, partitions)
187 .decoupled(UsageEventHandler {
188 plugin_provider: model_policy_gw.clone(),
189 })
190 .queue(&cleanup_queue_name, partitions)
191 .decoupled(AttachmentCleanupHandler)
192 .queue(&audit_queue_name, partitions)
193 .decoupled(AuditEventHandler {
194 audit_gateway: Arc::clone(&audit_gateway),
195 metrics: Arc::clone(&metrics),
196 })
197 .start()
198 .await
199 .map_err(|e| anyhow::anyhow!("outbox start: {e}"))?;
200
201 let outbox = Arc::clone(outbox_handle.outbox());
202 let outbox_enqueuer = Arc::new(InfraOutboxEnqueuer::new(
203 outbox,
204 queue_name,
205 cleanup_queue_name,
206 audit_queue_name,
207 num_partitions,
208 ));
209
210 {
211 let mut guard = self
212 .outbox_handle
213 .lock()
214 .map_err(|e| anyhow::anyhow!("outbox_handle lock: {e}"))?;
215 if guard.is_some() {
216 return Err(anyhow::anyhow!(
217 "{} outbox_handle already set",
218 Self::MODULE_NAME
219 ));
220 }
221 *guard = Some(outbox_handle);
222 }
223
224 info!("Outbox pipeline started");
225
226 let authz = ctx
227 .client_hub()
228 .get::<dyn AuthZResolverClient>()
229 .map_err(|e| anyhow::anyhow!("failed to get AuthZ resolver: {e}"))?;
230
231 let authn_client = ctx
232 .client_hub()
233 .get::<dyn AuthNResolverClient>()
234 .map_err(|e| anyhow::anyhow!("failed to get AuthN resolver: {e}"))?;
235
236 let gateway = ctx
237 .client_hub()
238 .get::<dyn ServiceGatewayClientV1>()
239 .map_err(|e| anyhow::anyhow!("failed to get OAGW gateway: {e}"))?;
240
241 for entry in cfg.providers.values_mut() {
246 if entry.upstream_alias.is_none() {
247 entry.upstream_alias = Some(entry.host.clone());
248 }
249 for ovr in entry.tenant_overrides.values_mut() {
250 if ovr.upstream_alias.is_none()
251 && let Some(ref h) = ovr.host
252 {
253 ovr.upstream_alias = Some(h.clone());
254 }
255 }
256 }
257
258 drop(self.oagw_deferred.set(OagwDeferred {
261 gateway: Arc::clone(&gateway),
262 authn: Arc::clone(&authn_client),
263 client_credentials: cfg.client_credentials.clone(),
264 providers: cfg.providers.clone(),
265 }));
266
267 let provider_resolver = Arc::new(ProviderResolver::new(&gateway, cfg.providers));
268
269 let repos = Repositories {
270 chat: Arc::new(ChatRepository::new(modkit_db::odata::LimitCfg {
271 default: 20,
272 max: 100,
273 })),
274 attachment: Arc::new(AttachmentRepository),
275 message: Arc::new(MessageRepository::new(modkit_db::odata::LimitCfg {
276 default: 20,
277 max: 100,
278 })),
279 quota: Arc::new(QuotaUsageRepository),
280 turn: Arc::new(TurnRepository),
281 reaction: Arc::new(ReactionRepository),
282 thread_summary: Arc::new(ThreadSummaryRepository),
283 vector_store: Arc::new(VectorStoreRepository),
284 message_attachment: Arc::new(MessageAttachmentRepository),
285 };
286
287 let rag_client = Arc::new(
288 crate::infra::llm::providers::rag_http_client::RagHttpClient::new(Arc::clone(&gateway)),
289 );
290
291 let mut file_impls: std::collections::HashMap<
294 String,
295 Arc<dyn crate::domain::ports::FileStorageProvider>,
296 > = std::collections::HashMap::new();
297 let mut vs_impls: std::collections::HashMap<
298 String,
299 Arc<dyn crate::domain::ports::VectorStoreProvider>,
300 > = std::collections::HashMap::new();
301 for (provider_id, entry) in provider_resolver.entries() {
302 let (file, vs): (
303 Arc<dyn crate::domain::ports::FileStorageProvider>,
304 Arc<dyn crate::domain::ports::VectorStoreProvider>,
305 ) = match entry.storage_kind {
306 crate::config::StorageKind::Azure => {
307 let api_version = entry.api_version.clone().unwrap_or_else(|| {
308 panic!(
309 "provider '{provider_id}': storage_kind is 'azure' \
310 but api_version is not set"
311 )
312 });
313 (
314 Arc::new(
315 crate::infra::llm::providers::azure_file_storage::AzureFileStorage::new(
316 Arc::clone(&rag_client),
317 Arc::clone(&provider_resolver),
318 api_version.clone(),
319 ),
320 ),
321 Arc::new(
322 crate::infra::llm::providers::azure_vector_store::AzureVectorStore::new(
323 Arc::clone(&rag_client),
324 Arc::clone(&provider_resolver),
325 api_version,
326 ),
327 ),
328 )
329 }
330 crate::config::StorageKind::OpenAi => (
331 Arc::new(
332 crate::infra::llm::providers::openai_file_storage::OpenAiFileStorage::new(
333 Arc::clone(&rag_client),
334 Arc::clone(&provider_resolver),
335 ),
336 ),
337 Arc::new(
338 crate::infra::llm::providers::openai_vector_store::OpenAiVectorStore::new(
339 Arc::clone(&rag_client),
340 Arc::clone(&provider_resolver),
341 ),
342 ),
343 ),
344 };
345 file_impls.insert(provider_id.clone(), file);
346 vs_impls.insert(provider_id.clone(), vs);
347 }
348 let file_storage: Arc<dyn crate::domain::ports::FileStorageProvider> = Arc::new(
349 crate::infra::llm::providers::dispatching_storage::DispatchingFileStorage::new(
350 file_impls,
351 ),
352 );
353 let vector_store_prov: Arc<dyn crate::domain::ports::VectorStoreProvider> = Arc::new(
354 crate::infra::llm::providers::dispatching_storage::DispatchingVectorStore::new(
355 vs_impls,
356 ),
357 );
358
359 let services = Arc::new(AppServices::new(
360 &repos,
361 db,
362 authz,
363 &(model_policy_gw.clone() as Arc<dyn crate::domain::repos::ModelResolver>),
364 &provider_resolver,
365 cfg.streaming,
366 model_policy_gw.clone() as Arc<dyn crate::domain::repos::PolicySnapshotProvider>,
367 model_policy_gw as Arc<dyn crate::domain::repos::UserLimitsProvider>,
368 cfg.estimation_budgets,
369 cfg.quota,
370 &(outbox_enqueuer as Arc<dyn crate::domain::repos::OutboxEnqueuer>),
371 cfg.context,
372 file_storage,
373 vector_store_prov,
374 cfg.rag,
375 metrics,
376 ));
377
378 self.service
379 .set(services)
380 .map_err(|_| anyhow::anyhow!("{} module already initialized", Self::MODULE_NAME))?;
381
382 info!("{} module initialized successfully", Self::MODULE_NAME);
383 Ok(())
384 }
385}
386
387impl DatabaseCapability for MiniChatModule {
388 fn migrations(&self) -> Vec<Box<dyn MigrationTrait>> {
389 use sea_orm_migration::MigratorTrait;
390 info!("Providing mini-chat database migrations");
391 let mut m = crate::infra::db::migrations::Migrator::migrations();
392 m.extend(modkit_db::outbox::outbox_migrations());
393 m
394 }
395}
396
397impl RestApiCapability for MiniChatModule {
398 fn register_rest(
399 &self,
400 _ctx: &ModuleCtx,
401 router: axum::Router,
402 openapi: &dyn OpenApiRegistry,
403 ) -> anyhow::Result<axum::Router> {
404 let services = self
405 .service
406 .get()
407 .ok_or_else(|| anyhow::anyhow!("{} not initialized", Self::MODULE_NAME))?;
408
409 info!("Registering mini-chat REST routes");
410 let prefix = self
411 .url_prefix
412 .get()
413 .ok_or_else(|| anyhow::anyhow!("{} not initialized (url_prefix)", Self::MODULE_NAME))?;
414
415 let router = routes::register_routes(router, openapi, Arc::clone(services), prefix);
416 info!("Mini-chat REST routes registered successfully");
417 Ok(router)
418 }
419}
420
421#[async_trait]
422impl RunnableCapability for MiniChatModule {
423 async fn start(&self, _cancel: CancellationToken) -> anyhow::Result<()> {
424 if let Some(deferred) = self.oagw_deferred.get() {
429 let ctx =
430 exchange_client_credentials(&deferred.authn, &deferred.client_credentials).await?;
431 let mut providers = deferred.providers.clone();
432 crate::infra::oagw_provisioning::register_oagw_upstreams(
433 &deferred.gateway,
434 &ctx,
435 &mut providers,
436 )
437 .await?;
438 }
439 Ok(())
440 }
441
442 async fn stop(&self, cancel: CancellationToken) -> anyhow::Result<()> {
443 let handle = self
444 .outbox_handle
445 .lock()
446 .map_err(|e| anyhow::anyhow!("outbox_handle lock: {e}"))?
447 .take();
448 if let Some(handle) = handle {
449 info!("Stopping outbox pipeline");
450 tokio::select! {
451 () = handle.stop() => {
452 info!("Outbox pipeline stopped");
453 }
454 () = cancel.cancelled() => {
455 info!("Outbox pipeline stop cancelled by framework deadline");
456 }
457 }
458 }
459 Ok(())
460 }
461}
462
463async fn exchange_client_credentials(
466 authn: &Arc<dyn AuthNResolverClient>,
467 creds: &crate::config::ClientCredentialsConfig,
468) -> anyhow::Result<modkit_security::SecurityContext> {
469 info!("Exchanging client credentials for OAGW provisioning context");
470 let request = ClientCredentialsRequest {
471 client_id: creds.client_id.clone(),
472 client_secret: creds.client_secret.clone(),
473 scopes: Vec::new(),
474 };
475 let result = authn
476 .exchange_client_credentials(&request)
477 .await
478 .map_err(|e| anyhow::anyhow!("client credentials exchange failed: {e}"))?;
479 info!("Security context obtained for OAGW provisioning");
480 Ok(result.security_context)
481}
482
483async fn register_plugin_schemas(
484 registry: &dyn TypesRegistryClient,
485 schemas: &[(String, &str, &str)],
486) -> anyhow::Result<()> {
487 let mut payload = Vec::with_capacity(schemas.len());
488 for (schema_str, schema_id, _label) in schemas {
489 let mut schema_json: serde_json::Value = serde_json::from_str(schema_str)?;
490 let obj = schema_json
491 .as_object_mut()
492 .ok_or_else(|| anyhow::anyhow!("schema {schema_id} is not a JSON object"))?;
493 obj.insert(
494 "additionalProperties".to_owned(),
495 serde_json::Value::Bool(false),
496 );
497 payload.push(schema_json);
498 }
499 let results = registry.register(payload).await?;
500 RegisterResult::ensure_all_ok(&results)?;
501 for (_schema_str, schema_id, label) in schemas {
502 info!(schema_id = %schema_id, "Registered {label} plugin schema in types-registry");
503 }
504 Ok(())
505}