1use std::sync::{Arc, Mutex, OnceLock};
2
3use async_trait::async_trait;
4use authn_resolver_sdk::{AuthNResolverClient, ClientCredentialsRequest};
5use authz_resolver_sdk::AuthZResolverClient;
6use mini_chat_sdk::{MiniChatAuditPluginSpecV1, MiniChatModelPolicyPluginSpecV1};
7use modkit::api::OpenApiRegistry;
8use modkit::contracts::RunnableCapability;
9use modkit::{DatabaseCapability, Module, ModuleCtx, RestApiCapability};
10use modkit_db::outbox::{Outbox, OutboxHandle, Partitions};
11use oagw_sdk::ServiceGatewayClientV1;
12use sea_orm_migration::MigrationTrait;
13use tokio_util::sync::CancellationToken;
14use tracing::info;
15use types_registry_sdk::{RegisterResult, TypesRegistryClient};
16
17use crate::api::rest::routes;
18use crate::config::ProviderEntry;
19use crate::domain::service::{AppServices as GenericAppServices, Repositories};
20use crate::infra::outbox::{AttachmentCleanupHandler, InfraOutboxEnqueuer, UsageEventHandler};
21
22pub(crate) type AppServices = GenericAppServices<
23 TurnRepository,
24 MessageRepository,
25 QuotaUsageRepository,
26 ReactionRepository,
27 ChatRepository,
28 ThreadSummaryRepository,
29 AttachmentRepository,
30 VectorStoreRepository,
31 MessageAttachmentRepository,
32>;
33use crate::infra::db::repo::attachment_repo::AttachmentRepository;
34use crate::infra::db::repo::chat_repo::ChatRepository;
35use crate::infra::db::repo::message_attachment_repo::MessageAttachmentRepository;
36use crate::infra::db::repo::message_repo::MessageRepository;
37use crate::infra::db::repo::quota_usage_repo::QuotaUsageRepository;
38use crate::infra::db::repo::reaction_repo::ReactionRepository;
39use crate::infra::db::repo::thread_summary_repo::ThreadSummaryRepository;
40use crate::infra::db::repo::turn_repo::TurnRepository;
41use crate::infra::db::repo::vector_store_repo::VectorStoreRepository;
42use crate::infra::llm::provider_resolver::ProviderResolver;
43use crate::infra::model_policy::ModelPolicyGateway;
44
45pub const DEFAULT_URL_PREFIX: &str = "/mini-chat";
47
48#[modkit::module(
50 name = "mini-chat",
51 deps = ["types-registry", "authn-resolver", "authz-resolver", "oagw"],
52 capabilities = [db, rest, stateful],
53)]
54pub struct MiniChatModule {
55 service: OnceLock<Arc<AppServices>>,
56 url_prefix: OnceLock<String>,
57 outbox_handle: Mutex<Option<OutboxHandle>>,
58 oagw_deferred: OnceLock<OagwDeferred>,
60}
61
62struct OagwDeferred {
64 gateway: Arc<dyn ServiceGatewayClientV1>,
65 authn: Arc<dyn AuthNResolverClient>,
66 client_credentials: crate::config::ClientCredentialsConfig,
67 providers: std::collections::HashMap<String, ProviderEntry>,
68}
69
70impl Default for MiniChatModule {
71 fn default() -> Self {
72 Self {
73 service: OnceLock::new(),
74 url_prefix: OnceLock::new(),
75 outbox_handle: Mutex::new(None),
76 oagw_deferred: OnceLock::new(),
77 }
78 }
79}
80
81#[allow(clippy::too_many_lines)]
82#[async_trait]
83impl Module for MiniChatModule {
84 async fn init(&self, ctx: &ModuleCtx) -> anyhow::Result<()> {
85 info!("Initializing {} module", Self::MODULE_NAME);
86
87 let mut cfg: crate::config::MiniChatConfig = ctx.config_expanded()?;
88 cfg.streaming
89 .validate()
90 .map_err(|e| anyhow::anyhow!("streaming config: {e}"))?;
91 cfg.estimation_budgets
92 .validate()
93 .map_err(|e| anyhow::anyhow!("estimation_budgets config: {e}"))?;
94 cfg.quota
95 .validate()
96 .map_err(|e| anyhow::anyhow!("quota config: {e}"))?;
97 cfg.outbox
98 .validate()
99 .map_err(|e| anyhow::anyhow!("outbox config: {e}"))?;
100 cfg.context
101 .validate()
102 .map_err(|e| anyhow::anyhow!("context config: {e}"))?;
103 cfg.client_credentials
104 .validate()
105 .map_err(|e| anyhow::anyhow!("client_credentials config: {e}"))?;
106 for (id, entry) in &cfg.providers {
107 entry
108 .validate(id)
109 .map_err(|e| anyhow::anyhow!("providers config: {e}"))?;
110 }
111
112 let vendor = cfg.vendor.trim().to_owned();
113 if vendor.is_empty() {
114 return Err(anyhow::anyhow!(
115 "{}: vendor must be a non-empty string",
116 Self::MODULE_NAME
117 ));
118 }
119
120 let registry = ctx.client_hub().get::<dyn TypesRegistryClient>()?;
121 register_plugin_schemas(
122 &*registry,
123 &[
124 (
125 MiniChatModelPolicyPluginSpecV1::gts_schema_with_refs_as_string(),
126 MiniChatModelPolicyPluginSpecV1::gts_schema_id(),
127 "model-policy",
128 ),
129 (
130 MiniChatAuditPluginSpecV1::gts_schema_with_refs_as_string(),
131 MiniChatAuditPluginSpecV1::gts_schema_id(),
132 "audit",
133 ),
134 ],
135 )
136 .await?;
137
138 self.url_prefix
139 .set(cfg.url_prefix)
140 .map_err(|_| anyhow::anyhow!("{} url_prefix already set", Self::MODULE_NAME))?;
141
142 let db_provider = ctx.db_required()?;
143 let db = Arc::new(db_provider);
144
145 let model_policy_gw = Arc::new(ModelPolicyGateway::new(
147 ctx.client_hub(),
148 vendor,
149 ctx.cancellation_token().clone(),
150 ));
151
152 let outbox_db = db.db();
158 let num_partitions = cfg.outbox.num_partitions;
159 let queue_name = cfg.outbox.queue_name.clone();
160 let cleanup_queue_name = cfg.outbox.cleanup_queue_name.clone();
161
162 let partitions = Partitions::of(
163 u16::try_from(num_partitions)
164 .map_err(|_| anyhow::anyhow!("num_partitions {num_partitions} exceeds u16"))?,
165 );
166
167 let outbox_handle = Outbox::builder(outbox_db)
168 .queue(&queue_name, partitions)
169 .decoupled(UsageEventHandler {
170 plugin_provider: model_policy_gw.clone(),
171 })
172 .queue(&cleanup_queue_name, partitions)
173 .decoupled(AttachmentCleanupHandler)
174 .start()
175 .await
176 .map_err(|e| anyhow::anyhow!("outbox start: {e}"))?;
177
178 let outbox = Arc::clone(outbox_handle.outbox());
179 let outbox_enqueuer = Arc::new(InfraOutboxEnqueuer::new(
180 outbox,
181 queue_name,
182 cleanup_queue_name,
183 num_partitions,
184 ));
185
186 {
187 let mut guard = self
188 .outbox_handle
189 .lock()
190 .map_err(|e| anyhow::anyhow!("outbox_handle lock: {e}"))?;
191 if guard.is_some() {
192 return Err(anyhow::anyhow!(
193 "{} outbox_handle already set",
194 Self::MODULE_NAME
195 ));
196 }
197 *guard = Some(outbox_handle);
198 }
199
200 info!("Outbox pipeline started");
201
202 let authz = ctx
203 .client_hub()
204 .get::<dyn AuthZResolverClient>()
205 .map_err(|e| anyhow::anyhow!("failed to get AuthZ resolver: {e}"))?;
206
207 let authn_client = ctx
208 .client_hub()
209 .get::<dyn AuthNResolverClient>()
210 .map_err(|e| anyhow::anyhow!("failed to get AuthN resolver: {e}"))?;
211
212 let gateway = ctx
213 .client_hub()
214 .get::<dyn ServiceGatewayClientV1>()
215 .map_err(|e| anyhow::anyhow!("failed to get OAGW gateway: {e}"))?;
216
217 for entry in cfg.providers.values_mut() {
222 if entry.upstream_alias.is_none() {
223 entry.upstream_alias = Some(entry.host.clone());
224 }
225 for ovr in entry.tenant_overrides.values_mut() {
226 if ovr.upstream_alias.is_none()
227 && let Some(ref h) = ovr.host
228 {
229 ovr.upstream_alias = Some(h.clone());
230 }
231 }
232 }
233
234 drop(self.oagw_deferred.set(OagwDeferred {
237 gateway: Arc::clone(&gateway),
238 authn: Arc::clone(&authn_client),
239 client_credentials: cfg.client_credentials.clone(),
240 providers: cfg.providers.clone(),
241 }));
242
243 let provider_resolver = Arc::new(ProviderResolver::new(&gateway, cfg.providers));
244
245 let repos = Repositories {
246 chat: Arc::new(ChatRepository::new(modkit_db::odata::LimitCfg {
247 default: 20,
248 max: 100,
249 })),
250 attachment: Arc::new(AttachmentRepository),
251 message: Arc::new(MessageRepository::new(modkit_db::odata::LimitCfg {
252 default: 20,
253 max: 100,
254 })),
255 quota: Arc::new(QuotaUsageRepository),
256 turn: Arc::new(TurnRepository),
257 reaction: Arc::new(ReactionRepository),
258 thread_summary: Arc::new(ThreadSummaryRepository),
259 vector_store: Arc::new(VectorStoreRepository),
260 message_attachment: Arc::new(MessageAttachmentRepository),
261 };
262
263 let rag_client = Arc::new(
264 crate::infra::llm::providers::rag_http_client::RagHttpClient::new(Arc::clone(&gateway)),
265 );
266
267 let mut file_impls: std::collections::HashMap<
270 String,
271 Arc<dyn crate::domain::ports::FileStorageProvider>,
272 > = std::collections::HashMap::new();
273 let mut vs_impls: std::collections::HashMap<
274 String,
275 Arc<dyn crate::domain::ports::VectorStoreProvider>,
276 > = std::collections::HashMap::new();
277 for (provider_id, entry) in provider_resolver.entries() {
278 let (file, vs): (
279 Arc<dyn crate::domain::ports::FileStorageProvider>,
280 Arc<dyn crate::domain::ports::VectorStoreProvider>,
281 ) = match entry.storage_kind {
282 crate::config::StorageKind::Azure => {
283 let api_version = entry.api_version.clone().unwrap_or_else(|| {
284 panic!(
285 "provider '{provider_id}': storage_kind is 'azure' \
286 but api_version is not set"
287 )
288 });
289 (
290 Arc::new(
291 crate::infra::llm::providers::azure_file_storage::AzureFileStorage::new(
292 Arc::clone(&rag_client),
293 Arc::clone(&provider_resolver),
294 api_version.clone(),
295 ),
296 ),
297 Arc::new(
298 crate::infra::llm::providers::azure_vector_store::AzureVectorStore::new(
299 Arc::clone(&rag_client),
300 Arc::clone(&provider_resolver),
301 api_version,
302 ),
303 ),
304 )
305 }
306 crate::config::StorageKind::OpenAi => (
307 Arc::new(
308 crate::infra::llm::providers::openai_file_storage::OpenAiFileStorage::new(
309 Arc::clone(&rag_client),
310 Arc::clone(&provider_resolver),
311 ),
312 ),
313 Arc::new(
314 crate::infra::llm::providers::openai_vector_store::OpenAiVectorStore::new(
315 Arc::clone(&rag_client),
316 Arc::clone(&provider_resolver),
317 ),
318 ),
319 ),
320 };
321 file_impls.insert(provider_id.clone(), file);
322 vs_impls.insert(provider_id.clone(), vs);
323 }
324 let file_storage: Arc<dyn crate::domain::ports::FileStorageProvider> = Arc::new(
325 crate::infra::llm::providers::dispatching_storage::DispatchingFileStorage::new(
326 file_impls,
327 ),
328 );
329 let vector_store_prov: Arc<dyn crate::domain::ports::VectorStoreProvider> = Arc::new(
330 crate::infra::llm::providers::dispatching_storage::DispatchingVectorStore::new(
331 vs_impls,
332 ),
333 );
334
335 let services = Arc::new(AppServices::new(
336 &repos,
337 db,
338 authz,
339 &(model_policy_gw.clone() as Arc<dyn crate::domain::repos::ModelResolver>),
340 &provider_resolver,
341 cfg.streaming,
342 model_policy_gw.clone() as Arc<dyn crate::domain::repos::PolicySnapshotProvider>,
343 model_policy_gw as Arc<dyn crate::domain::repos::UserLimitsProvider>,
344 cfg.estimation_budgets,
345 cfg.quota,
346 &(Arc::clone(&outbox_enqueuer) as Arc<dyn crate::domain::repos::OutboxEnqueuer>),
347 cfg.context,
348 file_storage,
349 vector_store_prov,
350 cfg.rag,
351 ));
352
353 self.service
354 .set(services)
355 .map_err(|_| anyhow::anyhow!("{} module already initialized", Self::MODULE_NAME))?;
356
357 info!("{} module initialized successfully", Self::MODULE_NAME);
358 Ok(())
359 }
360}
361
362impl DatabaseCapability for MiniChatModule {
363 fn migrations(&self) -> Vec<Box<dyn MigrationTrait>> {
364 use sea_orm_migration::MigratorTrait;
365 info!("Providing mini-chat database migrations");
366 let mut m = crate::infra::db::migrations::Migrator::migrations();
367 m.extend(modkit_db::outbox::outbox_migrations());
368 m
369 }
370}
371
372impl RestApiCapability for MiniChatModule {
373 fn register_rest(
374 &self,
375 _ctx: &ModuleCtx,
376 router: axum::Router,
377 openapi: &dyn OpenApiRegistry,
378 ) -> anyhow::Result<axum::Router> {
379 let services = self
380 .service
381 .get()
382 .ok_or_else(|| anyhow::anyhow!("{} not initialized", Self::MODULE_NAME))?;
383
384 info!("Registering mini-chat REST routes");
385 let prefix = self
386 .url_prefix
387 .get()
388 .ok_or_else(|| anyhow::anyhow!("{} not initialized (url_prefix)", Self::MODULE_NAME))?;
389
390 let router = routes::register_routes(router, openapi, Arc::clone(services), prefix);
391 info!("Mini-chat REST routes registered successfully");
392 Ok(router)
393 }
394}
395
396#[async_trait]
397impl RunnableCapability for MiniChatModule {
398 async fn start(&self, _cancel: CancellationToken) -> anyhow::Result<()> {
399 if let Some(deferred) = self.oagw_deferred.get() {
404 let ctx =
405 exchange_client_credentials(&deferred.authn, &deferred.client_credentials).await?;
406 let mut providers = deferred.providers.clone();
407 crate::infra::oagw_provisioning::register_oagw_upstreams(
408 &deferred.gateway,
409 &ctx,
410 &mut providers,
411 )
412 .await?;
413 }
414 Ok(())
415 }
416
417 async fn stop(&self, cancel: CancellationToken) -> anyhow::Result<()> {
418 let handle = self
419 .outbox_handle
420 .lock()
421 .map_err(|e| anyhow::anyhow!("outbox_handle lock: {e}"))?
422 .take();
423 if let Some(handle) = handle {
424 info!("Stopping outbox pipeline");
425 tokio::select! {
426 () = handle.stop() => {
427 info!("Outbox pipeline stopped");
428 }
429 () = cancel.cancelled() => {
430 info!("Outbox pipeline stop cancelled by framework deadline");
431 }
432 }
433 }
434 Ok(())
435 }
436}
437
438async fn exchange_client_credentials(
441 authn: &Arc<dyn AuthNResolverClient>,
442 creds: &crate::config::ClientCredentialsConfig,
443) -> anyhow::Result<modkit_security::SecurityContext> {
444 info!("Exchanging client credentials for OAGW provisioning context");
445 let request = ClientCredentialsRequest {
446 client_id: creds.client_id.clone(),
447 client_secret: creds.client_secret.clone(),
448 scopes: Vec::new(),
449 };
450 let result = authn
451 .exchange_client_credentials(&request)
452 .await
453 .map_err(|e| anyhow::anyhow!("client credentials exchange failed: {e}"))?;
454 info!("Security context obtained for OAGW provisioning");
455 Ok(result.security_context)
456}
457
458async fn register_plugin_schemas(
459 registry: &dyn TypesRegistryClient,
460 schemas: &[(String, &str, &str)],
461) -> anyhow::Result<()> {
462 let mut payload = Vec::with_capacity(schemas.len());
463 for (schema_str, schema_id, _label) in schemas {
464 let mut schema_json: serde_json::Value = serde_json::from_str(schema_str)?;
465 let obj = schema_json
466 .as_object_mut()
467 .ok_or_else(|| anyhow::anyhow!("schema {schema_id} is not a JSON object"))?;
468 obj.insert(
469 "additionalProperties".to_owned(),
470 serde_json::Value::Bool(false),
471 );
472 payload.push(schema_json);
473 }
474 let results = registry.register(payload).await?;
475 RegisterResult::ensure_all_ok(&results)?;
476 for (_schema_str, schema_id, label) in schemas {
477 info!(schema_id = %schema_id, "Registered {label} plugin schema in types-registry");
478 }
479 Ok(())
480}