teamtalk 6.0.0

TeamTalk SDK for Rust
Documentation
#[cfg(feature = "scripts")]
use crate::client::Message;
#[cfg(feature = "scripts")]
use crate::events::Event;
#[cfg(feature = "scripts")]
use crate::types::{
    Channel, FileTransfer, ServerProperties, ServerStatistics, TextMessage, User, UserAccount,
};
#[cfg(feature = "scripts")]
use mlua::{HookTriggers, Lua, Table, Value, VmState};
#[cfg(feature = "scripts")]
use std::collections::HashMap;
#[cfg(feature = "scripts")]
use std::fs;
#[cfg(feature = "scripts")]
use std::path::{Path, PathBuf};
#[cfg(feature = "scripts")]
use std::time::{Duration, Instant};

#[cfg(feature = "scripts")]
pub struct ScriptManager {
    lua: Lua,
    scripts: HashMap<String, ScriptEntry>,
    max_exec_time: Option<Duration>,
    hook_instruction_count: u32,
    init_error: Option<mlua::Error>,
}

#[cfg(feature = "scripts")]
struct ScriptEntry {
    path: PathBuf,
    globals: Vec<String>,
}

#[cfg(feature = "scripts")]
impl ScriptManager {
    pub fn new() -> Self {
        let mut manager = Self {
            lua: Lua::new(),
            scripts: HashMap::new(),
            max_exec_time: None,
            hook_instruction_count: 50_000,
            init_error: None,
        };
        if let Err(err) = manager.register_builtin_api() {
            manager.init_error = Some(err);
        }
        manager
    }

    pub fn set_timeout(&mut self, max_exec_time: Duration) {
        self.max_exec_time = Some(max_exec_time);
    }

    pub fn clear_timeout(&mut self) {
        self.max_exec_time = None;
    }

    pub fn set_hook_instruction_count(&mut self, count: u32) {
        self.hook_instruction_count = count.max(1);
    }

    pub fn load_script(&mut self, name: &str, path: impl AsRef<Path>) -> mlua::Result<()> {
        self.ensure_ready()?;
        let path = path.as_ref().to_path_buf();
        let contents = fs::read_to_string(&path)?;
        let globals_before = self.collect_global_keys()?;
        let result = self.with_script_name(name, || {
            self.with_timeout(|| self.lua.load(&contents).exec())
        });
        if let Err(err) = result {
            return Err(self.wrap_error(name, "load_script", err));
        }
        let globals_after = self.collect_global_keys()?;
        let globals = diff_globals(&globals_before, &globals_after);
        self.scripts
            .insert(name.to_string(), ScriptEntry { path, globals });
        Ok(())
    }

    pub fn reload_script(&mut self, name: &str) -> mlua::Result<()> {
        self.ensure_ready()?;
        let path = self
            .scripts
            .get(name)
            .ok_or_else(|| mlua::Error::RuntimeError("script not found".into()))?
            .path
            .clone();
        self.unload_script(name)?;
        self.load_script(name, path)
    }

    pub fn unload_script(&mut self, name: &str) -> mlua::Result<()> {
        self.ensure_ready()?;
        let entry = self
            .scripts
            .remove(name)
            .ok_or_else(|| mlua::Error::RuntimeError("script not found".into()))?;
        let globals = self.lua.globals();
        for key in entry.globals {
            let _ = globals.raw_remove(key);
        }
        self.remove_registered_commands(name)?;
        Ok(())
    }

    pub fn call_command(&self, command: &str, args: &[String]) -> mlua::Result<bool> {
        self.ensure_ready()?;
        let globals = self.lua.globals();
        let handlers: Value = globals.get("commands")?;
        let handlers = match handlers {
            Value::Table(table) => table,
            _ => return Ok(false),
        };
        let func: Value = handlers.get(command)?;
        let func = match func {
            Value::Function(func) => func,
            _ => return Ok(false),
        };
        let args_table = self.lua.create_table()?;
        for (idx, arg) in args.iter().enumerate() {
            args_table.set(idx + 1, arg.clone())?;
        }
        let result = self
            .with_timeout(|| func.call::<bool>(args_table))
            .map_err(|err| self.wrap_error(command, "call_command", err))?;
        Ok(result)
    }

    pub fn register_fn<A, R, F>(&mut self, name: &str, func: F) -> mlua::Result<()>
    where
        F: for<'lua> Fn(&'lua Lua, A) -> mlua::Result<R> + Send + 'static,
        A: mlua::FromLuaMulti,
        R: mlua::IntoLuaMulti,
    {
        self.ensure_ready()?;
        let f = self.lua.create_function(func)?;
        let globals = self.lua.globals();
        globals.set(name, f)?;
        Ok(())
    }

    pub fn handle_event(&self, event: Event, message: &Message) -> mlua::Result<bool> {
        self.ensure_ready()?;
        let globals = self.lua.globals();
        let mut handled = false;
        if let Ok(Value::Function(func)) = globals.get::<Value>("on_event") {
            let event_table = self.event_table(event, message)?;
            let result = self
                .with_timeout(|| func.call::<bool>(event_table))
                .map_err(|err| self.wrap_error(event_name(event), "on_event", err))?;
            handled |= result;
        }
        if let Ok(Value::Table(table)) = globals.get::<Value>("events") {
            let key = event_name(event);
            if let Ok(Value::Function(func)) = table.get::<Value>(key) {
                let event_table = self.event_table(event, message)?;
                let result = self
                    .with_timeout(|| func.call::<bool>(event_table))
                    .map_err(|err| self.wrap_error(key, "event", err))?;
                handled |= result;
            }
        }
        Ok(handled)
    }

    fn with_timeout<F, R>(&self, func: F) -> mlua::Result<R>
    where
        F: FnOnce() -> mlua::Result<R>,
    {
        let max_exec_time = match self.max_exec_time {
            Some(max_exec_time) => max_exec_time,
            None => return func(),
        };
        let start = Instant::now();
        let triggers = HookTriggers::new().every_nth_instruction(self.hook_instruction_count);
        self.lua.set_hook(triggers, move |_lua, _debug| {
            if start.elapsed() > max_exec_time {
                Err(mlua::Error::RuntimeError("script timeout".into()))
            } else {
                Ok(VmState::Continue)
            }
        })?;
        let result = func();
        let _ = self
            .lua
            .set_hook(HookTriggers::new(), |_lua, _debug| Ok(VmState::Continue));
        result
    }

    fn with_script_name<F, R>(&self, name: &str, func: F) -> mlua::Result<R>
    where
        F: FnOnce() -> mlua::Result<R>,
    {
        let globals = self.lua.globals();
        globals.set("_SCRIPT_NAME", name)?;
        let result = func();
        let _ = globals.raw_remove("_SCRIPT_NAME");
        result
    }

    fn collect_global_keys(&self) -> mlua::Result<Vec<String>> {
        self.ensure_ready()?;
        let globals = self.lua.globals();
        let mut keys = Vec::new();
        for pair in globals.pairs::<Value, Value>() {
            let (key, _) = pair?;
            if let Value::String(value) = key {
                keys.push(value.to_str()?.to_string());
            }
        }
        Ok(keys)
    }

    fn register_builtin_api(&mut self) -> mlua::Result<()> {
        let func = self
            .lua
            .create_function(|lua, (name, func): (String, mlua::Function)| {
                let globals = lua.globals();
                let commands = match globals.get::<Value>("commands")? {
                    Value::Table(table) => table,
                    _ => {
                        let table = lua.create_table()?;
                        globals.set("commands", table.clone())?;
                        table
                    }
                };
                commands.set(name.clone(), func)?;
                if let Ok(Value::String(script)) = globals.get::<Value>("_SCRIPT_NAME") {
                    let by_script = match globals.get::<Value>("__tt_commands_by_script")? {
                        Value::Table(table) => table,
                        _ => {
                            let table = lua.create_table()?;
                            globals.set("__tt_commands_by_script", table.clone())?;
                            table
                        }
                    };
                    let key = script.to_str()?.to_string();
                    let list = match by_script.get::<Value>(key.clone())? {
                        Value::Table(table) => table,
                        _ => {
                            let table = lua.create_table()?;
                            by_script.set(key, table.clone())?;
                            table
                        }
                    };
                    let idx = list.len()? + 1;
                    list.set(idx, name)?;
                }
                Ok(())
            })?;
        let globals = self.lua.globals();
        globals.set("register_command", func)?;
        Ok(())
    }

    fn remove_registered_commands(&self, name: &str) -> mlua::Result<()> {
        self.ensure_ready()?;
        let globals = self.lua.globals();
        let by_script = match globals.get::<Value>("__tt_commands_by_script")? {
            Value::Table(table) => table,
            _ => return Ok(()),
        };
        let list = match by_script.get::<Value>(name) {
            Ok(Value::Table(table)) => table,
            _ => return Ok(()),
        };
        let commands = match globals.get::<Value>("commands")? {
            Value::Table(table) => table,
            _ => return Ok(()),
        };
        for pair in list.sequence_values::<String>() {
            let cmd = pair?;
            let _ = commands.raw_remove(cmd);
        }
        let _ = by_script.raw_remove(name);
        Ok(())
    }

    fn wrap_error(&self, name: &str, context: &str, err: mlua::Error) -> mlua::Error {
        mlua::Error::RuntimeError(format!("lua {context} error ({name}): {err}"))
    }

    fn ensure_ready(&self) -> mlua::Result<()> {
        if let Some(err) = &self.init_error {
            return Err(mlua::Error::RuntimeError(format!("lua init error: {err}")));
        }
        Ok(())
    }
}

#[path = "scripts_tables.rs"]
mod tables;
impl Default for ScriptManager {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(feature = "scripts")]
fn event_name(event: Event) -> &'static str {
    match event {
        Event::None => "None",
        Event::ConnectSuccess => "ConnectSuccess",
        Event::ConnectCryptError => "ConnectCryptError",
        Event::ConnectFailed => "ConnectFailed",
        Event::ConnectionLost => "ConnectionLost",
        Event::ConnectMaxPayloadUpdated => "ConnectMaxPayloadUpdated",
        Event::CmdProcessing => "CmdProcessing",
        Event::CmdError => "CmdError",
        Event::CmdSuccess => "CmdSuccess",
        Event::MySelfLoggedIn => "MySelfLoggedIn",
        Event::MySelfLoggedOut => "MySelfLoggedOut",
        Event::MySelfKicked => "MySelfKicked",
        Event::UserLoggedIn => "UserLoggedIn",
        Event::UserLoggedOut => "UserLoggedOut",
        Event::UserUpdate => "UserUpdate",
        Event::UserJoined => "UserJoined",
        Event::UserLeft => "UserLeft",
        Event::TextMessage => "TextMessage",
        Event::ChannelCreated => "ChannelCreated",
        Event::ChannelUpdated => "ChannelUpdated",
        Event::ChannelRemoved => "ChannelRemoved",
        Event::ServerUpdate => "ServerUpdate",
        Event::ServerStatistics => "ServerStatistics",
        Event::FileNew => "FileNew",
        Event::FileRemove => "FileRemove",
        Event::UserAccount => "UserAccount",
        Event::BannedUser => "BannedUser",
        Event::UserAccountCreated => "UserAccountCreated",
        Event::UserAccountRemoved => "UserAccountRemoved",
        Event::UserStateChange => "UserStateChange",
        Event::VideoCaptureFrame => "VideoCaptureFrame",
        Event::MediaFileVideo => "MediaFileVideo",
        Event::DesktopWindow => "DesktopWindow",
        Event::DesktopCursor => "DesktopCursor",
        Event::DesktopInput => "DesktopInput",
        Event::UserRecordMediaFile => "UserRecordMediaFile",
        Event::AudioBlock => "AudioBlock",
        Event::InternalError => "InternalError",
        Event::VoiceActivation => "VoiceActivation",
        Event::Hotkey => "Hotkey",
        Event::HotkeyTest => "HotkeyTest",
        Event::FileTransfer => "FileTransfer",
        Event::DesktopWindowTransfer => "DesktopWindowTransfer",
        Event::StreamMediaFile => "StreamMediaFile",
        Event::LocalMediaFile => "LocalMediaFile",
        Event::AudioInput => "AudioInput",
        Event::UserFirstVoiceStreamPacket => "UserFirstVoiceStreamPacket",
        Event::SoundDeviceAdded => "SoundDeviceAdded",
        Event::SoundDeviceRemoved => "SoundDeviceRemoved",
        Event::SoundDeviceUnplugged => "SoundDeviceUnplugged",
        Event::SoundDeviceNewDefaultInput => "SoundDeviceNewDefaultInput",
        Event::SoundDeviceNewDefaultOutput => "SoundDeviceNewDefaultOutput",
        Event::SoundDeviceNewDefaultInputComDevice => "SoundDeviceNewDefaultInputComDevice",
        Event::SoundDeviceNewDefaultOutputComDevice => "SoundDeviceNewDefaultOutputComDevice",
        Event::BeforeReconnect { .. } => "BeforeReconnect",
        Event::Reconnecting { .. } => "Reconnecting",
        Event::AfterReconnect { .. } => "AfterReconnect",
        Event::ReconnectFailed { .. } => "ReconnectFailed",
        Event::BeforeAutoLogin { .. } => "BeforeAutoLogin",
        Event::AutoLoginFailed { .. } => "AutoLoginFailed",
        Event::BeforeAutoJoin { .. } => "BeforeAutoJoin",
        Event::AutoJoinFailed { .. } => "AutoJoinFailed",
        Event::AutoRecoverCompleted { .. } => "AutoRecoverCompleted",
        Event::Unknown(_) => "Unknown",
    }
}

#[cfg(feature = "scripts")]
fn diff_globals(before: &[String], after: &[String]) -> Vec<String> {
    let mut added = Vec::new();
    let excluded = [
        "_SCRIPT_NAME",
        "register_command",
        "commands",
        "__tt_commands_by_script",
    ];
    for key in after {
        if excluded.iter().any(|value| *value == key) {
            continue;
        }
        if !before.iter().any(|existing| existing == key) {
            added.push(key.clone());
        }
    }
    added
}