Skip to main content

mini_chat/
module.rs

1use std::sync::{Arc, OnceLock};
2
3use async_trait::async_trait;
4use authz_resolver_sdk::AuthZResolverClient;
5use mini_chat_sdk::{MiniChatAuditPluginSpecV1, MiniChatModelPolicyPluginSpecV1};
6use modkit::api::OpenApiRegistry;
7use modkit::{DatabaseCapability, Module, ModuleCtx, RestApiCapability};
8use oagw_sdk::ServiceGatewayClientV1;
9use sea_orm_migration::MigrationTrait;
10use tracing::info;
11use types_registry_sdk::{RegisterResult, TypesRegistryClient};
12
13use crate::api::rest::routes;
14use crate::domain::service::{AppServices as GenericAppServices, Repositories};
15
16pub(crate) type AppServices = GenericAppServices<
17    TurnRepository,
18    MessageRepository,
19    QuotaUsageRepository,
20    ReactionRepository,
21    ChatRepository,
22>;
23use crate::infra::db::repo::attachment_repo::AttachmentRepository;
24use crate::infra::db::repo::chat_repo::ChatRepository;
25use crate::infra::db::repo::message_repo::MessageRepository;
26use crate::infra::db::repo::quota_usage_repo::QuotaUsageRepository;
27use crate::infra::db::repo::reaction_repo::ReactionRepository;
28use crate::infra::db::repo::thread_summary_repo::ThreadSummaryRepository;
29use crate::infra::db::repo::turn_repo::TurnRepository;
30use crate::infra::db::repo::vector_store_repo::VectorStoreRepository;
31use crate::infra::llm::provider_resolver::ProviderResolver;
32use crate::infra::model_policy::ModelPolicyGateway;
33
34/// Default URL prefix for all mini-chat REST routes.
35pub const DEFAULT_URL_PREFIX: &str = "/mini-chat";
36
37/// The mini-chat module: multi-tenant AI chat with SSE streaming.
38#[modkit::module(
39    name = "mini-chat",
40    deps = ["types-registry", "authz-resolver", "oagw"],
41    capabilities = [db, rest],
42)]
43pub struct MiniChatModule {
44    service: OnceLock<Arc<AppServices>>,
45    url_prefix: OnceLock<String>,
46}
47
48impl Default for MiniChatModule {
49    fn default() -> Self {
50        Self {
51            service: OnceLock::new(),
52            url_prefix: OnceLock::new(),
53        }
54    }
55}
56
57#[async_trait]
58impl Module for MiniChatModule {
59    async fn init(&self, ctx: &ModuleCtx) -> anyhow::Result<()> {
60        info!("Initializing {} module", Self::MODULE_NAME);
61
62        let cfg: crate::config::MiniChatConfig = ctx.config_expanded()?;
63        cfg.streaming
64            .validate()
65            .map_err(|e| anyhow::anyhow!("streaming config: {e}"))?;
66        cfg.estimation_budgets
67            .validate()
68            .map_err(|e| anyhow::anyhow!("estimation_budgets config: {e}"))?;
69        cfg.quota
70            .validate()
71            .map_err(|e| anyhow::anyhow!("quota config: {e}"))?;
72        cfg.outbox
73            .validate()
74            .map_err(|e| anyhow::anyhow!("outbox config: {e}"))?;
75        for (id, entry) in &cfg.providers {
76            entry
77                .validate(id)
78                .map_err(|e| anyhow::anyhow!("providers config: {e}"))?;
79        }
80
81        let vendor = cfg.vendor.trim().to_owned();
82        if vendor.is_empty() {
83            return Err(anyhow::anyhow!(
84                "{}: vendor must be a non-empty string",
85                Self::MODULE_NAME
86            ));
87        }
88
89        let registry = ctx.client_hub().get::<dyn TypesRegistryClient>()?;
90        register_plugin_schemas(
91            &*registry,
92            &[
93                (
94                    MiniChatModelPolicyPluginSpecV1::gts_schema_with_refs_as_string(),
95                    MiniChatModelPolicyPluginSpecV1::gts_schema_id(),
96                    "model-policy",
97                ),
98                (
99                    MiniChatAuditPluginSpecV1::gts_schema_with_refs_as_string(),
100                    MiniChatAuditPluginSpecV1::gts_schema_id(),
101                    "audit",
102                ),
103            ],
104        )
105        .await?;
106
107        self.url_prefix
108            .set(cfg.url_prefix)
109            .map_err(|_| anyhow::anyhow!("{} url_prefix already set", Self::MODULE_NAME))?;
110
111        let db = Arc::new(ctx.db_required()?);
112
113        let authz = ctx
114            .client_hub()
115            .get::<dyn AuthZResolverClient>()
116            .map_err(|e| anyhow::anyhow!("failed to get AuthZ resolver: {e}"))?;
117
118        let gateway = ctx
119            .client_hub()
120            .get::<dyn ServiceGatewayClientV1>()
121            .map_err(|e| anyhow::anyhow!("failed to get OAGW gateway: {e}"))?;
122
123        // Register OAGW upstreams for each configured provider.
124        crate::infra::oagw_provisioning::register_oagw_upstreams(&gateway, &cfg.providers).await?;
125
126        let provider_resolver = Arc::new(ProviderResolver::new(&gateway, cfg.providers));
127
128        let repos = Repositories {
129            chat: Arc::new(ChatRepository::new(modkit_db::odata::LimitCfg {
130                default: 20,
131                max: 100,
132            })),
133            attachment: Arc::new(AttachmentRepository),
134            message: Arc::new(MessageRepository::new(modkit_db::odata::LimitCfg {
135                default: 20,
136                max: 100,
137            })),
138            quota: Arc::new(QuotaUsageRepository),
139            turn: Arc::new(TurnRepository),
140            reaction: Arc::new(ReactionRepository),
141            thread_summary: Arc::new(ThreadSummaryRepository),
142            vector_store: Arc::new(VectorStoreRepository),
143        };
144
145        let model_policy_gw = Arc::new(ModelPolicyGateway::new(ctx.client_hub(), vendor));
146        let services = Arc::new(AppServices::new(
147            &repos,
148            db,
149            authz,
150            &(model_policy_gw.clone() as Arc<dyn crate::domain::repos::ModelResolver>),
151            provider_resolver,
152            cfg.streaming,
153            model_policy_gw.clone() as Arc<dyn crate::domain::repos::PolicySnapshotProvider>,
154            model_policy_gw as Arc<dyn crate::domain::repos::UserLimitsProvider>,
155            cfg.estimation_budgets,
156            cfg.quota,
157        ));
158
159        self.service
160            .set(services)
161            .map_err(|_| anyhow::anyhow!("{} module already initialized", Self::MODULE_NAME))?;
162
163        info!("{} module initialized successfully", Self::MODULE_NAME);
164        Ok(())
165    }
166}
167
168impl DatabaseCapability for MiniChatModule {
169    fn migrations(&self) -> Vec<Box<dyn MigrationTrait>> {
170        use sea_orm_migration::MigratorTrait;
171        info!("Providing mini-chat database migrations");
172        crate::infra::db::migrations::Migrator::migrations()
173    }
174}
175
176impl RestApiCapability for MiniChatModule {
177    fn register_rest(
178        &self,
179        _ctx: &ModuleCtx,
180        router: axum::Router,
181        openapi: &dyn OpenApiRegistry,
182    ) -> anyhow::Result<axum::Router> {
183        let services = self
184            .service
185            .get()
186            .ok_or_else(|| anyhow::anyhow!("{} not initialized", Self::MODULE_NAME))?;
187
188        info!("Registering mini-chat REST routes");
189        let prefix = self
190            .url_prefix
191            .get()
192            .ok_or_else(|| anyhow::anyhow!("{} not initialized (url_prefix)", Self::MODULE_NAME))?;
193
194        let router = routes::register_routes(router, openapi, Arc::clone(services), prefix);
195        info!("Mini-chat REST routes registered successfully");
196        Ok(router)
197    }
198}
199
200async fn register_plugin_schemas(
201    registry: &dyn TypesRegistryClient,
202    schemas: &[(String, &str, &str)],
203) -> anyhow::Result<()> {
204    let mut payload = Vec::with_capacity(schemas.len());
205    for (schema_str, schema_id, _label) in schemas {
206        let mut schema_json: serde_json::Value = serde_json::from_str(schema_str)?;
207        let obj = schema_json
208            .as_object_mut()
209            .ok_or_else(|| anyhow::anyhow!("schema {schema_id} is not a JSON object"))?;
210        obj.insert(
211            "additionalProperties".to_owned(),
212            serde_json::Value::Bool(false),
213        );
214        payload.push(schema_json);
215    }
216    let results = registry.register(payload).await?;
217    RegisterResult::ensure_all_ok(&results)?;
218    for (_schema_str, schema_id, label) in schemas {
219        info!(schema_id = %schema_id, "Registered {label} plugin schema in types-registry");
220    }
221    Ok(())
222}