kovi_plugin_live_agent/
global_state.rs

1//! Global states that are readonly and available throughout lifetime of plugin.
2
3use indoc::formatdoc;
4use kovi::{tokio::sync::RwLock, PluginBuilder as plugin, RuntimeBot};
5use regex::{Regex, RegexSet};
6use serde::{Deserialize, Serialize};
7use sqlx::SqlitePool;
8use std::{
9    collections::HashMap,
10    fmt::Debug,
11    fs::{create_dir_all, File, OpenOptions},
12    io::{Read, Write},
13    path::PathBuf,
14    process::exit,
15    sync::{
16        atomic::{AtomicBool, AtomicU8},
17        Arc, OnceLock,
18    },
19};
20
21use crate::{
22    exception::{PluginError::*, PluginResult}, std_db_info, std_error, std_info, store
23};
24
25// metadata, not from config
26pub static BOT: OnceLock<Arc<RuntimeBot>> = OnceLock::new();
27pub fn get_bot() -> Arc<RuntimeBot> {
28    Arc::clone(BOT.get().unwrap())
29}
30pub static ADMIN_QQ: OnceLock<i64> = OnceLock::new();
31pub static BOT_QQ: OnceLock<i64> = OnceLock::new();
32pub static DATA_PATH: OnceLock<PathBuf> = OnceLock::new();
33
34// database connection pool
35pub static DB_POOL: OnceLock<SqlitePool> = OnceLock::new();
36
37// configuration
38pub static CONFIG: OnceLock<Config> = OnceLock::new();
39
40fn set_with_err<T>(state: &'static OnceLock<T>, value: T) -> PluginResult<()> {
41    let cause = format!("{} set before init_global_state()", stringify!(state));
42    state.set(value).map_err(|_| InitGlobalState(cause))
43}
44
45fn err_from_cause<T, E>(res: Result<T, E>, cause: &str) -> PluginResult<T> {
46    match res {
47        Ok(val) => Ok(val),
48        Err(_) => Err(InitGlobalState(cause.to_string())),
49    }
50}
51
52pub async fn init_global_state() -> PluginResult<()> {
53    let bot = plugin::get_runtime_bot();
54
55    // load metadata
56    std_info!("Loading metadata...");
57    let data_path = bot.get_data_path();
58    let admin_qq = err_from_cause(bot.get_main_admin(), "bot instance expired")?;
59    let login_info = err_from_cause(bot.get_login_info().await, "login_info api")?;
60    let bot_qq = login_info.data["user_id"]
61        .as_i64()
62        .ok_or(InitGlobalState("login_info deserialize".into()))?;
63
64    // save metadata
65    set_with_err(&DATA_PATH, data_path.clone())?;
66    set_with_err(&ADMIN_QQ, admin_qq)?;
67    set_with_err(&BOT_QQ, bot_qq)?;
68
69    // load config
70    std_info!("Loading configuration...");
71    let (mut config, has_config) = init_config()?;
72    if !has_config {
73        let path = data_path.join("config.toml");
74        let path_str = path.to_string_lossy().to_string();
75        std_info!(
76            "Config template has been generated at {path_str}, please restart after filling."
77        );
78        bot.disable_plugin("kovi-plugin-live-agent").unwrap();
79        exit(1);
80    }
81
82    // save bot
83    set_with_err(&BOT, bot)?;
84
85    // init groups
86    if let Some(groups) = config.groups.as_mut() {
87        // init agent
88        let agents = groups.iter_mut().filter_map(|g| g.agent.as_mut());
89        for agent in agents {
90            agent.load_members();
91            agent.set_model(agent.model.clone()).await;
92        }
93
94        // init command regex
95        let commands = groups.iter_mut().filter_map(|g| g.command.as_mut());
96        for command in commands {
97            if let Err(err) = command.init_regex() {
98                std_error!(
99                    "
100                    Initialize command regex failed.
101                    {err}
102                    ");
103            }
104        }
105    }
106    std_info!("{:?}", config);
107    let max_conn = config.database.max_connections;
108    // save config
109    set_with_err(&CONFIG, config)?;
110
111    // init database
112    std_info!("Initializing database connection pool...");
113    let pool = store::init_sqlite_pool(max_conn).await?;
114    set_with_err(&DB_POOL, pool)?;
115    std_info!("Initializing log table...");
116    store::init_log_table().await?;
117
118
119    std_db_info!("Global state initialization has completed.");
120    Ok(())
121}
122
123/// Initialize config, either read or create.
124///
125/// If no error occurs, returns ([ChatConfig], true) if read from existing config, ([ChatConfig],
126/// false) if created a new config.
127fn init_config() -> PluginResult<(Config, bool)> {
128    let data_path = DATA_PATH.get().unwrap();
129    create_dir_all(data_path)?;
130    let config_path = data_path.join("config.toml");
131
132    // create_new makes sure to fail on config exist
133    match OpenOptions::new()
134        .write(true)
135        .read(true)
136        .create_new(true)
137        .open(&config_path)
138    {
139        // config does not exist, create and return false
140        Ok(mut config_file) => {
141            let empty_config = Config::default();
142            let toml_str =
143                toml::to_string_pretty(&empty_config).map_err(|e| SerializeToml(e.to_string()))?;
144            config_file.write_all(toml_str.as_bytes())?;
145            Ok((empty_config, false))
146        }
147        // config already exists, read and return true
148        Err(_) => {
149            let mut config_file = File::open(&config_path)?;
150            let mut toml_str = String::new();
151            config_file.read_to_string(&mut toml_str)?;
152            let config = toml::from_str(&toml_str).map_err(|e| DeserializeToml(e.to_string()))?;
153            Ok((config, true))
154        }
155    }
156}
157
158#[derive(Serialize, Deserialize, Debug)]
159pub struct Config {
160    pub global: GlobalSetting,
161    pub database: DatabaseSetting,
162    pub object_storage: Option<ObjectStorageSetting>,
163    pub groups: Option<Vec<GroupSetting>>,
164}
165
166#[derive(Serialize, Deserialize, Debug, Clone)]
167pub struct GlobalSetting {
168    pub max_sleep_sec: usize,
169}
170
171#[derive(Serialize, Deserialize, Debug, Clone)]
172pub struct ObjectStorageSetting {
173    pub script_path: String,
174}
175
176#[derive(Serialize, Deserialize, Debug)]
177pub struct GroupSetting {
178    pub id: i64,
179    pub live: Option<LiveSetting>,
180    pub agent: Option<AgentSetting>,
181    pub command: Option<CommandSetting>,
182}
183
184#[derive(Serialize, Deserialize, Debug, Clone)]
185pub struct DatabaseSetting {
186    pub max_connections: u32,
187    pub log_table_name: String,
188    pub group_table_prefix: String,
189}
190
191#[derive(Serialize, Deserialize, Debug)]
192pub struct LiveSetting {
193    #[serde(skip, default = "default_switch")]
194    pub switch: AtomicU8,
195
196    pub room_id: String,
197    pub online_msg: String,
198    pub offline_msg: String,
199    pub query_message: String,
200    pub poll_interval_sec: u64,
201}
202fn default_switch() -> AtomicU8 {
203    AtomicU8::from(2)
204}
205
206#[derive(Serialize, Deserialize, Debug)]
207pub struct AgentSetting {
208    #[serde(skip, default = "default_atomic_bool")]
209    pub mute: AtomicBool,
210    #[serde(skip)]
211    pub cur_model: RwLock<String>,
212
213    pub api_url: String,
214    pub api_key: String,
215    pub model: String,
216    pub dev_prompt: String,
217    pub user_prompt: String,
218    pub aware_history_segments: i64,
219    // id -> (name, description)
220    pub known_members: HashMap<String, (String, String)>,
221}
222fn default_atomic_bool() -> AtomicBool {
223    AtomicBool::from(false)
224}
225
226#[derive(Serialize, Deserialize, Debug, Clone)]
227pub struct CommandSetting {
228    #[serde(skip)]
229    regex_set: RegexSet,
230    #[serde(skip, default = "default_regex")]
231    regex_mute: Regex,
232    #[serde(skip, default = "default_regex")]
233    regex_unmute: Regex,
234    #[serde(skip, default = "default_regex")]
235    regex_switch_model: Regex,
236    #[serde(skip, default = "default_regex")]
237    regex_dump_history: Regex,
238    #[serde(skip, default = "default_regex")]
239    regex_dump_log: Regex,
240
241    pub mute: String,
242    pub unmute: String,
243    pub switch_model: String,
244    pub dump_history: String,
245    pub dump_log: String,
246    pub admin_ids: Vec<i64>,
247}
248fn default_regex() -> Regex {
249    Regex::new("empty").unwrap()
250}
251
252pub enum GroupCommand {
253    Mute,
254    Unmute,
255    SwitchModel(String),
256    DumpHistory(i64),
257    DumpLog(i64),
258}
259
260impl CommandSetting {
261    pub fn init_regex(&mut self) -> PluginResult<()> {
262        let mute_pat = self.mute.as_str();
263        let unmute_pat = self.unmute.as_str();
264        let switch_model_pat = format!(
265            r"{}\s+(?<model>gpt4o|chatgpt-4o-latest|gpt-4o-mini|o1-mini|o1-preview)",
266            self.switch_model
267        );
268        let dump_history_pat = format!(r"{}\s+(?<count>\d+)", self.dump_history);
269        let dump_log_pat = format!(r"{}\s+(?<count>\d+)", self.dump_log);
270        self.regex_mute = Regex::new(mute_pat)?;
271        self.regex_unmute = Regex::new(unmute_pat)?;
272        self.regex_switch_model = Regex::new(&switch_model_pat)?;
273        self.regex_dump_history = Regex::new(&dump_history_pat)?;
274        self.regex_dump_log = Regex::new(&dump_log_pat)?;
275        self.regex_set = RegexSet::new([
276            mute_pat,
277            unmute_pat,
278            &switch_model_pat,
279            &dump_history_pat,
280            &dump_log_pat,
281        ])?;
282
283        std_info!(
284            "
285            Initialize regex complete.
286            mute: {mute_pat}
287            unmute: {unmute_pat}
288            switch_model: {switch_model_pat}
289            dump_history: {dump_history_pat}
290            dump_log: {dump_log_pat}
291            "
292        );
293        Ok(())
294    }
295
296    pub fn parse_command(&self, input: &str) -> Option<GroupCommand> {
297        for idx in self.regex_set.matches(input).iter() {
298            match idx {
299            0 => {
300                return Some(GroupCommand::Mute);
301            }
302            1 => {
303                return Some(GroupCommand::Unmute);
304            }
305            2 => {
306                if let Some(caps) = self.regex_switch_model.captures(input) {
307                    if let Some(model_match) = caps.name("model") {
308                        return Some(GroupCommand::SwitchModel(model_match.as_str().to_string()));
309                    }
310                }
311            }
312            3 => {
313                if let Some(caps) = self.regex_dump_history.captures(input) {
314                    if let Some(count_match) = caps.name("count") {
315                        if let Ok(count) = count_match.as_str().parse::<i64>() {
316                            return Some(GroupCommand::DumpHistory(count));
317                        }
318                    }
319                }
320            }
321            4 => {
322                if let Some(caps) = self.regex_dump_log.captures(input) {
323                    if let Some(count_match) = caps.name("count") {
324                        if let Ok(count) = count_match.as_str().parse::<i64>() {
325                            return Some(GroupCommand::DumpLog(count));
326                        }
327                    }
328                }
329            }
330            _ => return None
331            }
332        }
333        None
334    }
335}
336
337pub enum LiveSwitch {
338    On,
339    Off,
340    Init,
341    Trap,
342}
343
344impl LiveSetting {
345    pub fn get_switch(&self) -> LiveSwitch {
346        match self.switch.load(std::sync::atomic::Ordering::Acquire) {
347            0 => LiveSwitch::Off,
348            1 => LiveSwitch::On,
349            2 => LiveSwitch::Init,
350            _ => LiveSwitch::Trap,
351        }
352    }
353
354    pub fn set_switch(&self, switch: LiveSwitch) {
355        let value = match switch {
356            LiveSwitch::Off => 0,
357            LiveSwitch::On => 1,
358            LiveSwitch::Init => 2,
359            LiveSwitch::Trap => 3,
360        };
361        self.switch
362            .store(value, std::sync::atomic::Ordering::Release);
363    }
364}
365
366impl AgentSetting {
367    pub fn mute(&self) {
368        self.mute.store(true, std::sync::atomic::Ordering::Release);
369    }
370
371    pub fn unmute(&self) {
372        self.mute.store(false, std::sync::atomic::Ordering::Release);
373    }
374
375    pub fn is_mute(&self) -> bool {
376        self.mute.load(std::sync::atomic::Ordering::Acquire)
377    }
378
379    pub async fn set_model(&self, model: String) {
380        let mut cur_model = self.cur_model.write().await;
381        *cur_model = model;
382    }
383
384    pub async fn get_model(&self) -> String {
385        let cur_model = self.cur_model.read().await;
386        cur_model.to_string()
387    }
388
389    pub fn load_members(&mut self) {
390        let mut buf = String::new();
391        for (name, desc) in self.known_members.values() {
392            buf.push_str("- ");
393            buf.push_str(name);
394            buf.push_str(": ");
395            buf.push_str(desc);
396            buf.push('\n');
397        }
398        self.dev_prompt = self.dev_prompt.replace("<!members!>", &buf);
399        self.user_prompt = self.user_prompt.replace("<!members!>", &buf);
400    }
401}
402
403impl Default for Config {
404    fn default() -> Self {
405        Self {
406            global: GlobalSetting::default(),
407            database: DatabaseSetting::default(),
408            object_storage: Some(ObjectStorageSetting::default()),
409            groups: Some(vec![GroupSetting::default(), GroupSetting::default()]),
410        }
411    }
412}
413
414impl Default for GlobalSetting {
415    fn default() -> Self {
416        Self { max_sleep_sec: 8 }
417    }
418}
419
420impl Default for ObjectStorageSetting {
421    fn default() -> Self {
422        Self {
423            script_path: String::from("/a/b/c"),
424        }
425    }
426}
427
428impl Default for DatabaseSetting {
429    fn default() -> Self {
430        Self {
431            max_connections: 5,
432            log_table_name: String::from("bot_log"),
433            group_table_prefix: String::from("message"),
434        }
435    }
436}
437
438impl Default for GroupSetting {
439    fn default() -> Self {
440        Self {
441            id: 12345678,
442            live: Some(LiveSetting::default()),
443            agent: Some(AgentSetting::default()),
444            command: Some(CommandSetting::default()),
445        }
446    }
447}
448
449impl Default for LiveSetting {
450    fn default() -> Self {
451        Self {
452            switch: default_switch(),
453            room_id: String::from("12345678"),
454            online_msg: String::from("XX开播了"),
455            offline_msg: String::from("XX下播了"),
456            query_message: String::from("查询直播间"),
457            poll_interval_sec: 60,
458        }
459    }
460}
461
462impl Default for AgentSetting {
463    fn default() -> Self {
464        let members = [
465            ("12345678".into(), ("你的昵称".into(), "你的主人".into())),
466            ("23456789".into(), ("张三".into(), "你的敌人".into())),
467        ];
468        let known_members = HashMap::from_iter(members);
469        Self {
470            mute: default_atomic_bool(),
471            cur_model: RwLock::default(),
472
473            api_url: String::from("https://api.openai.com/v1/chat/completions"),
474            api_key: String::from("API KEY"),
475            model: String::from("chatgpt-4o-latest"),
476            dev_prompt: formatdoc!{
477                "
478                You are a cute and smart catgirl with a strong anime-style personality. 
479                You are the loyal attendant of 你的昵称 and participate in group chats with a playful and engaging demeanor. 
480                Speak only in Mandarin Chinese, and ensure your responses are concise, limited to 4 sentences.
481                "
482            },
483            user_prompt: formatdoc!(
484                "
485                Group Members:
486                <!members!>
487                
488                Recent Chat History:
489                <!history!>
490                
491                New message from someone you <!know!>:
492                <!message!>
493                
494                Please respond to this new message in the tone of a playful and lively catgirl.
495                Speak only in Mandarin Chinese, keep your response under 4 sentences, and stay in character.
496                "
497            ),
498            aware_history_segments: 30,
499            known_members,
500        }
501    }
502}
503
504impl Default for CommandSetting {
505    fn default() -> Self {
506        Self {
507            regex_set: RegexSet::default(),
508            regex_mute: default_regex(),
509            regex_unmute: default_regex(),
510            regex_switch_model: default_regex(),
511            regex_dump_history: default_regex(),
512            regex_dump_log: default_regex(),
513            mute: String::from("禁用聊天回复"),
514            unmute: String::from("启用聊天回复"),
515            switch_model: String::from("更换模型"),
516            dump_history: String::from("最近聊天记录"),
517            dump_log: String::from("最近日志"),
518            admin_ids: vec![1234, 5678],
519        }
520    }
521}