1use std::sync::{Arc, Mutex, OnceLock};
2
3use async_trait::async_trait;
4use authz_resolver_sdk::AuthZResolverClient;
5use mini_chat_sdk::{MiniChatAuditPluginSpecV1, MiniChatModelPolicyPluginSpecV1};
6use modkit::api::OpenApiRegistry;
7use modkit::contracts::RunnableCapability;
8use modkit::{DatabaseCapability, Module, ModuleCtx, RestApiCapability};
9use modkit_db::outbox::{Outbox, OutboxHandle, Partitions};
10use oagw_sdk::ServiceGatewayClientV1;
11use sea_orm_migration::MigrationTrait;
12use tokio_util::sync::CancellationToken;
13use tracing::info;
14use types_registry_sdk::{RegisterResult, TypesRegistryClient};
15
16use crate::api::rest::routes;
17use crate::domain::service::{AppServices as GenericAppServices, Repositories};
18use crate::infra::outbox::{InfraOutboxEnqueuer, UsageEventHandler};
19
20pub(crate) type AppServices = GenericAppServices<
21 TurnRepository,
22 MessageRepository,
23 QuotaUsageRepository,
24 ReactionRepository,
25 ChatRepository,
26 ThreadSummaryRepository,
27>;
28use crate::infra::db::repo::attachment_repo::AttachmentRepository;
29use crate::infra::db::repo::chat_repo::ChatRepository;
30use crate::infra::db::repo::message_repo::MessageRepository;
31use crate::infra::db::repo::quota_usage_repo::QuotaUsageRepository;
32use crate::infra::db::repo::reaction_repo::ReactionRepository;
33use crate::infra::db::repo::thread_summary_repo::ThreadSummaryRepository;
34use crate::infra::db::repo::turn_repo::TurnRepository;
35use crate::infra::db::repo::vector_store_repo::VectorStoreRepository;
36use crate::infra::llm::provider_resolver::ProviderResolver;
37use crate::infra::model_policy::ModelPolicyGateway;
38
39pub const DEFAULT_URL_PREFIX: &str = "/mini-chat";
41
42#[modkit::module(
44 name = "mini-chat",
45 deps = ["types-registry", "authz-resolver", "oagw"],
46 capabilities = [db, rest, stateful],
47)]
48pub struct MiniChatModule {
49 service: OnceLock<Arc<AppServices>>,
50 url_prefix: OnceLock<String>,
51 outbox_handle: Mutex<Option<OutboxHandle>>,
52 oagw_deferred: OnceLock<OagwDeferred>,
54}
55
56struct OagwDeferred {
58 gateway: Arc<dyn ServiceGatewayClientV1>,
59 providers: std::collections::HashMap<String, crate::config::ProviderEntry>,
60}
61
62impl Default for MiniChatModule {
63 fn default() -> Self {
64 Self {
65 service: OnceLock::new(),
66 url_prefix: OnceLock::new(),
67 outbox_handle: Mutex::new(None),
68 oagw_deferred: OnceLock::new(),
69 }
70 }
71}
72
73#[async_trait]
74impl Module for MiniChatModule {
75 async fn init(&self, ctx: &ModuleCtx) -> anyhow::Result<()> {
76 info!("Initializing {} module", Self::MODULE_NAME);
77
78 let mut cfg: crate::config::MiniChatConfig = ctx.config_expanded()?;
79 cfg.streaming
80 .validate()
81 .map_err(|e| anyhow::anyhow!("streaming config: {e}"))?;
82 cfg.estimation_budgets
83 .validate()
84 .map_err(|e| anyhow::anyhow!("estimation_budgets config: {e}"))?;
85 cfg.quota
86 .validate()
87 .map_err(|e| anyhow::anyhow!("quota config: {e}"))?;
88 cfg.outbox
89 .validate()
90 .map_err(|e| anyhow::anyhow!("outbox config: {e}"))?;
91 cfg.context
92 .validate()
93 .map_err(|e| anyhow::anyhow!("context config: {e}"))?;
94 for (id, entry) in &cfg.providers {
95 entry
96 .validate(id)
97 .map_err(|e| anyhow::anyhow!("providers config: {e}"))?;
98 }
99
100 let vendor = cfg.vendor.trim().to_owned();
101 if vendor.is_empty() {
102 return Err(anyhow::anyhow!(
103 "{}: vendor must be a non-empty string",
104 Self::MODULE_NAME
105 ));
106 }
107
108 let registry = ctx.client_hub().get::<dyn TypesRegistryClient>()?;
109 register_plugin_schemas(
110 &*registry,
111 &[
112 (
113 MiniChatModelPolicyPluginSpecV1::gts_schema_with_refs_as_string(),
114 MiniChatModelPolicyPluginSpecV1::gts_schema_id(),
115 "model-policy",
116 ),
117 (
118 MiniChatAuditPluginSpecV1::gts_schema_with_refs_as_string(),
119 MiniChatAuditPluginSpecV1::gts_schema_id(),
120 "audit",
121 ),
122 ],
123 )
124 .await?;
125
126 self.url_prefix
127 .set(cfg.url_prefix)
128 .map_err(|_| anyhow::anyhow!("{} url_prefix already set", Self::MODULE_NAME))?;
129
130 let db_provider = ctx.db_required()?;
131 let db = Arc::new(db_provider);
132
133 let model_policy_gw = Arc::new(ModelPolicyGateway::new(ctx.client_hub(), vendor));
135
136 let outbox_db = db.db();
142 let num_partitions = cfg.outbox.num_partitions;
143 let queue_name = cfg.outbox.queue_name.clone();
144
145 let outbox_handle =
146 Outbox::builder(outbox_db)
147 .queue(
148 &queue_name,
149 Partitions::of(u16::try_from(num_partitions).map_err(|_| {
150 anyhow::anyhow!("num_partitions {num_partitions} exceeds u16")
151 })?),
152 )
153 .decoupled(UsageEventHandler {
154 plugin_provider: model_policy_gw.clone(),
155 })
156 .start()
157 .await
158 .map_err(|e| anyhow::anyhow!("outbox start: {e}"))?;
159
160 let outbox = Arc::clone(outbox_handle.outbox());
161 let outbox_enqueuer =
162 Arc::new(InfraOutboxEnqueuer::new(outbox, queue_name, num_partitions));
163
164 {
165 let mut guard = self
166 .outbox_handle
167 .lock()
168 .map_err(|e| anyhow::anyhow!("outbox_handle lock: {e}"))?;
169 if guard.is_some() {
170 return Err(anyhow::anyhow!(
171 "{} outbox_handle already set",
172 Self::MODULE_NAME
173 ));
174 }
175 *guard = Some(outbox_handle);
176 }
177
178 info!("Outbox pipeline started");
179
180 let authz = ctx
181 .client_hub()
182 .get::<dyn AuthZResolverClient>()
183 .map_err(|e| anyhow::anyhow!("failed to get AuthZ resolver: {e}"))?;
184
185 let gateway = ctx
186 .client_hub()
187 .get::<dyn ServiceGatewayClientV1>()
188 .map_err(|e| anyhow::anyhow!("failed to get OAGW gateway: {e}"))?;
189
190 for entry in cfg.providers.values_mut() {
195 if entry.upstream_alias.is_none() {
196 entry.upstream_alias = Some(entry.host.clone());
197 }
198 for ovr in entry.tenant_overrides.values_mut() {
199 if ovr.upstream_alias.is_none()
200 && let Some(ref h) = ovr.host
201 {
202 ovr.upstream_alias = Some(h.clone());
203 }
204 }
205 }
206
207 drop(self.oagw_deferred.set(OagwDeferred {
210 gateway: Arc::clone(&gateway),
211 providers: cfg.providers.clone(),
212 }));
213
214 let provider_resolver = Arc::new(ProviderResolver::new(&gateway, cfg.providers));
215
216 let repos = Repositories {
217 chat: Arc::new(ChatRepository::new(modkit_db::odata::LimitCfg {
218 default: 20,
219 max: 100,
220 })),
221 attachment: Arc::new(AttachmentRepository),
222 message: Arc::new(MessageRepository::new(modkit_db::odata::LimitCfg {
223 default: 20,
224 max: 100,
225 })),
226 quota: Arc::new(QuotaUsageRepository),
227 turn: Arc::new(TurnRepository),
228 reaction: Arc::new(ReactionRepository),
229 thread_summary: Arc::new(ThreadSummaryRepository),
230 vector_store: Arc::new(VectorStoreRepository),
231 };
232
233 let services = Arc::new(AppServices::new(
234 &repos,
235 db,
236 authz,
237 &(model_policy_gw.clone() as Arc<dyn crate::domain::repos::ModelResolver>),
238 provider_resolver,
239 cfg.streaming,
240 model_policy_gw.clone() as Arc<dyn crate::domain::repos::PolicySnapshotProvider>,
241 model_policy_gw as Arc<dyn crate::domain::repos::UserLimitsProvider>,
242 cfg.estimation_budgets,
243 cfg.quota,
244 outbox_enqueuer,
245 cfg.context,
246 ));
247
248 self.service
249 .set(services)
250 .map_err(|_| anyhow::anyhow!("{} module already initialized", Self::MODULE_NAME))?;
251
252 info!("{} module initialized successfully", Self::MODULE_NAME);
253 Ok(())
254 }
255}
256
257impl DatabaseCapability for MiniChatModule {
258 fn migrations(&self) -> Vec<Box<dyn MigrationTrait>> {
259 use sea_orm_migration::MigratorTrait;
260 info!("Providing mini-chat database migrations");
261 let mut m = crate::infra::db::migrations::Migrator::migrations();
262 m.extend(modkit_db::outbox::outbox_migrations());
263 m
264 }
265}
266
267impl RestApiCapability for MiniChatModule {
268 fn register_rest(
269 &self,
270 _ctx: &ModuleCtx,
271 router: axum::Router,
272 openapi: &dyn OpenApiRegistry,
273 ) -> anyhow::Result<axum::Router> {
274 let services = self
275 .service
276 .get()
277 .ok_or_else(|| anyhow::anyhow!("{} not initialized", Self::MODULE_NAME))?;
278
279 info!("Registering mini-chat REST routes");
280 let prefix = self
281 .url_prefix
282 .get()
283 .ok_or_else(|| anyhow::anyhow!("{} not initialized (url_prefix)", Self::MODULE_NAME))?;
284
285 let router = routes::register_routes(router, openapi, Arc::clone(services), prefix);
286 info!("Mini-chat REST routes registered successfully");
287 Ok(router)
288 }
289}
290
291#[async_trait]
292impl RunnableCapability for MiniChatModule {
293 async fn start(&self, _cancel: CancellationToken) -> anyhow::Result<()> {
294 if let Some(deferred) = self.oagw_deferred.get() {
299 let mut providers = deferred.providers.clone();
300 crate::infra::oagw_provisioning::register_oagw_upstreams(
301 &deferred.gateway,
302 &mut providers,
303 )
304 .await?;
305 }
306 Ok(())
307 }
308
309 async fn stop(&self, cancel: CancellationToken) -> anyhow::Result<()> {
310 let handle = self
311 .outbox_handle
312 .lock()
313 .map_err(|e| anyhow::anyhow!("outbox_handle lock: {e}"))?
314 .take();
315 if let Some(handle) = handle {
316 info!("Stopping outbox pipeline");
317 tokio::select! {
318 () = handle.stop() => {
319 info!("Outbox pipeline stopped");
320 }
321 () = cancel.cancelled() => {
322 info!("Outbox pipeline stop cancelled by framework deadline");
323 }
324 }
325 }
326 Ok(())
327 }
328}
329
330async fn register_plugin_schemas(
331 registry: &dyn TypesRegistryClient,
332 schemas: &[(String, &str, &str)],
333) -> anyhow::Result<()> {
334 let mut payload = Vec::with_capacity(schemas.len());
335 for (schema_str, schema_id, _label) in schemas {
336 let mut schema_json: serde_json::Value = serde_json::from_str(schema_str)?;
337 let obj = schema_json
338 .as_object_mut()
339 .ok_or_else(|| anyhow::anyhow!("schema {schema_id} is not a JSON object"))?;
340 obj.insert(
341 "additionalProperties".to_owned(),
342 serde_json::Value::Bool(false),
343 );
344 payload.push(schema_json);
345 }
346 let results = registry.register(payload).await?;
347 RegisterResult::ensure_all_ok(&results)?;
348 for (_schema_str, schema_id, label) in schemas {
349 info!(schema_id = %schema_id, "Registered {label} plugin schema in types-registry");
350 }
351 Ok(())
352}