1use std::sync::{Arc, Mutex, OnceLock};
2
3use async_trait::async_trait;
4use authn_resolver_sdk::{AuthNResolverClient, ClientCredentialsRequest};
5use authz_resolver_sdk::AuthZResolverClient;
6use mini_chat_sdk::{MiniChatAuditPluginSpecV1, MiniChatModelPolicyPluginSpecV1};
7use modkit::api::OpenApiRegistry;
8use modkit::contracts::RunnableCapability;
9use modkit::{DatabaseCapability, Module, ModuleCtx, RestApiCapability};
10use modkit_db::outbox::{Outbox, OutboxHandle, Partitions};
11use oagw_sdk::ServiceGatewayClientV1;
12use sea_orm_migration::MigrationTrait;
13use tokio_util::sync::CancellationToken;
14use tracing::info;
15use types_registry_sdk::{RegisterResult, TypesRegistryClient};
16
17use crate::api::rest::routes;
18use crate::config::ProviderEntry;
19use crate::domain::ports::MiniChatMetricsPort;
20use crate::domain::service::{AppServices as GenericAppServices, Repositories};
21use crate::infra::metrics::MiniChatMetricsMeter;
22use crate::infra::outbox::{AttachmentCleanupHandler, InfraOutboxEnqueuer, UsageEventHandler};
23
24pub(crate) type AppServices = GenericAppServices<
25 TurnRepository,
26 MessageRepository,
27 QuotaUsageRepository,
28 ReactionRepository,
29 ChatRepository,
30 ThreadSummaryRepository,
31 AttachmentRepository,
32 VectorStoreRepository,
33 MessageAttachmentRepository,
34>;
35use crate::infra::db::repo::attachment_repo::AttachmentRepository;
36use crate::infra::db::repo::chat_repo::ChatRepository;
37use crate::infra::db::repo::message_attachment_repo::MessageAttachmentRepository;
38use crate::infra::db::repo::message_repo::MessageRepository;
39use crate::infra::db::repo::quota_usage_repo::QuotaUsageRepository;
40use crate::infra::db::repo::reaction_repo::ReactionRepository;
41use crate::infra::db::repo::thread_summary_repo::ThreadSummaryRepository;
42use crate::infra::db::repo::turn_repo::TurnRepository;
43use crate::infra::db::repo::vector_store_repo::VectorStoreRepository;
44use crate::infra::llm::provider_resolver::ProviderResolver;
45use crate::infra::model_policy::ModelPolicyGateway;
46
47pub const DEFAULT_URL_PREFIX: &str = "/mini-chat";
49
50#[modkit::module(
52 name = "mini-chat",
53 deps = ["types-registry", "authn-resolver", "authz-resolver", "oagw"],
54 capabilities = [db, rest, stateful],
55)]
56pub struct MiniChatModule {
57 service: OnceLock<Arc<AppServices>>,
58 url_prefix: OnceLock<String>,
59 outbox_handle: Mutex<Option<OutboxHandle>>,
60 oagw_deferred: OnceLock<OagwDeferred>,
62}
63
64struct OagwDeferred {
66 gateway: Arc<dyn ServiceGatewayClientV1>,
67 authn: Arc<dyn AuthNResolverClient>,
68 client_credentials: crate::config::ClientCredentialsConfig,
69 providers: std::collections::HashMap<String, ProviderEntry>,
70}
71
72impl Default for MiniChatModule {
73 fn default() -> Self {
74 Self {
75 service: OnceLock::new(),
76 url_prefix: OnceLock::new(),
77 outbox_handle: Mutex::new(None),
78 oagw_deferred: OnceLock::new(),
79 }
80 }
81}
82
83#[allow(clippy::too_many_lines)]
84#[async_trait]
85impl Module for MiniChatModule {
86 async fn init(&self, ctx: &ModuleCtx) -> anyhow::Result<()> {
87 info!("Initializing {} module", Self::MODULE_NAME);
88
89 let mut cfg: crate::config::MiniChatConfig = ctx.config_expanded()?;
90 cfg.streaming
91 .validate()
92 .map_err(|e| anyhow::anyhow!("streaming config: {e}"))?;
93 cfg.estimation_budgets
94 .validate()
95 .map_err(|e| anyhow::anyhow!("estimation_budgets config: {e}"))?;
96 cfg.quota
97 .validate()
98 .map_err(|e| anyhow::anyhow!("quota config: {e}"))?;
99 cfg.outbox
100 .validate()
101 .map_err(|e| anyhow::anyhow!("outbox config: {e}"))?;
102 cfg.context
103 .validate()
104 .map_err(|e| anyhow::anyhow!("context config: {e}"))?;
105 cfg.client_credentials
106 .validate()
107 .map_err(|e| anyhow::anyhow!("client_credentials config: {e}"))?;
108 for (id, entry) in &cfg.providers {
109 entry
110 .validate(id)
111 .map_err(|e| anyhow::anyhow!("providers config: {e}"))?;
112 }
113
114 let vendor = cfg.vendor.trim().to_owned();
115 if vendor.is_empty() {
116 return Err(anyhow::anyhow!(
117 "{}: vendor must be a non-empty string",
118 Self::MODULE_NAME
119 ));
120 }
121
122 let registry = ctx.client_hub().get::<dyn TypesRegistryClient>()?;
123 register_plugin_schemas(
124 &*registry,
125 &[
126 (
127 MiniChatModelPolicyPluginSpecV1::gts_schema_with_refs_as_string(),
128 MiniChatModelPolicyPluginSpecV1::gts_schema_id(),
129 "model-policy",
130 ),
131 (
132 MiniChatAuditPluginSpecV1::gts_schema_with_refs_as_string(),
133 MiniChatAuditPluginSpecV1::gts_schema_id(),
134 "audit",
135 ),
136 ],
137 )
138 .await?;
139
140 self.url_prefix
141 .set(cfg.url_prefix)
142 .map_err(|_| anyhow::anyhow!("{} url_prefix already set", Self::MODULE_NAME))?;
143
144 let db_provider = ctx.db_required()?;
145 let db = Arc::new(db_provider);
146
147 let model_policy_gw = Arc::new(ModelPolicyGateway::new(
149 ctx.client_hub(),
150 vendor,
151 ctx.cancellation_token().clone(),
152 ));
153
154 let outbox_db = db.db();
160 let num_partitions = cfg.outbox.num_partitions;
161 let queue_name = cfg.outbox.queue_name.clone();
162 let cleanup_queue_name = cfg.outbox.cleanup_queue_name.clone();
163
164 let partitions = Partitions::of(
165 u16::try_from(num_partitions)
166 .map_err(|_| anyhow::anyhow!("num_partitions {num_partitions} exceeds u16"))?,
167 );
168
169 let outbox_handle = Outbox::builder(outbox_db)
170 .queue(&queue_name, partitions)
171 .decoupled(UsageEventHandler {
172 plugin_provider: model_policy_gw.clone(),
173 })
174 .queue(&cleanup_queue_name, partitions)
175 .decoupled(AttachmentCleanupHandler)
176 .start()
177 .await
178 .map_err(|e| anyhow::anyhow!("outbox start: {e}"))?;
179
180 let outbox = Arc::clone(outbox_handle.outbox());
181 let outbox_enqueuer = Arc::new(InfraOutboxEnqueuer::new(
182 outbox,
183 queue_name,
184 cleanup_queue_name,
185 num_partitions,
186 ));
187
188 {
189 let mut guard = self
190 .outbox_handle
191 .lock()
192 .map_err(|e| anyhow::anyhow!("outbox_handle lock: {e}"))?;
193 if guard.is_some() {
194 return Err(anyhow::anyhow!(
195 "{} outbox_handle already set",
196 Self::MODULE_NAME
197 ));
198 }
199 *guard = Some(outbox_handle);
200 }
201
202 info!("Outbox pipeline started");
203
204 let authz = ctx
205 .client_hub()
206 .get::<dyn AuthZResolverClient>()
207 .map_err(|e| anyhow::anyhow!("failed to get AuthZ resolver: {e}"))?;
208
209 let authn_client = ctx
210 .client_hub()
211 .get::<dyn AuthNResolverClient>()
212 .map_err(|e| anyhow::anyhow!("failed to get AuthN resolver: {e}"))?;
213
214 let gateway = ctx
215 .client_hub()
216 .get::<dyn ServiceGatewayClientV1>()
217 .map_err(|e| anyhow::anyhow!("failed to get OAGW gateway: {e}"))?;
218
219 for entry in cfg.providers.values_mut() {
224 if entry.upstream_alias.is_none() {
225 entry.upstream_alias = Some(entry.host.clone());
226 }
227 for ovr in entry.tenant_overrides.values_mut() {
228 if ovr.upstream_alias.is_none()
229 && let Some(ref h) = ovr.host
230 {
231 ovr.upstream_alias = Some(h.clone());
232 }
233 }
234 }
235
236 drop(self.oagw_deferred.set(OagwDeferred {
239 gateway: Arc::clone(&gateway),
240 authn: Arc::clone(&authn_client),
241 client_credentials: cfg.client_credentials.clone(),
242 providers: cfg.providers.clone(),
243 }));
244
245 let provider_resolver = Arc::new(ProviderResolver::new(&gateway, cfg.providers));
246
247 let repos = Repositories {
248 chat: Arc::new(ChatRepository::new(modkit_db::odata::LimitCfg {
249 default: 20,
250 max: 100,
251 })),
252 attachment: Arc::new(AttachmentRepository),
253 message: Arc::new(MessageRepository::new(modkit_db::odata::LimitCfg {
254 default: 20,
255 max: 100,
256 })),
257 quota: Arc::new(QuotaUsageRepository),
258 turn: Arc::new(TurnRepository),
259 reaction: Arc::new(ReactionRepository),
260 thread_summary: Arc::new(ThreadSummaryRepository),
261 vector_store: Arc::new(VectorStoreRepository),
262 message_attachment: Arc::new(MessageAttachmentRepository),
263 };
264
265 let rag_client = Arc::new(
266 crate::infra::llm::providers::rag_http_client::RagHttpClient::new(Arc::clone(&gateway)),
267 );
268
269 let mut file_impls: std::collections::HashMap<
272 String,
273 Arc<dyn crate::domain::ports::FileStorageProvider>,
274 > = std::collections::HashMap::new();
275 let mut vs_impls: std::collections::HashMap<
276 String,
277 Arc<dyn crate::domain::ports::VectorStoreProvider>,
278 > = std::collections::HashMap::new();
279 for (provider_id, entry) in provider_resolver.entries() {
280 let (file, vs): (
281 Arc<dyn crate::domain::ports::FileStorageProvider>,
282 Arc<dyn crate::domain::ports::VectorStoreProvider>,
283 ) = match entry.storage_kind {
284 crate::config::StorageKind::Azure => {
285 let api_version = entry.api_version.clone().unwrap_or_else(|| {
286 panic!(
287 "provider '{provider_id}': storage_kind is 'azure' \
288 but api_version is not set"
289 )
290 });
291 (
292 Arc::new(
293 crate::infra::llm::providers::azure_file_storage::AzureFileStorage::new(
294 Arc::clone(&rag_client),
295 Arc::clone(&provider_resolver),
296 api_version.clone(),
297 ),
298 ),
299 Arc::new(
300 crate::infra::llm::providers::azure_vector_store::AzureVectorStore::new(
301 Arc::clone(&rag_client),
302 Arc::clone(&provider_resolver),
303 api_version,
304 ),
305 ),
306 )
307 }
308 crate::config::StorageKind::OpenAi => (
309 Arc::new(
310 crate::infra::llm::providers::openai_file_storage::OpenAiFileStorage::new(
311 Arc::clone(&rag_client),
312 Arc::clone(&provider_resolver),
313 ),
314 ),
315 Arc::new(
316 crate::infra::llm::providers::openai_vector_store::OpenAiVectorStore::new(
317 Arc::clone(&rag_client),
318 Arc::clone(&provider_resolver),
319 ),
320 ),
321 ),
322 };
323 file_impls.insert(provider_id.clone(), file);
324 vs_impls.insert(provider_id.clone(), vs);
325 }
326 let file_storage: Arc<dyn crate::domain::ports::FileStorageProvider> = Arc::new(
327 crate::infra::llm::providers::dispatching_storage::DispatchingFileStorage::new(
328 file_impls,
329 ),
330 );
331 let vector_store_prov: Arc<dyn crate::domain::ports::VectorStoreProvider> = Arc::new(
332 crate::infra::llm::providers::dispatching_storage::DispatchingVectorStore::new(
333 vs_impls,
334 ),
335 );
336
337 let metrics_prefix = cfg.metrics.effective_prefix(Self::MODULE_NAME);
339 let scope =
340 opentelemetry::InstrumentationScope::builder(Self::MODULE_NAME.to_owned()).build();
341 let metrics: Arc<dyn MiniChatMetricsPort> = Arc::new(MiniChatMetricsMeter::new(
342 &opentelemetry::global::meter_with_scope(scope),
343 &metrics_prefix,
344 ));
345 let services = Arc::new(AppServices::new(
346 &repos,
347 db,
348 authz,
349 &(model_policy_gw.clone() as Arc<dyn crate::domain::repos::ModelResolver>),
350 &provider_resolver,
351 cfg.streaming,
352 model_policy_gw.clone() as Arc<dyn crate::domain::repos::PolicySnapshotProvider>,
353 model_policy_gw as Arc<dyn crate::domain::repos::UserLimitsProvider>,
354 cfg.estimation_budgets,
355 cfg.quota,
356 &(Arc::clone(&outbox_enqueuer) as Arc<dyn crate::domain::repos::OutboxEnqueuer>),
357 cfg.context,
358 file_storage,
359 vector_store_prov,
360 cfg.rag,
361 metrics,
362 ));
363
364 self.service
365 .set(services)
366 .map_err(|_| anyhow::anyhow!("{} module already initialized", Self::MODULE_NAME))?;
367
368 info!("{} module initialized successfully", Self::MODULE_NAME);
369 Ok(())
370 }
371}
372
373impl DatabaseCapability for MiniChatModule {
374 fn migrations(&self) -> Vec<Box<dyn MigrationTrait>> {
375 use sea_orm_migration::MigratorTrait;
376 info!("Providing mini-chat database migrations");
377 let mut m = crate::infra::db::migrations::Migrator::migrations();
378 m.extend(modkit_db::outbox::outbox_migrations());
379 m
380 }
381}
382
383impl RestApiCapability for MiniChatModule {
384 fn register_rest(
385 &self,
386 _ctx: &ModuleCtx,
387 router: axum::Router,
388 openapi: &dyn OpenApiRegistry,
389 ) -> anyhow::Result<axum::Router> {
390 let services = self
391 .service
392 .get()
393 .ok_or_else(|| anyhow::anyhow!("{} not initialized", Self::MODULE_NAME))?;
394
395 info!("Registering mini-chat REST routes");
396 let prefix = self
397 .url_prefix
398 .get()
399 .ok_or_else(|| anyhow::anyhow!("{} not initialized (url_prefix)", Self::MODULE_NAME))?;
400
401 let router = routes::register_routes(router, openapi, Arc::clone(services), prefix);
402 info!("Mini-chat REST routes registered successfully");
403 Ok(router)
404 }
405}
406
407#[async_trait]
408impl RunnableCapability for MiniChatModule {
409 async fn start(&self, _cancel: CancellationToken) -> anyhow::Result<()> {
410 if let Some(deferred) = self.oagw_deferred.get() {
415 let ctx =
416 exchange_client_credentials(&deferred.authn, &deferred.client_credentials).await?;
417 let mut providers = deferred.providers.clone();
418 crate::infra::oagw_provisioning::register_oagw_upstreams(
419 &deferred.gateway,
420 &ctx,
421 &mut providers,
422 )
423 .await?;
424 }
425 Ok(())
426 }
427
428 async fn stop(&self, cancel: CancellationToken) -> anyhow::Result<()> {
429 let handle = self
430 .outbox_handle
431 .lock()
432 .map_err(|e| anyhow::anyhow!("outbox_handle lock: {e}"))?
433 .take();
434 if let Some(handle) = handle {
435 info!("Stopping outbox pipeline");
436 tokio::select! {
437 () = handle.stop() => {
438 info!("Outbox pipeline stopped");
439 }
440 () = cancel.cancelled() => {
441 info!("Outbox pipeline stop cancelled by framework deadline");
442 }
443 }
444 }
445 Ok(())
446 }
447}
448
449async fn exchange_client_credentials(
452 authn: &Arc<dyn AuthNResolverClient>,
453 creds: &crate::config::ClientCredentialsConfig,
454) -> anyhow::Result<modkit_security::SecurityContext> {
455 info!("Exchanging client credentials for OAGW provisioning context");
456 let request = ClientCredentialsRequest {
457 client_id: creds.client_id.clone(),
458 client_secret: creds.client_secret.clone(),
459 scopes: Vec::new(),
460 };
461 let result = authn
462 .exchange_client_credentials(&request)
463 .await
464 .map_err(|e| anyhow::anyhow!("client credentials exchange failed: {e}"))?;
465 info!("Security context obtained for OAGW provisioning");
466 Ok(result.security_context)
467}
468
469async fn register_plugin_schemas(
470 registry: &dyn TypesRegistryClient,
471 schemas: &[(String, &str, &str)],
472) -> anyhow::Result<()> {
473 let mut payload = Vec::with_capacity(schemas.len());
474 for (schema_str, schema_id, _label) in schemas {
475 let mut schema_json: serde_json::Value = serde_json::from_str(schema_str)?;
476 let obj = schema_json
477 .as_object_mut()
478 .ok_or_else(|| anyhow::anyhow!("schema {schema_id} is not a JSON object"))?;
479 obj.insert(
480 "additionalProperties".to_owned(),
481 serde_json::Value::Bool(false),
482 );
483 payload.push(schema_json);
484 }
485 let results = registry.register(payload).await?;
486 RegisterResult::ensure_all_ok(&results)?;
487 for (_schema_str, schema_id, label) in schemas {
488 info!(schema_id = %schema_id, "Registered {label} plugin schema in types-registry");
489 }
490 Ok(())
491}