Skip to main content

mini_chat/
module.rs

1use std::sync::{Arc, Mutex, OnceLock};
2
3use async_trait::async_trait;
4use authz_resolver_sdk::AuthZResolverClient;
5use mini_chat_sdk::{MiniChatAuditPluginSpecV1, MiniChatModelPolicyPluginSpecV1};
6use modkit::api::OpenApiRegistry;
7use modkit::contracts::RunnableCapability;
8use modkit::{DatabaseCapability, Module, ModuleCtx, RestApiCapability};
9use modkit_db::outbox::{Outbox, OutboxHandle, Partitions};
10use oagw_sdk::ServiceGatewayClientV1;
11use sea_orm_migration::MigrationTrait;
12use tokio_util::sync::CancellationToken;
13use tracing::info;
14use types_registry_sdk::{RegisterResult, TypesRegistryClient};
15
16use crate::api::rest::routes;
17use crate::domain::service::{AppServices as GenericAppServices, Repositories};
18use crate::infra::outbox::{InfraOutboxEnqueuer, UsageEventHandler};
19
20pub(crate) type AppServices = GenericAppServices<
21    TurnRepository,
22    MessageRepository,
23    QuotaUsageRepository,
24    ReactionRepository,
25    ChatRepository,
26>;
27use crate::infra::db::repo::attachment_repo::AttachmentRepository;
28use crate::infra::db::repo::chat_repo::ChatRepository;
29use crate::infra::db::repo::message_repo::MessageRepository;
30use crate::infra::db::repo::quota_usage_repo::QuotaUsageRepository;
31use crate::infra::db::repo::reaction_repo::ReactionRepository;
32use crate::infra::db::repo::thread_summary_repo::ThreadSummaryRepository;
33use crate::infra::db::repo::turn_repo::TurnRepository;
34use crate::infra::db::repo::vector_store_repo::VectorStoreRepository;
35use crate::infra::llm::provider_resolver::ProviderResolver;
36use crate::infra::model_policy::ModelPolicyGateway;
37
38/// Default URL prefix for all mini-chat REST routes.
39pub const DEFAULT_URL_PREFIX: &str = "/mini-chat";
40
41/// The mini-chat module: multi-tenant AI chat with SSE streaming.
42#[modkit::module(
43    name = "mini-chat",
44    deps = ["types-registry", "authz-resolver", "oagw"],
45    capabilities = [db, rest, stateful],
46)]
47pub struct MiniChatModule {
48    service: OnceLock<Arc<AppServices>>,
49    url_prefix: OnceLock<String>,
50    outbox_handle: Mutex<Option<OutboxHandle>>,
51}
52
53impl Default for MiniChatModule {
54    fn default() -> Self {
55        Self {
56            service: OnceLock::new(),
57            url_prefix: OnceLock::new(),
58            outbox_handle: Mutex::new(None),
59        }
60    }
61}
62
63#[async_trait]
64impl Module for MiniChatModule {
65    async fn init(&self, ctx: &ModuleCtx) -> anyhow::Result<()> {
66        info!("Initializing {} module", Self::MODULE_NAME);
67
68        let cfg: crate::config::MiniChatConfig = ctx.config_expanded()?;
69        cfg.streaming
70            .validate()
71            .map_err(|e| anyhow::anyhow!("streaming config: {e}"))?;
72        cfg.estimation_budgets
73            .validate()
74            .map_err(|e| anyhow::anyhow!("estimation_budgets config: {e}"))?;
75        cfg.quota
76            .validate()
77            .map_err(|e| anyhow::anyhow!("quota config: {e}"))?;
78        cfg.outbox
79            .validate()
80            .map_err(|e| anyhow::anyhow!("outbox config: {e}"))?;
81        for (id, entry) in &cfg.providers {
82            entry
83                .validate(id)
84                .map_err(|e| anyhow::anyhow!("providers config: {e}"))?;
85        }
86
87        let vendor = cfg.vendor.trim().to_owned();
88        if vendor.is_empty() {
89            return Err(anyhow::anyhow!(
90                "{}: vendor must be a non-empty string",
91                Self::MODULE_NAME
92            ));
93        }
94
95        let registry = ctx.client_hub().get::<dyn TypesRegistryClient>()?;
96        register_plugin_schemas(
97            &*registry,
98            &[
99                (
100                    MiniChatModelPolicyPluginSpecV1::gts_schema_with_refs_as_string(),
101                    MiniChatModelPolicyPluginSpecV1::gts_schema_id(),
102                    "model-policy",
103                ),
104                (
105                    MiniChatAuditPluginSpecV1::gts_schema_with_refs_as_string(),
106                    MiniChatAuditPluginSpecV1::gts_schema_id(),
107                    "audit",
108                ),
109            ],
110        )
111        .await?;
112
113        self.url_prefix
114            .set(cfg.url_prefix)
115            .map_err(|_| anyhow::anyhow!("{} url_prefix already set", Self::MODULE_NAME))?;
116
117        let db_provider = ctx.db_required()?;
118        let db = Arc::new(db_provider);
119
120        // Create the model-policy gateway early for both outbox handler and services.
121        let model_policy_gw = Arc::new(ModelPolicyGateway::new(ctx.client_hub(), vendor));
122
123        // Start the outbox pipeline eagerly in init() (migrations ran in phase 2, DB is ready).
124        // The framework guarantees stop() is called on init failure, so the pipeline
125        // will be shut down cleanly if any later init step errors.
126        // The handler resolves the plugin lazily on first message delivery,
127        // avoiding a hard dependency on plugin availability during init().
128        let outbox_db = db.db();
129        let num_partitions = cfg.outbox.num_partitions;
130        let queue_name = cfg.outbox.queue_name.clone();
131
132        let outbox_handle =
133            Outbox::builder(outbox_db)
134                .queue(
135                    &queue_name,
136                    Partitions::of(u16::try_from(num_partitions).map_err(|_| {
137                        anyhow::anyhow!("num_partitions {num_partitions} exceeds u16")
138                    })?),
139                )
140                .decoupled(UsageEventHandler {
141                    plugin_provider: model_policy_gw.clone(),
142                })
143                .start()
144                .await
145                .map_err(|e| anyhow::anyhow!("outbox start: {e}"))?;
146
147        let outbox = Arc::clone(outbox_handle.outbox());
148        let outbox_enqueuer =
149            Arc::new(InfraOutboxEnqueuer::new(outbox, queue_name, num_partitions));
150
151        {
152            let mut guard = self
153                .outbox_handle
154                .lock()
155                .map_err(|e| anyhow::anyhow!("outbox_handle lock: {e}"))?;
156            if guard.is_some() {
157                return Err(anyhow::anyhow!(
158                    "{} outbox_handle already set",
159                    Self::MODULE_NAME
160                ));
161            }
162            *guard = Some(outbox_handle);
163        }
164
165        info!("Outbox pipeline started");
166
167        let authz = ctx
168            .client_hub()
169            .get::<dyn AuthZResolverClient>()
170            .map_err(|e| anyhow::anyhow!("failed to get AuthZ resolver: {e}"))?;
171
172        let gateway = ctx
173            .client_hub()
174            .get::<dyn ServiceGatewayClientV1>()
175            .map_err(|e| anyhow::anyhow!("failed to get OAGW gateway: {e}"))?;
176
177        // Register OAGW upstreams for each configured provider.
178        crate::infra::oagw_provisioning::register_oagw_upstreams(&gateway, &cfg.providers).await?;
179
180        let provider_resolver = Arc::new(ProviderResolver::new(&gateway, cfg.providers));
181
182        let repos = Repositories {
183            chat: Arc::new(ChatRepository::new(modkit_db::odata::LimitCfg {
184                default: 20,
185                max: 100,
186            })),
187            attachment: Arc::new(AttachmentRepository),
188            message: Arc::new(MessageRepository::new(modkit_db::odata::LimitCfg {
189                default: 20,
190                max: 100,
191            })),
192            quota: Arc::new(QuotaUsageRepository),
193            turn: Arc::new(TurnRepository),
194            reaction: Arc::new(ReactionRepository),
195            thread_summary: Arc::new(ThreadSummaryRepository),
196            vector_store: Arc::new(VectorStoreRepository),
197        };
198
199        let services = Arc::new(AppServices::new(
200            &repos,
201            db,
202            authz,
203            &(model_policy_gw.clone() as Arc<dyn crate::domain::repos::ModelResolver>),
204            provider_resolver,
205            cfg.streaming,
206            model_policy_gw.clone() as Arc<dyn crate::domain::repos::PolicySnapshotProvider>,
207            model_policy_gw as Arc<dyn crate::domain::repos::UserLimitsProvider>,
208            cfg.estimation_budgets,
209            cfg.quota,
210            outbox_enqueuer,
211        ));
212
213        self.service
214            .set(services)
215            .map_err(|_| anyhow::anyhow!("{} module already initialized", Self::MODULE_NAME))?;
216
217        info!("{} module initialized successfully", Self::MODULE_NAME);
218        Ok(())
219    }
220}
221
222impl DatabaseCapability for MiniChatModule {
223    fn migrations(&self) -> Vec<Box<dyn MigrationTrait>> {
224        use sea_orm_migration::MigratorTrait;
225        info!("Providing mini-chat database migrations");
226        let mut m = crate::infra::db::migrations::Migrator::migrations();
227        m.extend(modkit_db::outbox::outbox_migrations());
228        m
229    }
230}
231
232impl RestApiCapability for MiniChatModule {
233    fn register_rest(
234        &self,
235        _ctx: &ModuleCtx,
236        router: axum::Router,
237        openapi: &dyn OpenApiRegistry,
238    ) -> anyhow::Result<axum::Router> {
239        let services = self
240            .service
241            .get()
242            .ok_or_else(|| anyhow::anyhow!("{} not initialized", Self::MODULE_NAME))?;
243
244        info!("Registering mini-chat REST routes");
245        let prefix = self
246            .url_prefix
247            .get()
248            .ok_or_else(|| anyhow::anyhow!("{} not initialized (url_prefix)", Self::MODULE_NAME))?;
249
250        let router = routes::register_routes(router, openapi, Arc::clone(services), prefix);
251        info!("Mini-chat REST routes registered successfully");
252        Ok(router)
253    }
254}
255
256#[async_trait]
257impl RunnableCapability for MiniChatModule {
258    async fn start(&self, _cancel: CancellationToken) -> anyhow::Result<()> {
259        // Outbox pipeline already started in init().
260        Ok(())
261    }
262
263    async fn stop(&self, cancel: CancellationToken) -> anyhow::Result<()> {
264        let handle = self
265            .outbox_handle
266            .lock()
267            .map_err(|e| anyhow::anyhow!("outbox_handle lock: {e}"))?
268            .take();
269        if let Some(handle) = handle {
270            info!("Stopping outbox pipeline");
271            tokio::select! {
272                () = handle.stop() => {
273                    info!("Outbox pipeline stopped");
274                }
275                () = cancel.cancelled() => {
276                    info!("Outbox pipeline stop cancelled by framework deadline");
277                }
278            }
279        }
280        Ok(())
281    }
282}
283
284async fn register_plugin_schemas(
285    registry: &dyn TypesRegistryClient,
286    schemas: &[(String, &str, &str)],
287) -> anyhow::Result<()> {
288    let mut payload = Vec::with_capacity(schemas.len());
289    for (schema_str, schema_id, _label) in schemas {
290        let mut schema_json: serde_json::Value = serde_json::from_str(schema_str)?;
291        let obj = schema_json
292            .as_object_mut()
293            .ok_or_else(|| anyhow::anyhow!("schema {schema_id} is not a JSON object"))?;
294        obj.insert(
295            "additionalProperties".to_owned(),
296            serde_json::Value::Bool(false),
297        );
298        payload.push(schema_json);
299    }
300    let results = registry.register(payload).await?;
301    RegisterResult::ensure_all_ok(&results)?;
302    for (_schema_str, schema_id, label) in schemas {
303        info!(schema_id = %schema_id, "Registered {label} plugin schema in types-registry");
304    }
305    Ok(())
306}