1use indoc::formatdoc;
4use kovi::{tokio::sync::RwLock, PluginBuilder as plugin, RuntimeBot};
5use regex::{Regex, RegexSet};
6use serde::{Deserialize, Serialize};
7use sqlx::SqlitePool;
8use std::{
9 collections::HashMap,
10 fmt::Debug,
11 fs::{create_dir_all, File, OpenOptions},
12 io::{Read, Write},
13 path::PathBuf,
14 process::exit,
15 sync::{
16 atomic::{AtomicBool, AtomicU8},
17 Arc, OnceLock,
18 },
19};
20
21use crate::{
22 exception::{PluginError::*, PluginResult}, std_db_info, std_error, std_info, store
23};
24
25pub static BOT: OnceLock<Arc<RuntimeBot>> = OnceLock::new();
27pub fn get_bot() -> Arc<RuntimeBot> {
28 Arc::clone(BOT.get().unwrap())
29}
30pub static ADMIN_QQ: OnceLock<i64> = OnceLock::new();
31pub static BOT_QQ: OnceLock<i64> = OnceLock::new();
32pub static DATA_PATH: OnceLock<PathBuf> = OnceLock::new();
33
34pub static DB_POOL: OnceLock<SqlitePool> = OnceLock::new();
36
37pub static CONFIG: OnceLock<Config> = OnceLock::new();
39
40fn set_with_err<T>(state: &'static OnceLock<T>, value: T) -> PluginResult<()> {
41 let cause = format!("{} set before init_global_state()", stringify!(state));
42 state.set(value).map_err(|_| InitGlobalState(cause))
43}
44
45fn err_from_cause<T, E>(res: Result<T, E>, cause: &str) -> PluginResult<T> {
46 match res {
47 Ok(val) => Ok(val),
48 Err(_) => Err(InitGlobalState(cause.to_string())),
49 }
50}
51
52pub async fn init_global_state() -> PluginResult<()> {
53 let bot = plugin::get_runtime_bot();
54
55 std_info!("Loading metadata...");
57 let data_path = bot.get_data_path();
58 let admin_qq = err_from_cause(bot.get_main_admin(), "bot instance expired")?;
59 let login_info = err_from_cause(bot.get_login_info().await, "login_info api")?;
60 let bot_qq = login_info.data["user_id"]
61 .as_i64()
62 .ok_or(InitGlobalState("login_info deserialize".into()))?;
63
64 set_with_err(&DATA_PATH, data_path.clone())?;
66 set_with_err(&ADMIN_QQ, admin_qq)?;
67 set_with_err(&BOT_QQ, bot_qq)?;
68
69 std_info!("Loading configuration...");
71 let (mut config, has_config) = init_config()?;
72 if !has_config {
73 let path = data_path.join("config.toml");
74 let path_str = path.to_string_lossy().to_string();
75 std_info!(
76 "Config template has been generated at {path_str}, please restart after filling."
77 );
78 bot.disable_plugin("kovi-plugin-live-agent").unwrap();
79 exit(1);
80 }
81
82 set_with_err(&BOT, bot)?;
84
85 if let Some(groups) = config.groups.as_mut() {
87 let agents = groups.iter_mut().filter_map(|g| g.agent.as_mut());
89 for agent in agents {
90 agent.load_members();
91 agent.set_model(agent.model.clone()).await;
92 }
93
94 let commands = groups.iter_mut().filter_map(|g| g.command.as_mut());
96 for command in commands {
97 if let Err(err) = command.init_regex() {
98 std_error!(
99 "
100 Initialize command regex failed.
101 {err}
102 ");
103 }
104 }
105 }
106 std_info!("{:?}", config);
107 let max_conn = config.database.max_connections;
108 set_with_err(&CONFIG, config)?;
110
111 std_info!("Initializing database connection pool...");
113 let pool = store::init_sqlite_pool(max_conn).await?;
114 set_with_err(&DB_POOL, pool)?;
115 std_info!("Initializing log table...");
116 store::init_log_table().await?;
117
118
119 std_db_info!("Global state initialization has completed.");
120 Ok(())
121}
122
123fn init_config() -> PluginResult<(Config, bool)> {
128 let data_path = DATA_PATH.get().unwrap();
129 create_dir_all(data_path)?;
130 let config_path = data_path.join("config.toml");
131
132 match OpenOptions::new()
134 .write(true)
135 .read(true)
136 .create_new(true)
137 .open(&config_path)
138 {
139 Ok(mut config_file) => {
141 let empty_config = Config::default();
142 let toml_str =
143 toml::to_string_pretty(&empty_config).map_err(|e| SerializeToml(e.to_string()))?;
144 config_file.write_all(toml_str.as_bytes())?;
145 Ok((empty_config, false))
146 }
147 Err(_) => {
149 let mut config_file = File::open(&config_path)?;
150 let mut toml_str = String::new();
151 config_file.read_to_string(&mut toml_str)?;
152 let config = toml::from_str(&toml_str).map_err(|e| DeserializeToml(e.to_string()))?;
153 Ok((config, true))
154 }
155 }
156}
157
158#[derive(Serialize, Deserialize, Debug)]
159pub struct Config {
160 pub global: GlobalSetting,
161 pub database: DatabaseSetting,
162 pub object_storage: Option<ObjectStorageSetting>,
163 pub groups: Option<Vec<GroupSetting>>,
164}
165
166#[derive(Serialize, Deserialize, Debug, Clone)]
167pub struct GlobalSetting {
168 pub max_sleep_sec: usize,
169}
170
171#[derive(Serialize, Deserialize, Debug, Clone)]
172pub struct ObjectStorageSetting {
173 pub script_path: String,
174}
175
176#[derive(Serialize, Deserialize, Debug)]
177pub struct GroupSetting {
178 pub id: i64,
179 pub live: Option<LiveSetting>,
180 pub agent: Option<AgentSetting>,
181 pub command: Option<CommandSetting>,
182}
183
184#[derive(Serialize, Deserialize, Debug, Clone)]
185pub struct DatabaseSetting {
186 pub max_connections: u32,
187 pub log_table_name: String,
188 pub group_table_prefix: String,
189}
190
191#[derive(Serialize, Deserialize, Debug)]
192pub struct LiveSetting {
193 #[serde(skip, default = "default_switch")]
194 pub switch: AtomicU8,
195
196 pub room_id: String,
197 pub online_msg: String,
198 pub offline_msg: String,
199 pub query_message: String,
200 pub poll_interval_sec: u64,
201}
202fn default_switch() -> AtomicU8 {
203 AtomicU8::from(2)
204}
205
206#[derive(Serialize, Deserialize, Debug)]
207pub struct AgentSetting {
208 #[serde(skip, default = "default_atomic_bool")]
209 pub mute: AtomicBool,
210 #[serde(skip)]
211 pub cur_model: RwLock<String>,
212
213 pub api_url: String,
214 pub api_key: String,
215 pub model: String,
216 pub dev_prompt: String,
217 pub user_prompt: String,
218 pub aware_history_segments: i64,
219 pub known_members: HashMap<String, (String, String)>,
221}
222fn default_atomic_bool() -> AtomicBool {
223 AtomicBool::from(false)
224}
225
226#[derive(Serialize, Deserialize, Debug, Clone)]
227pub struct CommandSetting {
228 #[serde(skip)]
229 regex_set: RegexSet,
230 #[serde(skip, default = "default_regex")]
231 regex_mute: Regex,
232 #[serde(skip, default = "default_regex")]
233 regex_unmute: Regex,
234 #[serde(skip, default = "default_regex")]
235 regex_switch_model: Regex,
236 #[serde(skip, default = "default_regex")]
237 regex_dump_history: Regex,
238 #[serde(skip, default = "default_regex")]
239 regex_dump_log: Regex,
240
241 pub mute: String,
242 pub unmute: String,
243 pub switch_model: String,
244 pub dump_history: String,
245 pub dump_log: String,
246 pub admin_ids: Vec<i64>,
247}
248fn default_regex() -> Regex {
249 Regex::new("empty").unwrap()
250}
251
252pub enum GroupCommand {
253 Mute,
254 Unmute,
255 SwitchModel(String),
256 DumpHistory(i64),
257 DumpLog(i64),
258}
259
260impl CommandSetting {
261 pub fn init_regex(&mut self) -> PluginResult<()> {
262 let mute_pat = self.mute.as_str();
263 let unmute_pat = self.unmute.as_str();
264 let switch_model_pat = format!(
265 r"{}\s+(?<model>gpt4o|chatgpt-4o-latest|gpt-4o-mini|o1-mini|o1-preview)",
266 self.switch_model
267 );
268 let dump_history_pat = format!(r"{}\s+(?<count>\d+)", self.dump_history);
269 let dump_log_pat = format!(r"{}\s+(?<count>\d+)", self.dump_log);
270 self.regex_mute = Regex::new(mute_pat)?;
271 self.regex_unmute = Regex::new(unmute_pat)?;
272 self.regex_switch_model = Regex::new(&switch_model_pat)?;
273 self.regex_dump_history = Regex::new(&dump_history_pat)?;
274 self.regex_dump_log = Regex::new(&dump_log_pat)?;
275 self.regex_set = RegexSet::new([
276 mute_pat,
277 unmute_pat,
278 &switch_model_pat,
279 &dump_history_pat,
280 &dump_log_pat,
281 ])?;
282
283 std_info!(
284 "
285 Initialize regex complete.
286 mute: {mute_pat}
287 unmute: {unmute_pat}
288 switch_model: {switch_model_pat}
289 dump_history: {dump_history_pat}
290 dump_log: {dump_log_pat}
291 "
292 );
293 Ok(())
294 }
295
296 pub fn parse_command(&self, input: &str) -> Option<GroupCommand> {
297 for idx in self.regex_set.matches(input).iter() {
298 match idx {
299 0 => {
300 return Some(GroupCommand::Mute);
301 }
302 1 => {
303 return Some(GroupCommand::Unmute);
304 }
305 2 => {
306 if let Some(caps) = self.regex_switch_model.captures(input) {
307 if let Some(model_match) = caps.name("model") {
308 return Some(GroupCommand::SwitchModel(model_match.as_str().to_string()));
309 }
310 }
311 }
312 3 => {
313 if let Some(caps) = self.regex_dump_history.captures(input) {
314 if let Some(count_match) = caps.name("count") {
315 if let Ok(count) = count_match.as_str().parse::<i64>() {
316 return Some(GroupCommand::DumpHistory(count));
317 }
318 }
319 }
320 }
321 4 => {
322 if let Some(caps) = self.regex_dump_log.captures(input) {
323 if let Some(count_match) = caps.name("count") {
324 if let Ok(count) = count_match.as_str().parse::<i64>() {
325 return Some(GroupCommand::DumpLog(count));
326 }
327 }
328 }
329 }
330 _ => return None
331 }
332 }
333 None
334 }
335}
336
337pub enum LiveSwitch {
338 On,
339 Off,
340 Init,
341 Trap,
342}
343
344impl LiveSetting {
345 pub fn get_switch(&self) -> LiveSwitch {
346 match self.switch.load(std::sync::atomic::Ordering::Acquire) {
347 0 => LiveSwitch::Off,
348 1 => LiveSwitch::On,
349 2 => LiveSwitch::Init,
350 _ => LiveSwitch::Trap,
351 }
352 }
353
354 pub fn set_switch(&self, switch: LiveSwitch) {
355 let value = match switch {
356 LiveSwitch::Off => 0,
357 LiveSwitch::On => 1,
358 LiveSwitch::Init => 2,
359 LiveSwitch::Trap => 3,
360 };
361 self.switch
362 .store(value, std::sync::atomic::Ordering::Release);
363 }
364}
365
366impl AgentSetting {
367 pub fn mute(&self) {
368 self.mute.store(true, std::sync::atomic::Ordering::Release);
369 }
370
371 pub fn unmute(&self) {
372 self.mute.store(false, std::sync::atomic::Ordering::Release);
373 }
374
375 pub fn is_mute(&self) -> bool {
376 self.mute.load(std::sync::atomic::Ordering::Acquire)
377 }
378
379 pub async fn set_model(&self, model: String) {
380 let mut cur_model = self.cur_model.write().await;
381 *cur_model = model;
382 }
383
384 pub async fn get_model(&self) -> String {
385 let cur_model = self.cur_model.read().await;
386 cur_model.to_string()
387 }
388
389 pub fn load_members(&mut self) {
390 let mut buf = String::new();
391 for (name, desc) in self.known_members.values() {
392 buf.push_str("- ");
393 buf.push_str(name);
394 buf.push_str(": ");
395 buf.push_str(desc);
396 buf.push('\n');
397 }
398 self.dev_prompt = self.dev_prompt.replace("<!members!>", &buf);
399 self.user_prompt = self.user_prompt.replace("<!members!>", &buf);
400 }
401}
402
403impl Default for Config {
404 fn default() -> Self {
405 Self {
406 global: GlobalSetting::default(),
407 database: DatabaseSetting::default(),
408 object_storage: Some(ObjectStorageSetting::default()),
409 groups: Some(vec![GroupSetting::default(), GroupSetting::default()]),
410 }
411 }
412}
413
414impl Default for GlobalSetting {
415 fn default() -> Self {
416 Self { max_sleep_sec: 8 }
417 }
418}
419
420impl Default for ObjectStorageSetting {
421 fn default() -> Self {
422 Self {
423 script_path: String::from("/a/b/c"),
424 }
425 }
426}
427
428impl Default for DatabaseSetting {
429 fn default() -> Self {
430 Self {
431 max_connections: 5,
432 log_table_name: String::from("bot_log"),
433 group_table_prefix: String::from("message"),
434 }
435 }
436}
437
438impl Default for GroupSetting {
439 fn default() -> Self {
440 Self {
441 id: 12345678,
442 live: Some(LiveSetting::default()),
443 agent: Some(AgentSetting::default()),
444 command: Some(CommandSetting::default()),
445 }
446 }
447}
448
449impl Default for LiveSetting {
450 fn default() -> Self {
451 Self {
452 switch: default_switch(),
453 room_id: String::from("12345678"),
454 online_msg: String::from("XX开播了"),
455 offline_msg: String::from("XX下播了"),
456 query_message: String::from("查询直播间"),
457 poll_interval_sec: 60,
458 }
459 }
460}
461
462impl Default for AgentSetting {
463 fn default() -> Self {
464 let members = [
465 ("12345678".into(), ("你的昵称".into(), "你的主人".into())),
466 ("23456789".into(), ("张三".into(), "你的敌人".into())),
467 ];
468 let known_members = HashMap::from_iter(members);
469 Self {
470 mute: default_atomic_bool(),
471 cur_model: RwLock::default(),
472
473 api_url: String::from("https://api.openai.com/v1/chat/completions"),
474 api_key: String::from("API KEY"),
475 model: String::from("chatgpt-4o-latest"),
476 dev_prompt: formatdoc!{
477 "
478 You are a cute and smart catgirl with a strong anime-style personality.
479 You are the loyal attendant of 你的昵称 and participate in group chats with a playful and engaging demeanor.
480 Speak only in Mandarin Chinese, and ensure your responses are concise, limited to 4 sentences.
481 "
482 },
483 user_prompt: formatdoc!(
484 "
485 Group Members:
486 <!members!>
487
488 Recent Chat History:
489 <!history!>
490
491 New message from someone you <!know!>:
492 <!message!>
493
494 Please respond to this new message in the tone of a playful and lively catgirl.
495 Speak only in Mandarin Chinese, keep your response under 4 sentences, and stay in character.
496 "
497 ),
498 aware_history_segments: 30,
499 known_members,
500 }
501 }
502}
503
504impl Default for CommandSetting {
505 fn default() -> Self {
506 Self {
507 regex_set: RegexSet::default(),
508 regex_mute: default_regex(),
509 regex_unmute: default_regex(),
510 regex_switch_model: default_regex(),
511 regex_dump_history: default_regex(),
512 regex_dump_log: default_regex(),
513 mute: String::from("禁用聊天回复"),
514 unmute: String::from("启用聊天回复"),
515 switch_model: String::from("更换模型"),
516 dump_history: String::from("最近聊天记录"),
517 dump_log: String::from("最近日志"),
518 admin_ids: vec![1234, 5678],
519 }
520 }
521}