Skip to main content

mini_chat/
module.rs

1use std::sync::{Arc, Mutex, OnceLock};
2
3use async_trait::async_trait;
4use authz_resolver_sdk::AuthZResolverClient;
5use mini_chat_sdk::{MiniChatAuditPluginSpecV1, MiniChatModelPolicyPluginSpecV1};
6use modkit::api::OpenApiRegistry;
7use modkit::contracts::RunnableCapability;
8use modkit::{DatabaseCapability, Module, ModuleCtx, RestApiCapability};
9use modkit_db::outbox::{Outbox, OutboxHandle, Partitions};
10use oagw_sdk::ServiceGatewayClientV1;
11use sea_orm_migration::MigrationTrait;
12use tokio_util::sync::CancellationToken;
13use tracing::info;
14use types_registry_sdk::{RegisterResult, TypesRegistryClient};
15
16use crate::api::rest::routes;
17use crate::domain::service::{AppServices as GenericAppServices, Repositories};
18use crate::infra::outbox::{InfraOutboxEnqueuer, UsageEventHandler};
19
20pub(crate) type AppServices = GenericAppServices<
21    TurnRepository,
22    MessageRepository,
23    QuotaUsageRepository,
24    ReactionRepository,
25    ChatRepository,
26    ThreadSummaryRepository,
27>;
28use crate::infra::db::repo::attachment_repo::AttachmentRepository;
29use crate::infra::db::repo::chat_repo::ChatRepository;
30use crate::infra::db::repo::message_repo::MessageRepository;
31use crate::infra::db::repo::quota_usage_repo::QuotaUsageRepository;
32use crate::infra::db::repo::reaction_repo::ReactionRepository;
33use crate::infra::db::repo::thread_summary_repo::ThreadSummaryRepository;
34use crate::infra::db::repo::turn_repo::TurnRepository;
35use crate::infra::db::repo::vector_store_repo::VectorStoreRepository;
36use crate::infra::llm::provider_resolver::ProviderResolver;
37use crate::infra::model_policy::ModelPolicyGateway;
38
39/// Default URL prefix for all mini-chat REST routes.
40pub const DEFAULT_URL_PREFIX: &str = "/mini-chat";
41
42/// The mini-chat module: multi-tenant AI chat with SSE streaming.
43#[modkit::module(
44    name = "mini-chat",
45    deps = ["types-registry", "authz-resolver", "oagw"],
46    capabilities = [db, rest, stateful],
47)]
48pub struct MiniChatModule {
49    service: OnceLock<Arc<AppServices>>,
50    url_prefix: OnceLock<String>,
51    outbox_handle: Mutex<Option<OutboxHandle>>,
52}
53
54impl Default for MiniChatModule {
55    fn default() -> Self {
56        Self {
57            service: OnceLock::new(),
58            url_prefix: OnceLock::new(),
59            outbox_handle: Mutex::new(None),
60        }
61    }
62}
63
64#[async_trait]
65impl Module for MiniChatModule {
66    async fn init(&self, ctx: &ModuleCtx) -> anyhow::Result<()> {
67        info!("Initializing {} module", Self::MODULE_NAME);
68
69        let mut cfg: crate::config::MiniChatConfig = ctx.config_expanded()?;
70        cfg.streaming
71            .validate()
72            .map_err(|e| anyhow::anyhow!("streaming config: {e}"))?;
73        cfg.estimation_budgets
74            .validate()
75            .map_err(|e| anyhow::anyhow!("estimation_budgets config: {e}"))?;
76        cfg.quota
77            .validate()
78            .map_err(|e| anyhow::anyhow!("quota config: {e}"))?;
79        cfg.outbox
80            .validate()
81            .map_err(|e| anyhow::anyhow!("outbox config: {e}"))?;
82        cfg.context
83            .validate()
84            .map_err(|e| anyhow::anyhow!("context config: {e}"))?;
85        for (id, entry) in &cfg.providers {
86            entry
87                .validate(id)
88                .map_err(|e| anyhow::anyhow!("providers config: {e}"))?;
89        }
90
91        let vendor = cfg.vendor.trim().to_owned();
92        if vendor.is_empty() {
93            return Err(anyhow::anyhow!(
94                "{}: vendor must be a non-empty string",
95                Self::MODULE_NAME
96            ));
97        }
98
99        let registry = ctx.client_hub().get::<dyn TypesRegistryClient>()?;
100        register_plugin_schemas(
101            &*registry,
102            &[
103                (
104                    MiniChatModelPolicyPluginSpecV1::gts_schema_with_refs_as_string(),
105                    MiniChatModelPolicyPluginSpecV1::gts_schema_id(),
106                    "model-policy",
107                ),
108                (
109                    MiniChatAuditPluginSpecV1::gts_schema_with_refs_as_string(),
110                    MiniChatAuditPluginSpecV1::gts_schema_id(),
111                    "audit",
112                ),
113            ],
114        )
115        .await?;
116
117        // Register provider upstream + route GTS entities in types-registry.
118        // OAGW will materialise these in post_init() via type provisioning.
119        crate::infra::oagw_provisioning::register_provider_entities(&*registry, &cfg.providers)
120            .await?;
121
122        self.url_prefix
123            .set(cfg.url_prefix)
124            .map_err(|_| anyhow::anyhow!("{} url_prefix already set", Self::MODULE_NAME))?;
125
126        let db_provider = ctx.db_required()?;
127        let db = Arc::new(db_provider);
128
129        // Create the model-policy gateway early for both outbox handler and services.
130        let model_policy_gw = Arc::new(ModelPolicyGateway::new(ctx.client_hub(), vendor));
131
132        // Start the outbox pipeline eagerly in init() (migrations ran in phase 2, DB is ready).
133        // The framework guarantees stop() is called on init failure, so the pipeline
134        // will be shut down cleanly if any later init step errors.
135        // The handler resolves the plugin lazily on first message delivery,
136        // avoiding a hard dependency on plugin availability during init().
137        let outbox_db = db.db();
138        let num_partitions = cfg.outbox.num_partitions;
139        let queue_name = cfg.outbox.queue_name.clone();
140
141        let outbox_handle =
142            Outbox::builder(outbox_db)
143                .queue(
144                    &queue_name,
145                    Partitions::of(u16::try_from(num_partitions).map_err(|_| {
146                        anyhow::anyhow!("num_partitions {num_partitions} exceeds u16")
147                    })?),
148                )
149                .decoupled(UsageEventHandler {
150                    plugin_provider: model_policy_gw.clone(),
151                })
152                .start()
153                .await
154                .map_err(|e| anyhow::anyhow!("outbox start: {e}"))?;
155
156        let outbox = Arc::clone(outbox_handle.outbox());
157        let outbox_enqueuer =
158            Arc::new(InfraOutboxEnqueuer::new(outbox, queue_name, num_partitions));
159
160        {
161            let mut guard = self
162                .outbox_handle
163                .lock()
164                .map_err(|e| anyhow::anyhow!("outbox_handle lock: {e}"))?;
165            if guard.is_some() {
166                return Err(anyhow::anyhow!(
167                    "{} outbox_handle already set",
168                    Self::MODULE_NAME
169                ));
170            }
171            *guard = Some(outbox_handle);
172        }
173
174        info!("Outbox pipeline started");
175
176        let authz = ctx
177            .client_hub()
178            .get::<dyn AuthZResolverClient>()
179            .map_err(|e| anyhow::anyhow!("failed to get AuthZ resolver: {e}"))?;
180
181        let gateway = ctx
182            .client_hub()
183            .get::<dyn ServiceGatewayClientV1>()
184            .map_err(|e| anyhow::anyhow!("failed to get OAGW gateway: {e}"))?;
185
186        // Pre-fill upstream_alias with host as fallback so ProviderResolver
187        // works immediately. OAGW materialises the real upstreams in post_init().
188        for entry in cfg.providers.values_mut() {
189            if entry.upstream_alias.is_none() {
190                entry.upstream_alias = Some(entry.host.clone());
191            }
192            for ovr in entry.tenant_overrides.values_mut() {
193                if ovr.upstream_alias.is_none()
194                    && let Some(ref h) = ovr.host
195                {
196                    ovr.upstream_alias = Some(h.clone());
197                }
198            }
199        }
200
201        let provider_resolver = Arc::new(ProviderResolver::new(&gateway, cfg.providers));
202
203        let repos = Repositories {
204            chat: Arc::new(ChatRepository::new(modkit_db::odata::LimitCfg {
205                default: 20,
206                max: 100,
207            })),
208            attachment: Arc::new(AttachmentRepository),
209            message: Arc::new(MessageRepository::new(modkit_db::odata::LimitCfg {
210                default: 20,
211                max: 100,
212            })),
213            quota: Arc::new(QuotaUsageRepository),
214            turn: Arc::new(TurnRepository),
215            reaction: Arc::new(ReactionRepository),
216            thread_summary: Arc::new(ThreadSummaryRepository),
217            vector_store: Arc::new(VectorStoreRepository),
218        };
219
220        let services = Arc::new(AppServices::new(
221            &repos,
222            db,
223            authz,
224            &(model_policy_gw.clone() as Arc<dyn crate::domain::repos::ModelResolver>),
225            provider_resolver,
226            cfg.streaming,
227            model_policy_gw.clone() as Arc<dyn crate::domain::repos::PolicySnapshotProvider>,
228            model_policy_gw as Arc<dyn crate::domain::repos::UserLimitsProvider>,
229            cfg.estimation_budgets,
230            cfg.quota,
231            outbox_enqueuer,
232            cfg.context,
233        ));
234
235        self.service
236            .set(services)
237            .map_err(|_| anyhow::anyhow!("{} module already initialized", Self::MODULE_NAME))?;
238
239        info!("{} module initialized successfully", Self::MODULE_NAME);
240        Ok(())
241    }
242}
243
244impl DatabaseCapability for MiniChatModule {
245    fn migrations(&self) -> Vec<Box<dyn MigrationTrait>> {
246        use sea_orm_migration::MigratorTrait;
247        info!("Providing mini-chat database migrations");
248        let mut m = crate::infra::db::migrations::Migrator::migrations();
249        m.extend(modkit_db::outbox::outbox_migrations());
250        m
251    }
252}
253
254impl RestApiCapability for MiniChatModule {
255    fn register_rest(
256        &self,
257        _ctx: &ModuleCtx,
258        router: axum::Router,
259        openapi: &dyn OpenApiRegistry,
260    ) -> anyhow::Result<axum::Router> {
261        let services = self
262            .service
263            .get()
264            .ok_or_else(|| anyhow::anyhow!("{} not initialized", Self::MODULE_NAME))?;
265
266        info!("Registering mini-chat REST routes");
267        let prefix = self
268            .url_prefix
269            .get()
270            .ok_or_else(|| anyhow::anyhow!("{} not initialized (url_prefix)", Self::MODULE_NAME))?;
271
272        let router = routes::register_routes(router, openapi, Arc::clone(services), prefix);
273        info!("Mini-chat REST routes registered successfully");
274        Ok(router)
275    }
276}
277
278#[async_trait]
279impl RunnableCapability for MiniChatModule {
280    async fn start(&self, _cancel: CancellationToken) -> anyhow::Result<()> {
281        Ok(())
282    }
283
284    async fn stop(&self, cancel: CancellationToken) -> anyhow::Result<()> {
285        let handle = self
286            .outbox_handle
287            .lock()
288            .map_err(|e| anyhow::anyhow!("outbox_handle lock: {e}"))?
289            .take();
290        if let Some(handle) = handle {
291            info!("Stopping outbox pipeline");
292            tokio::select! {
293                () = handle.stop() => {
294                    info!("Outbox pipeline stopped");
295                }
296                () = cancel.cancelled() => {
297                    info!("Outbox pipeline stop cancelled by framework deadline");
298                }
299            }
300        }
301        Ok(())
302    }
303}
304
305async fn register_plugin_schemas(
306    registry: &dyn TypesRegistryClient,
307    schemas: &[(String, &str, &str)],
308) -> anyhow::Result<()> {
309    let mut payload = Vec::with_capacity(schemas.len());
310    for (schema_str, schema_id, _label) in schemas {
311        let mut schema_json: serde_json::Value = serde_json::from_str(schema_str)?;
312        let obj = schema_json
313            .as_object_mut()
314            .ok_or_else(|| anyhow::anyhow!("schema {schema_id} is not a JSON object"))?;
315        obj.insert(
316            "additionalProperties".to_owned(),
317            serde_json::Value::Bool(false),
318        );
319        payload.push(schema_json);
320    }
321    let results = registry.register(payload).await?;
322    RegisterResult::ensure_all_ok(&results)?;
323    for (_schema_str, schema_id, label) in schemas {
324        info!(schema_id = %schema_id, "Registered {label} plugin schema in types-registry");
325    }
326    Ok(())
327}